from yp.client import YpClient, find_token
from yp.common import YP_PRODUCTION_CLUSTERS
import yt.wrapper
import nirvana.job_context as nv
import library.python.resource as rs

import datetime
import logging
import pandas as pd
import os
import json

logging.basicConfig(level=logging.INFO)
CLUSTERS = YP_PRODUCTION_CLUSTERS 

BN_REQ = 0
BN_TOT_PC = 0
HP_PC = 0
HP_TOT_PC = 0

    
def get_total_bandwith_mapping(client):
    total_node_mapping = {}

    for res in client.select_objects(
        "resource",
        selectors=["/meta/node_id", "/spec"],
        filter="[/meta/kind] = 'network'",
        enable_structured_response=True
    )["results"]:
        total_node_mapping[res[0]["value"]] = res[1]["value"]["network"]["total_bandwidth"]

    return total_node_mapping


def get_pod_mapping(client, guars):
    gmin = 10

    pod_mapping = {}
    node_pod_mapping = {}
    node_service_mapping = {}

    data = client.select_objects(
        "pod",
        selectors=["/meta/id", "/meta/pod_set_id", "/spec/node_id", "/labels/deploy_engine", "/labels/nanny_service_id"],
        enable_structured_response=True
    )["results"]
    pod_id, pod_set_id, node_id, deploy_engine, nanny_service_id = [[rec[i]["value"] for rec in data] for i in range(5)]

    for i in range(len(data)):
        try:
            is_myengine = True
            guarantee = gmin * 1024 * 1024
            service = ""
            if deploy_engine[i] == "YP_LITE":
                service = nanny_service_id[i]
            elif deploy_engine[i] in ("MCRSC", "RSC"):
                service = pod_set_id[i].split(".")[0]
            else:
                is_myengine = False

            if service in guars:
                guarantee = int(float(guars[service])) * 1024 * 1024

            if is_myengine:
                pod_mapping[pod_id[i]] = guarantee
                if node_id[i] not in node_pod_mapping:
                    node_pod_mapping[node_id[i]] = [pod_id[i]]
                else:
                    node_pod_mapping[node_id[i]].append(pod_id[i])
                if node_id[i] not in node_service_mapping:
                    node_service_mapping[node_id[i]] = [service]
                else:
                    node_service_mapping[node_id[i]].append(service)
            
        except KeyError:
            pass

    return pod_mapping, node_pod_mapping, node_service_mapping


def main(guars):
    result_heavypods = []
    result_badneighbours = []
    dump_heavypods = []
    dump_badneighbours = []

    for cluster in CLUSTERS:
        logging.info("Get accounts from cluster %s", cluster)
        client = YpClient(cluster, config=dict(token=find_token()))

        total_node_mapping = get_total_bandwith_mapping(client)
        pod_mapping, node_pod_mapping, node_service_mapping = get_pod_mapping(client, guars)
        node_result = []

        for node in client.select_objects(
            "node",
            selectors=["/meta/id"],
            filter='[/labels/segment] = "default"',
            enable_structured_response=True
        )["results"]:
            node_id = node[0]["value"]
            try:
                node_pods = node_pod_mapping[node_id]
                node_capacity = total_node_mapping[node_id]
                node_services = node_service_mapping[node_id]

                allocated_capacity = 0
                max_pod_net = 0
                max_pod_service = ""
                req_pods = []

                for i in range(len(node_pods)):
                    pod = node_pods[i]
                    service = node_services[i]
                    allocated_capacity += int(pod_mapping[pod])
                    if int(pod_mapping[pod]) > max_pod_net:
                        max_pod_net = int(pod_mapping[pod])
                        max_pod_service = service
                    if int(pod_mapping[pod]) >= BN_REQ * 1024 * 1024:
                        req_pods.append(service)
                    
                node_result.append({"id": node_id, 
                                    "net_capacity": node_capacity, 
                                    "net_allocated": allocated_capacity, 
                                    "max_pod_net": max_pod_net, 
                                    "max_pod_service": max_pod_service,
                                    "req_pods": req_pods,
                                    "req_pods_len": len(req_pods)
                })
            except:
                pass

        frame = pd.DataFrame(node_result)
        heavy_pods_frame = frame[(frame["max_pod_net"] >= HP_PC * frame['net_capacity']) & (frame['net_allocated'] > HP_TOT_PC * frame['net_capacity'])]
        bad_neighbours_frame = frame[(frame["req_pods_len"] >= 2) & (frame['net_allocated'] > BN_TOT_PC * frame['net_capacity'])]

        result_heavypods.append({"cluster": cluster, "bad_nodes": len(heavy_pods_frame), "date": datetime.date.today().strftime("%Y-%m-%d")})
        result_badneighbours.append({"cluster": cluster, "bad_nodes": len(bad_neighbours_frame), "date": datetime.date.today().strftime("%Y-%m-%d")})

        nodes_list = heavy_pods_frame.to_dict('records')
        for node in nodes_list:
            break
        if len(nodes_list) != 0:
            dump_heavypods += [{"cluster": cluster, "node": node["id"], "service": node["max_pod_service"], "date": datetime.date.today().strftime("%Y-%m-%d")} 
                for node in nodes_list]
            
        nodes_list = bad_neighbours_frame.to_dict('records')
        for node in nodes_list:
            break
        if len(nodes_list) != 0:
            dump_badneighbours += [{"cluster": cluster, "node": node["id"], "service": srv, "date": datetime.date.today().strftime("%Y-%m-%d")} 
                for node in nodes_list for srv in node["req_pods"]]

    return result_heavypods, result_badneighbours, dump_heavypods, dump_badneighbours


if __name__ == '__main__':
    job_context = nv.context()
    parameters, outputs, inputs = job_context.get_parameters(), job_context.get_outputs(), job_context.get_inputs()

    BN_REQ = float(os.environ.get("BN_REQ"))
    BN_TOT_PC = float(os.environ.get("BN_TOT_PC"))
    HP_PC = float(os.environ.get("HP_PC"))
    HP_TOT_PC = float(os.environ.get("HP_TOT_PC"))

    guars = {entry["service_id"]: entry.get("net", {}).get("guarantee", 0) for entry in json.loads(rs.find("io_limits.json").decode("utf-8"))}

    heavypods, badneighbours, dump_heavypods, dump_badneighbours = main(guars)

    with open(outputs.get("heavy_pods"), "w") as heavypods_file:
        json.dump(heavypods, heavypods_file)
    with open(outputs.get("bad_neighbours"), "w") as badneighbours_file:
        json.dump(badneighbours, badneighbours_file)
    with open(outputs.get("dump_heavy_pods"), "w") as dump_heavypods_file:
        json.dump(dump_heavypods, dump_heavypods_file)
    with open(outputs.get("dump_bad_neighbours"), "w") as dump_badneighbours_file:
        json.dump(dump_badneighbours, dump_badneighbours_file)