# coding: utf-8

from __future__ import absolute_import, print_function

from pprint import pprint, pformat

import time
import logging
import random
import requests
import json
import os
import sys
import argparse
from collections import defaultdict
import pickle

from yp.client import YpClient
from yt.yson import YsonEntity
import nirvana.job_context as nv

import infra.analytics.io_limits_pipeline.utils as utils

NANNY_URL = 'https://nanny.yandex-team.ru/api/repo'
NANNY_NOTIFICATION_ID = "io-read-limit-adoption"

YP_IVA_ADDRESS = "iva.yp.yandex.net:8090"
YP_MAN_ADDRESS = "man.yp.yandex.net:8090"
YP_MAN_PRE_ADDRESS = "man-pre.yp.yandex.net:8090"
YP_MYT_ADDRESS = "myt.yp.yandex.net:8090"
YP_SAS_ADDRESS = "sas.yp.yandex.net:8090"
YP_SAS_TEST_ADDRESS = "sas-test.yp.yandex.net:8090"
YP_VLA_ADDRESS = "vla.yp.yandex.net:8090"

YP_ADDRESSES = (
    YP_MAN_PRE_ADDRESS,
    YP_IVA_ADDRESS,
    YP_MYT_ADDRESS,
    YP_MAN_ADDRESS,
    YP_VLA_ADDRESS,
    YP_SAS_ADDRESS
)


def filter_useless_tag_combs(res):
    if "itype=common" in res or "tier=self" in res:
        return False
    else:
        return True


def get_tags_by_hosts(hst):
    """
    https://yasm.yandex-team.ru/metainfo/tags/?hosts=vla2-9427.search.yandex.net&tag_format=unified
    """
    data, tagstring_len = '', 0

    url = "https://yasm.yandex-team.ru/metainfo/tags/?hosts=" + hst + "&tag_format=unified"
    result = list(filter(filter_useless_tag_combs, requests.get(url).json()["response"]["result"]))

    for res in result:
        if len(res) > tagstring_len:
            data = res

    return data


def get_nanny_yasm_tagstring(nannyname, dct_to_enrich, missing_in_nanny):
    try:
        api_url = "https://nanny.yandex-team.ru/v2/services/{0}/current_state/instances/".format(nannyname)
        try:
            dta = [i["container_hostname"] for i in requests.get(api_url).json()["result"]]
        except:
            logging.warning('retrying nanny api after 15 seconds')
            time.sleep(15)
            dta = [i["container_hostname"] for i in requests.get(api_url).json()["result"]]
        dct_to_enrich[nannyname] = dta
    except (KeyError, AttributeError):
        missing_in_nanny.add(nannyname)


def create_yasm_tagstring(taglist):
    """
    нужные тэги: ctype, itype, prj
    nanny не обязателен, так как он тащится из имени сервиса
    """

    tagstring = ""

    for tag in taglist:
        if tag.startswith("a_itype"):
            tagstring += "itype=" + tag.replace('a_itype_', '') + ";"
        elif tag.startswith("a_ctype"):
            tagstring += "ctype=" + tag.replace('a_ctype_', '') + ";"
        elif tag.startswith("a_prj"):
            tagstring += "prj=" + tag.replace('a_prj_', '') + ";"

    return tagstring


def get_nanny_tags_json(servicename):
    base_url = "https://nanny.yandex-team.ru/v2/services/{0}/current_state/instances/".format(servicename)
    host_tags_mapping = {}

    try:
        try:
            host_tags_mapping = {i["container_hostname"]: {"tags": create_yasm_tagstring(i["itags"]),
                                                           "host": i["hostname"]}
                                 for i in requests.get(base_url).json()["result"]}
        except:
            logging.warning("retrying nanny api after 15 seconds")
            time.sleep(15)
            host_tags_mapping = {i["container_hostname"]: {"tags": create_yasm_tagstring(i["itags"]),
                                                           "host": i["hostname"]}
                                 for i in requests.get(base_url).json()["result"]}
    except:
        return host_tags_mapping

    return host_tags_mapping


def select_objects(yp_client, object_type, limit=1000, **kwargs):

    continuation_token = None
    while True:
        options = dict(limit=limit)
        if continuation_token is not None:
            options["continuation_token"] = continuation_token

        response = yp_client.select_objects(
            object_type,
            options=options,
            enable_structured_response=True,
            **kwargs
        )

        continuation_token = response["continuation_token"]

        for value in response["results"]:
            yield value

        if len(response["results"]) < limit:
            break


