import logging
import json
import time

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types.task import ReleaseStatus
from sandbox.projects.common.testenv_client.api_client import TestenvApiClient
from sandbox.projects.common.testenv_client import TEClient
import sandbox.projects.release_machine.components.all as rmc
from sandbox.projects.yabs.base_bin_task import BaseBinTaskMixin, base_bin_task_parameters
from sandbox.projects.yabs.qa.constants import SamplingStrategies
from sandbox.projects.yabs.qa.solomon.mixin import SolomonTaskMixin, SolomonTaskMixinParameters
from sandbox.projects.yabs.qa.tasks.YabsServerPerformancePlotter.tables import write_data, merge_data, DIFF_RPS_UNSORTED_SCHEMA
from sandbox.projects.yabs.qa.utils.arcadia import get_revision_datetime
from sandbox.projects.yabs.release.binary_search.testenv import get_interval_info_for_test
from sandbox.projects.yabs.release.binary_search.intervals import Partitioner, NO_CMP_TASK, BROKEN, FIXED
from sandbox.projects.yabs.release.common import BaseReleaseTask


logger = logging.getLogger(__name__)


def get_diff_rps_data_from_task(task_id, role, sampling_strategy):
    cmp_task = sdk2.Task[task_id]
    rps_pre = cmp_task.Context.rps_hl_median
    rps_test = cmp_task.Context.rps_hl_median_2
    if not rps_pre or not rps_test:
        return {}

    pre_task = cmp_task.Parameters.pre_task
    meta_mode = pre_task.Parameters.meta_mode
    ammo_resource_id = pre_task.Parameters.requestlog_resource.id
    shard = None
    if role == 'stat':
        shard = pre_task.Parameters.stat_shards[0]

    return {
        'role': role,
        'meta_mode': meta_mode,
        'handlers': [],
        'shard': shard,
        'sampling_strategy': sampling_strategy,

        'baseline_rps': rps_pre,
        'test_rps': rps_test,
        'ammo_resource': ammo_resource_id,
    }


def get_sensors_from_task(task_id, role, sampling_strategy, end):
    sensors = []
    rps_pre = sdk2.Task[task_id].Context.rps_hl_median
    rps_test = sdk2.Task[task_id].Context.rps_hl_median_2
    meta_mode = sdk2.Task[task_id].Parameters.pre_task.Parameters.meta_mode

    common_labels = {
        "meta_mode": meta_mode,
        "role": role,
        "sampling_strategy": sampling_strategy,
    }
    if role == 'stat':
        common_labels['shard'] = sdk2.Task[task_id].Parameters.pre_task.Parameters.stat_shards[0]
    ts = time.mktime(get_revision_datetime(end).timetuple())

    if isinstance(rps_pre, (int, float)) and isinstance(rps_test, (int, float)):
        labels = dict(common_labels, sensor="rps_ratio")
        sensors.append({
            "labels": labels,
            "value": rps_test / rps_pre,
            "ts": ts,
        })
        labels = dict(common_labels, sensor="rps_relative_change")
        sensors.append({
            "labels": labels,
            "value": (rps_test - rps_pre) / rps_pre,
            "ts": ts,
        })
    else:
        logger.error('Got incorrect data from task %s: rps_pre=%s, rps_test=%s', task_id, rps_pre, rps_test)
    return sensors


