from __future__ import division

import logging
import os
import json
import shutil
import time
from collections import namedtuple

from sandbox.common.types import client as ctc
from sandbox.common.types import misc as ctm
from sandbox.common.types import task as ctt
from sandbox import sdk2
from sandbox.common.errors import TaskFailure, TaskError
from sandbox.common.fs import (
    get_unique_file_name,
    make_folder,
)
from sandbox.sandboxsdk.environments import PipEnvironment

from sandbox.projects.common.yabs.server.util.general import CustomAssert

from sandbox.projects.yabs.base_bin_task import BaseBinTask
from sandbox.projects.yabs.qa.errorbooster.decorators import track_errors
from sandbox.projects.yabs.qa.resource_types import (
    BaseBackupSdk2Resource,
    YABS_SERVER_TESTENV_DB_FLAGS,
    YabsServerPerformanceShootData
)
from sandbox.projects.yabs.qa.sut.metastat.adapters.sandbox import YabsStatSandboxAdapter
from sandbox.projects.yabs.qa.sut.metastat.adapters.sandbox.parameters import YabsSUTParameters
from sandbox.projects.yabs.qa.ammo_module.dplan.adapters.sandbox import AmmoDplanModuleSandboxAdapter
from sandbox.projects.yabs.qa.ammo_module.requestlog.adapters.yabs_stat.sandbox import AmmoRequestlogModuleYabsStatSandboxAdapter
from sandbox.projects.yabs.qa.ammo_module.requestlog.adapters.yabs_specific.sandbox import AmmoRequestlogModuleYabsStatPerformanceSandboxAdapter
from sandbox.projects.yabs.qa.dolbilo_module.simple.adapters.sandbox import DolbiloModuleSandboxAdapter
from sandbox.projects.yabs.qa.dumper_module.adapters.sandbox import DumperModuleSandboxAdapter
from sandbox.projects.yabs.qa.perf_module.adapters.sandbox import PerfModuleSandboxAdapter
from sandbox.projects.yabs.qa.performance.stats import dumper_timeline
from sandbox.projects.yabs.qa.sut.solomon_stats import dump_solomon_stats
from sandbox.projects.yabs.qa.tasks.duration_measure_task import BaseDurarionMeasureTask, BaseDurarionMeasureTaskParameters
from sandbox.projects.yabs.qa.utils import Contextable
from sandbox.projects.yabs.qa.utils.base import get_bin_bases_unpacked_size, get_max_unpacking_workers


logger = logging.getLogger(__name__)

MB_IN_GB = 1 << 10
SHOOT_DATA_FOLDER = 'perf_shoot_data_folder'


class PerformancePipelineResults(namedtuple('PerformancePipelineResults', [
    'rps',
    'rps_corrected',
    'usage_data_meta',
    'usage_data_stat',
    'perf_resource_id',
    'rps_list',
    'rps_corrected_list',
]), Contextable):
    pass


StatShootResults = namedtuple('StatShootResults', ['processed_stats', 'perf_record_path', 'usage_data'])


class EmptyContext(object):
    def __enter__(self):
        return None

    def __exit__(self, *args):
        pass


class FlameResource(BaseBackupSdk2Resource):
    pass


class AbandonHostException(TaskError):
    pass


class StatLoadShootParameters(sdk2.Parameters):
    stat_shoot_sessions = sdk2.parameters.Integer('Stat shoot sessions', default_value=3)
    stat_shoot_threads = sdk2.parameters.Integer('Stat shooth threads', default_value=480)
    stat_shoot_request_limit = sdk2.parameters.Integer('Stat shoot request limit', default_value=500000)
    stat_circular_session = sdk2.parameters.Bool('Circular session (will loop back if the ammo is exhausted until limit is reached', default_value=True)
    store_dumps = sdk2.parameters.Bool('Store shoot dumps', default=False)


