# -*- coding: utf-8 -*-
import os
import logging
import jinja2
from collections import OrderedDict

from sandbox import sdk2
from sandbox.sandboxsdk import environments
import sandbox.common.types.client as ctc
import sandbox.common.types.task as ctt
import sandbox.projects.resource_types as rt
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
import sandbox.projects.release_machine.core.task_env as task_env
from sandbox.projects.common.geosearch.utils_sdk2 import make_parameters_cls, create_parameters_group
from sandbox.projects.common.geosearch.search_components_sdk2 import get_addrs_basesearch_params
from sandbox.projects.common import dolbilka2
from sandbox.projects.tank import executor2 as tank_executor
from sandbox.projects.geosearch.AddrsBasesearchPerformanceParallel import AddrsBasesearchPerformanceParallel
from sandbox.projects.release_machine.helpers.startrek_helper import STHelper
from sandbox.projects.common.geosearch.startrek import StartrekClient
from sandbox.projects.release_machine.components import all as rmc
from sandbox.projects.geosearch.tools import yp_lite


def get_addrs_custom_basesearch_params(param_suffix='', name=''):
    return [
        ('shardmap%s' % param_suffix, sdk2.parameters.Resource('%s shardmap' % name, required=True)),
        ('binary%s' % param_suffix, sdk2.parameters.Resource('%s binary' % name, required=True)),
        ('custom_plan_id%s' % param_suffix, sdk2.parameters.Resource('%s custom plan' % name, resource_type=rt.BASESEARCH_PLAN, required=False))
    ]


