from collections import Counter

from cachetools.func import ttl_cache

from sepelib.core import config
from sepelib.core.constants import HOUR_SECONDS, MINUTE_SECONDS, WEEK_SECONDS
from walle.util import db_cache

RACK_IDENTIFICATION_THRESHOLD = 50
RACK_TOPOLOGY_HOT_CACHE_SIZE = 10000
RACK_TOPOLOGY_HOT_CACHE_TTL = 30 * MINUTE_SECONDS
RACK_TOPOLOGY_COLD_CACHE_TTL = 2 * HOUR_SECONDS
RACK_TOPOLOGY_CACHE_ID = "rack_topology"


class RackInfo:
    def __init__(self, name, data):
        self.name = name
        self.systems = data.get("systems", [])
        self.ranges = []
        for r in data.get("slot_ranges", []):
            self.ranges.append((r["min"], r["max"]))

    def get_unit_range_index(self, unit):
        index = 0
        for r in self.ranges:
            if r[0] <= int(unit) <= r[1]:
                return index
            index += 1
        return -1


class RackTopology:
    def __init__(self, aggregate, rack_model, hosts_ranges, total_ranges):
        self.aggregate = aggregate
        self.rack_model = rack_model
        self.hosts_ranges = hosts_ranges
        self.total_ranges = total_ranges

    def __eq__(self, other):
        return (
            self.aggregate == other.aggregate
            and self.rack_model == other.rack_model
            and self.hosts_ranges == other.hosts_ranges
            and self.total_ranges == other.total_ranges
        )

    @classmethod
    def from_hosts(cls, aggregate, hosts):
        total_hosts = 0
        racks_score = Counter()
        result_rack_model = None
        for host in hosts:
            total_hosts += 1
            rack_model = get_system_rack(host.platform.system)
            if rack_model is not None:
                racks_score[rack_model] += 1
        for rack_model, score in racks_score.items():
            if score * 100 // total_hosts > RACK_IDENTIFICATION_THRESHOLD:
                result_rack_model = rack_model

        hosts_ranges = {}
        rack_info = get_rack_info(result_rack_model)
        total_ranges = 1
        if rack_info is not None:
            total_ranges = len(rack_info.ranges)
        for host in hosts:
            if host.name is None:
                continue
            range_index = 0
            if rack_info is not None:
                range_index = rack_info.get_unit_range_index(host.location.unit)
            hosts_ranges[normalize_hostname(host.name)] = range_index
        return cls(aggregate, result_rack_model, hosts_ranges, total_ranges)

    def serialize(self):
        return {
            "aggregate": self.aggregate,
            "rack_model": self.rack_model,
            "hosts_ranges": self.hosts_ranges,
            "total_ranges": self.total_ranges,
        }

    def get_host_range(self, hostname):
        if self.hosts_ranges is None:
            return None
        hostname = normalize_hostname(hostname)
        if hostname not in self.hosts_ranges:
            return None
        return self.hosts_ranges[hostname]


# You can't save fields to mongo with dots in names
def normalize_hostname(name):
    return name.replace(".", "_")


@ttl_cache(maxsize=1, ttl=WEEK_SECONDS)
def _get_system_rack_map():
    result = {}
    rack_map = config.get_value("rack_map")
    for name, rack in rack_map.items():
        rack_info = RackInfo(name, rack)
        for system in rack_info.systems:
            result[system] = name
    return result


@ttl_cache(maxsize=1, ttl=WEEK_SECONDS)
def _get_rack_models():
    result = {}
    rack_map = config.get_value("rack_map")
    for name, rack in rack_map.items():
        rack_info = RackInfo(name, rack)
        result[name] = rack_info
    return result


def _get_rack_topology_cache_id(aggretate):
    return "{}:{}".format(RACK_TOPOLOGY_CACHE_ID, aggretate)


def get_system_rack(system):
    return _get_system_rack_map().get(system, None)


def get_rack_info(rack_model):
    if rack_model is None:
        return None
    return _get_rack_models().get(rack_model, None)


def get_rack_topology(aggregate):
    data = db_cache.get_cache_value(_get_rack_topology_cache_id(aggregate), RACK_TOPOLOGY_COLD_CACHE_TTL)
    return RackTopology(**data)


def save_rack_topology(rack_topology):
    db_cache.set_value(
        _get_rack_topology_cache_id(rack_topology.aggregate), rack_topology.serialize(), RACK_TOPOLOGY_COLD_CACHE_TTL
    )
