# coding=utf-8

import os
import time
import json
import math
import logging
from datetime import datetime
from collections import defaultdict

import requests

from sandbox import sdk2
from sandbox.sdk2.os import enable_turbo_boost, disable_turbo_boost
from sandbox.common.errors import TaskFailure
from sandbox.common.config import Registry

from sandbox.projects.yabs.qa.ammo_module.dplan.adapters.sandbox import AmmoDplanModuleSandboxAdapter
from sandbox.projects.yabs.qa.dolbilo_module.simple import DolbiloModule
from sandbox.projects.yabs.qa.dolbilo_module.simple.adapters.sandbox import DolbiloModuleSandboxAdapter
from sandbox.projects.yabs.qa.dumper_module.adapters.sandbox import DumperModuleSandboxAdapter
from sandbox.projects.yabs.qa.mutable_parameters import MutableParameters

from sandbox.projects.antiadblock.qa.modules.configs_stub_module.adapters.sandbox import ConfigsStubSandboxAdapter
from sandbox.projects.antiadblock.qa.modules.engine_module.adapters.sandbox import EngineSandboxAdapter
from sandbox.projects.antiadblock.qa.resource_types import ANTIADBLOCK_PERFORMANCE_DIFF
from sandbox.projects.antiadblock.qa.tasks.one_performance_shoot import AntiadblockOnePerformanceShoot
from sandbox.projects.antiadblock.qa.tasks.parameters import BaseDiffShootParameters
from sandbox.projects.antiadblock.qa.utils.constants import DiffType, FILE_NGINX_ACCESS_LOG_TMPL, REPORT_KEYS, \
    REPORT_KEYS_STAT, SECOND_PREFIX, STAT_REPORT_PATH, PERCENTILES

from sandbox.projects.antiadblock.utils import ROBOT_ANTIADB_TOKENS_YAV_ID


logger = logging.getLogger("run_task_log")


class DolbiloModuleTmpfs(DolbiloModule):
    def output_dump_path(self, out_dir=''):
        """Use tmpfs for output_path."""
        output_path = super(DolbiloModuleTmpfs, self).output_dump_path(out_dir)
        if self.adapter.task_instance.ramdrive:
            output_path = os.path.join(str(self.adapter.task_instance.ramdrive.path), output_path)
        return output_path


class DolbiloModuleTmpfsSandboxAdapter(DolbiloModuleSandboxAdapter):
    # use custom DolbiloModule
    def create_module(self):
        return DolbiloModuleTmpfs(self)


def make_prefix(i, session):
    return '{}_session_{}'.format(i, session)


def get_percentiles(timelines, quantilies):
    """Extracts percentiles from list, disrupts item order in the list"""
    timelines.sort()
    for quan in quantilies:
        yield quan, timelines[min(int(len(timelines) * quan / 100), len(timelines))] * 1000  # in milliseconds


def parse_nginx_access_log(log_path, quantilies, logger):
    """
    :param log_path: путь до файла с access_log
    :param quantilies: список персентилей, для которых нужно посчитать тайминги
    :return: {key: {'count_req': value, 'count_server_errors': value, 'count_client_errors': value, 'percentiles': {}}}
    Формат лога: каждая отдельная запись - json, обязательно наличие ключей 'response_code', 'request_time',
    Признак акселя наличие ключа 'ar_host', признак бамбузлда наличие ключа 'bamboozled'.
    Количество невалидных записей попадет в 'count_parse_errors'
    """
    logger.info("Check access log {}".format(log_path))
    result = {key: {'count_req': 0, 'count_server_errors': 0, 'count_client_errors': 0, 'percentiles': {}} for key in REPORT_KEYS}
    result['all']['count_parse_errors'] = 0
    timelines = {key: [] for key in REPORT_KEYS}
    logging.info("Try parse {}".format(log_path))
    with open(log_path) as fin:
        for i, line in enumerate(fin):
            try:
                record = json.loads(line.strip())
                key = 'cryprox'
                if record.get('ar_host') is not None:
                    key = 'accelredirect'
                elif record.get('bamboozled') is not None:
                    key = 'bamboozled'
                result[key]['count_req'] += 1
                if record['response_code'] >= 500:
                    result[key]['count_server_errors'] += 1
                elif record['response_code'] >= 400 and record['response_code'] != 411:  # TODO: fix after https://st.yandex-team.ru/ANTIADB-2637
                    result[key]['count_client_errors'] += 1
                timelines[key].append(record['request_time'])
            except Exception as e:
                result['all']['count_parse_errors'] += 1
                logger.warning('Error when parsing record #{}: {}'.format(i + 1, e.message))
    for key in REPORT_KEYS:
        if key == 'all':
            continue
        result['all']['count_req'] += result[key]['count_req']
        result['all']['count_server_errors'] += result[key]['count_server_errors']
        result['all']['count_client_errors'] += result[key]['count_client_errors']
        timelines['all'].extend(timelines[key])

    if quantilies is not None:
        for key in REPORT_KEYS:
            if timelines[key]:
                for quan, val in get_percentiles(timelines[key], quantilies):
                    result[key]['percentiles'][quan] = val
    return result


