import contextlib
import functools
import logging
import os

import yaml

from crypta.lib.native.database.python.record import TRecord
from crypta.lib.python.yt import (
    tm_utils,
    yt_helpers,
)
from crypta.styx.services.common import db_dump
from crypta.styx.services.common.serializers.python import puid_state_record_serializer


logger = logging.getLogger(__name__)


def map_puid_state_to_last_event(row):
    puid_state = puid_state_record_serializer.FromRecord(TRecord.Create(row["key"], row["value"]._bytes))

    if len(puid_state.OblivionEvents) > 0:
        last_event = puid_state.OblivionEvents[-1]

        yield {
            "puid": puid_state.Puid,
            "hash": last_event.Obfuscated,
            "timestamp": last_event.Unixtime,
        }


def read_table_yaml(yt_client, dictionary_table):
    return yaml.safe_dump([row for row in yt_client.read_table(dictionary_table)])


def dump_table_as_yaml(yt_client, src_table, dst_file):
    table_dump_str = read_table_yaml(yt_client, src_table).encode("utf-8")
    yt_client.write_file(dst_file, table_dump_str, force_create=True)


def copy_to_additional_clusters(src_cluster, dst_clusters, paths, tmp_dir):
    if not dst_clusters:
        return

    assert src_cluster not in dst_clusters, "Source cluster {} must not be in destination clusters {}".format(src_cluster, dst_clusters)

    src_yt_client = yt_helpers.get_yt_client(src_cluster)
    tm_client = tm_utils.get_client(os.environ.get("YT_TOKEN"))

    unsucessful_clusters = []

    for dst_cluster in dst_clusters:
        try:
            logger.info("Preparing copy to {}".format(dst_cluster))
            dst_yt_client = yt_helpers.get_yt_client(dst_cluster)
            dst_yt_client.mkdir(tmp_dir, recursive=True)

            temp_tables = {path: dst_yt_client.TempTable(tmp_dir) for path in paths}
            temp_table_paths = dict()

            with contextlib.ExitStack() as stack:
                for path, tmp_table in temp_tables.items():
                    temp_table_paths[path] = stack.enter_context(tmp_table)

                for path in paths:
                    logger.info("Transfering {} to {}".format(path, dst_cluster))
                    tm_utils.move_or_transfer_table(
                        src_client=src_yt_client,
                        src_path=path,
                        dst_client=dst_yt_client,
                        dst_path=temp_table_paths[path],
                        force=True,
                        tm_client=tm_client,
                    )

                logger.info("Finalizing transfer with move under transaction")
                with dst_yt_client.Transaction():
                    for path, tmp_path in temp_table_paths.items():
                        dst_yt_client.mkdir(os.path.dirname(path), recursive=True)
                        dst_yt_client.move(tmp_path, path, force=True)

        except Exception as e:
            logger.error("Failed to copy {} from {} to {}: {}".format(paths, src_cluster, dst_cluster, str(e)))
            unsucessful_clusters.append(dst_cluster)

    if unsucessful_clusters:
        raise Exception("Failed to copy {} to the following clusters: {}".format(paths, unsucessful_clusters))


def run(config):
    db_dump.run(
        config=config,
        db_row_mapper=map_puid_state_to_last_event,
        sort_by=["puid"],
        postprocessor=functools.partial(dump_table_as_yaml, dst_file=config.DstFile)
    )

    copy_to_additional_clusters(
        src_cluster=config.DstCluster,
        dst_clusters=config.AdditionalDstClusters,
        paths=[config.DstTable, config.DstFile],
        tmp_dir=config.DstTmpDir,
    )
