from typing import Optional

from asyncpg import Connection
from asyncpg.exceptions import LockNotAvailableError
from asyncpg.transaction import Transaction

from maps_adv.stat_controller.server.lib.db import DB, DbTaskStatus

__all__ = [
    "BlockingManager",
    "ConflictOperation",
    "find_last_task",
    "lock",
    "retrieve_task_details",
    "select_con",
]


class ConflictOperation(Exception):
    pass


async def retrieve_task_details(task_id: int, con: Connection) -> dict:
    sql = (
        "SELECT tasks.id, tasks.timing_from, tasks.timing_to, tasks.current_log_id, "
        "   tasks_log.executor_id, tasks_log.status, tasks_log.execution_state "
        "FROM tasks JOIN tasks_log ON tasks.current_log_id = tasks_log.id "
        "WHERE tasks.id = $1"
    )

    row = await con.fetchrow(sql, task_id)
    return dict(row)


class lock:
    __slots__ = "db", "con", "name", "transaction"

    db: DB
    name: str
    con: Optional[Connection]
    transaction: Optional[Transaction]

    def __init__(self, db: DB, name: str):
        self.db = db
        self.name = name

        self.con = None
        self.transaction = None

    async def __aenter__(self) -> Connection:
        sql = "SELECT id FROM locks WHERE name = $1 FOR UPDATE NOWAIT"

        self.con = await self.db.acquire()

        self.transaction = self.con.transaction()
        await self.transaction.start()

        try:
            await self.con.execute(sql, self.name)
        except LockNotAvailableError:
            await self.db.release(self.con)
            raise ConflictOperation(self.name)

        return self.con

    async def __aexit__(self, exc_type, *args, **kwargs):
        if exc_type:
            await self.transaction.rollback()
        else:
            await self.transaction.commit()
        await self.db.release(self.con)


class select_con:
    __slots__ = "db", "con", "acquired"

    def __init__(self, db: DB, con: Optional[Connection]):
        self.db = db
        self.con = con

        self.acquired = False

    async def __aenter__(self) -> Connection:
        if not self.con:
            self.con = await self.db.acquire()
            self.acquired = True

        return self.con

    async def __aexit__(self, exc_type, *args, **kwargs):
        if self.acquired:
            await self.db.release(self.con)


async def find_last_task(status: DbTaskStatus, con: Connection) -> Optional[dict]:
    sql = (
        "SELECT id, timing_from, timing_to FROM tasks "
        "WHERE status = $1 ORDER BY timing_to DESC LIMIT 1"
    )

    row = await con.fetchrow(sql, status)
    if row:
        return dict(row)


class BlockingManager:
    __slots__ = ("_db",)

    _db: DB
    name: str

    def __init__(self, db: DB):
        self._db = db

    def lock(self):
        return lock(self._db, self.name)

    def connection(self, con: Optional[Connection] = None):
        return select_con(self._db, con)
