
import json

from pprint import pprint

import requests
import random
import argparse
import logging
import os
import re
from library.python import resource

import yt.wrapper
from yp.client import YpClient, find_token
from infra.analytics.io_limits_pipeline.utils import str2bool, AUTOMATED_GENERATED_LIMITS, BUSY_DIRS_HDD, \
    BUSY_DIRS_SSD, AUTOMATED_GENERATED_LIMITS_NET
import nirvana.job_context as nv


from infra.analytics.io_limits_pipeline.get_limit_diff.prepare_balancer_limits import process_balancers


class NoActiveContainersException(Exception):
    pass


def split_quota(volumes, bandwidth, disk_type):
    if len(volumes) == 1:
        return {list(volumes)[0]: bandwidth}

    if disk_type == "ssd":
        BUSY_DIRS = BUSY_DIRS_SSD
        min_volume = 10
    else:
        BUSY_DIRS = BUSY_DIRS_HDD
        min_volume = 5

    min_bandwidth = int(bandwidth / len(volumes))

    has_busy_dirs = any(x in BUSY_DIRS for x in volumes)
    if not has_busy_dirs:

        if "/" not in volumes:
            result = {}
            for vol in volumes:
                # if vol != "/":
                if bandwidth <= min_bandwidth:
                    min_bandwidth = bandwidth
                result[vol] = min_bandwidth
                bandwidth -= min_bandwidth
            # assert bandwidth >= min_bandwidth
            if bandwidth > 0:
                result[list(result.keys())[-1]] += bandwidth
            return result
        else:
            min_bandwidth = int(min(bandwidth / len(volumes), min_volume))
            result = {}
            for vol in volumes:
                if vol != "/":
                    result[vol] = min_bandwidth
                    bandwidth -= min_bandwidth
            assert bandwidth >= min_bandwidth
            result["/"] = bandwidth
            return result

    min_bandwidth = int(min(bandwidth / len(volumes), min_volume))
    result = {}
    busy_volumes = []
    for vol in volumes:
        if vol not in BUSY_DIRS:
            result[vol] = min_bandwidth
            bandwidth -= min_bandwidth
        else:
            busy_volumes.append(vol)
    assert bandwidth >= min_bandwidth
    for vol in busy_volumes:
        result[vol] = round(bandwidth / len(busy_volumes))
    return result


def get_existing_limits_local(filename):
    dta = json.loads(resource.find("/{0}".format(filename)))
    return dta


def get_new_limits(source_table, disk_type):
    limits_mb = []
    yt.wrapper.config.set_proxy("hahn")

    if not source_table:
        data = yt.wrapper.read_table(AUTOMATED_GENERATED_LIMITS)
    else:
        data = yt.wrapper.read_table(source_table)

    for obs in data:
        if disk_type == "ssd":
            limits_mb.append({obs["nannyservice"]: int(obs["lmt"] / 1000000) if int(obs["lmt"] / 1000000) <= 100
            else 100})
        else:
            limits_mb.append({obs["nannyservice"]: int(obs["lmt"] / 1000000)})

    return limits_mb


def get_new_limits_net(source_table):

    limits_mb = []

    yt.wrapper.config.set_proxy("hahn")

    if not source_table:
        data = yt.wrapper.read_table(AUTOMATED_GENERATED_LIMITS_NET)
    else:
        data = yt.wrapper.read_table(source_table)

    for obs in data:
        limits_mb.append({obs["nannyservice"]: round(obs["lmt"])})

    return limits_mb


def get_limit_diff_net():
    existing_limits = json.load(open("io_limits.json"))
    services_with_net = set()

    for obj in existing_limits:
        if "net" in obj:
            services_with_net.add(obj["service_id"])

    return services_with_net


def get_limit_diff(existing_limits, source_table, disk_type, resource_type):
    new_limits, existing_limits_servicenames = [], set()

    if resource_type != "net":

        for lim_obj in existing_limits:
            has_disk = False

            if not lim_obj.get("volumes"):
                pass
            else:
                for vol in lim_obj["volumes"]:
                    if vol["storage_class"] == disk_type:
                        has_disk = True
                if has_disk is True:
                    existing_limits_servicenames.add(lim_obj["service_id"])
    else:
        existing_limits_servicenames = existing_limits

    if resource_type != "net":
        lms = get_new_limits(source_table, disk_type)
    else:
        lms = get_new_limits_net(source_table)

    for lm in lms:
        if list(lm.keys())[0] not in existing_limits_servicenames:
            new_limits.append(lm)

    return new_limits


