
import argparse
from collections import Counter
import logging
import json
import random
import os
import re
from math import ceil
import requests
import csv
from pprint import pprint
from threading import *

import nirvana.job_context as nv
from yp.client import YpClient, find_token
from infra.analytics.io_limits_pipeline.utils import str2bool

DCS = {"iva", "myt", "vla", "sas", "man", "man-pre", "sas-test"}

logging.basicConfig(level = logging.DEBUG)

auto_services_regex = [re.compile(i) for i in ["mt_front-.*", "mt_vendor-.*", "mt_partner-.*", "pdb_nodejs-.*"]]


def determine_yp_cluster(host):
    cluster = host.split("-")[0][0:-1]

    if cluster in DCS:
        return cluster
    else:
        for dc in DCS:
            if dc in host:
                return dc


def get_balancer_services():

    nanny_token = os.environ.get("NANNY_TOKEN")
    awacs_services = set()

    headers = {"Authorization": "OAuth {0}".format(nanny_token)}
    data = requests.get("https://nanny.yandex-team.ru/v2/services/?type=AWACS_BALANCER&exclude_runtime_attrs=1",
                        headers=headers).json()["result"]

    for record in data:
        awacs_services.add(record["_id"])

    return awacs_services


def get_abc_group_id(servicename):

    cmd = '''
    curl -X GET -H 'Authorization: OAuth {0}' https://nanny.yandex-team.ru/v2/services/{1}/info_attrs/
    '''.format(os.environ.get("NANNY_TOKEN"), servicename)
    try:
        result = json.loads(os.popen(cmd).read())
    except:
        logging.debug("for service {0} we have not located any abc data at all".format(servicename))
        result = {}

    return result.get('content', {}).get('abc_group', None)


def get_abc_group_id_qyp(servicename):

    # servicename для QYP - id пода
    for cluster in DCS:

        with YpClient(cluster, config=dict(token=find_token())) as yp_client:
            podset = yp_client.select_objects("pod", selectors=["/meta/pod_set_id"],
                                             filter="[/meta/id]=\"{0}\"".format(servicename),
                                             enable_structured_response=True).get("results")

            if podset and len(podset) > 0:

                pod_set_id = podset[0][0]["value"]

                account = yp_client.select_objects(
                    "pod_set", selectors=["/spec/account_id"], filter="[/meta/id]='{0}'".format(pod_set_id),
                    enable_structured_response=True
                ).get("results")

                if account:
                    account_id = account[0][0]["value"]

                    return account_id.replace("abc:service:", "")

def get_abc_group_id_rsc(servicename):

    # servicename для QYP - id пода
    for cluster in DCS:

        with YpClient(cluster, config=dict(token=find_token())) as yp_client:
            podset = yp_client.select_objects("pod", selectors=["/meta/pod_set_id"],
                                             filter="[/labels/deploy/stage_id]=\"{0}\"".format(servicename),
                                             enable_structured_response=True).get("results")

            if podset and len(podset) > 0:

                pod_set_id = podset[0][0]["value"]

                account = yp_client.select_objects(
                    "pod_set", selectors=["/spec/account_id"], filter="[/meta/id]='{0}'".format(pod_set_id),
                    enable_structured_response=True
                ).get("results")

                if account:
                    account_id = account[0][0]["value"]

                    return account_id.replace("abc:service:", "")


def get_abc_group_id_nanny(servicename):

    token = os.environ.get("NANNY_TOKEN")
    _headers = {"Authorization": "OAuth {0}".format(token)}

    try:
        instances = requests.get("https://nanny.yandex-team.ru/v2/services/{0}/current_state/instances/".format(servicename), headers=_headers).json().get("result")

        if len(instances) >= 1:
            first_instance = instances[0]["container_hostname"]

            pod, dc = first_instance.split(".")[0], first_instance.split(".")[1]

            with YpClient(dc, config=dict(token=find_token())) as yp_client:

                podset = yp_client.select_objects(
                    "pod", selectors=["/meta/pod_set_id"], filter="[/meta/id]=\"{0}\"".format(pod),
                    enable_structured_response=True
                ).get("results")

                if podset:

                    podset = podset[0][0]["value"]

                    account = yp_client.select_objects(
                        "pod_set", selectors=["/spec/account_id"], filter="[/meta/id]='{0}'".format(podset),
                        enable_structured_response=True
                    ).get("results")


                    if account:
                        abc_id = account[0][0]["value"]

                        return abc_id.replace("abc:service:", "")

        else:
            return get_abc_group_id(servicename)

    except:
        logging.debug("for service {0} we have not located any abc data at all".format(servicename))
        return None