def extract_fields(obj):
    return [d["value"] for d in obj]


def extract_yasm_tags_deploy(podname, address):
    with YpClient(address, config=dict(token=utils.get_oauth_token())) as yp_client:
        data = yp_client.select_objects(
            "pod", selectors=["/spec/node_id", "/spec/host_infra/monitoring/labels"],
            filter="[/meta/id]='{0}'".format(podname)
        )
        try:
            return data[0][0], data[0][1]
        except:
            logging.error(pformat(data))
            raise


def check_disk_requests(disk_volume_requests, disk_type="ssd"):
    has_guarantee = True
    has_limit = True
    disk_volume_desc = []
    for request in disk_volume_requests:
        if request.get("storage_class") != disk_type:
            continue
        mount_path = request.get("labels", {}).get("mount_path", "")
        quota_policy = request.get("quota_policy", {})
        bandwidth_guarantee = quota_policy.get("bandwidth_guarantee", 0)
        bandwidth_limit = quota_policy.get("bandwidth_limit", 0)
        if not bandwidth_guarantee:
            has_guarantee = False
        if not bandwidth_limit:
            has_limit = False
        capacity = quota_policy.get("capacity", 0)
        disk_volume_desc.append((mount_path, capacity))
    disk_volume_desc.sort()
    return has_guarantee, has_limit, tuple(disk_volume_desc)


def get_resource_desc(resource_requests):
    return (
        resource_requests.get("vcpu_guarantee"),
        resource_requests.get("vcpu_limit"),
        resource_requests.get("anonymous_memory_limit"),
        resource_requests.get("memory_guarantee"),
        resource_requests.get("memory_limit")
    )


def get_deploy_id(pod_set_id, labels):
    deploy_engine = labels.get("deploy_engine")
    if deploy_engine == "YP_LITE":
        service_id = labels.get("nanny_service_id")
    elif deploy_engine == "QYP":
        service_id = pod_set_id
    else:
        service_id = pod_set_id
    return deploy_engine, service_id


def get_segment_map(yp_client, abc_services):
    segment_map = {}
    account_map = {}

    if abc_services is None:
        yp_objects = select_objects(yp_client, "pod_set", selectors=["/meta/id", "/spec/node_segment_id",
                                                                   "/spec/account_id"])
    else:
        account_filter = "or".join(["[/meta/account_id]='abc:service:{0}'".format(i) for i in abc_services.split(" ")])
        yp_objects = select_objects(yp_client, "pod_set", selectors=["/meta/id", "/spec/node_segment_id",
                                                                   "/spec/account_id"], filter=account_filter)

    for pod_set in yp_objects:
        pod_set_id, segment_id, account_id = extract_fields(pod_set)
        segment_map[pod_set_id] = segment_id
        account_map[pod_set_id] = account_id
    return segment_map, account_map


def create_service_stat():
    return dict(has_guarantee=None, has_limit=None, account_id=None, pods=dict())


def combine_service_stat(target, cluster_name, account_id, has_guarantee, has_limit, pod_id, resource_desc,
                         disk_volume_desc):
    target["account_id"] = account_id
    if not has_guarantee:
        target["has_guarantee"] = has_guarantee
    elif target["has_guarantee"] is None:
        target["has_guarantee"] = has_guarantee
    if not has_limit:
        target["has_limit"] = has_limit
    elif target["has_limit"] is None:
        target["has_limit"] = has_limit
    pods_in_cluster = target["pods"].setdefault(cluster_name, {})
    pods_in_cluster[pod_id] = (resource_desc, disk_volume_desc)


def process_pods(yp_client, segment_map, account_map, cluster_name, service_map, disk_type, segment_name, research="io"):
    for pod in select_objects(yp_client, "pod", selectors=["/meta/id", "/meta/pod_set_id", "/spec/disk_volume_requests",
                                                           "/labels", "/spec/resource_requests"]):
        pod_id, pod_set_id, disk_volume_requests, labels, resource_requests = extract_fields(pod)

        if pod_set_id not in segment_map or pod_set_id not in account_map:
            continue

        if research == "io":
            if isinstance(disk_volume_requests, YsonEntity):
                continue

        segment_id = segment_map[pod_set_id]
        if segment_id != segment_name:
            continue

        deploy_engine, service_id = get_deploy_id(pod_set_id, labels)
        if not deploy_engine:
            continue

        if research == "io":
            has_guarantee, has_limit, disk_volume_desc = check_disk_requests(disk_volume_requests, disk_type)

            if not disk_volume_desc:
                continue
        else:
            has_guarantee, has_limit, disk_volume_desc = True, True, []

        account_id = account_map[pod_set_id]
        combine_service_stat(
            service_map[(deploy_engine, service_id)], cluster_name,
            account_id, has_guarantee, has_limit,
            pod_id, get_resource_desc(resource_requests), disk_volume_desc
        )


