import hashlib
import json
import logging
import os
import re
import tarfile
from collections import defaultdict

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.projects.common.yabs.server.db import utils as dbutils
from sandbox.projects.common.yabs.server.util.general import check_tasks
from sandbox.projects.resource_types import YANDEX_SURF_DATA
from sandbox.projects.yabs.qa.constants import META_ROLES
from sandbox.projects.yabs.qa.sut.constants import DEFAULT_SERVER_CONFIG_MODE
from sandbox.projects.yabs.qa.sut.metastat.adapters.common import META_MODE_TO_CLUSTER_TAG
from sandbox.projects.yabs.qa.utils.general import unpack_targz_package, get_files_md5
from sandbox.projects.yabs.qa.resource_types import (
    BS_RELEASE_YT,
    BS_RELEASE_TAR,
    YABS_SERVER_TESTENV_SHARD_MAP,
    MkdbInfoResource,
    YT_ONESHOTS_PACKAGE,
)
from sandbox.projects.yabs.qa.sut.factory import YabsServerFactoryStandalone


LOGGER = logging.getLogger(__name__)
TAG_FOR_COLLECTING_BASENO = 'yabs_st'
ROLE_TO_CLUSTER = {
    'bs': 'bs',
    'bsrank': 'yabs',
    'yabs': 'yabs',
}


def _sync_resource(res_id):
    return unicode(sdk2.ResourceData(sdk2.Resource[res_id]).path)


class YabsServerGetBaseList(sdk2.Task):
    """Launch yabs-server and get base lists"""

    class Parameters(sdk2.Parameters):
        bs_release_tar_resource = sdk2.parameters.Resource('Resource with yabs-server-bundle', resource_type=BS_RELEASE_TAR, required=True)
        bs_release_yt_resource = sdk2.parameters.Resource('Resource with yabscs package', resource_type=BS_RELEASE_YT, required=True)
        common_oneshots_resource = sdk2.parameters.Resource('Resource with common oneshots package', resource_type=YT_ONESHOTS_PACKAGE)
        surf_data_resource = sdk2.parameters.Resource('Resource with surf data', resource_type=YANDEX_SURF_DATA, default=121355487)
        shard_map = sdk2.parameters.Resource('Resource with shard map', resource_type=YABS_SERVER_TESTENV_SHARD_MAP, required=True)
        create_mkdb_info = sdk2.parameters.Bool('Create resource with mkdb info', default=False)

        with sdk2.parameters.Output:
            db_ver = sdk2.parameters.String('Base version')
            mkdb_info_md5 = sdk2.parameters.String('Mkdb info md5')
            cs_import_ver = sdk2.parameters.String('CS import ver')
            baseno_list = sdk2.parameters.JSON('Baseno list')
            common_oneshots_md5 = sdk2.parameters.String('Common oneshots package md5')
            common_oneshots_bases = sdk2.parameters.String('Common oneshots bases')

    class Requirements(sdk2.Requirements):

        class Caches(sdk2.Requirements.Caches):
            pass

    def get_cs_import_ver(self):
        yabscs_path = dbutils.get_yabscs(self, self.Parameters.bs_release_yt_resource.id)
        return dbutils.get_cs_import_ver(yabscs_path)

    def get_mkdb_info_md5(self, bs_release_yt_resource):
        yabscs_path = dbutils.get_yabscs(self, bs_release_yt_resource.id)
        mkdb_info_filename = 'mkdb_info.json'
        mkdb_info = dbutils.get_full_mkdb_info(yabscs_path)
        mkdb_info_sorted = json.dumps(mkdb_info, sort_keys=True)
        with open(mkdb_info_filename, 'w') as mkdb_info_file:
            mkdb_info_file.write(mkdb_info_sorted)

        mkdb_info_resource = MkdbInfoResource(self, 'Mkdb info result', mkdb_info_filename)
        sdk2.ResourceData(mkdb_info_resource).ready()

        digest = hashlib.md5()
        digest.update(mkdb_info_sorted)
        return digest.hexdigest()

    def get_common_oneshots_md5(self, oneshots_package):
        oneshots_dir = 'oneshots_dir'
        if not oneshots_package:
            return ''
        package_path = str(sdk2.ResourceData(oneshots_package).path)
        os.mkdir(oneshots_dir)
        file_list = unpack_targz_package(package_path, oneshots_dir)
        return get_files_md5(file_list)

    def on_execute(self):
        with self.memoize_stage.get_base_tags(commit_on_entrance=False):
            with open(_sync_resource(self.Parameters.shard_map), 'r') as shard_map_file:
                shard_map = json.load(shard_map_file)

            server_resource_path = _sync_resource(self.Parameters.bs_release_tar_resource.id)
            server_path = os.path.abspath("yabs_server_bundle")
            logging.debug("Extract yabs-server bundle from %s to %s", server_resource_path, server_path)
            with tarfile.open(server_resource_path) as archive:
                archive.extractall(path=server_path)

            surf_data_path = _sync_resource(self.Parameters.surf_data_resource)

            yabs_server_factory = YabsServerFactoryStandalone(
                server_path=server_path,
                surf_data_path=surf_data_path,
            )

            meta_roles_tags, meta_binary_base_versions = get_meta_base_tags(_iter_meta_servers(yabs_server_factory))
            common_stat_tags, stat_tags, stat_binary_base_versions = get_stat_base_tags(_iter_stat_servers(yabs_server_factory, shard_map.values()))

            binary_base_versions = meta_binary_base_versions | stat_binary_base_versions

            if len(binary_base_versions) > 1:
                raise TaskFailure("More than one binary base version found in yabs-db-lists: {}".format(', '.join(binary_base_versions)))

            if binary_base_versions:
                db_ver = binary_base_versions.pop()
                with self.memoize_stage.set_output_db_ver():
                    self.Parameters.db_ver = db_ver

            bases_with_baseno = set()

            for role in META_ROLES:
                setattr(self.Context, 'base_tags_meta_{}'.format(role), list(sorted(meta_roles_tags[role])))

            reversed_shard_map = {v: k for k, v in shard_map.items()}
            for role, shards_tags in stat_tags.iteritems():
                setattr(self.Context, 'base_tags_stat_{}_COMMON'.format(role), list(sorted(common_stat_tags[role])))
                for shard_num, tags in shards_tags.iteritems():
                    shard_key = reversed_shard_map[shard_num]
                    tags_shard = sorted(tags - common_stat_tags[role])
                    setattr(self.Context, 'base_tags_stat_{}_{}'.format(role, shard_key), list(tags_shard))

                    if role == 'yabs':
                        bases_with_baseno.update(filter(lambda x: x.startswith(TAG_FOR_COLLECTING_BASENO), tags_shard))

            if not bases_with_baseno:
                raise TaskFailure('There is no tag {} for collecting baseno'.format(TAG_FOR_COLLECTING_BASENO))

            with self.memoize_stage.set_output_baseno_list():
                self.Parameters.baseno_list = map(lambda x: int(re.findall(r'\d+', x)[0]), bases_with_baseno)

            with self.memoize_stage.set_output_cs_import_ver():
                self.Parameters.cs_import_ver = self.get_cs_import_ver()

            if self.Parameters.create_mkdb_info:
                with self.memoize_stage.set_output_mkdb_info_md5():
                    self.Parameters.mkdb_info_md5 = self.get_mkdb_info_md5(self.Parameters.bs_release_yt_resource)

            if self.Parameters.common_oneshots_resource:
                common_oneshots_md5 = self.get_common_oneshots_md5(self.Parameters.common_oneshots_resource)
                with self.memoize_stage.set_output_common_oneshots_md5():
                    self.Parameters.common_oneshots_md5 = common_oneshots_md5

                if common_oneshots_md5:
                    task_id = self._run_get_oneshot_tables(self.Parameters.common_oneshots_resource)
                    self.Context.get_oneshot_tables_task_id = task_id
                    logging.info('Run task ExecuteYTOneshot to get common oneshots tables: #%s', str(task_id))

        common_oneshots_bases = []
        if self.Context.get_oneshot_tables_task_id:
            check_tasks(self, self.Context.get_oneshot_tables_task_id)
            common_oneshots_tables = sdk2.Task[self.Context.get_oneshot_tables_task_id].Parameters.oneshot_tables
            common_oneshots_bases = self._get_common_oneshots_bases(common_oneshots_tables)

        with self.memoize_stage.set_output_common_oneshots_bases():
            self.Parameters.common_oneshots_bases = common_oneshots_bases


