import inject

from awacs.model import cache
from awacs.model.dns_records import dns_record
from infra.awacs.proto import model_pb2


class DnsRecordValidator(object):
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache

    def __init__(self, namespace_id, dns_record_id):
        self._namespace_id = namespace_id
        self._dns_record_id = dns_record_id

    def validate(self, ctx, state_handler):
        """
        :type state_handler: dns_record.DnsRecordStateHandler
        :type ctx: context.OpCtx
        """
        ctx = ctx.with_op(op_id='validator')
        vectors = state_handler.generate_vectors()
        dns_record_pb = self._cache.must_get_dns_record(self._namespace_id, self._dns_record_id)
        name_server_pb = self._cache.must_get_name_server(dns_record_pb.spec.name_server.namespace_id,
                                                          dns_record_pb.spec.name_server.id)
        if (name_server_pb.spec.type in (model_pb2.NameServerSpec.DNS_MANAGER, model_pb2.NameServerSpec.AWACS_MANAGED) or
                dns_record_pb.spec.address.backends.type == model_pb2.DnsBackendsSelector.L3_BALANCERS):
            return state_handler.mark_versions_as_valid(list(vectors.current))
        elif name_server_pb.spec.type == model_pb2.NameServerSpec.YP_DNS:
            rollback_counter = vectors.current.transform_until_valid(
                ctx,
                valid_vector=vectors.valid,
                state_handler=state_handler,
            )
            if rollback_counter:
                ctx.log.debug('total number of versions that were rolled back: %s', rollback_counter)
        else:
            raise RuntimeError('Unknown name server type: {}'.format(
                model_pb2.NameServerSpec.Type.Name(name_server_pb.spec.type)))
