# -*- coding: utf-8 -*-
import logging
import socket

from netaddr import AddrFormatError
from passport.backend.core.cache.backend.locmem import LocalMemoryCache
from passport.backend.core.conf import settings
from passport.backend.core.exceptions import BaseCoreError
from passport.backend.core.types.ip.ip import IP


_DNS_CACHE = LocalMemoryCache(ttl=settings.DNS_CLEANING_INTERVAL)

NAT64_IP6_PREFIX = '64:ff9b:'

log = logging.getLogger('passport.useragent.dns')


class DNSError(BaseCoreError):
    pass


class DNSNoNameError(DNSError):
    pass


class DNSResolver(object):
    def __init__(self, cache=None):
        if cache is not None:
            self.cache = cache
        elif settings.USE_GLOBAL_DNS_CACHE:
            self.cache = _DNS_CACHE
        else:
            self.cache = LocalMemoryCache()

    def query(self, host):
        ips = self.cache.get(host)
        if ips:
            return ips

        # Не пытаемся резолвить IP-адреса, сразу добавляем их в кеш
        try:
            IP(host)
        except AddrFormatError:
            pass
        else:
            self.cache.set(host, [host])
            return [host]

        # Резолвим IP-адреса для хоста, сохраняем IP-адреса в кеш
        try:
            ips = []
            for (
                _family, _socktype, _proto, _canonname, sockaddr
            ) in socket.getaddrinfo(host, None, socket.AF_INET6, socket.SOCK_STREAM, socket.IPPROTO_TCP):
                ip = sockaddr[0]
                if ip.startswith('64:ff9b:'):
                    log.warning('Got nat64 address for %s: %s', host, ip)
                    raise DNSError('Unknown host: %s (got nat64 address)' % host)
                if ip in ips:
                    continue
                ips.append(ip)
            self.cache.set(host, ips)
            return ips
        except socket.gaierror as e:
            log.warning('gaierror: query failed for %s: %s (%s)', host, e.strerror, e.errno)
            if e.errno == socket.EAI_NONAME:
                raise DNSNoNameError('Unknown host: %s' % host)
            raise DNSError(e.strerror)

    def invalidate(self, host):
        self.cache.delete(host)