class YabsServerStatPerformanceParameters(BaseBinTask.Parameters):
    kill_timeout = 60 * 60 * 5
    auto_search = BaseBinTask.Parameters.auto_search(default=True)
    resource_attrs = BaseBinTask.Parameters.resource_attrs(default={'task_bundle': 'yabs_server_stat_load'})
    release_version = BaseBinTask.Parameters.release_version(default=ctt.ReleaseStatus.STABLE)

    with sdk2.parameters.Group('Yabs-server module settings') as yabs_server_module_settings:
        server_module_parameters = YabsSUTParameters()
        stat_module_parameters = YabsStatSandboxAdapter.get_init_parameters_class()()
    with sdk2.parameters.Group('Ammo generation module settings') as ammo_module_settings:
        use_requestlog = sdk2.parameters.Bool('Generate ammo from requestlog instead of dplan', description='Requestlog allows to update requests', default=False)
        with use_requestlog.value[True]:
            requestlog_ammo_module_parameters = AmmoRequestlogModuleYabsStatPerformanceSandboxAdapter.get_init_parameters_class()()
        with use_requestlog.value[False]:
            dplan_ammo_module_parameters = AmmoDplanModuleSandboxAdapter.get_init_parameters_class()()
    with sdk2.parameters.Group('Stat ammo generation module settings') as stat_requestlog_module_settings:
        stat_requestlog_parameters = AmmoRequestlogModuleYabsStatSandboxAdapter.get_init_parameters_class()
    with sdk2.parameters.Group('Shoot module settings') as shoot_module_settings:
        shoot_module_parameters = DolbiloModuleSandboxAdapter.get_init_parameters_class()()
        stat_load_shoot_parameters = StatLoadShootParameters()
    with sdk2.parameters.Group('Dumper module settings') as dumper_module_settings:
        dumper_module = DumperModuleSandboxAdapter.get_init_parameters_class()()
    with sdk2.parameters.Group('General settings') as general_settings:
        update_parameters_resource = sdk2.parameters.Resource('Resource with JSON-dumped parameter update dict', resource_type=YABS_SERVER_TESTENV_DB_FLAGS)
        ram_warmup_binary_resource = sdk2.parameters.Resource('RAM warmup binary', default_value=162057058)
        run_perf = sdk2.parameters.Bool('Run perf', default_value=False)
        with run_perf.value[True]:
            perf_module_parameters = PerfModuleSandboxAdapter.get_init_parameters_class()()
    with sdk2.parameters.Group('Requirements settings (use these instead of requirements tab!)') as requirements_settings:
        shard_space = sdk2.parameters.Integer('Binary base space required for single shard (will account to either disk or ramdrive requirement), GB', default_value=200)
        generic_disk_space = sdk2.parameters.Integer('Generic disk space, GB', default_value=140)
        ram_space = sdk2.parameters.Integer('Ram space, GB', default_value=240)

    with sdk2.parameters.Group('Ammo preparation settings') as ammo_preparation_settings:
        prepare_stat_dplan = sdk2.parameters.Bool('Prepare dplan for stat shoot sessions in this task', default=True)
        with prepare_stat_dplan.value[False]:
            stat_dplan_resource = sdk2.parameters.Resource('Previously prepared request log resource')
            prepared_usage_data_meta = sdk2.parameters.JSON('Data usage info for meta while preparing stat requests')

    duration_parameters = BaseDurarionMeasureTaskParameters()


