from __future__ import unicode_literals
import collections
import logging

from sepelib.core import config
from infra.swatlib.gevent import geventutil as gutil
from infra.mc_rsc.src import consts


log = logging.getLogger('comparer')


def compare_pods(p1, p2, compare_attr_paths):
    for path in compare_attr_paths:
        v1 = p1
        v2 = p2
        for attr in path:
            v1 = getattr(v1, attr)
            v2 = getattr(v2, attr)
        if v1 != v2:
            log.error('pod %s is not equal to pod %s: %s:\n%s\n%s',
                      p1.meta.id, p2.meta.id, '/'.join(path), v1, v2)
            return False
    return True


def compare_dict_keys(d1, d2):
    s1 = set()
    for el in gutil.gevent_idle_iter(d1):
        s1.add(el)
    for e2 in gutil.gevent_idle_iter(d2):
        if e2 not in s1:
            return False
        s1.discard(e2)
    return not s1


class ClusterStorage(object):
    def __init__(self):
        self._index = {}

    @classmethod
    def make_from_objects(cls, objs_with_timestamps):
        s = cls()
        for obj, updated_time in gutil.gevent_idle_iter(objs_with_timestamps):
            s.put(obj, updated_time)
        return s

    def get(self, obj_id):
        return self._index.get(obj_id)

    def put(self, obj, updated_time=None):
        self._index[obj.meta.id] = obj

    def remove(self, obj_id, removed_time=None):
        self._index.pop(obj_id, None)

    def replace(self, obj, updated_time=None):
        self.remove(obj.meta.id, None)
        self.put(obj, updated_time)

    def list(self):
        return self._index.itervalues()

    def size(self):
        return len(self._index)

    def sync_with_objects(self, objs_with_timestamps):
        s = self.make_from_objects(objs_with_timestamps)
        self._index = s._index

    def compare(self, s):
        if not s:
            return False
        return self._index == s._index


class RelationClusterStorage(ClusterStorage):
    def __init__(self):
        super(RelationClusterStorage, self).__init__()
        self._to_fqid_index = collections.defaultdict(set)

    def list_by_to_fqid(self, to_fqid):
        rv = []
        ids = self._to_fqid_index.get(to_fqid, [])
        for r_id in gutil.gevent_idle_iter(ids):
            r = self._index.get(r_id)
            if r:
                rv.append(r)
        return rv

    def find(self, from_fqid, to_fqid):
        rv = []
        to_fqid_rels = self.list_by_to_fqid(to_fqid)
        for r in gutil.gevent_idle_iter(to_fqid_rels):
            if r.meta.from_fqid == from_fqid:
                rv.append(r)
        return rv

    def put(self, obj, updated_time=None):
        super(RelationClusterStorage, self).put(obj, updated_time)
        to_fqid = obj.meta.to_fqid
        self._to_fqid_index[to_fqid].add(obj.meta.id)

    def remove(self, obj_id, removed_time=None):
        r = self.get(obj_id)
        if not r:
            return
        ids_by_to_fqid = self._to_fqid_index.get(r.meta.to_fqid)
        if ids_by_to_fqid:
            ids_by_to_fqid.remove(obj_id)
            if not ids_by_to_fqid:
                self._to_fqid_index.pop(r.meta.to_fqid, None)
        super(RelationClusterStorage, self).remove(obj_id, removed_time)

    def sync_with_objects(self, objs_with_timestamps):
        s = self.make_from_objects(objs_with_timestamps)
        self._index = s._index
        self._to_fqid_index = s._to_fqid_index


