# -*- coding: utf-8 -*-

import logging
import os
import re
import json
import datetime
import pytz
import time

from sandbox import sdk2

from sandbox.common import config, utils

from sandbox.common.types.client import Tag

import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt

from sandbox.common.fs import get_unique_file_name

from sandbox.sandboxsdk.parameters import SandboxBoolParameter, Container
from sandbox.sandboxsdk.environments import SvnEnvironment, PipEnvironment
from sandbox.sandboxsdk.process import run_process
from sandbox.sandboxsdk.svn import Arcadia
from sandbox.sandboxsdk.errors import SandboxSubprocessError
from sandbox.sdk2.paths import get_logs_folder

from sandbox.projects import resource_types
from sandbox.projects.common import utils2
from sandbox.projects.common.utils import get_or_default

from sandbox.projects.common.yabs.server.db.task.mysql import install_yabs_mysql, synpack_tables, un_myisampack_if_needed, prepare_perl_oneshot_environment

import sandbox.projects.common.yabs.server.db.utils as dbutils

from sandbox.projects.common.yabs.server.db.task.basegen import (
    BinBasesTask,
    FETCH_OUTPUTS_KEY,
    FETCH_PREFIX_KEY,
    GENERATION_ID_KEY,
    Options,
)

from sandbox.projects.common.yabs.server.db import yt_bases
from sandbox.projects.common.yabs.server.db.task.cs import YtPool

from sandbox.projects.common.yabs.server.components import hugepage_warmup
from sandbox.projects.yabs.qa.utils import yt_utils
from sandbox.projects.yabs.qa.utils.general import get_yt_path_html_hyperlink

from sandbox.common.errors import TaskFailure, TaskError

from sandbox.projects.common.yabs.server.tracing import TRACE_WRITER_FACTORY
from sandbox.projects.yabs.sandbox_task_tracing import trace_calls, trace_entry_point
from sandbox.projects.yabs.sandbox_task_tracing.wrappers.subprocess import check_output, STDOUT, CalledProcessError


HOST_SCORES_KEY = 'host_scores'

_SSD_CHECK_FAILS_KEY = 'ssd_check_fails'

FETCH_OUTPUTS_ATTR = '__' + FETCH_OUTPUTS_KEY

logger = logging.getLogger(__name__)


class ReuseFetchResults(SandboxBoolParameter):
    name = 'reuse_fetch_results'
    description = 'Reuse fetch result'
    default_value = True


class Container(Container):
    required = False
    # default_value = 2094988479  # https://sandbox.yandex-team.ru/resource/2094988479 YABS_MYSQL_LXC_IMAGE 08.04.2021
    default_value = None
    description = 'Container'


class Parameters(sdk2.Parameters):
    install_yabs_mysql = sdk2.parameters.Bool(
        label='Install yabs mysql in task itself',
        description='Should be abandoned in favour of prepared LXC containter usage',
        # default=False,
        default=True,
        required=True,
    )


