import operator

from sandbox.common.errors import TaskFailure
from sandbox.projects.collections.recommender_base.shooting_task2 import ShootingTask2
from sandbox.projects.common import dolbilka2
from sandbox.projects.common.search import performance as search_performance
from sandbox.projects.dj.services.entity import helpers
from sandbox.projects.dj.services.entity.helpers import TemporaryDirectory
from sandbox.projects.dj.services.entity.resources import EntityRecommenderShard
from sandbox.projects.dj.services.entity.resources import EntityRecommenderShard2
from sandbox.projects.dj.services.entity.resources import EntityRecommenderBundle
from sandbox.projects.dj.services.entity.resources import EntityRecommenderArcadiaModels
from sandbox.projects.dj.services.entity.resources import EntityRecommenderShootingPlan
from sandbox.projects.resource_types import TASK_CUSTOM_LOGS
from sandbox.projects.tank import executor2 as tank_executor
from sandbox.sdk2 import Task
from sandbox.sdk2 import ResourceData
from sandbox.sdk2.helpers import ProcessLog
from sandbox.sdk2.parameters import Integer
from sandbox.sdk2.parameters import Bool
from sandbox.sdk2.parameters import Group
from sandbox.sdk2.parameters import Dict
from sandbox.sdk2.parameters import String
from sandbox.sdk2.parameters import Resource
from sandbox.sdk2.parameters import List

import os

WARMING_UP_SUFFIX = '-warming-up'
NEW_TAG = 'new'
OLD_TAG = 'old'
SHOOTING_TAGS = [
    NEW_TAG,
    OLD_TAG,
    NEW_TAG + WARMING_UP_SUFFIX,
    OLD_TAG + WARMING_UP_SUFFIX,
]
REQUIRED_DELTAS = [
    ("diff", OLD_TAG + WARMING_UP_SUFFIX, NEW_TAG + WARMING_UP_SUFFIX),
    ("diff_retest", OLD_TAG, NEW_TAG),
    ("diff_same", NEW_TAG + WARMING_UP_SUFFIX, NEW_TAG)
]


class EntityRecommenderComparePerfTestParallelParameters(Task.Parameters):
    new_specify_shooting_plan = Bool('Specify new plan', default=False)
    old_specify_shooting_plan = Bool('Specify old plan', default=False)
    with new_specify_shooting_plan.value[True]:
        new_shooting_plan = Resource("Resource with dolbilka new plan", resource_type=EntityRecommenderShootingPlan, required=True)
    with new_specify_shooting_plan.value[False]:
        new_shooting_plan_attributes = Dict("Required attribute for new plan resource", default=dict())
    with old_specify_shooting_plan.value[True]:
        old_shooting_plan = Resource("Resource with dolbilka old plan", resource_type=EntityRecommenderShootingPlan, required=True)
    with old_specify_shooting_plan.value[False]:
        old_shooting_plan_attributes = Dict("Required attribute for old plan resource", default=dict())
    http_port = Integer("Http Port", default=15368, required=True)
    apphost_port = Integer("AppHost Port", default=15369, required=True)
    use_apphost = Bool("Use apphost port instead of http", default=False)
    patch = String('Text patch for config (has priority over bundle patches)')
    bundle_patches = List('Patches for config from bundle', default=[])
    verbose_level = Integer("Verbose level", default=0)
    write_access_log = Bool('Write access log', default=False)
    enable_realtime_update = Bool("Enable realtime update", default=False)
    with Group('New service') as new_service:
        new_bundle_resource = Resource('Entity Recommender bundle', resource_type=EntityRecommenderBundle)
        new_shard_resource = Resource('Entity Recommender shard', resource_type=EntityRecommenderShard)
        new_shard2_resource = Resource('Entity Recommender shard 2.0', resource_type=EntityRecommenderShard2)
        new_models_resource = Resource('Entity Recommender models', resource_type=EntityRecommenderArcadiaModels)
    with Group('Old service') as old_service:
        old_bundle_resource = Resource('Entity Recommender bundle', resource_type=EntityRecommenderBundle)
        old_shard_resource = Resource('Entity Recommender shard', resource_type=EntityRecommenderShard)
        old_shard2_resource = Resource('Entity Recommender shard 2.0', resource_type=EntityRecommenderShard2)
        old_models_resource = Resource('Entity Recommender models', resource_type=EntityRecommenderArcadiaModels)
    with Group('Checking parameters') as checking_parameters:
        check_perf_loss = Bool("Check performance loss", default=False)
        with check_perf_loss.value[True]:
            rps_threshold = Integer("RPS threshold (perc)", default=0)
            rps_min_abs_diff = Integer("RPS significant diff", default=0)
            latency_p50_threshold = Integer("Latency p50 threshold (perc)", default=0)
            latency_p50_min_abs_diff = Integer("Latency p50 significant diff", default=0)
            latency_p99_threshold = Integer("Latency p99 threshold (perc)", default=10)
            latency_p99_min_abs_diff = Integer("Latency p99 significant diff", default=5000)
        check_errors = Bool("Check response errors", default=True)
        with check_errors.value[True]:
            errors_threshold_abs = Integer("Critical errors threshold (abs value)", default=0)
            errors_threshold_perc = Integer("Critical errors threshold (perc)", default=25)
    with Group('Shooter parameters') as shooter_parameters:
        dolbilka_param = dolbilka2.DolbilkaExecutor2.Parameters
        lunapark_param = tank_executor.LunaparkPlugin.Parameters
        offline_param = tank_executor.OfflinePlugin.Parameters