def make_average(diff):
    result = {key: defaultdict(dict) for key in REPORT_KEYS}
    for key in REPORT_KEYS:
        for section, shoots in diff[key].items():
            for shoot, val in shoots.items():
                if val:
                    avg = sum(val) / len(val)
                    sd = 0
                    for el in val:
                        sd += (el - avg) ** 2
                    sd = math.sqrt(sd / (len(val) - 1))
                    result[key][section][shoot] = '{} ± {}'.format(round(avg, 2), round(sd, 2)) if sd > 0 else avg
    return result


def median(data):
    _data = sorted(data)
    _index = (len(_data) - 1) // 2
    if len(_data) % 2 == 1:
        return _data[_index]
    return (_data[_index] + _data[_index + 1]) / 2.0


def check_diff(list_1, list_2, diff_thr, reverse=False):
    assert len(list_1) == len(list_2)
    first_med = median(list_1)
    second_med = median(list_2)
    _res = first_med - second_med if reverse else second_med - first_med
    return _res > first_med * diff_thr


def render_report(diff, rev0, rev1, count_session, rps=None):
    # stat all sessions on shoot
    _res2 = '<h2>{}</h2>'.format('All stats per shoot')
    for shoot in ('shoot 0', 'shoot 1'):
        _res2 += '\n<h3>{}</h3>'.format(shoot)
        _res2 += '\n<table border=1 width={}px>\n\t<tr>'.format(300*(count_session + 1))
        for cname in ['param name'] + ['session {}'.format(i) for i in range(count_session)]:
            _res2 += '<th align=center>{}</th>'.format(cname)
        _res2 += '</tr>'
        for key in REPORT_KEYS:
            _res2 += '\n\t<tr><td align=center colspan={} bgcolor=#FBF0DB>{}</td></tr>'.format(count_session + 1, key.upper())
            for section, val in sorted(diff[key].items()):
                _res2 += '\n\t<tr><td align=center>{}</td>'.format(section)
                for elem in val.get(shoot, []):
                    _res2 += '<td align=center>{}</td>'.format(elem)
                _res2 += '</tr>'
        _res2 += '\n</table>'

    if rps is None:
        # average stat for timings
        diff = make_average(diff)
        _res = '<h2>{}</h2>'.format('Average stats from NGINX access log')
        _res += '\n<table border=1 width=900px>\n\t<tr>'
        for cname in ('param name', 'shoot 0 ({})'.format(rev0), 'shoot 1 ({})'.format(rev1)):
            _res += '<th align=center>{}</th>'.format(cname)
        _res += '</tr>'
        for key in REPORT_KEYS:
            _res += '\n\t<tr><td align=center colspan=3 bgcolor=#FBF0DB>{}</td></tr>'.format(key.upper())
            for section, val in sorted(diff[key].items()):
                _res += '\n\t<tr><td align=center>{}</td><td align=center>{}</td><td align=center>{}</td></tr>'.format(section, val.get('shoot 0', ''), val.get('shoot 1', ''))

        _res += '\n</table>'
    else:
        # stats for rps
        _res = '<h2>{}</h2>'.format('RPS stats for d-executor dump')
        _res += '\n<table border=1 width=900px>\n\t<tr>'
        for cname in ('shoot 0 ({})'.format(rev0), 'shoot 1 ({})'.format(rev1)):
            _res += '<th align=center>{}</th>'.format(cname)
        _res += '</tr>'
        _res += '\n\t<tr><td align=center>{}</td><td align=center>{}</td></tr>'.format(rps.get('shoot 0', []), rps.get('shoot 1', []))
        _res += '\n</table>'

    return _res + '\n' + _res2