class AddrsBasePerformanceParallelAcceptance(sdk2.Task):
    class Requirements(sdk2.Task.Requirements):
        environments = [task_env.TaskRequirements.startrek_client,
                        environments.PipEnvironment('yandex-yt')]
        client_tags = task_env.TaskTags.startrek_client & ctc.Tag.INTEL_E5_2650
        cores = 1
        ram = 8192

        class Caches(sdk2.Requirements.Caches):
            pass

    base_diff_params = get_addrs_custom_basesearch_params('1', 'Reference') + get_addrs_custom_basesearch_params('2', 'Test')
    base_diff_group = create_parameters_group(OrderedDict(base_diff_params), 'Shardmaps and binary')
    base_params = get_addrs_basesearch_params()
    del base_params['geobasesearch_executable_resource_id']
    del base_params['shard_rbtorrent']
    del base_params['custom_plan_id']
    base_params_group = create_parameters_group(base_params, 'Addrs Basesearch parameters')
    launch_type_param = sdk2.parameters.String('Launch type', required=True)
    launch_type_param.choices = (('Release Machine', 'RM'), ('Database', 'DB'))
    launch_type_param.default_value = 'DB'
    launch_type_param.__default_value__ = 'DB'

    Parameters = make_parameters_cls(OrderedDict(
        [('base_diff_group', base_diff_group)] +
        base_diff_params +
        [
            ('startrek_task', sdk2.parameters.String('Startrek task', required=False)),
            ('launch_type', launch_type_param),
            ('push_data', sdk2.parameters.Bool('Push perf data to YT', default=False)),
            ('release_number', sdk2.parameters.Integer('Release number', default=0, required=False)),
            ('component_name', sdk2.parameters.String('Component name', default='addrs_base', required=True)),
            ('dolbilo_plan_resource_id', sdk2.parameters.Resource('Plan', resource_type=rt.BASESEARCH_PLAN))
        ] +
        [
            ('dolbilka_group', sdk2.parameters.Group('Dolbilka parameters')),  # empty group for separating parameters.)
            ('dolbilka_param', dolbilka2.DolbilkaExecutor2.Parameters),
            ('lunapark_param', tank_executor.LunaparkPlugin.Parameters),
            ('offline_param', tank_executor.OfflinePlugin.Parameters)
        ] +
        [('base_params_group', base_params_group)] +
        list(base_params.iteritems())
    ))

    class Context(sdk2.Task.Context):
        shards = 0
        has_diff = False
        result = []

    @sdk2.footer()
    def footer(self):
        if not self.Context.result:
            return 'Processing...'
        template_path = os.path.dirname(os.path.abspath(__file__))
        env = jinja2.Environment(loader=jinja2.FileSystemLoader(template_path))
        return env.get_template("footer.html").render(result=OrderedDict(self.Context.result))

    def parse_shardmap(self, resource):
        shard_lines = [line.strip() for line in sdk2.ResourceData(resource).path.open('r') if line.strip() != '']
        return [shard_line.split()[1].split('(')[0] for shard_line in shard_lines]

    def start_tasks(self):
        shards1 = self.parse_shardmap(self.Parameters.shardmap1)
        shards2 = self.parse_shardmap(self.Parameters.shardmap2)
        if len(shards1) != len(shards2):
            raise SandboxTaskFailureError("Different number of shards: %s and %s" % (len(shards1), len(shards2)))
        if len(shards1) == 0:
            raise SandboxTaskFailureError("Empty shardmap")
        self.Context.shards = len(shards1)
        tasks = []
        for shard1, shard2 in zip(shards1, shards2):
            self.logger.info("Create task %s %s" % (shard1, shard2))
            args = {}
            for key in type(self).base_params.iterkeys():
                args['%s1' % key] = getattr(self.Parameters, key)
                args['%s2' % key] = getattr(self.Parameters, key)
            args.update(self.Parameters.dolbilka_param)
            args.update(self.Parameters.lunapark_param)
            args.update(self.Parameters.offline_param)

            task = AddrsBasesearchPerformanceParallel(
                self,
                description="Subtask for #%s (%s)" % (self.id, self.Parameters.description),
                create_sub_task=True,
                dolbilo_plan_resource_id=self.Parameters.dolbilo_plan_resource_id,
                shard_rbtorrent1=shard1,
                shard_rbtorrent2=shard2,
                custom_plan_id1=self.Parameters.custom_plan_id1,
                custom_plan_id2=self.Parameters.custom_plan_id2,
                geobasesearch_executable_resource_id1=self.Parameters.binary1,
                geobasesearch_executable_resource_id2=self.Parameters.binary2,
                **args)
            task.enqueue()
            tasks.append(task)
        raise sdk2.WaitTask(tasks, ctt.Status.Group.FINISH + ctt.Status.Group.BREAK, wait_all=True)

    def aggregate_shard_stats(self, shard_stats):
        stats = {}
        # Average top half to avoid down outliers and average rps jumps over time.
        rps = [session_stats['rps'] for session_stats in shard_stats]
        middle = len(rps) / 2
        avr_count = float(len(rps) - middle)
        stats['rps'] = sum(sorted(rps)[middle:]) / avr_count
        stats['latency_95'] = min([session_stats['latency_0.95'] for session_stats in shard_stats])
        stats['latency_99'] = min([session_stats['latency_0.99'] for session_stats in shard_stats])
        stats['rss'] = max([session_stats['rss'] for session_stats in shard_stats])
        stats['vsz'] = max([session_stats['vsz'] for session_stats in shard_stats])
        return stats

    def aggregate_shard_all_stats(self, all_stats):
        result = []
        for stat in all_stats:
            agg_stat = {}
            for key, value in stat.iteritems():
                agg_stat[key] = self.aggregate_shard_stats(value)
            result.append(agg_stat)
        return result

    def aggregate_stats(self, all_stats_agg, key):
        stats = {}
        stats['rps'] = sum([stat[key]['rps'] for stat in all_stats_agg]) / len(all_stats_agg)
        stats['latency_95'] = min([stat[key]['latency_95'] for stat in all_stats_agg])
        stats['latency_99'] = min([stat[key]['latency_99'] for stat in all_stats_agg])
        stats['rss'] = max([stat[key]['rss'] for stat in all_stats_agg])
        stats['vsz'] = max([stat[key]['vsz'] for stat in all_stats_agg])
        return stats

    def compare_stats(self, reference_stats, test_stats):
        result = []
        keys = ['rps',
                'latency_95',
                'latency_99',
                'rss',
                'vsz']
        for key in keys:
            result.append((key, {'reference': reference_stats.get(key),
                                 'test': test_stats.get(key),
                                 'difference': float(test_stats.get(key)) - float(reference_stats.get(key)),
                                 'percent': float(test_stats.get(key)) * 100 / float(reference_stats.get(key)) - 100}))
        return result

    def push_to_yt(self, stats):
        from sandbox.projects.geosearch.tools import stat
        import yt.wrapper as yt
        yt_config = {'proxy': {'url': 'hahn.yt.yandex.net'},
                     'token': sdk2.Vault.data('GEOMETA-SEARCH', 'yt-token')}
        client = yt.YtClient(config=yt_config)
        stat._add(
            {
                'dumper.rps': stats['rps'],
                'latency_0.95': stats['latency_95'],
                'latency_0.99': stats['latency_99'],
                'memory_rss': stats['rss'],
                'memory_vsz': stats['vsz']
            }, 'perf', client)

    def get_startrek_token(self):
        return sdk2.Vault.data('robot-geosearch', 'robot_geosearch_startrek_token')

    def get_memory_warning(self, stats):
        nanny_token = sdk2.Vault.data('GEOMETA-SEARCH', 'nanny_token')
        yp_api = yp_lite.YpLightAPI(nanny_token)
        pods = yp_api.get_pods('addrs_base')
        logging.debug('addrs_base pods: {}'.format(pods))
        mem_limits = [pod['spec']['resourceRequests']['memoryLimit'] for pod in pods]
        if mem_limits:
            logging.info('Memory limits on YP pods: {}'.format(mem_limits))
            mem_limit = min(mem_limits)
            memory_rss = stats['rss'] * 1000
            limit_perc = float(memory_rss) / float(mem_limit) * 100
            logging.info('YP lite limit = %s' % limit_perc)
            if limit_perc > 98:
                return (u'!!(red)Внимание! В тестах производительности '
                        u'использовано %d %% памяти от лимита в '
                        u'YP!!') % limit_perc
        return None

    def make_report(self):
        template_path = os.path.dirname(os.path.abspath(__file__))
        env = jinja2.Environment(loader=jinja2.FileSystemLoader(template_path))
        return env.get_template("startrek_template").render(result=OrderedDict(self.Context.result))

    def update_startrek_ticket_RM(self, stats):
        if not self.Parameters.component_name or self.Parameters.release_number == 0:
            return
        startrek_helper = STHelper(self.get_startrek_token())
        component_info = rmc.COMPONENTS[self.Parameters.component_name]()
        release_num = self.Parameters.release_number
        mem_limit_warning = self.get_memory_warning(stats)
        if mem_limit_warning:
            startrek_helper.comment(release_num, mem_limit_warning, component_info)
        startrek_helper.write_grouped_comment('====Performance test',
                                              'Launch performance test',
                                              self.make_report(),
                                              release_num,
                                              component_info)

    def update_startrek_ticket_DB(self, stats):
        startrek_ticket = self.Parameters.startrek_task
        if not startrek_ticket:
            return
        startrek_client = StartrekClient(self.get_startrek_token())
        mem_limit_warning = self.get_memory_warning(stats)
        if mem_limit_warning:
            startrek_client.add_comment(startrek_ticket, mem_limit_warning)
        startrek_client.add_comment(startrek_ticket, '<{%s}>' % self.make_report())

    def on_execute(self):
        self.logger = logging.getLogger('addrs')
        if self.Context.shards == 0:
            self.start_tasks()
        self.logger.info('Completed')
        self.logger.info([task for task in self.find()])
        all_stats = [task.Context.stats for task in self.find() if task.Context.stats]
        self.logger.info(all_stats)
        if len(all_stats) != self.Context.shards:
            raise SandboxTaskFailureError('Subtask failure')
        all_stats_agg = self.aggregate_shard_all_stats(all_stats)
        stats1 = self.aggregate_stats(all_stats_agg, '1')
        stats2 = self.aggregate_stats(all_stats_agg, '2')
        if self.Parameters.push_data:
            self.push_to_yt(stats2)
        self.Context.result = self.compare_stats(stats1, stats2)
        if stats2['rps'] < (1.0 - 1.2 * 0.01) * stats1['rps']:
            self.Context.has_diff = True
        if self.Parameters.launch_type == 'RM':
            self.update_startrek_ticket_RM(stats2)
        elif self.Parameters.launch_type == 'DB':
            self.update_startrek_ticket_DB(stats2)
