# coding: utf-8

from __future__ import absolute_import, print_function

import time
import logging
from collections import defaultdict
import pickle

from yp.client import YpClient
from yt.yson import YsonEntity

import infra.rtc.iolimit_ticketer.utils as utils

YP_IVA_ADDRESS = "iva.yp.yandex.net:8090"
YP_MAN_ADDRESS = "man.yp.yandex.net:8090"
YP_MAN_PRE_ADDRESS = "man-pre.yp.yandex.net:8090"
YP_MYT_ADDRESS = "myt.yp.yandex.net:8090"
YP_SAS_ADDRESS = "sas.yp.yandex.net:8090"
YP_SAS_TEST_ADDRESS = "sas-test.yp.yandex.net:8090"
YP_VLA_ADDRESS = "vla.yp.yandex.net:8090"
YP_XDC_ADDRESS = "xdc.yp.yandex.net:8090"

YP_ADDRESSES = (
    YP_MAN_PRE_ADDRESS,
    YP_IVA_ADDRESS,
    YP_MYT_ADDRESS,
    YP_MAN_ADDRESS,
    YP_VLA_ADDRESS,
    YP_SAS_ADDRESS,
    YP_XDC_ADDRESS
)

YP_ADDRESS_MAP = {}

YP_CACHE_VERSION = 14
YP_CACHE_TTL = 24 * 60 * 60

DEFAULT_SEGMENT = "default"
DEV_SEGMENT = "dev"

VALID_SEGMENTS = {DEFAULT_SEGMENT, DEV_SEGMENT}


def _parse_account_id(account_id):
    provider_args = account_id.split(":")
    if provider_args[0] != "abc":
        raise Exception("unsupported account format: {}".format(account_id))
    return int(provider_args[2])


def get_cluster_from_address(address):
    cluster_name = address.split(".", 1)[0]
    if cluster_name == "sas-test":
        return "test_sas"
    else:
        return cluster_name.replace("-", "_")


def _fill_address_map():
    global YP_ADDRESSES
    for address in YP_ADDRESSES:
        YP_ADDRESS_MAP[get_cluster_from_address(address)] = address

_fill_address_map()


def create_yp_client(address):
    return YpClient(address, config=dict(token=utils.get_oauth_token()))


def select_objects(yp_client, object_type, limit=1000, **kwargs):
    timestamp = yp_client.generate_timestamp()

    continuation_token = None
    while True:
        options = dict(limit=limit)
        if continuation_token is not None:
            options["continuation_token"] = continuation_token

        response = yp_client.select_objects(
            object_type,
            options=options,
            enable_structured_response=True,
            timestamp=timestamp,
            **kwargs
        )

        continuation_token = response["continuation_token"]

        for value in response["results"]:
            yield value

        if len(response["results"]) < limit:
            break


def extract_fields(obj):
    return [d["value"] for d in obj]


def get_deploy_id(pod_set_id, labels):
    deploy_engine = labels.get("deploy_engine")
    if deploy_engine == "YP_LITE":
        service_id = labels.get("nanny_service_id")
    elif deploy_engine == "QYP":
        service_id = pod_set_id
    elif deploy_engine in ("MCRSC", "RSC"):
        service_id = pod_set_id
    else:
        service_id = pod_set_id
    return deploy_engine, service_id


def get_segment_map(yp_client):
    segment_map = {}
    account_map = {}
    for pod_set in select_objects(yp_client, "pod_set", selectors=["/meta/id", "/spec/node_segment_id", "/spec/account_id"]):
        pod_set_id, segment_id, account_id = extract_fields(pod_set)
        segment_map[pod_set_id] = segment_id
        account_map[pod_set_id] = account_id
    return segment_map, account_map


class PodVolumeDescriptor:

    __slots__ = ["storage_class", "mount_path", "capacity", "bandwidth_guarantee", "bandwidth_limit"]

    def __init__(self):
        self.storage_class = None
        self.mount_path = None

        self.capacity = None
        self.bandwidth_guarantee = None
        self.bandwidth_limit = None

    def has_guarantee(self):
        return bool(self.bandwidth_guarantee)

    def has_limit(self):
        return bool(self.bandwidth_limit)

    def has_guarantee_and_limit(self):
        return self.has_guarantee() and self.has_limit()


