import re
import os.path
import json
import itertools
import logging
import hashlib
from sandbox import common
from sandbox.common import rest

import sandbox.sandboxsdk.paths as sdk_paths

from sandbox import sdk2

from sandbox.common.types.task import ReleaseStatus
from sandbox.projects.common.yabs.server.db.task.mysql import use_myisampack
from sandbox.projects.yabs.qa.resource_types import BS_RELEASE_YT, UCTool
from sandbox.projects.yabs.qa.utils.resource import sync_resource

from sandbox.projects.yabs.sandbox_task_tracing import trace, trace_calls, trace_subprocess
from sandbox.projects.yabs.sandbox_task_tracing.wrappers.sandbox.sandboxsdk.process import run_process


@trace_calls
def get_yabscs(task, res_id=None):
    uc_tar_path = sync_resource(resource_type=UCTool, release_status=ReleaseStatus.STABLE)
    yabscs_tar_path = sync_resource(resource=res_id, resource_type=BS_RELEASE_YT, release_status=ReleaseStatus.STABLE)

    run_process(['tar', '-xzf', uc_tar_path], log_prefix='tar_uc')

    run_process(['tar', '-xf', yabscs_tar_path], log_prefix='tar_yabscs')
    bs_release_yt_dir = sdk_paths.make_folder('bs_release_yt', delete_content=True)

    run_process(['./uc', '-x', '-f', 'package/linux-yandexcs.tar.uc.lzma-2', '-t', bs_release_yt_dir], log_prefix='uc_yabscs')

    return bs_release_yt_dir


def get_yabscs_yt_token(task):
    try:
        return task.get_vault_data(task.author, 'yabscs_yt_token')
    except (common.errors.VaultNotFound, common.errors.VaultNotAllowed) as err:
        logging.info("%s, falling back to robot-yabs-cs-sb token", err)
        return task.get_vault_data('robot-yabs-cs-sb', 'yabscs_yt_token')


def _glue_resource_tree(sync_resource, rest_client, res_id, next_id_key, overriding_id_key):
    attrs = rest_client.resource[res_id].attribute.read()

    glued_data = {}

    # Resource under next_id DOES NOT override earlier mapping table->resource!
    for a in attrs:
        if a['name'] == next_id_key:
            next_id = a['value']
            glued_data = _glue_resource_tree(sync_resource, rest_client, next_id, next_id_key, overriding_id_key)

    def update_glued_data(new_data):
        for key, section in new_data.iteritems():
            if key in glued_data:
                glued_data[key].update(section)
            else:
                glued_data[key] = section

    path = sync_resource(res_id)
    with open(path) as cnt_file:
        my_data = json.load(cnt_file)
        update_glued_data(my_data)

    for a in attrs:
        if a['name'] == overriding_id_key:
            overriding_id = a['value']
            override = _glue_resource_tree(sync_resource, rest_client, overriding_id, next_id_key, overriding_id_key)
            update_glued_data(override)

    return glued_data


# Used to add new tables automatically
MYSQL_ARCHIVE_CONTENTS_NEXT_KEY = 'next_id'

# Used in emergency cases to fix old tables manually (columns adition, etc...)
MYSQL_ARCHIVE_CONTENTS_OVERRIDING_KEY = 'overriding_id'

_DEFAULT_LM_SHARD_COUNT = 18


class GluedMySQLArchiveContents(object):

    def __init__(self, sync_resource, res_id, rest_client=None):
        """
        Glues tree of YABS_MYSQL_ARCHIVE_CONTENTS resources.
        """
        rest_client = rest_client or rest.Client()
        self._res_id = res_id

        self._data = _glue_resource_tree(sync_resource, rest_client, res_id, MYSQL_ARCHIVE_CONTENTS_NEXT_KEY, MYSQL_ARCHIVE_CONTENTS_OVERRIDING_KEY)

        self.tables = self._data.get('tables', {})
        self.lm_dumps = self._data.get('lm_dumps', {})

        # FIXME remove _fix_sizes when all MYSQL_ARCHIVE_CONTENTS provide pre-myisampack sizes
        self.sizes = _fix_sizes(self.tables, self._data.get('sizes', {}))
        self.sources = self._data.get('sources', {})

        self._lm_dumps_for_shard_count = self._data.get('lm_dumps_for_shard_count', {str(_DEFAULT_LM_SHARD_COUNT): self.lm_dumps})

    def get_lm_dumps_for_shard_count(self, shard_count):
        try:
            return self._lm_dumps_for_shard_count[str(shard_count)]
        except KeyError:
            raise RuntimeError("No linear model dumps for {} shards in resource {}".format(shard_count, self._res_id))


