# coding: utf-8

from __future__ import absolute_import, print_function

import time
from collections import defaultdict
import pickle
import logging

from yt.yson import YsonEntity
from infra.ya_salt.proto import ya_salt_pb2
import infra.rtc.iolimit_ticketer.yp_model as yp_model

HM_CACHE_VERSION = 3
HM_CACHE_TTL = 24 * 60 * 60
HOST_TTL = 14 * 24 * 60 * 60


class DmDisk:

    __slots__ = ["device_path", "capacity_bytes", "storage_class"]

    def __init__(self):
        self.device_path = None
        self.capacity_bytes = None
        self.storage_class = None


class OopsDisk:

    __slots__ = ["mount_point", "disk_type", "fs_size", "hw_info", "slaves", "name"]

    def __init__(self):
        self.mount_point = None
        self.disk_type = None
        self.fs_size = None
        self.hw_info = None
        self.slaves = None
        self.name = None


class HostManagerDescriptor:

    __slots__ = ["fqdn", "cluster", "disks", "oops_disks", "kernel_version"]

    def __init__(self):
        self.fqdn = None
        self.cluster = None
        self.disks = None
        self.oops_disks = None
        self.kernel_version = None

    def set_disks(self, disks):
        result = []
        for obj in disks:
            descriptor = DmDisk()
            descriptor.device_path = obj.spec.device_path
            descriptor.capacity_bytes = obj.spec.capacity_bytes
            descriptor.storage_class = obj.spec.storage_class
            result.append(descriptor)
        self.disks = result

    def set_oops_disks(self, oops_disks):
        result = []
        for spec in oops_disks:
            descriptor = OopsDisk()
            descriptor.mount_point = spec.mountPoint
            descriptor.disk_type = spec.type
            descriptor.fs_size = spec.fsSize
            descriptor.hw_info = spec.hwInfo
            descriptor.slaves = list(spec.slaves)
            descriptor.name = spec.name
            result.append(descriptor)
        self.oops_disks = result


def process_nodes(yp_client, cluster_name, node_map):
    for node in yp_model.select_objects(yp_client, "node", selectors=["/meta/id", "/labels/segment", "/status/last_seen_time", "/status/host_manager"]):
        node_id, segment, last_seen_time, host_manager = yp_model.extract_fields(node)
        if segment != yp_model.DEFAULT_SEGMENT:
            continue
        if last_seen_time / 10**6 + HOST_TTL < time.time():
            continue
        if isinstance(host_manager, YsonEntity):
            continue

        descriptor = node_map[node_id]  # type: HostManagerDescriptor
        descriptor.fqdn = node_id
        descriptor.cluster = cluster_name

        if descriptor.oops_disks is not None:
            continue

        hostman_status = ya_salt_pb2.HostmanStatus()
        hostman_status.MergeFromString(host_manager["value"])
        descriptor.kernel_version = hostman_status.node_info.os_info.kernel
        descriptor.set_disks(hostman_status.node_info.disks)
        descriptor.set_oops_disks(hostman_status.node_info.oops_disks)


def compute_stats(cluster_name, address, node_map):
    with yp_model.create_yp_client(address) as yp_client:
        process_nodes(yp_client, cluster_name, node_map)


class HostManagerStat:

    def __init__(self):
        self.node_map = defaultdict(HostManagerDescriptor)

        self._version = HM_CACHE_VERSION
        self._timestamp = time.time()

    def freeze(self):
        self.node_map = dict(self.node_map)

    def check_actual(self):
        if self._version != HM_CACHE_VERSION:
            raise Exception("version mismatch")
        if time.time() - self._timestamp > HM_CACHE_TTL:
            raise Exception("cache expired")


def compute_hm_stat():
    stat = HostManagerStat()
    for address in yp_model.YP_ADDRESSES:
        cluster_name = yp_model.get_cluster_from_address(address)
        logging.info("Get host manager from YP %s", cluster_name)
        compute_stats(cluster_name, address, stat.node_map)
    stat.freeze()
    return stat


def get_hm_stat_cached():
    """
    :rtype: HostManagerStat
    """
    file_name = "hm_stat.tmp"
    try:
        with open(file_name, "rb") as stream:
            result = pickle.load(stream)  # type: HostManagerStat
            result.check_actual()
            return result
    except Exception:
        result = compute_hm_stat()
        result.check_actual()
        with open(file_name, "wb") as stream:
            pickle.dump(result, stream, protocol=pickle.HIGHEST_PROTOCOL)
        return result
