# -*- coding: utf-8 -*-

import json
import time
import re
import logging

from sandbox import sdk2
from sandbox.projects.kikimr.resources import YdbCliBinary
from sandbox.sandboxsdk import environments
from sandbox.sdk2.helpers import ProcessLog, subprocess
from sandbox.projects.maps.common.juggler_alerts import TaskJugglerReportWithParameters


logger = logging.getLogger(__name__)

OPERATION_POLLING_INTERVAL = 60
OPERATION_RETRIES_INTERVAL = 60
WORK_DIR = 'current'
LATEST_DIR = 'latest'
SORT_BY = 'id'


class BackupUacYdb(TaskJugglerReportWithParameters):
    """
    Backup UAC YDB to YT
    """

    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment('yandex-yt'),
        ]

    class Parameters(TaskJugglerReportWithParameters.Parameters):
        kill_timeout = 5 * 60 * 60

        with sdk2.parameters.Group('YDB parameters') as ydb_parameters:
            ydb_token_vault = sdk2.parameters.String(
                'Token vault',
                default='ydb_token',
                required=True,
            )
            ydb_endpoint = sdk2.parameters.String(
                'Endpoint',
                default='ydb-ru-prestable.yandex.net:2135',
                required=True,
            )
            ydb_database = sdk2.parameters.String(
                'Database',
                default='/ru-prestable/mobileproducts/test/rmp',
                required=True,
            )
            ydb_included_tables = sdk2.parameters.String(
                'Excluded tables regexp',
                required=False
            )
        with sdk2.parameters.Group('YT parameters') as yt_parameters:
            yt_token_vault = sdk2.parameters.String(
                'Token vault',
                default='yt_token',
                required=True,
            )
            yt_proxy = sdk2.parameters.String(
                'Proxy',
                required=True,
            )
            yt_directory = sdk2.parameters.String(
                'Directory',
                required=True,
            )
            yt_is_typed = sdk2.parameters.Bool(
                'Enable typed tables',
                default=True,
                required=True,
            )
        with sdk2.parameters.Group('Common parameters') as common_parameters:
            release_name = sdk2.parameters.String(
                'Release name',
                required=False,
            )
            max_backup_count = sdk2.parameters.Integer(
                'Backups number to keep',
                default=30,
                required=True,
            )

    def on_execute(self):
        from yt.wrapper import YtClient, ypath_join, OperationsTracker

        yt = YtClient(
            proxy=self.Parameters.yt_proxy,
            token=self.yt_token(),
        )

        backups_dir = self.Parameters.yt_directory
        work_dir = ypath_join(backups_dir, WORK_DIR)
        if yt.exists(work_dir):
            yt.remove(work_dir, recursive=True)
        yt.mkdir(work_dir)

        ydb_cli_path = self.get_ydb_cli()
        self.ydb_export(work_dir, ydb_cli_path)

        with OperationsTracker() as tracker:
            for table in self.ydb_tables_list(ydb_cli_path):
                if self.is_table_included(table):
                    table_path = ypath_join(work_dir, table)
                    tracker.add(yt.run_sort(table_path, sort_by=SORT_BY))
        tracker.wait_all()

        result_dir_name = time.strftime('%Y-%m-%dT%H:%M:%S')
        if self.Parameters.release_name:
            result_dir_name += ' - ' + self.Parameters.release_name

        result_dir = ypath_join(backups_dir, result_dir_name)
        yt.move(work_dir, result_dir)

        latest_dir = ypath_join(backups_dir, LATEST_DIR)
        yt.link(result_dir, latest_dir, force=True)

        nodes = yt.list(backups_dir, sort=True, attributes=['type'])
        nodes_to_remove = self.list_nodes_to_remove(nodes)
        for node in nodes_to_remove:
            node_path = ypath_join(backups_dir, node)
            logging.info('removing ' + node_path)
            yt.remove(node_path, recursive=True)

    def yt_token(self):
        return sdk2.Vault.data(self.owner, self.Parameters.yt_token_vault)

    def ydb_token(self):
        return sdk2.Vault.data(self.owner, self.Parameters.ydb_token_vault)

    def list_nodes_to_remove(self, nodes):
        to_remove = []
        count = 0
        for node in reversed(nodes):
            if node.attributes['type'] != 'map_node':
                continue
            count += 1
            if count > self.Parameters.max_backup_count:
                to_remove.append(node)
        return to_remove

    def get_ydb_cli(self):
        ydb_cli_resource = YdbCliBinary.find(
            attrs=dict(released='stable', platform='linux')
        ).first()

        archive_path = str(sdk2.ResourceData(ydb_cli_resource).path)

        with ProcessLog(self, logger='extract_ydb_cli') as pl:
            subprocess.check_call(
                ['tar', '-xzf', archive_path],
                stdout=pl.stdout,
                stderr=pl.stderr,
            )

        return str(sdk2.path.Path.cwd() / 'ydb')

    def run_ydb_process(self, argv, parse_result=True, retries=0):
        with ProcessLog(self, logger='run ydb process') as pl:
            try:
                output = subprocess.check_output(
                    argv,
                    env={
                        'YDB_TOKEN': self.ydb_token(),
                        'YT_TOKEN': self.yt_token(),
                    },
                    stderr=pl.stderr,
                )
            except subprocess.CalledProcessError as e:
                if retries <= 0:
                    raise
                logger.info('Process failed with {code} {out} {err}'.format(
                    code=e.returncode,
                    out=e.output,
                    err=e.stderr,
                ))
                time.sleep(OPERATION_RETRIES_INTERVAL)
                return self.run_ydb_process(argv, parse_result, retries - 1)
        if parse_result:
            result = json.loads(output)
            logger.info('Get process result: {}'.format(result))
            return result
        else:
            return output

    def is_table_included(self, table):
        pattern = re.compile(self.Parameters.ydb_included_tables)
        return bool(re.search(pattern, table))

    def ydb_export(self, work_dir, ydb_cli_path=None):
        if not ydb_cli_path:
            ydb_cli_path = self.get_ydb_cli()
        result = self.ydb_start_export(ydb_cli_path, work_dir)
        self.ydb_wait_operation(ydb_cli_path, result['id'])

    def ydb_wait_operation(self, ydb_cli_path, operation_id):
        try:
            while True:
                result = self.ydb_get_operation(ydb_cli_path, operation_id)
                if result.get('ready', False):
                    break
                time.sleep(OPERATION_POLLING_INTERVAL)
        finally:
            self.ydb_forget(ydb_cli_path, operation_id)

    def ydb_start_export(self, ydb_cli_path, work_dir):
        argv = [
            ydb_cli_path,
            '--endpoint', self.Parameters.ydb_endpoint,
            '--database', self.Parameters.ydb_database,
            'export', 'yt',
            '--proxy', self.Parameters.yt_proxy,
            '--format', 'proto-json-base64',
        ]
        if self.Parameters.yt_is_typed:
            argv.append('--use-type-v3')
        tables = self.ydb_tables_list(ydb_cli_path)
        for table in tables:
            if self.is_table_included(table):
                item = 'source={ydb_db}/{table},destination={yt_dir}/{table}'.format(
                    ydb_db=self.Parameters.ydb_database,
                    yt_dir=work_dir,
                    table=table,
                )
                argv.extend(['--item', item])
        try:
            return self.run_ydb_process(argv)
        except subprocess.CalledProcessError as e:
            operation_id = json.loads(e.output).get('id')
            self.ydb_forget(ydb_cli_path, operation_id)
            raise

    def ydb_get_operation(self, ydb_cli_path, operation_id):
        argv = [
            ydb_cli_path,
            '--endpoint', self.Parameters.ydb_endpoint,
            '--database', self.Parameters.ydb_database,
            'operation', 'get', operation_id,
            '--format', 'proto-json-base64',
        ]
        return self.run_ydb_process(argv, retries=15)

    def ydb_forget(self, ydb_cli_path, operation_id):
        argv = [
            ydb_cli_path,
            '--endpoint', self.Parameters.ydb_endpoint,
            '--database', self.Parameters.ydb_database,
            'operation', 'forget', operation_id,
        ]
        self.run_ydb_process(argv, False, retries=15)

    def ydb_tables_list(self, ydb_cli_path):
        argv = [
            ydb_cli_path,
            '--endpoint', self.Parameters.ydb_endpoint,
            '--database', self.Parameters.ydb_database,
            'scheme',
            'ls',
            self.Parameters.ydb_database
        ]
        output = self.run_ydb_process(argv, False, retries=15)
        return output.split()
