# coding: utf-8

from __future__ import absolute_import, print_function

import os
import logging
from collections import defaultdict
import pickle
import json
import requests
from jinja2 import Template, StrictUndefined
import library.python.resource as rs

OAUTH_TOKEN = os.environ.get("OAUTH_TOKEN", "")
ST_OAUTH_TOKEN = os.environ.get("ST_OAUTH_TOKEN", "")
NANNY_OAUTH_TOKEN = os.environ.get("NANNY_OAUTH_TOKEN", "")
OK_OAUTH_TOKEN = os.environ.get("OK_OAUTH_TOKEN", "")

ABC_SERVICES_CACHE_VERSION = 3
ABC_MEMBERS_CACHE_VERSION = 3
ABC_RESPONSIBLES_CACHE_VERSION = 3


MB_DIVISOR = 1024 * 1024


BUSY_DIRS_HDD = {
    "/cache",
    "/cocaine",
    "/coredumps",
    "/data",
    "/data1",
    "/data2",
    "/data3",
    "/db/bsconfig/webstate",
    "/db/mysql",
    "/ephemeral",
    "/etcd_data",
    "/fresh",
    "/gear",
    "/graph",
    "/images",
    "/joker_stubs",
    "/local",
    "/morkva",
    "/perm",
    "/persistent",
    "/persistent-data",
    "/place",
    "/rem",
    "/resources",
    "/resources_storage",
    "/sandbox",
    "/srv",
    "/state",
    "/storage",
    "/tmp",
    "/var/bases",
    "/var/cache/yandex",
    "/var/exports",
    "/var/hdd_place",
    "/var/lib/ichwill",
    "/var/lib/push-client",
    "/var/lib/teamcity",
    "/var/lib/yandex/maps/ecstatic",
    "/var/lib/yandex/maps/wiki/stat",
    "/var/spool",
    "/var/spool/ichwill",
    "/webstate",
    "/worker_data",
    "/xurma",
    "/zen"
}


BUSY_DIRS_SSD = {
    "/amqp_data",
    "/cache",
    "/cm",
    "/dajr",
    "/data",
    "/data/db",
    "/data_ssd",
    "/db",
    "/db/bsconfig/webcache",
    "/db/bsconfig/webstate",
    "/db/mysql",
    "/ephemeral",
    "/fastbuild",
    "/fast_data",
    "/fresh",
    "/graph",
    "/local",
    "/mongo",
    "/opt/zookeeper",
    "/persistent",
    "/rem",
    "/resources",
    "/shard_root",
    "/ssd",
    "/state",
    "/storage",
    "/usr/local/teamcity-agents",
    "/var/bases",
    "/var/cache/nginx",
    "/var/cache/yandex",
    "/var/exports",
    "/var/lib/etcd",
    "/var/lib/nginx",
    "/var/lib/teamcity",
    "/var/lib/yandex/maps/carparks",
    "/var/lib/yandex/maps/ecstatic",
    "/var/lib/yandex/tanker",
    "/var/www/layers-mobile/carparks",
    "/xurma",
    "/zoo_data"
}


def read_token():
    global OAUTH_TOKEN
    if OAUTH_TOKEN:
        return
    try:
        with open(os.path.expanduser("~/.yp/token")) as stream:
            OAUTH_TOKEN = stream.read().strip()
    except Exception:
        logging.exception("Can't get token")


def read_st_token():
    global ST_OAUTH_TOKEN
    if ST_OAUTH_TOKEN:
        return
    try:
        with open(os.path.expanduser("~/.st/token")) as stream:
            ST_OAUTH_TOKEN = stream.read().strip()
    except Exception:
        logging.exception("Can't get st token")


def read_nanny_token():
    global NANNY_OAUTH_TOKEN
    if NANNY_OAUTH_TOKEN:
        return
    try:
        with open(os.path.expanduser("~/.nanny/token")) as stream:
            NANNY_OAUTH_TOKEN = stream.read().strip()
    except Exception:
        logging.exception("Can't get nanny token")


def read_ok_token():
    global OK_OAUTH_TOKEN
    if OK_OAUTH_TOKEN:
        return
    try:
        with open(os.path.expanduser("~/.ok/token")) as stream:
            OK_OAUTH_TOKEN = stream.read().strip()
    except Exception:
        logging.exception("Can't get OK token")


def get_oauth_token():
    return OAUTH_TOKEN


def get_st_oauth_token():
    return ST_OAUTH_TOKEN


def get_nanny_oauth_token():
    return NANNY_OAUTH_TOKEN


def get_ok_oauth_token():
    return OK_OAUTH_TOKEN


