import json
from typing import Optional

from asyncpg.connection import Connection
from asyncpg.pool import Pool, create_pool

from maps_adv.stat_controller.server.lib.db.enums import TaskStatus

__all__ = ["DB"]


class DbAcquire:
    __slots__ = "db", "target", "con"

    db: "DB"
    target: str
    con: Optional[Connection]

    def __init__(self, db: "DB", target: str):
        self.db = db
        self.target = target

    async def __call__(self) -> Connection:
        self.con = con = await self.db._acquire(self.target)
        await con.set_type_codec(
            "json", encoder=json.dumps, decoder=json.loads, schema="pg_catalog"
        )
        await con.set_type_codec(
            "taskstatus",
            encoder=lambda el: el.value,
            decoder=lambda name: TaskStatus(name),
            schema="public",
        )
        return con

    def __await__(self):
        return self().__await__()

    async def __aenter__(self) -> Connection:
        return await self()

    async def __aexit__(self, *exc):
        await self.db.release(self.con, self.target)


class DB:
    __slots__ = "_rw", "_ro", "_con", "_use_single_connection"

    _rw: Pool
    _ro: Pool
    _con: Optional[Connection]
    _use_single_connection: bool

    def __init__(self, rw_pool: Pool, ro_pool: Pool, use_single_connection: bool):
        self._rw = rw_pool
        self._ro = ro_pool or rw_pool

        self._use_single_connection = use_single_connection
        self._con = None

    @classmethod
    async def create(
        cls,
        rw_uri: str,
        ro_uri: Optional[str] = None,
        use_single_connection: bool = False,
    ) -> "DB":
        rw = await create_pool(rw_uri)
        ro = await create_pool(ro_uri) if ro_uri else rw
        return cls(rw, ro, use_single_connection=use_single_connection)

    async def close(self):
        await self._rw.close()
        await self._ro.close()

    def _pool(self, target: str) -> Pool:
        return self._rw if target == "rw" else self._ro

    def acquire(self, target: str = "rw") -> DbAcquire:
        return DbAcquire(self, target)

    async def _acquire(self, target: str):
        pool = self._pool(target)

        if self._use_single_connection:
            self._con = self._con or await pool.acquire()
            return self._con

        return await pool.acquire()

    async def release(self, con: Connection, target: str = "rw", force: bool = False):
        if not self._use_single_connection or force:
            pool = self._pool(target)

            await pool.release(con)
            self._con = None