class PodResourceDescriptor:

    __slots__ = ["vcpu_guarantee", "vcpu_limit", "anonymous_memory_limit", "memory_guarantee", "memory_limit",
                 "network_bandwidth_guarantee", "network_bandwidth_limit"]

    def __init__(self):
        self.vcpu_guarantee = None
        self.vcpu_limit = None
        self.anonymous_memory_limit = None
        self.memory_guarantee = None
        self.memory_limit = None
        self.network_bandwidth_guarantee = None
        self.network_bandwidth_limit = None

    def to_tuple(self):
        return (
            self.vcpu_guarantee or 0,
            self.vcpu_limit or 0,
            self.anonymous_memory_limit or 0,
            self.memory_guarantee or 0,
            self.memory_limit or 0,
            self.network_bandwidth_guarantee or 0,
            self.network_bandwidth_limit or 0
        )


class PodDescriptor:

    __slots__ = ["pod_id", "pod_set_id", "segment_id", "node_id", "qyp_vm_node_forced",
                 "resource_desc", "volume_desc_list", "yasm_tags"]

    def __init__(self):
        self.pod_id = None
        self.pod_set_id = None
        self.segment_id = None
        self.node_id = None
        self.qyp_vm_node_forced = None

        self.resource_desc = None
        self.volume_desc_list = None
        self.yasm_tags = None

    def iter_volume_desc(self, storage_class):
        if not self.volume_desc_list:
            return
        for volume_descriptor in self.volume_desc_list:
            if volume_descriptor.storage_class == storage_class:
                yield volume_descriptor

    def has_guarantee(self, storage_class):
        return any(
            volume_descriptor.has_guarantee()
            for volume_descriptor in self.iter_volume_desc(storage_class)
        )

    def has_limit(self, storage_class):
        return any(
            volume_descriptor.has_limit()
            for volume_descriptor in self.iter_volume_desc(storage_class)
        )

    def has_guarantee_and_limit(self, storage_class):
        return any(
            volume_descriptor.has_guarantee_and_limit()
            for volume_descriptor in self.iter_volume_desc(storage_class)
        )

    def has_storage_class(self, storage_class):
        for _ in self.iter_volume_desc(storage_class):
            return True
        return False

    def has_net_guarantee(self):
        if self.resource_desc is None:
            return False
        return (self.resource_desc.network_bandwidth_guarantee or 0) > 0

    def has_net_limit(self):
        if self.resource_desc is None:
            return False
        return (self.resource_desc.network_bandwidth_limit or 0) > 0

    def has_net_guarantee_and_limit(self):
        return self.has_net_limit() and self.has_net_guarantee()


def fill_from_disk_requests(pod_descriptor, disk_volume_requests):
    """
    :type pod_descriptor: PodDescriptor
    """
    volume_desc_list = []
    for request in disk_volume_requests:
        volume_descriptor = PodVolumeDescriptor()
        volume_descriptor.storage_class = request.get("storage_class")
        volume_descriptor.mount_path = request.get("labels", {}).get("mount_path", "")
        quota_policy = request.get("quota_policy", {})
        volume_descriptor.bandwidth_guarantee = quota_policy.get("bandwidth_guarantee", 0)
        volume_descriptor.bandwidth_limit = quota_policy.get("bandwidth_limit", 0)
        volume_descriptor.capacity = quota_policy.get("capacity", 0)
        volume_desc_list.append(volume_descriptor)
    volume_desc_list.sort(key=lambda x: x.mount_path)
    pod_descriptor.volume_desc_list = tuple(volume_desc_list)


