from itertools import chain

from sandbox import sdk2
from sandbox.projects.common.yabs.server.util.general import check_tasks
from sandbox.projects.yabs.qa.utils import task_run_type
from sandbox.projects.yabs.qa.pipeline.stage import stage
from sandbox.projects.yabs.qa.pipeline_test_framework.helpers import get_stat_base_tags_to_create, get_meta_base_tags_to_create
from sandbox.projects.yabs.qa.spec.misc import filter_not_ok_tasks
from sandbox.projects.yabs.qa.tasks.YabsServerCreateOneShotSpec.spec import OneShotSpec
from sandbox.projects.yabs.qa.tasks.YabsServerBaseSizeAggregate import YabsServerBaseSizeAggregate


def launch_make_bin_bases_task(
        task,
        gen_bin_bases_flags_resource_id,
        mysql_archive_resource_id, cs_input_spec_resource_id, cs_settings_archive_resource_id,
        bs_release_yt_resource_id, bs_release_tar_resource_id,
        cs_import_ver,
        base_tags_to_create,
):
    add_base_options = getattr(sdk2.Resource[gen_bin_bases_flags_resource_id], 'add_options', None)
    options_list = add_base_options.strip() if add_base_options is not None else ''

    bin_db_set = set(chain(*base_tags_to_create.values()))

    make_bin_bases_task = sdk2.Task['YABS_SERVER_MAKE_BIN_BASES'](
        task,
        description='Make all bases, oneshot spec generation',
        input_spec=cs_input_spec_resource_id,
        settings_archive=cs_settings_archive_resource_id,
        mysql_archive_contents=mysql_archive_resource_id,
        bs_release_yt_resource=bs_release_yt_resource_id,
        server_resource=bs_release_tar_resource_id,
        options_list=options_list,
        cs_import_ver=cs_import_ver,
        bin_db_list=' '.join(bin_db_set),
        use_cs_cycle=True,
        do_not_restart=True,
        glue_reduce=True,
        reuse_existing_bases=True,
        use_save_input_from_cs=False,
        filter_input_archive_tables_by_orderid=True,
        tags=task.Parameters.tags + [task_run_type.SANDBOX_TASK_TAGS['create_oneshot_spec']],
        **base_tags_to_create
    ).enqueue()
    return make_bin_bases_task.id


@stage(provides=('stat_make_bin_bases_task_id', 'meta_make_bin_bases_task_id'))
def launch_make_bin_bases(
        task,
        gen_bin_bases_flags_resource_id,
        mysql_archive_resource_id, cs_input_spec_resource_id, cs_settings_archive_resource_id,
        stat_bs_release_yt_resource_id, stat_bs_release_tar_resource_id,
        meta_bs_release_yt_resource_id, meta_bs_release_tar_resource_id,
        stat_cs_import_ver,
        meta_cs_import_ver,
        stat_base_tags_map,
        meta_base_tags_map,
        ft_shard_keys,
        stat_load_shard_keys,
        meta_load_shard_keys,
):
    shard_keys = list(set(ft_shard_keys + stat_load_shard_keys + meta_load_shard_keys))
    stat_base_tags_to_create = get_stat_base_tags_to_create(stat_base_tags_map, shard_keys)
    meta_base_tags_to_create = get_meta_base_tags_to_create(meta_base_tags_map)

    base_tags_to_create = dict(stat_base_tags_to_create)
    base_tags_to_create.update(meta_base_tags_to_create)

    return (
        launch_make_bin_bases_task(
            task,
            gen_bin_bases_flags_resource_id,
            mysql_archive_resource_id, cs_input_spec_resource_id, cs_settings_archive_resource_id,
            stat_bs_release_yt_resource_id, stat_bs_release_tar_resource_id,
            stat_cs_import_ver,
            stat_base_tags_to_create,
        ),
        launch_make_bin_bases_task(
            task,
            gen_bin_bases_flags_resource_id,
            mysql_archive_resource_id, cs_input_spec_resource_id, cs_settings_archive_resource_id,
            meta_bs_release_yt_resource_id, meta_bs_release_tar_resource_id,
            meta_cs_import_ver,
            meta_base_tags_to_create,
        ),
    )


def get_binary_bases_from_task(make_bin_bases_task_id):
    binary_base_resource_ids = sdk2.Task[make_bin_bases_task_id].Context.bin_base_res_ids
    binary_base_resources = sdk2.Resource.find(id=binary_base_resource_ids).limit(len(binary_base_resource_ids))
    binary_base_resource_id_by_tag = {
        resource.tag: resource.id
        for resource in binary_base_resources
    }
    return binary_base_resource_id_by_tag