def get_increase_limit_diff_net(source_table):

    new_limits, existing_limits_mapping = [], {}

    lms = get_new_limits_net(source_table)

    existing_limits = json.load(open("io_limits.json"))

    for record in existing_limits:
        if record.get("net"):
            existing_limits_mapping[record["service_id"]] = record["net"]["guarantee"]

    for lm in lms:
        if list(lm.keys())[0] not in existing_limits_mapping:
            new_limits.append(lm)
        else:
            nanny, new_limit = list(lm.keys())[0], list(lm.values())[0]
            old_limit = existing_limits_mapping[nanny]

            if new_limit > old_limit:
                new_limits.append({nanny: new_limit - old_limit})

    return new_limits


def get_cluster_by_hst(hst):

    try:
        return hst[0:3]
    except:
        return None


def prepare_defaults_for_missing(source_file, resource_type, disk_type, deploy_engine):

    bad_names = [re.compile(i) for i in ["mt_front-.*", "mt_vendor-.*", "mt_partner-.*", "pdb_nodejs-.*"]]

    result = []

    if resource_type == "net" or disk_type == "ssd":
        default_guarantee, default_limit = 30, 30
    else:
        default_guarantee, default_limit = 15, 30

    with open(source_file) as fle:

        for line in fle:
            name = line.replace("\n", "")

            for reg in bad_names:
                if re.search(reg, name):
                    pass
                elif "balancer" in name:
                    result.append({
                        "deploy_engine": deploy_engine,
                        "service_id": name,
                        "volumes": [
                            {
                                "mount_path": "/logs",
                                "storage_class": disk_type,
                                "guarantee": default_guarantee,
                                "limit": default_limit
                            }
                        ]
                    })
                    print("another balancer has been processed")
                else:
                    volumes_info = None

                    for i in range(3):
                        try:
                            volumes_info = extract_service_volumes(name, disk_type)
                            print("volumes extracted successfully")
                            break
                        except:
                            pass

                    if volumes_info:
                        if len(volumes_info["mount_points"]) > 0:
                            lim_obj = {
                                "deploy_engine": volumes_info["deploy_engine"],
                                "service_id": name,
                                "volumes": []
                            }

                            parsed_limits = split_quota(list(volumes_info["mount_points"]),
                                                        default_guarantee, disk_type)

                            for mp, guarant in parsed_limits.items():
                                if disk_type != "hdd":
                                    lim_obj["volumes"].append({"guarantee": guarant, "limit": guarant, "mount_path": mp,
                                                               "storage_class": args.disk_type})
                                else:
                                    lim_obj["volumes"].append({"guarantee": guarant, "limit": guarant * 2,
                                                               "mount_path": mp,
                                                               "storage_class": args.disk_type})

                            result.append(lim_obj)

    return result


def get_deploy_engine(nannyservice):

    dta = requests.get("https://nanny.yandex-team.ru/v2/services/{0}/target_state/".format(nannyservice)).json()
    return dta.get("content", {}).get("snapshot_meta_info", {}).get("annotations", {}).get("deploy_engine", "")


