import logging
import time
from datetime import datetime

from sandbox import sdk2
from sandbox.sandboxsdk import process, environments
from sandbox.projects.common import utils
from sandbox.projects.collections.mixins import (
    YasmReportable,
    _wait_all,
)
from sandbox.projects.collections.resources import (
    CollectionsMongoCopyBinary,
    CollectionsMongoBackupConverter,
)


COLLECTION_SPLITS = {
    'card': 20,
    'board': 5,
    'user': 10,
    'subscription': 5,
}
LOGGER = logging.getLogger(__name__)


class CollectionsDumpMongo(YasmReportable, sdk2.Task):
    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment('yandex-yt', version='0.9.35'),
            environments.PipEnvironment("yandex-yt-transfer-manager-client"),
        ]

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.Group('YT parameters') as yt_parameters:
            yt_proxy = sdk2.parameters.String(
                'YT proxy',
                required=True,
            )
            yt_pool = sdk2.parameters.String(
                'YT pool',
                required=True,
            )
            yt_token_vault = sdk2.parameters.String(
                'YT token vault',
                required=True,
            )
            yt_dump_path = sdk2.parameters.String(
                'Path to dump in YT',
                required=True,
            )
        with sdk2.parameters.Group('Mongo parameters') as mongo_parameters:
            mongo_database = sdk2.parameters.String(
                'Mongo database name',
                required=True,
            )
            collections = sdk2.parameters.Dict(
                'Mongo collections per uri (in vault) to dump',
                required=True,
            )
        with sdk2.parameters.Group('YDB parameters') as ydb_parameters:
            ydb_dump_path = sdk2.parameters.String(
                'Path to latest YDB dump (we want to copy them)',
                required=True,
            )
        with sdk2.parameters.Group('Dump parameters') as dump_parameters:
            env = sdk2.parameters.String(
                'Environment (prod or test)',
                required=True,
            )
            service_name = sdk2.parameters.String(
                'Legacy param for yasm',
                required=True,
            )
            start_datetime = sdk2.parameters.String(
                'Dump creation datetime',
            )
            additional_yt_destinations = sdk2.parameters.Dict(
                'additional Yt destinations (cluster: states to keep)',
                default={'hahn': 2},
            )
            monitoring_server_host = sdk2.parameters.String(
                'Monitoring server',
                default='monit.n.yandex-team.ru',
            )
            skip_apply_changes = sdk2.parameters.Bool(
                'Skip apply changes',
                default=False,
            )
        kill_timeout = 345600  # 4 days

    def on_execute(self):
        from yt.wrapper import YtClient
        from yt.wrapper import ypath_join
        yt_client = YtClient(
            proxy=self.Parameters.yt_proxy,
            token=self.yt_token,
        )
        timestamp = self._get_timestamp()
        current_dump_path = self._get_current_dump_path(timestamp)
        latest_dump_yt_path = self._latest_mongo_dump_path

        tmp_dir = ypath_join(current_dump_path, 'tmp')

        if not yt_client.exists(tmp_dir):
            yt_client.mkdir(tmp_dir, recursive=True)

        self._dump(
            yt_proxy=self.Parameters.yt_proxy,
            path=tmp_dir,
            collections=self.collections,
            yt_pool=self.Parameters.yt_pool,
            database=self.Parameters.mongo_database,
        )
        self._convert(
            yt_client=yt_client,
            src=tmp_dir,
            dst=current_dump_path,
            yt_pool=self.Parameters.yt_pool,
        )
        yt_client.remove(tmp_dir, recursive=True, force=True)
        self._copy_ydb_dumps(
            yt_client=yt_client,
            ydb_dump_path=self.Parameters.ydb_dump_path,
            mongo_dump_path=current_dump_path,
        )
        for dst_cluster, states_count in self.Parameters.additional_yt_destinations.iteritems():
            self._copy_to_other_cluster(
                yt_token=self.yt_token,
                current_dump_yt_path=current_dump_path,
                latest_dump_yt_path=latest_dump_yt_path,
                src_cluster=self.Parameters.yt_proxy,
                dst_cluster=dst_cluster,
                states_count=states_count,
            )
        self._ensure_latest_link(
            yt_client=yt_client,
            dump_yt_path=current_dump_path,
            latest_dump_yt_path=latest_dump_yt_path,
        )
        self._report_lag(
            'collections_{}_copy_{}_to_yt'.format(
                self.Parameters.service_name, self.Parameters.mongo_database,
            ),
        )

    @property
    def yt_token(self):
        return sdk2.Vault.data(self.owner, self.Parameters.yt_token_vault)

    @property
    def collections(self):
        return {
            sdk2.Vault.data(self.owner, mongo_uri_vault): collections.split(',')
            for mongo_uri_vault, collections in self.Parameters.collections.items()
        }

    def _get_current_dump_path(self, timestamp):
        return self._get_mongo_dump_path(
            base_path=self.Parameters.yt_dump_path,
            database=self.Parameters.mongo_database,
            env=self.Parameters.env,
            timestamp=timestamp,
        )

    def _get_timestamp(self):
        return (
            # local time isn't great, but it's there for backward compatibility, should probably replace with explicit timezone
            self.Parameters.start_datetime
            if self.Parameters.start_datetime else
            datetime.now().isoformat()  # kept as is for backward compatibility
        )

    @property
    def _latest_mongo_dump_path(self):
        return self._get_mongo_dump_path(
            base_path=self.Parameters.yt_dump_path,
            env=self.Parameters.env,
            database=self.Parameters.mongo_database,
            timestamp='latest',
        )

    def _get_mongo_dump_path(self, base_path, env, database, timestamp):
        return '{base_path}/{env}/{database}/{timestamp}'.format(
            base_path=base_path,
            env=env,
            database=database,
            timestamp=timestamp,
        )

    def _get_ydb_dump_path(self, base_path, cluster, env):
        return '{base_path}/{cluster}/collections/{env}/collections/latest'.format(
            base_path=base_path,
            cluster=cluster,
            env=env,
        )

    def _copy_ydb_dumps(self, yt_client, ydb_dump_path, mongo_dump_path):
        from yt.wrapper import ypath_join
        for table in self._get_tables(yt_client=yt_client, path=ydb_dump_path):
            ydb_path = ypath_join(ydb_dump_path, table)
            mongo_path = ypath_join(mongo_dump_path, table)
            logging.info('Copy ydb table %s to %s', ydb_path, mongo_path)
            # TM can't copy data by symlinks
            yt_client.copy(
                source_path=ydb_path,
                destination_path=mongo_path,
            )

    def _dump(self, yt_proxy, path, collections, database, yt_pool):
        dumpers = []
        for mongo_uri, cols in collections.items():
            for collection in cols:
                dumpers.append(
                    DumpRunner(
                        tool_path=self._get_tool_path(CollectionsMongoCopyBinary),
                        mongo_uri=mongo_uri,
                        yt_token=self.yt_token,
                        yt_proxy=yt_proxy,
                        yt_pool=yt_pool,
                        dst=path,
                        collection=collection,
                        splits=COLLECTION_SPLITS.get(collection, 1),
                        database=database,
                    )
                )
        tasks = []
        start = int(time.time())
        for dumper in dumpers:
            tasks.append(dumper.make_snapshot())
            LOGGER.info('start snapshot tasks')
        _wait_all(tasks)
        end = int(time.time())
        if self.Parameters.skip_apply_changes:
            start = 0
            end = 0
        tasks = []
        for dumper in dumpers:
            tasks.append(dumper.apply_changes(
                start=start,
                end=end,
            ))
        LOGGER.info('start apply_changes tasks')
        _wait_all(tasks)

    def _copy_to_other_cluster(self, yt_token, src_cluster, dst_cluster, current_dump_yt_path, latest_dump_yt_path, states_count):
        from yt.transfer_manager.client import TransferManager
        from yt.wrapper import YtClient, ypath_dirname, ypath_join

        transfer_client = TransferManager(token=yt_token)
        transfer_client.add_tasks(
            source_cluster=src_cluster,
            source_pattern=current_dump_yt_path,
            destination_cluster=dst_cluster,
            destination_pattern=current_dump_yt_path,
            sync=True,
        )
        yt_client = YtClient(proxy=dst_cluster, token=yt_token)
        self._ensure_latest_link(yt_client, current_dump_yt_path, latest_dump_yt_path)

        # remove old states
        dir_name = ypath_dirname(current_dump_yt_path)
        states = yt_client.list(dir_name, absolute=False)
        try:
            states.remove('latest')
        except ValueError:
            pass
        states_to_remove = sorted(states, reverse=True)[int(states_count):]
        for state in states_to_remove:
            yt_client.remove(ypath_join(dir_name, state), recursive=True)

    def _convert_dump(self, tool_path, src, dst, collection, yt_pool):
        args = [
            tool_path,
            '--yt-proxy', self.Parameters.yt_proxy,
            '--input-table', src,
            '--output-table', dst,
            '--collection', collection,
        ]
        LOGGER.info('Run command %s', args)
        return process.run_process(
            args,
            log_prefix='convertion_{}'.format(collection),
            environment={
                'YT_TOKEN': self.yt_token,
                'YT_POOL': yt_pool,
            },
            wait=False,
        )

    def _convert(self, yt_client, src, dst, yt_pool):
        from yt.wrapper import ypath_join
        collections = self._get_tables(yt_client=yt_client, path=src)
        tasks = []
        for collection in collections:
            dump_path = ypath_join(src, collection)
            dst_path = ypath_join(dst, collection)
            tasks.append(
                self._convert_dump(
                    self._get_tool_path(CollectionsMongoBackupConverter),
                    src=dump_path,
                    dst=dst_path,
                    collection=collection,
                    yt_pool=yt_pool,
                ),
            )
        _wait_all(tasks)

    def _ensure_latest_link(self, yt_client, dump_yt_path, latest_dump_yt_path):
        from yt.wrapper import YtHttpResponseError
        from yt.wrapper import ypath_join
        try:
            if dump_yt_path < yt_client.get(ypath_join(latest_dump_yt_path, '@path')):
                return  # not latest
        except YtHttpResponseError as e:
            if not e.is_resolve_error():
                raise
        self._create_link(target_path=dump_yt_path, link_path=latest_dump_yt_path, yt_client=yt_client)

    def _create_link(self, yt_client, target_path, link_path):
        from yt.wrapper import link
        logging.info('Make link from %s to %s', target_path, link_path)
        link(target_path=target_path, link_path=link_path, client=yt_client, force=True, recursive=True)

    def _get_tables(self, yt_client, path):
        if not yt_client.exists(path):
            logging.info('Path %s does not exist', path)
            return
        tables = yt_client.search(
            root=path,
            node_type=['table'],
            attributes=['path'],
        )
        for table in tables:
            yield self._split_for_dir_and_table(table.attributes['path'])[1]

    def _split_for_dir_and_table(self, path):
        from yt.wrapper import ypath_split
        return ypath_split(path)

    def _get_tool_path(self, resource_class):
        resource_id = utils.get_and_check_last_released_resource_id(resource_class)
        LOGGER.info('Found last released resource %s', resource_id)
        tool_path = str(sdk2.ResourceData(sdk2.Resource[resource_id]).path)
        LOGGER.info('Found tool\'s path: %s', tool_path)
        return tool_path


