from __future__ import unicode_literals
import collections

from sepelib.core import config

from infra.mc_rsc.src import podutil


class ClusterRevisionIndex(object):

    def __init__(self):
        self._d = collections.defaultdict(set)

    def add(self, pod):
        r = pod.spec.pod_agent_payload.spec.revision
        self._d[r].add(pod.meta.id)

    def find(self, revision):
        return iter(self._d[revision])

    def exclude(self, revision):
        sorted_by_revision = sorted(self._d.iteritems(), reverse=True)
        for r, p_ids in sorted_by_revision:
            if r == revision:
                continue
            for p_id in p_ids:
                yield p_id

    def all(self):
        for r in self.revisions():
            for p_id in self.find(r):
                yield p_id

    def revisions(self):
        for r, p_ids in self._d.iteritems():
            if p_ids:
                yield r

    def count(self, revision):
        return len(self._d[revision])

    def count_all(self):
        return sum(len(p_ids) for p_ids in self._d.itervalues())


class MultiClusterRevisionIndex(object):

    def __init__(self):
        self._d = collections.defaultdict(ClusterRevisionIndex)

    def get_cluster_index(self, cluster):
        return self._d[cluster]

    def add_cluster_index(self, idx, cluster):
        self._d[cluster] = idx

    def add(self, pod, cluster):
        self._d[cluster].add(pod)

    def find(self, revision):
        for c, idx in self._d.iteritems():
            for p_id in idx.find(revision):
                yield c, p_id

    def exclude(self, revision):
        for c, idx in self._d.iteritems():
            for p_id in idx.exclude(revision):
                yield c, p_id

    def all(self):
        for c, idx in self._d.iteritems():
            for p_id in idx.all():
                yield c, p_id

    def first(self):
        for c, idx in self._d.iteritems():
            for p_id in idx.all():
                return c, p_id

    def revisions(self):
        seen = set()
        for idx in self._d.itervalues():
            for r in idx.revisions():
                if r in seen:
                    continue
                yield r
                seen.add(r)

    def count(self, revision):
        return sum(idx.count(revision) for idx in self._d.itervalues())

    def count_all(self):
        return sum(idx.count_all() for idx in self._d.itervalues())


class Maintenance(object):
    __slots__ = ('acknowledged', 'waiting_in_progress', 'waiting_ready', 'maintenance_node_set_id', 'disruptive')

    def __init__(self):
        self.acknowledged = MultiClusterRevisionIndex()
        self.waiting_in_progress = MultiClusterRevisionIndex()
        self.waiting_ready = MultiClusterRevisionIndex()
        self.maintenance_node_set_id = ''
        self.disruptive = False

    def __lt__(self, other):
        assert isinstance(other, Maintenance)
        if self.disruptive == other.disruptive:
            return self.maintenance_node_set_id < other.maintenance_node_set_id
        return self.disruptive < other.disruptive


