import json
import logging
import time
from datetime import datetime, timedelta
from enum import Enum
from pytz import timezone

from sandbox import sdk2
from sandbox.sandboxsdk import environments
from sandbox.sdk2.helpers import subprocess as sp
from sandbox.projects.kikimr.resources import YdbCliBinary
from sandbox.projects.vins.common.resources import YdbTableRotator
from sandbox.sdk2.helpers import ProcessLog, subprocess

logger = logging.getLogger(__name__)

STABLE = 'stable'
LINUX = 'linux'

TIMESTAMP_FORMAT = "%Y-%m-%dT%H_%M_%SZ"


class OperationAction(Enum):
    cancel = 'cancel'
    forget = 'forget'


class AliceLogsBackupYdb2Yt(sdk2.Task):
    """
    Backup alice logs from YDB to YT
    """

    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment('yandex-yt', version='0.10.8'),
        ]
        cores = 1
        ram = 2048
        disk_space = 2048

        # disable all caches by default (make it multislot-aware)
        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 10 * 60 * 60  # 10 hours

        with sdk2.parameters.Group('YDB'):
            ydb_endpoint = sdk2.parameters.String(
                'YDB endpoint',
                default='ydb-ru.yandex.net:2135',
                required=True
            )
            ydb_database = sdk2.parameters.String(
                'YDB database',
                default='/ru/alice/prod/alicelogs',
                required=True
            )
            ydb_vault_token = sdk2.parameters.String(
                'YDB token secret in Sandbox Vault',
                default='robot-bassist_ydb_token',
                required=True
            )
            ydb_vault_owner = sdk2.parameters.String(
                'YDB token vault owner in Sandbox Vault',
                default='robot-bassist',
                required=True
            )
            ydb_forget_operation = sdk2.parameters.Bool(
                'Flag to forget operation or not',
                default=True,
                required=True
            )

        with sdk2.parameters.Group('YT'):
            yt_cluster = sdk2.parameters.String(
                'YT cluster',
                default='hahn',
                required=True
            )
            yt_directory = sdk2.parameters.String(
                'YT directory',
                default='//home/alice-setrace',
                required=True
            )
            yt_tmp_directory = sdk2.parameters.String(
                'YT tmp directory',
                default='//tmp/robot-bassist',
                required=True
            )
            yt_vault_token = sdk2.parameters.String(
                'YT token secret in Sandbox Vault',
                default='robot-bassist_yt_token',
                required=True
            )
            yt_vault_owner = sdk2.parameters.String(
                'YT token vault owner in Sandbox Vault',
                default='robot-bassist',
                required=True
            )
            yt_table_ttl_days = sdk2.parameters.Integer(
                'YT table ttl in days',
                default=5,
                required=True
            )

        force_backup = sdk2.parameters.Bool(
            'force_backup',
            required=True,
            default=False
        )

        accessible_hours = sdk2.parameters.String(
            'List of hours when exporting can start',
            default='[0, 1, 2, 3, 4, 5, 6, 7, 8, 22, 23]',
            required=True
        )

        with sdk2.parameters.Group('Table Rotation'):
            ydb_rotation_period = sdk2.parameters.Integer(
                'Period of rotations for YDB tables (hours)',
                default=8,
                required=True
            )
            table_rotator_binary = sdk2.parameters.Resource(
                'Binary from alice/rtlog/table_rotator',
                default=2736857462,
                resource_type=YdbTableRotator
            )

            with sdk2.parameters.Group('Events data'):
                rotate_events_data_table = sdk2.parameters.Bool(
                    'Flag to rotate events_data_* tables',
                    default=False,
                    required=True
                )
                events_data_table_drop_after = sdk2.parameters.Integer(
                    'This parameter is not used, event_data_* table are droped after backup',
                    default=0,
                    required=True
                )

            with sdk2.parameters.Group('Events index'):
                rotate_events_index_table = sdk2.parameters.Bool(
                    'Flag to rotate events_index_data_* tables',
                    default=False,
                    required=True
                )
                events_index_table_drop_after = sdk2.parameters.Integer(
                    'events_index_* tables are dropped after this amount of time',
                    default=240,
                    required=True
                )

            with sdk2.parameters.Group('Special events'):
                rotate_special_events_table = sdk2.parameters.Bool(
                    'Flag to rotate special_events_data_* tables',
                    default=False,
                    required=True
                )
                special_events_table_drop_after = sdk2.parameters.Integer(
                    'special_events_* tables are dropped after this amount of time',
                    default=240,
                    required=True
                )

    class Context(sdk2.Task.Context):
        ydb_cli_path = None
        operation_id = None
        ydb_table_rotator_path = None

    def _get_ydb_token(self):
        return sdk2.Vault.data(self.Parameters.ydb_vault_owner, self.Parameters.ydb_vault_token)

    def _get_yt_token(self):
        return sdk2.Vault.data(self.Parameters.yt_vault_owner, self.Parameters.yt_vault_token)

    def _get_ydb_cli(self):
        ydb_cli_resource = YdbCliBinary.find(
            attrs=dict(released=STABLE, platform=LINUX)
        ).first()

        archive_path = str(sdk2.ResourceData(ydb_cli_resource).path)

        with ProcessLog(self, logger='extract_ydb_cli') as pl:
            subprocess.check_call(
                ['tar', '-xzf', archive_path],
                stdout=pl.stdout,
                stderr=pl.stderr,
            )

        return str(sdk2.path.Path.cwd() / 'ydb')

    def _get_ydb_items(self):
        argv = [
            self.Context.ydb_cli_path,
            '--endpoint',
            self.Parameters.ydb_endpoint,
            '--database',
            self.Parameters.ydb_database,
            'scheme',
            'ls',
            self.Parameters.ydb_database,
        ]

        with ProcessLog(self, logger='ydb_scheme_ls') as pl:
            output = sp.check_output(
                argv,
                env={
                    'YDB_TOKEN': self._get_ydb_token(),
                },
                stderr=pl.stderr,
            )
        logger.info('found YDB items: {}'.format(output))
        # output is a string like 'table1  table2'
        return output.split('  ')

    def _get_table_info(self, table):
        argv = [
            self.Context.ydb_cli_path,
            '--endpoint',
            self.Parameters.ydb_endpoint,
            '--database',
            self.Parameters.ydb_database,
            'scheme',
            'describe',
            table,
            '--format',
            'proto-json-base64',
            '--stats',
        ]

        with ProcessLog(self, logger='ydb_scheme_describe') as pl:
            output = sp.check_output(
                argv,
                env={
                    'YDB_TOKEN': self._get_ydb_token(),
                },
                stderr=pl.stderr,
            )
        logger.info('got info about table {}: {}'.format(table, output))
        return json.loads(output)

    def _filter_export_tables(self, tables):
        result = []
        for table in tables:
            info = self._get_table_info(table)
            ts = datetime.strptime(info["table_stats"]["modification_time"], "%Y-%m-%dT%H:%M:%S.%fZ")
            if ts + timedelta(hours=1) < datetime.utcnow():
                result.append(table)
        return result

    def _get_table_for_export(self, yt_client):
        for prefix in ['events_data']:
            logging.info('Looking for table for export with prefix {}'.format(prefix))

            ydb_items = self._get_ydb_items()
            events_data_ydb_tables = []
            for item in ydb_items:
                if 'export' in item:
                    logger.info('Export directory {} is present in YDB'.format(item))
                    return None
                if item.startswith(prefix):
                    events_data_ydb_tables.append(item)
            events_data_ydb_tables.sort()

            logger.info('Existing events data YDB tables: {}'.format(', '.join(events_data_ydb_tables)))

            try:
                events_data_ydb_tables_for_export = self._filter_export_tables(events_data_ydb_tables)
            except Exception as e:
                logger.error("failed to filter tables, error: {}".format(str(e)))
                events_data_ydb_tables_for_export = events_data_ydb_tables[:-2]

            logger.info('Existing events data YDB tables available for export: {}'.format(
                ', '.join(events_data_ydb_tables_for_export)
            ))

            yt_tables = set()
            for table in map(lambda yt_item: str(yt_item), yt_client.list(self.Parameters.yt_directory)):
                if table.startswith(prefix):
                    yt_tables.add(table)

            logger.info('Existing YT tables: {}'.format(', '.join(yt_tables)))

            for table_name in events_data_ydb_tables_for_export:
                if table_name not in yt_tables:
                    return table_name

        return None

    def _create_yt_table(self, yt_client, table_path):
        schema = [
            {
                'name': 'reqid',
                'type': 'string'
            },
            {
                'name': 'activation_id',
                'type': 'string'
            },
            {
                'name': 'frame_id',
                'type': 'uint64'
            },
            {
                'name': 'event_index',
                'type': 'uint64'
            },
            {
                'name': 'req_ts',
                'type': 'int64'
            },
            {
                'name': 'instance_id',
                'type': 'uint64'
            },
            {
                'name': 'ts',
                'type': 'uint64'
            },
            {
                'name': 'event_type',
                'type': 'string'
            },
            {
                'name': 'event',
                'type': 'string'
            },
            {
                'name': 'event_binary',
                'type': 'string'
            },
        ]
        yt_client.create('table', table_path, attributes={'schema': schema})

    def _start_ydb_export_to_yt(self, table):
        argv = [
            self.Context.ydb_cli_path,
            '--endpoint',
            self.Parameters.ydb_endpoint,
            '--database',
            self.Parameters.ydb_database,
            'export',
            'yt',
            '--proxy',
            self.Parameters.yt_cluster,
            '--item',
            'source={ydb_db}/{table},destination={yt_dir}/{table}'.format(
                ydb_db=self.Parameters.ydb_database,
                yt_dir=self.Parameters.yt_tmp_directory,
                table=table
            ),
            '--format',
            'proto-json-base64',
        ]

        with ProcessLog(self, logger='ydb_export_yt') as pl:
            output = sp.check_output(
                argv,
                env={
                    'YDB_TOKEN': self._get_ydb_token(),
                    'YT_TOKEN': self._get_yt_token(),
                },
                stderr=pl.stderr,
            )

        result = json.loads(output)
        logger.info('Start operation export result: {}'.format(result))

        return result

    def _ydb_operation_get(self, stderr):
        argv = [
            self.Context.ydb_cli_path,
            '--endpoint',
            self.Parameters.ydb_endpoint,
            '--database',
            self.Parameters.ydb_database,
            'operation',
            'get',
            self.Context.operation_id,
            '--format',
            'proto-json-base64',
        ]

        output = sp.check_output(
            argv,
            env={
                'YDB_TOKEN': self._get_ydb_token(),
            },
            stderr=stderr,
        )

        result = json.loads(output)
        logger.info('Operation get result: {}'.format(result))

        return result

    def _ydb_apply_action_to_operation(self, action):
        argv = [
            self.Context.ydb_cli_path,
            '--endpoint',
            self.Parameters.ydb_endpoint,
            '--database',
            self.Parameters.ydb_database,
            'operation',
            action.value,
            self.Context.operation_id,
        ]

        with ProcessLog(self, logger='ydb_operation_{}'.format(action)) as pl:
            sp.check_call(
                argv,
                env={
                    'YDB_TOKEN': self._get_ydb_token(),
                },
                stdout=pl.stdout,
                stderr=pl.stderr,
            )

    def _postprocess_table(self, yt_client, table_path):
        logger.info('Starting to sort table {}'.format(table_path))
        yt_client.run_sort(table_path, sort_by=[
            'reqid',
            'activation_id',
            'frame_id',
            'event_index',
            'req_ts',
        ])

        logger.info('Starting to transform table {}'.format(table_path))
        yt_client.transform(table_path, table_path, compression_codec='brotli_6', erasure_codec='lrc_12_2_2')

    def _ydb_drop_table(self, table_name):
        argv = [
            self.Context.ydb_cli_path,
            '--endpoint',
            self.Parameters.ydb_endpoint,
            '--database',
            self.Parameters.ydb_database,
            'table',
            'drop',
            self.Parameters.ydb_database + '/' + table_name
        ]

        logger.info("Trying to drop table {}".format(table_name))
        with ProcessLog(self, logger='ydb_drop_table') as pl:
            sp.check_call(
                argv,
                env={
                    'YDB_TOKEN': self._get_ydb_token(),
                },
                stdout=pl.stdout,
                stderr=pl.stderr,
            )
        logger.info("Successfully droped table {}".format(table_name))

    def _get_table_rotator_binary(self):
        path = str(sdk2.ResourceData(self.Parameters.table_rotator_binary).path)
        logging.info("table_rotator path: " + path)
        return path

    def _get_timestamp_from_table_name(self, table_name):
        timestamp_str = table_name[table_name.find("data_") + len("data_"):]
        return datetime.strptime(timestamp_str, TIMESTAMP_FORMAT)

    def _ydb_rotate_table(self, table_name, new_table_name, preserve_partitions):
        argv = [
            self.Context.ydb_table_rotator_path,
            '--endpoint',
            self.Parameters.ydb_endpoint,
            '--database',
            self.Parameters.ydb_database,
            '--tablename',
            table_name,
            '--newtablename',
            new_table_name
        ]
        if (preserve_partitions):
            argv.append("--copy-patrtitions")

        logger.info("Trying to rotate table {}".format(table_name))
        with ProcessLog(self, logger='rotate_table') as pl:
            sp.check_call(
                argv,
                env={
                    'YDB_TOKEN': self._get_ydb_token(),
                },
                stdout=pl.stdout,
                stderr=pl.stderr,
            )

    def _handle_ydb_table_rotation(self):
        all_tables = self._get_ydb_items()
        logger.info('Looking through tables: {}'.format(', '.join(all_tables)))
        timestamps = []
        for table_name in all_tables:
            if '_data_' in table_name and table_name[-1] == 'Z':
                timestamps.append(self._get_timestamp_from_table_name(table_name))

        next_timestamp = max(timestamps) + timedelta(hours=self.Parameters.ydb_rotation_period)
        if next_timestamp > datetime.utcnow():
            return

        table_prefixes_for_rotation = []
        if self.Parameters.rotate_special_events_table:
            table_prefixes_for_rotation.append(["special_events_data_", False])
        if self.Parameters.rotate_events_index_table:
            table_prefixes_for_rotation.append(["events_index_data_", False])
        if self.Parameters.rotate_events_data_table:
            table_prefixes_for_rotation.append(["events_data_", True])

        for prefix, preserve_partitions in table_prefixes_for_rotation:
            rotation_candidates = []
            for table_name in all_tables:
                if table_name.startswith(prefix):
                    rotation_candidates.append(table_name)
            if len(rotation_candidates) == 0:
                logger.error('Failed to find tables with prefix "{}"'.format(prefix))
                continue
            table_name = max(rotation_candidates)  # table with biggest timestamp, events_data_2021-10-26T12_57_40Z
            new_table_name = prefix + datetime.strftime(next_timestamp, TIMESTAMP_FORMAT)

            self._ydb_rotate_table(table_name, new_table_name, preserve_partitions)

    def _drop_old_tables(self):
        all_tables = self._get_ydb_items()
        table_prefixes_for_dropping = []
        if self.Parameters.rotate_special_events_table:
            table_prefixes_for_dropping.append(["special_events_data_", self.Parameters.special_events_table_drop_after])
        if self.Parameters.rotate_events_index_table:
            table_prefixes_for_dropping.append(["events_index_data_", self.Parameters.events_index_table_drop_after])

        for prefix, drop_after_period in table_prefixes_for_dropping:
            for table_name in all_tables:
                if table_name.startswith(prefix):
                    table_timestamp = self._get_timestamp_from_table_name(table_name)
                    if table_timestamp + timedelta(hours=drop_after_period) <= datetime.utcnow():
                        self._ydb_drop_table(table_name)

    def on_execute(self):
        self.Context.ydb_cli_path = self._get_ydb_cli()
        self.Context.ydb_table_rotator_path = self._get_table_rotator_binary()
        self._handle_ydb_table_rotation()
        self._drop_old_tables()

        # Backup only at accessible hours
        now = datetime.now(tz=timezone('Europe/Moscow'))
        accessible_hours = json.loads(self.Parameters.accessible_hours)
        if not self.Parameters.force_backup and now.hour not in accessible_hours:
            return

        import yt.wrapper

        yt_client = yt.wrapper.YtClient(
            proxy=self.Parameters.yt_cluster,
            token=self._get_yt_token(),
            config={
                'pool': 'setrace'
            }
        )

        table_for_export = self._get_table_for_export(yt_client)

        if table_for_export is None:
            logger.info('Nothing to export, all YDB tables are present in YT')
            return

        logger.info('Selected table {} for export'.format(table_for_export))

        result_table_path = yt.wrapper.ypath_join(self.Parameters.yt_directory, table_for_export)
        yt_client.create('table', result_table_path)

        tmp_table_path = yt.wrapper.ypath_join(self.Parameters.yt_tmp_directory, table_for_export)
        self._create_yt_table(yt_client, tmp_table_path)

        result = self._start_ydb_export_to_yt(table_for_export)
        self.Context.operation_id = result.get('id')
        logger.info('Operation id = {}'.format(self.Context.operation_id))

        try:
            operation_ready = False
            with ProcessLog(self, logger='ydb_operation_get') as pl:
                while not operation_ready:
                    result = self._ydb_operation_get(pl.stderr)
                    operation_ready = result.get('ready', False)
                    logger.info('Operation get result {}'.format(json.dumps(result)))
                    time.sleep(600)
        finally:
            self._ydb_apply_action_to_operation(OperationAction.cancel)
            if self.Parameters.ydb_forget_operation:
                self._ydb_apply_action_to_operation(OperationAction.forget)

        self._postprocess_table(yt_client, tmp_table_path)

        logger.info('Moving table to result diretory')
        yt_client.move(tmp_table_path, result_table_path, force=True)

        expire_datetime = (datetime.now() + timedelta(days=self.Parameters.yt_table_ttl_days)).isoformat()
        logger.info('Setting expiration_time {} to result table'.format(expire_datetime))
        yt_client.set(yt.wrapper.ypath_join(result_table_path, '@expiration_time'), expire_datetime)

        self._ydb_drop_table(table_for_export)