def _fix_sizes(tables, sizes):
    def _iter_sizes():
        for table_key, res_id in tables.iteritems():
            sizes_key = str(res_id)
            try:
                size_maybe_post_myisampack = sizes[sizes_key]
            except KeyError:
                logging.warning("Unknown table size for %s (resource id %s)", table_key, res_id)
                pass
            else:
                _, _, tname = table_key.split('.')
                yield sizes_key, size_maybe_post_myisampack * (3 if use_myisampack(tname) else 1)
    return dict(_iter_sizes())


def append_mysql_archive_contents(task, chain_res_id, res_id_to_append, rest_client=None):
    """
    Append YABS_MYSQL_ARCHIVE_CONTENTS resource to the end of the chain of such resources.
    """
    rest_client = rest_client or rest.Client()
    try:
        rest_client.resource[chain_res_id].attribute.create({'name': MYSQL_ARCHIVE_CONTENTS_NEXT_KEY, 'value': str(res_id_to_append)})
    except Exception as err:
        # Maybe something is already chained?
        attrs = rest_client.resource[chain_res_id].attribute.read()
        for a in attrs:
            if a['name'] == MYSQL_ARCHIVE_CONTENTS_NEXT_KEY:
                return append_mysql_archive_contents(task, int(a['value']), res_id_to_append, rest_client)
        raise err


GHOST_TABLES = frozenset(itertools.chain(
    ('Banner{:02d}'.format(n) for n in range(91, 98)),
    ('BannerForGen{}'.format(n) for n in range(1, 150)),
))
HARDCODED_TABLES = frozenset(['DataTime'])  # FIXME do we really need this?


@trace_calls(save_arguments=(1, 'tag'))
def get_mkdb_info(tooldir, tag):
    return json.loads(_get_dbtool_out(tooldir, 'mkdb_info', tag))


@trace_calls
def get_full_mkdb_info(tooldir, __cache={}):
    # output is too big to store in task context, so it needs to be cached somewhere else
    key = os.path.abspath(tooldir)
    if key in __cache:
        return __cache[key]
    result = json.loads(_get_dbtool_out(tooldir, 'mkdb_info'))
    __cache[key] = result
    return result


@trace_calls
def get_base_ver(tooldir):
    return _get_dbtool_out(tooldir, 'base_ver').strip()


@trace_calls
def get_base_internal_vers(tooldir):
    return json.loads(_get_dbtool_out(tooldir, 'internal_ver'))['shards']


@trace_calls
def get_cs_import_ver(tooldir):
    return _get_dbtool_out(tooldir, 'cs_import_ver').strip()


@trace_calls
def get_importer_code_ver(tooldir):
    """Calculate the checksum of code for every importer

    :param tooldir: path to dbtool
    :type tooldir: str
    :return: Mapping from importer name to checksum of its code and dependencies
    :rtype: dict
    """
    result = json.loads(_get_dbtool_out(tooldir, 'cs_import_ver_separate').strip())
    # old version of dbtool may prefix keys with 'CSImportVer'
    prefix = 'CSImportVer'
    result = {
        (key[len(prefix):] if key.startswith(prefix) else key): value
        for key, value in result.items()
    }
    return result


@trace_calls
def get_cs_fetch_ver(tooldir):
    return _get_dbtool_out(tooldir, 'cs_fetch_ver').strip()


@trace_calls
def get_cs_advmachine_export_phrase_info_ver(tooldir):
    return _get_dbtool_out(tooldir, 'advmachine_export_phrase_info_ver').strip()


@trace_calls
def get_cs_advmachine_export_banners_ver(tooldir):
    return _get_dbtool_out(tooldir, 'advmachine_export_banners_ver').strip()


def _get_dbtool_out(tooldir, *args):
    dbtool_path = os.path.join(tooldir, 'dbtool')
    cmdline = (dbtool_path,) + args
    p = run_process(cmdline, log_prefix='_'.join(('dbtool',) + args))
    with open(p.stdout_path) as dbtool_out:
        return dbtool_out.read()


