from itertools import chain

from sandbox.projects.common.yabs.server.util.general import check_tasks
from sandbox.projects.yabs.qa.pipeline_test_framework.helpers import _launch_task
from sandbox.projects.yabs.qa.tasks.YabsServerB2BFuncShootCmp import YabsServerB2BFuncShootCmp


def _launch_ft_cmp_tasks(task, baseline_ft_shoot_tasks, test_ft_shoot_tasks, cmp_task_update_parameters=None, description=''):
    check_tasks(task, list(chain(*(
        [
            baseline_ft_shoot_tasks[meta_mode].values() for meta_mode in task.Parameters.ft_meta_modes
        ] + [
            test_ft_shoot_tasks[meta_mode].values() for meta_mode in task.Parameters.ft_meta_modes
        ]
    ))))
    return {
        meta_mode: {
            ft_shard_num: _launch_task(
                task,
                YabsServerB2BFuncShootCmp,
                description='ft cmp {} | {} {}'.format(meta_mode, task.Parameters.description, description),
                pre_task=baseline_ft_shoot_tasks[meta_mode][ft_shard_num],
                test_task=test_ft_shoot_tasks[meta_mode][ft_shard_num],
                compare_statuses=True,
                __requirements__={'tasks_resource': task.Requirements.tasks_resource},
                **(cmp_task_update_parameters or {})
            )
            for ft_shard_num in set(test_ft_shoot_tasks[meta_mode].keys()).intersection(set(baseline_ft_shoot_tasks[meta_mode].keys()))
        }
        for meta_mode in task.Parameters.ft_meta_modes
    }
