# -*- coding: utf-8 -*-
import json
import logging
import itertools
import hashlib
import os
import re
from collections import defaultdict

from sandbox import sdk2
from sandbox.common.errors import TaskError
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sandboxsdk.parameters import SandboxStringParameter, ResourceSelector, ListRepeater, SandboxBoolParameter

from sandbox.common.types.client import Tag

from sandbox.projects.common.utils import get_or_default
from sandbox.projects.common.yabs.server.components.task import ServerTask, CSResource
from sandbox.projects.common.yabs.server.db import utils as dbutils
from sandbox.projects.common.yabs.server.db.task.cs import SettingsArchive
from sandbox.projects.common.yabs.server.db.yt_bases import get_cs_import_info
from sandbox.projects.common.yabs.server.util.general import check_tasks
from sandbox.projects.yabs.qa.constants import META_ROLES
from sandbox.projects.yabs.qa.resource_types import MkdbInfoResource, YT_ONESHOTS_PACKAGE
from sandbox.projects.yabs.qa.sut.constants import DEFAULT_STAT_SHARD, DEFAULT_SERVER_CONFIG_MODE
from sandbox.projects.yabs.qa.sut.metastat.adapters.common import META_MODE_TO_CLUSTER_TAG
from sandbox.projects.yabs.qa.utils.general import unpack_targz_package, get_files_md5
from sandbox.projects.yabs.qa.utils.importer import get_importer_bases_by_tables

from sandbox.projects.common.yabs.server.components import hugepage_warmup

from sandbox.projects.common.yabs.server.tracing import TRACE_WRITER_FACTORY
from sandbox.projects.yabs.sandbox_task_tracing import trace, trace_entry_point
from sandbox.projects.yabs.sandbox_task_tracing.wrappers.sandbox.generic import enqueue_task
from sandbox.projects.yabs.sandbox_task_tracing.wrappers.sandbox.sdk2 import new_resource_data

BIN_DB_LIST_COMMON_KEY = 'bin_db_list_common'
TAGS_COMMON_KEY = 'tags_common'
TAGS_SHARD_KEY = 'tags_shard'

ROLE_TO_CLUSTER = {
    'bs': 'bs',
    'bsrank': 'yabs',
    'yabs': 'yabs',
}


def get_base_name(base, role):
    return '_'.join((role, base))


TAG_FOR_COLLECTING_BASENO = 'yabs_st'


def get_field_name(name, shard=None, role=None):
    names = [name]
    if shard is not None:
        names.append(shard)
    if role is not None:
        names.append(role)
    return '_'.join(names)


class StatShards(SandboxStringParameter):
    name = 'stat_shards'
    description = 'Yabs-stat shards (if no shard map specified)'
    default_value = DEFAULT_STAT_SHARD


class ShardMap(ResourceSelector):
    name = 'shard_map_res_id'
    description = 'Resource with shard map'
    required = False
    multiple = False


class ShardsKeys(ListRepeater, SandboxStringParameter):
    name = 'shards_keys'
    description = 'Keys in shard map resource'
    default_value = ['A', 'B', 'C']
    required = True


class CreateMkdbInfo(SandboxBoolParameter):
    name = 'create_mkdb_info'
    description = 'Run mkdb_info and dump results to resource'
    default_value = False


