import collections
import logging
import retry
import six

import yt.wrapper as yt

from crypta.lib.python.yt import (
    tm_utils,
    yt_helpers,
)
from crypta.lib.python.yt.dyntable_utils import (
    dump,
    replica_utils,
)

logger = logging.getLogger(__name__)


class YtOpsContext(object):
    FrozenTable = collections.namedtuple("FrozenTable", ["client", "path"])
    TempObject = collections.namedtuple("TempObject", ["client", "id"])

    def __init__(self):
        self.frozen_tables = set()
        self.temp_objects = set()

    def __enter__(self):
        self.frozen_tables.clear()
        self.temp_objects.clear()
        return self

    def __exit__(self, *args):
        logger.info("Cleaning up")

        self._unfreeze_frozen_tables()
        self._remove_temp_objects()

        if self.frozen_tables or self.temp_objects:
            logger.error("There were errors while cleaning up. Make sure to MANUALLY UNDO the following changes:")

            for frozen_table in self.frozen_tables:
                logger.error("  Frozen table: %s at %s", frozen_table.path, yt_helpers.get_cluster_name(frozen_table.client))

            for temp_object in self.temp_objects:
                logger.error("  Temporary object: %s at %s", temp_object.id, yt_helpers.get_cluster_name(temp_object.client))

        assert not self.frozen_tables and not self.temp_objects, "There must be neither frozen tables nor temp objects left"

    def freeze_table(self, client, path):
        logger.info("Freezing %s at %s", path, yt_helpers.get_cluster_name(client))

        client.freeze_table(path, sync=True)
        yt_helpers.wait_for_frozen(client, path, timeout=60)

        self.frozen_tables.add(self.FrozenTable(client, path))

    def add_temp_replica(self, client, replica_id):
        self.temp_objects.add(self.TempObject(client, replica_utils.format_replica_id(replica_id)))

    def forget_temp_replica(self, client, replica_id):
        self.temp_objects.remove(self.TempObject(client, replica_utils.format_replica_id(replica_id)))

    def _unfreeze_table(self, client, path):
        client.unfreeze_table(path, sync=True)
        yt_helpers.wait_for_mounted(client, path, timeout=60)

        self.frozen_tables.remove(self.FrozenTable(client, path))

    def _unfreeze_frozen_tables(self):
        for client, path in list(self.frozen_tables):
            cluster = yt_helpers.get_cluster_name(client)
            logger.info("Unfreezing %s at %s", path, cluster)

            try:
                self._unfreeze_table(client, path)
            except Exception as e:
                logger.error("Error while unfreezing table %s at %s: %s", path, cluster, e)

    def _remove_temp_objects(self):
        for client, id in list(self.temp_objects):
            cluster = yt_helpers.get_cluster_name(client)
            logger.info("Removing %s at %s", id, cluster)

            try:
                client.remove(id)
                self.temp_objects.remove(self.TempObject(client, id))
            except Exception as e:
                logger.error("Error while removing temporary object %s at %s: %s", id, cluster, e)


def _copy_attributes(src_client, src_table, dst_client, dst_table):
    builtin_attributes = [
        "account",
        "optimize_for",
        "in_memory_mode",
        "atomicity",
        "tablet_cell_bundle",
        "min_tablet_size",
        "max_tablet_size",
        "desired_tablet_size",
        "desired_tablet_count",
        "enable_dynamic_store_read",
        "enable_tablet_balancer",
    ]

    user_attributes = yt_helpers.get_attribute(src_table, "user_attribute_keys", src_client)
    if "forced_compaction_revision" in user_attributes:
        user_attributes.remove("forced_compaction_revision")

    src_attributes = yt_helpers.get_attributes(src_table, builtin_attributes + user_attributes, src_client)

    for name, value in six.iteritems(src_attributes):
        def set_attribute():
            yt_helpers.set_attribute(dst_table, name, value, dst_client)

        retry.retry_call(set_attribute, tries=3, delay=1, exceptions=yt.YtError)


def _wait_for_replication_to_finish(master_client, master_path, replica_id):
    master_cluster = yt_helpers.get_cluster_name(master_client)

    logger.info("Waiting for all data to be replicated for replica #%s of %s on %s", replica_id, master_path, master_cluster)

    def is_replicated():
        tablets = yt_helpers.get_attribute(replica_utils.format_replica_id(replica_id), "tablets", master_client)
        return all(tablet["flushed_row_count"] == tablet["current_replication_row_index"] for tablet in tablets)

    retry.retry_call(is_replicated, tries=5*60, delay=1, exceptions=yt.YtError)


