# coding: utf8

from datetime import datetime
import logging

from sandbox import sdk2
from sandbox.sdk2 import environments


NAME_FORMAT = "%Y%m%d_%H%M"


class BackupYtReplicatedDyntable(sdk2.Task):
    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment("yandex-yt"),
        ]

    class Parameters(sdk2.Task.Parameters):
        description = "Backup YT Replicated Dynamic Table"
        hidden = False
        owner = "MARKET-IDX"

        tokens = sdk2.parameters.YavSecret("YT Token Secret", required=True)
        meta_cluster = sdk2.parameters.String("Meta-cluster", required=True)
        meta_table_path = sdk2.parameters.String("Meta table path", required=True)
        backups_dir = sdk2.parameters.String("Backup directory", required=True)
        excluded_clusters = sdk2.parameters.List("Excluded clusters", required=False)
        max_backup_count = sdk2.parameters.Integer("Maximal count of backups to store", default=7)
        codec = sdk2.parameters.String("Compression codec", default="brotli_8")

    @staticmethod
    def _get_meta_info(meta_cluster, meta_table_path, clients):
        logging.info("Retrieving info about replicas.")
        meta_client = clients[meta_cluster]
        replicas = meta_client.get_attribute(meta_table_path, "replicas")  # type: dict

        replica_clusters = set()
        sync_replica_cluster = None
        sync_replica_path = None
        for replica_id in replicas:
            replica = replicas[replica_id]

            if replica["state"] != "enabled":
                logging.info("Filtered out replica %s: not \"enabled\"", replica_id)
                continue

            if replica["mode"] == "sync":
                if sync_replica_cluster:
                    raise RuntimeError("Unexpected amount of sync replicas, expected exactly 1.")
                sync_replica_cluster = replica["cluster_name"]
                sync_replica_path = replica["replica_path"]

            replica_clusters.add(replica["cluster_name"])

        if not sync_replica_cluster:
            raise RuntimeError("No sync replica found.")
        if not replica_clusters:
            raise RuntimeError("No replicas found.")

        logging.info("Has %s replicas on clusters: %s.", len(replica_clusters), ", ".join(replica_clusters))
        logging.info("Sync replica path: %s on cluster %s.", sync_replica_path, sync_replica_cluster)

        return sync_replica_cluster, sync_replica_path, replica_clusters

    @staticmethod
    def _create_local_backup(cluster, src_path, backups_dir, backup_name, codec, clients):
        from yt.wrapper import ypath_join

        dst_path = ypath_join(backups_dir, backup_name)

        logging.info("Creating local backup on '%s'.", cluster)
        src_client = clients[cluster]
        with src_client.Transaction():
            src_client.mkdir(backups_dir, recursive=True)
            src_client.lock(src_path, mode="snapshot")
            src_client.create(
                "table",
                dst_path,
                recursive=True,
                attributes={
                    "optimize_for": src_client.get_attribute(src_path, "optimize_for", default="lookup"),
                    "schema": src_client.get_attribute(src_path, "schema"),
                    "_yt_dump_restore_pivot_keys": src_client.get_attribute(src_path, "pivot_keys"),
                    "compression_codec": codec,
                },
            )
            op = src_client.run_merge(
                src_client.TablePath(src_path), dst_path, mode="ordered", sync=False
            )
            logging.info("Merge operation: %s", op.url)
            op.wait()
            logging.info("Successful local backup in cluster '%s' from '%s' to '%s'.", cluster, src_path, dst_path)

    @staticmethod
    def _replicate_to_clusters(src_cluster, dst_clusters, backups_dir, backup_name, codec, clients):
        from yt.wrapper import OperationsTracker, TablePath, ypath_join

        table_path = TablePath(ypath_join(backups_dir, backup_name), compression_codec=codec)

        if dst_clusters == [] or dst_clusters == [src_cluster]:
            logging.info("No clusters for replicating.")
            return

        logging.info("Replicating table to clusters (%s).", ', '.join(dst_clusters))
        with OperationsTracker() as tracker:
            for current_cluster in dst_clusters:
                if current_cluster == src_cluster:
                    continue
                current_client = clients[current_cluster]
                current_client.mkdir(backups_dir, recursive=True)
                op = current_client.run_remote_copy(table_path, table_path, src_cluster, sync=False, copy_attributes=True)
                logging.info("Remote copy operation (%s -> %s): %s", src_cluster, current_cluster, op.url)
                tracker.add(op)

        logging.info("Remote copies successfully finished.")

    @staticmethod
    def _create_recent_links(clusters, backups_dir, backup_name, clients):
        from yt.wrapper import ypath_join

        if not clusters:
            logging.info("No clusters to update recent links.")
            return

        logging.info("Updating recent links.")
        for current_cluster in clusters:
            current_client = clients[current_cluster]
            current_client.create(
                "link",
                ypath_join(backups_dir, "recent"),
                attributes={"target_path": ypath_join(backups_dir, backup_name)},
                force=True,
            )
            logging.info("Updated recent link on %s.", current_cluster)

        logging.info("Recent links were updated.")

    @staticmethod
    def _leave_last_n_backups(n, clusters, backups_dir, clients):
        from yt.wrapper import ypath_join

        if not clusters:
            logging.info("No clusters for clearing backup directory.")
            return

        logging.info("Clearing backup directory from old backups.")
        for current_cluster in clusters:
            current_client = clients[current_cluster]
            backup_names = [name for name in current_client.get(backups_dir) if name != "recent"]
            backup_names.sort(reverse=True)
            old_backup_names = backup_names[n:]
            if not old_backup_names:
                continue
            logging.info(
                "Found %s old backups: %s on cluster '%s'. Deleting.",
                len(old_backup_names),
                ', '.join(old_backup_names),
                current_cluster,
            )
            for name in old_backup_names:
                current_client.remove(ypath_join(backups_dir, name), force=True)
        logging.info("Removed old backups.")

    def _create_yt_client(self, cluster, yt_token=None):
        from yt.wrapper import YtClient

        if yt_token is None:
            yt_token = self.Parameters.tokens.data()["yt-market-indexer"]

        return YtClient(proxy=cluster, config={"token": yt_token})

    def on_execute(self):
        from yt.wrapper import ypath_join

        meta_cluster = self.Parameters.meta_cluster
        meta_table_path = self.Parameters.meta_table_path
        backups_dir = self.Parameters.backups_dir
        backup_name = datetime.now().strftime(NAME_FORMAT)
        excluded_clusters = set(self.Parameters.excluded_clusters)
        max_backup_count = self.Parameters.max_backup_count
        codec = self.Parameters.codec

        logging.info("Starting task with following parameters")
        logging.info("Meta cluster: %s", meta_cluster)
        logging.info("Meta table path: %s", meta_table_path)
        logging.info("Backups dir: %s", backups_dir)
        logging.info("Backup name: %s", backup_name)
        logging.info("Excluded clusters: {%s}", ", ".join(excluded_clusters))
        logging.info("Maximal backup count: %s", max_backup_count)
        logging.info("Compression codec: %s", codec)

        clients = {meta_cluster: self._create_yt_client(meta_cluster)}
        sync_replica_cluster, sync_replica_path, replica_clusters = self._get_meta_info(
            meta_cluster, meta_table_path, clients=clients
        )

        for current_cluster in replica_clusters:
            if current_cluster not in clients:
                clients[current_cluster] = self._create_yt_client(current_cluster)

        try:
            self._create_local_backup(
                sync_replica_cluster, sync_replica_path, backups_dir, backup_name, codec, clients=clients
            )

            good_clusters = [
                current_cluster for current_cluster in replica_clusters if current_cluster not in excluded_clusters
            ]

            self._replicate_to_clusters(
                sync_replica_cluster, good_clusters, backups_dir, backup_name, codec, clients=clients
            )
        finally:
            # удаляем бекапные таблицы с кластеров, где не должны хранить бекапы, например сенеки
            self._leave_last_n_backups(0, excluded_clusters.intersection(clients), backups_dir, clients=clients)

        self._create_recent_links(good_clusters, backups_dir, backup_name, clients=clients)
        self._leave_last_n_backups(max_backup_count, good_clusters, backups_dir, clients=clients)

        logging.info("Backup task has successfully finished.")
