from sandbox import sdk2
from sandbox.common.types import task as ctt
from sandbox.projects.yabs.qa.mutable_parameters import MutableParameters
from sandbox.projects.yabs.qa.tasks.base_compare_task.parameters import BaseCompareTaskParameters
from sandbox.projects.yabs.qa.tasks.base_compare_task.task import BaseCompareTask
from sandbox.projects.yabs.base_bin_task import BaseBinTaskMixin


class EvenInteger(sdk2.parameters.Integer):
    @classmethod
    def cast(cls, value):
        value = int(value)
        assert value % 2 == 0, 'Parameter must be even'
        return value


def factory(
    shoot_2on1_task_class,
    subtasks_aggregator_class
):
    class TaskParameters(BaseCompareTaskParameters):
        push_tasks_resource = True
        with sdk2.parameters.Group('Compare task settings') as cmp_options:
            subtasks_count = EvenInteger('Number of subtasks', default_value=16)

        common_parameters = shoot_2on1_task_class.CommonParameters()
        subtasks_aggregator_parameters = subtasks_aggregator_class.get_init_parameters_class()()

    class TaskContext(sdk2.Task.Context):
        has_diff = False
        subtasks_data = []

    class TaskClass(BaseBinTaskMixin, BaseCompareTask):
        Parameters = TaskParameters
        Context = TaskContext

        def get_merged_parameters(self, parameters_list):
            first_run_parameters, second_run_parameters = parameters_list

            parameters = MutableParameters.__from_parameters__(self.Parameters.common_parameters)
            for name, value in first_run_parameters:
                parameters.__dict__[name] = value

            for name, value in second_run_parameters:
                parameters.__dict__[name + '_2'] = value

            return parameters

        def enqueue_subtasks(self):
            parameters_list = [
                self.Parameters.pre_task.Parameters,
                self.Parameters.test_task.Parameters,
            ]

            subtasks_list = []

            for index in range(self.Parameters.subtasks_count):
                shuffle_parameters = index % 2 == 1
                description = 'Subtask #{} of {}. {}'.format(
                    index + 1,
                    self.Parameters.subtasks_count,
                    'Test vs pre' if shuffle_parameters else 'Pre vs test'
                )

                subtask_parameters_list = parameters_list if not shuffle_parameters else parameters_list[::-1]
                subtask_parameters = self.get_merged_parameters(subtask_parameters_list)

                subtask = shoot_2on1_task_class(
                    self,
                    description=description,
                    **dict(subtask_parameters)
                )

                subtask.save()
                subtask.enqueue()

                subtasks_list.append(subtask)
                self.Context.subtasks_data.append((subtask.id, shuffle_parameters))

            return subtasks_list

        def wait_subtasks(self, subtasks_list):
            statuses_to_wait = ctt.Status.Group.FINISH | ctt.Status.Group.BREAK
            raise sdk2.WaitTask(
                filter(
                    lambda task: task.status not in statuses_to_wait,
                    subtasks_list
                ),
                statuses_to_wait,
                wait_all=True
            )

        def on_execute(self):
            if len(self.Context.subtasks_data) == 0:
                if self.check_tasks_parameters():
                    return

                subtasks = self.enqueue_subtasks()
                self.wait_subtasks(subtasks)
            else:
                subtasks_aggregator = subtasks_aggregator_class(self.Parameters)

                report, has_diff = subtasks_aggregator.make_report(self.Context.subtasks_data)
                self.set_info(report, do_escape=False)

                with self.memoize_stage.set_output:
                    self.Context.has_diff = self.Parameters.has_diff = has_diff

    return TaskClass