class YabsServerBasesFetch(BinBasesTask):

    type = 'YABS_SERVER_BASES_FETCH'

    required_ram = 40 * 1024

    execution_space = 640 * 1024

    max_restarts = 10

    @utils.singleton_classproperty
    def client_tags(self):
        return (
            Tag.LINUX_PRECISE & Tag.SSD & Tag.YABS
            if config.Registry().common.installation in ctm.Installation.Group.NONLOCAL else
            Tag.GENERIC
        )

    input_parameters = (ReuseFetchResults, Container) + tuple(Parameters) + BinBasesTask.input_parameters

    environment = (
        SvnEnvironment(),
        PipEnvironment('yandex-yt', use_wheel=True),
        PipEnvironment('MySQL-python', '1.2.5', use_wheel=True),
    ) + BinBasesTask.environment

    privileged = True

    @property
    def cs_fetch_ver(self):
        if 'cs_fetch_ver' not in self.ctx:
            bs_release_yt_dir = self.get_yabscs()
            try:
                cs_fetch_ver = dbutils.get_cs_fetch_ver(bs_release_yt_dir)
            except Exception as e:
                logger.exception(e)
                cs_fetch_ver = None
            self.ctx['cs_fetch_ver'] = cs_fetch_ver

        return self.ctx['cs_fetch_ver']

    @trace_calls
    def get_oneshots(self):
        yabscs_path = self.get_yabscs()
        oneshot_json_path = os.path.join(yabscs_path, 'oneshots.json')
        with open(oneshot_json_path) as f:
            return json.load(f)

    def on_enqueue(self):
        yt_pool = get_or_default(self.ctx, YtPool)
        if not yt_pool or yt_pool == yt_bases.YT_POOL:
            yt_pool_semaphore = self.YT_POOL_FETCH_SEMAPHORE
        else:
            yt_pool_semaphore = yt_pool

        self.semaphores(ctt.Semaphores(
            acquires=[
                ctt.Semaphores.Acquire(name=yt_pool_semaphore, weight=3)
            ]
        ))

    def get_generation_id(self):
        return self.ctx.get(GENERATION_ID_KEY, 0)

    @trace_entry_point(writer_factory=TRACE_WRITER_FACTORY)
    def on_execute(self):
        from yt.wrapper import YtClient

        self.check_and_wait_tasks()

        yt_token = self.get_yt_token()  # Fail fast if no token available
        yt_client = YtClient(proxy=yt_bases.YT_PROXY, token=yt_token)

        if config.Registry().client.sandbox_user:
            hugepage_warmup.check_required_ram(self)

        dbs_to_generate = self.get_db_list()

        try_reuse_fetch_results = all((
            self.ctx.get(ReuseFetchResults.name),
            not self.oneshot_path
        ))

        fetch_prefix, fetch_output = self._try_reuse_fetch_results(dbs_to_generate, yt_client) if try_reuse_fetch_results else (None, None)

        if not (fetch_prefix and fetch_output):
            logging.info('Will not reuse fetch results')
            self._prepare_data(dbs_to_generate)
            self._modify_tables(yt_token)
            fetch_prefix, fetch_output = self._execute_yabscs(yt_token, yt_client, dbs_to_generate)
        else:
            logger.debug("Reuse fetch results from \"%s\"", fetch_prefix)
            yt_utils.set_yt_node_ttl(fetch_prefix, yt_bases.FETCH_OUTPUT_TTL, yt_client)
            logger.debug("Set TTL for node %s: %s", fetch_prefix, yt_bases.FETCH_OUTPUT_TTL)
            self.ctx["fetch_reused"] = True

        self.ctx[FETCH_PREFIX_KEY] = fetch_prefix
        self.ctx[FETCH_OUTPUTS_KEY] = fetch_output

    @trace_calls
    def _prepare_data(self, dbs_to_generate):
        """Sync tables, install packages, get LMs"""
        mysql_instances = set()

        table_providers, lm_dumps = self.process_archive_contents(
            dbs_to_generate, bool(self.switcher_options), run_cs_import=False, need_switcher_bases_restore=bool(self.need_switcher_bases_restore)
        )

        if self.ctx.get('install_yabs_mysql', True):
            logging.debug('Started installing YABS mysql')
            start_time = time.time()
            mysql_instances = frozenset(tp.inst for tp in table_providers)
            install_yabs_mysql(switcher=bool(self.switcher_options), mysql_instances=mysql_instances)
            logging.debug('Finished installing YABS mysql. Time elapsed: {: 0.3f}', time.time() - start_time)
        else:
            logging.debug('Starting YABS mysql')
            try:
                run_process(['/etc/init.d/mysql.yabs', 'yabs', 'start'], log_prefix='mysql.yabs_start')
            except Exception as err:
                raise RuntimeError(err)
            logging.debug('Started YABS mysql')

        synpack_tables(table_providers)

    @trace_calls
    def _modify_tables(self, yt_token):

        _apply_oneshots_json(
            self.get_yabscs(),
            self.get_db_date(),
            self.get_glued_mysql_archive_contents(),
            self._get_real_instance(),
        )

        if self.oneshot:
            run_process(['/bin/bash', '-c', self.oneshot], log_prefix='oneshot')

        if self.oneshot_path:
            apply_oneshot(
                self.oneshot_path,
                oneshot_config=self.get_oneshot_config(),
                basegen_instance=self._get_real_instance(),
                glued_mysql_archive_contents=self.get_glued_mysql_archive_contents(),
                store_oneshot_log=self._store_oneshot_log,
            )

        for switcher_opt in self.switcher_options:
            with self.current_action('Running experiment-switcher'):
                _run_experiment_switcher(
                    self.switcher_revision,
                    patch=self.ctx.get('arcadia_patch'),
                    add_args=self.switcher_args,
                    network_only=(switcher_opt == Options.RUN_SWITCHER_NETWORK),
                    yt_token=yt_token
                )

    @trace_calls
    @yt_bases.yabscs_failures_retried
    def _execute_yabscs(self, yt_token, yt_client, dbs_to_generate):
        fetch_prefix, fetch_output = yt_bases.fetch(
            yt_token,
            self.get_yabscs(),
            dbs_to_generate,
            self.client_info['fqdn'],
            str(self.id),
            self.get_db_date(),
            settings_spec=self.cs_settings,
            task_id=self.id,
            yt_pool=get_or_default(self.ctx, YtPool),
        )
        yt_bases.fill_node_attributes(
            path=fetch_prefix,
            yt_client=yt_client,
            node_attributes={
                yt_bases.BIN_DBS_NAMES_ATTR: dbs_to_generate,
                yt_bases.MYSQL_ARCHIVE_ATTR: self.mysql_archive_contents,
                yt_bases.INPUT_ARCHIVE_ATTR: self.input_spec_res_id,
                yt_bases.IS_REUSABLE_ATTR: not (self.ctx.get('arcadia_patch') or self.oneshot_path),
                FETCH_OUTPUTS_ATTR: fetch_output,
                yt_bases.SETTINGS_SPEC_MD5_ATTR: dbutils.calc_combined_settings_md5(self.cs_settings_archive_res_id,
                                                                                    self.cs_settings_patch_res_id,
                                                                                    self.settings_spec),
                yt_bases.SETTINGS_SPEC_ATTR: self.settings_spec,
                yt_bases.CS_FETCH_VER_ATTR: self.cs_fetch_ver,
            }
        )

        return fetch_prefix, fetch_output

    @trace_calls
    def _store_oneshot_log(self, oneshot_log_data):
        out_path = get_unique_file_name(os.path.abspath('./'), 'oneshot.log')

        with open(out_path, 'w') as out:
            out.write(oneshot_log_data)

        oneshot_log_res = self.create_resource(
            description='oneshot output',
            resource_path=out_path,
            resource_type=resource_types.OTHER_RESOURCE,
        )
        self.mark_resource_ready(oneshot_log_res.id)
        self.set_info(
            utils2.resource_redirect_link(oneshot_log_res.id, 'Oneshot execution log'),
            do_escape=False
        )

        self.ctx['oneshot_log_id'] = oneshot_log_res.id
        logging.info('oneshot_log resource created')

    @trace_calls
    def _get_real_instance(self):
        yabscs_path = self.get_yabscs()
        tag_instances = set(dbutils.get_mysql_tag(yabscs_path, tag) for tag in self.get_db_list())
        if len(tag_instances) > 1:
            raise RuntimeError("Cannot generate bases from multiple MySql instances (%s)" % ', '.join(tag_instances))
        return tag_instances.pop() if tag_instances else None

    @trace_calls
    def _try_reuse_fetch_results(self, dbs_to_generate, yt_client):
        filter_attributes = {
            yt_bases.BIN_DBS_NAMES_ATTR: dbs_to_generate,
            yt_bases.MYSQL_ARCHIVE_ATTR: self.mysql_archive_contents,
            yt_bases.INPUT_ARCHIVE_ATTR: self.input_spec_res_id,
            yt_bases.IS_REUSABLE_ATTR: True,
            yt_bases.SETTINGS_SPEC_MD5_ATTR: dbutils.calc_combined_settings_md5(self.cs_settings_archive_res_id, self.cs_settings_patch_res_id, self.settings_spec),
            yt_bases.CS_FETCH_VER_ATTR: self.cs_fetch_ver
        }
        logging.info('Search fetch results to reuse. Filter nodes by attributes: %s',
                     json.dumps(filter_attributes, indent=2))
        node_to_reuse = yt_bases.find_node_to_reuse(yt_client, yt_bases.FETCH_ROOT, check_task_status=False, filter_attributes=filter_attributes, add_attributes=[FETCH_OUTPUTS_ATTR])
        if node_to_reuse:
            self.node_to_reuse = node_to_reuse['$value']
            logger.info('Found fetch results to reuse: node %s with attributes %s',
                        self.node_to_reuse, node_to_reuse['$attributes'])
            self.set_info('Reuse fetch results from node {}'
                          .format(get_yt_path_html_hyperlink(proxy=yt_bases.YT_PROXY, path=self.node_to_reuse)),
                          do_escape=False)
            return self.node_to_reuse, node_to_reuse['$attributes'][FETCH_OUTPUTS_ATTR]
        else:
            logging.info('Not found fetch results to reuse')

        return None, None