def _get_custom_env(role):
    custom_env_dict = {
        'cluster_set_config': ROLE_TO_CLUSTER[role],
        'cluster_tag': META_MODE_TO_CLUSTER_TAG[role],
    }
    return '\n'.join('{}={}'.format(k, v) for k, v in custom_env_dict.items())


def get_meta_base_tags(iter_meta_servers):
    meta_tags = defaultdict(set)
    meta_binary_base_versions = set()

    for role, server in iter_meta_servers:
        with server:
            for filename in server.get_bin_db_list():
                ver, tag, _ = filename.split('.')
                meta_binary_base_versions.add(ver)
                meta_tags[role].add(tag)

    logging.debug("Base tags by role and meta_mode: %s", meta_tags)
    return meta_tags, meta_binary_base_versions


def get_stat_base_tags(iter_stat_servers):
    stat_shard_tags = defaultdict(lambda: defaultdict(set))
    stat_binary_base_versions = set()

    for role, shard, server in iter_stat_servers:
        with server:
            for filename in server.get_bin_db_list():
                ver, tag, _ = filename.split('.')
                stat_binary_base_versions.add(ver)
                stat_shard_tags[role][shard].add(tag)

    common_stat_tags = defaultdict(set)
    common_all_stat_tags = set()
    for role, shards_tags in stat_shard_tags.iteritems():
        common_stat_tags[role] = set.intersection(*shards_tags.values())
        common_all_stat_tags |= common_stat_tags[role]

    common_stat_tags[None] = common_all_stat_tags

    logging.debug("Base tags by stat shard: %s", stat_shard_tags)
    return common_stat_tags, stat_shard_tags, stat_binary_base_versions


def _iter_meta_servers(yabs_server_factory):
    for config_mode in ('', DEFAULT_SERVER_CONFIG_MODE):  # '' - cloud config
        for role in META_ROLES:
            yield role, yabs_server_factory.create_meta(
                role,
                config_mode=config_mode,
                custom_env=_get_custom_env(role),
            )


def _iter_stat_servers(yabs_server_factory, shards):
    for config_mode in ('', DEFAULT_SERVER_CONFIG_MODE):  # '' - cloud config
        for role in META_ROLES:
            for shard in shards:
                if shard:
                    yield role, shard, yabs_server_factory.create_stat(
                        shard=shard,
                        config_mode=config_mode,
                        custom_env=_get_custom_env(role),
                    )
