# -*- coding: utf-8 -*-
import copy
import itertools
import json
import hashlib
import logging
import re
import six
from collections import namedtuple, defaultdict

from sandbox import sdk2
from sandbox.common import rest
from sandbox.common.types.task import Status as TaskStatus
from sandbox.common.types.client import Tag
from sandbox.common.errors import TaskFailure, VaultError

from sandbox.sandboxsdk.environments import SvnEnvironment, PipEnvironment
from sandbox.sandboxsdk.errors import SandboxTaskUnknownError
from sandbox.sandboxsdk.parameters import SandboxBoolParameter, ResourceSelector
from sandbox.sandboxsdk.svn import Arcadia

from sandbox.projects.yabs.qa import resource_types
from sandbox.projects.common.utils import get_or_default

from sandbox.projects.common.yabs.server.db import yt_bases
from sandbox.projects.common.yabs.server.db.task.cs import MySQLArchiveContents, UseSaveInputFromCS, SaveAllInputs, YtPool
import sandbox.projects.common.yabs.server.db.task.basegen as dbtask
import sandbox.projects.common.yabs.server.db.utils as dbutils
from sandbox.projects.common.yabs.server.util.general import check_tasks

from sandbox.projects.yabs.bases import base_producing_task, YabsServerBasesFetch, YabsServerBasesReduce
from sandbox.projects.yabs.qa.tasks.YabsServerRunCSImportWrapper import YabsServerRunCSImportWrapper
from sandbox.projects.common.yabs.server.db.task.basegen import (
    FETCH_PREFIX_KEY,
    GENERATION_ID_KEY,
    BaseTagImporterSettingsVersion,
    BaseTagImporterCodeVersion,
    BaseTagMkdbInfoVersion,
)

from sandbox.projects.yabs.qa.bases.sample_tables.parameters import SamplingStrategyParameter
from sandbox.projects.yabs.qa.resource_types import YABS_CS_IMPORT_OUT_LOCATION
from sandbox.projects.yabs.qa.tasks.YabsServerRealRunCSImport import ReuseImportResults, UseCsCycle
from sandbox.projects.yabs.qa.utils import task_run_type
from sandbox.projects.yabs.qa.utils.subtasks import split_subtasks, check_and_handle_subtask_failure, set_subtasks_info

from sandbox.projects.yabs.qa.solomon.mixin import SolomonTaskMixinParameters
from sandbox.projects.yabs.qa.solomon.push_client import SolomonPushClient

from .fetch_splitter import FetchSplitter
from .helpers import (
    base_mkdb_info_version,
    base_requires_mysql,
    base_to_importer_code_version,
    base_to_importer_settings_version,
    filter_mkdb_info_version_by_bases,
    get_importer_queries,
    intersect_dicts,
    lock_for_search,
    make_child_descr,
)
from .report import create_base_generation_report

from sandbox.projects.common.yabs.server.tracing import TRACE_WRITER_FACTORY
from sandbox.projects.yabs.sandbox_task_tracing import trace_calls, trace_entry_point
from sandbox.projects.yabs.sandbox_task_tracing.wrappers.sandbox.generic import enqueue_task


logger = logging.getLogger()


LOCKS_DIR = '//home/yabs-cs-sandbox/locks/make_bin_bases'
YT_PROXY = 'hahn'
YABS_CS_DISK_SIZE_GB = 10

ATTR_VALUE_DO_NOT_REUSE = 'Do not reuse'


class BASEGEN_TASK_TYPES(object):
    FETCH = YabsServerBasesFetch.YabsServerBasesFetch.type
    REDUCE = YabsServerBasesReduce.YabsServerBasesReduce.type
    IMPORT = YabsServerRunCSImportWrapper.name


class ReuseExistingBases(SandboxBoolParameter):
    name = 'reuse_existing_bases'
    description = 'Reuse binary bases and make generated bases reusable'
    default_value = False


class ImportOutLocation(ResourceSelector):
    name = 'import_out_location_res_id'
    description = 'Resource with cs import output location'
    required = False
    multiple = False
    resource_type = YABS_CS_IMPORT_OUT_LOCATION


class DeleteFetchResults(SandboxBoolParameter):
    name = 'delete_fetch_results'
    description = 'Remove nodes with fetch results after reduce is successfully finished'
    default_value = True


class IsYTOneshot(SandboxBoolParameter):
    name = 'is_yt_oneshot'
    description = 'Apply oneshot to YT'
    default_value = False


class ReuseBasesByImporterSettingsVersion(SandboxBoolParameter):
    name = 'reuse_bases_by_importer_settings_versions'
    description = 'Reuse bases by the versions of importer settings involved in the base generation'
    default_value = True


class ReuseBasesByImporterCodeVersion(SandboxBoolParameter):
    name = 'reuse_bases_by_importer_code_versions'
    description = 'Reuse bases by the versions of importer code involved in the base generation'
    default_value = True


class ReuseBasesByMkdbInfoVersion(SandboxBoolParameter):
    name = 'reuse_bases_by_mkdb_info_versions'
    description = 'Reuse bases by the version of their mkdb_info'
    default_value = True


class ReuseBasesByCSImportVer(SandboxBoolParameter):
    name = 'reuse_bases_by_cs_import_ver'
    description = 'Reuse bases by cs_import_ver'
    default_value = False