def get_nannyservice_segment_cluster_location(servicename):
    cluster_allocation = {}

    # добавляем сегмент в формате cluster_segment

    try:
        for cluster in DCS:
            with YpClient(cluster, config=dict(token=find_token())) as yp_client:

                reslt = yp_client.select_objects("pod", selectors=["/labels/segment"],
                                         filter="[/labels/nanny_service_id]=\"{0}\"".format(servicename),
                                         enable_structured_response=True)

                segment = "default"
                try:
                    cluster_allocation["{0}_{1}".format(cluster, segment)] = len(reslt["results"])
                except IndexError:
                    pass
    except:
        for cluster in DCS:
            with YpClient(cluster, config=dict(token=find_token())) as yp_client:

                reslt = yp_client.select_objects("pod", selectors=["/labels/segment"],
                                         filter="[/labels/nanny_service_id]=\"{0}\"".format(servicename),
                                         enable_structured_response=True)

                segment = "default"
                try:
                    cluster_allocation["{0}_{1}".format(cluster, segment)] = len(reslt["results"])
                except IndexError:
                    pass

    return cluster_allocation


def get_qyp_segment_cluster_location(servicename):

    cluster_allocation = {}

    for cluster in DCS:
        with YpClient(cluster, config=dict(token=find_token())) as yp_client:
            reslt = yp_client.select_objects(
                "pod", selectors=["/labels/segment"], filter="[/meta/id]='{0}'".format(servicename),
                enable_structured_response=True
            )
            segment = "default"
            try:
                cluster_allocation["{0}_{1}".format(cluster, segment)] = len(reslt["results"])
            except IndexError:
                pass

    return cluster_allocation

def get_rsc_segment_cluster_location(servicename):

    cluster_allocation = {}

    for cluster in DCS:
        with YpClient(cluster, config=dict(token=find_token())) as yp_client:
            reslt = yp_client.select_objects(
                "pod", selectors=["/labels/segment"], filter="[/labels/deploy/stage_id]='{0}'".format(servicename),
                enable_structured_response=True
            )
            segment = "default"
            try:
                cluster_allocation["{0}_{1}".format(cluster, segment)] = len(reslt["results"])
            except IndexError:
                pass

    return cluster_allocation


def get_ssd_servicnames(fname):
    limits = json.load(open(fname))
    return [i["service_id"] for i in limits]


def get_full_capacity_used(abcid, cluster, disk_type):

    with YpClient(cluster, config=dict(token=find_token())) as yp_client:
        resources_usage_capacity = yp_client.select_objects("account",
                                                 selectors=["/spec/resource_limits/per_segment/default/cpu/capacity",
                                                            "/spec/resource_limits/per_segment/default/memory/capacity",
                                                            "/status/resource_usage/per_segment/default/cpu/capacity",
                                                            "/status/resource_usage/per_segment/default/memory/capacity"
                                                            ],
                                                 filter="[/meta/id]=\"abc:service:{0}\"".format(abcid),
                                                 enable_structured_response=True)

        try:
            cpu_usage = resources_usage_capacity["results"][0][2]["value"] / \
                        resources_usage_capacity["results"][0][0]["value"]
        except:
            cpu_usage = 0

        try:
            memory_usage = resources_usage_capacity["results"][0][3]["value"] / \
                           resources_usage_capacity["results"][0][1]["value"]
        except:
            memory_usage = 0

        divisor = max(cpu_usage, memory_usage)

        if divisor <= 0.4:
            divisor = 0.4

        logging.debug("{0} is divisor for {1}".format(str(divisor), str(abcid)))

        return divisor

