import logging
import json
import hashlib
import time

from sandbox import common
from sandbox import sdk2
from sandbox.common.errors import TaskError, TaskFailure

from sandbox.sandboxsdk.parameters import ResourceSelector, SandboxStringParameter, SandboxBoolParameter, ListRepeater
from sandbox.sandboxsdk.svn import Arcadia
from sandbox.sandboxsdk.environments import PipEnvironment

from sandbox.projects.common.utils import get_or_default

import sandbox.projects.common.yabs.server.db.utils as dbutils
from sandbox.projects.common.yabs.server.db import yt_bases
from sandbox.projects.common.yabs.server.components.task import ServerMkdbInstallTask

from sandbox.projects.common.yabs.server.parameters import (
    BinDbList,
    BackupDate,
    Option,
    OptionsParameter,
    PrevArchiveContents,
    BsReleaseYT,
    AdditionalRestoreDescription,
    AdditionalRestoreDescriptionMD5,
)
from sandbox.projects.common.yabs.server.db.task.mysql import TableProvider

from sandbox.projects.yabs.bases.table_backups import get_server_instance
from sandbox.projects.yabs.qa.bases.sample_tables.parameters import (
    SamplingStrategyParameter,
    SamplingQueryTemplateResourceParameter,
    SamplingTablesResourceParameter,
    FindTablesForSampling,
    SamplingTablesBlackListResourceParameter,
    MinRowsCountForSamplingTables,
)
from sandbox.projects.yabs.qa.resource_types import (
    YABS_MYSQL_ARCHIVE_CONTENTS,
    YABS_CS_INPUT_SPEC,
    BS_RELEASE_YT,
    YABS_MYSQL_RESTORE_DESCRIPTION,
    YABS_CS_SETTINGS_ARCHIVE,
    YABS_CS_SETTINGS_PATCH,
)
from sandbox.projects.yabs.qa.utils import general, yt_utils


class MySQLArchiveContents(ResourceSelector):
    name = 'mysql_archive_contents'
    description = 'MySQL archive contents.'
    required = True
    multiple = False
    resource_type = YABS_MYSQL_ARCHIVE_CONTENTS


class BSReleaseYT(ResourceSelector):
    name = 'bs_release_yt_resource'
    description = 'Resource with cs_cycle and transport (last stable BS_RELEASE_YT if empty)'
    required = False
    multiple = False
    resource_type = BS_RELEASE_YT


class InputSpec(ResourceSelector):
    name = 'input_spec'
    description = 'Input spec for cs_import'
    required = True
    resource_type = YABS_CS_INPUT_SPEC


class SettingsSpec(SandboxStringParameter):
    name = 'settings_spec'
    description = 'CS Settings spec (json) to be passed with --settings-spec'
    required = False


class SettingsArchive(ResourceSelector):
    name = 'settings_archive'
    description = 'CS settings archive'
    required = False
    resource_type = YABS_CS_SETTINGS_ARCHIVE


class CSSettingsPatch(ResourceSelector):
    name = 'cs_settings_patch'
    description = 'Update cs settings archive with jsondiff patch'
    required = False
    resource_type = YABS_CS_SETTINGS_PATCH


class ExecCommonYtOneshots(SandboxBoolParameter):
    name = 'exec_common_yt_oneshots'
    description = 'Execute common yt oneshots'
    default_value = False


class SaveAllInputs(sdk2.parameters.Bool):
    name = 'save_all_inputs'
    description = 'Save all input tables regardless of the desired importers to run'
    default = True
    default_value = True


class UseSaveInputFromCS(sdk2.parameters.Bool):
    name = 'use_save_input_from_cs'
    description = 'Use cs import --save_input to save inputs'
    default = False
    default_value = False


class FilterInputArchiveTablesByOrderID(sdk2.parameters.Bool):
    name = 'filter_input_archive_tables_by_orderid'
    description = 'Filter tables by OrderID'
    default = True
    default_value = True


class YtPool(sdk2.parameters.String):
    name = 'yt_pool'
    description = 'YT pool to use in operations'
    default = yt_bases.YT_POOL
    default_value = yt_bases.YT_POOL


