from sandbox import sdk2
from sandbox.common.types import misc as ctm

from sandbox.projects.yabs.qa.errorbooster.decorators import track_errors
from sandbox.projects.yabs.qa.sut.metastat.adapters.sandbox import YabsStatSandboxAdapter
from sandbox.projects.yabs.qa.sut.metastat.adapters.sandbox.parameters import YabsSUTParameters2On1
from sandbox.projects.yabs.qa.tasks.YabsServerStatPerformance2 import YabsServerStatPerformance2, MB_IN_GB
from sandbox.projects.yabs.qa.mutable_parameters import MutableParameters
from sandbox.projects.yabs.qa.utils.base import get_bin_bases_unpacked_size, get_max_unpacking_workers

import collections
import logging

ParametersFor2On1Batch = collections.namedtuple('ParametersFor2On1', ['parameters', 'dump_to_yt_key', 'second_batch'])


class YabsServerStatPerformance2On1_2(YabsServerStatPerformance2):

    '''
    New task for load b2b tests of yabs-server service (2 on 1).
    '''

    name = 'YABS_SERVER_STAT_PERFORMANCE_2_ON_1_2'

    class Parameters(YabsServerStatPerformance2.Parameters):
        kill_timeout = 6 * 60 * 60  # in seconds

        with sdk2.parameters.Group('Yabs-server module settings') as yabs_server_module_settings:
            server_module_parameters = YabsSUTParameters2On1()
            stat_module_parameters = YabsStatSandboxAdapter.get_init_parameters_class(for_2_on_1=True)()

        prepare_stat_dplan_2 = sdk2.parameters.Bool('Prepare dplan for stat shoot sessions in this task (2)', default=True)
        with prepare_stat_dplan_2.value[False]:
            prepared_usage_data_meta_2 = sdk2.parameters.JSON('Data usage info for meta while preparing stat requests (2)')
            stat_dplan_resource_2 = sdk2.parameters.Resource('Previously prepared request log resource (2)')

        store_dumps_2 = sdk2.parameters.Bool('Store shoot dumps', default=False)
        shuffle_run_order = sdk2.parameters.Bool('Shuffle order of runs within 2 on 1', default_value=True)
        flush_bases_directory = sdk2.parameters.Bool('Flush yabs-server bases directory between 2 on 1 runs', default_value=True)

    class Context(YabsServerStatPerformance2.Context):
        rps_2 = 0
        rps_corrected_2 = 0
        rps_list_2 = []
        rps_corrected_list_2 = []
        usage_data_meta_2 = []
        usage_data_stat_2 = []
        perf_resource_id_2 = None

    def on_save(self):
        super(YabsServerStatPerformance2, self).on_save()
        self.Requirements.ram = self.Parameters.ram_space * MB_IN_GB
        shard_count = len(self.Parameters.stat_shards)
        bin_bases_unpacked_size = []
        if self.Parameters.use_tmpfs:
            bin_bases_unpacked_size = max(
                get_bin_bases_unpacked_size(self.Parameters.stat_binary_base_resources),
                get_bin_bases_unpacked_size(self.Parameters.stat_binary_base_resources_2),
                key=sum
            )
            self.Context.ramdrive_size = min((self.Parameters.shard_space * shard_count) * MB_IN_GB, sum(bin_bases_unpacked_size) + 20 * MB_IN_GB)
            self.Requirements.ramdrive = ctm.RamDrive(
                ctm.RamDriveType.TMPFS,
                self.Context.ramdrive_size,
                None
            )
            self.Requirements.disk_space = self.Parameters.generic_disk_space * MB_IN_GB
        else:
            self.Requirements.ramdrive = None
            self.Requirements.disk_space = (self.Parameters.generic_disk_space + (self.Parameters.shard_space * shard_count)) * MB_IN_GB
        self.Context.unpacking_workers = get_max_unpacking_workers(bin_bases_unpacked_size, self.Requirements.ram, self.Context.ramdrive_size)

    @track_errors
    def on_execute(self):
        parameters_ordered_list = [
            ParametersFor2On1Batch(self.Parameters, 'baseline', False),
            ParametersFor2On1Batch(self.generate_second_server_parameters(), 'testing', True)
        ]
        if self.Parameters.shuffle_run_order and (self.Context.hosts_slot_index % 2):
            logging.info('Swapping test order')
            parameters_ordered_list.reverse()
        self.run_2_on_1_pipeline(*parameters_ordered_list)

    def run_2_on_1_pipeline(self, parameters_1, parameters_2):
        performance_pipeline_results, base_provider = self.run_performance_pipeline(
            parameters_1.parameters,
            dump_to_yt_key=parameters_1.dump_to_yt_key,
            index_2on1=1,
        )
        performance_pipeline_results.__to_context__(self.Context, second_batch=parameters_1.second_batch)
        if self.Parameters.flush_bases_directory:
            base_provider.flush_state()
        performance_pipeline_results_2, _ = self.run_performance_pipeline(
            parameters_2.parameters,
            shared_base_state=base_provider.base_state,
            dump_to_yt_key=parameters_2.dump_to_yt_key,
            index_2on1=2,
        )
        performance_pipeline_results_2.__to_context__(self.Context, second_batch=parameters_2.second_batch)

    def generate_second_server_parameters(self):
        parameters = MutableParameters.__from_parameters__(self.Parameters)
        for name, value in parameters:
            if name.endswith('_2'):
                target_name = name[:-2]
                setattr(parameters, target_name, value)
        return parameters