service_guarantee_mapping = {}
no_abc = set()

no_abc_lock = BoundedSemaphore(1)
backup_lock = BoundedSemaphore(1)
general_lock = BoundedSemaphore(10)

class ServiceInfoGetter(Thread):
    def __init__(self, serv_name, deploy_engine, guarantee, backup, no_abc_file):
        Thread.__init__(self)
        self.serv_name = serv_name
        self.deploy_engine = deploy_engine
        self.guarantee = guarantee
        self.backup = backup
        self.no_abc_file = no_abc_file
    def run(self):
        general_lock.acquire()
        logging.debug("starting to process service {0} with engine {1}".format(self.serv_name, self.deploy_engine))
        if self.deploy_engine == "YP_LITE":
            abc_id = get_abc_group_id_nanny(self.serv_name)
            pod_dislocation = get_nannyservice_segment_cluster_location(self.serv_name)
        elif self.deploy_engine in ["MCRSC", "RSC"]:
            abc_id = get_abc_group_id_rsc(self.serv_name)
            pod_dislocation = get_rsc_segment_cluster_location(self.serv_name)
        elif self.deploy_engine == "QYP":
            abc_id = get_abc_group_id_qyp(self.serv_name)
            pod_dislocation = get_qyp_segment_cluster_location(self.serv_name)
        else:
            abc_id = None
            pod_dislocation = None

        logging.debug("Downloaded needed information")

        if abc_id == None:
            no_abc_lock.acquire()
            with open(self.no_abc_file, "a") as bckwrite:
                bckwrite.write(self.serv_name + "\n")
            no_abc_lock.release()
            logging.debug("Passed lock 1")
        else:
            service_guarantee_mapping[self.serv_name] = {"guarantee": self.guarantee}
            service_guarantee_mapping[self.serv_name]["abc_id"] = abc_id
            service_guarantee_mapping[self.serv_name]["pod_dislocation"] = pod_dislocation

            with open(self.backup, "a") as _bckwrite:
                backup_lock.acquire()
                bckwrite = csv.writer(_bckwrite)
                bckwrite.writerow([self.serv_name, self.guarantee, abc_id, json.dumps(pod_dislocation)])
                backup_lock.release()

            logging.debug("Passed lock 2")

            logging.debug(service_guarantee_mapping[self.serv_name])
        general_lock.release()

def get_nanny_abc_mapping(fle, test_run, resource_type, include_awacs, blacklist, backup, no_abc_file):

    try:
        with open(backup, "r") as bckread:
            service_guarantee_mapping = {row[0]: {"guarantee": int(row[1]), "abc_id": row[2], "pod_dislocation": json.loads(row[3])} for row in csv.reader(bckread)}
        #logging.debug(service_guarantee_mapping)
    except:
        logging.debug("No backup file found or backup file is broken.")
        service_guarantee_mapping = {}
        pass

    try:
        with open(no_abc_file, "r") as bckread:
            no_abc = set([svc.strip() for svc in bckread])
        #logging.debug(no_abc)
    except:
        logging.debug("No backup file found or backup file is broken.")
        no_abc = set()
        pass

    dta = json.load(fle)
    threads = []

    if test_run is True:
        dta = random.sample(dta, 20)

    if include_awacs is False:
        awacs_services = get_balancer_services()

    for serv in dta:
        serv_name = serv["service_id"]
        if serv_name in blacklist or serv_name in service_guarantee_mapping or serv_name in no_abc:
            continue

        if include_awacs is False:

            if serv_name in awacs_services or serv_name.startswith("rtc_balancer_") or serv_name.startswith("production_balancer_"):
                continue

        for pattern in auto_services_regex:
            if re.search(pattern, serv_name):
                continue

        guarantee = None

        if resource_type == "io":
            guarantee = 0

            for vol in serv["volumes"]:
                guarantee += int(vol["guarantee"])
        else:
            try:
                guarantee = int(serv["net"]["guarantee"])
            except KeyError:
                pass

        if guarantee:
            th = ServiceInfoGetter(serv_name, serv["deploy_engine"], guarantee, backup, no_abc_file)
            th.start()
            threads.append(th)

    for th in threads:
        th.join()

    return service_guarantee_mapping


