import six

from awacs.lib.strutils import to_full_id
from awacs.lib.vectors import cacheutil
from awacs.lib.vectors.state_handler import StateHandler
from awacs.lib.vectors.vector import Vector
from awacs.lib.vectors.vector_discovered import DiscoveredVectorWithBackends
from awacs.lib.vectors.vector_mutable import MutableVectorWithBackends
from awacs.lib.vectors.version import DnsRecordVersion, BackendVersion, EndpointSetVersion
from awacs.model.cache import IAwacsCache
from awacs.model.util import get_balancer_location
from sepelib.core import config
from infra.awacs.proto import model_pb2


class DnsRecordVector(Vector):
    __main_version_class__ = DnsRecordVersion
    __version_classes__ = (BackendVersion, EndpointSetVersion)

    def __init__(self, dns_record_version, backend_versions, endpoint_set_versions):
        self.dns_record_version = dns_record_version
        self.backend_versions = backend_versions
        self.endpoint_set_versions = endpoint_set_versions
        super(DnsRecordVector, self).__init__()


class MutableDnsRecordVector(MutableVectorWithBackends):
    __main_version_class__ = DnsRecordVersion
    __version_classes__ = (BackendVersion, EndpointSetVersion)

    def __init__(self, dns_record_version, backend_versions, endpoint_set_versions, validated_pbs):
        self.dns_record_version = dns_record_version
        super(MutableDnsRecordVector, self).__init__(
            backend_versions=backend_versions,
            endpoint_set_versions=endpoint_set_versions,
            validated_pbs=validated_pbs)

    def validate(self, ctx):
        """
        :type ctx: context.OpCtx
        :raises: ValidationError
        """
        assert self.dns_record_version
        namespace_id, dns_record_id = self.dns_record_version.id
        dns_record_spec_pb = cacheutil.must_get_dns_record_revision_spec_with_cache(
            self._cache, namespace_id, dns_record_id, self.dns_record_version.version)
        included_backend_ids = get_included_backend_ids(namespace_id, dns_record_spec_pb)
        self.validate_backends_and_endpoint_sets(ctx, namespace_id, included_backend_ids)


class DiscoveredDnsRecordVector(DiscoveredVectorWithBackends):
    __main_version_class__ = DnsRecordVersion
    __version_classes__ = (BackendVersion, EndpointSetVersion)

    def __init__(self, dns_record_version, backend_versions, endpoint_set_versions):
        self.dns_record_version = dns_record_version
        super(DiscoveredDnsRecordVector, self).__init__(
            backend_versions=backend_versions,
            endpoint_set_versions=endpoint_set_versions,
        )

    @classmethod
    def from_cache(cls, namespace_id, main_id):
        dns_record_version = cacheutil.must_get_dns_record_version(cls._cache, namespace_id, main_id)
        backend_versions = cacheutil.get_backend_versions(cls._cache, namespace_id)
        endpoint_set_versions = cacheutil.get_endpoint_set_versions(cls._cache, namespace_id)
        return cls(dns_record_version, backend_versions, endpoint_set_versions)

    def get_included_backends(self, version):
        ns_id, dns_record_id = self.dns_record_version.id
        dns_record_spec_pb = cacheutil.must_get_dns_record_revision_spec_with_cache(
            self._cache, ns_id, dns_record_id, version.version)
        return get_included_backend_ids(ns_id, dns_record_spec_pb)

    def _get_backend_and_es_versions_to_delete(self, vectors, latest_included_backend_versions):
        # for DNS record "current" vector plays the role of the "active" vector from the parent implementation,
        # so we must specialize it
        versions = set()
        if not vectors.current.main_version:
            return versions
        # first, collect backends/ES that are included in current spec
        included_current_backend_ids = self.get_included_backends(vectors.current.main_version)
        # then go through all current backend/ES versions and find something to delete
        for field in (u'backend_versions', u'endpoint_set_versions'):
            current_versions = vectors.current.must_get_version_dict(field)
            for current_id, current_version in six.iteritems(current_versions):

                # if it's present in current vector, it's useful
                if current_id in included_current_backend_ids:
                    continue

                # if it's present in the latest vector and not deleted there, it's also useful
                if (current_id in latest_included_backend_versions and
                        not latest_included_backend_versions[current_id].deleted):
                    continue

                # otherwise it means it's safe to delete from state
                versions.add(current_version)
        return versions


class DnsRecordStateHandler(StateHandler):
    __slots__ = ()
    __protobuf__ = model_pb2.DnsRecordState
    __vector_class__ = DnsRecordVector
    __mutable_vector_class__ = MutableDnsRecordVector
    __zk_update_method__ = u'update_dns_record_state'

    @property
    def full_id(self):
        return to_full_id(self._pb.namespace_id, self._pb.dns_record_id)


def get_dns_record_versions(cache, namespace_id):
    """
    :type cache: AwacsCache
    :param six.text_type namespace_id:
    :rtype: dict[(six.text_type, six.text_type), DnsRecordVersion]
    """
    return {(dns_record_pb.meta.namespace_id, dns_record_pb.meta.id): DnsRecordVersion.from_pb(dns_record_pb)
            for dns_record_pb in cache.list_all_dns_records(namespace_id)}


def get_included_backend_ids(namespace_id, dns_record_spec_pb):
    """
    :type namespace_id: six.text_type
    :type dns_record_spec_pb: infra.awacs.proto.model_pb2.DnsRecordSpec
    :rtype: set[tuple[six.text_type, six.text_type]]
    """
    c = IAwacsCache.instance()
    backends_selector_pb = dns_record_spec_pb.address.backends
    if backends_selector_pb.type == model_pb2.DnsBackendsSelector.EXPLICIT:
        return [(namespace_id, backend_pb.id) for backend_pb in backends_selector_pb.backends]
    if backends_selector_pb.type == model_pb2.DnsBackendsSelector.BALANCERS:
        # system backend's id always matches balancer id
        rv = []
        for selector_balancer_pb in backends_selector_pb.balancers:
            balancer_pb = c.get_balancer(namespace_id, selector_balancer_pb.id)
            if balancer_pb is not None:
                location = get_balancer_location(balancer_pb)
                if config.get_value('run.unavailable_clusters.{}.remove_l7_from_dns_records'
                                    .format(location.lower()), False):
                    continue
            rv.append((namespace_id, selector_balancer_pb.id))
        return rv
    if backends_selector_pb.type == model_pb2.DnsBackendsSelector.L3_BALANCERS:
        return []
    else:
        raise AssertionError(u'Unknown DNS record backend selector type {}'.format(backends_selector_pb.type))


def get_fqdn(dns_record_pb):
    if dns_record_pb.spec.incomplete:
        field_pb = dns_record_pb.order.content
    else:
        field_pb = dns_record_pb.spec
    name_server_pb = IAwacsCache.instance().must_get_name_server(
        namespace_id=field_pb.name_server.namespace_id,
        name_server_id=field_pb.name_server.id)
    return u'{}.{}.'.format(field_pb.address.zone, name_server_pb.spec.zone)


def get_nameserver_type(dns_record_pb):
    if dns_record_pb.spec.incomplete:
        field_pb = dns_record_pb.order.content
    else:
        field_pb = dns_record_pb.spec
    name_server_pb = IAwacsCache.instance().must_get_name_server(
        namespace_id=field_pb.name_server.namespace_id,
        name_server_id=field_pb.name_server.id)
    return name_server_pb.spec.type
