# coding: utf-8

import six

from awacs.lib.strutils import flatten_full_id, to_full_id
from awacs.model.balancer.state_handler import L7BalancerStateHandler
from awacs.model.balancer.vector import (
    Vector,
    BalancerVersion,
    UpstreamVersion,
    DomainVersion,
    BackendVersion,
    EndpointSetVersion,
    KnobVersion,
    CertVersion,
    VersionName,
    copy_dict_of_dicts,
)
from awacs.model import objects
from awacs.model.util import newer


class BalancerStateHolder(object):
    def __init__(self, namespace_id, balancer_id, balancer_state_pb=None):
        self._namespace_id = namespace_id
        self._balancer_id = balancer_id
        self._balancer_state_pb = balancer_state_pb

        self._reset()
        if balancer_state_pb is not None:
            self.update(balancer_state_pb)

    def _reset(self):
        self._balancer_latest_valid_version = None
        self._balancer_latest_version = None
        self._balancer_active_version = None
        self._balancer_latest_in_progress_version = None

        self._upstream_latest_valid_versions = {}
        self._upstream_latest_versions = {}
        self._upstream_latest_in_progress_versions = {}
        self._upstream_active_versions = {}

        self._domain_latest_valid_versions = {}
        self._domain_latest_versions = {}
        self._domain_latest_in_progress_versions = {}
        self._domain_active_versions = {}

        self._backend_latest_versions = {}
        self._backend_latest_valid_versions = {}
        self._backend_latest_in_progress_versions = {}
        self._backend_active_versions = {}

        self._endpoint_set_latest_versions = {}
        self._endpoint_set_latest_valid_versions = {}
        self._endpoint_set_active_versions = {}
        self._endpoint_set_latest_in_progress_versions = {}

        self._knob_latest_versions = {}
        self._knob_latest_valid_versions = {}
        self._knob_active_versions = {}
        self._knob_latest_in_progress_versions = {}

        self._cert_latest_versions = {}
        self._cert_latest_valid_versions = {}
        self._cert_active_versions = {}
        self._cert_latest_in_progress_versions = {}

        self._weight_section_latest_versions = {}
        self._weight_section_latest_valid_versions = {}
        self._weight_section_active_versions = {}
        self._weight_section_latest_in_progress_versions = {}

        self._validated_pbs = {v: {} for v in VersionName.__members__.values()}

    @property
    def endpoint_set_latest_versions(self):
        return self._endpoint_set_latest_versions

    @property
    def endpoint_set_latest_valid_versions(self):
        return self._endpoint_set_latest_valid_versions

    @property
    def backend_latest_versions(self):
        return self._backend_latest_versions

    @property
    def backend_latest_valid_versions(self):
        return self._backend_latest_valid_versions

    @property
    def balancer_latest_in_progress_version(self):
        return self._balancer_latest_in_progress_version

    @property
    def balancer_active_version(self):
        return self._balancer_active_version

    @property
    def curr_vector(self):
        return Vector(self._balancer_latest_version,
                      dict(self._upstream_latest_versions),
                      dict(self._domain_latest_versions),
                      dict(self._backend_latest_versions),
                      dict(self._endpoint_set_latest_versions),
                      dict(self._knob_latest_versions),
                      dict(self._cert_latest_versions),
                      dict(self._weight_section_latest_versions),
                      copy_dict_of_dicts(self._validated_pbs))

    @property
    def valid_vector(self):
        return Vector(self._balancer_latest_valid_version,
                      dict(self._upstream_latest_valid_versions),
                      dict(self._domain_latest_valid_versions),
                      dict(self._backend_latest_valid_versions),
                      dict(self._endpoint_set_latest_valid_versions),
                      dict(self._knob_latest_valid_versions),
                      dict(self._cert_latest_valid_versions),
                      dict(self._weight_section_latest_valid_versions))

    @property
    def in_progress_vector(self):
        return Vector(self._balancer_latest_in_progress_version,
                      dict(self._upstream_latest_in_progress_versions),
                      dict(self._domain_latest_in_progress_versions),
                      dict(self._backend_latest_in_progress_versions),
                      dict(self._endpoint_set_latest_in_progress_versions),
                      dict(self._knob_latest_in_progress_versions),
                      dict(self._cert_latest_in_progress_versions),
                      dict(self._weight_section_latest_in_progress_versions))

    @property
    def active_vector(self):
        return Vector(self._balancer_active_version,
                      dict(self._upstream_active_versions),
                      dict(self._domain_active_versions),
                      dict(self._backend_active_versions),
                      dict(self._endpoint_set_active_versions),
                      dict(self._knob_active_versions),
                      dict(self._cert_active_versions),
                      dict(self._weight_section_active_versions))

    def _forget_upstream(self, upstream_id):
        self._upstream_latest_valid_versions.pop(upstream_id, None)
        self._upstream_latest_versions.pop(upstream_id, None)
        self._upstream_active_versions.pop(upstream_id, None)
        self._upstream_latest_in_progress_versions.pop(upstream_id, None)
        self._validated_pbs[VersionName.UPSTREAM].pop(upstream_id, None)

    def _forget_domain(self, domain_id):
        self._domain_latest_valid_versions.pop(domain_id, None)
        self._domain_latest_versions.pop(domain_id, None)
        self._domain_active_versions.pop(domain_id, None)
        self._domain_latest_in_progress_versions.pop(domain_id, None)
        self._validated_pbs[VersionName.DOMAIN].pop(domain_id, None)

    def _forget_backend(self, backend_id):
        self._backend_latest_valid_versions.pop(backend_id, None)
        self._backend_latest_versions.pop(backend_id, None)
        self._backend_active_versions.pop(backend_id, None)
        self._backend_latest_in_progress_versions.pop(backend_id, None)
        self._validated_pbs[VersionName.BACKEND].pop(backend_id, None)

    def _forget_endpoint_set(self, endpoint_set_id):
        self._endpoint_set_latest_valid_versions.pop(endpoint_set_id, None)
        self._endpoint_set_latest_versions.pop(endpoint_set_id, None)
        self._endpoint_set_active_versions.pop(endpoint_set_id, None)
        self._endpoint_set_latest_in_progress_versions.pop(endpoint_set_id, None)
        self._validated_pbs[VersionName.ENDPOINT_SET].pop(endpoint_set_id, None)

    def _forget_knob(self, knob_id):
        self._knob_latest_valid_versions.pop(knob_id, None)
        self._knob_latest_versions.pop(knob_id, None)
        self._knob_active_versions.pop(knob_id, None)
        self._knob_latest_in_progress_versions.pop(knob_id, None)
        self._validated_pbs[VersionName.KNOB].pop(knob_id, None)

    def _forget_cert(self, cert_id):
        self._cert_latest_valid_versions.pop(cert_id, None)
        self._cert_latest_versions.pop(cert_id, None)
        self._cert_active_versions.pop(cert_id, None)
        self._cert_latest_in_progress_versions.pop(cert_id, None)
        self._validated_pbs[VersionName.CERT].pop(cert_id, None)

    def _forget_weight_section(self, weight_section_id):
        self._weight_section_latest_valid_versions.pop(weight_section_id, None)
        self._weight_section_latest_versions.pop(weight_section_id, None)
        self._weight_section_active_versions.pop(weight_section_id, None)
        self._weight_section_latest_in_progress_versions.pop(weight_section_id, None)
        self._validated_pbs[VersionName.WEIGHT_SECTION].pop(weight_section_id, None)

    def _list_known_full_upstream_ids(self):
        """:rtype: set[(six.text_type, six.text_type)]"""
        return (
                set(self._upstream_latest_valid_versions) |
                set(self._upstream_latest_in_progress_versions) |
                set(self._upstream_latest_versions) |
                set(self._upstream_active_versions)
        )

    def _list_known_full_domain_ids(self):
        """:rtype: set[(six.text_type, six.text_type)]"""
        return (
                set(self._domain_latest_valid_versions) |
                set(self._domain_latest_in_progress_versions) |
                set(self._domain_latest_versions) |
                set(self._domain_active_versions)
        )

    def _list_known_full_backend_ids(self):
        """:rtype: set[(six.text_type, six.text_type)]"""
        return (
                set(self._backend_latest_valid_versions) |
                set(self._backend_latest_in_progress_versions) |
                set(self._backend_latest_versions) |
                set(self._backend_active_versions)
        )

    def _list_known_full_endpoint_set_ids(self):
        """:rtype: set[(six.text_type, six.text_type)]"""
        return (
                set(self._endpoint_set_latest_valid_versions) |
                set(self._endpoint_set_latest_in_progress_versions) |
                set(self._endpoint_set_latest_versions) |
                set(self._endpoint_set_active_versions)
        )

    def _list_known_full_knob_ids(self):
        """:rtype: set[(six.text_type, six.text_type)]"""
        return (
                set(self._knob_latest_valid_versions) |
                set(self._knob_latest_in_progress_versions) |
                set(self._knob_latest_versions) |
                set(self._knob_active_versions)
        )

    def _list_known_full_cert_ids(self):
        """:rtype: set[(six.text_type, six.text_type)]"""
        return (
                set(self._cert_latest_valid_versions) |
                set(self._cert_latest_in_progress_versions) |
                set(self._cert_latest_versions) |
                set(self._cert_active_versions)
        )

    def _list_known_full_weight_section_ids(self):
        """:rtype: set[(six.text_type, six.text_type)]"""
        return (
                set(self._weight_section_latest_valid_versions) |
                set(self._weight_section_latest_in_progress_versions) |
                set(self._weight_section_latest_versions) |
                set(self._weight_section_active_versions)
        )

    def get_validated_last_transition_time_in_microseconds(self, version):
        rev = L7BalancerStateHandler(self._balancer_state_pb).select_rev(version)
        if rev is None or not rev.pb.validated.HasField('last_transition_time'):
            return None
        return rev.pb.validated.last_transition_time.ToMicroseconds()

    def update(self, balancer_state_pb):
        """
        :type balancer_state_pb: awacs.proto.model_pb2.BalancerState
        """
        assert balancer_state_pb.namespace_id == self._namespace_id
        assert balancer_state_pb.balancer_id == self._balancer_id
        self._balancer_state_pb = balancer_state_pb

        self._reset()

        self._balancer_latest_in_progress_version = None
        for rev_pb in balancer_state_pb.balancer.statuses:
            v = BalancerVersion.from_rev_status_pb((self._namespace_id, self._balancer_id), rev_pb)
            self._balancer_latest_version = newer(self._balancer_latest_version, v)
            if v == self._balancer_latest_version:
                self._validated_pbs[VersionName.BALANCER][v.balancer_id] = rev_pb.validated
            if rev_pb.validated.status == 'True':
                self._balancer_latest_valid_version = newer(self._balancer_latest_valid_version, v)
            if rev_pb.in_progress.status == 'True':
                self._balancer_latest_in_progress_version = newer(
                    self._balancer_latest_in_progress_version, v)
            if rev_pb.active.status == 'True':
                self._balancer_active_version = newer(self._balancer_active_version, v)

        for domain_id, domain_state_pb in six.iteritems(balancer_state_pb.domains):
            full_domain_id = to_full_id(self._namespace_id, domain_id)
            if full_domain_id in self._domain_latest_in_progress_versions:
                del self._domain_latest_in_progress_versions[full_domain_id]

            for rev_pb in domain_state_pb.statuses:
                v = DomainVersion.from_rev_status_pb(full_domain_id, rev_pb)

                current_latest = self._domain_latest_versions.get(full_domain_id)
                new_latest = newer(current_latest, v)
                if new_latest != current_latest:
                    self._domain_latest_versions[full_domain_id] = new_latest
                    self._validated_pbs[VersionName.DOMAIN][full_domain_id] = rev_pb.validated

                if rev_pb.validated.status == 'True':
                    self._domain_latest_valid_versions[full_domain_id] = newer(
                        self._domain_latest_valid_versions.get(full_domain_id), v)

                if rev_pb.in_progress.status == 'True':
                    self._domain_latest_in_progress_versions[full_domain_id] = newer(
                        self._domain_latest_in_progress_versions.get(full_domain_id), v)

                if rev_pb.active.status == 'True':
                    self._domain_active_versions[full_domain_id] = newer(
                        self._domain_active_versions.get(full_domain_id), v)

        for full_domain_id in self._list_known_full_domain_ids():
            flat_domain_id = flatten_full_id(self._namespace_id, full_domain_id)
            if flat_domain_id not in balancer_state_pb.domains:
                self._forget_domain(full_domain_id)

        for upstream_id, upstream_state_pb in six.iteritems(balancer_state_pb.upstreams):
            full_upstream_id = to_full_id(self._namespace_id, upstream_id)
            if full_upstream_id in self._upstream_latest_in_progress_versions:
                del self._upstream_latest_in_progress_versions[full_upstream_id]

            for rev_pb in upstream_state_pb.statuses:
                v = UpstreamVersion.from_rev_status_pb(full_upstream_id, rev_pb)

                current_latest = self._upstream_latest_versions.get(full_upstream_id)
                new_latest = newer(current_latest, v)
                if new_latest != current_latest:
                    self._upstream_latest_versions[full_upstream_id] = new_latest
                    self._validated_pbs[VersionName.UPSTREAM][full_upstream_id] = rev_pb.validated

                if rev_pb.validated.status == 'True':
                    self._upstream_latest_valid_versions[full_upstream_id] = newer(
                        self._upstream_latest_valid_versions.get(full_upstream_id), v)

                if rev_pb.in_progress.status == 'True':
                    self._upstream_latest_in_progress_versions[full_upstream_id] = newer(
                        self._upstream_latest_in_progress_versions.get(full_upstream_id), v)

                if rev_pb.active.status == 'True':
                    self._upstream_active_versions[full_upstream_id] = newer(
                        self._upstream_active_versions.get(full_upstream_id), v)

        for full_upstream_id in self._list_known_full_upstream_ids():
            flat_upstream_id = flatten_full_id(self._namespace_id, full_upstream_id)
            if flat_upstream_id not in balancer_state_pb.upstreams:
                self._forget_upstream(full_upstream_id)

        for backend_id, backend_state_pb in six.iteritems(balancer_state_pb.backends):
            full_backend_id = to_full_id(self._namespace_id, backend_id)
            if full_backend_id in self._backend_latest_in_progress_versions:
                del self._backend_latest_in_progress_versions[full_backend_id]

            for rev_pb in backend_state_pb.statuses:
                v = BackendVersion.from_rev_status_pb(full_backend_id, rev_pb)

                current_latest = self._backend_latest_versions.get(full_backend_id)
                new_latest = newer(current_latest, v)
                if new_latest != current_latest:
                    self._backend_latest_versions[full_backend_id] = new_latest
                    self._validated_pbs[VersionName.BACKEND][full_backend_id] = rev_pb.validated

                if rev_pb.validated.status == 'True':
                    self._backend_latest_valid_versions[full_backend_id] = newer(
                        self._backend_latest_valid_versions.get(full_backend_id), v)

                if rev_pb.in_progress.status == 'True':
                    self._backend_latest_in_progress_versions[full_backend_id] = newer(
                        self._backend_latest_in_progress_versions.get(full_backend_id), v)

                if rev_pb.active.status == 'True':
                    self._backend_active_versions[full_backend_id] = newer(
                        self._backend_active_versions.get(full_backend_id), v)

        for full_backend_id in self._list_known_full_backend_ids():
            flat_backend_id = flatten_full_id(self._namespace_id, full_backend_id)
            if flat_backend_id not in balancer_state_pb.backends:
                self._forget_backend(full_backend_id)

        for endpoint_set_id, endpoint_set_state_pb in six.iteritems(balancer_state_pb.endpoint_sets):
            full_endpoint_set_id = to_full_id(self._namespace_id, endpoint_set_id)

            if full_endpoint_set_id in self._endpoint_set_latest_in_progress_versions:
                del self._endpoint_set_latest_in_progress_versions[full_endpoint_set_id]

            for rev_pb in endpoint_set_state_pb.statuses:
                v = EndpointSetVersion.from_rev_status_pb(full_endpoint_set_id, rev_pb)

                current_latest = self._endpoint_set_latest_versions.get(full_endpoint_set_id)
                new_latest = newer(current_latest, v)
                if new_latest != current_latest:
                    self._endpoint_set_latest_versions[full_endpoint_set_id] = new_latest
                    self._validated_pbs[VersionName.ENDPOINT_SET][full_endpoint_set_id] = rev_pb.validated

                if rev_pb.validated.status == 'True':
                    self._endpoint_set_latest_valid_versions[full_endpoint_set_id] = newer(
                        self._endpoint_set_latest_valid_versions.get(full_endpoint_set_id), v)

                if rev_pb.in_progress.status == 'True':
                    self._endpoint_set_latest_in_progress_versions[full_endpoint_set_id] = newer(
                        self._endpoint_set_latest_in_progress_versions.get(full_endpoint_set_id), v)

                if rev_pb.active.status == 'True':
                    self._endpoint_set_active_versions[full_endpoint_set_id] = newer(
                        self._endpoint_set_active_versions.get(full_endpoint_set_id), v)

        for full_endpoint_set_id in self._list_known_full_endpoint_set_ids():
            flat_endpoint_set_id = flatten_full_id(self._namespace_id, full_endpoint_set_id)
            if flat_endpoint_set_id not in balancer_state_pb.endpoint_sets:
                self._forget_endpoint_set(full_endpoint_set_id)

        for knob_id, knob_state_pb in six.iteritems(balancer_state_pb.knobs):
            full_knob_id = to_full_id(self._namespace_id, knob_id)

            if full_knob_id in self._knob_latest_in_progress_versions:
                del self._knob_latest_in_progress_versions[full_knob_id]

            for rev_pb in knob_state_pb.statuses:
                v = KnobVersion.from_rev_status_pb(full_knob_id, rev_pb)

                current_latest = self._knob_latest_versions.get(full_knob_id)
                new_latest = newer(current_latest, v)
                if new_latest != current_latest:
                    self._knob_latest_versions[full_knob_id] = new_latest
                    self._validated_pbs[VersionName.KNOB][full_knob_id] = rev_pb.validated

                if rev_pb.validated.status == 'True':
                    self._knob_latest_valid_versions[full_knob_id] = newer(
                        self._knob_latest_valid_versions.get(full_knob_id), v)

                if rev_pb.in_progress.status == 'True':
                    self._knob_latest_in_progress_versions[full_knob_id] = newer(
                        self._knob_latest_in_progress_versions.get(full_knob_id), v)

                if rev_pb.active.status == 'True':
                    self._knob_active_versions[full_knob_id] = newer(
                        self._knob_active_versions.get(full_knob_id), v)

        for full_knob_id in self._list_known_full_knob_ids():
            flat_knob_id = flatten_full_id(self._namespace_id, full_knob_id)
            if flat_knob_id not in balancer_state_pb.knobs:
                self._forget_knob(full_knob_id)

        for cert_id, cert_state_pb in six.iteritems(balancer_state_pb.certificates):
            full_cert_id = to_full_id(self._namespace_id, cert_id)

            if full_cert_id in self._cert_latest_in_progress_versions:
                del self._cert_latest_in_progress_versions[full_cert_id]

            for rev_pb in cert_state_pb.statuses:
                v = CertVersion.from_rev_status_pb(full_cert_id, rev_pb)

                current_latest = self._cert_latest_versions.get(full_cert_id)
                new_latest = newer(current_latest, v)
                if new_latest != current_latest:
                    self._cert_latest_versions[full_cert_id] = new_latest
                    self._validated_pbs[VersionName.CERT][full_cert_id] = rev_pb.validated

                if rev_pb.validated.status == 'True':
                    self._cert_latest_valid_versions[full_cert_id] = newer(
                        self._cert_latest_valid_versions.get(full_cert_id), v)

                if rev_pb.in_progress.status == 'True':
                    self._cert_latest_in_progress_versions[full_cert_id] = newer(
                        self._cert_latest_in_progress_versions.get(full_cert_id), v)

                if rev_pb.active.status == 'True':
                    self._cert_active_versions[full_cert_id] = newer(
                        self._cert_active_versions.get(full_cert_id), v)

        for full_cert_id in self._list_known_full_cert_ids():
            flat_cert_id = flatten_full_id(self._namespace_id, full_cert_id)
            if flat_cert_id not in balancer_state_pb.certificates:
                self._forget_cert(full_cert_id)

        for weight_section_id, weight_section_state_pb in six.iteritems(balancer_state_pb.weight_sections):
            full_weight_section_id = to_full_id(self._namespace_id, weight_section_id)

            if full_weight_section_id in self._weight_section_latest_in_progress_versions:
                del self._weight_section_latest_in_progress_versions[full_weight_section_id]

            for rev_pb in weight_section_state_pb.statuses:
                v = objects.WeightSection.version.from_rev_status_pb(full_weight_section_id, rev_pb)

                current_latest = self._weight_section_latest_versions.get(full_weight_section_id)
                new_latest = newer(current_latest, v)
                if new_latest != current_latest:
                    self._weight_section_latest_versions[full_weight_section_id] = new_latest
                    self._validated_pbs[VersionName.WEIGHT_SECTION][full_weight_section_id] = rev_pb.validated

                if rev_pb.validated.status == 'True':
                    self._weight_section_latest_valid_versions[full_weight_section_id] = newer(
                        self._weight_section_latest_valid_versions.get(full_weight_section_id), v)

                if rev_pb.in_progress.status == 'True':
                    self._weight_section_latest_in_progress_versions[full_weight_section_id] = newer(
                        self._weight_section_latest_in_progress_versions.get(full_weight_section_id), v)

                if rev_pb.active.status == 'True':
                    self._weight_section_active_versions[full_weight_section_id] = newer(
                        self._weight_section_active_versions.get(full_weight_section_id), v)

        for full_weight_section_id in self._list_known_full_weight_section_ids():
            flat_weight_section_id = flatten_full_id(self._namespace_id, full_weight_section_id)
            if flat_weight_section_id not in balancer_state_pb.weight_sections:
                self._forget_weight_section(full_weight_section_id)