def iter_abc_members():
    url_template = "https://abc-back.yandex-team.ru/api/v4/services/members/?cursor={}-0&ordering=id&fields=id,person.id,person.login,service.id,role.scope.slug"
    cursor_id = 0
    while True:
        reply = requests.get(url_template.format(cursor_id), headers={"Authorization": "OAuth {}".format(get_st_oauth_token())}).json()
        logging.info("Got ABC members, cursor %d", cursor_id)
        for row in reply["results"]:
            cursor_id = max(cursor_id, row["id"])
            if row["role"]["scope"]["slug"] not in ("services_management", "administration"):
                continue
            if row["person"]["login"].startswith("robot-") or row["person"]["login"].endswith("-robot"):
                continue
            yield int(row["service"]["id"]), row["person"]["id"], row["person"]["login"]
        if not reply["results"]:
            break


def get_abc_members_cached():
    file_name = "abc_members_cache.tmp"
    try:
        with open(file_name, "rb") as stream:
            cache_version, result = pickle.load(stream)
            if cache_version != ABC_MEMBERS_CACHE_VERSION:
                raise Exception("old cache")
            return result
    except Exception:
        result = defaultdict(list)
        for abc_service_id, person_id, login in iter_abc_members():
            result[abc_service_id].append((person_id, login))
        result = dict(result)
        with open(file_name, "wb") as stream:
            pickle.dump((ABC_MEMBERS_CACHE_VERSION, result), stream, protocol=pickle.HIGHEST_PROTOCOL)
        return result


def iter_abc_services():
    url_template = "https://abc-back.yandex-team.ru/api/v4/services/?cursor={}-0&ordering=id&fields=id,name,slug"
    cursor_id = 0
    while True:
        reply = requests.get(url_template.format(cursor_id), headers={"Authorization": "OAuth {}".format(get_st_oauth_token())}).json()
        logging.info("Got ABC service, cursor %d", cursor_id)
        for row in reply.get("results", []):
            cursor_id = max(cursor_id, row["id"])
            yield int(row["id"]), row["name"], row["slug"]
        if not reply.get("results"):
            break


def get_abc_services_cached():
    file_name = "abc_services_cache.tmp"
    try:
        with open(file_name, "rb") as stream:
            cache_version, result = pickle.load(stream)
            if cache_version != ABC_SERVICES_CACHE_VERSION:
                raise Exception("old cache")
            return result
    except Exception:
        result = {}
        for abc_service_id, name, slug in iter_abc_services():
            result[abc_service_id] = (name, slug)
        with open(file_name, "wb") as stream:
            pickle.dump((ABC_SERVICES_CACHE_VERSION, result), stream, protocol=pickle.HIGHEST_PROTOCOL)
        return result


def iter_abc_responsibles():
    next_url = "https://abc-back.yandex-team.ru/api/v4/services/responsibles/?fields=id,person.id,person.login,service.id"
    while True:
        reply = requests.get(next_url, headers={"Authorization": "OAuth {}".format(get_st_oauth_token())}).json()
        logging.info("Got ABC responsibles, url %s", next_url)
        for row in reply["results"]:
            if row["person"]["login"].startswith("robot-") or row["person"]["login"].endswith("-robot"):
                continue
            yield int(row["service"]["id"]), row["person"]["id"], row["person"]["login"]
        if not reply["next"]:
            break
        next_url = reply["next"]


def get_abc_responsibles_cached():
    file_name = "abc_responsibles_cache.tmp"
    try:
        with open(file_name, "rb") as stream:
            cache_version, result = pickle.load(stream)
            if cache_version != ABC_RESPONSIBLES_CACHE_VERSION:
                raise Exception("old cache")
            return result
    except Exception:
        result = defaultdict(list)
        for abc_service_id, person_id, login in iter_abc_responsibles():
            result[abc_service_id].append((person_id, login))
        result = dict(result)
        with open(file_name, "wb") as stream:
            pickle.dump((ABC_RESPONSIBLES_CACHE_VERSION, result), stream, protocol=pickle.HIGHEST_PROTOCOL)
        return result


def get_io_limit_map():
    with open("io_limits.json") as stream:
        records = json.load(stream)
    io_limit_map = {}
    for record in records:
        per_deploy_engine = io_limit_map.setdefault(record["deploy_engine"], {})
        per_service_id = per_deploy_engine.setdefault(record["service_id"], {})

        if record.get("volumes"):
            if record["deploy_engine"] == "YP_LITE":
                for volume in record["volumes"]:
                    per_service_id[(volume["mount_path"], volume["storage_class"])] = {
                        "guarantee": volume["guarantee"],
                        "limit": volume["limit"],
                        "storage_class": volume["storage_class"]
                    }
            elif record["deploy_engine"] in ("MCRSC", "RSC", "QYP"):
                for volume in record["volumes"]:
                    per_service_id[volume["storage_class"]] = {
                        "guarantee": volume["guarantee"],
                        "limit": volume["limit"],
                        "storage_class": volume["storage_class"]
                    }

        if record.get("net"):
            per_service_id["net"] = {
                "guarantee": record["net"]["guarantee"],
                "limit": record["net"]["limit"]
            }

    return io_limit_map