class YabsServerStatPerformance2(BaseDurarionMeasureTask, BaseBinTask):

    '''
    New task for load b2b tests of yabs-server service.
    '''

    class Parameters(YabsServerStatPerformanceParameters):
        need_host_correction = sdk2.parameters.Bool('Need host correction', default_value=False)
        need_ram_check = sdk2.parameters.Bool('Need additional RAM check', default_value=False)

    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.GENERIC & ctc.Tag.INTEL_E5_2650 & ctc.Tag.LINUX_XENIAL & ctc.Tag.LXC
        environments = (
            PipEnvironment('yandex-yt'),
            PipEnvironment('pandas'),
            PipEnvironment('plotly'),
        )

    def on_save(self):
        super(YabsServerStatPerformance2, self).on_save()
        self.Requirements.ram = self.Parameters.ram_space * MB_IN_GB
        shard_count = len(self.Parameters.stat_shards)
        bin_bases_unpacked_size = []
        if self.Parameters.use_tmpfs:
            bin_bases_unpacked_size = get_bin_bases_unpacked_size(self.Parameters.stat_binary_base_resources)
            self.Context.ramdrive_size = min((self.Parameters.shard_space * shard_count) * MB_IN_GB, sum(bin_bases_unpacked_size) + 10 * MB_IN_GB)
            self.Requirements.ramdrive = ctm.RamDrive(
                ctm.RamDriveType.TMPFS,
                self.Context.ramdrive_size,
                None
            )
            self.Requirements.disk_space = self.Parameters.generic_disk_space * MB_IN_GB
        else:
            self.Requirements.ramdrive = None
            self.Requirements.disk_space = (self.Parameters.generic_disk_space + (self.Parameters.shard_space * shard_count)) * MB_IN_GB
        self.Context.unpacking_workers = get_max_unpacking_workers(bin_bases_unpacked_size, self.Requirements.ram, self.Context.ramdrive_size)

    class Context(sdk2.Task.Context):
        hosts_slots_count = 0
        hosts_slot_index = 0

        rps = 0
        rps_corrected = 0
        rps_list = []
        rps_corrected_list = []
        usage_data_meta = []
        usage_data_stat = []
        perf_resource_id = None

    def validate_parameters(self, parameters):
        CustomAssert(len(parameters.stat_shards) == 1, 'Wrong value for stat_shards parameter (must be a list of length 1 for this task)', TaskFailure)
        if parameters.prepare_stat_dplan:
            CustomAssert(parameters.stat_store_request_log, 'Need stat_store_request_log set to True for this task', TaskFailure)

    def run_performance_pipeline(self, parameters, shared_base_state=None, dump_to_yt_key='', index_2on1=1):
        self.validate_parameters(parameters)

        shoot_module = DolbiloModuleSandboxAdapter(parameters, self).create_module()

        stat_index = parameters.stat_shards[0]
        stat_dplan_path = str(sdk2.ResourceData(parameters.stat_dplan_resource).path)

        dumper_module = DumperModuleSandboxAdapter(parameters, self).create_module()

        # TODO: Create separate shoot module for better control
        shoot_module.adapter.parameters.shoot_request_limit = parameters.stat_shoot_request_limit
        shoot_module.adapter.parameters.circular_session = parameters.stat_circular_session
        shoot_module.adapter.parameters.shoot_threads = parameters.stat_shoot_threads

        if parameters.run_perf:
            perf_module = PerfModuleSandboxAdapter(parameters, self).create_module()
        stat_shoot_results = []
        stat_shoot_rps_graphs = {}

        stat_module_parameters = shoot_module.adapter.parameters
        stat_module_parameters.stat_store_request_log = False

        with self.stage_duration("create_modules"):
            stat_server_adapter = YabsStatSandboxAdapter(
                stat_module_parameters,
                task_instance=self,
                work_dir="stat_server_adapter_{}".format(index_2on1),
            )
            cachedaemon = stat_server_adapter.create_cachedaemon()
            stat_server_module = stat_server_adapter.create_module(
                cachedaemon=cachedaemon,
                shard_no=stat_index,
                shared_base_state=shared_base_state,
                maximum_keys_per_tag=1,
                use_sandbox_config=False
            )
        base_provider = None

        aggegate_access_log_time = 0
        make_folder(SHOOT_DATA_FOLDER, delete_content=True)
        for i in range(parameters.stat_shoot_sessions):
            logging.debug('Start of session %s', i)
            make_folder(os.path.join(SHOOT_DATA_FOLDER, str(i)))

            solomon_stats_dir = get_unique_file_name(stat_server_adapter.get_logs_dir(), "solomon_stats")
            make_folder(solomon_stats_dir)

            with cachedaemon, stat_server_module as stat, stat.get_usage_data() as usage_data_stat:
                with (perf_module.collect_perf_data(stat.get_server_backend_object().process.pid) if parameters.run_perf else EmptyContext()) as perf_record_path:
                    dump_path = shoot_module.shoot_and_watch(stat, stat_dplan_path, store_dump=parameters.store_dumps)
                    with open(os.path.join(SHOOT_DATA_FOLDER, str(i), "access-log.json"), 'w+') as f:
                        start_parsing_time = int(time.time())
                        json.dump(stat.get_server_backend_object().aggregate_access_log(ext_sharded=None).as_dict(), f)
                        access_log_path = stat.get_server_backend_object().get_phantom_log_path(stat.get_server_backend_object().LOG_ACCESS)
                        os.rename(access_log_path, access_log_path[:-4] + str(i) + '.log')
                        finish_parsing_time = int(time.time())
                        aggegate_access_log_time += finish_parsing_time - start_parsing_time
                dump_solomon_stats(stat.get_server_backend_object(), solomon_stats_dir)
                shutil.copytree(solomon_stats_dir, os.path.join(SHOOT_DATA_FOLDER, str(i), "solomon_stats"))

            base_provider = stat.base_provider
            timeline_data = dumper_module.get_timeline(dump_path)
            graph_data = dumper_timeline.generate_graph_data(timeline_data)
            stat_shoot_rps_graphs[i] = graph_data

            stat_shoot_results.append(
                StatShootResults(
                    processed_stats=dumper_module.get_processed_stats(dump_path),
                    perf_record_path=perf_record_path,
                    usage_data=usage_data_stat
                )
            )
            logging.debug('End of session %s', i)
        shutil.make_archive(SHOOT_DATA_FOLDER + str(index_2on1), 'zip', SHOOT_DATA_FOLDER)
        sdk2.ResourceData(YabsServerPerformanceShootData(
            self,
            'Folder with stats performance shoot data',
            SHOOT_DATA_FOLDER + str(index_2on1) + '.zip',
            server_role='stat',
            index_2on1=index_2on1
        )).ready()
        self.set_info('Aggregate access log time: {}s'.format(aggegate_access_log_time))

        if stat_shoot_results:
            max_shoot_results = max(stat_shoot_results, key=lambda results: float(results.processed_stats['rps']))
            if parameters.run_perf:
                perf_results_path = perf_module.generate_perf_results(max_shoot_results.perf_record_path)
                perf_resource_id = FlameResource(
                    self,
                    'Flamegraph resource',
                    perf_results_path
                ).id
            else:
                perf_resource_id = None
            rps_corrected = rps = float(max_shoot_results.processed_stats['rps'])
            rps_corrected_list = rps_list = [float(shoot_result.processed_stats['rps']) for shoot_result in stat_shoot_results]
        else:
            logger.info('Did not run shoot sessions, defaulting RPS to 0')
            rps = 0
            rps_corrected = 0
            perf_resource_id = None
            rps_list = []
            rps_corrected_list = []

        if stat_shoot_rps_graphs:
            self.render_rps_graph(stat_shoot_rps_graphs, index_2on1)

        return PerformancePipelineResults(
            usage_data_meta=[parameters.prepared_usage_data_meta],
            usage_data_stat=[result.usage_data for result in stat_shoot_results],
            rps=rps,
            rps_corrected=rps_corrected,
            perf_resource_id=perf_resource_id,
            rps_list=rps_list,
            rps_corrected_list=rps_corrected_list,
        ), base_provider

    def render_rps_graph(self, graph_data, index_2on1):
        import plotly.graph_objects as go

        fig = go.Figure()
        for session_number, lines in graph_data.items():
            for line_type, points in lines.items():
                fig.add_trace(
                    go.Scatter(
                        x=[p[0] for p in points],
                        y=[p[1] for p in points],
                        name='{}_session_{}'.format(line_type, session_number),
                    )
                )

        fig.write_html(os.path.join(str(self.log_path()), 'rps_graph_{}.html'.format(index_2on1)))

    @track_errors
    def on_execute(self):
        performance_pipeline_results, _ = self.run_performance_pipeline(self.Parameters, dump_to_yt_key='')
        performance_pipeline_results.__to_context__(self.Context)
