from pydantic import BaseModel

from sqlalchemy import select
from aiopg.sa.connection import SAConnection


class RecordNotFound(Exception):
    """Requested record in database was not found"""


class OperationError(Exception):
    """This operation is not possible"""


class DBGateway:

    table = None
    pk = None
    model_class = None

    def __init__(self, conn: SAConnection):
        self.conn = conn

    async def _fetchall(self, query):
        result = await self.conn.execute(query)
        return await result.fetchall()

    async def _first(self, query):
        result = await self.conn.execute(query)
        return await result.first()

    async def check_if_exists(self, ids: list[int]) -> dict[int, bool]:
        query = (
            select([self.table.c[self.pk]])
            .where(self.table.c[self.pk].in_(ids))
        )
        data = await self._fetchall(query)
        db_ids = {record[self.pk] for record in data}
        return {id_: id_ in db_ids for id_ in ids}

    async def _get_one(
            self,
            select_list: list,
            where_clause,
            model_class: type[BaseModel] = None,
            select_from=None,
    ):
        model_class = model_class or self.model_class
        query = select(select_list)
        if select_from is not None:
            query = query.select_from(select_from)
        query = query.where(where_clause)
        record = await self._first(query)
        if not record:
            msg = f'{self.model_class.__name__} does not exist'
            raise RecordNotFound(msg)
        return model_class(**record)

    async def create(self, **fields) -> int:
        query = self.table.insert().values(**fields).returning(self.table.c[self.pk])
        return await self.conn.scalar(query)

    async def update(self, pk_value: int, **fields) -> int:
        query = (
            self.table
            .update()
            .where(self.table.c[self.pk] == pk_value)
            .values(**fields)
            .returning(self.table.c[self.pk])
        )
        return await self.conn.scalar(query)

    async def delete(self, pk_value: int):
        query = self.table.delete().where(self.table.c[self.pk] == pk_value)
        await self.conn.execute(query)