def get_net_limit_map():
    with open("io_limits.json") as stream:
        records = json.load(stream)
    net_limit_map = {}
    for record in records:
        if "net" not in record:
            continue
        per_deploy_engine = net_limit_map.setdefault(record["deploy_engine"], {})
        per_deploy_engine[record["service_id"]] = {
            "guarantee": record["net"]["guarantee"],
            "limit": record["net"]["limit"]
        }
    return net_limit_map


def get_quota_map():
    with open("abc_quotas.json") as stream:
        quotas = json.load(stream)
    quota_map = {}
    for abc_service_id, cluster_quota_list in quotas.items():
        if not abc_service_id.isdigit():
            continue
        target = quota_map.setdefault(int(abc_service_id), {})
        for cluster_quota in cluster_quota_list:
            if cluster_quota["quota_type"] != "io":
                continue
            target.setdefault(cluster_quota["cluster"], {})[cluster_quota["storage_class"]] = cluster_quota["quota"]
    return quota_map


def check_if_full_vm(service_stat):

    for cluster_descriptor in service_stat.clusters.values():
        for pod in cluster_descriptor.pods.values():
            if pod.qyp_vm_node_forced:
                return True

    return False


def group_services_by_abc(service_map, target_storage_class=None, target_deploy_engines=None):
    """
    :rtype: dict[int, dict[(string, string), yp_model.ServiceDescriptor]]
    """
    result = defaultdict(dict)
    for key, service_stat in service_map.items():

        deploy_engine, _ = key
        if target_storage_class and not service_stat.has_storage_class(target_storage_class):
            continue
        if target_deploy_engines and deploy_engine not in target_deploy_engines:
            continue

        if deploy_engine == "QYP":
            if check_if_full_vm(service_stat):
                continue

        account_id = service_stat.account_id
        provider_args = account_id.split(":")
        if provider_args[0] != "abc":
            continue
        abc_service_id = int(provider_args[2])
        result[abc_service_id][key] = service_stat
    return dict(result)


def have_all_services_io_limits(services, storage_class):
    return all(service_stat.has_guarantee_and_limit(storage_class) for service_stat in services.values())


def have_all_services_net_limits(services):
    return all(service_stat.has_net_guarantee_and_limit() for service_stat in services.values())


def have_all_services_net_guarantees(services):
    return all(service_stat.has_net_guarantee() for service_stat in services.values())


def render_template(content, **kwargs):
    tpl = Template(content, undefined=StrictUndefined)
    return tpl.render(**kwargs)


def render_template_from_resource(name, **kwargs):
    return render_template(rs.find(name).decode('utf-8'), **kwargs).strip()


def split_quota(volumes, bandwidth, disk_type):
    if len(volumes) == 1:
        return {list(volumes)[0]: bandwidth}

    if disk_type == "ssd":
        BUSY_DIRS = BUSY_DIRS_SSD
        min_volume = 10
    else:
        BUSY_DIRS = BUSY_DIRS_HDD
        min_volume = 5

    min_bandwidth = int(bandwidth / len(volumes))

    has_busy_dirs = any(x in BUSY_DIRS for x in volumes)
    if not has_busy_dirs:
        if "/" not in volumes:
            result = {}
            for vol in volumes:
                if bandwidth <= min_bandwidth:
                    min_bandwidth = bandwidth
                result[vol] = min_bandwidth
                bandwidth -= min_bandwidth
            if bandwidth > 0:
                result[list(result.keys())[-1]] += bandwidth
            return result

        else:
            min_bandwidth = int(min(bandwidth / len(volumes), min_volume))
            result = {}
            for vol in volumes:
                if vol != "/":
                    result[vol] = min_bandwidth
                    bandwidth -= min_bandwidth
            assert bandwidth >= min_bandwidth
            result["/"] = bandwidth
            return result

    min_bandwidth = int(min(bandwidth / len(volumes), min_volume))
    result = {}
    busy_volumes = []
    for vol in volumes:
        if vol not in BUSY_DIRS:
            result[vol] = min_bandwidth
            bandwidth -= min_bandwidth
        else:
            busy_volumes.append(vol)
    assert bandwidth >= min_bandwidth
    for vol in busy_volumes:
        result[vol] = round(bandwidth / len(busy_volumes))
    return result
