"""DNS records common routines."""

import logging

from sepelib.core.exceptions import LogicalError
from walle import network
from walle.clients import dns as dns_clients
from walle.hosts import Host
from walle.projects import Project
from walle.util import net

MAX_DNS_MAP_DEPTH = 1  # really?

log = logging.getLogger(__name__)


def _bfs_discover(dns_map, hostname, ip_v4_address_list, ip_v6_address_list):
    hostname = _normalize_fqdn(hostname)

    queue = []
    seen = set()

    stack = []

    queue.append(("fqdn", hostname))
    if ip_v6_address_list:
        queue.extend([("ipv6", ip_v6_address) for ip_v6_address in ip_v6_address_list])
    if ip_v4_address_list:
        queue.extend([("ipv4", ip_v4_address) for ip_v4_address in ip_v4_address_list])

    while queue:
        item_type, item_value = queue.pop(0)
        if item_value in seen:
            continue
        seen.add(item_value)

        if item_type in ["ipv6", "ipv4"]:
            current_hostname = dns_map["PTR"].get(item_value)
            if current_hostname:
                stack.append(("PTR", item_value, current_hostname))
                queue.append(("fqdn", current_hostname))
        else:
            current_v6_ip = dns_map["AAAA"].get(item_value)
            if current_v6_ip:
                stack.append(("AAAA", item_value, current_v6_ip))
                queue.append(("ipv6", current_v6_ip))

            current_v4_ip = dns_map["A"].get(item_value)
            if current_v4_ip:
                stack.append(("A", item_value, current_v4_ip))
                queue.append(("ipv4", current_v4_ip))

    return stack


def _get_ops_from_stack(project: Project, stack, hostname, ip_v4_address_list, ip_v6_address_list):
    hostname = _normalize_fqdn(hostname)

    records_to_add = list()
    records_to_keep = dict()

    def _create_expected_records(record_type, ip_list):
        """Create "expected" records.
        We want to keep or create only one IP per FQDN (and one FQDN per IP-address),
        but we accept any IP from existing DNS records if host have it.
        """
        if ip_list:
            main_record = (record_type, hostname, ip_list[0])
            records_to_add.append(main_record)
            records_to_keep.update({(record_type, hostname, ip): main_record for ip in ip_list})

            main_ptr_record = ("PTR", ip_list[0], hostname)
            records_to_add.append(main_ptr_record)
            records_to_keep.update({("PTR", ip, hostname): main_ptr_record for ip in ip_list})

    _create_expected_records("A", ip_v4_address_list)
    _create_expected_records("AAAA", ip_v6_address_list)

    operations = []

    # Here we will keep records that are worth deleting
    # i.e. ones that are not strongly connected (where PTR points to A that point to the same PTR
    # and vice versa), or ones that match our desired hostname and addresses.
    look_at_records = {}
    for record_type, left, right in stack:  # wrong usage of a stack, I know
        # records refer to our desired hostname or address
        if (
            left == hostname
            or left in ip_v4_address_list
            or left in ip_v6_address_list
            or right == hostname
            or right in ip_v4_address_list
            or right in ip_v6_address_list
        ):
            look_at_records[(record_type, left)] = right
            continue

    # NOTE(rocco66): rurikk_dns has automatic for PTR creation https://st.yandex-team.ru/CLOUD-78777
    box_config = project.get_dns_box_config()
    ignore_ptr = box_config and box_config.should_use_rurikk_dns()
    while stack:
        item = stack.pop()
        record_type, left, right = item
        if (record_type, left) not in look_at_records:
            continue

        if item in records_to_keep:
            try:
                records_to_add.remove(records_to_keep[item])
            except ValueError:
                pass  # already removed
            continue

        if ignore_ptr and record_type == "PTR":
            continue
        operations.append(dns_clients.DnsApiOperation.delete(record_type, left.rstrip("."), right))

    for record_type, left, right in records_to_add:
        if ignore_ptr and record_type == "PTR":
            continue
        operations.append(dns_clients.DnsApiOperation.add(record_type, left.rstrip("."), right.rstrip(".")))

    return operations


