from typing import Any, Callable, Awaitable

import aiopg
from aiopg import Pool, Cursor

from timeit import default_timer as timer


class Dao(object):
    def __init__(self, timeout_sec: int, host: str, port: int, user: str, dbname: str, dry_run: bool,
                 password: str = None):
        self.timeout_sec = timeout_sec
        self.host = host
        self.port = port
        self.user = user
        self.dbname = dbname
        self.password = password
        self.dry_run = dry_run
        self._pool: Pool = None

    async def __begin(self, cursor: Cursor):
        await cursor.execute("BEGIN", timeout=self.timeout_sec)

    async def __commit(self, cursor: Cursor):
        await cursor.execute("COMMIT", timeout=self.timeout_sec)

    async def __rollback(self, cursor: Cursor):
        await cursor.execute("ROLLBACK", timeout=self.timeout_sec)

    async def __exec(self, query: Callable[[Cursor], Awaitable[Any]]) -> Any:
        async with self._pool.acquire() as conn:
            async with conn.cursor(timeout=self.timeout_sec) as cur:
                return await query(cur)

    async def __transactional(self, query: Callable[[Cursor], Awaitable[Any]]) -> Any:
        async def tx_block(cur: Cursor):
            await self.__begin(cur)
            try:
                res = await query(cur)
                if self.dry_run:
                    print("Dry run rollback")
                    await self.__rollback(cur)
                else:
                    await self.__commit(cur)
                return res
            except Exception as ex:
                print(f"Rollback on error: {ex}")
                await self.__rollback(cur)
                raise

        return await self.__exec(tx_block)

    async def connect(self, pool_size: int) -> None:
        dsn = f"host={self.host} port={self.port} dbname={self.dbname} user={self.user} password={self.password}"
        self._pool = await aiopg.create_pool(dsn=dsn, maxsize=pool_size, minsize=1)

    async def ping(self):
        await self.__exec(lambda cur: cur.execute("SELECT 1", timeout=self.timeout_sec))

    async def close_events(self, event_ids: [int]) -> None:
        await self.__transactional(lambda cursor: self.__close_events(event_ids, cursor))

    async def __close_events(self, event_ids: [int], cursor: Cursor) -> None:
        start = timer()
        events_ids_string = ','.join([str(event) for event in event_ids])
        await cursor.execute(f"""UPDATE event SET perm_all = 'none'
                                 WHERE event.id IN ({events_ids_string})""",
                             timeout=self.timeout_sec)
        print(f"Close events executed in {timer() - start} seconds")
