import asyncio
from dataclasses import dataclass
from typing import Any, Callable, Awaitable, Optional, Tuple

import aiopg
from aiopg import Pool, Cursor

from timeit import default_timer as timer


@dataclass
class Batch:
    lower_bound: int  # exclusive
    upper_bound: int  # inclusive
    size: int

    def __str__(self):
        return f"({self.lower_bound}, {self.upper_bound}]:{self.size}"


@dataclass
class BatchStat:
    success: int
    failed: int


class Dao(object):
    def __init__(self, timeout_sec: int, host: str, port: int, user: str, dbname: str, dry_run: bool, delay: float,
                 password: str = None):
        self.timeout_sec = timeout_sec
        self.host = host
        self.port = port
        self.user = user
        self.dbname = dbname
        self.password = password
        self.delay = delay
        self.dry_run = dry_run
        self._pool: Pool = None

    async def __begin(self, cursor: Cursor):
        await cursor.execute("BEGIN", timeout=self.timeout_sec)

    async def __commit(self, cursor: Cursor):
        await cursor.execute("COMMIT;", timeout=self.timeout_sec)

    async def __rollback(self, cursor: Cursor):
        await cursor.execute("ROLLBACK", timeout=self.timeout_sec)

    async def __exec(self, query: Callable[[Cursor], Awaitable[Any]]) -> Any:
        async with self._pool.acquire() as conn:
            async with conn.cursor(timeout=self.timeout_sec) as cur:
                return await query(cur)

    async def __transactional(self, query: Callable[[Cursor], Awaitable[Any]]) -> Any:
        async def tx_block(cur: Cursor):
            await self.__begin(cur)
            try:
                res = await query(cur)
                if self.dry_run:
                    print("Dry run rollback")
                    await self.__rollback(cur)
                else:
                    await self.__commit(cur)
                return res
            except Exception as ex:
                print(f"Rollback on error: {ex}")
                await self.__rollback(cur)
                raise

        return await self.__exec(tx_block)

    async def __delay(self) -> None:
        if self.delay > 0:
            await asyncio.sleep(self.delay)

    async def connect(self, pool_size: int) -> None:
        dsn = f"host={self.host} port={self.port} dbname={self.dbname} user={self.user} password={self.password}"
        self._pool = await aiopg.create_pool(dsn=dsn, maxsize=pool_size, minsize=1)

    async def ping(self):
        await self.__exec(lambda cur: cur.execute("SELECT 1", timeout=self.timeout_sec))

    async def migrate_all_events(self, batch_size: int, since_id: int = -1) -> None:
        start = timer()
        try:
            batch_counter = 0
            pivot_id = since_id

            while pivot_id is not None:
                try:
                    await self.__delay()
                    pivot_id = await self.migrate_events(batch_size, pivot_id)
                    if pivot_id is not None:
                        batch_counter += 1
                        end = timer()
                        print(f"{batch_counter} batches processed, last id = {pivot_id},"
                              f"total time = {end - start} seconds")
                except:
                    print(f"Migration failed on id = {pivot_id}, elapsed time = {timer() - start} seconds")
                    raise
        finally:
            print(f"Migration elapsed time = {timer() - start} seconds")

    async def migrate_events(self, batch_size: int, since_id: int = -1) -> Optional[int]:
        start = timer()
        try:
            return await self.__transactional(lambda cur: self.__migrate_events_impl(cur, batch_size, since_id))
        finally:
            print(f"Batch elapsed time = {timer() - start} seconds")

    async def __migrate_events_impl(self, cursor: Cursor, batch_size: int, since_id: int = -1) -> Optional[int]:
        await cursor.execute(f"""WITH events_range as (
                                   SELECT e.id, l.perm_all FROM event e
                                   JOIN event_layer el on e.id = el.event_id
                                   JOIN layer l on el.layer_id = l.id
                                   WHERE e.id > %(since_id)s
                                         AND (e.data_version IS NULL OR e.data_version <> 1)
                                         AND el.is_primary_inst = TRUE
                                   ORDER BY e.id
                                   LIMIT %(batch_size)s
                                   FOR NO KEY UPDATE
                                 ),
                                 updated as (
                                   UPDATE event
                                   SET data_version = 1,
                                   perm_all = CASE WHEN er.perm_all = 'none' THEN 'none'::event_perm_all
                                                                             ELSE event.perm_all
                                              END
                                   FROM events_range as er
                                   WHERE event.id = er.id
                                   RETURNING event.id
                                 )
                                 SELECT max(id) FROM updated""",
                             timeout=self.timeout_sec,
                             parameters={'since_id': since_id, 'batch_size': batch_size})
        return (await cursor.fetchone())[0]

    async def __select_events(self, batches_count: int, batch_size: int, since_id: int = -1) -> [Batch]:
        async def query(cur: Cursor, bound: int) -> Optional[Tuple[int, int]]:
            await cur.execute(f"""WITH ids as (
                                    SELECT id FROM event
                                    WHERE id > %(bound)s AND (data_version IS NULL OR data_version <> 1)
                                    ORDER BY id
                                    LIMIT %(batch_size)s
                                  )
                                  SELECT max(id), count(id) FROM ids""",
                              timeout=self.timeout_sec,
                              parameters={'bound': bound, 'batch_size': batch_size})
            response = await cur.fetchone()
            return response[0], response[1]

        result: [Batch] = []
        since = since_id
        for n in range(0, batches_count):
            last, size = await self.__exec(lambda cur: query(cur, since))
            if last is None:
                return result

            batch = Batch(lower_bound=since, upper_bound=last, size=size)
            result.append(batch)
            since = last
        return result

    async def __migrate_events_batch_impl(self, batch: Batch, cursor: Cursor) -> bool:
        start = timer()
        try:
            await cursor.execute(f"""WITH ids as (
                                       SELECT id FROM event
                                       WHERE id > %(lower_bound)s AND (data_version IS NULL OR data_version <> 1)
                                       ORDER BY id
                                       LIMIT %(batch_size)s
                                       FOR NO KEY UPDATE
                                     ),
                                     events_range as (
                                       SELECT ids.id, l.perm_all FROM ids
                                       JOIN event_layer el on ids.id = el.event_id
                                       JOIN layer l on el.layer_id = l.id
                                       WHERE el.is_primary_inst = TRUE
                                     )
                                     UPDATE event
                                     SET data_version = 1,
                                         perm_all = CASE WHEN er.perm_all = 'none' THEN 'none'::event_perm_all
                                                                                   ELSE event.perm_all
                                                    END
                                     FROM events_range as er
                                     WHERE event.id = er.id""",
                                 timeout=self.timeout_sec,
                                 parameters={'lower_bound': batch.lower_bound, 'batch_size': batch.size})
            print(f"Batch {batch} successfully complete")
            return True
        except Exception as ex:
            print(f"Batch {batch} failed: {ex}")
            return False
        finally:
            print(f"Batch elapsed time {timer() - start} seconds")

    async def __migrate_events_batch(self, batch: Batch) -> bool:
        return await self.__transactional(lambda cur: self.__migrate_events_batch_impl(batch, cur))

    async def __migrate_events_batches(self, batches: [Batch]) -> BatchStat:
        tasks = []
        for batch in batches:
            task = self.__migrate_events_batch(batch)
            tasks.append(task)
        done, _ = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)

        results = [task.result() for task in done]
        success = results.count(True)
        failed = results.count(False)
        return BatchStat(success=success, failed=failed)

    async def batch_migrate_events(self, batch_size: int, simultaneous_batches_count: int, since_id: int = -1) -> None:
        start = timer()
        since = since_id

        try:
            batches_counter = 0
            batches_success = 0
            batches_failed = 0

            while True:
                round_start = timer()
                batches = await self.__select_events(simultaneous_batches_count, batch_size, since)
                if len(batches) == 0:
                    print("There is no more useful batches. Migration complete.")
                    break

                stat = await self.__migrate_events_batches(batches)
                batches_success += stat.success
                batches_failed += stat.failed
                batches_counter += len(batches)
                since = batches[-1].upper_bound

                end = timer()
                print(f"{batches_counter} batches processed({batches_success} success, {batches_failed} failed), "
                      f"round time: {end - round_start}, total time: {end - start} seconds")

                await self.__delay()
        finally:
            print(f"Batch events migration elapsed time = {timer() - start} seconds")

