# encoding: UTF-8

import struct

import dns.flags
import dns.message
import dns.opcode
import dns.query
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.SOA
import dns.renderer
import gevent.server
import gevent.socket
from ws_properties.utils.logs import get_logger_for_instance

from appcore.injection import Injected
from appcore.tx.plugin import tx_manager
from dns_hosting.dao.domains import DomainRepository
from dns_hosting.dao.domains import RecordRepository


class DNSMasterDNSServer(gevent.server.StreamServer):
    _LEN_STRUCT = struct.Struct('!H')

    domains = Injected('domain_repository')  # type: DomainRepository
    records = Injected('record_repository')  # type: RecordRepository

    def __init__(self, listener, app, backlog=None, spawn='default',
                 **ssl_args):
        super(DNSMasterDNSServer, self).__init__(
            listener=listener,
            backlog=backlog,
            spawn=spawn,
            **ssl_args
        )
        self.app = app
        self._logger = get_logger_for_instance(self)
        self._ns_rdatas = [
            dns.rdtypes.ANY.NS.NS(
                dns.rdataclass.IN,
                dns.rdatatype.NS,
                dns.name.from_text('dns1.yandex.net.'),
            ),
            dns.rdtypes.ANY.NS.NS(
                dns.rdataclass.IN,
                dns.rdatatype.NS,
                dns.name.from_text('dns2.yandex.net.'),
            ),
        ]
        self._ns_name = dns.name.from_text('dns1.yandex.net.')
        self._admin_name = dns.name.from_text('dns-hosting.yandex.ru.')

    def handle(self, socket, addr):
        # type: (gevent.socket.socket, ...) -> None

        socket.settimeout(5)
        addr = '%s@%s' % (addr[0], addr[1])
        query = None

        try:
            self._logger.debug('%s: Connection started.', addr)

            with self.app.app_context():
                while True:
                    query = self._read_query(socket)

                    if query is None:
                        break
                    elif query.edns and query.edns > 0:
                        self._answer_rcode(socket, query, dns.rcode.BADVERS)
                        self._logger.warning('%s: Query BADVERS: %s', addr, query)
                    else:
                        self._handle_query(socket, query)
        except gevent.socket.error as e:
            self._logger.error('%s: %s', addr, e.strerror)
        except Exception:
            if query:
                self._logger.exception(
                    '%s: Query failed.\n\n%s\n\n',
                    addr,
                    query
                )
                self._answer_rcode(socket, query, dns.rcode.SERVFAIL)
            else:
                self._logger.exception('%s: Error occurred.', addr)
        finally:
            self._logger.debug('%s: Connection closed.', addr)
            socket.close()

    def _handle_query(self, socket, query):
        if query.question:
            rdtype = query.question[0].rdtype

            if rdtype == dns.rdatatype.SOA:
                self._answer_soa(socket, query)
                self._logger.info('SOA Query processing completed: %s', query)
            elif rdtype == dns.rdatatype.AXFR:
                self._answer_axfr(socket, query)
                self._logger.info('AXFR Query processing completed: %s', query)
            elif rdtype == dns.rdatatype.IXFR:
                self._answer_axfr(socket, query)
                self._logger.info('IXFR Query processing completed: %s', query)
            else:
                self._answer_rcode(socket, query, dns.rcode.REFUSED)
                self._logger.info('Query REFUSED: %s', query)
        else:
            self._answer_rcode(socket, query, dns.rcode.FORMERR)
            self._logger.info('Query FORMERR: %s', query)

    def _read_query(self, socket):
        # type: (gevent.socket.socket) -> dns.message.Message|None

        raw_length = socket.recv(2)
        if not raw_length:
            return None

        length, = self._LEN_STRUCT.unpack(raw_length)

        wire = socket.recv(length)
        return dns.message.from_wire(wire)

    def _answer_rcode(self, socket, query, rcode):
        response = dns.message.make_response(query)
        response.set_rcode(rcode)
        self._write_response(socket, response)

    def _answer_soa(self, socket, query):
        # type: (gevent.socket.socket, dns.message.Message) -> None

        question = query.question[0]
        origin = question.name

        try:
            with tx_manager:
                soa_rdata = self._get_soa_rdata(origin)
        except LookupError:
            self._answer_rcode(socket, query, dns.rcode.REFUSED)
        else:
            response = dns.message.make_response(query)
            response.flags |= dns.flags.AA
            response.answer.append(dns.rrset.from_rdata(origin, 900, soa_rdata))
            self._write_response(socket, response)

    def _answer_axfr(self, socket, query):
        # type: (gevent.socket.socket, dns.message.Message) -> None

        question = query.question[0]
        origin = question.name

        try:
            with tx_manager:
                soa_rdata = self._get_soa_rdata(origin)
                zone_rdatas = self._get_zone_rdatas(origin)
        except LookupError:
            self._answer_rcode(socket, query, dns.rcode.REFUSED)
        else:
            def _rrsets():
                yield dns.rrset.from_rdata(origin, 900, soa_rdata)
                yield dns.rrset.from_rdata_list(origin, 900, self._ns_rdatas)
                for name, ttl, rdata in zone_rdatas:
                    yield dns.rrset.from_rdata(name, ttl, rdata)
                yield dns.rrset.from_rdata(origin, 900, soa_rdata)

            renderer = self._make_renderer(query)
            self._copy_question(query, renderer)
            for rrset in _rrsets():
                try:
                    renderer.add_rrset(dns.renderer.ANSWER, rrset)
                except dns.exception.TooBig:
                    self._write_rendered(socket, renderer)
                    renderer = self._make_renderer(query)
                    renderer.add_rrset(dns.renderer.ANSWER, rrset)

            self._write_rendered(socket, renderer)

    def _get_soa_rdata(self, origin):
        # type: (dns.name.Name) -> dns.rdtypes.ANY.SOA.SOA

        domain_serial = self.domains.get_domain_serial(origin.to_text())
        return dns.rdtypes.ANY.SOA.SOA(
            dns.rdataclass.IN,
            dns.rdatatype.SOA,
            self._ns_name,
            self._admin_name,
            domain_serial,
            refresh=900,
            retry=90,
            expire=86400,
            minimum=900,
        )

    def _get_zone_rdatas(self, origin):
        # type: (dns.name.Name) -> list[dns.rdata.Rdata]

        for record in self.records.find_all_by_domain_name(origin.to_text()):
            name = dns.name.from_text(record.name, origin)
            try:
                rdata = dns.rdata.from_text(
                    dns.rdataclass.IN,
                    dns.rdatatype.from_text(record.type),
                    record.content,
                    origin,
                    relativize=False,
                )
            except UnicodeDecodeError:
                self._logger.exception(
                    'Failed to make rdata from text for rec id %d.',
                    record.id,
                )
            else:
                yield name, record.ttl, rdata

    def _write_response(self, socket, response):
        # type: (gevent.socket.socket, dns.message.Message) -> None

        wire = response.to_wire()
        length = len(wire)
        wire = self._LEN_STRUCT.pack(length) + wire
        socket.send(wire)

    def _make_renderer(self, query):
        flags = dns.flags.QR | dns.flags.AA | (query.flags & dns.flags.RD)
        flags &= 0x87FF
        flags |= dns.opcode.to_flags(query.opcode())

        if query.edns >= 0 and query.payload != 0:
            max_size = query.payload
        else:
            max_size = 65535

        return dns.renderer.Renderer(query.id, flags, max_size)

    def _copy_question(self, query, renderer):
        for rrset in query.question:
            renderer.add_question(rrset.name, rrset.rdtype, rrset.rdclass)

    def _write_rendered(self, socket, renderer):
        renderer.write_header()
        wire = renderer.get_wire()
        length = len(wire)
        wire = self._LEN_STRUCT.pack(length) + wire
        socket.send(wire)