class SolomonParameters(SolomonTaskMixinParameters):
    solomon_project = SolomonTaskMixinParameters.solomon_project(default="yabs_testing")
    solomon_cluster = SolomonTaskMixinParameters.solomon_cluster(default="bases")
    solomon_service = SolomonTaskMixinParameters.solomon_service(default="reuse")
    solomon_token_vault_name = SolomonTaskMixinParameters.solomon_token_vault_name(default="robot-yabs-cs-sb-mon-solomon-token")


class BaseTagsList(sdk2.parameters.JSON):
    default_value = []

    @classmethod
    def cast(cls, value):
        value = sdk2.parameters.JSON.cast(value)
        if not isinstance(value, list):
            raise ValueError("This parameter supports only list")
        if any([isinstance(item, (list, dict)) for item in value]):
            raise ValueError("This parameter does not support nested lists or objects")
        return map(str, value)


class BaseTagsMetaBs(BaseTagsList):
    name = 'base_tags_meta_bs'
    description = 'Base tags for bs meta'


class BaseTagsMetaBsrank(BaseTagsList):
    name = 'base_tags_meta_bsrank'
    description = 'Base tags for bsrank meta'


class BaseTagsMetaYabs(BaseTagsList):
    name = 'base_tags_meta_yabs'
    description = 'Base tags for yabs meta'


class BaseTagsStatBs(BaseTagsList):
    name = 'base_tags_stat_bs'
    description = 'Base tags for bs stat'


class BaseTagsStatBsrank(BaseTagsList):
    name = 'base_tags_stat_bsrank'
    description = 'Base tags for bsrank stat'


class BaseTagsStatYabs(BaseTagsList):
    name = 'base_tags_stat_yabs'
    description = 'Base tags for yabs stat'


class BinDbListBs(sdk2.parameters.String):
    name = 'bin_db_list_bs'
    default_value = ''
    default = ''
    description = 'Separated base tags for bs (deprecated)'


class BinDbListBsrank(sdk2.parameters.String):
    name = 'bin_db_list_bsrank'
    default_value = ''
    default = ''
    description = 'Separated base tags for bsrank (deprecated)'


class BinDbListYabs(sdk2.parameters.String):
    name = 'bin_db_list_yabs'
    default_value = ''
    default = ''
    description = 'Separated base tags for yabs (deprecated)'


BIN_DB_LIST_FIELDS_TO_OUTPUT_FIELDS = {
    dbtask.BinDbList.name: dbtask.BIN_BASE_RES_IDS_KEY,
    BinDbListBs.name: dbtask.BIN_BASE_RES_IDS_KEY + '_bs',
    BinDbListYabs.name: dbtask.BIN_BASE_RES_IDS_KEY + '_yabs',
    BinDbListBsrank.name: dbtask.BIN_BASE_RES_IDS_KEY + '_bsrank',

    BaseTagsMetaBs.name: 'base_resources_meta_bs',
    BaseTagsMetaBsrank.name: 'base_resources_meta_bsrank',
    BaseTagsMetaYabs.name: 'base_resources_meta_yabs',
    BaseTagsStatBs.name: 'base_resources_stat_bs',
    BaseTagsStatBsrank.name: 'base_resources_stat_bsrank',
    BaseTagsStatYabs.name: 'base_resources_stat_yabs',
}


