import datetime
import logging
import os
import re
from os.path import join as pj

from sandbox import sdk2
from sandbox.sdk2 import service_resources

NODES_TO_BACKUP = [
    "backend",
    "workers/worker/counters/state",
    "workers/worker/processor/state",
    "imports/channel_info"
]
BACKUP_FINISHED_ATTR = "@backup_finished"
TIME_FORMAT = "%Y%m%d-%H%M%S"
STATE_RE = "^[0-9]{8}-[0-9]{6}$"


class SmelterBackupTask(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        cluster = sdk2.parameters.String("YT cluster", default="arnold")
        prefix = sdk2.parameters.String("YT prefix", default="//home/smelter")
        backup_prefix = sdk2.parameters.String("YT prefix where backups are stored", default="//home/smelter/backup")
        vault_yt_token = sdk2.parameters.String("YT token vault item name", required=True, default="robot-smelter-yt-token")
        yt_pool = sdk2.parameters.String("YT pool", default="smelter")
        max_operations_count = sdk2.parameters.Integer("Max operations count in YT", default=3)
        states_to_keep = sdk2.parameters.Integer("Max operations count in YT", default=5)

    def make_yt_client(self):
        import yt.wrapper as yt
        yt_token = None
        if self.Parameters.vault_yt_token:
            yt_token = sdk2.Vault.data(self.Parameters.vault_yt_token).strip()
        else:
            logging.warning("YT token was not provided via vault_yt_token parameter")

        return yt.YtClient(proxy=self.Parameters.cluster, token=yt_token)

    def on_create(self):
        self.Requirements.tasks_resource = service_resources.SandboxTasksBinary.find(
            owner="SMELTER",
            attrs={
                "task_type": "SMELTER_BACKUP_TASK",
                "released": "stable"
            },
        ).first()

    def yt_walk(self, yt_client, paths_with_suffixes, types):
        result = []

        def yt_walk_inner(node, prefix):
            if node.attributes["type"] in types:
                result.append(prefix)
            elif node.attributes["type"] == "map_node":
                for name, inner_node in node.items():
                    yt_walk_inner(inner_node, pj(prefix, name))

        for path, suffix in paths_with_suffixes:
            yt_walk_inner(yt_client.get(path, attributes=["type"]), suffix)
        return set(result)

    def cleanup(self, yt_client, current_state):
        assert self.Parameters.states_to_keep > 0

        prefix = self.Parameters.backup_prefix

        states = sorted(list(yt_client.list(prefix)))

        for state in states:
            assert re.match(STATE_RE, state), "Cleanup: state {} has invalid format".format(state)

        assert current_state in states, "Cleanup: current_state {} is not there".format(current_state)

        states_to_keep = []
        finished_states_to_keep = 0
        # Keeping at least self.Parameters.states_to_keep finished states and everything newer than ony of those
        for state in reversed(states):
            states_to_keep.append(state)
            attr_path = pj(prefix, state, BACKUP_FINISHED_ATTR)
            if yt_client.exists(attr_path) and yt_client.get(attr_path):
                finished_states_to_keep += 1
                logging.info("Cleanup: finished state to keep: %s", state)
            else:
                logging.info("Cleanup: unfinished state to keep: %s", state)
            if finished_states_to_keep >= self.Parameters.states_to_keep:
                break

        states_to_delete = [state for state in states if state not in states_to_keep]
        logging.info("Cleanup: %d state to delete", len(states_to_delete))

        assert len(states) - len(states_to_delete) >= self.Parameters.states_to_keep
        assert current_state not in states_to_delete, "Cleanup: current_state {} should not be deleted".format(current_state)

        for folder in [pj(prefix, state) for state in states_to_delete]:
            assert folder.startswith(self.Parameters.backup_prefix), "Cleanup: double check: {} should start with {}".format(folder, self.Parameters.backup_prefix)
            assert re.match(STATE_RE, os.path.basename(folder)), "Cleanup: double check basename {} should be of date format".format(os.path.basename(folder))
            logging.info("Removing %s...", folder)
            yt_client.remove(folder, recursive=True)

    def on_execute(self):
        yt_client = self.make_yt_client()

        current_state = datetime.datetime.today().strftime(TIME_FORMAT)
        backup_prefix = pj(self.Parameters.backup_prefix, current_state)

        tables_to_backup = self.yt_walk(yt_client, [(pj(self.Parameters.prefix, suffix), suffix) for suffix in NODES_TO_BACKUP], ["table"])
        logging.info("Going to backup %d tables into %s", len(tables_to_backup), backup_prefix)

        import yt.wrapper as yt
        with yt.OperationsTrackerPool(self.Parameters.max_operations_count, client=yt_client) as tracker:
            for table_suffix in tables_to_backup:
                src_path = pj(self.Parameters.prefix, table_suffix)
                dst_path = pj(backup_prefix, table_suffix)
                dst_dir_path = os.path.dirname(dst_path)
                if not yt_client.exists(src_path):
                    logging.error("Could not find %s, ignoring...", src_path)
                    continue
                if not yt_client.exists(dst_dir_path):
                    logging.info("%s does not exist, creating it...", dst_dir_path)
                    yt_client.create("map_node", dst_dir_path, recursive=True)

                logging.info("Backing up %s into %s", src_path, dst_path)
                spec_builder = yt.spec_builders.MergeSpecBuilder(). \
                    input_table_paths(src_path).output_table_path(dst_path). \
                    mode("auto")
                if self.Parameters.yt_pool:
                    spec_builder = spec_builder.pool(self.Parameters.yt_pool)
                tracker.add(spec_builder)
            logging.info("All operations started. Waiting to complete.")

        yt_client.set(pj(backup_prefix, BACKUP_FINISHED_ATTR), True)  # for transfer to find finished version of backup

        self.cleanup(yt_client, current_state)

        logging.info("Finished")