class KeysForSamplingTables(ListRepeater, SandboxStringParameter):
    name = 'sampling_tables_keys'
    description = 'Keys for sampling tables'
    default_value = ['BannerID', 'CreativeID', 'GroupExportID', 'OrderID']


class YabsCSTask(ServerMkdbInstallTask):
    required_ram = 80 * 1024

    # These parameters affect base contents.
    # Parameters that don't should be defined in child tasks.

    input_parameters = (
        InputSpec,
        SettingsSpec,
        SettingsArchive,
        CSSettingsPatch,
        MySQLArchiveContents,
        BSReleaseYT,
        SamplingStrategyParameter,
        SamplingQueryTemplateResourceParameter,
        SamplingTablesResourceParameter,
        FindTablesForSampling,
        SamplingTablesBlackListResourceParameter,
        MinRowsCountForSamplingTables,
        KeysForSamplingTables,
        ExecCommonYtOneshots,
        SaveAllInputs,
        UseSaveInputFromCS,
        FilterInputArchiveTablesByOrderID,
        YtPool,
    ) + ServerMkdbInstallTask.input_parameters

    YT_POOL_DEFAULT_SEMAPHORE = 'yabscs/pools/yabs-cs-sandbox'
    YT_POOL_REDUCE_SEMAPHORE = 'yabscs/pools/yabs-cs-sandbox-reduce'
    YT_POOL_FETCH_SEMAPHORE = 'yabscs/pools/yabs-cs-sandbox-fetch'

    environment = (PipEnvironment('jsondiff', version="1.2.0"), )

    def read_json_resource(self, res_id):
        path = self.sync_resource(res_id)
        with open(path) as f:
            return json.load(f)

    def read_text_resource(self, res_id):
        path = self.sync_resource(res_id)
        with open(path) as f:
            return f.read()

    @property
    def mysql_archive_contents(self):
        return get_or_default(self.ctx, MySQLArchiveContents)

    @property
    def input_spec_res_id(self):
        return get_or_default(self.ctx, InputSpec)

    @property
    def settings_spec(self):
        return get_or_default(self.ctx, SettingsSpec)

    @property
    def cs_settings_archive_res_id(self):
        return get_or_default(self.ctx, SettingsArchive)

    @property
    def cs_settings_patch_res_id(self):
        return get_or_default(self.ctx, CSSettingsPatch)

    @property
    def cs_settings(self):
        """Get CS settings

        :return: CS settings
        :rtype: str
        """
        if 'cs_settings' not in self.ctx:
            self.ctx['cs_settings'] = dbutils.get_cs_settings(self, self.cs_settings_archive_res_id, self.cs_settings_patch_res_id, self.settings_spec)
        return self.ctx['cs_settings']

    @property
    def exec_common_yt_oneshots(self):
        return self.ctx.get(ExecCommonYtOneshots.name)

    @property
    def sampling_parameters_hash(self):
        from sandbox.projects.yabs.qa.bases.sample_tables.parameters import get_sampling_parameters_hash

        return get_sampling_parameters_hash(
            self.ctx.get(SamplingQueryTemplateResourceParameter.name, 0),
            self.ctx.get(SamplingTablesResourceParameter.name, 0),
            get_or_default(self.ctx, MinRowsCountForSamplingTables),
            get_or_default(self.ctx, KeysForSamplingTables),
            self.ctx.get(SamplingTablesBlackListResourceParameter.name, 0),
            get_or_default(self.ctx, FindTablesForSampling))

    def get_db_date(self):
        res_id = self.mysql_archive_contents
        res_data = common.rest.Client().resource[res_id].read()
        date = res_data['attributes'].get('date')
        if date is None:
            raise TaskError("Resource {} has no 'date' attribute.".format(res_id))
        return date

    def get_yabscs(self):
        """Sync and unpack BS_RELEASE_YT"""
        try:
            return self._yabscs_path
        except AttributeError:
            yabscs_res_id = self.ctx[BSReleaseYT.name]
            self._yabscs_path = dbutils.get_yabscs(self, yabscs_res_id)
            return self._yabscs_path

    def get_yabscs_base_revision(self):
        res_id = self.ctx[BSReleaseYT.name]
        res_data = common.rest.Client().resource[res_id].read()
        attrs = res_data['attributes']
        return int(attrs['revision'] if attrs['branch'] == 'trunk' else attrs['base_revision'])

    def can_basever_be_valid(self):
        try:
            rev = self.get_yabscs_base_revision()
        except (KeyError, ValueError):
            logging.info("Failed to get valid revision from attributes of BS_RELEASE_YT")
            return False
        return rev >= 3182896

    def get_yt_token(self):
        return dbutils.get_yabscs_yt_token(self)

    def get_glued_mysql_archive_contents(self):
        try:
            return self._glued_mysql_archive_contents
        except AttributeError:
            self._glued_mysql_archive_contents = dbutils.GluedMySQLArchiveContents(self.sync_resource, self.mysql_archive_contents)
            return self._glued_mysql_archive_contents

    def process_archive_contents(self, db_tags, run_switcher, run_cs_import, dry_run=False, need_switcher_bases_restore=False):
        """
        Check MYSQL_ARCHIVE_CONTENTS resource specified by MySQLArchiveContents parameter against needed tables for mkdb and yabs-switcher.
        Start YABS_SERVER_GET_SQL_ARCHIVE if some tables are missing,
        otherwise return TableProviders and linear model dumps
        """
        contents = self.get_glued_mysql_archive_contents()

        yabscs_path = self.get_yabscs()
        db_tables_iter = yt_bases.iter_db_tables(yabscs_path, db_tags, run_cs_import)
        table_providers, db_missing = self._find_archived_tables(contents.tables, db_tables_iter, dry_run)
        if db_missing:
            logging.info("Missing tables: %s", '\n'.join('.'.join(key) for key in db_missing))

        oneshot_config = self.get_oneshot_config()
        logging.info("Restore more tables")
        oneshot_table_providers, oneshot_missing_tables = self._find_archived_tables(contents.tables, oneshot_config.iter_tables(), dry_run)

        additional_restore_description = {'.'.join(key): None for key in oneshot_missing_tables}
        logging.info(additional_restore_description)
        table_providers |= oneshot_table_providers

        if run_switcher:  # FIXME move switcher-related logic out of this class!
            sw_tps, sw_missing = self._find_archived_tables(contents.tables, _iter_switcher_tables(), dry_run)
            additional_restore_description.update({'.'.join(key): 'bsdb' for key in sw_missing})
            table_providers |= sw_tps

        if db_missing or additional_restore_description:
            self._get_absent_tables(db_tags, self.mysql_archive_contents, additional_restore_description)  # FIXME where is db_missing??

        if dry_run:
            return None

        return table_providers, contents.lm_dumps

    def get_oneshot_config(self):
        """Override this to apply nontrivial oneshot"""
        return OneshotConfig(self.get_yabscs(), base_tags=[], query='', tables=[], hosts=[])

    def report_operations(self):
        with self.memoize_stage.report_operations:
            start_time = int(time.time())
            finish_time = start_time + self.ctx.get('kill_timeout', 5 * 3600)
            operations_filter_link = yt_utils.get_operations_filter_link('hahn', '"task_id"="{}"'.format(self.id), start_time, finish_time)
            self.set_info('Show operations launched by this task: {}'.format(general.html_hyperlink(operations_filter_link, 'operations')), do_escape=False)

    def on_execute(self):
        raise NotImplementedError("Class is abstarct")

    def _find_archived_tables(self, known_tables, needed_tables_iter, dry_run=False):
        missing = set()

        def _find_table_resource(needed_key):
            needed_key_str = '.'.join(needed_key)
            try:
                val = known_tables[needed_key_str]
            except KeyError:
                missing.add(needed_key)
                logging.info("Missing table: %s", needed_key_str)
                return None
            if str(val).startswith('='):
                linked_key_str = val[1:]
                linked_key = tuple(linked_key_str.split('.'))
                if key[2] != linked_key[2]:
                    raise RuntimeError("Linked table name mismatch in MYSQL_ARCHIVE_CONTENTS: %s %s" % (needed_key_str, val))

                logging.info("Table link: %s -> %s", needed_key_str, linked_key_str)
                return _find_table_resource(linked_key)
            else:
                logging.info("Found table: %s -> %s", needed_key_str, val)
                return int(val)

        table_providers = dict()
        for key in needed_tables_iter:
            inst, db, table = key
            try:
                res_id = _find_table_resource(key)
            except Exception:
                logging.exception("Failed _find_table_resource for table key %s", key)
                raise

            tp = TableProvider(self, inst, db, table, res_id)

            location_key = (db, table)
            if not dry_run and location_key in table_providers:
                raise RuntimeError("Table providers with duplicate output location: %s and %s" % (table_providers[location_key], tp))

            table_providers[location_key] = tp

        return set(table_providers.itervalues()), missing

    def _get_absent_tables(self, db_tags, archive_head_res_id, additional_restore_description):
        restore_description_md5 = None
        if additional_restore_description:
            restore_description_md5 = hashlib.md5(json.dumps(additional_restore_description)).hexdigest()

        # We do not filter out existing tables: GetSQLArchive will do it on its own
        attrs = {
            BinDbList.name: ' '.join(db_tags),
            BackupDate.name: self.get_db_date(),
            OptionsParameter.name: Option.NO_LM_DUMPS,
            PrevArchiveContents.name: archive_head_res_id,
            BsReleaseYT.name: self.ctx[BSReleaseYT.name],
            AdditionalRestoreDescriptionMD5.name: restore_description_md5,
        }
        # create resource
        path = 'additional_restore_description.json'
        with open(path, 'w') as rd_file:
            json.dump(additional_restore_description, rd_file, indent=4)
        restore_description_res_id = self.create_resource(
            description="Additional restore description for " + self.descr,
            resource_path=path,
            resource_type=YABS_MYSQL_RESTORE_DESCRIPTION,
            arch='any'
        ).id
        self.mark_resource_ready(restore_description_res_id)
        add_params = {AdditionalRestoreDescription.name: restore_description_res_id}

        prepare_task_id = self._find_or_start_task('YABS_SERVER_GET_SQL_ARCHIVE', attrs, start_params=add_params).id
        self.check_and_wait_tasks([prepare_task_id])
        raise RuntimeError("We should never get here")