class YabsServerPerformancePlotter(SolomonTaskMixin, BaseBinTaskMixin, sdk2.Task):

    __start_revision = None
    __component_info = None

    class Parameters(BaseReleaseTask.Parameters):
        _base_bin_task_parameters = base_bin_task_parameters(
            release_version_default=ReleaseStatus.STABLE,
            resource_attrs_default={"task_type": "YABS_SERVER_PERFORMANCE_PLOTTER"},
        )
        tokens = sdk2.parameters.YavSecret("YAV secret identifier", default="sec-01d6apzcex5fpzs5fcw1pxsfd5")
        solomon_parameters = SolomonTaskMixinParameters()

        with sdk2.parameters.Group('YT storage parameters') as yt_parameters:
            diff_rps_yt_table_prefix = sdk2.parameters.String('Prefix to store diff rps data', default='//home/yabs-cs-sandbox/performance/diff_rps/')

        with sdk2.parameters.Group('TestEnv parameters') as testenv_parameters:
            database = sdk2.parameters.String('TestEnv project aka database', default='yabs-2.0')
            stat_tests = sdk2.parameters.List('Stat performance tests', default=[
                "YABS_SERVER_40_PERFORMANCE_BEST_BS",
                "YABS_SERVER_40_PERFORMANCE_BEST_BS_B",
                "YABS_SERVER_40_PERFORMANCE_BEST_BSRANK",
                "YABS_SERVER_40_PERFORMANCE_BEST_BSRANK_B",
                "YABS_SERVER_40_PERFORMANCE_BEST_YABS",
                "YABS_SERVER_40_PERFORMANCE_BEST_YABS_B",
            ])
            meta_tests = sdk2.parameters.List('Meta performance tests', default=[
                "YABS_SERVER_45_PERFORMANCE_META_BS_A_B",
                "YABS_SERVER_45_PERFORMANCE_META_BSRANK_A_B",
                "YABS_SERVER_45_PERFORMANCE_META_YABS_A_B",
            ])
            stat_tests_sampled = sdk2.parameters.List('Stat performance sampled tests', default=[
                "YABS_SERVER_40_PERFORMANCE_BEST_BS_SAMPLED",
                "YABS_SERVER_40_PERFORMANCE_BEST_BSRANK_SAMPLED",
                "YABS_SERVER_40_PERFORMANCE_BEST_YABS_SAMPLED",
            ])
            meta_tests_sampled = sdk2.parameters.List('Meta performance sampled tests', default=[
                "YABS_SERVER_45_PERFORMANCE_META_BS_A_B_SAMPLED",
                "YABS_SERVER_45_PERFORMANCE_META_BSRANK_A_B_SAMPLED",
                "YABS_SERVER_45_PERFORMANCE_META_YABS_A_B_SAMPLED",
            ])

        with sdk2.parameters.Output:
            last_processed_revisions = sdk2.parameters.Dict('Last processed revisions by test')

    class Requirements(sdk2.Requirements):
        ram = 512
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    @property
    def component_info(self):
        if self.__component_info is None:
            self.__component_info = rmc.get_component(self.Parameters.component_name)
        return self.__component_info

    @property
    def start_revision(self):
        if not self.__start_revision:
            if self.Parameters.start_revision:
                logger.info("Using explicitly provided start_revision")
                self.__start_revision = self.Parameters.start_revision

            else:
                logger.info("Automatically detecting start_revision relative to last stable release")
                self.__start_revision = self.component_info.first_rev - 1

        return self.__start_revision

    def get_test_start_revision(self, test_name):
        if self.scheduler:
            prev = self.find(scheduler=self.scheduler, status='SUCCESS', order='-id').limit(1).first()
            if not prev:
                logger.warning('No previous successful tasks found for scheduler %s', self.scheduler)
                return self.start_revision

            return prev.Parameters.last_processed_revisions.get(test_name, self.start_revision)
        return self.start_revision

    def get_test_end_revision(self, test_name, start_revision):
        if self.Parameters.final_revision:
            return self.Parameters.final_revision

        checked_intervals = list(TEClient.iter_checked_intervals(self.Parameters.database, test_name, start_revision))
        logger.debug('Got checked_intervals for %s:\n%s', test_name, json.dumps(checked_intervals, indent=2))
        end_revision = max([r['last_revision'] for r in checked_intervals if r['is_checked']])

        return end_revision

    def get_test_sensors(self, role, test_name, sampling_strategy, start_revision, end_revision, testenv_api_client):
        sensors = []
        diff_rps_data = []

        logger.info('Will try to find intervals for %s (%s - %s)', test_name, start_revision, end_revision)

        interval_map, problems_by_interval = get_interval_info_for_test(testenv_api_client, self.Parameters.database, test_name, start_revision, end_revision)
        interval_sequence = Partitioner(interval_map).partition_interval((start_revision, end_revision))

        logger.info('Got intervals sequence: %s', json.dumps(interval_sequence, indent=2))

        for interval in interval_sequence:
            if interval.status in (NO_CMP_TASK, BROKEN, FIXED) or not interval.task_id:
                logger.warning('Got weird interval for %s: %s', test_name, interval)
                continue
            diff_rps_data_from_task = get_diff_rps_data_from_task(interval.task_id, role, sampling_strategy)
            if diff_rps_data_from_task:
                diff_rps_data_from_task.update(baseline_revision=interval.begin, test_revision=interval.end)
                diff_rps_data.append(diff_rps_data_from_task)

            sensors.extend(get_sensors_from_task(interval.task_id, role, sampling_strategy, interval.end))

        return sensors, diff_rps_data

    def write_data(self, yt_client, rps_diff_data, trunk_table_path):
        write_data(yt_client, trunk_table_path, rps_diff_data, schema=DIFF_RPS_UNSORTED_SCHEMA)

    def merge_data(self, yt_client, trunk_data, trunk_table_path, release_table_prefix, last_processed_revisions):
        from release_machine.release_machine.services.release_engine.services.Model import ModelClient
        from sandbox.projects.release_machine.core import const
        from release_machine.release_machine.proto.structures import message_pb2
        rm_client = ModelClient.from_address(const.Urls.RM_HOST)
        scopes = rm_client.get_scopes(message_pb2.ScopesRequest(component_name="yabs_server", limit=10, start_scope_number=0))

        scope_revisions = {
            scope.scope_number: scope.branch.base_commit_id
            for scope in scopes.branch_scopes
        }

        logger.info("Got release scopes: %s", scope_revisions)
        merge_data(yt_client, trunk_data, trunk_table_path, release_table_prefix.rstrip('/'), scope_revisions, min(last_processed_revisions.values()))

    def on_execute(self):
        testenv_token = sdk2.Vault.data(self.Parameters.te_vault_name)
        yt_token = self.Parameters.tokens.data()['yt_token']

        testenv_api_client = TestenvApiClient(token=testenv_token)

        sensors = []
        rps_diff_data = []
        last_processed_revisions = {}
        failed_tests = []

        for role, test_name, sampling_strategy in [
            ('stat', _test_name, SamplingStrategies.FULL) for _test_name in self.Parameters.stat_tests
        ] + [
            ('meta', _test_name, SamplingStrategies.FULL) for _test_name in self.Parameters.meta_tests
        ] + [
            ('stat', _test_name, SamplingStrategies.SAMPLED) for _test_name in self.Parameters.stat_tests_sampled
        ] + [
            ('meta', _test_name, SamplingStrategies.SAMPLED) for _test_name in self.Parameters.meta_tests_sampled
        ]:
            start_revision = self.get_test_start_revision(test_name)
            end_revision = self.get_test_end_revision(test_name, start_revision)
            logger.info('Will try to find intervals for %s (%s - %s)', test_name, start_revision, end_revision)

            test_sensors = []
            try:
                test_sensors, test_rps_diff_data = self.get_test_sensors(role, test_name, sampling_strategy, start_revision, end_revision, testenv_api_client)
            except Exception:
                logger.debug('Failed to get data for %s', test_name, exc_info=True)
                failed_tests.append(test_name)
                last_processed_revisions[test_name] = start_revision
            else:
                sensors.extend(test_sensors)
                rps_diff_data.extend(test_rps_diff_data)
                last_processed_revisions[test_name] = end_revision

        try:
            from yt.wrapper import YtClient, ypath_join
            trunk_table_path = ypath_join(self.Parameters.diff_rps_yt_table_prefix, 'trunk')

            yt_client = YtClient(proxy='hahn', token=yt_token)
            self.merge_data(yt_client, rps_diff_data, trunk_table_path, self.Parameters.diff_rps_yt_table_prefix, last_processed_revisions)
        except Exception:
            logger.error('Unable to write data', exc_info=True)

        self.Parameters.last_processed_revisions = last_processed_revisions
        self.solomon_push_client.add(sensors)

        if failed_tests:
            raise TaskFailure('Failed to get performance data for {}'.format(', '.join(failed_tests)))