class BrokenMkdbInfo(Exception):
    pass


def _get_tag_tables(tooldir, tag):
    """
    Returns set of tables for specified base tag
    """
    info = get_mkdb_info(tooldir, tag)
    instance = info['Shard'].get('MySqlTag')
    tables = set(HARDCODED_TABLES) if instance else set()

    for item in info['Queries']:
        if item['Type'] == 'MySql':
            query = item['SQL']
            query_tables = _extract_tables(query)  # FIXME BSSERVER-6412 fix tables in dbtool and use them
            tables |= (query_tables - GHOST_TABLES)
    if tables and not instance:
        raise BrokenMkdbInfo("Broken mkdb_info for {}".format(tag))
    return instance, tables


def get_mysql_tag(tooldir, db_tag):
    mkdb_info = get_mkdb_info(tooldir, db_tag)
    return mkdb_info['Shard'].get('MySqlTag')


def _extract_tables(query):
    """Extract tables from SQL select query"""
    tablefactor_rx = re.compile(
        r'(?:FROM|[\w]*JOIN)[\s]+?([\w]+|\([\w]+?(?:[\s]*?,[\s]*?[\w]+)+?\))',
        flags=re.I)

    if re.match(r'[\s]*SELECT.+', query, flags=re.I) is None:
        return set()
    table_factors = tablefactor_rx.findall(query)
    if not table_factors:
        raise RuntimeError("Bad query: " + query)
    tables = set()
    for table_factor in table_factors:
        for tbl_candidate in table_factor.lstrip('(').rstrip(')').split(','):
            match = re.match(r'^[\s]*([\w]+)[\s]*$', tbl_candidate)
            if match is None:
                raise RuntimeError("Bad table name: %s", tbl_candidate)
            tables.add(match.group(1))
    return tables


def calc_combined_settings_md5(cs_settings_archive_res_id, cs_settings_patch_res_id=None, settings_spec=None):
    if not any([cs_settings_archive_res_id, cs_settings_patch_res_id, settings_spec]):
        return None
    digest = hashlib.md5()
    if cs_settings_archive_res_id:
        settings_archive = sdk2.Resource[cs_settings_archive_res_id]
        digest.update(settings_archive.md5)
    if cs_settings_patch_res_id:
        settings_patch = sdk2.Resource[cs_settings_patch_res_id]
        digest.update(settings_patch.md5)
    if settings_spec:
        sorted_json = json.dumps(json.loads(settings_spec), sort_keys=True)
        digest.update(sorted_json)
    return digest.hexdigest()


def json_from_resource(task, res_id):
    resource_path = task.sync_resource(res_id)
    with open(resource_path) as resource_file:
        return json.load(resource_file)


def encode_tags(tags_list):
    return '&'.join(
        map(
            lambda tag: '({})'.format(tag),
            tags_list
        )
    )


def encode_settings(settings):
    return {
        setting_name: {
            encode_tags(setting_value.get('tags', [])): setting_value
            for setting_value in setting_values
        }
        for setting_name, setting_values in settings.items()
    }


def is_patch_encoded(res_id):
    return getattr(sdk2.Resource[res_id], 'encoded_patch', False)


def decode_settings(settings):
    return {
        setting_name: list(values[tags] for tags in sorted(values.keys()))
        for setting_name, values in settings.items()
    }


def get_cs_settings(task, cs_settings_archive_res_id, cs_settings_patch_res_id=None, settings_spec=None):
    from jsondiff import JsonDiffer
    encoded_patch=False
    if cs_settings_archive_res_id:
        settings = json_from_resource(task, cs_settings_archive_res_id)
        if cs_settings_patch_res_id:
            encoded_patch = is_patch_encoded(cs_settings_patch_res_id)
            if encoded_patch:
                settings = encode_settings(settings)
            differ = JsonDiffer(syntax='compact')
            diff = differ.unmarshal(json_from_resource(task, cs_settings_patch_res_id))
            settings = differ.patch(settings, diff)
    else:
        settings = {}
    if settings_spec:
        settings_spec_parsed = json.loads(settings_spec) if isinstance(settings_spec, (str, unicode)) else settings_spec
        if encoded_patch:
            settings_spec_parsed = encode_settings(settings_spec_parsed)
        settings.update(settings_spec_parsed)
    if encoded_patch:
        settings = decode_settings(settings)
    return json.dumps(settings)