class DumpRunner(object):
    def __init__(self, tool_path, mongo_uri, database, collection, splits, yt_token, yt_pool, yt_proxy, dst):
        self._tool_path = tool_path
        self._mongo_uri = mongo_uri
        self._yt_proxy = yt_proxy
        self._yt_token = yt_token
        self._dst = dst
        self._yt_pool = yt_pool
        self._collection = collection
        self._splits = splits
        self._mongo_database = database

    def make_snapshot(self):
        args = [
            self._tool_path,
            '--source-uri', self._mongo_uri,
            '--source-splits', str(self._splits),
            '--source-collection', self._collection,
            '--skip-changes',
            '--source-database', self._mongo_database,
            'yt',
            '--yt-proxy', self._yt_proxy,
            '--target-dir', self._dst,
        ]
        LOGGER.info('Run command %s', args)
        return process.run_process(
            args,
            log_prefix='make_snapshot_{}'.format(self._collection),
            environment={
                'YT_TOKEN': self._yt_token,
                'YT_POOL': self._yt_pool,
            },
            wait=False,
        )

    def apply_changes(self, start, end):
        args = [
            self._tool_path,
            '--source-uri', self._mongo_uri,
            '--source-collection', self._collection,
            '--skip-snapshot',
            '--start-ts', str(start),
            '--end-ts', str(end),
            '--source-database', self._mongo_database,
            'yt',
            '--yt-proxy', self._yt_proxy,
            '--target-dir', self._dst,
        ]
        LOGGER.info('Run command %s', args)
        return process.run_process(
            args,
            log_prefix='apply_changes_{}'.format(self._collection),
            environment={
                'YT_TOKEN': self._yt_token,
                'YT_POOL': self._yt_pool,
            },
            wait=False,
        )


__TASK__ = CollectionsDumpMongo
