# -*- coding: utf-8 -*-
import logging
import itertools


from sandbox.common.types.client import Tag

from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk import parameters as sp

from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common import utils
from sandbox.projects import resource_types
from sandbox.projects.common import dolbilka
from sandbox.projects.common.differ import coloring
from sandbox.projects.websearch.basesearch.CompareBasesearchPerformance import MaxDiff, MaxDiffPercent, CompareType, TopResults, MaxMemoryDiff


class PlanParameter(sp.LastReleasedResource):
    name = 'dolbilo_plan_resource_id'
    description = 'Plan'
    resource_type = [
        resource_types.BASESEARCH_PLAN,
    ]


class MaxSequentialSessions(sp.SandboxIntegerParameter):
    name = 'max_sequential_sessions'
    description = 'Switch binary after n sessions'
    default_value = 0


class BaseTestPerformanceParallelTask(SandboxTask):
    """
        Базовый класс для обстрела бинарей двух разных версий на одной машине
    """

    type = 'TEST_PERFORMANCE_PARALLEL'
    client_tags = Tag.LINUX_PRECISE

    input_parameters = (
        (
            PlanParameter,
        ) +
        dolbilka.DolbilkaExecutor.input_task_parameters +
        (
            MaxSequentialSessions,
            TopResults,
            CompareType,
            MaxDiff,
            MaxDiffPercent,
            MaxMemoryDiff,
        )
    )

    def on_execute(self):
        self.prepare_environment()

        bin1 = self._get_first_bin()
        bin2 = self._get_second_bin()

        self._init_data()

        d_executor = dolbilka.DolbilkaExecutor()

        sessions_length = self._get_sessions_length()
        for session_length in sessions_length:
            for binary_index, binary in enumerate([bin1, bin2]):
                # alter sessions
                d_executor.sessions = session_length

                # clear global appendable context
                for global_dolbilka_key in ("memory_rss", "memory_vsz", "memory_bytes"):
                    self.ctx[global_dolbilka_key] = []

                run_results = d_executor.run_sessions(
                    self.sync_resource(self.ctx[PlanParameter.name]),
                    binary,
                    run_once=True,
                )
                local_dolbilo_ctx = {}
                dolbilka.DolbilkaPlanner.fill_rps_ctx(run_results, local_dolbilo_ctx)
                local_memory_ctx = {}
                local_memory_ctx['memory_rss'] = self.ctx.get("memory_rss", [0])
                self._prepare_data_for_footer(local_dolbilo_ctx, 1 + binary_index, session_length)
                self.ctx["dolbilo_contexts"][binary_index].append(local_dolbilo_ctx)
                self.ctx["memory_contexts"][binary_index].append(local_memory_ctx)

        top_results_number = self.ctx[TopResults.name]
        compare_result = True
        rps_compare_result, self.ctx["rps_stats"] = \
            self._check_performance(
                itertools.chain.from_iterable([dolbilo_ctx.get("requests_per_sec", []) for dolbilo_ctx in self.ctx["dolbilo_contexts"][0]]),
                itertools.chain.from_iterable([dolbilo_ctx.get("requests_per_sec", []) for dolbilo_ctx in self.ctx["dolbilo_contexts"][1]]),
                top_results_number
        )
        compare_result = compare_result and rps_compare_result

        mem_compare_result, self.ctx["memory_stats"] = \
            self._check_memory(
                itertools.chain.from_iterable([memory_ctx.get("memory_rss", []) for memory_ctx in self.ctx["memory_contexts"][0]]),
                itertools.chain.from_iterable([memory_ctx.get("memory_rss", []) for memory_ctx in self.ctx["memory_contexts"][1]]),
                top_results_number
        )
        compare_result = compare_result and mem_compare_result

        self.ctx["rps_compare_result"] = rps_compare_result
        self.ctx["mem_compare_result"] = mem_compare_result
        self.ctx["compare_result"] = compare_result
        self.ctx["has_diff"] = not compare_result

    @property
    def footer(self):
        """
            Common footer for performance tests
        """
        if not self.is_completed():
            return None

        top1_rps = self.ctx["rps_stats"][0][0]
        diff_rps = self.ctx["rps_stats"][4]
        diff_percent_rps = 100.0 * diff_rps / top1_rps

        med1_mem = self.ctx["memory_stats"][5]
        diff_mem = self.ctx["memory_stats"][4]
        diff_percent_mem = 100.0 * diff_mem / med1_mem

        rps_list = self.ctx.get("rps_list", [0])
        bin_type_list = self.ctx.get("bin_type_list", [0])

        max_rps_by_type = {}
        for rps, bin_type in zip(rps_list, bin_type_list):
            max_rps_by_type[bin_type] = max(max_rps_by_type.get(bin_type, rps), rps)
        formatted_rps_list = []
        for rps, bin_type in zip(rps_list, bin_type_list):
            if bin_type in max_rps_by_type and rps == max_rps_by_type[bin_type]:
                rps = '<b style="color:red">{}</b>'.format(rps)
                del max_rps_by_type[bin_type]  # highlight only first record
            formatted_rps_list.append(rps)

        head = [
            {"key": "session", "title": "N session"},
            {"key": "binary", "title": "Binary type"},
            {"key": "rps", "title": "RPS"},
            {"key": "fail_rate", "title": "Fail rate"},
            {"key": "rss", "title": "Rss memory (vmtouch)"},
            {"key": "vsz", "title": "Vsz memory (vmtouch)"},
            {"key": "requests_ok", "title": "Requests with OK status"},
        ]
        return [
            {
                'content': {
                    '<h4>Verdict:</h4>': "{}".format('Not significant diff' if not self.ctx["has_diff"] else 'Has significant diff')
                }
            },
            {
                'content': {
                    '<h4>Diff info:</h4>': "RPS diff: {},<br>\nRSS median memory diff: {:+.2f}%".format(
                        coloring.color_diff(diff_percent_rps, max_diff=-utils.get_or_default(self.ctx, MaxDiffPercent)),
                        diff_percent_mem
                    )
                }
            },
            {
                'helperName': '',
                'content': {
                    "<h4>Executor stats:</h4>": {
                        "header": head,
                        "body": {
                            "session": self.ctx.get("dolbilka_executor_sessions_info", [0]),
                            "binary": bin_type_list,
                            "rps": formatted_rps_list,
                            "fail_rate": self.ctx.get("fail_rates_info", [0]),
                            "rss": self.ctx.get("memory_rss_info", [0]),
                            "vsz": self.ctx.get("memory_vsz_info", [0]),
                            "requests_ok": self.ctx.get("requests_ok_status_list", []),
                        }
                    }
                }
            }
        ]

    def _prepare_environment(self):
        """
            Готовит окружение перед запуском
        """

        pass

    def _get_first_bin(self):
        """
            Возвращает первый бинарник
        """

        raise NotImplementedError()

    def _get_second_bin(self):
        """
            Возвращает второй бинарник
        """

        raise NotImplementedError()

    def _check_performance(self, stats_task1, stats_task2, top_results_number):
        """
            Сравнивает производительность
        """
        try:
            top1, avg1 = self._top_results(stats_task1, top_results_number)
        except KeyError:
            eh.check_failed(
                "No rps data aggregated, maybe test results 1 are too old and resources expired"
            )
        try:
            top2, avg2 = self._top_results(stats_task2, top_results_number)
        except KeyError:
            eh.check_failed(
                "No rps data aggregated, maybe test results 2 are too old and resources expired"
            )
        # much better than anything else
        diff = max(top2) - max(top1)

        if self.ctx[CompareType.name] == 'rps':
            compare_result = (abs(diff) <= self.ctx[MaxDiff.name])
        elif self.ctx[CompareType.name] == 'percent':
            if max(top1) < 1.0:
                logging.info('Diff failed -- First task RPSes are too low')
                compare_result = False
            else:
                diff_percent = 100.0 * diff / max(top1)
                abs_diff_percent = abs(diff_percent)
                logging.info('Diff per cent: %f', diff_percent)
                compare_result = (abs_diff_percent <= self.ctx[MaxDiffPercent.name])
        else:
            eh.fail("Unknown compare type: {}".format(self.ctx[CompareType.name]))

        stats = top1, avg1, top2, avg2, diff
        return compare_result, stats

    def _check_memory(self, memory_list_1, memory_list_2, top_results_number):
        """
            Сравнивает использование памяти
        """

        memory_list_1 = list(memory_list_1)
        memory_list_2 = list(memory_list_2)
        top1, avg1 = self._top_results(memory_list_1, top_results_number)
        top2, avg2 = self._top_results(memory_list_2, top_results_number)
        med1 = self._median(memory_list_1)
        med2 = self._median(memory_list_2)
        diff = med2 - med1
        max_diff = med1 * self.ctx[MaxMemoryDiff.name] / 100.0

        compare_result = (abs(diff) <= max_diff)
        stats = top1, avg1, top2, avg2, diff, med1, med2

        return compare_result, stats

    @staticmethod
    def _top_results(results, num):
        """
            Возвращает последние n результатов и среднее значение для них
        """
        results = list(results)
        eh.verify(len(results) >= num, 'Test results should contain at least {} results'.format(num))

        top = sorted(results, reverse=True)[:num]
        avg = sum(top) / len(top)

        return top, avg

    @staticmethod
    def _median(results):
        results = sorted(results)
        n = len(results)
        eh.verify(n >= 1, 'Test results should contain at least one result')
        if n % 2 == 1:
            return results[n // 2]
        else:
            return (results[n // 2 - 1] + results[n // 2]) * 0.5

    def _init_data(self):
        self.ctx["dolbilo_contexts"] = {
            0: [],
            1: [],
        }
        self.ctx["memory_contexts"] = {
            0: [],
            1: [],
        }
        self.ctx["dolbilka_executor_sessions_info"] = []
        self.ctx["bin_type_list"] = []
        self.ctx["rps_list"] = []
        self.ctx["fail_rates_info"] = []
        self.ctx["memory_rss_info"] = []
        self.ctx["memory_vsz_info"] = []
        self.ctx["requests_ok_status_list"] = []

    @staticmethod
    def _transpose_dolbilka_results(results):
        keys = frozenset(itertools.chain.from_iterable(result.keys() for result in results))
        rd = {key: [] for key in keys}
        for result in results:
            for key in keys:
                rd[key].append(result.get(key, None))
        return rd

    @staticmethod
    def _safe_cast(val, to_type, default=None):
        try:
            return to_type(val)
        except (ValueError, TypeError):
            return default

    def _prepare_data_for_footer(self, ctx, bin_type, dolbilka_executor_sessions):
        self.ctx["dolbilka_executor_sessions_info"] += range(dolbilka_executor_sessions)
        self.ctx["bin_type_list"] += [bin_type] * dolbilka_executor_sessions
        self.ctx["fail_rates_info"] += ctx.get("fail_rates", [0])

        rps_list = list(ctx.get("requests_per_sec", [0]))
        self.ctx["rps_list"] += rps_list

        dlb_trans_results = self._transpose_dolbilka_results(ctx.get("results", []))

        requests_list = [self._safe_cast(requests_num, int, 0) for requests_num in dlb_trans_results.get("requests", [])]
        requests_ok_list = [self._safe_cast(requests_ok_num, int, 0) for requests_ok_num in dlb_trans_results.get("requests_ok", [])]
        assert len(requests_list) == len(requests_ok_list)
        for requests_num, requests_ok_num in zip(requests_list, requests_ok_list):
            msg = '{} / {}'.format(requests_ok_num, requests_num)
            if requests_num != requests_ok_num:
                ok_ratio = float(requests_ok_num) / float(requests_num) if requests_num > 0 else 0.0
                msg = '<span style="color:{};">{}</span>'.format(('red' if ok_ratio < 0.99 else 'olive'), msg)
            self.ctx["requests_ok_status_list"].append(msg)

        self.ctx["memory_rss_info"] += self.ctx.get("memory_rss", [0])
        self.ctx["memory_vsz_info"] += self.ctx.get("memory_vsz", [0])

    def _get_sessions_length(self):
        max_consequential_sessions = int(utils.get_or_default(self.ctx, MaxSequentialSessions))
        dolbilka_sessions = int(utils.get_or_default(self.ctx, dolbilka.DolbilkaSessionsCount))
        if max_consequential_sessions <= 0:
            return [dolbilka_sessions]
        q, r = divmod(dolbilka_sessions, max_consequential_sessions)
        result = [max_consequential_sessions] * q
        if r:
            result = [r] + result
        return result
