from infra.rtc.iolimit_ticketer.cli import cli
import click
from infra.rtc.iolimit_ticketer.yp_update import MEGABYTE
from yp.client import YpClient, find_token
from yp.common import YP_PRODUCTION_CLUSTERS
from yt.yson import YsonEntity
import logging
from collections import defaultdict
import json
import time
from collections import Counter


WHITELIST_DEPLOY_UNITS = {"MultiClusterReplicaSet"}
NET_GUARANTEES_PATH_DEPLOY = "/spec/deploy_units/{0}/replica_set/replica_set_template/pod_template_spec/spec/resource_requests/network_bandwidth_guarantee"
MCRS_NET_GUARANTEES_PATH_DEPLOY = "/spec/deploy_units/{0}/multi_cluster_replica_set/replica_set/pod_template_spec/spec/resource_requests/network_bandwidth_guarantee"


def get_all_stage_deploy_units(yp_client, stage):
    deploy_units = list(yp_client.select_objects(
        "stage", selectors=["/spec/deploy_units"], filter="[/meta/id]='{0}'".format(stage),
        enable_structured_response=True
    )["results"][0][0]["value"].keys())
    return deploy_units


def get_stage_map():
    stage_net_guarantee_map = {}

    io_limits = json.load(open("io_limits.json"))

    for record in io_limits:
        if record["deploy_engine"] in {"MCRSC", "RSC"}:

            if record.get("net", {}).get("guarantee"):
                stage_net_guarantee_map[record["service_id"]] = record["net"]["guarantee"] * MEGABYTE

    return stage_net_guarantee_map


def extract_stage_account_mapping(yp_client):
    account_stage_mapping = {}

    service_account_mapping = yp_client.select_objects(
        "stage", selectors=["/spec/account_id", "/meta/id"], enable_structured_response=True
    )["results"]

    for record in service_account_mapping:
        account_id, stage_id = record[0]["value"].replace("abc:service:", ""), record[1]["value"]

        if account_id not in account_stage_mapping:
            account_stage_mapping[account_id] = [stage_id]
        else:
            account_stage_mapping[account_id].append(stage_id)

    return account_stage_mapping


def get_pod_node_mapping(ctx):
    pod_node_mapping = {}

    for record in ctx.obj.yp_stat.service_map.values():
        for clust in YP_PRODUCTION_CLUSTERS:
            cluster = record.clusters.get(clust)

            if cluster:
                pods = cluster.pods

                for pod, pod_descriptor in pods.items():
                    pod_node_mapping[pod] = pod_descriptor.node_id

    return pod_node_mapping


def get_stage_pod_mapping(ctx):
    stage_pod_mapping = defaultdict(dict)

    for record in ctx.obj.yp_stat.service_map.values():
        stage = record.deploy_stage_id
        deploy_unit = record.deploy_unit_id

        if not stage:
            continue

        pods = []
        for clust in YP_PRODUCTION_CLUSTERS:
            cluster = record.clusters.get(clust)

            if cluster:
                cluster_pods = cluster.pods

                for pod, pod_descriptor in cluster_pods.items():
                    pods.append(pod_descriptor.pod_id)

        stage_pod_mapping[stage][deploy_unit] = pods

    return stage_pod_mapping


def validate_guarantee_aplicability(stage, stage_pod_mapping, free_bandwidth_mapping, pod_node_mapping,
                                    net_guarantee_map, stage_applicability_ratio):

    guarantees_applicable = True

    node_pod_counter = defaultdict(Counter)
    stage_pods = stage_pod_mapping[stage]
    total_nodes, total_inplace_applicable_nodes = Counter(), Counter()

    for deploy_unit, pods in stage_pods.items():
        for pod in pods:
            pod_node = pod_node_mapping[pod]
            node_pod_counter[deploy_unit].update({pod_node: 1})

    guarantee_to_apply = net_guarantee_map.get(stage, 10)

    for deploy_unit, pod_counter in node_pod_counter.items():
        for node, pod_count in pod_counter.items():
            total_nodes.update({deploy_unit: 1})
            total_guarantee_to_apply = pod_count * guarantee_to_apply

            if total_guarantee_to_apply <= free_bandwidth_mapping[node]:
                total_inplace_applicable_nodes.update({deploy_unit: 1})

            if not total_inplace_applicable_nodes.get(deploy_unit):
                total_inplace_applicable_nodes.update({deploy_unit: 0})

    for du, pod_count in total_nodes.items():
        if total_inplace_applicable_nodes[du] / total_nodes[du] <= stage_applicability_ratio:
            guarantees_applicable = False

    return guarantees_applicable


