import collections

import dns
import dns.resolver
import inject
import ipaddr
import six


resolve_response = collections.namedtuple('resolve_response', ('cnames_chain', 'ipv4_addrs', 'ipv6_addrs'))


class IDnsResolver(object):
    """
    Interface to be used in dependency injection.
    """

    @classmethod
    def instance(cls):
        """
        :rtype: DnsResolver
        """
        return inject.instance(cls)


class DnsResolver(object):
    MAX_CNAME_CHAIN_LEN = 10
    _resolver = dns.resolver.Resolver()

    def _get_terminal_domain_name(self, domain_name):
        """
        Follow the chain of CNAME records to reach the actual A/AAAA records.
        https://tools.ietf.org/html/rfc1034: If a CNAME RR is present at a node, no other data should be present

        :type domain_name: six.text_type
        :rtype: tuple[dns.rdata.Rdata, list[six.text_type]]
        """
        cnames_chain = []
        next_domain_name = dns.name.from_text(domain_name)
        counter = 0
        while True:
            cname = self._resolver.query(next_domain_name, dns.rdatatype.CNAME, raise_on_no_answer=False)
            if not cname.rrset:
                return next_domain_name, cnames_chain
            counter += 1
            if counter > self.MAX_CNAME_CHAIN_LEN:
                raise ValueError(u'CNAME chain is longer than 10 entries')
            next_domain_name = cname.rrset[0].target  # by the same RFC, at most one CNAME record can be present
            if six.PY3:
                target = cname.rrset[0].target.to_text(omit_final_dot=True)
            else:
                target = cname.rrset[0].target
                if isinstance(target, six.binary_type):
                    target = target.decode('utf-8').rstrip(u'.')
            cnames_chain.append(target)

    def resolve(self, domain_name):
        """
        Find all CNAMEs that domain_name references, and the final IPv4 & IPv6 addresses it points to

        :type domain_name: six.text_type
        :rtype: resolve_response
        """
        fqdn, cnames_chain = self._get_terminal_domain_name(domain_name)
        return resolve_response(
            cnames_chain=cnames_chain,
            ipv4_addrs=sorted((r.address
                               for r in self._resolver.query(fqdn, dns.rdatatype.A, raise_on_no_answer=False))),
            ipv6_addrs=sorted((r.address
                               for r in self._resolver.query(fqdn, dns.rdatatype.AAAA, raise_on_no_answer=False)))
        )

    def get_address_record(self, domain_name, record_type):
        """
        Resolve a single address record of chosen type

        :type domain_name: six.text_type
        :type record_type: six.text_type | int
        :rtype: set[six.text_type]
        """
        try:
            return {r.address for r in self._resolver.query(domain_name, record_type, raise_on_no_answer=False)}
        except dns.resolver.NXDOMAIN:
            return set()

    @staticmethod
    def get_record_type_by_address(ip_address):
        """
        Get DNS address record type by IP address

        :type ip_address.text_type
        :rtype: int
        """
        ip_addr = ipaddr.IPAddress(ip_address)
        if isinstance(ip_addr, ipaddr.IPv6Address):
            return 'AAAA'
        else:
            return 'A'