class AntiadblockPerformanceShootsDiff(AntiadblockOnePerformanceShoot):
    name = 'ANTIADBLOCK_PERFORMANCE_SHOOTS_DIFF'

    class Context(sdk2.Task.Context):
        has_diff = True

    class Parameters(BaseDiffShootParameters):
        description = 'Antiadblock performance shoot diff task'

        with sdk2.parameters.Group('General settings') as general:
            rps = sdk2.parameters.Integer('RPS', default_value=100, required=True)
            shoot_time = sdk2.parameters.Integer('Shoot time in seconds', default_value=60, required=True)
            error_rate_thr = sdk2.parameters.Float('Error rate threshold', default_value=0.003, required=True)
            perf_diff_thr = sdk2.parameters.Float('Performance diff threshold (p99)', default_value=0.02, required=True)
            count_shoot_sessions = sdk2.parameters.Integer('Count shoot sessions', default_value=5,  required=True)
            push_to_stat = sdk2.parameters.Bool('Push perf stat to report', default=False, required=True)
            disable_turbo_boost = sdk2.parameters.Bool('Disable Turbo Boost', default=False, required=True)

        with sdk2.parameters.Output:
            has_diff = sdk2.parameters.Bool('Shoots has diff', default=False)

    def push_perf_stat(self, values, diff_value, revision):
        import statface_client

        if self.engine_service.test_type not in (DiffType.LATENCY, DiffType.CAPACITY) or values is None:
            return

        sandbox_host = Registry().this.fqdn
        date = datetime.now().strftime("%Y-%m-%d 00:00:00")
        diff_type = "\t{}\t".format(self.engine_service.test_type.name)

        if self.engine_service.test_type == DiffType.CAPACITY:
            # values = [1, 2, 3, 4, 5]
            data = [{
                "fielddate": date,
                "diff_type": diff_type,
                "value": median(values),
                "diff_value": diff_value,
                "revision": revision,
                "sandbox_host": sandbox_host,
            }]
        else:
            # values = {
            #     "ALL": {
            #         "p50": [1, 2, 3, 4, 5],
            #         ...
            #         "p99": [1, 2, 3, 4, 5],
            #     },
            #     ...
            # }
            data = []
            for key, percentiles in values.items():
                row = {
                    "fielddate": date,
                    "diff_type": "{}{}\t".format(diff_type, key),
                    "value": median(percentiles["p99"]),
                    "revision": revision,
                    "sandbox_host": sandbox_host,
                    "diff_value": diff_value[key],
                }
                for perc, val in percentiles.items():
                    row[perc] = median(val)
                data.append(row)

        stat_token = sdk2.yav.Secret(ROBOT_ANTIADB_TOKENS_YAV_ID).data()["ANTIADBLOCK_STAT_TOKEN"]
        stat_client = statface_client.StatfaceClient(host=statface_client.STATFACE_PRODUCTION, oauth_token=stat_token)
        report = stat_client.get_report(STAT_REPORT_PATH)
        upload_result = report.upload_data(scale='d', data=data)
        logger.info(upload_result)

    def run_pipeline(self, prefix='0'):
        self.dumper_module = DumperModuleSandboxAdapter(self.Parameters.dumper_parameters, self).create_module()
        self.ammo_module = AmmoDplanModuleSandboxAdapter(self.Parameters.ammo_parameters, self).create_module()
        self.shoot_module = DolbiloModuleTmpfsSandboxAdapter(self.Parameters.dolbilo_parameters, self).create_module()
        # setup module for shooting with fixed rps
        self.shoot_module.adapter.parameters.circular_session = True  # reuse bullets
        self.shoot_module.adapter.parameters.mode = 'plan'  # shoot mode

        with self.engine_service as active_service:
            self.warm_up(active_service, logger)
            # main shoot
            # fixed rps and shooting time
            self.shoot_module.adapter.parameters.mode_arg = ['--rps-schedule', 'const({}, {})'.format(self.Parameters.rps,
                                                                                                      self.Parameters.shoot_time)]
            self.shoot_module.shoot(active_service, self.ammo_module.get_dplan_path(), store_dump=False)

        self.save_resources(dump_path=None, prefix=prefix)

    def generate_second_engine_parameters(self):
        parameters = MutableParameters.__from_parameters__(self.Parameters.second_engine_parameters)
        for name, value in MutableParameters.__from_parameters__(self.Parameters.second_engine_parameters):
            if name.endswith(SECOND_PREFIX):
                target_name = name[:-len(SECOND_PREFIX)]
                setattr(parameters, target_name, value)
        return parameters

    def make_diff(self, quantilies, rps=None):
        _diff = {key: defaultdict(lambda: defaultdict(list)) for key in REPORT_KEYS}
        p99 = [[], []]
        errors = []
        engine_logs_path = getattr(self, 'engine_logs_path', '')
        for _ind in (0, 1):
            shoot = 'shoot {}'.format(_ind)
            for session in range(int(self.Parameters.count_shoot_sessions)):
                nginx_access_stat = parse_nginx_access_log(os.path.join(engine_logs_path, FILE_NGINX_ACCESS_LOG_TMPL.format(make_prefix(_ind, session))),
                                                           quantilies, logger)
                for key in REPORT_KEYS:
                    for section, val in nginx_access_stat[key].items():
                        if section == 'percentiles':
                            for quan, v in val.items():
                                # дифф по 99 персентилю считаем только для раздела all
                                if self.engine_service.test_type == DiffType.LATENCY and quan == 99 and key == 'all':
                                    p99[_ind].append(v)
                                _diff[key]['p{} (ms)'.format(quan)][shoot].append(v)
                        else:
                            _diff[key][section][shoot].append(val)

                # errors
                for error in ('count_server_errors', 'count_client_errors'):
                    if nginx_access_stat['all'][error] > float(self.Parameters.error_rate_thr) * nginx_access_stat['all']['count_req']:
                        message = 'Too many {} in our nginx access log ({}, session {}): {}'.format(error, shoot, session, nginx_access_stat['all'][error])
                        errors.append(message)

        rev0 = getattr(self.Parameters.first_engine_parameters.cryprox_package_resource, 'svn_revision', 'unknown')
        rev1 = getattr(self.Parameters.second_engine_parameters.cryprox_package_resource_2, 'svn_revision', 'unknown')
        report = render_report(_diff, rev0, rev1, int(self.Parameters.count_shoot_sessions), rps=rps)

        filename = 'shoots_diff.html'

        with open(filename, mode="w") as fin:
            fin.write(report)
        diff_resource = ANTIADBLOCK_PERFORMANCE_DIFF(self, 'shoots diff', filename)
        sdk2.ResourceData(diff_resource).ready()
        self.Context.report = report
        self.Context.save()

        if errors:
            raise TaskFailure('\n'.join(errors))

        data = None
        diff_value = None
        if self.engine_service.test_type == DiffType.LATENCY:
            # p99 has diff
            self.Context.has_diff = check_diff(p99[0], p99[1], self.Parameters.perf_diff_thr)
            data = {key: defaultdict(list) for key in REPORT_KEYS_STAT}
            diff_value = {}
            for key in REPORT_KEYS_STAT:
                for p in PERCENTILES[:-1]:
                    perc = 'p{}'.format(p)
                    data[key][perc] = _diff[key.lower()]['p{} (ms)'.format(p)]['shoot 1']
                # diff для всех ключей по p99 ('ALL', 'CRYPROX', 'ACCELREDIRECT')
                diff_value[key] = median(_diff[key.lower()]['p99 (ms)']['shoot 1']) - median(_diff[key.lower()]['p99 (ms)']['shoot 0'])
        elif self.engine_service.test_type == DiffType.CAPACITY:
            # rps has diff
            self.Context.has_diff = check_diff(rps.get('shoot 0', []), rps.get('shoot 1', []), self.Parameters.rps_diff_thr, reverse=True)
            data = rps['shoot 1']
            # diff rps
            diff_value = median(rps['shoot 1']) - median(rps['shoot 0'])
        self.Context.save()
        self.Parameters.has_diff = self.Context.has_diff
        # push to stat
        if self.Parameters.push_to_stat and rev0 != rev1:
            self.push_perf_stat(data, diff_value, rev1)

    @sdk2.report(title="Shoots diff")
    def report(self):
        return self.Context.report or "No report discovered in context"

    def on_execute(self):
        if self.Parameters.disable_turbo_boost:
            disable_turbo_boost()
        stub_dir = str(self.ramdrive.path) if self.ramdrive else ''
        self.configs_stub_module = ConfigsStubSandboxAdapter(self.Parameters.configs_stub_parameters, self).create_module()
        with self.configs_stub_module as active_stub:
            # wait while stub started
            time.sleep(180)
            for i, engine_parameters in enumerate([self.Parameters.first_engine_parameters, self.generate_second_engine_parameters()]):
                # reload configs
                self.reload_configs(engine_parameters.replaced_configs, active_stub.get_port(), logger)
                for session in range(int(self.Parameters.count_shoot_sessions)):
                    logger.info("Shoot {}, session {}".format(i, session))
                    self.engine_service = EngineSandboxAdapter(engine_parameters, self).create_module(stub_dir, DiffType.LATENCY,  active_stub.get_port())
                    self.run_pipeline(prefix=make_prefix(i, session))
                    self.engine_service.clean()
                    # wait when all ports unbinding
                    time.sleep(180)

        logger.info("Make results diff")
        self.make_diff(quantilies=PERCENTILES)
        # return Turbo Boost
        if self.Parameters.disable_turbo_boost:
            enable_turbo_boost()

    @staticmethod
    def reload_configs(params, port, logger):
        if params and isinstance(params, dict):
            url = "http://localhost:{}/".format(port)
            try:
                resp = requests.post(url, json=params)
                resp.raise_for_status()
                logger.info("Reload configs success.\nUrl: {},\nparams={},\n".format(url, str(params)))
            except Exception as e:
                logger.warning("Reload configs failed.\nUrl: {},\nparams={},\nerror={}\n".format(url, str(params), str(e)))
