import itertools
import logging
import math
import os
import sys

from sandbox import sdk2
import sandbox.common.types.resource as ctr
from sandbox.common import errors
from sandbox.sdk2.helpers import subprocess as sp


class EShowGroupCorrectorBinary(sdk2.Resource):
    """Binary to build and launch eshow group corrector graphs in Nirvana"""
    releasable = True
    releasers = ["alex-serov", "berezniker"]


class AdsEShowBuildGroupCorrectorGraph(sdk2.Task):
    """Task to run compiled binary for EShow group corrector graph builder"""

    class Parameters(sdk2.Task.Parameters):
        description = "Build EShow Group Corrector training graph"
        kill_timeout = 1 * 24 * 60

        binary_resource = sdk2.parameters.Resource(
            "Group Corrector Graph Builder Binary (if empty - will find last released)",
            resource_type=EShowGroupCorrectorBinary,
            state=(ctr.State.READY, ),
        )
        timestamp = sdk2.parameters.Integer(
            "Timestamp to build graph (if none or 0 - will take current time)",
            required=False,
        )

        with sdk2.parameters.Group("Model Parameters") as model_params:
            task_id = sdk2.parameters.String(
                "Model TaskID",
                default="test",
                required=True
            )
            samples_weighting = sdk2.parameters.Dict(
                "Weight multipliers for samples based on age in hours",
                required=True, default={24: 1.0}
            )
            min_eshow = sdk2.parameters.Float(
                "Minimal possible value of EShow (for clamping)",
                required=True, default=0.1
            )
            max_eshow = sdk2.parameters.Float(
                "Maximal possible value of EShow (for clamping)",
                required=True, default=10.0
            )
            override_orders_config = sdk2.parameters.String(
                "Override Parameters Config",
                default="ads/quality/reach_product/eshow_group_corrector/override_orders.yaml"
            )
            ignore_expid_dates = sdk2.parameters.Dict(
                "Ignore dates in training pool for these ExpIDs",
            )
            use_eshow_expid = sdk2.parameters.Bool(
                "Use EShowExpID namespace",
                default=False
            )
            use_inventory_type = sdk2.parameters.Bool(
                "Use InventoryType namespace",
                default=False
            )
            upload_to_ml_storage = sdk2.parameters.Bool(
                "Upload corrector dump to ML Storage",
                default=False
            )

            correction_power_duplication_period = sdk2.parameters.Float(
                "Period for duplicating correction power",
                default=0.0
            )
            min_shows_per_day = sdk2.parameters.Integer(
                "Minimal number of shows per order per day to be considered",
                default=1
            )
            correction_size_min = sdk2.parameters.Integer(
                "Minimal order size (shows / day) to be corrected at all"
            )

        with sdk2.parameters.Group("Nirvana Parameters") as nirvana_params:
            nirvana_quota = sdk2.parameters.String(
                "Nirvana Quota",
                required=False, default="ads"
            )
            nirvana_ns_id = sdk2.parameters.Integer(
                "Nirvana nsId",
                default=0
            )
            nirvana_project = sdk2.parameters.String(
                "Nirvana Project",
                default=""
            )
            nirvana_secret = sdk2.parameters.String(
                "Nirvana Secret",
                required=True,
            )
            dump_owners = sdk2.parameters.List(
                "Dump owners",
                required=True, value_type=sdk2.parameters.String,
            )

            email_from = sdk2.parameters.Staff(
                "Login to send email from"
            )
            email_to = sdk2.parameters.List(
                "Send email to",
                value_type=sdk2.parameters.Staff
            )

        with sdk2.parameters.Group("YT Parameters") as yt_params:
            yt_proxy = sdk2.parameters.String(
                "YT cluster",
                required=True, default="hahn"
            )
            yt_pool = sdk2.parameters.String(
                "YT Pool",
                required=False, default="ml-engine"
            )
            yt_operations_weight = sdk2.parameters.Float(
                "YT operations weight",
                required=False, default=1.0
            )
            mr_account = sdk2.parameters.String(
                "MR account",
                required=False, default="bs"
            )

        secret_name = sdk2.parameters.YavSecret(
            "Secret in YAV with that can run graphs in Nirvana. Use #key for key",
            required=True,
        )

    def on_execute(self):
        if not self.Parameters.binary_resource:
            sys.stdout.write("No specific binary given - will find latest released\n")
            latest_released = sdk2.Resource.find(
                type=EShowGroupCorrectorBinary,
                attrs={"released": "stable"}
            ).first()

            if latest_released:
                sys.stdout.write("Using resource {}\n".format(latest_released.id))
                binary_path = str(sdk2.ResourceData(latest_released).path)
            else:
                raise errors.TaskFailure("Failed to find at least one released binary")
        else:
            binary_path = str(sdk2.ResourceData(self.Parameters.binary_resource).path)

        cmd = [
            binary_path, "train-corrector",
            "--task-id", self.Parameters.task_id,
            "--corrector-min-eshow", str(self.Parameters.min_eshow),
            "--corrector-max-eshow", str(self.Parameters.max_eshow),
            "--yt-proxy", self.Parameters.yt_proxy,
            "--nirvana-secret", self.Parameters.nirvana_secret,
            "--min-shows-per-day", str(self.Parameters.min_shows_per_day),
            "--dump-owners"
        ]
        cmd.extend(self.Parameters.dump_owners)

        if self.Parameters.correction_size_min >= 1 and self.Parameters.correction_power_duplication_period > 1e-6:
            cmd.extend((
                "--correction-size-log-min", str(math.log10(float(self.Parameters.correction_size_min))),
                "--correction-power-duplication-period", str(self.Parameters.correction_power_duplication_period)
            ))

        if self.Parameters.override_orders_config:
            sdk2.svn.Arcadia.export(sdk2.svn.Arcadia.trunk_url(self.Parameters.override_orders_config), os.getcwd())

            with open(self.Parameters.override_orders_config.split("/")[-1], "r") as fh:
                for line in fh:
                    logging.info(line)

            cmd.extend(("--override-orders", self.Parameters.override_orders_config.split("/")[-1]))

        if self.Parameters.ignore_expid_dates:
            cmd.append("--exps-ignore-dates")
            cmd.extend("{}:{}".format(exp_id, dates) for (exp_id, dates) in self.Parameters.ignore_expid_dates.items())

        if not self.Parameters.use_eshow_expid:
            cmd.append("--no-eshow-expid")

        if self.Parameters.use_inventory_type:
            cmd.append("--inventory-type")

        if self.Parameters.upload_to_ml_storage:
            cmd.append("--upload-to-ml-storage")

        if self.Parameters.mr_account:
            cmd.extend(("--mr-account", self.Parameters.mr_account))

        bps, wts = [], []

        for (bp, wt) in sorted((int(k), float(v)) for (k, v) in self.Parameters.samples_weighting.items()):
            bps.append(str(bp))
            wts.append(str(wt))

        cmd.extend(itertools.chain(["--breakpoints"], bps, ["--weights"], wts))

        if self.Parameters.timestamp:
            cmd.extend(["--timestamp", str(self.Parameters.timestamp)])

        if self.Parameters.nirvana_quota:
            cmd.extend(["--nirvana-quota", self.Parameters.nirvana_quota])

        if self.Parameters.yt_pool:
            cmd.extend(["--yt-pool", self.Parameters.yt_pool])

        if abs(self.Parameters.yt_operations_weight - 1.0) > 1e-6:
            cmd.extend(["--map-reduce-weight", str(self.Parameters.yt_operations_weight)])

        if self.Parameters.nirvana_ns_id:
            cmd.extend(["--nirvana-ns-id", str(self.Parameters.nirvana_ns_id)])
        elif self.Parameters.nirvana_project:
            cmd.extend(["--nirvana-project", str(self.Parameters.nirvana_project)])
        else:
            raise RuntimeError("No nirvana NsID or project code specified")

        if self.Parameters.email_from and self.Parameters.email_to:
            cmd.extend(["--email-from", self.Parameters.email_from, "--email-to"])
            cmd.extend(self.Parameters.email_to)

        tmp_file = "_wfid_tmp_file"
        cmd.extend(["--output-to", tmp_file])

        env_vars = {
            "VH_TOKEN": self.Parameters.secret_name.data()[self.Parameters.secret_name.default_key]
        }

        with sdk2.helpers.ProcessLog(self, logger="graph_builder") as pl:
            sp.check_call(cmd, stdout=pl.stdout, stderr=pl.stderr, env=env_vars)

        with open(tmp_file, "r") as fh:
            self.Context.wf_id = fh.read().strip("\n")

        self.Context.save()

    @sdk2.footer(title="Created nirvana workflows")
    def report(self):
        if self.Context.wf_id:
            return "<a href='{}'>{}</a>".format(
                "https://nirvana.yandex-team.ru/flow/{}".format(self.Context.wf_id),
                "Workflow"
            )
        else:
            return "No workflows created yet"