def fill_from_resource_desc(pod_descriptor, resource_requests):
    """
    :type pod_descriptor: PodDescriptor
    """
    resource_desc = PodResourceDescriptor()
    resource_desc.vcpu_guarantee = resource_requests.get("vcpu_guarantee")
    resource_desc.vcpu_limit = resource_requests.get("vcpu_limit")
    resource_desc.anonymous_memory_limit = resource_requests.get("anonymous_memory_limit")
    resource_desc.memory_guarantee = resource_requests.get("memory_guarantee")
    resource_desc.memory_limit = resource_requests.get("memory_limit")
    resource_desc.network_bandwidth_guarantee = resource_requests.get("network_bandwidth_guarantee")
    resource_desc.network_bandwidth_limit = resource_requests.get("network_bandwidth_limit")
    pod_descriptor.resource_desc = resource_desc


class ClusterDescriptor:

    __slots__ = ["cluster_name", "pods"]

    def __init__(self):
        self.cluster_name = None

        self.pods = {}  # type: dict[str, PodDescriptor]

    def has_guarantee(self, storage_class):
        return all(
            pod_descriptor.has_guarantee(storage_class)
            for pod_descriptor in self.pods.itervalues()
            if pod_descriptor.has_storage_class(storage_class)
        )

    def has_limit(self, storage_class):
        return all(
            pod_descriptor.has_limit(storage_class)
            for pod_descriptor in self.pods.itervalues()
            if pod_descriptor.has_storage_class(storage_class)
        )

    def has_guarantee_and_limit(self, storage_class):
        return all(
            pod_descriptor.has_guarantee_and_limit(storage_class)
            for pod_descriptor in self.pods.itervalues()
            if pod_descriptor.has_storage_class(storage_class)
        )

    def has_net_guarantee(self):
        return all(
            pod_descriptor.has_net_guarantee()
            for pod_descriptor in self.pods.itervalues()
        )

    def has_net_limit(self):
        return all(
            pod_descriptor.has_net_limit()
            for pod_descriptor in self.pods.itervalues()
        )

    def has_net_guarantee_and_limit(self):
        return all(
            pod_descriptor.has_net_guarantee_and_limit()
            for pod_descriptor in self.pods.itervalues()
        )

    def has_storage_class(self, storage_class):
        for pod_descriptor in self.pods.itervalues():
            if pod_descriptor.has_storage_class(storage_class):
                return True
        return False

    def get_mount_paths(self, storage_class):
        return sorted({
            volume_descriptor.mount_path
            for pod_descriptor in self.pods.values()
            for volume_descriptor in pod_descriptor.iter_volume_desc(storage_class)
        })

    def get_first_pod_id(self, storage_class=None):
        for pod_descriptor in sorted(self.pods.itervalues(), key=lambda x: x.pod_id):
            if storage_class is None:
                return pod_descriptor.pod_id
            elif pod_descriptor.has_storage_class(storage_class):
                return pod_descriptor.pod_id
        return None


class ServiceDescriptor:

    __slots__ = ["deploy_engine", "service_id", "account_id", "clusters", "deploy_stage_id", "deploy_unit_id"]

    def __init__(self):
        self.deploy_engine = None
        self.service_id = None

        self.account_id = None

        self.deploy_stage_id = None
        self.deploy_unit_id = None

        self.clusters = {}  # type: dict[str, ClusterDescriptor]

    def get_abc_id(self):
        return _parse_account_id(self.account_id)

    def has_guarantee(self, storage_class):
        return all(
            cluster_descriptor.has_guarantee(storage_class)
            for cluster_descriptor in self.clusters.itervalues()
        )

    def has_limit(self, storage_class):
        return all(
            cluster_descriptor.has_limit(storage_class)
            for cluster_descriptor in self.clusters.itervalues()
        )

    def has_guarantee_and_limit(self, storage_class):
        return all(
            cluster_descriptor.has_guarantee_and_limit(storage_class)
            for cluster_descriptor in self.clusters.itervalues()
        )

    def has_net_guarantee(self):
        return all(
            cluster_descriptor.has_net_guarantee()
            for cluster_descriptor in self.clusters.itervalues()
        )

    def has_net_limit(self, storage_class):
        return all(
            cluster_descriptor.has_net_limit()
            for cluster_descriptor in self.clusters.itervalues()
        )

    def has_net_guarantee_and_limit(self):
        return all(
            cluster_descriptor.has_net_guarantee_and_limit()
            for cluster_descriptor in self.clusters.itervalues()
        )

    def has_storage_class(self, storage_class):
        for cluster_descriptor in self.clusters.itervalues():
            if cluster_descriptor.has_storage_class(storage_class):
                return True
        return False


