
import os
import json
import sys
import logging
from yp.client import YpClient, find_token
from yp.common import YP_PRODUCTION_CLUSTERS
import requests
import argparse
import nirvana.job_context as nv
import infra.analytics.io_limits_pipeline.utils as utils

BYTES_TO_MBITS = 1000 * 1000 / 8
BYTES_TO_GBITS = 1000 * 1000 * 1000 / 8


logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))



def is_normal(sw):
    try:
        return set(sw.split("-")[1]) <= set("s0987654321")
    except:
        return False


def get_switch_uplinks():
    raw_data = requests.get("https://ro.racktables.yandex-team.ru/export/ports.json").json()
    switches_ports = [sw for sw in list(set(
        requests.get("https://ro.racktables.yandex-team.ru/export/switchports.txt").text.replace("\t","\n").split("\n")[::3]
    ) - set([""])) if is_normal(sw)]

    res = {}
    for switch in raw_data.keys():
        if switch.split('.')[0] in switches_ports:
            upl_total = 0
            for port in raw_data[switch]:
                if port["speed"]:
                    speed = int(port["speed"])
                if port["is_uplink"]:
                    upl_total += speed
            res[switch.split(".")[0]] = upl_total

    return res


def parse_overcommited(filename):

    bad_switches = {}

    overcommited_sas = open(filename)

    for line in overcommited_sas:
        if "NAME" in line:
            current_switch = line.replace("NAME: ", "").replace("\n", "")
        else:
            switch_downlink, switch_uplink = int(line.split("<=")[0].replace(" ", "")), int(line.split("<=")[1].replace(" ", "").replace("\n", ""))

            if switch_uplink / switch_downlink > 1.0:
                bad_switches[current_switch] = switch_uplink / switch_downlink

    return bad_switches


def get_default_nodes(yp_client):
    default_nodes = yp_client.select_objects("node", selectors=["/meta/id"],
                                            filter="[/labels/segment]=\"default\"",
                                            enable_structured_response=True).get("results")
    return [i[0]["value"] for i in default_nodes]


def get_host_switch_walle(fqdn):
    try:
        return requests.get("https://api.wall-e.yandex-team.ru/v1/hosts/{0}?fields=location.switch".format(fqdn)).json()["location"]["switch"]
    except:
        return ""


def get_hosts_switch(fqdns):
    mapping = {}
    counter = 0

    for fqdn in fqdns:
        switch = get_host_switch_walle(fqdn)
        try:
            mapping[fqdn] = switch
        except:
            pass
        counter += 1

        if counter % 100 == 0:
            logging.info("Extracted switches for {0} hosts".format(str(counter)))

    return mapping


def extract_host_net_capacity(fqdns, yp_client):
    mapping = {}
    counter = 0

    for fqdn in fqdns:
        network = yp_client.select_objects("resource", selectors=["/spec/network/total_bandwidth"],
                                            filter="[/meta/node_id]='{0}'and[/meta/kind]='network'".format(fqdn),
                                            enable_structured_response=True).get("results")
        try:
            mapping[fqdn] = network[0][0]["value"]
        except:
            pass
        counter += 1

        if counter % 100 == 0:
            logging.info("Extracted capacities on {0} hosts".format(str(counter)))

    return mapping


def round_up(cp):
    cp /= BYTES_TO_GBITS / BYTES_TO_MBITS
    if cp == 0.0:
        return 0
    if cp <= 1.0:
        return int(1.0 * BYTES_TO_GBITS)
    if cp <= 2.5:
        return int(2.5 * BYTES_TO_GBITS)
    if cp <= 5.0:
        return int(5.0 * BYTES_TO_GBITS)
    if cp <= 10.0:
        return int(10.0 * BYTES_TO_GBITS)
    return int(cp * BYTES_TO_GBITS)


def round_down_complete(cp, coeff):
    cp /= BYTES_TO_GBITS / BYTES_TO_MBITS
    full_twohalf_initial = cp * 10 / 25
    full_twohalf_modified = int((cp / coeff) * 10 / 25)

    modified_capacity = int(((cp * full_twohalf_modified) / full_twohalf_initial) * BYTES_TO_GBITS)

    if modified_capacity > 0:
        return modified_capacity
    else:
        return 125000000


def get_switch_downlinks(net_caps):
    ret = {}
    counter = 0
    for fdqn, cap in net_caps.items():
        cap = int(cap / (1000 * 1000 / 8))
        walle = get_host_switch_walle(fdqn)
        if walle != "":
            if walle not in ret:
                ret[walle] = {
                    "total_downlink": cap,
                    "hosts": {
                        fdqn: cap
                    }
                }
            else:
                ret[walle]["total_downlink"] += cap
                ret[walle]["hosts"][fdqn] = cap

        counter += 1

        if counter % 100 == 0:
            logging.info("Extracted downlinks on {0} hosts".format(str(counter)))
    return ret


def main():
    uplinks = get_switch_uplinks()
    net_cap_normed = []

    for clst in sorted(YP_PRODUCTION_CLUSTERS): #iva, man, myt, sas, vla
        clst_data = []

        yp_client = YpClient(clst, config=dict(token=find_token()))
        def_nodes = get_default_nodes(yp_client)
        #def_nodes = random.sample=(def_nodes, 100)
        net_caps = extract_host_net_capacity(def_nodes, yp_client)
        downlinks = get_switch_downlinks(net_caps)

        for sw in downlinks:
            if downlinks[sw]["total_downlink"] > uplinks[sw]:
                coeff = uplinks[sw] / downlinks[sw]["total_downlink"]
                if coeff == 0.0:
                    logging.info("Zero overcommitment coeff on {0}".format(sw))
                    continue
                for hst, cap in downlinks[sw]["hosts"].items():

                    round_capacity = round_up(cap * coeff) < 10 * BYTES_TO_GBITS and round_up(cap) >= 10 * BYTES_TO_GBITS

                    clst_data.append(
                        {
                            "host": hst, "calibrated_up_capacity": round_up(cap * coeff), "intitial_capacity": cap,
                            "calibrated_down_capacity": round_down_complete(cap, downlinks[sw]["total_downlink"] /
                                                                            uplinks[sw]),
                            "round_capacity": round_capacity, "geo": clst,
                            "switch": sw, "overcommitment_coeff": downlinks[sw]["total_downlink"] / uplinks[sw],
                        }
                    )

        net_cap_normed.append(clst_data)

    return net_cap_normed


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_run", type=utils.str2bool, default=False)
    args = parser.parse_args()

    if args.local_run is True:
        fouts = [open("out{0}.json".format(str(i + 1)), "w") for i in range(len(YP_PRODUCTION_CLUSTERS))]
    else:
        job_context = nv.context()
        outputs = job_context.get_outputs()
        fouts = [open(outputs.get("output" + str(i + 1)), "w") for i in range(len(YP_PRODUCTION_CLUSTERS))]

    ncm = main()

    for i in range(len(YP_PRODUCTION_CLUSTERS)):
        json.dump(ncm[i], fouts[i], indent=4, sort_keys=True)
