import logging
import os
import datetime
import json
import time
import os.path

from sandbox import common
from sandbox import sdk2
from sandbox.sdk2.helpers import subprocess as sp
import sandbox.sandboxsdk.environments as sdk_environments
from sandbox.projects.mssngr.runtime.resources import MssngrYdbTool


BACKUP_WAIT_SEC = 4 * 3600
DEFAULT_MAX_BACKUP_COUNT = 7
WAIT_SLOT_SEC = 600
BACKUP_PARAM_RETRIES = 100


class Env:
    DEFAULT_YT_PROXY = 'hahn'

    ENV_TESTING = 'testing'
    ENV_ALPHA = 'alpha'
    ENV_PRODUCTION = 'production'

    _BACKUP_ENV = {
        ENV_TESTING: 'testing',
        ENV_ALPHA: 'alpha',
        ENV_PRODUCTION: 'prod'
    }

    _YDB_ENV = {
        ENV_TESTING: 'TESTING',
        ENV_ALPHA: 'ALPHA',
        ENV_PRODUCTION: 'PRODUCTION'
    }

    _YDB_ENDPOINT = {
        ENV_TESTING: 'ydb-ru-prestable.yandex.net:2135',
        ENV_ALPHA: 'ydb-ru-prestable.yandex.net:2135',
        ENV_PRODUCTION: 'ydb-ru.yandex.net:2135'
    }

    _YDB_DATABASE = {
        ENV_TESTING: '/ru-prestable/messenger/testing/backend',
        ENV_ALPHA: '/ru-prestable/messenger/alpha/backend',
        ENV_PRODUCTION: '/ru/messenger/prod/backend'
    }

    @staticmethod
    def __get_by_env(env, env_dict):
        if env in env_dict:
            return env_dict[env]
        raise 'Unknown env ' + env

    @staticmethod
    def get_ydb_endpoint(env):
        return Env.__get_by_env(env, Env._YDB_ENDPOINT)

    @staticmethod
    def get_ydb_database(env):
        return Env.__get_by_env(env, Env._YDB_DATABASE)

    @staticmethod
    def get_backup_env(env):
        return Env.__get_by_env(env, Env._BACKUP_ENV)

    @staticmethod
    def get_ydb_env(env):
        return Env.__get_by_env(env, Env._YDB_ENV)

    @staticmethod
    def get_yt_backup_path(env, path):
        return Env.yt_abs_path(os.path.join(path, Env.get_backup_env(env)))

    @staticmethod
    def yt_abs_path(path):
        return '/' + os.path.normpath('/' + path)  # make path starting with //

    @staticmethod
    def make_yt_name(now):
        return now.strftime('%Y-%m-%dT%H:%M:%S')

    @staticmethod
    def remove_old_data(yt_client, yt_path, max_count, recursive):
        old_data_items = yt_client.list(Env.yt_abs_path(yt_path), absolute=True)
        old_data_items = sorted(old_data_items, reverse=True)
        to_remove = old_data_items[max_count:]
        for item in to_remove:
            logging.info('Remove old data %s' % item)
            yt_client.remove(item, recursive=recursive)

#
# Not all YDB types could be transparently converted to YT types
# So we place here only convertible types
#
# Other types are described here: https://ydb.yandex-team.ru/docs/concepts/datatypes/
#


TYPE_MAP = {
    'INT64': 'int64',
    'INT32': 'int32',
    'INT16': 'int16',
    'INT8': 'int8',
    'UINT64': 'uint64',
    'UINT32': 'uint32',
    'UINT16': 'uint16',
    'UINT8': 'uint8',
    'DECIMAL': 'double',  # Is it correct?
    'FLOAT': 'double',
    'DOUBLE': 'double',
    'BOOL': 'boolean',
    'STRING': 'string',
    'UTF8': 'utf8',
}


def ydb_to_yt_type(type_name):
    res = TYPE_MAP.get(type_name)
    if res is None:
        raise Exception('Not supported YDB type')

    return res