@trace_calls
def apply_oneshot(oneshot_path, oneshot_config, basegen_instance, glued_mysql_archive_contents, store_oneshot_log):
    oneshot_type = oneshot_path.split('.')[-1]

    if len(oneshot_config.instances) > 1:
        raise TaskError("Cannot apply oneshot to multiple instances (%s)" % oneshot_config.instances)

    un_myisampack_if_needed(oneshot_config.tables)

    if oneshot_type == 'pl':
        prepare_perl_oneshot_environment(instance=basegen_instance)
        _apply_script_oneshot('perl', oneshot_config.query, store_oneshot_log)
    elif oneshot_type == 'py':
        _apply_script_oneshot('python', oneshot_config.query, store_oneshot_log)
    elif oneshot_type == 'sql':
        if oneshot_config.instances:
            oneshot_instance = next(iter(oneshot_config.instances))
            if _is_oneshot_needed(glued_mysql_archive_contents, oneshot_instance, basegen_instance, oneshot_config.tables):
                _apply_sql_oneshot(oneshot_config.query, store_oneshot_log)
    else:
        raise Exception('{}-oneshots not implemented'.format(oneshot_type))


def _get_oneshot_datetime_bound(oneshot):
    bound_str = oneshot.get('apply-if-backup-is-older-than')

    if bound_str is None:
        bound_str = oneshot.get('datetime')  # XXX: deprecated key

    if bound_str is not None:
        return _to_date(bound_str, "%Y/%m/%d %H:%M")

    return None


