import logging

from datetime import timedelta

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types.task import ReleaseStatus, Status
from sandbox.common.utils import get_task_link
from sandbox.projects.common.yabs.server.util.general import check_tasks, CustomAssert
from sandbox.projects.common.yabs.server.util import truncate_output_parameters

from sandbox.projects.yabs.base_bin_task import BaseBinTaskMixin, base_bin_task_parameters
from sandbox.projects.yabs.qa.errorbooster.decorators import track_errors
from sandbox.projects.yabs.qa.tasks.YabsServerB2BFuncShoot2 import (
    YabsServerB2BFuncShoot2,
    YabsServerB2BFuncShoot2Parameters,
    calc_input_parameters_hash,
)
from sandbox.projects.yabs.qa.tasks.YabsServerB2BFuncShootCmp import YabsServerB2BFuncShootCmp
from sandbox.projects.yabs.qa.tasks.YabsServerUploadShootResultToYt import YTUploadParameters

from sandbox.projects.common.yabs.server.tracing import TRACE_WRITER_FACTORY
from sandbox.projects.yabs.sandbox_task_tracing import trace_entry_point
from sandbox.projects.yabs.sandbox_task_tracing.wrappers.sandbox.generic import enqueue_task


logger = logging.getLogger(__name__)


class YabsServerB2BFuncShootStability(BaseBinTaskMixin, sdk2.Task):

    '''
    New task for functional b2b stability tests of yabs-server service.
    '''

    name = 'YABS_SERVER_B2B_FUNC_SHOOT_STABILITY'

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 1024

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = int(timedelta(minutes=20).total_seconds())
        push_tasks_resource = True

        _base_bin_task_parameters = base_bin_task_parameters(
            release_version_default=ReleaseStatus.STABLE,
            resource_attrs_default={"task_bundle": "yabs_server_shoot"},
        )

        func_shoot_parameters = YabsServerB2BFuncShoot2Parameters()
        yt_upload_parameters = YTUploadParameters()
        with sdk2.parameters.Group('Stability test parameters') as stability_parameters:
            stability_runs = sdk2.parameters.Integer('Number of stability runs', default_value=4)
            reuse_shoot_task = sdk2.parameters.Bool('Reuse one of shoot tasks by input parameters hash', default=True)

    class Context(sdk2.Task.Context):
        shoot_tasks = []
        cmp_tasks = []
        finished_cmp_tasks = []
        failed_cmp_tasks = []
        cmp_tasks_to_wait = []

    @track_errors
    @trace_entry_point(writer_factory=TRACE_WRITER_FACTORY)
    def on_execute(self):
        input_parameters_hash = calc_input_parameters_hash(self.server, self.id)
        logger.debug('Input parameters hash is %s', input_parameters_hash)

        CustomAssert(self.Parameters.stability_runs >= 2, 'Invalid stablity runs parameters, should be >= 2', TaskFailure)
        common_task_params = {
            'tags': self.Parameters.tags,
            'hints': list(self.hints),
            'response_dumps_ttl': 1,
        }
        if not self.Context.shoot_tasks:
            for idx in range(self.Parameters.stability_runs):
                shoot_task_parameters = dict(self.Parameters.func_shoot_parameters)
                shoot_task_parameters.update(dict(self.Parameters.yt_upload_parameters))
                shoot_task_params = dict(
                    description='Stability run launched by task #{}'.format(self.id),
                    **truncate_output_parameters(shoot_task_parameters, YabsServerB2BFuncShoot2.Parameters)
                )
                if self.Parameters.reuse_shoot_task and idx == 0:
                    shoot_task = YabsServerB2BFuncShoot2.find(
                        input_parameters={'allow_reuse_of_this_task': True},
                        output_parameters={'input_parameters_hash': input_parameters_hash},
                        hidden=True,
                        children=True,
                    ).limit(1).first()
                    if shoot_task:
                        self.Context.shoot_tasks.append(shoot_task.id)
                        self.set_info('Will reuse <a href="{task_link}", target="_blank">{task.type} #{task.id}</a>'.format(task_link=get_task_link(shoot_task.id), task=shoot_task), do_escape=False)
                        continue

                shoot_task_params.update(common_task_params)
                shoot_task = YabsServerB2BFuncShoot2(self, **shoot_task_params)
                enqueue_task(shoot_task)
                self.Context.shoot_tasks.append(shoot_task.id)

        check_tasks(self, self.Context.shoot_tasks)

        if not self.Context.cmp_tasks:
            for cmp_index in range(self.Parameters.stability_runs):
                cmp_task = YabsServerB2BFuncShootCmp(
                    self,
                    description='Stability cmp launched by task #{}'.format(self.id),
                    pre_task=self.Context.shoot_tasks[cmp_index],
                    test_task=self.Context.shoot_tasks[(cmp_index + 1) % self.Parameters.stability_runs],
                    **common_task_params
                )
                enqueue_task(cmp_task)
                self.Context.cmp_tasks.append(cmp_task.id)

        self.Context.cmp_tasks_to_wait = self.Context.cmp_tasks

        while self.Context.cmp_tasks_to_wait:
            subtasks = [sdk2.Task[task_id] for task_id in self.Context.cmp_tasks_to_wait]
            self.Context.cmp_tasks_to_wait = []
            for task in subtasks:
                logger.debug('Task %d is in status %s', task.id, task.status)
                if task.status == Status.SUCCESS:
                    self.Context.finished_cmp_tasks = list(set(self.Context.finished_cmp_tasks + [task.id]))
                elif task.status in Status.Group.BREAK | {Status.FAILURE}:
                    self.Context.failed_cmp_tasks = list(set(self.Context.failed_cmp_tasks + [task.id]))
                else:
                    self.Context.cmp_tasks_to_wait = list(set(self.Context.cmp_tasks_to_wait + [task.id]))

            if len(self.Context.failed_cmp_tasks) > 1:
                self.report_failure(self.Context.failed_cmp_tasks, 'failed')
            if len(self.Context.finished_cmp_tasks) >= self.Parameters.stability_runs - 1:
                break

            check_tasks(self, self.Context.cmp_tasks_to_wait, wait_all=False, raise_on_fail=False)

        cmp_task_ids_with_diff = []
        for cmp_task_id in self.Context.finished_cmp_tasks:
            if sdk2.Task[cmp_task_id].Context.has_diff:
                cmp_task_ids_with_diff.append(cmp_task_id)

        if cmp_task_ids_with_diff:
            self.report_failure(cmp_task_ids_with_diff, 'found unexpected diff')

    def report_failure(self, tasks, desc):
        self.set_info(
            '{count} stability compare tasks {desc}, see child tasks {tasks} for more info'.format(
                desc=desc,
                count=len(tasks),
                tasks=", ".join([
                    "<a href=\"{task_link}\" target=\"_blank\">{task_id}</a>".format(
                        task_id=task_id,
                        task_link=get_task_link(task_id),
                    ) for task_id in tasks
                ]),
            ),
            do_escape=False,
        )
        raise TaskFailure('Stability run {}'.format(desc))
