import collections
import datetime
import logging

import pymongo
import gevent

from libraries.online_state import InstanceState
from libraries.topology import load_tags, get_latest_success_commit, get_topology_by_version
from libraries.topology.groups import get_version_for_commit, is_trunk, get_commit_of_version
from libraries.topology.searcher_lookup import load_group_trunk, load_instances
from libraries.utils import shortname, singleton, memoize
from libraries.containers.clusters import Cluster, Allocation, ClusterType, Smooth

from reports_ctrl import ReportsController
from utils import find_group
from libraries.mongo_params import HEARTBEAT_MONGODB


class DeadMongo(Exception):
    message = 'mongo is dead'


TRUNK = 'trunk'


class TopologyCache(object):
    def __init__(self):
        self._trunk_commit = 0
        self._trunk_host_group_mapping = {}
        self._trunk_topology = {}
        self._trunk_searcherlookup = {}
        self._min_tag = None
        self._max_tag = None

    def update(self):
        self._update_trunk()
        self._get_tags()

    def _update_trunk(self):
        latest_commit = get_latest_success_commit()
        if latest_commit != self.trunk_commit:
            fresh_topology = get_topology_by_version(get_version_for_commit(latest_commit))
            host_group_mapping = map_groups_to_hosts(fresh_topology)
            (
                self._trunk_host_group_mapping,
                self._trunk_topology,
                self._trunk_commit,
                self._trunk_searcherlookup
            ) = (
                host_group_mapping,
                fresh_topology,
                latest_commit,
                {}
            )

    def _get_tags(self):
        tags_ = load_tags()
        self._min_tag = sorted(tags_.values(), key=lambda x: int(x['commit']))[0]['tag']
        self._max_tag = sorted(tags_.values(), key=lambda x: int(x['commit']))[-1]['tag']

    @property
    def trunk_host_group_mapping(self):
        return self._trunk_host_group_mapping

    @property
    def trunk_topology(self):
        return self._trunk_topology

    @property
    def trunk_commit(self):
        return self._trunk_commit

    @property
    def min_tag(self):
        return self._min_tag

    @property
    def max_tag(self):
        return self._max_tag

    def trunk_searcherlookup(self, group):
        cache = self._trunk_searcherlookup
        if group not in cache:
            cache[group] = load_group_trunk(self._trunk_commit, group)
        return cache[group]


class OnlineStateCache(object):
    host_liveness_timeout = datetime.timedelta(minutes=30)

    def __init__(self):
        self._db = get_mongo_client()['heartbeat']
        self._alive_hosts_cache = set()
        self._instance_state = InstanceState()

    def update(self):
        self._get_alive()
        self._get_groups()

    def _get_alive(self):
        alive = set()
        cursor = self._db['hostinfo'].find(
            {'last_update': {'$gt': datetime.datetime.now() - self.host_liveness_timeout}},
            {'host': 1, '_id': 0}
        )
        for r in cursor:
            alive.add(shortname(r['host']))
        self._alive_hosts_cache = alive

    def _get_groups(self):
        self._instance_state.update(self._alive_hosts_cache, self._db, sleep_method=gevent.sleep)
        self._instance_state.groups_state.recalc_alive(sleep_method=gevent.sleep)

    @property
    def groups(self):
        return self._instance_state.groups_state