def apply_net_guarantee_stage(yp_client, stage, mapping, overwrite, dry_run):
    deploy_units = get_all_stage_deploy_units(yp_client, stage)

    for unit in [i for i in deploy_units if i not in WHITELIST_DEPLOY_UNITS]:
        path = NET_GUARANTEES_PATH_DEPLOY.format(unit)

        result = yp_client.select_objects("stage",
                                          selectors=[path],
                                          filter="[/meta/id]='{0}'".format(stage),
                                          enable_structured_response=True
                                          )
        current_guarantee, transaction_timestamp = result["results"][0][0]["value"], result["timestamp"]

        if current_guarantee != YsonEntity() and not overwrite:
            continue

        if mapping.get(stage):
            guarantee_to_apply = mapping[stage]
        else:
            guarantee_to_apply = 10 * MEGABYTE

        if not dry_run:
            transaction_id = yp_client.start_transaction()
            try:
                yp_client.update_object(
                    "stage", stage,
                    set_updates=[{"path": path, "value": guarantee_to_apply}],
                    transaction_id=transaction_id,
                    attribute_timestamp_prerequisites=[{"path": path, "timestamp": transaction_timestamp}]
                )
            except:
                path = MCRS_NET_GUARANTEES_PATH_DEPLOY.format(unit)
                transaction_id = yp_client.start_transaction()
                yp_client.update_object(
                    "stage", stage,
                    set_updates=[{"path": path, "value": guarantee_to_apply}],
                    transaction_id=transaction_id,
                    attribute_timestamp_prerequisites=[{"path": path, "timestamp": transaction_timestamp}]
                )

            yp_client.commit_transaction(transaction_id=transaction_id)
            logging.info("Applied guarantee {0} Mb/s on stage {1}".format(str(guarantee_to_apply), stage))


def apply_stage_change(stage, applicability_ratio, stage_pod_mapping, free_node_bandwidth_map, pod_node_mapping,
                       guarantee_map, xdc_client, overwrite, ctx):
    logging.info("Going to apply guarantee for stage {0}".format(stage))
    stage_applicability = validate_guarantee_aplicability(
        stage, stage_pod_mapping, free_node_bandwidth_map, pod_node_mapping, guarantee_map, applicability_ratio
    )

    if stage_applicability is True:
        apply_net_guarantee_stage(xdc_client, stage, guarantee_map, overwrite, ctx.obj.dry_run)
    else:
        logging.info("Can't apply garantee on stage {0}, reallocation rate is too high".format(stage))


@cli.command('apply_net_deploy')
@click.option('--overwrite/--no-overwrite', default=False)
@click.option('--sleep-time', type=int, default=10)
@click.option('--accounts', default="")
@click.option('--stages', default="")
@click.pass_context
def apply_net_deploy(ctx, overwrite, sleep_time, accounts, stages):

    xdc_client = YpClient("xdc", config=dict(token=find_token()))
    guarantee_map, stage_account_map = get_stage_map(), extract_stage_account_mapping(xdc_client)
    accounts = [str(x.strip()) for x in accounts.split(",") if x]
    stages = [str(x.strip()) for x in stages.split(",") if x]
    free_node_bandwidth_map = {k: v.free_network_bandwidth for k, v in ctx.obj.yp_stat.node_map.items()}
    pod_node_mapping = get_pod_node_mapping(ctx)
    stage_pod_mapping = get_stage_pod_mapping(ctx)

    stage_applicability_ratio = 0.85

    if len(accounts) > 0:
        for acc in accounts:
            acc_stages = stage_account_map[acc]
            for stage in acc_stages:
                apply_stage_change(
                    stage, stage_applicability_ratio, stage_pod_mapping, free_node_bandwidth_map, pod_node_mapping,
                    guarantee_map, xdc_client, overwrite, ctx
                )
                time.sleep(sleep_time)
    elif len(stages) > 0:
        for stage in stages:

            apply_stage_change(
                stage, stage_applicability_ratio, stage_pod_mapping, free_node_bandwidth_map, pod_node_mapping,
                guarantee_map, xdc_client, overwrite, ctx
            )
            time.sleep(sleep_time)

    else:
        for acc, stages in stage_account_map.items():
            for stage in stages:
                apply_stage_change(
                    stage, stage_applicability_ratio, stage_pod_mapping, free_node_bandwidth_map, pod_node_mapping,
                    guarantee_map, xdc_client, overwrite, ctx
                )
                time.sleep(sleep_time)
