# -*- coding: utf-8 -*-
from argparse import ArgumentParser
from datetime import datetime
import logging.config

from dateutil.relativedelta import relativedelta
from yt.wrapper.client import YtClient

from travel.library.python.tools import replace_args_from_env
from travel.hotels.content_manager.config.stage_config import STAGE_LIST
from travel.hotels.content_manager.migrations.additional_migrations import ADDITIONAL_MIGRATIONS
from travel.hotels.content_manager.data_model.options import Options
from travel.hotels.content_manager.data_model.types import ProcessType
from travel.hotels.content_manager.exec.log_config import LOG_CONFIG
from travel.hotels.content_manager.lib.common import hide_secrets_dc, hide_secrets_ns, resolve_secrets_dc
from travel.hotels.content_manager.lib.common import ts_to_str_msk_tz
from travel.hotels.content_manager.lib.delayed_executor import DelayedExecutor
from travel.hotels.content_manager.lib.environment import get_available_environments
from travel.hotels.content_manager.lib.environment import get_environment_options_file, get_environment_options_resource
from travel.hotels.content_manager.lib.history_merger import HistoryMerger
from travel.hotels.content_manager.lib.migration_generator import MigrationGenerator
from travel.hotels.content_manager.lib.path_info import PathInfo
from travel.hotels.content_manager.lib.persistence_manager import PersistenceManager, YtCachePersistenceManager
from travel.hotels.content_manager.lib.storage_patcher import StoragePatcher
from travel.hotels.content_manager.lib.trigger import Trigger
from travel.hotels.content_manager.lib.yql_simple_client import YqlSimpleClient
from travel.hotels.content_manager.metrics.metrics_updater import MetricsUpdater


LOG = logging.getLogger(__name__)


class Cleaner:

    def __init__(
        self,
        persistence_manager: PersistenceManager,
        path: str,
        keep_storage_snapshots_days: int,
        keep_storage_snapshots_count: int,
    ):
        self.persistence_manager = persistence_manager
        self.path = path
        self.keep_storage_snapshots_days = keep_storage_snapshots_days
        self.keep_storage_snapshots_count = keep_storage_snapshots_count

    def clean(self):
        LOG.info(f'Cleaning at {self.path}')
        date_to = datetime.now().astimezone() - relativedelta(days=self.keep_storage_snapshots_days)
        LOG.info(f'Removing all before {date_to}')
        nodes = self.persistence_manager.list(self.path)
        nodes = sorted(nodes, key=lambda n: n.created_at)
        removed_count = 0
        total_count = len(nodes)
        for node in nodes:
            if node.created_at >= date_to:
                LOG.info('No more old snapshots')
                break
            if total_count - removed_count <= self.keep_storage_snapshots_count:
                LOG.info(f'Keeping {self.keep_storage_snapshots_count} snapshots')
                break
            node_path = self.persistence_manager.join(self.path, node.name)
            self.persistence_manager.delete(node_path)
            removed_count += 1
        LOG.info(f'Nodes removed: {removed_count}')