class YabsServerMakeBinBases(dbtask.BinBasesTask):  # pylint: disable=R0904

    type = 'YABS_SERVER_MAKE_BIN_BASES'

    execution_space = 20 * 1024  # We sync BS_RELEASE_YT, 20 GiB will be more than enough
    required_ram = 20 * 1024

    client_tags = Tag.LINUX_PRECISE & Tag.GENERIC

    environment = (
        SvnEnvironment(),
        PipEnvironment('networkx', version='2.2', use_wheel=True),
        PipEnvironment('yandex-yt', use_wheel=True)
    ) + dbtask.BinBasesTask.environment

    input_parameters = dbtask.BinBasesTask.input_parameters + (
        BaseTagsMetaBs,
        BaseTagsMetaBsrank,
        BaseTagsMetaYabs,
        BaseTagsStatBs,
        BaseTagsStatBsrank,
        BaseTagsStatYabs,
        BinDbListBs,
        BinDbListBsrank,
        BinDbListYabs,
        ReuseExistingBases,
        ImportOutLocation,
        DeleteFetchResults,
        ReuseImportResults,
        UseCsCycle,
        YabsServerBasesFetch.ReuseFetchResults,
        ReuseBasesByImporterSettingsVersion,
        ReuseBasesByImporterCodeVersion,
        ReuseBasesByMkdbInfoVersion,
        ReuseBasesByCSImportVer,
        IsYTOneshot,
    ) + tuple(SolomonParameters)

    def get_generation_id(self):
        if not get_or_default(self.ctx, ReuseExistingBases):
            return self.id
        arcadia_patch = self.ctx.get('arcadia_patch', '')
        if arcadia_patch:
            return hashlib.md5(arcadia_patch).hexdigest()
        return 0

    @trace_calls
    def _get_fetch_prefixes(self, skip_reused=True):
        rest_client = rest.Client()
        child_tasks_id_list = [item["id"] for item in rest_client.task[self.id].children.read()["items"]]
        fetch_prefixes = set()

        for fetch_id in child_tasks_id_list:
            ctx = rest_client.task[fetch_id].context.read()
            fetch_prefix = ctx.get(FETCH_PREFIX_KEY, None)
            if skip_reused and ctx.get("fetch_reused"):
                logger.debug("Skip reused fetch result: %s", fetch_prefix)
                continue
            if fetch_prefix:
                fetch_prefixes.add(fetch_prefix)
        return fetch_prefixes

    @trace_calls
    def _delete_fetch_results(self, yt_client):
        from yt.wrapper import YtHttpResponseError

        fetch_prefixes = self._get_fetch_prefixes()
        for node in fetch_prefixes:
            try:
                yt_client.remove(node, recursive=True)
            except YtHttpResponseError as e:
                logging.warning('Failed to remove node %s: %s', node, e.message)
            else:
                logging.info('Remove fetch results node: %s', node)

    def create_report(self, base_generation_info):
        try:
            report = create_base_generation_report(base_generation_info, self.id)
        except:
            logging.error("Cannot create report", exc_info=True)
            return

        self.set_info(report, do_escape=False)

    @trace_entry_point(writer_factory=TRACE_WRITER_FACTORY)
    def on_execute(self):
        from yt.wrapper import YtClient

        self._ensure_switcher_revision_param()
        self.check_and_wait_tasks()  # We might have scheduled GetSQLArchive from self.process_archive_contents()
        self._check_children()

        input_spec = self.read_json_resource(self.input_spec_res_id)
        yt_token = dbutils.get_yabscs_yt_token(self)
        yt_client = YtClient(config={'token': yt_token, 'proxy': {'url': YT_PROXY}})
        if not self.ctx.get('input_archive_pinged', False):
            yt_bases.renew_input_spec_expiration_time(yt_client, input_spec, ttl=yt_bases.DEFAULT_CS_INPUT_ARCHIVE_TTL)
            self.ctx['input_archive_pinged'] = True

        resources, resources_by_internal_ver, tasks_to_wait = self._collect()

        if tasks_to_wait:
            self.set_info("Waiting tasks {} to complete".format(', '.join(str(i) for i in tasks_to_wait)))
            check_tasks(self, tasks_to_wait, wait_all=False)
            # if all tasks_to_wait are already finished we need to compute resources list again
            resources, resources_by_internal_ver, tasks_to_wait = self._collect()

        self.set_info("All bases ready.")
        if get_or_default(self.ctx, DeleteFetchResults):
            logging.info('Deleting fetch results')
            self._delete_fetch_results(yt_client)

        tags_by_role = {
            field: self.get_db_list(field)
            for field in BIN_DB_LIST_FIELDS_TO_OUTPUT_FIELDS.keys()
        }
        res_ids_by_role = {
            field: []
            for field in BIN_DB_LIST_FIELDS_TO_OUTPUT_FIELDS.keys()
        }
        base_resource_ids = []

        for resources_dict in (resources, resources_by_internal_ver, ):
            for tag, res_id in resources_dict.iteritems():
                base_resource_ids.append(res_id)
                for field, tags in tags_by_role.items():
                    if tag in tags:
                        res_ids_by_role[field].append(res_id)

        for field, res_ids in res_ids_by_role.items():
            self.ctx[BIN_DB_LIST_FIELDS_TO_OUTPUT_FIELDS[field]] = sorted(res_ids)

        try:
            base_generation_info = get_base_generation_info(base_resource_ids)
            self.create_report(base_generation_info)
        except Exception as err:
            logger.exception('Failed to create base_gen report')
            self.set_info('Failed to create base_gen report, see logs for more details')

        try:
            self.send_metrics(base_generation_info)
        except Exception as err:
            message = 'Failed to push metrics to solomon'
            logger.exception(message)
            self.set_info('{msg}: {err}'.format(msg=message, err=err))

    @trace_calls
    def send_metrics(self, base_generation_info):
        solomon_token_vault_name = get_or_default(self.ctx, SolomonParameters.solomon_token_vault_name)
        solomon_token = get_solomon_token(solomon_token_vault_name)

        solomon_push_client = SolomonPushClient(
            project=get_or_default(self.ctx, SolomonParameters.solomon_project),
            token=solomon_token,
            default_cluster=get_or_default(self.ctx, SolomonParameters.solomon_cluster),
            default_service=get_or_default(self.ctx, SolomonParameters.solomon_service),
            solomon_hostname=get_or_default(self.ctx, SolomonParameters.solomon_hostname),
            max_sensors_per_push=get_or_default(self.ctx, SolomonParameters.max_sensors_per_push),
            spack=get_or_default(self.ctx, SolomonParameters.spack),
        )

        reused_bases = []
        generated_bases = []
        for base in base_generation_info:
            if base.base_generation_task_id == self.id:
                generated_bases.append(base.base_tag)
            else:
                reused_bases.append(base.base_tag)

        run_type = task_run_type.get_task_run_type(self.tags)
        testenv_database = task_run_type.get_task_testenv_database(self.tags)
        metrics = collect_metrics(
            reused_bases,
            generated_bases,
            run_type,
            testenv_database
        )
        solomon_push_client.add(metrics)
        solomon_push_client.push_collected()

    def get_db_list(self, field=None):
        def _get_db_list(ctx, field):
            field_value = ctx.get(field) or "[]"
            if isinstance(field_value, list):
                bases = field_value
            elif isinstance(field_value, six.string_types):
                try:
                    bases = json.loads(field_value)
                except ValueError:
                    logger.debug('Parsing old format field=%s, value=`%s`', field, field_value)
                    bases = field_value.split()
                except Exception:
                    logger.debug('Cannot parse field=%s, value=`%s`', field, field_value)

            logger.debug('%s bases are: %s', field, list(bases))
            return bases

        if field is not None:
            return _get_db_list(self.ctx, field)

        bases = set()
        for field in BIN_DB_LIST_FIELDS_TO_OUTPUT_FIELDS.keys():
            bases.update(set(_get_db_list(self.ctx, field)))

        return list(bases)

    @trace_calls
    def _collect(self):
        from yt.wrapper import YtClient

        yt_token = dbutils.get_yabscs_yt_token(self)
        yt_client = YtClient(config={'token': yt_token, 'proxy': {'url': YT_PROXY}})

        mkdb_info = dbutils.get_full_mkdb_info(self.get_yabscs())
        mkdb_info_version = base_mkdb_info_version(mkdb_info)
        logger.debug('mkdb_info version: %s', mkdb_info_version)

        db_list = self.get_db_list()
        settings_version = base_to_importer_settings_version(
            cs_dir=self.get_yabscs(),
            bases=db_list,
            settings_spec=self.cs_settings)
        logger.debug('settings version: %s', settings_version)

        code_version = base_to_importer_code_version(
            cs_dir=self.get_yabscs(),
            bases=db_list)
        logger.debug('code version: %s', code_version)

        binary_bases_attrs = []
        for attrs in self.iter_binary_bases_attrs(
                tags=db_list,
                db_ver=self._get_db_ver(),
                db_internal_vers=self._get_db_internal_vers(),
                base_settings_version=settings_version,
                base_code_version=code_version,
                base_mkdb_info_version=mkdb_info_version
        ):
            # Search either by attribute 'settings_version' or 'settings_spec_md5'
            if get_or_default(self.ctx, ReuseBasesByImporterSettingsVersion):
                attrs.pop('settings_spec_md5', None)
                # Prevent searching for base without attribute
                if yt_bases.IMPORTER_SETTINGS_VERSION_ATTR not in attrs:
                    attrs[yt_bases.IMPORTER_SETTINGS_VERSION_ATTR] = ATTR_VALUE_DO_NOT_REUSE
            else:
                attrs.pop(yt_bases.IMPORTER_SETTINGS_VERSION_ATTR, None)

            if get_or_default(self.ctx, ReuseBasesByMkdbInfoVersion):
                # Prevent searching for base without attribute
                if yt_bases.IMPORTER_MKDB_INFO_VERSION_ATTR not in attrs:
                    attrs[yt_bases.IMPORTER_MKDB_INFO_VERSION_ATTR] = ATTR_VALUE_DO_NOT_REUSE
            else:
                attrs.pop(yt_bases.IMPORTER_MKDB_INFO_VERSION_ATTR, None)

            if get_or_default(self.ctx, ReuseBasesByImporterCodeVersion):
                # Prevent searching for base without attribute
                if yt_bases.IMPORTER_CODE_VERSION_ATTR not in attrs:
                    attrs[yt_bases.IMPORTER_CODE_VERSION_ATTR] = ATTR_VALUE_DO_NOT_REUSE
            else:
                attrs.pop(yt_bases.IMPORTER_CODE_VERSION_ATTR, None)

            if not get_or_default(self.ctx, ReuseBasesByCSImportVer):
                attrs.pop('cs_import_ver', None)

            binary_bases_attrs.append(attrs)
            logger.debug('Base "%s" attributes: %s', attrs['tag'], attrs)

        common_attrs = intersect_dicts(*binary_bases_attrs)
        logger.debug('Use bases\' common atributes to compute lock\'s child_key: %s', common_attrs)
        with yt_client.Transaction():
            lock_for_search(yt_client, LOCKS_DIR, common_attrs)

            resources, resources_by_internal_ver, existing_tasks, missing_tags = _find_base_sources(binary_bases_attrs, get_or_default(self.ctx, ReuseExistingBases))
            logging.debug('resources: %s', resources)
            logging.debug('resources_by_internal_ver: %s', resources_by_internal_ver)
            logging.debug('existing_tasks: %s', existing_tasks)
            logging.debug('missing_tags: %s', missing_tags)

            new_child_tasks = []
            if missing_tags:
                self.process_archive_contents(
                    missing_tags,
                    run_switcher=bool(self.switcher_options),
                    run_cs_import=False,
                    dry_run=True,
                    need_switcher_bases_restore=bool(self.need_switcher_bases_restore)
                )
                # If we got here, all tables are ready
                new_child_tasks = self._launch_generation(missing_tags)
                logging.info('Run tasks: %s', new_child_tasks)

        logging.debug('Lock released')

        running_child_tasks, _ = self._check_children()
        logging.info("Running child tasks: %s", running_child_tasks)

        if existing_tasks or new_child_tasks:
            return resources, resources_by_internal_ver, sorted(set(existing_tasks + running_child_tasks + new_child_tasks))
        # All bases are ready
        try:
            rest.Client().batch.tasks.stop.update(running_child_tasks)
        except Exception:
            logging.warning("Failed to stop unneeded children", exc_info=True)
        return resources, resources_by_internal_ver, []

    def _check_children(self):
        subtasks = split_subtasks(self.id)
        check_and_handle_subtask_failure(subtasks, self.set_info, kill_only_queued_on_failure=True)

        if subtasks.broken:
            msg = "There are BROKEN subtasks"
            set_subtasks_info(msg, subtasks.broken, self.set_info)
            raise SandboxTaskUnknownError(msg)

        return subtasks.running, subtasks.succeed

    def _get_db_ver(self):
        try:
            return self._db_ver
        except AttributeError:
            self._db_ver = dbutils.get_base_ver(self.get_yabscs())
            return self._db_ver

    def _get_db_internal_vers(self):
        try:
            return self._db_internal_vers
        except AttributeError:
            self._db_internal_vers = dbutils.get_base_internal_vers(self.get_yabscs())
            return self._db_internal_vers

    def _run_fetch_tasks(self, plan, params):
        task_ids = []
        fetches = dict()

        cs_dir = self.get_yabscs()
        importers_info = yt_bases.get_cs_import_info(cs_dir)
        importer_queries = set(get_importer_queries(importers_info))
        logger.debug("Queries in importers info: %s", importer_queries)
        for bases, add_execution_space in plan.get_fetches():
            mysql_bases = []

            for base in bases:
                mkdb_info = dbutils.get_mkdb_info(cs_dir, base)
                if base_requires_mysql(mkdb_info, importer_queries):
                    logger.debug("Base %s requires MySql", base)
                    mysql_bases.append(base)

            if not mysql_bases:
                continue

            logger.debug("MySql bases: %s", mysql_bases)

            _params = copy.deepcopy(params)
            try:
                _params[YabsServerBasesFetch.HOST_SCORES_KEY] = self._score_hosts_for_fetch(mysql_bases)
            except Exception:
                logging.warning(
                    "Failed to get fetch host scores for mysql_archive_contents=%s and tags=%s",
                    self.mysql_archive_contents, mysql_bases, exc_info=True)

            task_id = self._start_child(
                mysql_bases,
                add_execution_space,
                task_type=BASEGEN_TASK_TYPES.FETCH,
                **_params)
            for tag in mysql_bases:
                fetches[tag] = task_id
            task_ids.append(task_id)

        return fetches, task_ids

    def _launch_generation(self, tags):
        ignored_params = {dbtask.BinDbList.name, DeleteFetchResults.name}
        logger.debug('Ignored parameters: %s', ignored_params)

        params_to_copy = set(
            parameter.name
            for parameter in dbtask.BinBasesTask.input_parameters
            if parameter.name not in ignored_params
        ) | {'arcadia_patch', 'tasks_archive_resource'}
        logger.debug('Parameters to copy: %s', params_to_copy)

        common_params = {
            key: self.ctx.get(key)
            for key in params_to_copy
        }
        common_params.update({
            base_producing_task.DBVer.name: self._get_db_ver(),
            base_producing_task.DBInternalVers.name: json.dumps(self._get_db_internal_vers()),
            dbtask.GENERATION_ID_KEY: self.get_generation_id(),
            SamplingStrategyParameter.name: get_or_default(self.ctx, SamplingStrategyParameter),
            '__requirements__': {
                'tasks_resource': self.ctx.get('tasks_archive_resource')
            },
        })
        if self.settings_spec:
            common_params['settings_spec'] = self.settings_spec

        task_tags = rest.Client().task[self.id].read()["tags"]
        if (
            'TESTENV-JOB-YABS_SERVER_30_BASES_BS_A' in task_tags or
            'TESTENV-JOB-YABS_SERVER_30_BASES_YABS_A' in task_tags or
            'TESTENV-JOB-YABS_SERVER_30_BASES_COMMON' in task_tags
        ):
            self.set_info("Will calculate digest")
            common_params.update({
                'calc_digest': True,
                'wait_digest': False,
            })
        logger.debug('Common params: %s', common_params)

        task_ids = []
        base_tag_groups = defaultdict(list)
        for tag in tags:
            if tag.startswith('st_update'):
                base_tag_groups['st_update'].append(tag)
            elif tag.startswith('dssm'):
                base_tag_groups['heavy'].append(tag)
            else:
                base_tag_groups['common'].append(tag)

        logger.info('Grouped base tags: %s', base_tag_groups)

        if base_tag_groups.get('st_update') and (base_tag_groups.get('common') or base_tag_groups.get('heavy')) and self.import_prefix:
            raise TaskFailure("Cannot use import_prefix param, st_update cannot be generated with other bases")

        mkdb_info = dbutils.get_full_mkdb_info(self.get_yabscs())
        mkdb_info_version = base_mkdb_info_version(mkdb_info)
        logging.debug('Mkdb info version: %s', mkdb_info_version)

        for group_name, tags in base_tag_groups.items():
            if not tags:
                continue

            logger.info('Launch generation for %s group', group_name)

            plan = GluedGenerationPlan(tags, self.get_yabscs(), self._get_table_sizes())

            # run cs import
            cs_import_task_id = None
            # Do not run import if import_node already provided, pass it as is
            # to reduce task
            if not self.import_prefix:
                cs_import_task_id = self._start_child_import(tags, common_params, task_type=BASEGEN_TASK_TYPES.IMPORT)
                task_ids.append(cs_import_task_id)

            fetches, fetch_task_ids = self._run_fetch_tasks(plan, common_params)
            task_ids += fetch_task_ids

            for reduce_tags, add_execution_space, fetch_tags in plan.get_reduces():
                logging.info('Reduce tags: %s', reduce_tags)

                settings_version = base_to_importer_settings_version(
                    self.get_yabscs(),
                    bases=reduce_tags,
                    settings_spec=self.cs_settings)

                code_version = base_to_importer_code_version(
                    self.get_yabscs(),
                    bases=reduce_tags)

                params = copy.deepcopy(common_params)
                params.update({
                    dbtask.CS_IMPORT_ID_KEY: cs_import_task_id,
                    BaseTagImporterSettingsVersion.name: json.dumps(settings_version),
                    BaseTagImporterCodeVersion.name: json.dumps(code_version),
                    BaseTagMkdbInfoVersion.name: json.dumps(
                        filter_mkdb_info_version_by_bases(mkdb_info_version, reduce_tags)
                    ),
                })

                params[dbtask.FETCH_ID_KEY] = list(set(filter(
                    None,
                    [fetches.get(tag) for tag in fetch_tags]
                )))

                logging.debug('Reduce parameters: %s', params)
                task_id = self._start_child(
                    reduce_tags,
                    add_execution_space,
                    task_type=BASEGEN_TASK_TYPES.REDUCE,
                    **params
                )
                task_ids.append(task_id)

        return task_ids

    def _get_table_sizes(self):
        contents = self.get_glued_mysql_archive_contents()
        return {key: contents.sizes[str(res_id)] for key, res_id in contents.tables.iteritems()}

    @staticmethod
    def _get_bin_db_list(tags):
        return ' '.join(tags)

    def _start_child_import(self, tags, common_params, task_type):
        tags = sorted(tags)
        import_params = {
            dbtask.BinDbList.name: self._get_bin_db_list(tags),
            ReuseImportResults.name: self.ctx.get(ReuseImportResults.name),
            UseCsCycle.name: self.ctx.get(UseCsCycle.name, False),
            SaveAllInputs.name: self.ctx.get(SaveAllInputs.name, SaveAllInputs.default_value),
            UseSaveInputFromCS.name: self.ctx.get(UseSaveInputFromCS.name, UseSaveInputFromCS.default_value),
            YtPool.name: get_or_default(self.ctx, YtPool),
        }

        logging.debug("Params import %s", import_params)
        logging.debug("Params common %s", common_params)
        import_params.update(common_params)

        tasks_resource = None
        if import_params.get('tasks_archive_resource') is not None:
            tasks_resource = import_params.pop('tasks_archive_resource')

        # Pass only YT oneshot to cs import tasks
        if not self.ctx.get(IsYTOneshot.name):
            import_params.pop(dbtask.OneShotPath.name)
            import_params.pop(dbtask.OneShotArgs.name)

        logging.debug("Params import after update %s", import_params)

        subtask = sdk2.Task[task_type](
            sdk2.Task.current,
            owner=self.owner,
            priority=self.priority,
            tags=self.tags,
            hints=list(sdk2.Task.current.hints),
            description=make_child_descr(self.cs_import_ver, tags, self.descr),
            **import_params
        )
        if tasks_resource:
            subtask.Requirements.tasks_resource = tasks_resource
            subtask.save()
        enqueue_task(subtask)
        return subtask.id

    def _start_child(self, tags, additional_execution_space, task_type, **_params):
        tags = sorted(tags)
        params = copy.deepcopy(_params)
        params.update({
            dbtask.BinDbList.name: self._get_bin_db_list(tags),
            YtPool.name: get_or_default(self.ctx, YtPool),
        })

        # Pass only not-YT oneshot to fetch and reduce tasks
        if self.ctx.get(IsYTOneshot.name):
            params.pop(dbtask.OneShotPath.name)
            params.pop(dbtask.OneShotArgs.name)

        # Trunk above all (because it can be reused)
        priority = ("BACKGROUND", "HIGH") if self.ctx.get('arcadia_patch') else None
        if task_type != BASEGEN_TASK_TYPES.FETCH:
            # Force priority for subtasks that can leak into GENERIC pool
            priority = priority or ("BACKGROUND", "HIGH")

        params.setdefault('__requirements__', {})['disk_space'] = 2048 + additional_execution_space  # 2 GiB + ...
        child = sdk2.Task[task_type](
            sdk2.Task.current,
            owner=self.owner,
            hints=list(sdk2.Task.current.hints),
            description=make_child_descr(self._get_db_ver(), tags, self.descr),
            priority=priority,
            tags=self.tags,
            **params
        )
        return enqueue_task(child).id

    def _score_hosts_for_fetch(self, tags):
        try:
            self._fetch_score_calc
        except AttributeError:
            self._fetch_score_calc = FetchScoreCalc()
        return self._fetch_score_calc.get_scores(self.mysql_archive_contents, tags)

    def _ensure_switcher_revision_param(self):
        if 'run_switcher' in str(self.ctx.get(dbtask.Options.name)) and not self.ctx.get(dbtask.SwitcherRevision.name):
            self.ctx[dbtask.SwitcherRevision.name] = str(
                Arcadia.info('arcadia:/arc/trunk/arcadia/yabs/utils/experiment-switcher')['entry_revision']
            )