class MultiClusterCurrentState(object):
    def __init__(self, max_tolerable_downtime_seconds):
        default_max_tolerable_downtime_seconds = config.get_value('controller.default_max_tolerable_downtime_seconds', 0)
        self.max_tolerable_downtime_seconds = max_tolerable_downtime_seconds or default_max_tolerable_downtime_seconds

        self.in_progress = MultiClusterRevisionIndex()
        self.ready = MultiClusterRevisionIndex()
        self.failed = MultiClusterRevisionIndex()

        self.node_alerted = MultiClusterRevisionIndex()
        self.maintenance_overtimed = MultiClusterRevisionIndex()
        self.to_evict_in_progress = MultiClusterRevisionIndex()
        self.to_evict_ready = MultiClusterRevisionIndex()
        self.maintenances = collections.defaultdict(Maintenance)
        self.implicitly_dead = 0
        self.target_state_removed = MultiClusterRevisionIndex()
        self.removing_delegate_required = MultiClusterRevisionIndex()

    def add(self, pod, cluster):
        is_ready = podutil.is_pod_ready(pod)
        is_failed = podutil.is_pod_failed(pod)
        if podutil.is_target_state_removed(pod):
            self.target_state_removed.add(pod, cluster)
        elif is_ready:
            # if is_ready == True and is_failed == True, pod is alive (DEPLOY-4899)
            self.ready.add(pod, cluster)
        elif is_failed:
            self.failed.add(pod, cluster)
        else:
            self.in_progress.add(pod, cluster)

        if podutil.is_pod_eviction_acknowledged(pod) or podutil.is_target_state_removed(pod):
            # Pod will evicted or removed anyway, don't have to do anything
            return

        if podutil.is_pod_marked_removing_delegate(pod) and not podutil.is_pod_eviction_requested_ignore_reason(pod):
            self.removing_delegate_required.add(pod, cluster)

        if podutil.is_pod_node_alerted(pod):
            self.node_alerted.add(pod, cluster)
            if is_ready:
                self.implicitly_dead += 1
        elif podutil.is_maintenance_overtimed(pod, self.max_tolerable_downtime_seconds):
            self.maintenance_overtimed.add(pod, cluster)
            if is_ready:
                self.implicitly_dead += 1
        elif podutil.is_maintenance_in_progress(pod):
            disruptive = podutil.is_maintenance_disruptive(pod, self.max_tolerable_downtime_seconds)
            m = self.maintenances[pod.status.maintenance.info.node_set_id]
            m.acknowledged.add(pod, cluster)
            m.disruptive |= disruptive
            m.maintenance_node_set_id = pod.status.maintenance.info.node_set_id
            if is_ready:
                self.implicitly_dead += 1
        elif podutil.is_pod_eviction_requested(pod):
            if is_ready:
                self.to_evict_ready.add(pod, cluster)
            else:
                self.to_evict_in_progress.add(pod, cluster)
        elif podutil.is_maintenance_requested(pod):
            disruptive = podutil.is_maintenance_disruptive(pod, self.max_tolerable_downtime_seconds)
            m = self.maintenances[pod.status.maintenance.info.node_set_id]
            m.disruptive |= disruptive
            m.maintenance_node_set_id = pod.status.maintenance.info.node_set_id
            if is_ready:
                m.waiting_ready.add(pod, cluster)
            else:
                m.waiting_in_progress.add(pod, cluster)

    def count(self, revision):
        return self.ready.count(revision) + self.in_progress.count(revision) \
            + self.target_state_removed.count(revision) + self.failed.count(revision)

    def count_all(self):
        return self.ready.count_all() + self.in_progress.count_all() \
            + self.target_state_removed.count_all() + self.failed.count_all()

    def revisions(self):
        seen = set()
        for r in self.ready.revisions():
            yield r
            seen.add(r)
        for r in self.in_progress.revisions():
            if r in seen:
                continue
            yield r
            seen.add(r)
        for r in self.target_state_removed.revisions():
            if r in seen:
                continue
            yield r
            seen.add(r)

    def make_filtered_by_cluster_current_state(self, cluster):
        rv = MultiClusterCurrentState(self.max_tolerable_downtime_seconds)
        rv.in_progress.add_cluster_index(
            cluster=cluster,
            idx=self.in_progress.get_cluster_index(cluster)
        )
        rv.failed.add_cluster_index(
            cluster=cluster,
            idx=self.failed.get_cluster_index(cluster)
        )
        rv.ready.add_cluster_index(
            cluster=cluster,
            idx=self.ready.get_cluster_index(cluster)
        )
        rv.to_evict_in_progress.add_cluster_index(
            cluster=cluster,
            idx=self.to_evict_in_progress.get_cluster_index(cluster)
        )
        rv.to_evict_ready.add_cluster_index(
            cluster=cluster,
            idx=self.to_evict_ready.get_cluster_index(cluster)
        )
        rv.node_alerted.add_cluster_index(
            cluster=cluster,
            idx=self.node_alerted.get_cluster_index(cluster)
        )
        rv.maintenance_overtimed.add_cluster_index(
            cluster=cluster,
            idx=self.maintenance_overtimed.get_cluster_index(cluster)
        )
        rv.target_state_removed.add_cluster_index(
            cluster=cluster,
            idx=self.target_state_removed.get_cluster_index(cluster)
        )
        for node_set_id, maintenance in self.maintenances.iteritems():
            m = rv.maintenances[node_set_id]
            m.maintenance_node_set_id = maintenance.maintenance_node_set_id
            m.disruptive = maintenance.disruptive
            m.acknowledged.add_cluster_index(
                cluster=cluster,
                idx=maintenance.acknowledged.get_cluster_index(cluster)
            )
            m.waiting_in_progress.add_cluster_index(
                cluster=cluster,
                idx=maintenance.waiting_in_progress.get_cluster_index(cluster)
            )
            m.waiting_ready.add_cluster_index(
                cluster=cluster,
                idx=maintenance.waiting_ready.get_cluster_index(cluster)
            )
        return rv

    def __str__(self):
        parts = []
        for r in self.revisions():
            p = '{}: ready={}/in_progress={}/failed={}'.format(
                r,
                self.ready.count(r),
                self.in_progress.count(r),
                self.failed.count(r),
            )
            parts.append(p)
        parts.append('to_evict_in_progress: {}'.format(self.to_evict_in_progress.count_all()))
        parts.append('to_evict_ready: {}'.format(self.to_evict_ready.count_all()))
        parts.append('target_state_removed: {}'.format(self.target_state_removed.count_all()))
        return ', '.join(parts)
