from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Awaitable, List

import aiopg
from aiopg import Pool, Cursor


@dataclass
class DbSettings:
    host: str
    port: int
    dbname: str
    user: str
    password: str
    timeout_sec: int


@dataclass
class UserSettings:
    uid: int
    login: str
    yandex_email: str


@dataclass
class LayerInfo:
    id: int
    name: str


class LayerAction(Enum):
    LIST = 0
    ANONYMOUS_LIST = 1
    CREATE_EVENT = 2
    VIEW_EVENT = 3
    EDIT_EVENT = 4
    DETACH_EVENT = 5
    DELETE_EVENT = 6
    EDIT = 7
    DELETE = 8
    GRANT = 9


class LayerUserPerm(Enum):
    ADMIN = set(LayerAction.__iter__())
    EDIT = {
        LayerAction.LIST, LayerAction.CREATE_EVENT, LayerAction.VIEW_EVENT, LayerAction.EDIT_EVENT,
        LayerAction.DETACH_EVENT, LayerAction.DELETE_EVENT, LayerAction.EDIT
    }
    VIEW = {
        LayerAction.LIST, LayerAction.VIEW_EVENT
    }
    CREATE = {
        LayerAction.LIST, LayerAction.VIEW_EVENT, LayerAction.CREATE_EVENT
    }
    LIST = {
        LayerAction.LIST
    }
    ACCESS = {
        LayerAction.ANONYMOUS_LIST
    }


@dataclass
class LayerPermissions:
    id: int
    uid: int
    perm: LayerUserPerm


class DbClient:
    def __init__(self, settings: DbSettings, dry_run: bool):
        self.timeout_sec = settings.timeout_sec
        self.host = settings.host
        self.port = settings.port
        self.user = settings.user
        self.dbname = settings.dbname
        self.password = settings.password
        self.dry_run = dry_run
        self._pool: Pool = None

    async def __begin(self, cursor: Cursor):
        await cursor.execute("BEGIN", timeout=self.timeout_sec)

    async def __commit(self, cursor: Cursor):
        await cursor.execute("COMMIT", timeout=self.timeout_sec)

    async def __rollback(self, cursor: Cursor):
        await cursor.execute("ROLLBACK", timeout=self.timeout_sec)

    async def __exec(self, query: Callable[[Cursor], Awaitable[Any]]) -> Any:
        async with self._pool.acquire() as conn:
            async with conn.cursor(timeout=self.timeout_sec) as cur:
                return await query(cur)

    async def __transactional(self, query: Callable[[Cursor], Awaitable[Any]]) -> Any:
        async def tx_block(cur: Cursor):
            await self.__begin(cur)
            try:
                res = await query(cur)
                if self.dry_run:
                    print("Dry run rollback")
                    await self.__rollback(cur)
                else:
                    await self.__commit(cur)
                return res
            except Exception as ex:
                print(f"Rollback on error: {ex}")
                await self.__rollback(cur)
                raise

        return await self.__exec(tx_block)

    async def connect(self, pool_size: int) -> None:
        dsn = f"host={self.host} port={self.port} dbname={self.dbname} user={self.user} password={self.password}"
        self._pool = await aiopg.create_pool(dsn=dsn, maxsize=pool_size, minsize=1)

    async def ping(self):
        await self.__exec(lambda cur: cur.execute("SELECT 1", timeout=self.timeout_sec))

    async def select_user_settings(self, batch_size: int, since_id: int = -1) -> List[UserSettings]:
        async def query(cur: Cursor) -> List[UserSettings]:
            await cur.execute("""SELECT uid, user_login, yandex_email
                                 FROM settings
                                 WHERE uid > %(since)s
                                 ORDER BY uid
                                 LIMIT %(limit)s""",
                              timeout=self.timeout_sec,
                              parameters={'since': since_id, 'limit': batch_size})
            rows = await cur.fetchall()
            return [UserSettings(uid, login, yandex_login) for (uid, login, yandex_login) in rows]

        return await self.__exec(lambda cur: query(cur))

    async def select_layers(self, batch_size: int, since_id: int = -1) -> List[LayerInfo]:
        async def query(cur: Cursor) -> List[LayerInfo]:
            await cur.execute("""SELECT id, name
                                 FROM layer
                                 WHERE id > %(since)s
                                 ORDER BY id
                                 LIMIT %(limit)s""",
                              timeout=self.timeout_sec,
                              parameters={'since': since_id, 'limit': batch_size})
            rows = await cur.fetchall()
            return [LayerInfo(id, name) for (id, name) in rows]

        return await self.__exec(lambda cur: query(cur))

    async def select_layers_permissions(self, batch_size: int, since_uid: int = -1, since_id: int = -1) -> List[LayerPermissions]:
        async def query(cur: Cursor) -> List[LayerPermissions]:
            await cur.execute("""SELECT lu.uid, lu.layer_id, lu.perm
                                 FROM layer_user lu
                                 JOIN settings s on lu.uid = s.uid
                                 WHERE (lu.uid, lu.layer_id) > (%(since_uid)s, %(since_id)s)
                                    AND NOT EXISTS(SELECT 1 FROM resource WHERE exchange_name = s.user_login)
                                 ORDER BY (lu.uid, lu.layer_id)
                                 LIMIT %(limit)s""",
                              timeout=self.timeout_sec,
                              parameters={'since_uid': since_uid, 'since_id': since_id, 'limit': batch_size})
            rows = await cur.fetchall()
            return [LayerPermissions(layer_id, uid, LayerUserPerm[perm.upper()]) for (uid, layer_id, perm) in rows]

        return await self.__exec(lambda cur: query(cur))
