# coding: utf-8
import itertools

from infra.awacs.proto import model_pb2
from awacs.model.util import newer
from .vector import L3Vector, L3BalancerVersion, BackendVersion, EndpointSetVersion
import six


class L3BalancerStateHolder(object):
    def __init__(self, namespace_id, l3_balancer_id, l3_balancer_state_pb=None):
        self._namespace_id = namespace_id
        self._l3_balancer_id = l3_balancer_id
        self._l3_balancer_state_pb = l3_balancer_state_pb  # type: model_pb2.L3BalancerState

        self._l3_balancer_latest_valid_version = None
        self._l3_balancer_latest_version = None
        self._l3_balancer_active_version = None
        self._l3_balancer_latest_in_progress_version = None

        self._backend_latest_versions = {}
        self._backend_latest_valid_versions = {}
        self._backend_latest_in_progress_versions = {}
        self._backend_active_versions = {}

        self._endpoint_set_latest_versions = {}
        self._endpoint_set_latest_valid_versions = {}
        self._endpoint_set_active_versions = {}
        self._endpoint_set_latest_in_progress_versions = {}

        if l3_balancer_state_pb is not None:
            self.update(l3_balancer_state_pb)

    @property
    def curr_vector(self):
        return L3Vector(self._l3_balancer_latest_version,
                        dict(self._backend_latest_versions),
                        dict(self._endpoint_set_latest_versions))

    @property
    def valid_vector(self):
        return L3Vector(self._l3_balancer_latest_valid_version,
                        dict(self._backend_latest_valid_versions),
                        dict(self._endpoint_set_latest_valid_versions))

    @property
    def in_progress_vector(self):
        return L3Vector(self._l3_balancer_latest_in_progress_version,
                        dict(self._backend_latest_in_progress_versions),
                        dict(self._endpoint_set_latest_in_progress_versions))

    @property
    def active_vector(self):
        return L3Vector(self._l3_balancer_active_version,
                        dict(self._backend_active_versions),
                        dict(self._endpoint_set_active_versions))

    def get_in_progress_config_ids(self):
        """
        :rtype: dict[(six.text_type, six.text_type), datetime]
        """
        rv = {}

        state_pb = self._l3_balancer_state_pb

        revs = [state_pb.l3_balancer.l3_statuses]
        revs.extend([pb.l3_statuses for pb in six.itervalues(state_pb.backends)])
        revs.extend([pb.l3_statuses for pb in six.itervalues(state_pb.endpoint_sets)])

        for rev_pb in itertools.chain.from_iterable(revs):
            if rev_pb.in_progress.status == 'True':
                meta_pb = rev_pb.in_progress.meta  # type: model_pb2.L3ConfigTransportMeta
                assert meta_pb.type == model_pb2.L3MGR
                for config_pb in meta_pb.l3mgr.configs:
                    rv[(config_pb.service_id, config_pb.config_id)] = config_pb.ctime.ToDatetime()
        return rv

    def _forget_backend(self, backend_id):
        self._backend_latest_valid_versions.pop(backend_id, None)
        self._backend_latest_versions.pop(backend_id, None)
        self._backend_active_versions.pop(backend_id, None)
        self._backend_latest_in_progress_versions.pop(backend_id, None)

    def _forget_endpoint_set(self, endpoint_set_id):
        self._endpoint_set_latest_valid_versions.pop(endpoint_set_id, None)
        self._endpoint_set_latest_versions.pop(endpoint_set_id, None)
        self._endpoint_set_active_versions.pop(endpoint_set_id, None)
        self._endpoint_set_latest_in_progress_versions.pop(endpoint_set_id, None)

    def _list_known_backend_ids(self):
        return (
            set(self._backend_latest_valid_versions) |
            set(self._backend_latest_in_progress_versions) |
            set(self._backend_latest_versions) |
            set(self._backend_active_versions)
        )

    def _list_known_endpoint_set_ids(self):
        return (
            set(self._endpoint_set_latest_valid_versions) |
            set(self._endpoint_set_latest_in_progress_versions) |
            set(self._endpoint_set_latest_versions) |
            set(self._endpoint_set_active_versions)
        )

    def update(self, l3_balancer_state_pb):
        """
        :type l3_balancer_state_pb: infra.awacs.proto.model_pb2.L3BalancerState
        """
        assert l3_balancer_state_pb.namespace_id == self._namespace_id
        assert l3_balancer_state_pb.l3_balancer_id == self._l3_balancer_id

        self._l3_balancer_state_pb = l3_balancer_state_pb

        self._l3_balancer_latest_in_progress_version = None
        for rev_pb in l3_balancer_state_pb.l3_balancer.l3_statuses:
            v = L3BalancerVersion.from_rev_status_pb(self._l3_balancer_id, rev_pb)
            self._l3_balancer_latest_version = newer(self._l3_balancer_latest_version, v)
            if rev_pb.validated.status == 'True':
                self._l3_balancer_latest_valid_version = newer(self._l3_balancer_latest_valid_version, v)
            if rev_pb.in_progress.status == 'True':
                self._l3_balancer_latest_in_progress_version = newer(
                    self._l3_balancer_latest_in_progress_version, v)
            if rev_pb.active.status == 'True':
                self._l3_balancer_active_version = newer(self._l3_balancer_active_version, v)

        for backend_id, backend_state_pb in six.iteritems(l3_balancer_state_pb.backends):
            if backend_id in self._backend_latest_in_progress_versions:
                del self._backend_latest_in_progress_versions[backend_id]

            for rev_pb in backend_state_pb.l3_statuses:
                v = BackendVersion.from_rev_status_pb(backend_id, rev_pb)

                self._backend_latest_versions[backend_id] = newer(
                    self._backend_latest_versions.get(backend_id), v)

                if rev_pb.validated.status == 'True':
                    self._backend_latest_valid_versions[backend_id] = newer(
                        self._backend_latest_valid_versions.get(backend_id), v)

                if rev_pb.in_progress.status == 'True':
                    self._backend_latest_in_progress_versions[backend_id] = newer(
                        self._backend_latest_in_progress_versions.get(backend_id), v)

                if rev_pb.active.status == 'True':
                    self._backend_active_versions[backend_id] = newer(
                        self._backend_active_versions.get(backend_id), v)

        for backend_id in self._list_known_backend_ids():
            if backend_id not in l3_balancer_state_pb.backends:
                self._forget_backend(backend_id)

        for endpoint_set_id, endpoint_set_state_pb in six.iteritems(l3_balancer_state_pb.endpoint_sets):
            if endpoint_set_id in self._endpoint_set_latest_in_progress_versions:
                del self._endpoint_set_latest_in_progress_versions[endpoint_set_id]

            for rev_pb in endpoint_set_state_pb.l3_statuses:
                v = EndpointSetVersion.from_rev_status_pb(endpoint_set_id, rev_pb)

                self._endpoint_set_latest_versions[endpoint_set_id] = newer(
                    self._endpoint_set_latest_versions.get(endpoint_set_id), v)

                if rev_pb.validated.status == 'True':
                    self._endpoint_set_latest_valid_versions[endpoint_set_id] = newer(
                        self._endpoint_set_latest_valid_versions.get(endpoint_set_id), v)

                if rev_pb.in_progress.status == 'True':
                    self._endpoint_set_latest_in_progress_versions[endpoint_set_id] = newer(
                        self._endpoint_set_latest_in_progress_versions.get(endpoint_set_id), v)

                if rev_pb.active.status == 'True':
                    self._endpoint_set_active_versions[endpoint_set_id] = newer(
                        self._endpoint_set_active_versions.get(endpoint_set_id), v)

        for endpoint_set_id in self._list_known_endpoint_set_ids():
            if endpoint_set_id not in l3_balancer_state_pb.endpoint_sets:
                self._forget_endpoint_set(endpoint_set_id)
