import datetime
import socket

import dns.resolver
from django.conf import settings

from infra.cauth.server.master.api.models import DNS_STATUS, ServerFlags

dns_resolver = dns.resolver.Resolver()
dns_resolver.timeout = 0.3
dns_resolver.lifetime = 1


def get_family_ips(fqdn, family):
    try:
        return {ret[4][0] for ret in socket.getaddrinfo(fqdn, None, family)}
    # UnicodeError падает если длинна одного из поддоменов больше 63 символов
    except (socket.gaierror, UnicodeError):
        return set()


def get_ips(fqdn):
    ips_v4 = get_family_ips(fqdn, socket.AF_INET)
    ips_v6 = get_family_ips(fqdn, socket.AF_INET6)
    return ips_v4 | ips_v6


def is_cname(fqdn):
    try:
        dns_resolver.query(fqdn, dns.rdatatype.CNAME)
    except dns.exception.DNSException:
        return False

    return True


def have_correct_ptr(ip, fqdn):
    reverse_name = dns.reversename.from_address(ip)
    try:
        ptr_names = dns_resolver.query(reverse_name, dns.rdatatype.PTR)
        return fqdn in list(map(str, ptr_names))
    # UnicodeDecodeError падает, если в ptr_names есть запретные символы
    except (dns.exception.DNSException, UnicodeDecodeError):
        return False


def get_dns_status(fqdn):
    if not fqdn.endswith('.'):
        fqdn += '.'

    if is_cname(fqdn):
        return DNS_STATUS.CNAME

    ips = get_ips(fqdn)
    if not ips:
        return DNS_STATUS.CANT_RESOLVE

    correct_ips = [ip for ip in ips if have_correct_ptr(ip, fqdn)]

    if not correct_ips:
        return DNS_STATUS.HAS_NO_PTRS

    if len(correct_ips) != len(ips):
        return DNS_STATUS.SOME_PTR_MISSING

    return DNS_STATUS.OK


def create_or_update_dns_status(session, server):
    server_flags = session.merge(ServerFlags(server_id=server.id))
    if server_flags.dns_status is None or server_flags.is_expired:
        expire_at = datetime.datetime.now() + datetime.timedelta(
            days=settings.DNS_STATUS_TTL_DAYS)
        server_flags.dns_status = get_dns_status(server.fqdn)
        server_flags.dns_status_expire_at = expire_at
    return server_flags
