import inject
import six
from six.moves import map

from awacs.model import cache, zk, util
from awacs.model.l3_balancer import cacheutil
from awacs.model.l3_balancer.state_handler import L3BalancerStateHandler


class Discoverer(object):
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache
    _zk = inject.attr(zk.IZkStorage)  # type: zk.ZkStorage

    def __init__(self, namespace_id, l3_balancer_id):
        self._namespace_id = namespace_id
        self._l3_balancer_id = l3_balancer_id
        self._is_large_namespace = util.is_large_namespace(namespace_id)

    def discover(self, ctx, l3_balancer_state_pb):
        """
        1) gather all revisions of L3 balancer, backend, and endpoint sets in namespace
        2) determine which of them are included in in current and valid vectors
        3) and add them to the L3 balancer state, so it can be validated

        :type l3_balancer_state_pb: infra.awacs.proto.model_pb2.L3BalancerState
        :type ctx: context.OpCtx
        :rtype: bool
        """
        ctx = ctx.with_op(op_id=u'discoverer')

        # gather all revisions from namespace
        namespace_l3_balancer_version = cacheutil.get_l3_balancer_version(self._cache, self._namespace_id,
                                                                          self._l3_balancer_id)
        namespace_backend_versions = cacheutil.get_backend_versions(self._cache, self._namespace_id)
        namespace_endpoint_set_versions = cacheutil.get_endpoint_set_versions(self._cache, self._namespace_id)
        (curr_vector,
         valid_vector,
         in_progress_vector,
         active_vector) = cacheutil.l3_balancer_state_to_vectors(l3_balancer_state_pb)

        # select all L3 balancer revisions that are contained in current and valid vectors
        versions_to_add = set()
        l3_balancer_versions_to_explore = set()
        if namespace_l3_balancer_version is not None:
            l3_balancer_versions_to_explore.add(namespace_l3_balancer_version)
            if curr_vector.balancer_version is None or namespace_l3_balancer_version > curr_vector.balancer_version:
                versions_to_add.add(namespace_l3_balancer_version)
        if curr_vector.balancer_version is not None:
            l3_balancer_versions_to_explore.add(curr_vector.balancer_version)
        if valid_vector.balancer_version is not None:
            l3_balancer_versions_to_explore.add(valid_vector.balancer_version)

        # find all backends that are included in interesting L3 balancers
        included_backend_ids = set()
        for l3_balancer_version in sorted(l3_balancer_versions_to_explore, reverse=True):
            l3_balancer_spec_pb = cacheutil.find_l3_balancer_revision_spec_and_use_cache(
                self._namespace_id, self._l3_balancer_id, l3_balancer_version.version)
            included_backend_ids.update(cacheutil.get_included_backend_ids(l3_balancer_spec_pb))

        # filter out backends that don't exist or are marked as deleted
        existing_included_backend_ids = included_backend_ids & set(namespace_backend_versions)
        for backend_id in existing_included_backend_ids:
            namespace_backend_version = namespace_backend_versions[backend_id]
            current_backend_version = curr_vector.backend_versions.get(backend_id)
            if not namespace_backend_version.deleted and not current_backend_version:
                versions_to_add.add(namespace_backend_version)
            if current_backend_version and namespace_backend_version > current_backend_version:
                versions_to_add.add(namespace_backend_version)

        # filter out endpoint sets that don't exist or are marked as deleted
        existing_included_endpoint_set_ids = included_backend_ids & set(namespace_endpoint_set_versions)
        for endpoint_set_id in existing_included_endpoint_set_ids:
            namespace_endpoint_set_version = namespace_endpoint_set_versions[endpoint_set_id]
            current_endpoint_set_version = curr_vector.endpoint_set_versions.get(endpoint_set_id)
            if not namespace_endpoint_set_version.deleted and not current_endpoint_set_version:
                versions_to_add.add(namespace_endpoint_set_version)
            if current_endpoint_set_version and namespace_endpoint_set_version > current_endpoint_set_version:
                versions_to_add.add(namespace_endpoint_set_version)

        backend_ids_to_delete = set()
        if active_vector.balancer_version:
            l3_balancer_active_spec_pb = cacheutil.find_l3_balancer_revision_spec_and_use_cache(
                self._namespace_id, self._l3_balancer_id, active_vector.balancer_version.version)
            active_included_backend_ids = cacheutil.get_included_backend_ids(l3_balancer_active_spec_pb)

            for active_backend_id, active_backend_version in six.iteritems(active_vector.backend_versions):
                if active_backend_id in active_included_backend_ids:
                    if active_backend_version.deleted:
                        backend_ids_to_delete.add(active_backend_id)
                else:
                    backend_ids_to_delete.add(active_backend_id)
        for backend_id, backend_version in six.iteritems(in_progress_vector.backend_versions):
            if active_vector.backend_versions.get(backend_id) != backend_version:
                backend_ids_to_delete.discard(backend_id)
        backend_ids_to_delete -= existing_included_backend_ids
        if versions_to_add or backend_ids_to_delete:
            ctx.log.debug(u'versions_to_add: %s, backend_ids_to_delete: %s',
                          versions_to_add, backend_ids_to_delete)
            updated = cacheutil.modify_state_revisions(namespace_id=self._namespace_id,
                                                       l3_balancer_id=self._l3_balancer_id,
                                                       l3_balancer_state_pb=l3_balancer_state_pb,
                                                       versions_to_add=versions_to_add,
                                                       backend_ids_to_delete=backend_ids_to_delete)
        else:
            updated = False
        return updated

    def clean_l3_balancer_state(self, ctx, l3_balancer_state_pb):
        """
        Remove from state all entities that are older than the latest active vector

        :type ctx: context.OpCtx
        :type l3_balancer_state_pb: infra.awacs.proto.model_pb2.L3BalancerState
        :rtype: bool
        """
        ctx = ctx.with_op(op_id=u'discoverer_clean_l3_balancer_state')
        l3_vectors = cacheutil.l3_balancer_state_to_vectors(l3_balancer_state_pb)
        updated = False
        for l3_balancer_state_pb in self._zk.update_l3_balancer_state(self._namespace_id, self._l3_balancer_id,
                                                                      l3_balancer_state_pb=l3_balancer_state_pb):
            h = L3BalancerStateHandler(l3_balancer_state_pb)
            updated = False

            versions_before = list(h.iter_versions())

            def filter_balancer_revs(l3_status_pb):
                """
                :type l3_status_pb: model_pb2.L3BalancerState.RevisionL3Status
                """
                return l3_vectors.active.balancer_version.ctime > l3_status_pb.ctime.ToMicroseconds()

            if l3_vectors.active.balancer_version:
                updated |= h.select_l3_balancer().omit_revs(filter_balancer_revs)

            for backend_id, backend_active_version in six.iteritems(l3_vectors.active.backend_versions):
                def filter_backend_revs(l3_status_pb):
                    """
                    :type l3_status_pb: model_pb2.L3BalancerState.RevisionL3Status
                    """
                    return (
                            backend_active_version.ctime > l3_status_pb.ctime.ToMicroseconds() or
                            (l3_status_pb.revision_id == backend_active_version.version and l3_status_pb.deleted)
                    )

                updated |= h.select_backend(backend_id).omit_revs(filter_backend_revs)
            updated |= h.remove_backends_wo_revs()

            for endpoint_set_id, endpoint_set_active_version in six.iteritems(l3_vectors.active.endpoint_set_versions):
                def filter_endpoint_set_revs(l3_status_pb):
                    """
                    :type l3_status_pb: model_pb2.L3BalancerState.RevisionL3Status
                    """
                    return (
                            endpoint_set_active_version.ctime > l3_status_pb.ctime.ToMicroseconds() or
                            (l3_status_pb.revision_id == endpoint_set_active_version.version and l3_status_pb.deleted)
                    )

                if endpoint_set_id in h.pb.backends:
                    updated |= h.select_endpoint_set(endpoint_set_id).omit_revs(filter_endpoint_set_revs)
                else:
                    # we don't need endpoint sets whose backends are not present in the state
                    del h.pb.endpoint_sets[endpoint_set_id]
                    updated |= True
            updated |= h.remove_endpoint_sets_wo_revs()

            if updated:
                versions_after = list(h.iter_versions())
                ctx.log.debug(u'clean_l3_balancer_state()')
                ctx.log.debug(u'versions_before: %s', u', '.join(list(map(repr, versions_before))))
                ctx.log.debug(u'versions_after: %s', u', '.join(list(map(repr, versions_after))))
            else:
                break
        return updated