def compute_stats(cluster_name, address, service_map, disk_type, segment_name, research, abc_services):
    with YpClient(address, config=dict(token=utils.get_oauth_token())) as yp_client:
        segment_map, account_map = get_segment_map(yp_client, abc_services)
        process_pods(yp_client, segment_map, account_map, cluster_name, service_map, disk_type, segment_name, research)


def compute_yp_stat(disk_type, segment_name, research, abc_services):
    service_map = defaultdict(create_service_stat)
    for address in YP_ADDRESSES:
        cluster_name = address.split(".", 1)[0]
        compute_stats(cluster_name, address, service_map, disk_type, segment_name, research, abc_services)
    return dict(service_map)


def get_yp_stat(disk_type, segment_name, research, abc_services):
    result = compute_yp_stat(disk_type, segment_name, research, abc_services)
    return result


def parse_yp_services(service_map, data, tag_extraction, deploy_method):

    for k, v in service_map.items():
        pods_sets_by_location = {'iva': [], 'man': [], 'myt': [], 'sas': [], 'vla': []}
        if k[0] == deploy_method and k[1]:

            servicename = k[1]

            pod_urls_full = []
            pods = v["pods"]
            pods_cnt = 0

            for location, pds in pods.items():

                if tag_extraction == "nanny":
                    nanny_tags = get_nanny_tags_json(servicename)

                for podname in pds.keys():
                    successfully_extracted = True

                    pod_data = {"pod_name": podname, "dc": location, "ssd": True,
                                          "service_name": servicename}
                    try:
                        if tag_extraction == "nanny":
                            pod_data["yasm_tag"] = nanny_tags[".".join([podname, location,
                                                                        "yp-c.yandex.net"])]["tags"][0:-1]

                        if deploy_method != "QYP" and deploy_method not in ("MCRSC", "RSC"):
                            pod_data["host"] = nanny_tags[".".join([podname, location,
                                                                        "yp-c.yandex.net"])]["host"]
                        else:
                            pod_data["host"] = ".".join([podname, location, "yp-c.yandex.net"])

                        if tag_extraction == "deploy":
                            pod_data["host"], pod_data["yasm_tag"] = extract_yasm_tags_deploy(podname, location)

                    except:
                        logging.warning("failed to get nanny tags for pod name {0}".format(".".join([podname,
                                                                                                     location,
                                                                                                     "yp-c.yandex.net"])))
                        successfully_extracted = False

                    if successfully_extracted:
                        pods_cnt += 1
                        pods_sets_by_location[location[:3]].append(pod_data)

            logging.info("this service {0} has this many pods {1} (sas/vla/man/myt/iva: {2}/{3}/{4}/{5}/{6})".format(servicename,
                                                                                                                     pods_cnt,
                                                                                                                     len(pods_sets_by_location['sas']),
                                                                                                                     len(pods_sets_by_location['vla']),
                                                                                                                     len(pods_sets_by_location['man']),
                                                                                                                     len(pods_sets_by_location['myt']),
                                                                                                                     len(pods_sets_by_location['iva'])))
            for k, v in pods_sets_by_location.items():
                if len(v) > 10:
                    logging.info('service {0} has more than 10 pods in {1}, sampling 10'.format(servicename, location))
                    pod_urls_full.extend(random.sample(v, 10))
                else:
                    pod_urls_full.extend(v)

            for pod in pod_urls_full:
                url = ".".join([pod["pod_name"], pod["dc"], "yp-c.yandex.net"])
                if tag_extraction == "yasm":
                    pod["yasm_tag"] = get_tags_by_hosts(url)

            logging.debug("for servicename {0} we have data {1}".format(servicename, pod_urls_full))

            data.extend(pod_urls_full)