def _apply_oneshots_json(yabscs_path, db_date, glued_mysql_archive_contents, basegen_instance):

    backup_date = _to_date(str(db_date), "%Y%m%d")

    with open(os.path.join(yabscs_path, 'oneshots.json')) as f:
        try:
            oneshots = json.loads(f.read())
        except Exception as e:
            raise TaskFailure('Can\'t parse oneshots json file : %s' % e)

    for oneshot in oneshots:
        if '!!!README!!!' in oneshot:
            # skip dummy "readme" entry
            continue

        try:
            oneshot_instance = oneshot['instance']
            query = oneshot['SQL']
            modified_tables = set(oneshot['tables'])
        except KeyError as k:
            raise TaskFailure("Missing key %s in oneshots.json" % k)

        # XXX remove this after yabs-server release with good oneshots.json
        if query == "alter table PageGroupExperiment change `HitCostID` `HitCostParamID` int(11) unsigned NOT NULL DEFAULT '0'":
            logging.warning("SKIPPING infamous broken query form oneshots.json: %s", query)
            continue

        try:
            oneshot_datetime_bound = _get_oneshot_datetime_bound(oneshot)
        except Exception as e:
            raise TaskFailure("Can't parse oneshot datetime: %s" % e)

        if oneshot_datetime_bound is not None and backup_date >= oneshot_datetime_bound:
            # skip oneshot: database backup is fresh enough not to need it
            continue
        if _is_oneshot_needed(glued_mysql_archive_contents, basegen_instance, oneshot_instance, modified_tables):
            un_myisampack_if_needed(modified_tables)
            _apply_sql_oneshot(query)


def _to_date(date, date_format):
    pst = pytz.timezone('Europe/Moscow')
    return pst.localize(datetime.datetime.strptime(date, date_format))


def _get_source_instance(glued_mysql_archive_contents, destination_instance, database, table):
    instance = destination_instance

    known_tables = glued_mysql_archive_contents.tables

    while True:
        # Follow chain of table links (these reflect replica)
        key = '.'.join((instance, database, table))
        try:
            value = known_tables[key]
        except KeyError:
            logging.info("Source of %s.%s.%s not found (%s missing from MYSQL_ARCHIVE_CONTENTS)", instance, database, table, key)
            return None

        try:
            int(value)
        except (TypeError, ValueError):
            if not str(value).startswith('='):
                raise TaskError("Invalid value in MYSQL_ARCHIVE_CONTENTS for key %s: %s" % (key, value))
        else:
            # This is a real resource id
            return instance

        instance, database, real_table = value[1:].split('.')
        if real_table != table:
            raise RuntimeError("Bad MYSQL_ARCHIVE_CONTENTS: %s points to table with another name: %s" % (key, value))