def fill_service_descriptor(service_descriptor, pod_descriptor, deploy_engine, service_id, cluster_name, account_id, labels):
    """
    :type service_descriptor: ServiceDescriptor
    :type pod_descriptor: PodDescriptor
    """
    service_descriptor.deploy_engine = deploy_engine
    service_descriptor.service_id = service_id
    service_descriptor.account_id = account_id

    if deploy_engine in ("RSC", "MCRSC"):
        service_descriptor.deploy_stage_id = labels.get("deploy", {}).get("stage_id")
        service_descriptor.deploy_unit_id = labels.get("deploy", {}).get("deploy_unit_id")

    cluster_descriptor = service_descriptor.clusters.setdefault(cluster_name, ClusterDescriptor())
    cluster_descriptor.cluster_name = cluster_name
    cluster_descriptor.pods[pod_descriptor.pod_id] = pod_descriptor


def extract_yasm_tags_for_yp_lite(iss_properties):
    yasm_itype = None
    yasm_prj = None
    yasm_ctype = None
    for prop in iss_properties:
        if prop["key"] == "INSTANCE_TAG_PRJ":
            yasm_prj = prop["value"]
        elif prop["key"] == "INSTANCE_TAG_ITYPE":
            yasm_itype = prop["value"]
        elif prop["key"] == "INSTANCE_TAG_CTYPE":
            yasm_ctype = prop["value"]
    return {
        "itype": yasm_itype,
        "prj": yasm_prj,
        "ctype": yasm_ctype
    }


def process_pods(yp_client, cluster_name, service_map):
    segment_map, account_map = get_segment_map(yp_client)
    for pod in select_objects(yp_client, "pod", selectors=["/meta/id", "/meta/pod_set_id", "/spec/disk_volume_requests",
                                                           "/labels", "/spec/resource_requests", "/status/scheduling/node_id",
                                                           "/spec/iss/instances/-1/properties"], limit=500):
        pod_id, pod_set_id, disk_volume_requests, labels, resource_requests, node_id, iss_properties = extract_fields(pod)
        pod_descriptor = PodDescriptor()
        pod_descriptor.pod_id = pod_id
        pod_descriptor.pod_set_id = pod_set_id
        pod_descriptor.node_id = node_id
        if labels.get("qyp_vm_forced_node_id"):
            pod_descriptor.qyp_vm_node_forced = True

        if isinstance(disk_volume_requests, YsonEntity) or isinstance(resource_requests, YsonEntity):
            continue
        if pod_set_id not in segment_map or pod_set_id not in account_map:
            continue

        fill_from_resource_desc(pod_descriptor, resource_requests)

        pod_descriptor.segment_id = segment_map[pod_set_id]
        if pod_descriptor.segment_id not in VALID_SEGMENTS:
            continue

        deploy_engine, service_id = get_deploy_id(pod_set_id, labels)
        if not deploy_engine:
            continue

        fill_from_disk_requests(pod_descriptor, disk_volume_requests)
        if not pod_descriptor.volume_desc_list:
            continue

        if deploy_engine == "YP_LITE" and not isinstance(iss_properties, YsonEntity):
            pod_descriptor.yasm_tags = extract_yasm_tags_for_yp_lite(iss_properties)

        account_id = account_map[pod_set_id]
        fill_service_descriptor(
            service_map[(deploy_engine, service_id)], pod_descriptor,
            deploy_engine, service_id, cluster_name, account_id, labels
        )