def get_solomon_token(vault_name):
    try:
        return sdk2.VaultItem(vault_name).data()
    except VaultError as err:
        logging.warning('Failed to get data from vault: %s', err)


def collect_metrics(reused_bases, generated_bases, run_type=None, testenv_database=None):
    """Generate metrics in solomon format

    :param reused_bases: Reused bases
    :type reused_bases: list[str]
    :param generated_bases: Generated bases
    :type generated_bases: list[str]
    :type run_type: sandbox.projects.yabs.qa.utils.task_run_type.RunType or None
    :type testenv_database: str or None
    :return: Metrics in solomon format
    :rtype: list
    """
    base_status = itertools.chain.from_iterable((
        ((base_name, True) for base_name in reused_bases),
        ((base_name, False) for base_name in generated_bases),
    ))

    common_labels = {
        'run_type': run_type or 'unknown'
    }
    if testenv_database:
        common_labels['testenv_database'] = str(testenv_database)

    metrics = list(itertools.chain.from_iterable([
        (
            {
                'labels': dict(
                    sensor='reuse',
                    base_tag=base_tag,
                    status='reused',
                    **common_labels
                ),
                'value': int(reused),
            },
            {
                'labels': dict(
                    sensor='reuse',
                    base_tag=base_tag,
                    status='not_reused',
                    **common_labels
                ),
                'value': int(not reused),
            }
        )
        for base_tag, reused in base_status
    ]))

    return metrics