@stage(provides=('stat_binary_base_resource_id_by_tag', 'meta_binary_base_resource_id_by_tag'))
def get_binary_bases(
        self,
        stat_make_bin_bases_task_id,
        meta_make_bin_bases_task_id,
):
    check_tasks(self, [stat_make_bin_bases_task_id, meta_make_bin_bases_task_id])

    return (
        get_binary_bases_from_task(stat_make_bin_bases_task_id),
        get_binary_bases_from_task(meta_make_bin_bases_task_id),
    )


@stage(provides=('stat_chkdb_task_id', 'meta_chkdb_task_id'))
def launch_chkdb(
        task,
        ft_shard_keys,
        stat_base_tags_map,
        meta_base_tags_map,
        stat_binary_base_resource_id_by_tag,
        meta_binary_base_resource_id_by_tag,
):
    meta_bases = {
        'base_resources_meta_{meta_mode}'.format(meta_mode=meta_mode): [
            meta_binary_base_resource_id_by_tag[tag]
            for tag in meta_base_tags_map['base_tags_meta_{meta_mode}'.format(meta_mode=meta_mode)]
        ]
        for meta_mode in ('bs', 'bsrank', 'yabs')
    }
    stat_bases = {
        'base_resources_stat_{meta_mode}'.format(meta_mode=meta_mode): [
            stat_binary_base_resource_id_by_tag[tag]
            for shard in ft_shard_keys + ['COMMON']
            for tag in stat_base_tags_map.get('base_tags_stat_{meta_mode}_{shard}'.format(meta_mode=meta_mode, shard=shard), [])
        ]
        for meta_mode in ('bs', 'bsrank', 'yabs')
    }

    return (
        YabsServerBaseSizeAggregate(
            task,
            tags=task.Parameters.tags,
            description='stat chkdb | oneshot spec generation',
            **stat_bases
        ).enqueue().id,
        YabsServerBaseSizeAggregate(
            task,
            tags=task.Parameters.tags,
            description='meta chkdb | oneshot spec generation',
            **meta_bases
        ).enqueue().id,
    )


@stage(provides=('broken_subtasks', 'failed_subtasks'), force_run=True)
def check_subtasks_status(
        task,
        stat_chkdb_task_id,
        meta_chkdb_task_id,
        ft_shoot_tasks,
        ft_validation_tasks,
        ft_stability_shoot_tasks,
        meta_load_baseline_tasks,
        stat_load_baseline_tasks,
):
    tasks_to_check = list(chain.from_iterable(
        tasks.values()
        for tasks in ft_shoot_tasks.values() + ft_validation_tasks.values()
    ))
    tasks_to_check.extend([stat_chkdb_task_id, meta_chkdb_task_id])

    if task.Parameters.sanity_check:
        tasks_to_check.extend(
            list(chain.from_iterable(
                tasks.values()
                for tasks in ft_stability_shoot_tasks.values()
            ))
        )

    if task.Parameters.stat_load:
        tasks_to_check.extend(
            list(chain.from_iterable(
                meta_mode_shoot_tasks.values()
                for meta_mode_shoot_tasks in stat_load_baseline_tasks.values()
            ))
        )

    if task.Parameters.meta_load:
        tasks_to_check.extend(meta_load_baseline_tasks.values())

    return filter_not_ok_tasks(task, tasks_to_check)


@stage(provides='spec')
def get_spec_data(
        task,
        stat_bs_release_tar_resource_id,
        stat_bs_release_yt_resource_id,
        meta_bs_release_tar_resource_id,
        meta_bs_release_yt_resource_id,
        cs_input_spec_resource_id,
        cs_settings_archive_resource_id,
        dolbilka_plan_resource_id_map,
        ft_request_log_resource_id_map,
        ft_shoot_settings_resource_id,
        ft_shoot_tasks,
        ft_validation_tasks,
        gen_bin_bases_flags_resource_id,
        linear_models_binary_resource_id,
        load_shoot_settings_resource_id,
        meta_load_baseline_tasks,
        meta_load_shoot_settings_resource_id,
        mysql_archive_resource_id,
        stat_setup_ya_make_task_id,
        meta_setup_ya_make_task_id,
        shard_map_resource_id,
        stat_load_baseline_tasks,
        stat_load_request_log_resource_id_map,
        meta_load_request_log_resource_id_map,
        ext_service_endpoint_resource_ids,
        use_separated_meta_and_stat=True,
):
    spec_obj = OneShotSpec(task.Context, use_separated_meta_and_stat=use_separated_meta_and_stat)
    spec = spec_obj.as_dict()
    return spec