def extract_service_volumes(servicename, disk_type):

    api_url = "https://nanny.yandex-team.ru/v2/services/{0}/current_state/instances/".format(servicename)

    try:
        cntnrs = {
            k: v for k, v in {i.get("container_hostname"): get_cluster_by_hst(i.get("hostname")) for i in
                                    requests.get(api_url).json()["result"]}.items() if v and k and len(k) > 0 and
                                                                                       len(v) > 0
        }
    except KeyError:
        raise NoActiveContainersException("No active containers found for service {0}!".format(servicename))

    mount_points = set()

    if len(cntnrs) > 0:
        if len(cntnrs) >= 3:
            sample_pods = random.sample(list(cntnrs.keys()), 3)
        else:
            sample_pods = [list(cntnrs.keys())[0]]

        for sample_pod in sample_pods:

            try:
                with YpClient(cntnrs[sample_pod], config=dict(token=find_token())) as yp_client:

                    rst = yp_client.select_objects("pod", selectors=["/spec/disk_volume_requests"],
                                                   filter="[/meta/id]=\"{0}\"".format(sample_pod.split(".")[0]),
                                                   enable_structured_response=True
                                                   )
                    if len(rst["results"]) > 0:
                        for dt in rst["results"][0][0]["value"]:
                            if dt["storage_class"] == disk_type:
                                mount_points.add(dt["labels"]["mount_path"])

                print("pod parsed successfully")
            except:
                print('unable to parse pods')
                print(sample_pod.split(".")[0])
                print(cntnrs[sample_pod])

    else:
        raise NoActiveContainersException("No active containers found for service {0}!".format(servicename))

    return {"mount_points": mount_points, "deploy_engine": get_deploy_engine(servicename)}


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_run", type=str2bool, default=False)
    parser.add_argument("--disk_type", type=str, default="ssd")
    parser.add_argument("--source_table", type=str, default=None)
    parser.add_argument("--output_file", type=str, default="new_limits.json")
    parser.add_argument("--resource_type", type=str, default="io")
    parser.add_argument("--metadata_filename", type=str, default=None)
    parser.add_argument("--prepare_missing_defaults", type=str2bool, default=False)
    parser.add_argument("--missing_servicename_file", type=str, default=None)
    parser.add_argument("--deploy_engine", type=str, default="YP_LITE")
    parser.add_argument("--overcommit_coeff", type=int, default=2)
    parser.add_argument("--increase_limits", type=str2bool, default=False)
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, filename="limit_diff_getter.log")
    logger = logging.getLogger(__name__)

    resource_type = os.environ.get("RESOURCE_TYPE") if os.environ.get("RESOURCE_TYPE") else args.resource_type

    new_limits_full = []

    if resource_type == "io":

        existing_limits = get_existing_limits_local("io_limits")
        print("about to extract limit diff")

        if args.increase_limits is False:
            limit_dff_per_service = get_limit_diff(existing_limits, args.source_table, args.disk_type, resource_type)
        else:
            raise NotImplementedError

        if args.prepare_missing_defaults == True:
            missing_limits_defaults = prepare_defaults_for_missing(args.missing_servicename_file, resource_type,
                                                                   args.disk_type, args.deploy_engine)
            new_limits_full.extend(missing_limits_defaults)
            print("missing services has been processed")
            print("\n\n\n\n\n\n\n\n")
        elif args.metadata_filename:
            balancer_data = process_balancers(args.metadata_filename)
            new_limits_full.extend(balancer_data)
            print("balancers has been processed")
            print("\n\n\n\n\n\n\n\n")
        else:

            for lmt_guarant in limit_dff_per_service:

                print(lmt_guarant)

                serv_name, lmt = list(lmt_guarant.items())[0]
                success, attempts = False, 0
                while success is False and attempts <= 3:
                    try:

                        if args.deploy_engine == "YP_LITE":
                            result = extract_service_volumes(serv_name, args.disk_type)
                        else:
                            result = {"mount_points": [], "deploy_engine": args.deploy_engine}

                        if len(result["mount_points"]) > 0 or args.deploy_engine != "YP_LITE":
                            lim_obj = {}
                            lim_obj["deploy_engine"] = result["deploy_engine"]
                            lim_obj["service_id"] = serv_name
                            lim_obj["volumes"] = []

                            if lmt <= 78 and args.disk_type == "hdd" and args.resource_type == "io":
                                overcommit = True
                            else:
                                overcommit = False

                            if args.deploy_engine == "YP_LITE":
                                parsed_limits = split_quota(list(result["mount_points"]), lmt, args.disk_type)

                                for mp, guarant in parsed_limits.items():

                                    if overcommit is True:
                                        limit = guarant * args.overcommit_coeff
                                    else:
                                        limit = guarant

                                    lim_obj["volumes"].append({"guarantee": guarant, "limit": limit, "mount_path": mp,
                                                               "storage_class": args.disk_type})


                            else:
                                if overcommit:
                                    grnt = lmt * args.overcommit_coeff
                                else:
                                    grnt = lmt

                                lim_obj["volumes"].append({
                                    "guarantee": grnt, "limit": lmt, "storage_class": args.disk_type
                                })

                            new_limits_full.append(lim_obj)
                            success = True
                            attempts += 1
                        else:
                            attempts += 1

                    except NoActiveContainersException:
                        attempts += 1
                    except:
                        pass

                if attempts == 3:
                    logger.info("service {0} has no active containers/containers has no valid volumes".format(serv_name))
    else:

        if args.increase_limits is False:
            existing_limits = get_limit_diff_net()
            limit_dff_per_service = get_limit_diff(existing_limits, args.source_table, args.disk_type, "net")
        else:
            limit_dff_per_service = get_increase_limit_diff_net(args.source_table)

        for lmt_guarant in limit_dff_per_service:
            serv_name, lmt = list(lmt_guarant.items())[0]

            new_limits_full.append({"service_id": serv_name, "deploy_engine": "YP_LITE",
                                    "net": {"guarantee": lmt, "limit": lmt}})

    if args.local_run is True:
        json.dump(new_limits_full, open(args.output_file, "w"))
    else:
        job_context = nv.context()
        outputs = job_context.get_outputs()
        filename = os.environ.get("OUTPUT_FILE")
        with open(outputs.get(filename), "w") as write_file:
            json.dump(new_limits_full, write_file)