def _find_binary_base(attrs, state='READY', find_bases_from_other_tasks=False):
    """Find resource with binary base

    :param attrs: Attributes of binary base resource
    :type attrs: dict
    :param state: Resource state, defaults to 'READY'
    :type state: str, optional
    :param find_bases_from_other_tasks: If no resource found, try to find resource with attribute 'generation_id'=0, defaults to False
    :type find_bases_from_other_tasks: bool, optional
    :return: Resource data or None
    :rtype: Union[dict, None]
    """
    result = _find_binary_base_impl(attrs, state)
    if result is None and attrs[GENERATION_ID_KEY] and find_bases_from_other_tasks:
        _attrs = attrs.copy()
        _attrs[GENERATION_ID_KEY] = 0
        result = _find_binary_base_impl(_attrs, state)
    return result


def _find_binary_base_impl(attrs, state='READY'):
    logging.info("Finding binary base with state=%s and attrs=%s", state, attrs)
    cl = rest.Client()
    res_data = cl.resource.read(
        type=resource_types.YABS_SERVER_B2B_BINARY_BASE.name,
        state=state,
        attrs=attrs,
        order='-id',
        limit=100 if state == 'READY' else 1000,
    )
    items = res_data["items"]
    if state == 'READY':
        try:
            return items[0]
        except IndexError:
            return None
    for item in items:
        if item.get('task', {}).get('status') in _GOOD_TASK_STATUS_FOR_NOT_READY:
            return item
    return None


