import logging
import itertools

import numpy as np

import sandbox.common.types.client as ctc
import sandbox.sandboxsdk.parameters as sp

from sandbox.sandboxsdk import process
from sandbox.sandboxsdk import environments as sb_env
from sandbox.projects import resource_types
from sandbox.projects.common.differ import coloring
from sandbox.projects.common.search import components as sc
from sandbox.projects.common.search import performance as search_performance
from sandbox.projects.common.search.basesearch import task as search_task
from sandbox.projects.common import file_utils as fu
from sandbox.projects.common import stats as stats
from sandbox.projects.common import utils
from sandbox.projects.websearch.basesearch import constants as bs_const
from sandbox.projects.websearch.basesearch.PriemkaBasesearchBinary import PlotImage

BASESEARCH1_PARAMS = sc.create_basesearch_params(n=1)
BASESEARCH2_PARAMS = sc.create_basesearch_params(n=2)
DB_TYPES = ["platinum", "web0", "web1"]
COMBINATIONS = ["{}_{}".format(*i) for i in list(itertools.product(DB_TYPES, bs_const.QUERY_TYPES))] + ["unknown"]


class Plan1Parameter(sp.ResourceSelector):
    name = 'dolbilo_plan1_resource_id'
    description = 'Plan for basesearch #1'
    resource_type = resource_types.BASESEARCH_PLAN
    required = True


class Plan2Parameter(sp.ResourceSelector):
    name = 'dolbilo_plan2_resource_id'
    description = 'Plan for basesearch #2'
    resource_type = resource_types.BASESEARCH_PLAN
    required = True


class RunType(sp.SandboxStringParameter):
    """
        Type of running task
    """
    name = 'run_type'
    description = "Run type"
    choices = zip(COMBINATIONS, COMBINATIONS)
    default_value = "unknown"


class NumberOfRuns(sp.SandboxIntegerParameter):
    name = "number_of_runs"
    description = "Number of runs"
    default_value = "50"