class DiskDescriptor:

    __slots__ = ["storage_class", "capacity_usage", "bandwidth_usage", "capacity_limits", "bandwidth_limits"]

    def __init__(self):
        self.storage_class = None

        self.capacity_usage = 0
        self.bandwidth_usage = 0
        self.capacity_limits = 0
        self.bandwidth_limits = 0


class NetworkDescriptor:

    __slots__ = ["bandwidth_usage", "bandwidth_limits"]

    def __init__(self):
        self.bandwidth_usage = 0
        self.bandwidth_limits = 0


class CpuDescriptor:

    __slots__ = ["cpu_usage", "cpu_limits"]

    def __init__(self):
        self.cpu_limits = 0
        self.cpu_usage = 0


class ClusterResourcesDescriptor:

    __slots__ = ["cluster_name", "segment_id", "disks", "network", "cpu"]

    def __init__(self):
        self.cluster_name = None
        self.segment_id = None

        self.disks = {}  # type: dict[str, DiskDescriptor]
        self.network = NetworkDescriptor()
        self.cpu = CpuDescriptor()

    def get_disk_descriptor(self, storage_class):
        return self.disks.setdefault(storage_class, DiskDescriptor())


class AccountDescriptor:

    __slots__ = ["account_id", "clusters"]

    def __init__(self):
        self.account_id = None

        # cluster -> segment -> resources
        self.clusters = {}  # type: dict[str, dict[str, ClusterResourcesDescriptor]]

    def get_resources_descriptor(self, cluster_name, segment_id):
        segments = self.clusters.setdefault(cluster_name, {})
        resources_descriptor = segments.setdefault(segment_id, ClusterResourcesDescriptor())
        resources_descriptor.cluster_name = cluster_name
        resources_descriptor.segment_id = segment_id
        return resources_descriptor

    def get_abc_id(self):
        return _parse_account_id(self.account_id)


def process_accounts(yp_client, cluster_name, account_map):
    filter_accounts = ["tentacles"]
    for account in select_objects(yp_client, "account", selectors=["/meta/id", "/spec/parent_id", "/spec/resource_limits/per_segment", "/status/resource_usage/per_segment"]):
        account_id, parent_id, resource_limits, resource_usage = extract_fields(account)
        if not account_id.startswith("abc:service:") and account_id not in filter_accounts:
            continue
        if parent_id:
            continue

        account_descriptor = account_map[account_id]  # type: AccountDescriptor
        account_descriptor.account_id = account_id

        if not isinstance(resource_limits, YsonEntity):
            for segment_id, resources in resource_limits.items():
                if segment_id not in VALID_SEGMENTS:
                    continue
                for storage_class, values in resources.get("disk_per_storage_class", {}).items():
                    resources_descriptor = account_descriptor.get_resources_descriptor(cluster_name, segment_id)
                    disk_descriptor = resources_descriptor.get_disk_descriptor(storage_class)
                    disk_descriptor.capacity_limits += values.get("capacity", 0)
                    disk_descriptor.bandwidth_limits += values.get("bandwidth", 0)

                if resources.get("network", {}).get("bandwidth"):
                    resources_descriptor = account_descriptor.get_resources_descriptor(cluster_name, segment_id)
                    resources_descriptor.network.bandwidth_limits += resources["network"]["bandwidth"]

                if resources.get("cpu", {}).get("capacity"):
                    account_descriptor.get_resources_descriptor(cluster_name, segment_id).cpu.cpu_limits += \
                        resources["cpu"]["capacity"] / 1000

        if not isinstance(resource_usage, YsonEntity):
            for segment_id, resources in resource_usage.items():
                if segment_id not in VALID_SEGMENTS:
                    continue
                for storage_class, values in resources.get("disk_per_storage_class", {}).items():
                    resources_descriptor = account_descriptor.get_resources_descriptor(cluster_name, segment_id)
                    disk_descriptor = resources_descriptor.get_disk_descriptor(storage_class)
                    disk_descriptor.capacity_usage += values.get("capacity", 0)
                    disk_descriptor.bandwidth_usage += values.get("bandwidth", 0)

                if resources.get("network", {}).get("bandwidth"):
                    resources_descriptor = account_descriptor.get_resources_descriptor(cluster_name, segment_id)
                    resources_descriptor.network.bandwidth_usage += resources["network"]["bandwidth"]

                if resources.get("cpu", {}).get("capacity"):
                    account_descriptor.get_resources_descriptor(cluster_name, segment_id).cpu.cpu_usage += \
                    resources["cpu"]["capacity"] / 1000