class YabsServerSetupYaMake(ServerTask):
    execution_space = 10 * 1024
    required_ram = 8 * 1024
    type = 'YABS_SERVER_SETUP_YA_MAKE'
    description = 'Check that yabs-server (built with YA_MAKE) starts and extract binary base lists'

    client_tags = Tag.LINUX_PRECISE
    cores = 1

    input_parameters = ServerTask.input_parameters + (
        CSResource,
        CreateMkdbInfo,
        ShardMap,
        ShardsKeys,
        StatShards,
        SettingsArchive,
    )

    @property
    def yabscs_path(self):
        if not hasattr(self, "__yabscs_path"):
            yabscs_res_id = self.ctx.get(CSResource.name)
            self.__yabscs_path = dbutils.get_yabscs(self, yabscs_res_id)
        return self.__yabscs_path

    def get_meta_base_tags(self):
        meta_tags = defaultdict(set)
        meta_binary_base_versions = set()

        for role, server in self._iter_meta_servers():
            with server:
                for filename in server.get_bin_db_list():
                    ver, tag, _ = filename.split('.')
                    meta_binary_base_versions.add(ver)
                    meta_tags[role].add(tag)

        logging.debug("Base tags by role and meta_mode: %s", meta_tags)
        return meta_tags, meta_binary_base_versions

    def get_stat_base_tags(self, shard_map):
        stat_shard_tags = defaultdict(lambda: defaultdict(set))
        stat_binary_base_versions = set()

        needed_shards = set(shard_map.values() or get_or_default(self.ctx, StatShards).split())
        # ['01', '05', '16', '02,03,04'] -> ['01', '05', '16', '02', '03', '04']
        needed_shards = ','.join(needed_shards).split(',')  # TODO: shard_map must be {string: list}

        for role, shard, server in self._iter_stat_servers(needed_shards):
            with server:
                for filename in server.get_bin_db_list():
                    ver, tag, _ = filename.split('.')
                    stat_binary_base_versions.add(ver)
                    stat_shard_tags[role][shard].add(tag)

        common_stat_tags = defaultdict(set)
        common_all_stat_tags = set()
        if len(needed_shards) > 1:  # More fair. For TE
            for role, shards_tags in stat_shard_tags.iteritems():
                common_stat_tags[role] = set.intersection(*shards_tags.values())
                common_all_stat_tags |= common_stat_tags[role]
        else:  # For Oneshots
            common_all_stat_tags = set.intersection(*itertools.chain(*[val.values() for val in stat_shard_tags.values()]))
            for role, shards_tags in stat_shard_tags.iteritems():
                common_stat_tags[role] = shards_tags[needed_shards[0]] & common_all_stat_tags

        common_stat_tags[None] = common_all_stat_tags

        logging.debug("Base tags by stat shard: %s", stat_shard_tags)
        return common_stat_tags, stat_shard_tags, stat_binary_base_versions

    @trace_entry_point(writer_factory=TRACE_WRITER_FACTORY)
    def on_execute(self):
        """
        Output resource (DB_INFO) is empty and is only used to pass db_ver
        """
        with self.memoize_stage.main(commit_on_entrance=False), trace('main'):
            if not hugepage_warmup.check_required_ram(self):
                self.abandon_host("Required RAM check failed")

            shard_map = self._get_shard_map()
            shard_map = self._cut_shard_map(shard_map)
            reversed_shard_map = {v: k for k, v in shard_map.items()}  # OTHER tag key used to contain multiple shard_nums, but it's okay because it is not used anymore

            meta_roles_tags, meta_binary_base_versions = self.get_meta_base_tags()
            common_stat_tags, stat_tags, stat_binary_base_versions = self.get_stat_base_tags(shard_map)

            binary_base_versions = meta_binary_base_versions | stat_binary_base_versions

            if len(binary_base_versions) > 1:
                raise SandboxTaskFailureError(
                    "More than one binary base version found in yabs-db-lists: {}".format(
                        ', '.join(binary_base_versions)
                    )
                )
            if binary_base_versions:
                db_ver = binary_base_versions.pop()
                self.ctx["db_ver"] = db_ver

            bases_with_baseno = set()

            common_tags = {}
            for role in META_ROLES:
                common_tags[role] = meta_roles_tags[role] | common_stat_tags[role]
                tags_in_role = sorted(common_tags[role])
                self.ctx[get_field_name(BIN_DB_LIST_COMMON_KEY, role)] = ' '.join(tags_in_role)  # deprecated
                self.ctx[get_field_name(TAGS_COMMON_KEY, role)] = tags_in_role  # deprecated
                self.ctx['base_tags_meta_{}'.format(role)] = list(sorted(meta_roles_tags[role]))

            for role, shards_tags in stat_tags.iteritems():
                self.ctx['base_tags_stat_{}_COMMON'.format(role)] = list(sorted(common_stat_tags[role]))
                for shard_num, tags in shards_tags.iteritems():
                    shard_key = reversed_shard_map.get(shard_num, str(shard_num))
                    tags_shard = sorted(tags - common_tags[role])
                    self.ctx['base_tags_stat_{}_{}'.format(role, shard_key)] = list(sorted(tags_shard))
                    self.ctx[get_field_name(TAGS_SHARD_KEY, shard_num, role)] = tags_shard  # deprecated

                    if role == 'yabs':
                        bases_with_baseno.update(filter(lambda x: x.startswith(TAG_FOR_COLLECTING_BASENO), tags_shard))

            if not bases_with_baseno:
                raise TaskError('There is no tag {} for collecting baseno'.format(TAG_FOR_COLLECTING_BASENO))
            self.ctx['baseno_list'] = map(lambda x: int(re.findall(r'\d+', x)[0]), bases_with_baseno)

            for shard_key, shard_value in shard_map.iteritems():
                shards_list = shard_value.split(',')
                for role in META_ROLES:
                    bin_db_list_field = get_field_name('bin_db_list', shard_key, role)  # deprecated

                    self.ctx[bin_db_list_field] = ' '.join(sorted(itertools.chain(*[self.ctx[get_field_name(TAGS_SHARD_KEY, shard, role)] for shard in shards_list])))  # deprecated

            if self.ctx.get(CreateMkdbInfo.name):
                mkdb_info_filename = 'mkdb_info.json'
                mkdb_info = self._get_mkdb_info()
                mkdb_info_sorted = json.dumps(mkdb_info, sort_keys=True)
                with open(mkdb_info_filename, 'w') as mkdb_info_file:
                    mkdb_info_file.write(mkdb_info_sorted)

                mkdb_info_resource = self.create_resource(
                    description='Mkdb info result',
                    resource_path=mkdb_info_filename,
                    resource_type=MkdbInfoResource,
                )
                self.mark_resource_ready(mkdb_info_resource.id)

                digest = hashlib.md5()
                digest.update(mkdb_info_sorted)
                self.ctx['mkdb_info_md5'] = digest.hexdigest()

            self.ctx['cs_import_ver'] = self._get_cs_import_ver()

            oneshots_package = YT_ONESHOTS_PACKAGE.find(attrs=self.global_key).first()
            common_oneshots_md5 = self._get_common_oneshots_md5(oneshots_package)
            self.ctx['common_oneshots_md5'] = common_oneshots_md5
            if common_oneshots_md5:
                task_id = self._run_get_oneshot_tables(oneshots_package)
                self.ctx['get_oneshot_tables_task_id'] = task_id
                logging.info('Run task ExecuteYTOneshot to get common oneshots tables: #%s', str(task_id))

        if self.ctx.get('common_oneshots_md5'):
            check_tasks(self, self.ctx['get_oneshot_tables_task_id'])
            common_oneshots_tables = sdk2.Task[self.ctx['get_oneshot_tables_task_id']].Parameters.oneshot_tables
            self.ctx['common_oneshots_bases'] = self._get_common_oneshots_bases(common_oneshots_tables)
        else:
            self.ctx['common_oneshots_bases'] = []

    def _cut_shard_map(self, shard_map):
        return {key: value for key, value in shard_map.iteritems() if key in get_or_default(self.ctx, ShardsKeys)}

    @staticmethod
    def _get_custom_env(role):
        custom_env_dict = {
            'cluster_set_config': ROLE_TO_CLUSTER[role],
            'cluster_tag': META_MODE_TO_CLUSTER_TAG[role],
        }
        return '\n'.join('{}={}'.format(k, v) for k, v in custom_env_dict.items())

    def _iter_meta_servers(self):
        for config_mode in ('', DEFAULT_SERVER_CONFIG_MODE):  # '' - cloud config
            for role in META_ROLES:
                yield role, self.create_meta(
                    role,
                    config_mode=config_mode,
                    custom_env=self._get_custom_env(role),
                )

    def _iter_stat_servers(self, shards):
        for config_mode in ('', DEFAULT_SERVER_CONFIG_MODE):  # '' - cloud config
            for role in META_ROLES:
                for shard in shards:
                    if shard:
                        yield role, shard, self.create_stat(
                            shard=shard,
                            config_mode=config_mode,
                            custom_env=self._get_custom_env(role),
                        )

    def _get_shard_map(self):
        res_id = get_or_default(self.ctx, ShardMap)
        if not res_id:
            return {}

        shard_map_path = self.sync_resource(res_id)
        with open(shard_map_path) as shard_map_file:
            shard_map_str = shard_map_file.read().strip()

        return json.loads(shard_map_str) if shard_map_str else {}

    def _get_cs_import_ver(self):
        return dbutils.get_cs_import_ver(self.yabscs_path)

    def _get_mkdb_info(self):
        return dbutils.get_full_mkdb_info(self.yabscs_path)

    def _get_common_oneshots_md5(self, oneshots_package):
        oneshots_dir = 'oneshots_dir'
        if not oneshots_package:
            return ''
        package_path = str(new_resource_data(oneshots_package).path)
        os.mkdir(oneshots_dir)
        file_list = unpack_targz_package(package_path, oneshots_dir)
        return get_files_md5(file_list)

    def _run_get_oneshot_tables(self, oneshots_package):
        task_class = sdk2.Task['EXECUTE_YT_ONESHOT']
        execute_yt_oneshot_task = enqueue_task(task_class(
            task_class.current,
            owner=self.owner,
            priority=self.priority,
            description='Get tables from YT oneshot: {}'.format(oneshots_package),
            oneshots_package=oneshots_package,
            print_tables_only=True,
        ))
        return execute_yt_oneshot_task.id

    def _get_common_oneshots_bases(self, common_oneshots_tables):
        settings_archive = self.ctx.get(SettingsArchive.name, None)
        cs_settings = dbutils.get_cs_settings(self, settings_archive) if settings_archive else None
        importers_info = get_cs_import_info(self.yabscs_path, cs_settings)
        mkdb_info = self._get_mkdb_info()
        return get_importer_bases_by_tables(common_oneshots_tables, importers_info, mkdb_info)


__Task__ = YabsServerSetupYaMake
