from starlette.applications import Starlette
from psycopg2 import errors


class LockNotAvailable(Exception):
    pass


class LockHolder:
    async def __aenter__(self):
        pass

    async def __aexit__(self, exc_type, exc, tb):
        pass


class LockProvider:
    def acquire_lock(
            self,
            app: Starlette,
            lock_name: str,
            transaction_timeout: int = 600,
            lock_wait_timeout: int = 1,
    ) -> LockHolder:
        return LockHolder()


class PgAdvisoryLockProvider(LockProvider):
    def acquire_lock(
            self,
            app: Starlette,
            lock_name: str,
            transaction_timeout: int = 600,
            lock_wait_timeout: int = 1,
    ) -> LockHolder:
        return PgLockHolder(app, lock_name, transaction_timeout, lock_wait_timeout)


class PgLockHolder(LockHolder):
    def __init__(
        self,
        app: Starlette,
        lock_name: str,
        transaction_timeout: int = 600,
        lock_wait_timeout: int = 1,
    ) -> None:
        super().__init__()
        self._app = app
        self._lock_name = lock_name
        self._transaction_timeout = transaction_timeout
        self._lock_wait_timeout = lock_wait_timeout
        self._connection = None
        self._transaction = None

    async def __aenter__(self):
        self._connection = await self._app.state.engine.acquire()
        self._transaction = await self._connection.begin()
        try:
            await self._connection.execute(
                f'''SET LOCAL idle_in_transaction_session_timeout TO {self._transaction_timeout * 1000};
                SET LOCAL lock_timeout TO {self._lock_wait_timeout * 1000};
                SELECT pg_advisory_lock(('x' || md5('{self._lock_name}'))::bit(64)::bigint);''')
        except errors.LockNotAvailable:
            await self.close()
            raise LockNotAvailable()
        except:
            await self.close()
            raise

    async def __aexit__(self, exc_type, exc, tb):
        await self.close()

    async def close(self):
        if self._transaction:
            await self._transaction.close()
        if self._connection:
            await self._connection.close()
