import asyncio
import struct
from typing import List, Iterable

import dns.message
import dns.query
import dns.rcode
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.SOA
from loguru import logger
from sqlalchemy import select
from sqlalchemy.orm import selectinload

from database.db import get_session_maker_replica
from models.domain import Domain
from utils.stat import dns_request_type, tcp_connect_count, dns_rcode, metrics, dns_opcode


class TCPHandler:
    def __init__(self):
        self._ns_rdatas = [
            dns.rdtypes.ANY.NS.NS(
                dns.rdataclass.IN,
                dns.rdatatype.NS,
                dns.name.from_text('dns1.yandex.net.'),
            ),
            dns.rdtypes.ANY.NS.NS(
                dns.rdataclass.IN,
                dns.rdatatype.NS,
                dns.name.from_text('dns2.yandex.net.'),
            ),
        ]
        self._ns_name = dns.name.from_text('dns1.yandex.net.')
        self._admin_name = dns.name.from_text('dns-hosting.yandex.ru.')

    @metrics
    async def handle_tcp(self, reader, writer):
        addr = writer.transport.get_extra_info('peername')
        addr = f'{addr[0]}@{addr[1]}'
        logger.debug(f'{addr} connection started')
        while True:
            try:
                size, = struct.unpack('!H', await reader.readexactly(2))
            except asyncio.IncompleteReadError:
                logger.debug(f'{addr} connection finished')
                break
            except ConnectionResetError:
                tcp_connect_count.labels('reset_by_peer_error').inc()
                logger.debug(f'{addr} connection reset by peer')
                break
            except BrokenPipeError:
                tcp_connect_count.labels('broken_pipe_error').inc()
                logger.warning(f'{addr} broken pipe error')
                break
            data = await reader.readexactly(size)
            await self.handle_dns(data, addr, writer)

    async def write_result(self, writer, result: bytes):
        bsize = struct.pack('!H', len(result))
        writer.write(bsize)
        writer.write(result)

    async def handle_dns(self, data: bytes, addr: str, writer):
        '''Handle DNS requests'''
        response = None
        query = dns.message.from_wire(data)
        query_text = query.to_text().replace("\n", " ")
        logger.info(f'{addr} has query: {query_text}')
        dns_opcode.labels('QUERY').inc()

        if not query or not query.question:
            logger.info(f'FORMERR query: {query}')
            response = self._answer_rcode(query, dns.rcode.FORMERR)

        if query.edns and query.edns > 0:
            logger.warning(f'BADVERS query: {query}')
            response = self._answer_rcode(query, dns.rcode.BADVERS)

        if response:
            await self.write_result(writer, response)
            return

        rdtype = query.question[0].rdtype
        question: dns.rrset.RRset = query.question[0]
        origin = question.name
        dns_request_type.labels(dns.rdatatype.to_text(rdtype)).inc()

        domain = await self.find_by_name(origin.to_text())
        if not domain:
            logger.info(f'REFUSED query, domain {origin} not found')
            response = self._answer_rcode(query, dns.rcode.REFUSED)
            await self.write_result(writer, response)
            return

        if rdtype == dns.rdatatype.SOA:
            response = self._answer_soa(query, domain)
            logger.info(f'NOERROR SOA query processing completed, domain: {origin}, serial: {domain.serial}')
            response = response.to_wire()
            await self.write_result(writer, response)

        if rdtype == dns.rdatatype.AXFR:
            for response in self._answer_axfr(query, domain, origin):
                logger.info(f'NOERROR AXFR query processing completed, domain: {origin}, serial: {domain.serial}, count records: {len(domain.records)}')
                await self.write_result(writer, response.get_wire())

        if rdtype == dns.rdatatype.IXFR:
            for response in self._answer_axfr(query, domain, origin):
                logger.info(f'NOERROR IXFR query processing completed, domain: {origin}, serial: {domain.serial}, count records: {len(domain.records)}')
                await self.write_result(writer, response.get_wire())

        dns_rcode.labels('NOERROR').inc()

    def _answer_rcode(self, query: dns.message.Message, rcode: dns.rcode.Rcode) -> bytes:
        dns_rcode.labels(dns.rcode.Rcode.to_text(rcode)).inc()
        response = dns.message.make_response(query)
        response.set_rcode(rcode)
        return response.to_wire()

    def _answer_axfr(self, query: dns.message.Message, domain: Domain, origin: dns.name.Name) -> Iterable:
        soa_rdata = self._get_soa_rdata(domain)
        zone_rdatas = self._get_zone_rdatas(domain, origin)

        renderer = self._make_renderer(query)
        # self._copy_question(query, renderer)

        for rrset in self._rrsets(origin, zone_rdatas, soa_rdata):
            try:
                renderer.add_rrset(dns.renderer.ANSWER, rrset)
            except dns.exception.TooBig:
                renderer.write_header()
                yield renderer
                renderer = self._make_renderer(query)
                renderer.add_rrset(dns.renderer.ANSWER, rrset)

        renderer.write_header()
        yield renderer

    def _answer_soa(self, query: dns.message.Message, domain: Domain) -> dns.message.Message:
        soa_rdata = self._get_soa_rdata(domain)
        response = dns.message.make_response(query)
        response.flags |= dns.flags.AA
        response.answer.append(dns.rrset.from_rdata(domain.name, 900, soa_rdata))

        return response

    def _get_soa_rdata(self, domain: Domain) -> dns.rdtypes.ANY.SOA:
        domain_serial = domain.serial
        return dns.rdtypes.ANY.SOA.SOA(
            dns.rdataclass.IN,
            dns.rdatatype.SOA,
            self._ns_name,
            self._admin_name,
            domain_serial,
            refresh=900,
            retry=90,
            expire=86400,
            minimum=900,
        )

    def _get_zone_rdatas(self, domain: Domain, origin: dns.name.Name) -> List[dns.rdata.Rdata]:
        for record in domain.records:
            name = dns.name.from_text(record.name, origin)
            try:
                rdata = dns.rdata.from_text(
                    dns.rdataclass.IN,
                    dns.rdatatype.from_text(record.type),
                    record.content,
                    origin,
                    relativize=False,
                )
            except UnicodeDecodeError as e:
                logger.warning(
                    f'failed to make rdata from text - unicode error {e}; record id {record.id}')
            except dns.exception.SyntaxError as e:
                logger.warning(
                    f'failed to make rdata from text - syntax error {e}; record id {record.id}')
            else:
                yield name, record.ttl, rdata

    def _make_renderer(self, query: dns.message.Message) -> dns.renderer.Renderer:
        flags = dns.flags.QR | dns.flags.AA | (query.flags & dns.flags.RD)
        flags &= 0x87FF
        flags |= dns.opcode.to_flags(query.opcode())

        if query.edns >= 0 and query.payload != 0:
            max_size = query.payload
        else:
            max_size = 65535

        return dns.renderer.Renderer(query.id, flags, max_size)

    def _rrsets(
        self, origin: dns.name.Name, zone_rdatas: List[dns.rdata.Rdata], soa_rdata: dns.rdtypes.ANY.SOA
    ) -> Iterable:
        yield dns.rrset.from_rdata(origin, 900, soa_rdata)
        yield dns.rrset.from_rdata_list(origin, 900, self._ns_rdatas)
        for name, ttl, rdata in zone_rdatas:
            yield dns.rrset.from_rdata(name, ttl, rdata)
        yield dns.rrset.from_rdata(origin, 900, soa_rdata)

    @staticmethod
    async def find_by_name(name: str) -> Domain:
        if not name.endswith('.'):
            name += '.'

        session = await get_session_maker_replica()
        async with session.begin():
            db_query = select(Domain).filter(Domain.name == name.lower()).options(selectinload(Domain.records))
            result: Domain = (await session.execute(db_query)).scalar()

        return result
