import ipaddress
import logging
import dns.name
import dns.rdatatype
import dns.resolver

from walle.util import net

from .interface import DnsError, DnsMultipleEndpoints, DnsValueError


log = logging.getLogger(__name__)


class DnsRecordNotFound(DnsError):
    def __init__(self, qname, exc, **kwargs):
        super().__init__("DNS record {} not found: {}".format(repr(qname), exc), qname=qname, exc=exc, **kwargs)


class LocalDNSResolver:
    def __init__(self):
        self._resolver = dns.resolver.get_default_resolver()

    def query(self, qname, *args, **kwargs):
        return self._resolver.query(qname, *args, **kwargs)

    def get_zone_for_name(self, qname):
        return dns.resolver.zone_for_name(qname, resolver=self._resolver)

    def query_ns(self, query, record_type):
        rdtype = dns.rdatatype.from_text(record_type)
        local_resolver_dns_response, local_resolver_exception = self._query_resolver(query, rdtype, record_type)
        if local_resolver_dns_response is not None:
            local_resolver_dns_response.sort()

        if local_resolver_exception:
            raise local_resolver_exception

        return local_resolver_dns_response

    def _query_resolver(self, query, rdtype, record_type):
        resolver_exception = None
        dns_response = None
        try:
            dns_response = self.query(query, rdtype, raise_on_no_answer=True)
        except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer) as e:
            resolver_exception = DnsRecordNotFound(query, e, record_type=record_type)
        except dns.resolver.NoNameservers as e:
            resolver_exception = DnsValueError(
                "Can not resolve {} record {}: {}", record_type, query, e, qname=query, record_type=record_type
            )

        if resolver_exception:
            return None, resolver_exception

        return self._parse_dns_response(query, dns_response, record_type), None

    @staticmethod
    def _parse_dns_response(query, dns_response, record_type) -> list[str]:
        results = []
        for rrset in dns_response.response.answer:
            for rdataset in rrset:
                # Add only A, AAAA and PTR records
                result = str(rdataset.to_text())
                if record_type in ["A", "AAAA"] and rdataset.rdtype in [dns.rdatatype.A, dns.rdatatype.AAAA]:
                    try:
                        result = net.explode_ip(result)
                    except ValueError:
                        log.error(
                            "Malformed response from dns for query=%r, record_type=%r: %r", query, record_type, result
                        )
                    results.append(result)
                elif record_type == "PTR" and rdataset.rdtype == dns.rdatatype.PTR:
                    results.append(result)
        return results

    def get_ptr(self, ip_address):
        ptr = reverse_name_from_address(ip_address)
        try:
            current_hostname = self.query_ns(ptr, "PTR")
            if len(current_hostname) > 1:
                raise DnsMultipleEndpoints(ptr, current_hostname)
            current_hostname = current_hostname.pop()
        except DnsRecordNotFound:
            current_hostname = None
        return current_hostname

    def get_ip_address(self, hostname, record_type):
        try:
            current_ip_address = self.query_ns(hostname, record_type)
            if len(current_ip_address) > 1:
                raise DnsMultipleEndpoints(hostname, current_ip_address, record_type=record_type)
            current_ip_address = current_ip_address.pop()
        except DnsRecordNotFound:
            current_ip_address = None
        return current_ip_address

    def get_aaaa(self, hostname) -> str:
        return self.get_ip_address(hostname, "AAAA")

    def get_a(self, hostname) -> str:
        return self.get_ip_address(hostname, "A")


def reverse_name_from_address(ip_addr):
    ip_addr = ipaddress.ip_address(ip_addr)
    if ip_addr.version == 4:
        reverse_octets = str(ip_addr).split(".")[::-1]
        return ".".join(reverse_octets) + ".in-addr.arpa."
    elif ip_addr.version == 6:
        reverse_chars = ip_addr.exploded[::-1].replace(":", "")
        return ".".join(reverse_chars) + ".ip6.arpa."
    else:
        raise ipaddress.AddressValueError("Unexpected IP address version: {}".format(repr(ip_addr.version)))
