# -*- coding: utf-8 -*-
import datetime
import operator
import os
import sandbox.common.types.task as ctt
from sandbox import sdk2
from sandbox.common.utils import get_task_link
from sandbox.projects.modadvert.common import modadvert
from sandbox.projects.modadvert.rm.constants import AUTOMODERATOR_OBJECT_FIELDS, KEY_COLUMNS
from sandbox.projects.modadvert.RunComparison.stat import get_stat_table
from sandbox.projects.release_machine.components import all as rmc
from sandbox.projects.release_machine.core import const as rm_const
from sandbox.projects.release_machine.core import task_env
from sandbox.projects.release_machine.helpers.startrek_helper import STHelper
from sandbox.sandboxsdk.environments import PipEnvironment


class ModadvertRunB2B(modadvert.ModadvertBaseRunBinaryTask):
    """
    Base task for automoderators' B2B.
    This task should NOT be runned directly
    """

    name = 'MODADVERT_RUN_B2B'
    resource_name = 'compare-automoderator-results'
    automoderator_name = None

    class Requirements(modadvert.ModadvertBaseRunBinaryTask.Requirements):
        environments = (
            task_env.TaskRequirements.startrek_client,
            PipEnvironment('yandex-yt', '0.8.38a1'),
            PipEnvironment('yandex-yt-yson-bindings-skynet', '0.3.7.post1'),
        )
        client_tags = task_env.TaskTags.startrek_client

    class Parameters(modadvert.ModadvertBaseRunBinaryTask.Parameters):

        kill_timeout = 36 * 60 * 60  # 36h

        with sdk2.parameters.Group('Comparison parameters') as comparison_group:
            comparison_root_dir = sdk2.parameters.String('Root directory where comparison results will be stored', required=True)
            comparison_yt_memory_limit = sdk2.parameters.Integer(
                'Memory limit for YT jobs in comparison script',
                default=6442450944
            )
            comparison_yt_max_failed_job_count = sdk2.parameters.Integer(
                'Limit of max failed YT jobs in comparison script',
                default=5
            )

        with sdk2.parameters.Group('Input requests parameters') as requests_group:
            max_tables_age = sdk2.parameters.Integer(
                'Max tables age to process in hours. Tables will be filtered by "creation_time" attribute',
                required=False,
                default=48
            )
            src_tables_dir = sdk2.parameters.String(
                'Source tables directory (NOTE: objects should be already splitted into sub-objects)',
                required=True
            )
            sampling_rate = sdk2.parameters.Float('Input requests sampling rate', default=1.0)

        with sdk2.parameters.Group('Testenv parameters') as testenv_group:
            release_number = sdk2.parameters.Integer('release number', default=None)
            with sdk2.parameters.String('Component name', multiline=True, default=None) as component_name:
                for name in rmc.get_component_names():
                    setattr(component_name.values, name, name)

    def create_comparison_dir(self):
        self.Context.comparison_dir = self.yt_client.find_free_subpath(self.Parameters.comparison_root_dir)
        for subdir in ['data', 'base', 'feature']:
            subdir_key = 'comparison_{}_dir'.format(subdir)
            setattr(
                self.Context,
                subdir_key,
                os.path.join(self.Context.comparison_dir, subdir)
            )
            self.yt_client.create(path=getattr(self.Context, subdir_key), type='map_node', recursive=True)

    def prepare_requests(self):
        """
        Prepares YT table with requests for B2B
        Firstly, picks tables from `src_tables_dir` which are not older than `max_tables_age` hours (according to `creation_time` attribute)
        Secondly, merges this picked tables to a single table and performs requests sampling with `sampling_rate` ratio
        Resulting table will be approximately `1/sampling_rate` times smaller (but output size hugely varies depending on chunks of input tables)
        """
        min_creation_time = datetime.datetime.now() - datetime.timedelta(hours=self.Parameters.max_tables_age)
        self.Context.raw_src_tables = list(self.yt_client.search(
            self.Parameters.src_tables_dir,
            node_type='table',
            attributes=['creation_time'],
            object_filter=lambda table: datetime.datetime.strptime(table.attributes.get('creation_time'), '%Y-%m-%dT%H:%M:%S.%fZ') >= min_creation_time
        ))
        src_table = self.yt_client.TablePath(
            os.path.join(self.Context.comparison_data_dir, 'requests'),
        )
        self.Context.src_table = str(src_table)

        self.yt_client.run_merge(
            source_table=[
                self.yt_client.TablePath(table, columns=list(AUTOMODERATOR_OBJECT_FIELDS.keys()))  # Remove unnecessary columns
                for table in self.Context.raw_src_tables
            ],
            destination_table=src_table,
            spec={
                'sampling': {'sampling_rate': self.Parameters.sampling_rate},
                'force_transform': True,
                'combine_chunks': True,
                'schema_inference_mode': 'from_output'  # Handle case when raw_src_tables' schema is changed
            }
        )

        self.yt_client.run_sort(
            src_table, destination_table=src_table, sort_by=KEY_COLUMNS
        )

    def create_runner_subtask(self, branch):
        """This method should be implemented in inherited classes"""
        raise NotImplementedError

    def get_b2b_messages(self):
        diff_verdicts_url = 'https://yt.yandex-team.ru/{cluster}/navigation?path={path}'.format(
            cluster=self.Parameters.yt_proxy_url,
            path=os.path.join(self.Context.comparison_dir, 'comparison', 'diff_verdicts')
        )
        message_chunks = [u'Таблица с объектами, на которых не совпали вердикты: {}\n'.format(diff_verdicts_url)]

        for description, dir_name in [
            (u'объектам автомодератора (request_id, sub_id)', 'request'),
            (u'баннерам (meta/banner_id)', 'banner'),
            (u'кампаниям (meta/campaign_id)', 'campaign'),
            (u'клиентам (meta/client_id)', 'client')
        ]:
            global_counts_path = os.path.join(self.Context.comparison_dir, 'comparison', 'global', dir_name, 'count')
            global_counts_rows = sorted(self.yt_client.read_table(global_counts_path), key=operator.itemgetter('type'))
            global_message = u'В разбивке по типам:\n{}'.format(get_stat_table(['type'], global_counts_rows))

            rule_counts_path = os.path.join(self.Context.comparison_dir, 'comparison', 'rule', dir_name, 'count')
            rule_counts_rows= sorted(self.yt_client.read_table(rule_counts_path), key=operator.itemgetter('type', 'rule'))
            rule_diff_table = get_stat_table(['type', 'rule'], rule_counts_rows, hide_tolerance=0.0)
            rule_message = ''
            if rule_diff_table:
                rule_message += u'Фильтры с изменившимися вердиктами:\n{}\n'.format(rule_diff_table)

            message_chunks.append(u'<{{Статистика по {description}\n{global_message}\n{rule_message}}}>'.format(
                description=description,
                global_message=global_message,
                rule_message=rule_message
            ))

        return '\n'.join(message_chunks)

    def create_command(self):
        return [
            './comparison',
            '--yt-cluster', self.Parameters.yt_proxy_url,
            '--yt-src-dir', self.Context.comparison_dir,
            '--automoderator', self.automoderator_name,
            '--yt-memory-limit', self.Parameters.comparison_yt_memory_limit,
            '--yt-max-failed-job-count', self.Parameters.comparison_yt_max_failed_job_count
        ]

    def on_prepare(self):
        from sandbox.projects.modadvert.common.ytutils import yt_connect
        self.yt_client = yt_connect(cluster_url=self.Parameters.yt_proxy_url, yt_token=self.get_yt_token())
        self.component_info = rmc.get_component(self.Parameters.component_name)
        self.st_helper = STHelper(sdk2.Vault.data(rm_const.COMMON_TOKEN_OWNER, rm_const.COMMON_TOKEN_NAME))

    def st_comment(self, message):
        self.st_helper.comment(
            self.Parameters.release_number,
            message,
            self.component_info
        )

    def on_before_execute(self):
        super(ModadvertRunB2B, self).on_before_execute()

        if not self.Context.st_initial_message:
            self.st_comment('B2B started\n{}'.format(get_task_link(self.id)))
            self.Context.st_initial_message = True

        if not self.Context.comparison_dir:
            self.create_comparison_dir()
            self.prepare_requests()

        for branch in ('base', 'feature'):
            subtask_key = '{}_runner_subtask'.format(branch)
            if not getattr(self.Context, subtask_key):
                setattr(self.Context, subtask_key, self.create_runner_subtask(branch))
        self.wait_all_subtasks()

    def on_finish(self, prev_status, status):
        message_lines = ['B2B finished with status {}'.format(status), get_task_link(self.id)]
        if status == ctt.Status.SUCCESS:
            message_lines.append(self.get_b2b_messages())
        self.st_comment('\n'.join(message_lines))

        if status == ctt.Status.SUCCESS:
            self.st_comment('All tests passed\n' + self.component_info.get_deploy_message(self.Parameters.release_number))

        super(ModadvertRunB2B, self).on_finish(prev_status, status)