def _get_operations_for_one_host(
    hostname, ip_v4_address_list, ip_v6_address_list, statbox_logger, __max_dns_map_depth=MAX_DNS_MAP_DEPTH
):

    project = _find_project(hostname)
    dns_map = _build_dns_map(project, hostname, ip_v4_address_list, ip_v6_address_list, max_depth=__max_dns_map_depth)
    stack = _bfs_discover(dns_map, hostname, ip_v4_address_list, ip_v6_address_list)
    dns_operations = _get_ops_from_stack(
        project,
        stack,
        hostname,
        ip_v4_address_list if ip_v4_address_list else [],
        ip_v6_address_list if ip_v6_address_list else [],
    )

    # We use only the first IP-address for DNS records. Other addresses won't get into DNS.
    logger = statbox_logger.get_child(
        wanted_hostname=hostname,
        wanted_ip_v4_address=ip_v4_address_list[0] if ip_v4_address_list else None,
        wanted_ip_v6_address=ip_v6_address_list[0] if ip_v6_address_list else None,
    )

    for operation in dns_operations:
        logger.log(**operation.to_statbox())

    return dns_operations


def _find_project(hostname):
    try:
        host = Host.objects.only("project").get(name=hostname)
    except Host.DoesNotExist as exc:
        if hostname_without_fb := network.try_to_remove_fb_prefix(hostname):
            host = Host.objects.only("project").get(name=hostname_without_fb)
        else:
            raise exc
    return Project.objects.only("id", "yc_dns_zone_id").get(id=host.project)


def _build_dns_map(project, hostname, ip_v4_address_list, ip_v6_address_list, max_depth=MAX_DNS_MAP_DEPTH):
    dns_map = {
        "A": {},
        "AAAA": {},
        "PTR": {},
    }

    dns_client = project.get_dns()
    hostname_with_dot = _normalize_fqdn(hostname)

    def build_hostname_map(fqdn, depth):
        if depth > max_depth:
            return

        current_ipv4 = dns_map["A"].get(fqdn, dns_client.get_a(fqdn))
        if current_ipv4:
            dns_map["A"][fqdn] = current_ipv4
            build_ip_v4_map(current_ipv4, depth + 1)

        current_ipv6 = dns_map["AAAA"].get(fqdn, dns_client.get_aaaa(fqdn))
        if current_ipv6:
            dns_map["AAAA"][fqdn] = current_ipv6
            build_ip_v6_map(current_ipv6, depth + 1)

    def build_ip_v4_map(ip_v4, depth):
        if depth > max_depth:
            return

        current_hostname = dns_map["PTR"].get(ip_v4, dns_client.get_ptr(ip_v4))
        if current_hostname:
            dns_map["PTR"][ip_v4] = current_hostname
            build_hostname_map(current_hostname, depth + 1)

    def build_ip_v6_map(ip_v6, depth):
        if depth > max_depth:
            return

        current_hostname = dns_map["PTR"].get(ip_v6, dns_client.get_ptr(ip_v6))
        if current_hostname:
            dns_map["PTR"][ip_v6] = current_hostname
            build_hostname_map(current_hostname, depth + 1)

    if hostname_with_dot:
        build_hostname_map(hostname_with_dot, 0)
    if ip_v4_address_list:
        for ip_v4_address in ip_v4_address_list:
            build_ip_v4_map(ip_v4_address, 0)
    if ip_v6_address_list:
        for ip_v6_address in ip_v6_address_list:
            build_ip_v6_map(ip_v6_address, 0)

    return dns_map


def get_operations_for_dns_records(records, statbox_logger) -> list[dns_clients.DnsApiOperation]:
    by_hostname = {}
    for record in records:
        hostname = record.name
        if hostname not in by_hostname:
            by_hostname[hostname] = {}
        if record.type == "A":
            by_hostname[hostname]["ip_v4"] = [net.explode_ip(v) for v in record.value]
        elif record.type == "AAAA":
            by_hostname[hostname]["ip_v6"] = [net.explode_ip(v) for v in record.value]
        else:
            raise LogicalError()

    operations = []
    for hostname, addresses in by_hostname.items():
        operations += _get_operations_for_one_host(
            hostname, addresses.get("ip_v4"), addresses.get("ip_v6"), statbox_logger
        )

    return operations


def get_delete_operations_for_fqdns(fqdns, statbox_logger):
    operations = []
    for fqdn in fqdns:
        operations += _get_operations_for_one_host(fqdn, None, None, statbox_logger)
    return operations


def _normalize_fqdn(fqdn):
    if not fqdn.endswith("."):
        fqdn += "."

    return fqdn
