import asyncio
import re
from asyncio import wait_for, StreamReader
from typing import List, AsyncIterator, Dict, Optional

from pydantic import BaseModel, Field

from settings import config
from utils.common import ssl_ctx

sep = re.compile(":\s")


class Zone(BaseModel):
    name: str = Field(alias="zone")
    pattern: str
    state: str
    served_serial: Optional[str] = Field(alias="served-serial")
    commit_serial: Optional[str] = Field(alias="commit-serial")
    notified_serial: Optional[str] = Field(alias="notified-serial")
    wait: Optional[str]
    transfer: Optional[str]


class NSDConnection:
    def __init__(self, timeout: int = None):
        self.reader = None
        self.writer = None
        self.timeout = timeout
        self.host = config.nsd.host
        self.port = config.nsd.port
        self.ssl_context = ssl_ctx

    async def __aenter__(self):
        create_conn = asyncio.open_connection(self.host, self.port, ssl=self.ssl_context)
        self.reader, self.writer = await wait_for(create_conn, timeout=self.timeout)
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        self.writer.close()
        await self.writer.wait_closed()


class NSDControl:
    def __init__(self):
        self.PROTOCOL_VERSION = "NSDCT1"

    async def addzones(self, zones: List[str], pattern="default") -> bytes:
        async with NSDConnection() as connection:
            connection.writer.write(f"{self.PROTOCOL_VERSION} addzones\n".encode())
            await connection.writer.drain()
            for i, zone in enumerate(zones):
                connection.writer.write(f"{zone} {pattern}\n".encode())
                if i > 0 and i % 10000 == 0:
                    await connection.writer.drain()
                    line = await connection.reader.readline()
                    while (f"added: {zone}".encode()) not in line:
                        line = await connection.reader.readline()

            connection.writer.write(b"\x04\n")
            await connection.writer.drain()
            response = await connection.reader.read()

        return response

    async def _command(self, cmd: str, timeout: int = None) -> bytes:
        async with NSDConnection(timeout) as connection:
            connection.writer.write(f"{self.PROTOCOL_VERSION} {cmd}\n".encode())
            response = await connection.reader.read()
            return response

    async def stats(self, timeout: int = None) -> bytes:
        return await self._command("stats", timeout=timeout)

    async def delzones(self, zones: List[str]) -> bytes:
        async with NSDConnection() as connection:
            connection.writer.write(f"{self.PROTOCOL_VERSION} delzones\n".encode())
            await connection.writer.drain()
            for i, zone in enumerate(zones):
                connection.writer.write(f"{zone}\n".encode())
                if i > 0 and i % 10000 == 0:
                    await connection.writer.drain()
                    line = await connection.reader.readline()
                    while (f"removed: {zone}".encode()) not in line:
                        line = await connection.reader.readline()

            connection.writer.write(b"\x04\n")
            await connection.writer.drain()
            response = await connection.reader.read()

        return response

    async def zonestatus(self) -> AsyncIterator:
        async with NSDConnection() as connection:
            connection.writer.write(f"{self.PROTOCOL_VERSION} zonestatus\n".encode())
            await connection.writer.drain()

            while not connection.reader.at_eof():
                zonedata = await self._parse_zone(connection.reader)
                if not zonedata:
                    continue

                yield Zone.parse_obj(zonedata)

    async def _parse_zone(self, reader: StreamReader) -> Dict:
        zonedata = {}
        # максимум может быть 7 полей для каждой зоны
        for _ in range(7):
            line = await reader.readline()
            if not line:
                continue
            key, value = re.split(sep, line.decode().strip())
            value = value.strip("\"")
            zonedata[key] = value

            if key in ("served-serial", "commit-serial", "notified-serial"):
                if value == "none":
                    zonedata[key] = None
                    zonedata[key + "_at"] = None
                else:
                    serial, dt = value.split(" since ", 1)
                    zonedata[key] = int(serial)
                    zonedata[key + "_at"] = dt
            if key in ("transfer", "wait"):
                if value == "none":
                    zonedata[key] = None
                else:
                    zonedata[key] = value
                break
        return zonedata