def process_services(fle, disk_type, test_run, resource_type, include_awacs, blacklist, backup, no_abc_file):

    mapping = get_nanny_abc_mapping(fle, test_run, resource_type, include_awacs, blacklist, backup, no_abc_file)

    resulting_quotes = {}
    abcitem_guarantee_sum = {}

    for k, v in mapping.items():

        if v["abc_id"] not in abcitem_guarantee_sum:
            abcitem_guarantee_sum[v["abc_id"]] = [k]
        else:
            abcitem_guarantee_sum[v["abc_id"]].append(k)

    for abc, services in abcitem_guarantee_sum.items():

        logging.debug("processing abc {0}".format(abc))
        logging.debug("services: {0}".format(str(services)))

        abc_cluster = Counter()

        for serv in services:
            serv_pod_guarantee = int(mapping[serv]["guarantee"])
            pod_dislocation = mapping[serv]["pod_dislocation"]
            clusters = {}

            if pod_dislocation:
                for segm, podnum in pod_dislocation.items():
                    cluster, segment = segm.split("_")[0:2]

                    if segment == "default":
                        if cluster not in clusters:
                            clusters[cluster] = ceil(podnum * serv_pod_guarantee)
                        else:
                            clusters[cluster] += ceil(podnum * serv_pod_guarantee)

                        logging.debug("for service {0} and cluster {1} we have raw consumption {2}".format(serv,
                                                                                                           cluster,
                                                                                str(ceil(podnum * serv_pod_guarantee))))

            abc_cluster.update(clusters)

        for clst, guarantee in abc_cluster.items():

            divisor = get_full_capacity_used(abc, clst, disk_type)

            logging.debug("this abc contains following services ")
            for i in services:
                logging.debug(i + ", ")  
            logging.debug("\n\n") 


            logging.debug("guarantees before correction equals to {0}".format(str(guarantee)))
            guarantee_corrected = guarantee * (1 / divisor)
            logging.debug("guarantees corrected equlas to {0}".format(str(guarantee_corrected)))

            if guarantee_corrected > 0:
                if abc not in resulting_quotes:
                    resulting_quotes[abc] = [{"quota": ceil(guarantee_corrected), "cluster": clst}]
                else:
                    resulting_quotes[abc].append({"quota": ceil(guarantee_corrected), "cluster": clst})

    return resulting_quotes

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_run", type=str2bool, default=False)
    parser.add_argument("--test_run", type=str2bool, default=False)
    parser.add_argument("--limits_filename", type=str)
    parser.add_argument("--output_filename", type=str)
    parser.add_argument("--disk_type", type=str, default="ssd")
    parser.add_argument("--resource_type", type=str, default="io")
    parser.add_argument("--include_awacs", type=str2bool, default=True)
    parser.add_argument("--blacklist", type=str)
    parser.add_argument("--backup", type=str)
    parser.add_argument("--no_abc_file", type=str)
    args = parser.parse_args()

    if args.local_run is True:
        processed_services = process_services(open(args.limits_filename), args.disk_type, args.test_run,
                                              args.resource_type, args.include_awacs, set(json.load(open(args.blacklist, "r"))), args.backup, args.no_abc_file)
        json.dump(processed_services, open(args.output_filename, "w"), indent="\t", sort_keys=True)
    else:
        job_context = nv.context()
        inputs = job_context.get_inputs()
        outputs = job_context.get_outputs()

        output_filename = os.environ.get("OUTPUT_FILE")
        input_filename = os.environ.get("INPUT_FILE")
        resource_type = os.environ.get("RESOURCE_TYPE")
        blacklist = os.environ.get("BLACKLIST")

        processed_services = process_services(open(inputs.get(input_filename)), args.disk_type, args.test_run,
                                              resource_type, args.include_awacs, args.deploy_engine, set(json.load(open(args.blacklist, "r"))), args.backup, args.no_abc_file)

        with open(outputs.get(output_filename), "w") as write_file:
            json.dump(processed_services, write_file, indent="\t", sort_keys=True)
