from typing import Optional

from asyncpg.connection import Connection
from asyncpg.pool import Pool, create_pool

__all__ = ["DB", "UnexpectedTransactionMode"]


class UnexpectedTransactionMode(Exception):
    pass


class DbAcquire:
    __slots__ = "db", "target", "con"

    db: "DB"
    target: str
    con: Optional[Connection]

    def __init__(self, db: "DB", target: str):
        self.db = db
        self.target = target

    async def __call__(self) -> Connection:
        self.con = con = await self.db._acquire(self.target)
        return con

    def __await__(self):
        return self().__await__()

    async def __aenter__(self) -> Connection:
        return await self()

    async def __aexit__(self, *exc):
        await self.db.release(self.con, self.target)


class DB:
    __slots__ = "_rw", "_ro", "_con", "_use_single_connection"

    _rw: Pool
    _ro: Pool
    _con: Optional[Connection]
    _use_single_connection: bool

    def __init__(self, rw_pool: Pool, ro_pool: Pool, use_single_connection: bool):
        self._rw = rw_pool
        self._ro = ro_pool

        self._use_single_connection = use_single_connection
        self._con = None

    @classmethod
    async def create(
        cls,
        rw_uri: str,
        ro_uri: Optional[str] = None,
        use_single_connection: Optional[bool] = False,
    ) -> "DB":
        rw = await create_pool(rw_uri, init=cls._set_codecs)
        ro = await create_pool(ro_uri, init=cls._set_codecs) if ro_uri else rw
        return cls(rw, ro, use_single_connection=use_single_connection)

    async def close(self):
        await self._rw.close()
        await self._ro.close()

    def acquire(self, target: str = "rw") -> DbAcquire:
        return DbAcquire(self, target)

    async def release(self, con: Connection, target: str = "rw", force: bool = False):
        if not self._use_single_connection or force:
            pool = self._pool(target)

            await pool.release(con)
            self._con = None

    async def check_pools(self):
        async with self.acquire() as con:
            rw_read_only = await con.fetchval(
                query="SHOW transaction_read_only", column="transaction_read_only"
            )
            if rw_read_only == "on":
                raise UnexpectedTransactionMode("Unexpected read-only mode")

        async with self.acquire("ro") as con:
            await con.execute("SELECT 1")

    async def _acquire(self, target: str):
        pool = self._pool(target)

        if self._use_single_connection:
            self._con = self._con or await pool.acquire()
            return self._con

        return await pool.acquire()

    def _pool(self, target: str) -> Pool:
        return self._rw if target == "rw" else self._ro

    @staticmethod
    async def _set_codecs(con: Connection) -> Connection:
        """Set your codecs here. Or not."""
        pass
