"""DNS API client"""

import ipaddress
import logging
import re
from collections import namedtuple

from sepelib.core import config
from sepelib.yandex import dns_api
from .interface import (
    DnsClientInterface,
    DnsZoneNotFound,
    DnsInvalidZone,
    DnsCommunicationError,
)
from .local_dns_resolver import LocalDNSResolver

log = logging.getLogger(__name__)

DEFAULT_TTL = 3600

DNS_ACL_REGEXP = re.compile(r"(\w+)\(([a-z0-9A-Z_-]+)\);")

DnsAclKey = namedtuple("MockedGroupInfo", "name, key_type")


class DnsAclKeyType:
    GROUP = "GROUP_KEYS"
    USER = "USER_KEY"
    ROBOT = "ROBOT_KEY"


class DnsClient(DnsClientInterface):
    def __init__(self):
        self._dns_api_client = self._get_client()
        self._local_ns_client = LocalDNSResolver()

    @staticmethod
    def _get_client():
        host = config.get_value("dns_api.host")
        token = config.get_value("dns_api.access_token")
        login = config.get_value("dns_api.user")
        validate_only = config.get_value("dns_api.validate_only")

        base_url = "https://{}/v2.3/".format(host)

        # We don't use cert and key anymore. See WALLE-2994
        return dns_api.DnsApiClient(
            cert=None, key=None, login=login, token=token, validate_only=validate_only, base_url=base_url
        )

    def get_zone_owners(self, zone):
        try:
            zone_info = self._dns_api_client.zone_info(zone)
        except dns_api.DnsApiNotFound as e:
            raise DnsZoneNotFound(zone, e)
        except dns_api.DnsApiError as e:
            raise DnsCommunicationError(e)
        acl_list = zone_info.get("acl-list", None)
        if acl_list is None:
            raise DnsInvalidZone(zone)
        return self._parse_dns_acl(acl_list)

    def is_zone_owner(self, zone):
        try:
            for key in self.get_zone_owners(zone):
                if key.key_type == DnsAclKeyType.USER and key.name == self._dns_api_client.login:
                    return True
            return False
        except dns_api.DnsApiNotFound as e:
            raise DnsZoneNotFound(zone, e)
        except dns_api.DnsApiError as e:
            raise DnsCommunicationError(e)

    @staticmethod
    def _parse_dns_acl(acl_line):
        for match in DNS_ACL_REGEXP.finditer(acl_line):
            key_type = match.group(1)
            name = match.group(2)
            yield DnsAclKey(name, key_type)

    def apply_operations(self, operations):
        return self._dns_api_client.apply_primitives(
            [op.to_slayer_api_primitive() for op in operations],
            show_operations=True,
        )

    def get_aaaa(self, hostname) -> str:
        return self._local_ns_client.get_ip_address(hostname, "AAAA")

    def get_a(self, hostname) -> str:
        return self._local_ns_client.get_ip_address(hostname, "A")

    def get_ptr(self, ip_address):
        return self._local_ns_client.get_ptr(ip_address)


def address_from_reverse_name(reverse_name, raw=False):
    ipv4_domain = ".in-addr.arpa."
    ipv6_domain = ".ip6.arpa."
    reverse_fqdn = reverse_name if reverse_name.endswith(".") else reverse_name + "."

    if reverse_fqdn.endswith(ipv4_domain):
        parts = reversed(reverse_fqdn[: -len(ipv4_domain)].split("."))
        ip = ".".join([p for p in parts if "/" not in p])

        address = ipaddress.IPv4Address(ip)
    elif reverse_fqdn.endswith(ipv6_domain):
        if len(reverse_fqdn) != 16 * 2 * 2 - 1 + len(ipv6_domain):
            raise ipaddress.AddressValueError("Invalid reverse pointer name: {}".format(reverse_name))

        parts = reverse_fqdn[: -len(ipv6_domain)].split(".")[::-1]
        parts = [p for p in parts if "/" not in p]

        ip = ":".join("".join(parts[i * 4 : i * 4 + 4]) for i in range(8))
        address = ipaddress.IPv6Address(ip)
    else:
        raise ipaddress.AddressValueError("Invalid reverse pointer name: {}".format(reverse_name))

    return address if raw else str(address)


def network_from_reverse_name(reverse_name, raw=False):
    ipv4_domain = ".in-addr.arpa."
    ipv6_domain = ".ip6.arpa."
    reverse_fqdn = reverse_name if reverse_name.endswith(".") else reverse_name + "."

    if reverse_fqdn.endswith(ipv4_domain):
        parts = reverse_fqdn[: -len(ipv4_domain)].split(".")[::-1]

        num_parts = len(parts)
        parts.extend(["0"] * (4 - num_parts))
        ip = ".".join(parts)

        if "/" not in ip:
            ip += "/" + str(8 * num_parts)

        address = ipaddress.IPv4Network(ip)
    elif reverse_fqdn.endswith(ipv6_domain):
        parts = reverse_fqdn[: -len(ipv6_domain)].split(".")[::-1]

        num_parts = len(parts)
        parts.extend(["0"] * (32 - num_parts))
        ip = ":".join("".join(parts[i * 4 : i * 4 + 4]) for i in range(8))

        if "/" not in ip:
            ip += "/" + str(num_parts * 4)
        address = ipaddress.IPv6Network(ip)
    else:
        raise ipaddress.AddressValueError("Invalid reverse pointer name: {}".format(reverse_name))

    return address if raw else str(address)


def is_domain_in_zone(qname, zone):
    """Get a qname and a zone name and check if given qname matches the give zone."""
    # ensure we have names that end with dots.
    if not qname.endswith("."):
        qname += "."
    if not zone.endswith("."):
        zone += "."

    if zone.endswith(".arpa.") and qname.endswith(".arpa."):
        address = address_from_reverse_name(qname, raw=True)
        network = network_from_reverse_name(zone, raw=True)
        return address in network
    else:
        return qname.endswith("." + zone)