def _is_oneshot_needed(
    glued_mysql_archive_contents,
    basegen_instance,
    oneshot_instance,
    modified_tables,
):
    # FIXME works only with yabsdb databases
    db = 'yabsdb'

    import MySQLdb

    existing_tables = set(row[0] for row in _run_query(MySQLdb, _YABS_MYSQL_SOCKET, 'SHOW TABLES;', db=db))
    touched_tables = existing_tables & modified_tables

    needed = any(_get_source_instance(glued_mysql_archive_contents, basegen_instance, db, table) == oneshot_instance for table in touched_tables)
    logging.info(
        "Oneshot for instance %s and declared modified tables %s is %sneded",
        oneshot_instance,
        ', '.join(modified_tables),
        '' if needed else 'NOT '
    )
    return needed


@trace_calls
def _apply_sql_oneshot(query, store_oneshot_log=None):
    # FIXME works only with yabsdb databases
    db = 'yabsdb'

    logging.info("Applying oneshot\n%s", query)
    query_file = get_unique_file_name(get_logs_folder(), 'oneshot')
    try:
        # Write file because of big query
        with open(query_file, 'w') as f:
            f.write(query + ';\nFLUSH TABLES;')
        try:
            result = check_output(
                'mysql -u root {db} --socket {socket} < {query_file}'.format(db=db, socket=_YABS_MYSQL_SOCKET, query_file=query_file),
                stderr=STDOUT,
                shell=True
            )
        except CalledProcessError as e:
            result = e.output
            raise

    except Exception as e:
        raise TaskFailure('Query %s failed: %s' % (query, e))

    finally:
        if store_oneshot_log is not None:
            store_oneshot_log(result)
        logging.info('Result of oneshot:\n' + result)

    logging.info("Done applying oneshot")


@trace_calls
def _apply_script_oneshot(runner, file_contents, store_oneshot_log):
    if runner not in ['perl', 'python']:
        raise Exception('{}-oneshots not implemented'.format(runner))
    with open('oneshot', 'w') as f:
        f.write(file_contents)

    oneshot_success = True
    try:
        result = check_output([runner, 'oneshot'], stderr=STDOUT)
    except CalledProcessError as e:
        oneshot_success = False
        result = e.output

    store_oneshot_log(result)

    logging.info('Result of oneshot:\n' + result)

    if not oneshot_success:
        raise TaskFailure('{}-script {} failed with {}'.format(runner, file_contents, e))


@trace_calls
def _run_experiment_switcher(revision, patch=None, add_args=None, network_only=False, yt_token=''):
    add_args = add_args or list()
    es_dir = os.path.abspath('yabs/utils/experiment-switcher')
    if os.path.exists(es_dir):
        os.rmdir(es_dir)
    os.makedirs(es_dir)

    Arcadia.export('arcadia:/arc/trunk/arcadia/yabs/utils/experiment-switcher', es_dir, revision=revision)
    arcadia_path = re.search('^(.*)/yabs/utils/experiment-switcher', es_dir).group(1)
    if patch is not None:
        Arcadia.apply_patch(arcadia_path, patch, es_dir)

    env = os.environ.copy()
    env['YT_TOKEN'] = yt_token
    args = ['test', 'network', '--tested'] if network_only else ['--new-ctr-prediction', '--tested']  # del --refresh
    args = filter(lambda x: x not in add_args, args)
    args = add_args + args
    if args and args[0].endswith('.pl'):
        switcher_script = args[0]
        args.pop(0)
    else:
        switcher_script = 'experiment-switcher.pl'

    logging.info('Experiment-switcher args: %s', ' '.join(args))
    # sometimes experiment-switcher.pl temporary fails. restart should help
    num_of_tries = 2
    for cnt in range(num_of_tries):
        try:
            run_process(
                [os.path.join(es_dir, switcher_script)] + list(args),
                log_prefix='experiment-switcher',
                work_dir=es_dir,
                environment=env,
            )
        except SandboxSubprocessError as e:
            if cnt < num_of_tries - 1:
                continue
            else:
                raise e
        break


_YABS_MYSQL_SOCKET = '/var/run/mysqld.yabs/mysqld.sock'


@trace_calls
def _run_query(mysqldb_module, socket, query, db='yabsdb'):
    connection = mysqldb_module.connect(unix_socket=socket, user='root', db=db)
    cursor = connection.cursor()
    cursor.execute(query)
    rows = cursor.fetchall()
    cursor.close()
    connection.close()
    return list(rows)
