import collections
import json

from crypta.lib.python.yt import yt_helpers
from crypta.spine.pushers.yt_replicated_table_checker.lib import common


def serialize_replica(cluster, path):
    return "{}.[{}]".format(cluster, path)


def check_table_type(client, path):
    if not client.exists(path):
        return ["{} does not exist".format(path)]

    return common.check_attrs(client, path, [
        {"type": "replicated_table"},
        {"dynamic": True},
    ])


def check_replicas(present_replicas, expected_replicas, replication_lag_threshold):
    errors = collections.defaultdict(list)

    present_replica_keys = set(present_replicas.keys())
    expected_replica_attrs = {(replica.name, replica.path): replica.expected_replication_attributes for replica in expected_replicas}
    expected_replica_keys = set(expected_replica_attrs.keys())

    for key in present_replica_keys - expected_replica_keys:
        errors[key].append("extra replica")

    for key in expected_replica_keys - present_replica_keys:
        errors[key].append("missing replica")

    for key in expected_replica_keys & present_replica_keys:
        replica = present_replicas[key]
        replica_errors = errors[key]

        if replica["replication_lag_time"] > replication_lag_threshold.total_seconds() * 1000:
            replica_errors.append("replication lag time: {} > {}".format(replica["replication_lag_time"], replication_lag_threshold))

        if replica["state"] != "enabled":
            replica_errors.append("disabled")

        for attr, value in expected_replica_attrs[key].iteritems():
            actual_value = replica.get(attr)
            if value != actual_value:
                replica_errors.append("mismatched replication attribute '{}': expected '{}', actual '{}'".format(attr, repr(value), repr(actual_value)))

        replica_errors.extend(replica["errors"])

    return ["{}: {}".format(serialize_replica(*k), json.dumps(v)) for k, v in errors.iteritems() if v]


def check_replica_sync_count(replicas, sync_count):
    sync_replicas = [serialize_replica(*key) for key in replicas if replicas[key]["mode"] == "sync"]
    if len(sync_replicas) != sync_count:
        return ["sync replica count is not {}: {}".format(sync_count, json.dumps(sync_replicas))]
    return []


def get_replication_errors_count(replica_id, master_attrs):
    # TODO(CRYPTAYT-3579) Remove "replication_errors" branch after Markov update to 19.8
    if "replication_errors" in master_attrs:
        return len(master_attrs["replication_errors"][replica_id])
    else:
        return master_attrs["replicas"][replica_id]["error_count"]


def get_replica_statuses(client, master_path):
    master_attrs = yt_helpers.get_attributes(master_path, ["replicas", "replication_errors"], client)
    result = {}

    for replica_id, replica_attrs in master_attrs["replicas"].iteritems():
        error_count = get_replication_errors_count(replica_id, master_attrs)
        errors = ["replication errors count: {}".format(error_count)] if error_count > 0 else []
        result[(replica_attrs["cluster_name"], replica_attrs["replica_path"])] = dict(replica_attrs, id=replica_id, errors=errors)

    return result


def get_errors(replicated_table):
    try:
        master = replicated_table.master

        master_client = yt_helpers.get_yt_client(master.proxy)

        errors = check_table_type(master_client, master.path)
        if errors:
            return errors

        errors += common.check_attrs(master_client, master.path, [dict(master.expected_attributes, tablet_state="mounted")])

        present_replicas = get_replica_statuses(master_client, master.path)
        errors += check_replicas(present_replicas, replicated_table.replicas, master.replication_lag_threshold)
        errors += check_replica_sync_count(present_replicas, master.sync_count)

        return errors
    except Exception as e:
        return [repr(e)]