class WebBasesearchPerformanceMultiruns(
    search_performance.NewShootingTask,
    search_task.BasesearchComponentTask
):
    """
        Parallel version of basesearch performance task.
        1) Run first basesearch (to map data into memory)
        2) Run second basesearch, then run first basesearch again and compare their results (diff_retest in footer)
    """
    type = 'WEB_BASESEARCH_PERFORMANCE_MULTIRUNS'
    client_tags = search_performance.NewShootingTask.client_tags & ctc.Tag.LXC
    execution_space = 110 * 1024
    required_ram = 96 * 1024
    input_parameters = (
        (RunType,) +
        (NumberOfRuns,) +
        BASESEARCH1_PARAMS.params + (Plan1Parameter,) +
        BASESEARCH2_PARAMS.params + (Plan2Parameter,) +
        search_task.BasesearchComponentTask.basesearch_input_parameters +
        search_performance.NewShootingTask.shoot_input_parameters
    )
    environment = [
        sb_env.PipEnvironment("matplotlib", '1.5.1', use_wheel=True),
        sb_env.PipEnvironment("scipy", '0.19.0', use_wheel=True),
    ]
    new_stat_types = search_performance.NewShootingTask.new_stats_types + (
        ("shooting.latency_0.95", "Latency P95", "{:0.2f}"),
        ("dumper.rps", "RPS", "{:0.2f}"),
    )

    def on_enqueue(self):
        search_task.BasesearchComponentTask.on_enqueue(self)
        self.ctx["kill_timeout"] = 24 * 60 * 60  # 24 hours

    def on_execute(self):
        search_task.BasesearchComponentTask.on_execute(self)
        bs1 = self._basesearch(BASESEARCH1_PARAMS)
        bs2 = self._basesearch(BASESEARCH2_PARAMS)

        self._init_virtualenv()
        self.ctx["bs1_rps"] = []
        self.ctx["bs2_rps"] = []
        number_of_runs = utils.get_or_default(self.ctx, NumberOfRuns)
        for i in range(number_of_runs):
            run_id = "baseline #" + "{}".format(i + 1).zfill(2)
            search_performance.NewShootingTask._dolbilo_shoot(self, bs1, self.ctx[Plan1Parameter.name], run_id)
            self.ctx["bs1_rps"].append(self.ctx[self.new_stats_key][run_id]["dumper.rps"])
        for i in range(number_of_runs):
            run_id = "test #" + "{}".format(i + 1).zfill(2)
            search_performance.NewShootingTask._dolbilo_shoot(self, bs2, self.ctx[Plan2Parameter.name], run_id)
            self.ctx["bs2_rps"].append(self.ctx[self.new_stats_key][run_id]["dumper.rps"])

        self.count_t_criteria()
        self.draw_graph_on_perfs(self.ctx["bs1_rps"], self.ctx["bs2_rps"])

    def count_t_criteria(self):
        """
            Uses Welch's t-test.
            If the p-value is smaller than the threshold,
            e.g. 1%, 5% or 10%, then we reject the null hypothesis of equal averages.
        """
        logging.info("Try to calculate performance stats")
        with sb_env.VirtualEnvironment() as venv:
            logging.info('Installing numpy + scipy...')
            sb_env.PipEnvironment('numpy', use_wheel=True, venv=venv, version="1.12.1").prepare()
            sb_env.PipEnvironment('scipy', use_wheel=True, venv=venv, version="0.19.0").prepare()
            stats_path1 = self.abs_path('rps_sample_1.json')
            stats_path2 = self.abs_path('rps_sample_2.json')
            fu.json_dump(stats_path1, self.ctx['bs1_rps'])
            fu.json_dump(stats_path2, self.ctx['bs2_rps'])
            p = process.run_process(
                [venv.executable, stats.get_module_path() + '/t_test.py', stats_path1, stats_path2],
                log_prefix='t_test',
                outputs_to_one_file=False,
                check=True,
                wait=True,
            )
            t_stats = fu.json_load(p.stdout_path)
            self.ctx.update(t_stats)

    def draw_graph_on_perfs(self, y1_points, y2_points):
        y1_points = sorted(y1_points)[2:-2]
        y2_points = sorted(y2_points)[2:-2]
        import matplotlib.pyplot as plt
        _draw_perf_together(plt, y1_points, y2_points)
        resource = self.create_resource("Plot rps", "plot_rps.png", PlotImage)
        plt.savefig(resource.path)
        plt.close()
        self.ctx["plot_rps"] = resource.id
        self.ctx["plot_rps_proxy_url"] = resource.proxy_url

    def _get_perf_results(self):
        return {
            "median_rps_1": [round(float(self.ctx.get("median_rps_1", 0)), 2)],
            "median_rps_2": [round(float(self.ctx.get("median_rps_2", 0)), 2)],
            "diff": [coloring.color_diff(
                float(self.ctx.get("diff_per_cent_median", 0)),
                max_diff=-1,
                probability=self.ctx.get("diff_probability", 0)
            )],
        }

    @property
    def footer(self):
        performance_footer = search_performance.NewShootingTask.footer.fget(self)
        plot_rps = self.ctx.get("plot_rps")
        proxy_plot_link = '<a href="https://proxy.sandbox.yandex-team.ru/{res_id}">{res_id}</a>'.format(
            res_id=plot_rps
        ) if plot_rps else ""
        return [
            {
                'content': {
                    "<h3>Performance results (Probability of different rps: {} %)</h3>".format(
                        self.ctx.get("diff_probability", "-")
                    ): {
                        "header": [
                            {"key": "median_rps_1", "title": "Baseline median rps"},
                            {"key": "median_rps_2", "title": "Test median rps"},
                            {"key": "diff", "title": "Diff, %"},
                        ],
                        "body": self._get_perf_results(),
                    }
                }
            },
            {"content": {"<h3>Performance plot:</h3>": proxy_plot_link}},
            {'content': performance_footer},
        ]


def _draw_perf_together(plt, y_points1, y_points2):
    x_points1 = list(range(len(y_points1)))
    x_points2 = list(range(len(y_points2)))
    plt.xlabel("shoots")
    plt.ylabel("rps")
    plt.plot(x_points1, y_points1, "b+", label="baseline")
    plt.plot(x_points2, y_points2, "gx", label="test")
    k1 = np.polyfit(x_points1, y_points1, 0)
    k2 = np.polyfit(x_points2, y_points2, 0)
    plt.plot(x_points1, list(itertools.repeat(k1, len(y_points1))), ":r", label="min mse for baseline")
    plt.plot(x_points2, list(itertools.repeat(k2, len(y_points2))), "--y", label="min mse for test")
    plt.legend(loc=0)


__Task__ = WebBasesearchPerformanceMultiruns