class EntityRecommenderComparePerfTestParallelContext(Task.Context):
    bundle_ids = dict()
    shard_ids = dict()
    shard2_ids = dict()
    models_ids = dict()
    dssm_ids = dict()
    shooting_plan_ids = dict()


class EntityRecommenderComparePerfTestParallelRequirements(Task.Requirements):
    ram = 10 * 1024
    disk_space = 10 * 1024
    client_tags = search_performance.NewShootingTask.client_tags
    cores = 4


class EntityRecommenderComparePerfTestParallel(Task, ShootingTask2):
    """
        Parallel version of entity recommender performance task.
        1) Run first entity recommender (to map data into memory)
        2) Run second entity recommender, then run first entity recommender again and compare their results (diff_retest in footer)
    """

    new_stats_types = sorted(
        ShootingTask2.new_stats_types + (
            ("shooting.latency_0.95", "Latency P95", "{:0.2f}"),
            ("dumper.total_requests", "Total requests", "{:0.2f}"),
        ),
        key=operator.itemgetter(1)
     )

    class Parameters(EntityRecommenderComparePerfTestParallelParameters):
        pass

    class Requirements(EntityRecommenderComparePerfTestParallelRequirements):
        pass

    class Context(EntityRecommenderComparePerfTestParallelContext):
        pass

    def on_execute(self):
        context = self.Context
        parameters = self.Parameters
        self._init_virtualenv()
        self._init_recommender_and_shoot(
               explicit_bundle=parameters.new_bundle_resource,
               explicit_shard=parameters.new_shard_resource,
               explicit_shard2=parameters.new_shard2_resource,
               explicit_models=parameters.new_models_resource,
               explicit_shooting_plan=parameters.new_shooting_plan,
               shooting_plan_attributes=parameters.new_shooting_plan_attributes,
               tag=NEW_TAG)
        self._init_recommender_and_shoot(
               explicit_bundle=parameters.old_bundle_resource,
               explicit_shard=parameters.old_shard_resource,
               explicit_shard2=parameters.old_shard2_resource,
               explicit_models=parameters.old_models_resource,
               explicit_shooting_plan=parameters.old_shooting_plan,
               shooting_plan_attributes=parameters.old_shooting_plan_attributes,
               tag=OLD_TAG)
        self._calculate_diff()
        self._check_performance_loss()
        self._check_errors()

    def _init_recommender_and_shoot(
            self,
            explicit_bundle,
            explicit_shard,
            explicit_shard2,
            explicit_models,
            explicit_shooting_plan,
            shooting_plan_attributes,
            tag):
        with TemporaryDirectory(dir=os.getcwd()) as working_dir, \
                ProcessLog(self, logger="entity_recommender_" + tag) as process_log:
            self._prepare_resources(
                    working_dir, explicit_shard, explicit_shard2, explicit_bundle,
                    explicit_models, explicit_shooting_plan, shooting_plan_attributes, tag)
            self._shoot_entity_recommender(working_dir, process_log, tag)

    def _prepare_resources(
            self, working_dir, explicit_shard, explicit_shard2,
            explicit_bundle, explicit_models, explicit_shooting_plan,
            shooting_plan_attributes, tag):
        context = self.Context
        shard2 = helpers.find_shard2_resource(explicit_resource=explicit_shard2)
        bundle = helpers.find_bundle_resource(explicit_resource=explicit_bundle)
        models = helpers.find_models_resource(explicit_resource=explicit_models)
        dssm = helpers.find_dssm_resource()
        shooting_plan = helpers.find_shooting_plan_resource(
                explicit_resource=explicit_shooting_plan,
                plan_attributes=shooting_plan_attributes)
        bundle_data = ResourceData(bundle)
        shard2_data = ResourceData(shard2)
        models_data = ResourceData(models)
        dssm_data = ResourceData(dssm)
        context.shard2_ids[tag] = shard2.id
        context.bundle_ids[tag] = bundle.id
        context.models_ids[tag] = models.id
        context.dssm_ids[tag] = dssm.id
        context.shooting_plan_ids[tag] = shooting_plan.id
        helpers.prepare_entity_recommender_resources(
                bundle_data, shard2_data, models_data, dssm_data, working_dir)

    def _shoot_entity_recommender(
            self,
            working_dir,
            process_log,
            tag):
        parameters = self.Parameters
        context = self.Context
        access_log_data = None
        access_log_path = None
        if parameters.write_access_log:
            access_log_data = ResourceData(TASK_CUSTOM_LOGS(self,
                "{tag} access log".format(tag=tag),
                "access_log_{tag}.txt".format(tag=tag)))
            access_log_path = helpers.resource_data_path(access_log_data)
        realtime_update_config_path = None
        if parameters.enable_realtime_update:
            realtime_update_config_path = os.path.join(working_dir, "realtime_update_config.pbtxt")
        with helpers.launch_entity_recommender(
                binary_path=os.path.join(working_dir, "entity_recommender_service"),
                http_port=parameters.http_port,
                apphost_port=parameters.apphost_port,
                config_path=os.path.join(working_dir, "config.pbtxt"),
                patch_paths=helpers.prepare_patches(parameters.bundle_patches, parameters.patch, working_dir),
                working_dir_path=working_dir,
                realtime_update_config_path=realtime_update_config_path,
                process_log=process_log,
                access_log_path=access_log_path,
                verbose=parameters.verbose_level):
            shooting_plan = helpers.find_shooting_plan_resource(explicit_id=context.shooting_plan_ids[tag])
            for shooting_name in [tag + WARMING_UP_SUFFIX, tag]:
                port = parameters.apphost_port if parameters.use_apphost else parameters.http_port
                self._switch_new_stats_key(self._make_stats_key(shooting_name))
                self._dolbilo_shoot(port, shooting_plan, shooting_name)
        if access_log_data:
            access_log_data.ready()


    def _calculate_diff(self):
        context = self.Context
        stats = self._get_stats()
        for delta_key, first_stat, second_stat in REQUIRED_DELTAS:
            row = dict()
            for key in stats[NEW_TAG]:
                if self._is_stats_key(key):
                    row[key] = self.delta_percent(stats[first_stat][key], stats[second_stat][key])
            setattr(context, self._make_stats_key(delta_key), row)
            stats[delta_key] = row

    def _check_performance_loss(self):
        context = self.Context
        parameters = self.Parameters
        if not parameters.check_perf_loss:
            return
        stats = self._get_stats(True)

        checks = (
            (
                "shooting.rps_0.5",
                parameters.rps_threshold,
                parameters.rps_min_abs_diff,
                "rps_",
                -1,
            ),
            (
                "shooting.latency_0.5",
                parameters.latency_p50_threshold,
                parameters.latency_p50_min_abs_diff,
                "p50_",
                1,
            ),
            (
                "shooting.latency_0.99",
                parameters.latency_p99_threshold,
                parameters.latency_p99_min_abs_diff,
                "p99_",
                1,
            ),
        )

        if any(not hasattr(context, key) for key in ['diff_retest', NEW_TAG, OLD_TAG]):
            raise TaskFailure("Some required fields for performance loss checking are missed")

        crit = False
        for key, threshold, min_abs_diff, prefix, sign in checks:
            comp = operator.ge if sign == 1 else operator.le
            new = stats[NEW_TAG].get(key, 0.0)
            old = stats[OLD_TAG].get(key, 0.0)
            diff = abs(stats["diff_retest"].get(key, 0.0))
            check_abs_diff_crit = min_abs_diff != 0 and comp(new - old, abs(min_abs_diff) * sign)
            check_threshold_crit = threshold != 0 and comp(diff, abs(threshold) * sign)
            check_crit = (check_abs_diff_crit or min_abs_diff == 0) and check_threshold_crit
            crit = crit or check_crit
            setattr(context, prefix + "threshold_crit", check_threshold_crit)
            setattr(context, prefix + "check_abs_diff_crit", check_abs_diff_crit)
            setattr(context, prefix + "crit", check_crit)

        if crit:
            raise TaskFailure("Critical perf loss")

    def _check_errors(self):
        context = self.Context
        parameters = self.Parameters
        if not parameters.check_errors:
            return

        stats = self._get_stats()
        if not hasattr(context, NEW_TAG):
            raise TaskFailure("New bundle stats are missed")

        total = stats[NEW_TAG].get("dumper.total_requests", 0.0)
        errors = stats[NEW_TAG].get("shooting.errors", 0.0)

        if total == 0.0:
            raise TaskFailure("No requests has been send")

        abs_crit = parameters.errors_threshold_abs != 0.0 and \
                   errors >= parameters.errors_threshold_abs
        perc_crit = parameters.errors_threshold_perc != 0.0 and \
                    errors / total * 100.0 >= parameters.errors_threshold_perc

        setattr(context, "errors_check_abs_crit", abs_crit)
        setattr(context, "errors_check_perc_crit", perc_crit)

        if abs_crit or perc_crit:
            raise TaskFailure("Critical errors count")

    def _get_stats(self, with_diffs=False):
        context = self.Context
        tags = SHOOTING_TAGS
        if with_diffs:
            tags = tags + self._get_deltas_keys()
        return {tag: getattr(context, self._make_stats_key(tag)) for tag in tags}

    def _switch_new_stats_key(self, new_stats_key):
        self.new_stats_key = new_stats_key

    def _is_stats_key(self, key):
        for stat_key, stat_title, stat_fmt in self.new_stats_types:
            if key == stat_key:
                return True
        return False

    @staticmethod
    def _get_deltas_keys():
        return [delta_key for delta_key, first_stat, second_stat in REQUIRED_DELTAS]

    @staticmethod
    def _make_stats_key(tag):
        return "{}_{}".format(ShootingTask2.new_stats_key, tag)

    @staticmethod
    def _sync_resources(*args):
        res_data = dict()
        for res in args:
            if res.id not in res_data:
                res_data[res.id] = ResourceData(res)
        return res_data

    @property
    def footer(self):
        stats = self._get_stats(with_diffs=True)
        variants = sorted(stats.keys())

        header = [
            {"key": "num", "title": "&nbsp;"}
        ] + [
            {"key": key, "title": title} for key, title, _ in self.new_stats_types
        ]

        body = {"num": []}
        for variant in variants:
            body["num"].append(self._format_report(variant, stats[variant], add_hash=False))
            for key, title, fmt in self.new_stats_types:
                body.setdefault(key, []).append(self._format_stats(fmt, stats, variant, key))

        return {
            "<h3>Performance stats</h3>": {
                "header": header,
                "body": body,
            }
        }

    @staticmethod
    def _format_report(num, results, add_hash=True, strong=False):
        if "report_url" in results:
            report_url = results["report_url"]
        elif "report_resource" in results:
            report_url = "//proxy.sandbox.yandex-team.ru/{}/{}".format(results["report_resource"], results["report_path"])
        else:  # No additional formatting
            report_url = None

        num = "#{}".format(num) if add_hash else num
        report = "<a href='{}'>{}</a>".format(report_url, num) if report_url else num
        return "<strong>{}</strong>".format(report) if strong else report

    @staticmethod
    def _format_stats(fmt, stats, variant, key):
        if key not in stats[variant]:
            return "-"
        return fmt.format(stats[variant][key])

    @staticmethod
    def delta_percent(a, b):
        if a != 0:
            return (float(b - a) / a) * 100.0
        elif b != 0:
            return 100.0
        else:
            return 0.0