def transpose_pods_dictionary(pddct):

    result = []

    for k, v in pddct.items():
        for pod in v:
            pod_string = {}
            pod_string["host"] = k
            pod_string["pod_name"] = pod["pod_name"]
            pod_string["service_name"] = pod["service_name"]
            pod_string["yasm_tag"] = pod["yasm_tag"]
            pod_string["dc"] = pod["dc"]
            pod_string["ssd"] = True
            result.append(pod_string)

    return result


def main(output_file, tag_extraction="nanny", disk_type="ssd", test_run=False, local_run=True, deploy_system="YP_LITE",
         segment_name="default", research="io", abc_services=None):

    logging.info("launching data collection for following disk type {0}".format(disk_type))

    stat = get_yp_stat_cached(disk_type, segment_name, abc_services)

    service_map = {k: v for k, v in stat.items() if k[0] == deploy_system and k[1]}

    data = []

    if test_run is True:
        selected_keys = random.sample(list(service_map.keys()), min(len(list(service_map.keys())), 50))
        service_map = {k: v for k, v in service_map.items() if k in selected_keys}

    parse_yp_services(service_map, data, tag_extraction, deploy_system)

    final_data = {}

    for obj in data:
        try:
            if obj["host"] not in final_data:
                final_data[obj["host"]] = [obj]
            else:
                final_data[obj["host"]].append(obj)
        except KeyError:
            pass

    if local_run is True:
        with open(output_file, 'wb') as f:
            logging.info("this is what we are going to insert")
            pickle.dump(final_data, f)
    else:
        job_context = nv.context()
        outputs = job_context.get_outputs()

        with open(outputs.get(output_file), "w") as write_file:
            json.dump(transpose_pods_dictionary(final_data), write_file)
    logging.info("data about {0} pods on {1} hosts saved into {2}".format(len(data), len(final_data.keys()), output_file))


def get_yp_stat_cached(disk_type, segment_name, abc_services):
    """
    :rtype: YpStat
    """
    file_name = "yp_stat.tmp"
    try:
        with open(file_name, "rb") as stream:
            result = pickle.load(stream)
            # result.check_actual()
            return result
    except Exception:
        result = get_yp_stat(disk_type, segment_name, research, abc_services)
        with open(file_name, "wb") as stream:
            pickle.dump(result, stream, protocol=pickle.HIGHEST_PROTOCOL)
        return result


if __name__ == "__main__":
    logging.basicConfig(
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler("pod_meta_extraction.log")
        ],
        level=logging.INFO,
        format='%(asctime)s %(levelname)s %(module)s - %(funcName)s: %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
        )
    utils.read_token()
    parser = argparse.ArgumentParser()
    parser.add_argument("--yasm_tag_extraction", type=str, default="nanny")
    parser.add_argument("--disk_type", type=str, default="ssd")
    parser.add_argument("--deploy_system", type=str, default="YP_LITE")
    parser.add_argument("--output_filename", type=str, default="host_pod_tag_hdd_full.pickle")
    parser.add_argument("--test_run", type=utils.str2bool, default=False)
    parser.add_argument("--local_run", type=utils.str2bool, default=False)
    parser.add_argument("--segment", type=str, default="default")
    parser.add_argument("--research", type=str, default="io")
    parser.add_argument("--abc_services", type=str, default=None, help="Ids of abc services, "
                                                                              "separated by whitespated")
    args = parser.parse_args()

    # extracting parameters from nirvana enviroment
    disk_type = args.disk_type if not os.environ.get("DISK_TYPE") else os.environ.get("DISK_TYPE")
    output_file = args.output_filename if not os.environ.get("OUTPUT_FILE") else os.environ.get("OUTPUT_FILE")
    tag_extraction = args.yasm_tag_extraction if not os.environ.get("TAGS_EXTRACTION") else \
        os.environ.get("TAGS_EXTRACTION")
    deploy_system = args.deploy_system if not os.environ.get("DEPLOY_SYSTEM") else os.environ.get("DEPLOY_SYSTEM")
    segment = args.segment if not os.environ.get("SEGMENT") else os.environ.get("SEGMENT")
    research = args.research if not os.environ.get("RESEARCH") else os.environ.get("RESEARCH")


    main(output_file, tag_extraction=tag_extraction, disk_type=disk_type, local_run=args.local_run,
         test_run=args.test_run, deploy_system=deploy_system, segment_name=segment, research=research,
         abc_services=args.abc_services)
