import logging

from sandbox import sdk2
from sandbox import sandboxsdk
from sandbox.common.errors import TaskFailure
from sandbox.common.types.task import Status

import sandbox.projects.release_machine.components.all as rmc
import sandbox.projects.release_machine.core.const as rm_const
import sandbox.projects.release_machine.core.task_env as task_env
import sandbox.projects.release_machine.helpers.startrek_helper as st_helper

from sandbox.projects.websearch.begemot.tasks.BegemotYT.BegemotReducer import BegemotReducer
from sandbox.projects.websearch.begemot.tasks.BegemotYT.common import CommonYtParameters
from sandbox.projects.websearch.begemot.tasks.BegemotYT.MiddleSearchCacheHitGuess2 import MiddleSearchCacheHitGuess2
from sandbox.projects.websearch.begemot.tasks.BegemotYT.paths import BegemotYtPaths


def get_begemot_release_tag():
    begemot_executable = sdk2.Resource["BEGEMOT_EXECUTABLE"].find(state='READY', attrs={'released': 'stable'}).first()
    release_task = sdk2.Task.find(id=begemot_executable.task_id).first()
    logging.debug("Previous begemot release task is #{}".format(release_task.id))
    return release_task.Context.checkout_arcadia_from_url.split('stable-')[1].split('-')[0]