class CMRunner(object):

    def __init__(
            self,
            yql_client: YqlSimpleClient,
            persistence_manager: PersistenceManager,
            path_info: PathInfo,
            start_ts: int,
            options: Options,
    ):
        self.yql_client = yql_client
        self.persistence_manager = persistence_manager
        self.path_info = path_info
        self.start_ts = start_ts
        self.prev_snapshot_path = None
        self.snapshot_path = None
        self.options = options

    def branch_storage(self) -> None:
        self.prev_snapshot_path = self.persistence_manager.realpath(self.path_info.storage_path)

        snapshot_name = ts_to_str_msk_tz(self.start_ts)
        snapshot_path = self.persistence_manager.join(self.path_info.storage_snapshots_path, snapshot_name)
        self.snapshot_path = snapshot_path

        LOG.info(f'Copying storage from {self.path_info.storage_path} to {snapshot_path}')
        self.persistence_manager.copy_dir(self.path_info.storage_path, snapshot_path)

        LOG.info(f'Linking storage from {self.path_info.storage_path} to {self.snapshot_path}')
        self.persistence_manager.link(self.path_info.storage_path, self.snapshot_path)

    def do_migrate(self) -> bool:
        storage_version_changed = False

        LOG.info('Checking migrations')
        version_table = self.path_info.storage_version_table
        storage_version = 0
        migration_version = 0
        if self.persistence_manager.exists(version_table):
            storage_version = list(self.persistence_manager.read(version_table))[0]['value']
        LOG.info(f'Storage version: {storage_version}')

        mg = MigrationGenerator(storage_version)
        for table in mg.new_tables:
            table_path = self.persistence_manager.join(self.path_info.storage_path, table.name)
            LOG.info(f'Creating table: {table_path}')
            self.persistence_manager.write(table_path, [], table.schema)

        for migration_version, template in enumerate(mg.migrations, storage_version + 1):
            LOG.info(f'Applying migration {migration_version}')
            query = template.format(storage_path=self.path_info.storage_path)
            self.yql_client.run_query(query)

            additional_migration = ADDITIONAL_MIGRATIONS.get(migration_version)
            if additional_migration is not None:
                LOG.info(f'Applying additional migration {migration_version}')
                self.persistence_manager.reset_cache()
                additional_migration(self.persistence_manager, self.path_info)
                self.persistence_manager.flush()

        if migration_version > storage_version:
            LOG.info(f'Updating storage version to {migration_version}')
            self.persistence_manager.write(version_table, [{'value': migration_version}], {'value': 'uint32'})
            storage_version_changed = True

        self.persistence_manager.reset_cache()

        LOG.info('Migrating complete')
        return storage_version_changed

    def run(self):
        self.branch_storage()

        storage_version_changed = self.do_migrate()

        storage_patcher = StoragePatcher(self.persistence_manager, self.path_info)
        storage_patched = storage_patcher.patch()

        for stage_cfg in STAGE_LIST:
            if stage_cfg.process_type == ProcessType.CATROOM:
                continue
            updater_cls = stage_cfg.updater
            if updater_cls is None:
                continue
            updater = stage_cfg.updater(
                stage_name=stage_cfg.name,
                persistence_manager=self.persistence_manager,
                yql_client=self.yql_client,
                path_info=self.path_info,
                start_ts=self.start_ts,
                save_history=stage_cfg.save_history,
                options=self.options.stages,
            )
            updater.process()

        delayed_executor = DelayedExecutor()

        run_first = list()
        run_last = list()
        for stage_cfg in STAGE_LIST:
            if stage_cfg.process_type == ProcessType.CATROOM:
                continue
            if stage_cfg.run_on_storage_change:
                run_last.append(stage_cfg)
            else:
                run_first.append(stage_cfg)

        for stage_cfg in run_first + run_last:
            stage_options = self.options.stages.get(stage_cfg.name)

            trigger = Trigger(
                process_type=stage_cfg.process_type,
                stage_name=stage_cfg.name,
                producer_cls=stage_cfg.producer,
                thread_filters=stage_cfg.filters,
                persistence_manager=self.persistence_manager,
                delayed_executor=delayed_executor,
                path_info=self.path_info,
                entity_cls=stage_cfg.entity_cls,
                other_entities=stage_cfg.other_entities,
                start_ts=self.start_ts,
                options=stage_options,
                jobs_max=stage_cfg.jobs_max,
                job_size=stage_cfg.job_size,
                manual_start=stage_cfg.manual_start,
                run_on_storage_change=stage_cfg.run_on_storage_change,
                job_retry_count=stage_cfg.job_retry_count,
                job_max_run_time=stage_cfg.job_max_run_time,
            )
            trigger.process()

        LOG.info('Flushing changes')
        self.persistence_manager.flush()

        if not storage_version_changed and not storage_patched and not self.persistence_manager.upstream_changed:
            LOG.info('Nothing changed. Rolling back storage changes')

            LOG.info(f'Linking back storage from {self.path_info.storage_path} to {self.prev_snapshot_path}')
            self.persistence_manager.link(self.path_info.storage_path, self.prev_snapshot_path)

            LOG.info(f'Deleting {self.snapshot_path}')
            self.persistence_manager.delete(self.snapshot_path)

            return

        LOG.info('Running delayed tasks')
        delayed_executor.execute()


def main():
    logging.config.dictConfig(LOG_CONFIG)

    parser = ArgumentParser()

    parser.add_argument('--vault-token', default=None)
    parser.add_argument('--env', choices=get_available_environments('options/'))
    parser.add_argument('--env-fn')
    parser.add_argument('--send-metrics-to-solomon', action='store_true')
    parser.add_argument('--keep-storage-snapshots-days', type=int, default=30)
    parser.add_argument('--keep-storage-snapshots-count', type=int, default=5)

    args = parser.parse_args(replace_args_from_env())
    LOG.info(f'Args: {hide_secrets_ns(args)}')

    if args.env and args.env_fn:
        raise RuntimeError('Do not use both "--env" and "--env-fn"')

    unresolved_options = None
    if args.env:
        unresolved_options = get_environment_options_resource(args.env, 'options/', Options)
    elif args.env_fn:
        unresolved_options = get_environment_options_file(args.env_fn, Options)
    LOG.debug(f'Unresolved options: {unresolved_options}')

    if unresolved_options is None:
        raise RuntimeError('Set either "--env" or "--env-fn"')

    options: Options = resolve_secrets_dc(unresolved_options, args.vault_token)
    LOG.debug(f'Resolved options: {hide_secrets_dc(options)}')

    yql_client = YqlSimpleClient(token=options.yql.token, yt_proxy=options.yt.proxy)
    yt_client = YtClient(proxy=options.yt.proxy, token=options.yt.token)
    persistence_manager = YtCachePersistenceManager(yt_client, yql_client)
    path_info = PathInfo(
        persistence_manager=persistence_manager,
        root=options.yt.root,
        requests_path=options.yt.requests_path,
    )

    start_ts = int(datetime.utcnow().timestamp())

    Cleaner(
        persistence_manager,
        path_info.storage_snapshots_path,
        args.keep_storage_snapshots_days,
        args.keep_storage_snapshots_count
    ).clean()

    with persistence_manager.transaction() as t:
        yql_client.transaction_id = t.transaction_id
        CMRunner(yql_client, persistence_manager, path_info, start_ts, options).run()
        yql_client.transaction_id = None

    MetricsUpdater(persistence_manager, path_info, options.metrics, args.send_metrics_to_solomon).run()
    HistoryMerger(persistence_manager).do_merge(path_info.history_path)
    LOG.info('All done')


if __name__ == '__main__':
    main()
