from collections import defaultdict

import datetime

from libraries import topology

_alive_threshold = datetime.timedelta(hours=2, minutes=30)


class GroupState(object):
    def __init__(self, topology_instances):
        self._len = len(topology_instances)
        self._by_host = defaultdict(dict)
        for instance in topology_instances:
            self._by_host[instance[0]][instance] = datetime.datetime.min

        self._alive_instances_cnt = 0
        self._all_instances_cnt = 0
        self._last_modified = datetime.datetime.min
        self._last_recalced = datetime.datetime.min

    def __len__(self):
        return self._len

    def __repr__(self):
        return '<{} {}/{}>'.format(self.__class__.__name__, len(list(self.iter_alive_instances())), len(self))

    def iter_hosts(self):
        return self._by_host.iterkeys()

    def set_alive(self, host, instance, when):
        if instance in self._by_host[host]:
            self._by_host[host][instance] = max(self._by_host[host][instance], when)
        self._last_modified = datetime.datetime.now()

    def set_dead(self, host):
        for instance in self._by_host[host]:
            self._by_host[host][instance] = datetime.datetime.min
        self._last_modified = datetime.datetime.now()

    def iter_alive_instances(self):
        for host in self._by_host.itervalues():
            for instance, when in host.iteritems():
                if datetime.datetime.now() - when < _alive_threshold:
                    yield instance

    def iter_all_instances(self):
        for host in self._by_host.itervalues():
            for instance in host.iterkeys():
                yield instance

    def alive_instances_cnt(self):
        return self._alive_instances_cnt

    def all_instances_cnt(self):
        return self._all_instances_cnt

    def recalc_cnt(self):
        now = datetime.datetime.now()
        if self._last_modified < self._last_recalced and self._last_recalced + _alive_threshold > now:
            return
        dead_cnt, alive_cnt = 0, 0
        for host in self._by_host.itervalues():
            for instance, when in host.iteritems():
                if now - when < _alive_threshold:
                    alive_cnt += 1
                else:
                    dead_cnt += 1
        self._all_instances_cnt = alive_cnt + dead_cnt
        self._alive_instances_cnt = alive_cnt
        self._last_recalced = datetime.datetime.now()


class OnlineState(object):
    def __init__(self):
        self._by_host = defaultdict(dict)
        self._by_id = defaultdict(dict)

    def groups_on_host(self, host):
        return self._by_host[host]

    def group_versions(self, group_name):
        return self._by_id[group_name]

    def groups(self):
        return self._by_id.keys()

    def reset_host(self, host):
        for group in self._by_host[host].itervalues():
            group.set_dead(host)

    def update(self, group_name, group_revision, host, instance, when):
        group = self._ensure_group(group_name, group_revision)
        group.set_alive(host, instance, when)

    def _ensure_group(self, name, revision):
        if revision in self._by_id[name]:
            return self._by_id[name][revision]

        group = GroupState(_load_group(name, revision))
        self._by_id[name][revision] = group

        for host in group.iter_hosts():
            self._by_host[host][name, revision] = group

        return group

    def recalc_alive(self, sleep_method=None):
        for g in self._by_id.values():
            for v in g.values():
                v.recalc_cnt()
                if sleep_method:
                    sleep_method()


def _load_group(name, revision):
    return topology.load_group2(name, revision)