def _tune_new_replica(src_client, src_replica_path, dst_client, dst_replica_path, replica_attrs):
    logger.info("Setting attributes for new replica %s at %s", dst_replica_path, yt_helpers.get_cluster_name(dst_client))
    _copy_attributes(src_client, src_replica_path, dst_client, dst_replica_path)

    logger.info("Running merge to fix chunk and block sizes")
    dst_client.run_merge(dst_replica_path, dst_replica_path, mode="ordered", spec={
        "force_transform": True,
        "job_io": {"table_writer": {"block_size": 256 * 2**10, "desired_chunk_size": 100 * 2**20}}
    })

    # Check because there were some problems with this attribute.
    assert yt_helpers.get_optimize_for(dst_client, dst_replica_path) == yt_helpers.get_optimize_for(src_client, src_replica_path)

    logger.info("Table copied, altering and resharding table")
    dst_client.alter_table(dst_replica_path, dynamic=True)

    pivot_keys = yt_helpers.get_attribute(src_replica_path, "pivot_keys", src_client)
    dst_client.reshard_table(dst_replica_path, pivot_keys=pivot_keys, sync=True)

    for key, value in six.iteritems(replica_attrs):
        dst_client.set_attribute(dst_replica_path, key, value)


def _connect_replica(master_client, replica_client, replica_path, replica_id, replication_attrs):
    logger.info("Dst attributes: %s", yt_helpers.get_attribute(replica_path, "", replica_client))
    replica_client.alter_table(replica_path, upstream_replica_id=replica_id)

    logger.info("Mounting new replica table %s at %s", replica_path, yt_helpers.get_cluster_name(replica_client))
    replica_client.mount_table(replica_path, sync=True)

    logger.info("Enabling new replica, replica_id = %s", replica_id)
    master_client.alter_table_replica(replica_id, True)

    for key, value in six.iteritems(replication_attrs):
        master_client.set_attribute("#{}".format(replica_id), key, value)


def _create_replica_object(master_client, master_path, reference_replica_id, replica_client, replica_path, cleanup_context):
    tablets = yt_helpers.get_attribute(replica_utils.format_replica_id(reference_replica_id), "tablets", master_client)
    current_replication_row_indexes = [tablet["current_replication_row_index"] for tablet in tablets]
    logger.debug("Got replication row indexes: %s", current_replication_row_indexes)

    replica_id = master_client.create("table_replica", attributes={
        "table_path": master_path,
        "cluster_name": yt_helpers.get_cluster_name(replica_client),
        "replica_path": replica_path,
        "start_replication_row_indexes": current_replication_row_indexes,
    })
    cleanup_context.add_temp_replica(master_client, replica_id)
    logger.info("Created new replica, id = %s", replica_id)
    return replica_id


def add_table_replica(
    master_client,
    master_table,
    src_client,
    src_table,
    dst_client,
    dst_table,
    temp_prefix,
    yt_pool,
    force,
    tm_client,
    replica_attrs,
    replication_attrs,
):
    with YtOpsContext() as cleanup_context, src_client.TempTable(prefix=temp_prefix) as dump_path:
        cleanup_context.freeze_table(master_client, master_table)
        reference_replica_id = replica_utils.get_replica_id_by_path(master_client, master_table,
                                                                    yt_helpers.get_cluster_name(src_client), src_table)
        _wait_for_replication_to_finish(master_client, master_table, reference_replica_id)

        cleanup_context.freeze_table(src_client, src_table)

        dump.dump_table(src_client, src_table, dump_path, yt_pool)
        tm_utils.move_or_transfer_table(src_client, dump_path, dst_client, dst_table, force, tm_client)
        _tune_new_replica(src_client, src_table, dst_client, dst_table, replica_attrs)

        dst_replica_id = _create_replica_object(master_client, master_table, reference_replica_id, dst_client, dst_table, cleanup_context)
        _connect_replica(master_client, dst_client, dst_table, dst_replica_id, replication_attrs)

        cleanup_context.forget_temp_replica(master_client, dst_replica_id)

        logger.info("Done")
