from sandbox import sdk2
import sandbox.common.types.task as ctt
import sandbox.common.types.misc as ctm
from sandbox.common.errors import TaskFailure

from sandbox.projects import resource_types
from sandbox.projects.dj.services.entity.resources import EntityRecommenderBundle
from sandbox.projects.dj.services.entity.resources import EntityRecommenderShard
from sandbox.projects.dj.services.entity.resources import EntityRecommenderArcadiaModels
from sandbox.projects.dj.services.entity.EntityRecommenderPerfTest import EntityRecommenderPerfTest


import itertools

LATENCY_KEY = "shooting.latency_0.99"
MARKS = ["old", "new"]
GENERAL_REQUEST_KEY = "general"
NEW_STATS_KEY = EntityRecommenderPerfTest.new_stats_key
NEW_STATS_TYPES = EntityRecommenderPerfTest.new_stats_types


class EntityRecommenderComparePerfTests(sdk2.Task):
    """ Task for comparing Entity Recommender performance """

    class Context(sdk2.Context):
        shooters_started = False
        performance_loss = "UNKNOWN"
        bundle_ids = dict()
        shard_ids = dict()
        models_ids = dict()
        request_plan_ids = dict()
        shooter_task_ids = {m: dict() for m in MARKS}

    class Parameters(sdk2.Task.Parameters):
        request_count = sdk2.parameters.Integer("Count of requests", default=100000)
        critical_performance_loss = sdk2.parameters.Integer("Critical performance loss in percent", default=5)
        bundle_resource_old = sdk2.parameters.Resource("Resource with old Entity Recommender bundle", resource_type=EntityRecommenderBundle)
        shard_resource_old = sdk2.parameters.Resource("Resource with old Entity Recommender shard", resource_type=EntityRecommenderShard)
        models_resource_old = sdk2.parameters.Resource("Resource with old Entity Recommender models", resource_type=EntityRecommenderArcadiaModels)
        bundle_resource_new = sdk2.parameters.Resource("Resource with new Entity Recommender bundle", resource_type=EntityRecommenderBundle)
        shard_resource_new = sdk2.parameters.Resource("Resource with new Entity Recommender shard", resource_type=EntityRecommenderShard)
        models_resource_new = sdk2.parameters.Resource("Resource with new Entity Recommender models", resource_type=EntityRecommenderArcadiaModels)
        experiments = sdk2.parameters.List("Experiments", default=[])
        required_attributes = sdk2.parameters.Dict("Required attribute for basesearch plan resource", default=dict())

    def on_execute(self):
        self._init_context()
        self._start_shooters()
        self._check_shooters()

    def _init_context(self):
        experiments = self._get_experiments()

        # Prepare resources
        self.Context.bundle_ids["old"] = self._get_prod_or_specified_resource(EntityRecommenderBundle, self.Parameters.bundle_resource_old).id
        self.Context.bundle_ids["new"] = self._get_prod_or_specified_resource(EntityRecommenderBundle, self.Parameters.bundle_resource_new).id
        self.Context.shard_ids["old"] = self._get_prod_or_specified_resource(EntityRecommenderShard, self.Parameters.shard_resource_old).id
        self.Context.shard_ids["new"] = self._get_prod_or_specified_resource(EntityRecommenderShard, self.Parameters.shard_resource_new).id
        self.Context.models_ids["old"] = self._get_prod_or_specified_resource(EntityRecommenderArcadiaModels, self.Parameters.models_resource_old).id
        self.Context.models_ids["new"] = self._get_prod_or_specified_resource(EntityRecommenderArcadiaModels, self.Parameters.models_resource_new).id

        # Find plans
        for exp in experiments:
            attrs = dict(self.Parameters.required_attributes)
            attrs['entity_recommender_experiment'] = exp
            plan = sdk2.Resource.find(type=resource_types.BASESEARCH_PLAN,
                                      attrs=self.Parameters.required_attributes).limit(1).first()
            if not plan:
                raise TaskFailure("Failed to find ENTITY RECOMMENDER plan with requests for experiment " % exp)
            self.Context.request_plan_ids[exp] = plan.id

    def _get_experiments(self):
        return self.Parameters.experiments + [GENERAL_REQUEST_KEY]

    def _start_shooters(self):
        if self.Context.shooters_started:
            return

        subtasks = []
        for exp, mark in itertools.product(self._get_experiments(), MARKS):
            task_parameters = {
                "request_count": self.Parameters.request_count,
                "dolbilka_plan_resource": self.Context.request_plan_ids[exp],
                "bundle_resource": self.Context.bundle_ids[mark],
                "shard_resource": self.Context.shard_ids[mark],
                "models_resource": self.Context.models_ids[mark]
            }
            perf_test_task = EntityRecommenderPerfTest(
                self,
                description="Shooter task for compare performance of entity recommender (stand={}, experiment={})".format(mark, exp),
                owner=self.Parameters.owner,
                priority=self.Parameters.priority,
                notifications=self.Parameters.notifications,
                **task_parameters
            )
            perf_test_task.enqueue()
            subtasks.append(perf_test_task)

            perf_test_tasks = self.Context.shooter_task_ids.setdefault(mark, dict())
            perf_test_tasks[exp] = perf_test_task.id
        self.Context.shooters_started = True
        raise sdk2.WaitTask(subtasks, ctt.Status.Group.FINISH | ctt.Status.Group.BREAK, wait_all=True)

    @staticmethod
    def _get_prod_or_specified_resource(resource_type, default=None):
        if default is None:
            return sdk2.Resource.find(resource_type, attrs={"released": "stable"}).limit(1).first()
        return default

    def _check_shooters(self):
        shooter_tasks = self._get_shooter_tasks()
        for exp in self._get_experiments():
            stats = {mark: getattr(shooter_tasks[mark][exp].Context, NEW_STATS_KEY) for mark in MARKS}
            if any(st is ctm.NotExists for st in stats.values()):
                raise TaskFailure("Can not find stats for {} experiment".format(exp))
            old, new = [stats.get(mark).get(LATENCY_KEY) for mark in MARKS]
            if old is None or new is None:
                raise TaskFailure("Can not find latency info for {} experiment".format(exp))
            if new >= old * (100 + self.Parameters.critical_performance_loss) / 100:
                self.Context.performance_loss = "CRITICAL"
                raise TaskFailure("Critical performance loss!")
        self.Context.performance_loss = "OK"

    @property
    def footer(self):
        content = []
        shooter_tasks = self._get_shooter_tasks()
        for exp in self._get_experiments():
            stats = dict()
            task_statuses = []
            for mark in MARKS:
                task_status = getattr(shooter_tasks[mark][exp], 'status', None)
                if not task_status:
                    continue
                task_statuses.append("{} <span class='status status_{}'>{}</span>".format(mark, task_status.lower(), task_status))
                if NEW_STATS_KEY in shooter_tasks[mark][exp].Context:
                    stats[mark] = getattr(shooter_tasks[mark][exp].Context, NEW_STATS_KEY)

            if not task_statuses:
                continue

            row = {
                "Experiment": exp,
                "Status": "<br/>".join(task_statuses),
            }

            row.update({title: self._make_ratio_str(stats, key, fmt) for key, title, fmt in NEW_STATS_TYPES})
            content.append(row)

        return {
            "<h3>Performance stats</h3>": content or "Waiting for shooting tasks...",
            "<h3>Performance loss</h3>": self.Context.performance_loss,
        }

    def _get_shooter_tasks(self):
        sub_tasks = {task.id: task for task in self.find()}
        shooter_tasks = {mark: dict() for mark in MARKS}
        for exp, mark in itertools.product(self._get_experiments(), MARKS):
            shooter_task_id = self.Context.shooter_task_ids.get(mark).get(exp)
            shooter_tasks[mark][exp] = None
            if shooter_task_id is not None:
                shooter_tasks[mark][exp] = sub_tasks[shooter_task_id]
        return shooter_tasks

    def _make_ratio_str(self, stats, key, fmt):
        values = [stats.get(mark, dict()).get(key) for mark in MARKS]
        old, new = values
        formatted_values = [fmt.format(val) if val is not None else "???" for val in values]
        diff = ""
        if None not in values and values[1] != 0:
            diff_val = (new - old) / old * 100
            diff = " ({0:+.1f}%)".format(diff_val)
        return " <br/>".join([' '.join(data) for data in zip(MARKS, formatted_values)]) + diff