class NodeDescriptor:

    __slots__ = [
        "fqdn", "rack", "segment", "total_network_bandwidth", "used_network_bandwidth", "free_network_bandwidth"
    ]

    def __init__(self):
        self.fqdn = None
        self.rack = None
        self.segment = None
        self.total_network_bandwidth = 0
        self.used_network_bandwidth = 0
        self.free_network_bandwidth = 0


def process_nodes(yp_client, cluster_name, node_map):
    for node in select_objects(yp_client, "node", selectors=["/meta/id", "/labels/segment", "/labels/topology/rack", "/status/last_seen_time"], limit=5000):
        node_id, segment, rack, last_seen_time = extract_fields(node)
        if segment not in VALID_SEGMENTS:
            continue
        if not last_seen_time:
            continue
        node_descriptor = node_map[node_id]  # type: NodeDescriptor
        node_descriptor.fqdn = node_id
        node_descriptor.rack = rack
        node_descriptor.segment = segment

    for network_resource in select_objects(yp_client, "resource", selectors=[
        "/meta/node_id", "/spec/network/total_bandwidth", "/status/free/network/guaranteed_bandwidth",
        "/status/used/network/guaranteed_bandwidth"
    ], filter="[/meta/kind]='network'", limit=5000):
        node_id, total_network_bandwidth, free_network_bandwidth, used_network_bandwidth = \
            extract_fields(network_resource)
        node_descriptor = node_map.get(node_id)
        if not node_descriptor:
            continue
        node_descriptor.total_network_bandwidth += total_network_bandwidth
        node_descriptor.free_network_bandwidth += free_network_bandwidth
        node_descriptor.used_network_bandwidth += used_network_bandwidth


def compute_stats(cluster_name, address, service_map, account_map, node_map):
    with create_yp_client(address) as yp_client:
        process_pods(yp_client, cluster_name, service_map)
        process_accounts(yp_client, cluster_name, account_map)
        process_nodes(yp_client, cluster_name, node_map)


class YpStat:

    def __init__(self):
        self.account_map = defaultdict(AccountDescriptor)
        self.service_map = defaultdict(ServiceDescriptor)
        self.node_map = defaultdict(NodeDescriptor)

        self._version = YP_CACHE_VERSION
        self._timestamp = time.time()

    def freeze(self):
        self.service_map = dict(self.service_map)
        self.account_map = dict(self.account_map)
        self.node_map = dict(self.node_map)

    def check_actual(self):
        if self._version != YP_CACHE_VERSION:
            raise Exception("version mismatch")
        if time.time() - self._timestamp > YP_CACHE_TTL:
            raise Exception("cache expired")


def compute_yp_stat():
    stat = YpStat()
    for address in YP_ADDRESSES:
        cluster_name = get_cluster_from_address(address)
        logging.info("Get stats from YP %s", cluster_name)
        compute_stats(cluster_name, address, stat.service_map, stat.account_map, stat.node_map)
    stat.freeze()
    return stat


def get_yp_stat_cached():
    """
    :rtype: YpStat
    """
    file_name = "yp_stat.tmp"
    try:
        with open(file_name, "rb") as stream:
            result = pickle.load(stream)  # type: YpStat
            result.check_actual()
            return result
    except Exception:
        result = compute_yp_stat()
        result.check_actual()
        with open(file_name, "wb") as stream:
            pickle.dump(result, stream, protocol=pickle.HIGHEST_PROTOCOL)
        return result
