import json
import argparse

from yp.client import YpClient, find_token
from yp.common import YP_PRODUCTION_CLUSTERS
import nirvana.job_context as nv

import infra.analytics.io_limits_pipeline.utils as utils


def get_default_nodes(yp_client):
    return set([str(i[0]["value"]) for i in yp_client.select_objects(
        "node",
        selectors=["/meta/id"],
        filter='[/labels/segment]="default"',
        enable_structured_response=True
    )["results"]])


def get_total_bandwidth(clst):
    yp_client = YpClient(clst, config=dict(token=find_token()))
    data = yp_client.select_objects(
        "resource",
        selectors=["/spec/network/total_bandwidth", "/meta/node_id"],
        filter="[/meta/kind] = 'network'"
    )
    default_nodes = get_default_nodes(yp_client)
    return sum([a[0] for a in data if a[1] in default_nodes]) / (1000 * 1000)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_run", type=utils.str2bool, default=False)
    args = parser.parse_args()

    if args.local_run is True:
        fout = open("out.json", "w")
    else:
        job_context = nv.context()
        outputs = job_context.get_outputs()
        fout = open(outputs.get("output1"), "w")

    json.dump([{
        "cluster": clst,
        "total_bandwidth": get_total_bandwidth(clst)
    } for clst in YP_PRODUCTION_CLUSTERS], fout, indent=4)