class OneshotConfig(object):

    def __init__(self, yabscs_path, base_tags, query, tables, hosts):

        self.base_tags = base_tags
        self.query = query

        self.tables = set()

        for table in tables:
            parts = table.split('.')
            if len(parts) != 1:
                if len(parts) == 2 and parts[0] == 'yabsdb':
                    table = parts[1]
                else:
                    raise TaskFailure("Bad table in oneshot parameters: %s" % table)
            self.tables.add(table)

        hosts_instances = set(get_server_instance(h) for h in hosts)
        tag_instances = set(dbutils.get_mysql_tag(yabscs_path, tag) for tag in base_tags)

        if hosts_instances == {'yabs'}:
            self.instances = tag_instances
        else:
            if 'yabs' in hosts_instances:
                raise RuntimeError("Do not know how to handle oneshot for multiple hosts that include bsdb")
            self.instances = hosts_instances & tag_instances

        logging.info("Oneshot base tags: %s", sorted(self.base_tags))
        logging.info("Oneshot tables: %s", sorted(self.tables))
        logging.info("Oneshot query: %s", self.query)
        logging.info("Oneshot hosts_instances: %s", sorted(hosts_instances))
        logging.info("Oneshot tag_instances: %s", sorted(tag_instances))
        logging.info("Oneshot instances: %s", sorted(self.instances))

    def iter_tables(self):
        for instance in self.instances:
            for table in self.tables:
                key = (instance, 'yabsdb', table)
                logging.info("Oneshot table: %s", '.'.join(key))
                yield key


def _iter_switcher_tables():
    Arcadia.export(
        'arcadia:/arc/trunk/arcadia/yabs/utils/experiment-switcher/used_tables.json',
        'tb_list.json'
    )
    tl = json.loads(open('tb_list.json').read())
    for db in tl:
        for table in tl[db]:
            yield 'yabs', db, table