class Updater(object):
    def __init__(self):
        self._active = False
        self._last_update = datetime.datetime.min
        self._sleep_time = 30  # seconds

        self._clusters = {}
        self._type_clusters = {}
        self._host_allocations = {}
        self._type_host_groups = {}
        self._group_slots = {}
        self._smooth_clusters = {}

        self._topology_cache = TopologyCache()
        self._online_state_cache = OnlineStateCache()
        self._report_ctrl = ReportsController()

    @property
    def report_ctrl(self):
        return self._report_ctrl

    @property
    def topology(self):
        return self._topology_cache

    @property
    def online_state(self):
        return self._online_state_cache

    @property
    def clusters(self):
        return self._clusters

    @property
    def host_allocations(self):
        return self._host_allocations

    @property
    def type_host_groups(self):
        return self._type_host_groups

    @property
    def type_clusters(self):
        return self._type_clusters

    @property
    def smooth_clusters(self):
        return self._smooth_clusters

    @property
    def group_slots(self):
        return self._group_slots

    # load data methods

    def _get_clusters(self):
        self._clusters = {a.name: a for a in Cluster.list()}
        type_host_groups = {t: {} for t in ClusterType.ALL}
        type_clusters = {t: {} for t in ClusterType.ALL}

        banned_host_clusters = self._get_banned_hosts()
        for cluster in self._clusters.values():
            type_clusters[cluster.type][cluster.name] = cluster
            host_groups = type_host_groups[cluster.type]
            for group, conf in cluster.groups.items():
                for host in self.instances_for_group(group, self.find_version_for_group(conf.agents_group)):
                    if (
                        cluster.name not in banned_host_clusters[host]
                        and (
                            not conf.intersect_with
                            or host in self.hosts_for_group(
                                conf.intersect_with, self.find_version_for_group(conf.intersect_with)
                            )
                        )
                    ):
                        if host not in host_groups:
                            host_groups[host] = set()
                        host_groups[host].add(group)

        self._type_host_groups = type_host_groups
        self._type_clusters = type_clusters

    def _get_banned_hosts(self):
        banned_host_clusters = collections.defaultdict(set)
        for cluster in self._clusters.values():
            for group in cluster.ban_groups:
                for host in self.hosts_for_group(group, self.find_version_for_group(group)):
                    banned_host_clusters[host].add(cluster.name)
        return banned_host_clusters

    def _get_allocations(self):
        host_allocations = {}
        for a in Allocation.list():
            if a.expired:
                a.remove()
            else:
                if a.host not in host_allocations:
                    host_allocations[a.host] = []
                host_allocations[a.host].append(a)
        self._host_allocations = host_allocations

    def _get_smooth(self):
        smooth_clusters = {}
        for smooth in Smooth.list():
            if smooth.name not in smooth_clusters:
                smooth_clusters[smooth.name] = []
            smooth_clusters[smooth.name].append(smooth)
        for name in smooth_clusters:
            smooth_clusters[name].sort(key=lambda x: x.since, reverse=True)
        self._smooth_clusters = smooth_clusters

    def _get_slots(self):
        g_slots = {}
        for cluster in sum(self.smooth_clusters.values(), self.clusters.values()):
            for group, conf in cluster.groups.items():
                g_slots[group] = self.instances_for_group(group, self.find_version_for_group(conf.agents_group))
        self._group_slots = g_slots

    def _update_all(self):
        self._online_state_cache.update()
        logging.info('updated online state')
        self._topology_cache.update()
        logging.info('updated topology')
        self._get_clusters()
        logging.info('done clusters')
        self._get_allocations()
        logging.info('done allocations')
        self._get_smooth()
        logging.info('loaded smooth')
        self._get_slots()
        logging.info('loaded slots')
        self._last_update = datetime.datetime.now()

    def start(self):
        self._report_ctrl.start()
        self._active = True
        while self._active:
            try:
                with gevent.Timeout(10 * 60, DeadMongo):
                    self._update_all()
                gevent.sleep(self._sleep_time)
            except DeadMongo:
                logging.error('mongo died')
            except Exception as ex:
                logging.exception(ex)
                gevent.sleep(self._sleep_time)

    def stop(self):
        self._active = False
        self._report_ctrl.kill()

    @property
    def ready(self):
        return self._last_update > datetime.datetime.min

    def instances_for_group(self, group, version):
        res = {}
        if version is TRUNK:
            data = self.topology.trunk_searcherlookup(group)
            instances = data['instances'] if data else []
        else:
            instances = load_instances(get_commit_of_version(version), group) or []
        for one in instances:
            host, port = one['hostname'], one['port']
            if host not in res:
                res[host] = {}
            res[host][port] = one
        return res

    def hosts_for_group(self, group, version):
        hosts = set()
        for host in self.instances_for_group(group, version):
            hosts.add(host)
        return list(hosts)

    def find_version_for_group(self, group):
        if group:
            _, ver = find_group(self.online_state.groups, group)
            if ver and not is_trunk(ver):
                return ver
        return TRUNK


def map_groups_to_hosts(data):
    hosts = {}
    for group in data:
        for host in data[group]['hosts']:
            if host not in hosts:
                hosts[host] = []
            hosts[host].append(group)
    return hosts


@singleton
def get_mongo_client():
    return pymongo.MongoReplicaSetClient(
        HEARTBEAT_MONGODB.uri,
        replicaSet=HEARTBEAT_MONGODB.replicaset,
        localThresholdMS=30000,
        connectTimeoutMS=5000,
        read_preference=HEARTBEAT_MONGODB.read_preference,
    )


@memoize
def updater():
    return Updater()