class BegemotStabilityCheck(sdk2.Task):

    class Parameters(sdk2.Parameters):
        release_run=sdk2.parameters.Bool(
            'Compare with last released begemot',
            default=False
        )
        write_to_ticket=sdk2.parameters.Bool(
            'Write about diff to priemka ticket',
            default=False
        )
        task_to_clone=sdk2.parameters.Integer(
            'BEGEMOT_YT_RESPONSES task to run with new input',
            required=False
        )
        task_to_compare=sdk2.parameters.Integer(
            'Task to compare with (its input will be used in task run)',
            required=True
        )
        full_diff=sdk2.parameters.Bool(
            'Get full diff (not middle search cache hit only)',
            default=True
        )
        acceptable_response_change=sdk2.parameters.Float(
            'Acceptable response change, %',
            required=True,
            default=1.0,
            description='If more responses change, task will fail'
        )
        yt_token_vault_name = CommonYtParameters.yt_token_vault_name()
        yt_token_vault_owner = CommonYtParameters.yt_token_vault_owner()
        yt_proxy = CommonYtParameters.yt_proxy()
        yt_pool = CommonYtParameters.yt_pool()

    class Requirements(sdk2.Task.Requirements):
        environments = [
            task_env.TaskRequirements.startrek_client,
            sandboxsdk.environments.PipEnvironment('yandex-yt', version='0.10.8'),
        ]
        client_tags = task_env.TaskTags.startrek_client

    def get_branch_tag(self):
        tags = self.server.task[self.id].read()["tags"]
        for tag in tags:
            try:
                branch = tag.split("-")[-1]
                if branch == "TRUNK":
                    return tag
                branch = int(branch)
                return tag
            except:
                pass
        logging.debug("Task failed to find current branch tag")
        return None

    def find_release_test_task(self, cur_tag):
        release_tag = cur_tag.replace(cur_tag.split('-')[-1], get_begemot_release_tag())
        logging.debug("Try to find an old task for tag {}".format(release_tag))
        task = sdk2.Task.find(task_type=BegemotReducer, tags=[release_tag], hidden=True, status=Status.SUCCESS).first()
        try:
            return task.id
        except:
            logging.debug("Failed to find previous release testing task")
            return None

    def check_previous_task_yt_resources(self, yt_client, compare_task):
        paths = [compare_task.Parameters.eventlog_table, compare_task.Parameters.answers]
        for path in paths:
            if not yt_client.exists(path):
                self.set_info('YT path {} not found. Rerunning task_to_compare with new data'.format(path))
                return False
        return True

    def write_to_ticket(self, branch, diff):
        if diff < 15:
            result = "**!!(green)OK!!**"
        elif diff < 50:
            result = "**!!(yellow)WARNING!!**"
        else:
            result = "**!!(red)CRITICAL!!**"
        message = '\n'.join([
            "Cache guess task ((https://sandbox.yandex-team.ru/task/{}/view {}))".format(self.id, self.id),
            result,
            "We expect {:.2f}% cache miss after release".format(diff)
        ])

        st = st_helper.STHelper(sdk2.Vault.data(rm_const.COMMON_TOKEN_OWNER, rm_const.COMMON_TOKEN_NAME))
        c_info = rmc.COMPONENTS["begemot"]()
        st.write_grouped_comment(
            "====Cache hit test",
            "",
            message,
            branch,
            c_info,
        )

    def on_execute(self):
        with self.memoize_stage.get_tasks(commit_on_entrance=False):
            if self.Parameters.write_to_ticket and not self.Parameters.release_run:
                raise TaskFailure("Set release run mode to write results to ticket")
            self.Context.clone_task = self.Parameters.task_to_clone
            self.Context.compare_task = self.Parameters.task_to_compare
            if self.Parameters.release_run:
                self.Context.cur_branch_tag = self.get_branch_tag()
                old_task = self.find_release_test_task(self.Context.cur_branch_tag)
                if self.Parameters.task_to_clone is not None:
                    self.set_info("WARNING: in release run task to clone is chosen automatically. Task to clone changed to #{}".format(old_task))
                self.Context.clone_task = old_task
                if self.Context.clone_task is None:
                    raise TaskFailure("Task failed to find a previous successful BEGEMOT_YT_RESPONSES task")

            if self.Context.clone_task is None:
                self.set_info("Task to clone not found. Task will rerun task to compare to check begemot stability")
                self.Context.clone_task = self.Context.compare_task

        with self.memoize_stage.clone_reducer(commit_on_entrance=False):
            import yt.wrapper as yt
            token = sdk2.Vault.data(self.Parameters.yt_token_vault_owner, self.Parameters.yt_token_vault_name)
            yt_client = yt.YtClient(self.Parameters.yt_proxy, token)

            clone_task = sdk2.Task.find(id=self.Context.clone_task).first()
            compare_task = sdk2.Task.find(id=self.Context.compare_task).first()
            prev_task_valid = self.check_previous_task_yt_resources(yt_client, compare_task)

            tasks_to_wait= []
            prev_output_path = compare_task.Parameters.output_path
            if prev_task_valid:
                eventlog_table = compare_task.Parameters.eventlog_table
                self.Context.prev_answers = compare_task.Parameters.answers
            else:
                count_suffix = compare_task.Parameters.eventlog_table.split(':')[-1]
                eventlog_table = BegemotYtPaths.get_last_eventlog_table() + '[#0:{}'.format(count_suffix)
                old_output_path = prev_output_path.split('BEGEMOT_REDUCER')[0] + 'BEGEMOT_STABILITY_CHECK/{}_old'.format(self.id)
                self.Context.prev_answers = '{}/Merger/answers'.format(old_output_path)
                self.Context.old_reducer = BegemotReducer(
                    self, description='Begemot stability check, rerun old reducer',
                    eventlog_table=eventlog_table,
                    output_path=old_output_path,
                    begemot_mapper=compare_task.Parameters.begemot_mapper,
                    shards=compare_task.Parameters.shards,
                    fresh=compare_task.Parameters.fresh,
                    yt_token_vault_owner=self.Parameters.yt_token_vault_owner,
                    yt_token_vault_name=self.Parameters.yt_token_vault_name,
                    yt_proxy=self.Parameters.yt_proxy,
                    yt_pool=self.Parameters.yt_pool,
                    wait_time=60,
                    job_count=compare_task.Parameters.job_count
                ).enqueue().id
                tasks_to_wait.append(self.Context.old_reducer)

            self.Context.new_output_path = prev_output_path.split('BEGEMOT_REDUCER')[0] + 'BEGEMOT_STABILITY_CHECK/{}'.format(self.id)
            self.Context.new_answers = self.Context.prev_answers.split('BEGEMOT_REDUCER')[0] + 'BEGEMOT_STABILITY_CHECK/{}'.format(self.id)
            self.Context.new_reducer = BegemotReducer(
                self, description='Begemot stability check',
                eventlog_table=eventlog_table,
                output_path=self.Context.new_output_path,
                begemot_mapper=clone_task.Parameters.begemot_mapper,
                shards=clone_task.Parameters.shards,
                fresh=clone_task.Parameters.fresh,
                yt_token_vault_owner=self.Parameters.yt_token_vault_owner,
                yt_token_vault_name=self.Parameters.yt_token_vault_name,
                yt_proxy=self.Parameters.yt_proxy,
                yt_pool=self.Parameters.yt_pool,
                wait_time=60,
                job_count=clone_task.Parameters.job_count
            ).enqueue().id
            tasks_to_wait.append(self.Context.new_reducer)
            raise sdk2.WaitTask(tasks_to_wait, Status.Group.FINISH | Status.Group.BREAK)

        with self.memoize_stage.count_diff(commit_on_entrance=False):
            reducer = sdk2.Task.find(id=self.Context.new_reducer, children=True).first()
            if reducer.status != Status.SUCCESS:
                raise TaskFailure("Begemot reducer failed")
            self.Context.diff_task = MiddleSearchCacheHitGuess2(
                self, description='Begemot stability check',
                begemot_answers_old=self.Context.prev_answers,
                begemot_answers_new=self.Context.new_answers + "/Merger/answers",
                full_check=self.Parameters.full_diff,
                detailed=True,
                limit=10000,
                output_path=self.Context.new_output_path + "/diff_check",
                yt_token_vault_owner=self.Parameters.yt_token_vault_owner,
                yt_token_vault_name=self.Parameters.yt_token_vault_name,
                yt_proxy=self.Parameters.yt_proxy,
                yt_pool=self.Parameters.yt_pool
            ).enqueue().id
            raise sdk2.WaitTask(self.Context.diff_task, Status.Group.FINISH | Status.Group.BREAK)

        with self.memoize_stage.check_diff(commit_on_entrance=False):
            diff_task = sdk2.Task.find(id=self.Context.diff_task, children=True).first()
            if diff_task.status != Status.SUCCESS:
                raise TaskFailure("Diff task failed. Try to restart CACHE_GUESS child task")
            diff_task = sdk2.Task.find(id=self.Context.diff_task).first()
            diff = float(diff_task.Parameters.answers_diff) * 100
            self.Context.answers_diff = diff
            self.set_info("Diff between old and new answers: {}%".format(diff))
            if diff > self.Parameters.acceptable_response_change:
                raise TaskFailure("Diff is too large. See more in CACHE_GUESS child task")

        with self.memoize_stage.write_to_ticket(commit_on_entrance=False):
            if self.Parameters.write_to_ticket:
                testenv_db = self.Context.cur_branch_tag.split('-')[-1]
                if testenv_db != 'TRUNK':
                    self.write_to_ticket(int(testenv_db), diff)
