# -*- coding: utf-8 -*-
import logging
from datetime import timedelta
from collections import defaultdict

from sandbox import sdk2
from sandbox.common.errors import TaskFailure

from sandbox.projects.yabs.qa.tasks.YabsServerRunCSImportWrapper import YabsServerRunCSImportWrapper, ImportWrapperInputParameters
from sandbox.projects.yabs.qa.tasks.YabsServerRealRunCSImport import IMPORT_DIGEST_KEY
from sandbox.projects.yabs.qa.utils.general import get_task_html_hyperlink
from sandbox.projects.common.yabs.server.util.general import check_tasks

from sandbox.projects.common.yabs.server.tracing import TRACE_WRITER_FACTORY
from sandbox.projects.yabs.sandbox_task_tracing import trace, trace_entry_point
from sandbox.projects.yabs.sandbox_task_tracing.wrappers.sandbox.generic import enqueue_task


logger = logging.getLogger(__name__)


# Task launches multiple YABS_SERVER_RUN_CS_IMPORT_WRAPPER tasks with digest calculation
# and then compares received digests

class YabsServerCSImportTestStability(sdk2.Task):  # pylint: disable=R0904

    name = 'YABS_SERVER_CS_IMPORT_TEST_STABILITY'

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 1024

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = timedelta(hours=10).total_seconds()

        with sdk2.parameters.Group('Test parameters') as test_parameters:
            number_of_checks = sdk2.parameters.Integer('Number of cs_import-s to launch', default_value=5, required=True)
            resource_check_mode = sdk2.parameters.Bool('Launch task only if context contains testenv_resource_check', default_value=False)
            fail_on_diff = sdk2.parameters.Bool('Raise task error if stability failed', default_value=True)

        with sdk2.parameters.Group('Subtasks parameters') as subtasks_parameters:
            cs_import_parameters = ImportWrapperInputParameters()  # reusing parameters class

        with sdk2.parameters.Group('Developer parameters') as dev_params:
            arcadia_patch = sdk2.parameters.String('Arcadia patch', default='')

    class Context(sdk2.Task.Context):
        tasks_by_digests = {}  # digest: [task_id, task_id, ...]
        testenv_resource_check = False
        report = ''

    def _generate_report_by_digests(self, message):
        report = [message, '']
        for digest, tasks in self.Context.tasks_by_digests.items():
            report.append('Digest: {0} | Tasks: {1}'.format(digest, tasks))
        report_message = '\n'.join(report)
        self.Context.report += report_message
        self.set_info(report_message)

    @trace_entry_point(writer_factory=TRACE_WRITER_FACTORY)
    def on_execute(self):
        if 'TESTENV-PRECOMMIT-CHECK' in self.Parameters.tags and 'WITHOUT-PATCH' in self.Parameters.tags:
            self.set_info('Do nothing in precommit check run without patch')
            return

        full_db_list = self.Parameters.bin_db_list.split()
        tags = []
        for db in full_db_list:
            if not db.startswith('st_update'):
                tags.append(db)

        tags = list(set(tags))

        if int(self.Parameters.number_of_checks) < 2:
            TaskFailure('Test needs to be launched with at least 2 subtasks!')

        with self.memoize_stage.launch_subtasks(commit_on_entrance=False), trace('launch_subtasks'):
            logger.info('Launching {0} subtasks'.format(self.Parameters.number_of_checks))

            subtask_parameters = dict(self.Parameters.cs_import_parameters)
            subtask_parameters.update(
                reuse_import_results=False,  # can't reuse here
                calc_digest=True,  # ensure digest calculation
                wait_digest=True,  # wait for digest calculation
                bin_db_list=' '.join(tags),

            )
            self.Context.subtask_ids = []
            for _ in range(self.Parameters.number_of_checks):
                subtask = YabsServerRunCSImportWrapper(
                    self,
                    description='Launched for stability test (subtask of {0})'.format(get_task_html_hyperlink(self.id)),
                    owner=self.owner,
                    tags=self.Parameters.tags,
                    __requirements__={'tasks_resource': self.Requirements.tasks_resource},
                    **subtask_parameters
                )
                enqueue_task(subtask.save())
                self.Context.subtask_ids.append(subtask.id)

        logger.info('Waiting for subtasks')
        check_tasks(self, self.Context.subtask_ids)

        logger.info('Reading received digests')
        tasks_by_digests = defaultdict(list)
        for subtask_id in self.Context.subtask_ids:
            subtask_digest = getattr(sdk2.Task[subtask_id].Context, IMPORT_DIGEST_KEY, None)
            tasks_by_digests[subtask_digest].append(subtask_id)
        self.Context.tasks_by_digests = dict(tasks_by_digests)

        logger.info('Checking received digests and generating report')
        if None in tasks_by_digests.keys():
            message = 'No digest was received from tasks: {0}'.format(tasks_by_digests[None])
            self._generate_report_by_digests(message)
            if self.Parameters.fail_on_diff:
                raise TaskFailure(message)
        elif len(tasks_by_digests.keys()) != 1:
            message = 'Different digests was received from tasks!'
            self._generate_report_by_digests(message)
            if self.Parameters.fail_on_diff:
                raise TaskFailure(message)
        else:
            digest = tasks_by_digests.keys()[0]
            message = 'Test passed, all digests are equal to {0}'.format(digest)
            self.Context.report += message
            self.set_info(message)
