import asyncio
from dataclasses import dataclass
from typing import Any, Callable, Awaitable, Optional

import aiopg
from aiopg import Pool, Cursor

from timeit import default_timer as timer


@dataclass
class Batch:
    lower_bound: int  # exclusive
    upper_bound: int  # inclusive
    size: int

    def __str__(self):
        return f"({self.lower_bound}, {self.upper_bound}]:{self.size}"


@dataclass
class BatchStat:
    success: int
    failed: int


class Dao(object):
    def __init__(self, timeout_sec: int, host: str, port: int, user: str, dbname: str, dry_run: bool, delay: float,
                 password: str = None):
        self.timeout_sec = timeout_sec
        self.host = host
        self.port = port
        self.user = user
        self.dbname = dbname
        self.password = password
        self.delay = delay
        self.dry_run = dry_run
        self._pool: Pool = None

    async def __begin(self, cursor: Cursor):
        await cursor.execute("BEGIN", timeout=self.timeout_sec)

    async def __commit(self, cursor: Cursor):
        await cursor.execute("COMMIT;", timeout=self.timeout_sec)

    async def __rollback(self, cursor: Cursor):
        await cursor.execute("ROLLBACK", timeout=self.timeout_sec)

    async def __exec(self, query: Callable[[Cursor], Awaitable[Any]]) -> Any:
        async with self._pool.acquire() as conn:
            async with conn.cursor(timeout=self.timeout_sec) as cur:
                return await query(cur)

    async def __transactional(self, query: Callable[[Cursor], Awaitable[Any]]) -> Any:
        async def tx_block(cur: Cursor):
            await self.__begin(cur)
            try:
                res = await query(cur)
                if self.dry_run:
                    print("Dry run rollback")
                    await self.__rollback(cur)
                else:
                    await self.__commit(cur)
                return res
            except Exception as ex:
                print(f"Rollback on error: {ex}")
                await self.__rollback(cur)
                raise

        return await self.__exec(tx_block)

    async def __delay(self) -> None:
        if self.delay > 0:
            await asyncio.sleep(self.delay)

    async def connect(self, pool_size: int) -> None:
        dsn = f"host={self.host} port={self.port} dbname={self.dbname} user={self.user} password={self.password}"
        self._pool = await aiopg.create_pool(dsn=dsn, maxsize=pool_size, minsize=1)

    async def ping(self):
        await self.__exec(lambda cur: cur.execute("SELECT 1", timeout=self.timeout_sec))

    async def delete_all_outdated_profiles(self, batch_size: int) -> None:
        start = timer()
        try:
            batch_counter = 0
            deleted_count = batch_size

            while deleted_count == batch_size:
                try:
                    await self.__delay()
                    deleted_count = await self.delete_outdated_profiles(batch_size)
                    batch_counter += 1
                    end = timer()
                    print(f"{batch_counter} batches processed, {deleted_count} profiles deleted, "
                          f"total time = {end - start} seconds")
                except:
                    print(f"Migration failed, elapsed time = {timer() - start} seconds")
                    raise
        finally:
            print(f"Migration elapsed time = {timer() - start} seconds")

    async def delete_outdated_profiles(self, batch_size: int) -> Optional[int]:
        start = timer()
        try:
            return await self.__transactional(lambda cur: self.__delete_outdated_profiles_impl(cur, batch_size))
        finally:
            print(f"Batch elapsed time = {timer() - start} seconds")

    async def __delete_outdated_profiles_impl(self, cursor: Cursor, batch_size: int) -> Optional[int]:
        await cursor.execute(f"""DELETE FROM profiles.profiles
                                 WHERE profile in (
                                    SELECT profile FROM profiles.profiles
                                    WHERE empty_update_time IS NOT NULL
                                         AND empty_update_time > 0
                                         AND empty_update_time < extract(epoch from now() - interval '2 WEEK')::bigint
                                    LIMIT %(batch_size)s
                                 )
                                 """,
                             timeout=self.timeout_sec,
                             parameters={'batch_size': batch_size})
        return cursor.rowcount