def get_yt_scheme_from_ydb_scheme(ydb_scheme):
    yt_scheme = []

    for col in ydb_scheme["columns"]:
        yt_scheme.append({
            "name": col["name"],
            "type": ydb_to_yt_type(col["type"]["optional_type"]["item"]["type_id"]),
        })

    return yt_scheme


class BackupMssngrYdbToYt(sdk2.Task):
    """ A task, which backups YDB database. """

    class Requirements(sdk2.Task.Requirements):
        environments = [
            sdk_environments.PipEnvironment("yandex-yt"),
        ]
        cores = 1
        ram = 2048
        disk_space = 2048

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        # common parameters
        kill_timeout = BACKUP_WAIT_SEC

        with sdk2.parameters.Group("Input parameters") as in_block:
            endpoint = sdk2.parameters.String("Database endpoint (path:port)",
                                              default=Env.get_ydb_endpoint(Env.ENV_TESTING), required=True)
            db_path = sdk2.parameters.String("Database path", default=Env.get_ydb_database(Env.ENV_TESTING),
                                             required=True)
            db_tables = sdk2.parameters.List("Tables to backup (optional)", default=[])
            ydb_client = sdk2.parameters.Resource("YDB client to use", resource_type=MssngrYdbTool, required=True)
            yt_token = sdk2.parameters.Vault("Vault secret name with YT token",
                                             default='MSSNGR:yt-token-hahn',
                                             required=True)
            ydb_token = sdk2.parameters.Vault("Vault secret name with YDB token",
                                              default='MSSNGR:ydb-auth-token',
                                              required=True)
            backup_count = sdk2.parameters.Integer("Max backup count", default=DEFAULT_MAX_BACKUP_COUNT, required=True)

        with sdk2.parameters.Group("Output parameters") as out_block:
            yt_proxy = sdk2.parameters.String("YT proxy", default=Env.DEFAULT_YT_PROXY, required=True)
            yt_path = sdk2.parameters.String('YT path where backup will be stored', required=True)

    def on_create(self):
        self.Context.operation_id = None

    def get_ydb_cli_path(self):
        ydb_cli_resource = sdk2.ResourceData(self.Parameters.ydb_client)
        return str(ydb_cli_resource.path)

    def get_ydb_table_list(self, ydb_cli):
        with sdk2.helpers.ProcessLog(self, logger="ydb_cli_table_list") as pl:
            output = sp.check_output(
                [
                    ydb_cli,
                    '--endpoint', str(self.Parameters.endpoint),
                    '--database', str(self.Parameters.db_path),
                    'scheme', 'ls', '-l',
                    str(self.Parameters.db_path)
                ],
                env={
                    'YDB_TOKEN': self.Parameters.ydb_token.data()
                },
                stderr=pl.stderr,
                timeout=60
            )
            logging.info('ydb cli list tables result: %s', str(output))

        result = []
        for line in str(output).strip().split('\n'):
            if 'table' in line:
                line_split = line.split()
                if line_split[1] == 'table':
                    result.append(line_split[-2])

        return result

    def make_yt_folder(self, folder):
        self.yt_client.mkdir(folder)

    def get_env(self):
        return {
            'YT_TOKEN': self.Parameters.yt_token.data(),
            'YDB_TOKEN': self.Parameters.ydb_token.data()
        }

    def get_ydb_common_args(self, ydb_cli):
        return [
            ydb_cli,
            '--endpoint', str(self.Parameters.endpoint),
            '--database', str(self.Parameters.db_path),
        ]

    def create_dst_tables(self, ydb_cli, yt_dst_path, table_names):
        for tbl in table_names:
            table_path = os.path.join(self.Parameters.db_path, tbl)
            yt_tbl = tbl
            if tbl.startswith('/'):
                yt_tbl = tbl.split('/')[-1]
                table_path = tbl
            ydb_sch = self.get_ydb_table_scheme(ydb_cli, table_path)
            schema = get_yt_scheme_from_ydb_scheme(ydb_sch)

            full_path = os.path.join(yt_dst_path, yt_tbl)
            logging.info('Creating dest table %s with schema: %s' % (full_path, schema))
            self.yt_client.create_table(full_path, attributes={"schema": schema, "strict": True})

    def start_backup(self, ydb_cli, yt_dst_path, table_names, description):
        self.create_dst_tables(ydb_cli, yt_dst_path, table_names)

        args = self.get_ydb_common_args(ydb_cli)
        args.extend([
            'export', 'yt',
            '--format', 'proto-json-base64',
            '--proxy', str(self.Parameters.yt_proxy),
            '--description', description,
            '--retries', str(BACKUP_PARAM_RETRIES),
        ])

        for tbl in table_names:
            src_path = os.path.join(self.Parameters.db_path, tbl)
            dst_path = os.path.join(yt_dst_path, tbl)
            # db_name /Root
            # Table: /Root/mssngr/router/chat_states_and_history_enc
            if tbl.startswith('/'):
                # result should be
                # src: /Root/mssngr/router/chat_states_and_history_enc
                # dst: //home/mssngr/history/ydb_ru/backups/old_prod/chat_states_and_history_enc
                src_path = tbl
                tbl_dst = tbl.split('/')[-1]
                dst_path = os.path.join(yt_dst_path, tbl_dst)

            args.append('--item')
            args.append("src=%(src)s,dst=%(dst)s" % {
                'src': src_path,
                'dst': dst_path,
            })

        with sdk2.helpers.ProcessLog(self, logger="ydb_cli_start_backup") as pl:
            p = sp.Popen(
                args,
                env=self.get_env(),
                stdout=sp.PIPE,
                stderr=pl.stderr,
            )
            out, err = p.communicate()
            if p.returncode != 0:
                raise common.errors.TaskError("Could not start backup process")

            logging.info('Got json: %s' % out)
            json_data = json.loads(out)
            if json_data['status'] != u'SUCCESS':
                raise common.errors.TaskError("Something went wrong. See ydb_cli log")

            return json_data['id']

    def call_ydb_cli_with_json_out(self, args, cli_call):
        with sdk2.helpers.ProcessLog(self, logger="ydb_cli_%s.%s" % (cli_call, time.time())) as pl:
            p = sp.Popen(
                args,
                env=self.get_env(),
                stdout=sp.PIPE,
                stderr=pl.stderr,
            )
            out, err = p.communicate()
            if p.returncode == 0:
                try:
                    return json.loads(out)
                except:
                    logging.error('Not json response: %s' % out)
                    raise

            logging.error('ydb cli failed with message:')
            logging.error(out)
            logging.error(err)

            return None

    def get_ydb_operation(self, ydb_cli, operation_id, retry_count=3):
        args = self.get_ydb_common_args(ydb_cli)
        args.extend([
            'operation', 'get', operation_id,
            '--format', 'proto-json-base64',
        ])

        for i in xrange(retry_count):
            data = self.call_ydb_cli_with_json_out(args, 'get_operation')

            if data is not None:
                return data

            time.sleep(WAIT_SLOT_SEC)

        raise common.errors.TaskError("Could not get operation status")

    def cancel_operation(self, ydb_cli, operation_id):
        args = self.get_ydb_common_args(ydb_cli)
        args.extend([
            'operation', 'cancel', operation_id,
            '--format', 'proto-json-base64',
        ])

        with sdk2.helpers.ProcessLog(self, logger="ydb_cli_cancel_operation") as pl:
            for i in range(1, 4):
                p = sp.Popen(
                    args,
                    env=self.get_env(),
                    stdout=pl.stdout,
                    stderr=pl.stderr,
                )
                out, err = p.communicate()
                if p.returncode == 0:
                    return

                logging.error('ydb cli ({}) cancel failed with message:'.format(i))
                logging.error(out)
                logging.error(err)

                time.sleep(5*i)

    def forget_operation(self, ydb_cli, operation_id):
        args = self.get_ydb_common_args(ydb_cli)
        args.extend([
            'operation', 'forget', operation_id
        ])

        with sdk2.helpers.ProcessLog(self, logger="ydb_cli_forget_operation") as pl:
            for i in range(1, 4):
                p = sp.Popen(
                    args,
                    env=self.get_env(),
                    stdout=pl.stdout,
                    stderr=pl.stderr,
                )
                out, err = p.communicate()

                if p.returncode == 0:
                    return

                logging.error('ydb cli ({}) forget failed with message:'.format(i))
                logging.error(out)
                logging.error(err)

                time.sleep(5*i)

    def get_ydb_table_scheme(self, ydb_cli, table_path, retry_count=3):
        logging.info('ydb_cli: %s' % ydb_cli)
        args = self.get_ydb_common_args(ydb_cli)
        args.extend([
            'scheme', 'describe', table_path,
            '--format', 'proto-json-base64',
        ])

        for i in xrange(retry_count):
            table_scheme = self.call_ydb_cli_with_json_out(args, 'describe_table')
            if table_scheme:
                return table_scheme

            time.sleep(WAIT_SLOT_SEC)

        raise common.errors.TaskError("Could not get ydb table scheme")

    def wait_backup_complete(self, ydb_cli, operation_id):
        start_time = time.time()
        while True:
            if time.time() - start_time > self.Parameters.kill_timeout:
                logging.info('Time is out. Cancelling operation')
                self.cancel_operation(ydb_cli, operation_id)
                raise common.errors.TaskError("Backup process is timed out. Operation %s has been canceled" % operation_id)

            op_data = self.get_ydb_operation(ydb_cli, operation_id)
            if op_data.get('ready'):
                logging.info('Operation is ready now')
                break

            logging.info('Operation is still running')
            time.sleep(WAIT_SLOT_SEC)

    def check_backup(self, yt_dst_folder, ydb_table_list):
        yt_file_list = self.yt_client.list(yt_dst_folder, absolute=True)

        if not yt_file_list:
            raise common.errors.TaskError('Zero number of created backup files')

        if len(ydb_table_list) != len(yt_file_list):
            raise common.errors.TaskError('Number of created backup files differs from number of YDB tables')

    def do_backup(self, ydb_cli, yt_dst_path, table_names, description):
        self.Context.operation_id = self.start_backup(ydb_cli, yt_dst_path, table_names, description)
        try:
            self.wait_backup_complete(ydb_cli, self.Context.operation_id)
        finally:
            op = self.Context.operation_id
            self.Context.operation_id = None
            self.forget_operation(ydb_cli, op)

    def on_terminate(self):
        logging.info("Terminating")
        if self.Context.operation_id:
            ydb_cli_path = self.get_ydb_cli_path()
            self.cancel_operation(ydb_cli_path, self.Context.operation_id)
            self.forget_operation(ydb_cli_path, self.Context.operation_id)

    def on_execute(self):
        import yt.wrapper as yt

        self.yt_client = yt.YtClient(self.Parameters.yt_proxy, self.Parameters.yt_token.data())
        ydb_cli_path = self.get_ydb_cli_path()

        now = datetime.datetime.now()

        backup_description = 'database backup created on %s' % now.strftime('%Y-%m-%d %H:%M:%S')

        if self.Parameters.db_tables:
            table_names = self.Parameters.db_tables
        else:
            table_names = self.get_ydb_table_list(ydb_cli_path)

        if not table_names:
            raise common.errors.TaskError('Empty list of tables in database. Nothing to do')

        logging.info("Table list to backup: ", table_names)

        yt_dst_folder = os.path.join(str(self.Parameters.yt_path), Env.make_yt_name(now))
        self.make_yt_folder(yt_dst_folder)
        self.do_backup(ydb_cli_path, yt_dst_folder, table_names, backup_description)
        self.check_backup(yt_dst_folder, table_names)
        Env.remove_old_data(self.yt_client, str(self.Parameters.yt_path), self.Parameters.backup_count, True)

        logging.info("Done")