def _find_base_sources(binary_bases_attrs, reuse):
    """Find either existing resource with binary base or executing binary base generation task

    :param binary_bases_attrs: Attributes of bases to search for resources by
    :type binary_bases_attrs: list[dict]
    :param reuse: Reuse binary bases generated by other tasks
    :type reuse: bool
    :return: Tuple (found resources by tag; found resource with 'internal_ver' attribute  by tag; running tasks; bases tags without resources or tasks)
    :rtype: tuple[dict[str -> int], dict[str -> int], list[int], list[str]]
    """
    resources = {}
    resources_by_internal_ver = {}
    existing_tasks = []
    tags_to_schedule = []

    for attrs in binary_bases_attrs:
        tag = attrs['tag']

        # Find ready resource
        resource = _find_binary_base(attrs, state='READY', find_bases_from_other_tasks=reuse)
        if resource:
            logging.info('Tag: %s, resource: %s, state: READY, attributes: %s', tag, resource['id'], attrs)
            if 'internal_ver' in attrs:
                resources_by_internal_ver[tag] = resource['id']
            else:
                resources[tag] = resource['id']
            continue

        # Find running base generation task
        resource = _find_binary_base(attrs, state='NOT_READY', find_bases_from_other_tasks=reuse)
        if resource:
            logging.info(
                'Tag: %s, resource: %s, state: NOT_READY, attributes: %s, task: %s',
                tag, resource['id'], attrs, resource['task']['id'])
            existing_tasks.append(resource['task']['id'])
            continue

        tags_to_schedule.append(tag)
        logging.info('Tag: %s, no resources found', tag)

    return resources, resources_by_internal_ver, existing_tasks, tags_to_schedule