class PodClusterStorage(ClusterStorage):
    def __init__(self):
        super(PodClusterStorage, self).__init__()
        self._ps_index = collections.defaultdict(set)
        self._last_deploy_timestamp_index = {}

    def _update_last_deploy_timestamp_index(self, pod, updated_time):
        if not updated_time:
            return
        ps_id = pod.meta.pod_set_id
        deploy_ts = self._last_deploy_timestamp_index.get(ps_id)
        if deploy_ts is None or deploy_ts < updated_time:
            self._last_deploy_timestamp_index[ps_id] = updated_time

    def list_by_ps_id(self, ps_id):
        rv = []
        ids = self._ps_index.get(ps_id, [])
        for p_id in gutil.gevent_idle_iter(ids):
            p = self._index.get(p_id)
            if p:
                rv.append(p)
        return rv

    def get_last_deploy_timestamp(self, ps_id):
        return self._last_deploy_timestamp_index.get(ps_id)

    def put(self, obj, updated_time=None):
        super(PodClusterStorage, self).put(obj, updated_time)
        ps_id = obj.meta.pod_set_id
        self._ps_index[ps_id].add(obj.meta.id)
        self._update_last_deploy_timestamp_index(obj, updated_time)

    def remove(self, obj_id, removed_time=None):
        p = self.get(obj_id)
        if not p:
            return
        ids_by_ps = self._ps_index.get(p.meta.pod_set_id)
        if ids_by_ps:
            ids_by_ps.remove(obj_id)
            if not ids_by_ps:
                self._ps_index.pop(p.meta.pod_set_id, None)
        self._update_last_deploy_timestamp_index(p, removed_time)
        super(PodClusterStorage, self).remove(obj_id, removed_time)

    def sync_with_objects(self, objs_with_timestamps):
        s = self.make_from_objects(objs_with_timestamps)
        self._index = s._index
        self._ps_index = s._ps_index
        self._last_deploy_timestamp_index = s._last_deploy_timestamp_index

    def compare(self, s):
        if not compare_dict_keys(self._index, s._index):
            log.error('pod indexes are not equal:\n%s\n%s', self._index.keys(), s._index.keys())
            return False
        compare_selectors = config.get_value('watches.compare_pod_selectors', consts.POD_WATCH_SELECTORS)
        compare_attr_paths = []
        for sel in compare_selectors:
            compare_attr_paths.append(sel.strip("/").split("/"))
        for p_id, p1 in gutil.gevent_idle_iter(self._index.iteritems()):
            p2 = s._index[p_id]
            if not compare_pods(p1, p2, compare_attr_paths):
                return False
        for ps_id, this_p_ids in gutil.gevent_idle_iter(self._ps_index.iteritems()):
            p_ids = s._ps_index.get(ps_id)
            if not this_p_ids and not p_ids:
                continue
            if this_p_ids != p_ids:
                log.error('ps indexes are not equal for ps %s\n%s\n%s', ps_id, this_p_ids, p_ids)
                return False
            this_deploy_ts = self._last_deploy_timestamp_index.get(ps_id)
            deploy_ts = s._last_deploy_timestamp_index.get(ps_id)
            if this_deploy_ts != deploy_ts:
                log.error('deploy ts indexes are not equal for ps %s\n%s\n%s', ps_id, this_deploy_ts, deploy_ts)
                return False
        if len(self._ps_index) != len(s._ps_index):
            log.error('ps index sizes are not equal: %d != %d', len(self._ps_index), len(s._ps_index))
            return False
        return True


class MultiClusterStorage(object):
    def __init__(self):
        self._storages = {}

    def add_storage(self, s, cluster):
        self._storages[cluster] = s

    def get(self, obj_id, cluster):
        s = self._storages.get(cluster)
        if not s:
            return
        return s.get(obj_id)

    def put(self, obj, timestamps, cluster):
        s = self._storages.get(cluster)
        if not s:
            return
        return s.put(obj, timestamps)

    def list(self, cluster):
        s = self._storages.get(cluster)
        if not s:
            return
        for o in s.list():
            yield o

    def size(self, cluster):
        s = self._storages.get(cluster)
        if not s:
            return 0
        return s.size()

    def sync_with_objects(self, objs_with_timestamps, cluster):
        s = self._storages.get(cluster)
        if not s:
            return
        s.sync_with_objects(objs_with_timestamps)

    def compare(self, s):
        for c, cluster_s in self._storages.iteritems():
            if not cluster_s.compare(s._storages.get(c)):
                return False
        return True


class PodMultiClusterStorage(MultiClusterStorage):
    def list_by_ps_id(self, ps_id, cluster):
        s = self._storages.get(cluster)
        if not s:
            return []
        return s.list_by_ps_id(ps_id)

    def get_last_deploy_timestamp(self, ps_id, cluster):
        s = self._storages.get(cluster)
        if not s:
            return None
        return s.get_last_deploy_timestamp(ps_id)