_TSG = TaskStatus.Group
_GOOD_TASK_STATUS_FOR_NOT_READY = frozenset(_TSG.QUEUE) | frozenset(_TSG.EXECUTE) | frozenset(_TSG.WAIT) | frozenset(_TSG.SUCCEED)


BaseGenerationInfo = namedtuple(
    'BaseGenerationInfo',
    [
        'base_tag',
        'resource_id',
        'producing_task_id',
        'producing_task_type',
        'base_generation_task_id',
        'base_generation_task_type'
    ]
)


def get_base_generation_info(base_resource_ids):
    """Collect data about base generation

    :type base_resource_ids: list
    :rtype: list[BaseGenerationInfo]
    """
    sandbox_client = rest.Client()

    base_resources = sandbox_client.resource.read(id=base_resource_ids, limit=len(base_resource_ids))['items']

    producing_task_ids = [resource['task']['id'] for resource in base_resources]
    logger.debug('Producing task ids: %s', producing_task_ids)
    producing_tasks_data = sandbox_client.task.read(id=producing_task_ids, limit=len(producing_task_ids))['items']
    logger.debug('Producing tasks: %s', producing_tasks_data)
    producing_tasks_indexed = {
        task['id']: task
        for task in producing_tasks_data
    }

    base_generation_task_ids = []
    for task in producing_tasks_data:
        if not task['parent']:
            logger.error('Task #%d has no parent', task['id'])
            continue
        try:
            base_generation_task_ids.append(task['parent']['id'])
        except (KeyError, AttributeError):
            logger.error('Task #%d has no parent', task['id'])
            continue
    logger.debug('Base generation task ids: %s', base_generation_task_ids)
    base_generation_tasks_data = sandbox_client.task.read(id=base_generation_task_ids, limit=len(base_generation_task_ids))['items']
    logger.debug('Base generation tasks: %s', base_generation_tasks_data)
    base_generation_tasks_indexed = {
        task['id']: task
        for task in base_generation_tasks_data
    }

    result = []
    for resource in base_resources:
        producing_task_id = resource['task']['id']
        producing_task = producing_tasks_indexed[producing_task_id]
        base_generation_task_id = producing_task.get('parent', {}).get('id')
        base_generation_task = base_generation_tasks_indexed[base_generation_task_id]

        result.append(BaseGenerationInfo(
            base_tag=resource['attributes']['tag'],
            resource_id=resource['id'],
            producing_task_id=producing_task_id,
            producing_task_type=producing_task['type'],
            base_generation_task_id=base_generation_task_id,
            base_generation_task_type=base_generation_task['type']
        ))

    return result


class FetchScoreCalc(object):

    CTX_TAGS_FIELD = 'context.' + dbtask.BinDbList.name
    CTX_ARCHIVE_CONTENTS_FIELD = 'context.' + MySQLArchiveContents.name
    CLIENT_ID_FIELD = 'execution.client.id'

    def __init__(self):
        # FIXME use input_parameters search after swithcing to SDK2
        cl = rest.Client()
        self._fetch_tasks_items = cl.task.read(
            type=BASEGEN_TASK_TYPES.FETCH,
            limit=200,
            order_by='-id',
            hidden=True,
            children=True,
            status='SUCCESS',
            fields=','.join((self.CTX_TAGS_FIELD, self.CTX_ARCHIVE_CONTENTS_FIELD, self.CLIENT_ID_FIELD))
        )['items']

    def get_scores(self, mysql_archive_contents, tags):
        tags = frozenset(tags)

        scores = {}
        for task in self._fetch_tasks_items:
            if frozenset(task[self.CTX_TAGS_FIELD].split()) == tags and task[self.CTX_ARCHIVE_CONTENTS_FIELD] == mysql_archive_contents:
                client_id = task[self.CLIENT_ID_FIELD]
                scores.setdefault(client_id, int(100 / (len(scores) + 1)))
                if len(scores) >= 5:
                    break
            else:
                logging.debug("FetchScoreCalc skipped fetch task %s", task)
        return scores


class GluedGenerationPlan(object):

    def __init__(self, tags, tooldir, table_sizes):

        piggyback_data = yt_bases.get_piggyback_info(tooldir)['piggyback_tags']

        piggyback_tags = set(tags) & piggyback_data.viewkeys()
        fetch_tags = set(tags) - piggyback_tags
        for tag in piggyback_tags:
            fetch_tags |= set(piggyback_data[tag])

        fetch_splitters = dict()

        for tag in fetch_tags:
            mkdb_info = dbutils.get_mkdb_info(tooldir, tag)
            real_mysql_tag = mkdb_info['Shard'].get('RealMySqlTag')

            if real_mysql_tag not in fetch_splitters:
                fetch_splitters[real_mysql_tag] = FetchSplitter(table_sizes, group_base_size=10 << 30, target_size=70 << 30)

            tag_tables = ['.'.join(key) for key in yt_bases.iter_db_tables(tooldir, base_tags=[tag], cs_import=False)]
            fetch_splitters[real_mysql_tag].add_tag(tag, tag_tables)

        self.fetch_chunks = []

        for splitter in fetch_splitters.itervalues():
            for chunk, data_size in splitter.iter_split():
                self.fetch_chunks.append((chunk, int(data_size * 1.5)))
        logging.debug("Fetch chunks: %s", self.fetch_chunks)

        stat_chunk = set(piggyback_tags)
        others_chunk = set()
        for fetch_chunk, _ in self.fetch_chunks:
            for tag in fetch_chunk:
                if re.match(r'(bs_|yabs_)?st\d+', tag):
                    stat_chunk.add(tag)
                else:
                    others_chunk.add(tag)

        self.reduce_chunks = []
        yabs_cs_execution_space = 1024 * YABS_CS_DISK_SIZE_GB
        if stat_chunk:
            chunk_execution_space = 5 * 1024 * len(stat_chunk)
            self.reduce_chunks.append((stat_chunk, yabs_cs_execution_space + chunk_execution_space, stat_chunk - piggyback_tags))
        if others_chunk:
            chunk_execution_space = 1024 * len(others_chunk)
            self.reduce_chunks.append((others_chunk, yabs_cs_execution_space + chunk_execution_space, others_chunk - piggyback_tags))

        logging.debug("Reduce chunks: %s", self.reduce_chunks)

    def get_fetches(self):
        return self.fetch_chunks

    def get_reduces(self):
        return self.reduce_chunks


__Task__ = YabsServerMakeBinBases
