from typing import AsyncIterable

import psycopg2
from sqlalchemy import func

from sendr_aiopg.data_mapper import SelectableDataMapper, TableDataDumper
from sendr_aiopg.query_builder import CRUDQueries

from mail.beagle.beagle.core.entities.user import User
from mail.beagle.beagle.storage.db.tables import users as t_users
from mail.beagle.beagle.storage.db.tables import users_organizations_foreign_key
from mail.beagle.beagle.storage.exceptions import OrganizationNotFound, UserNotFound
from mail.beagle.beagle.storage.mappers.base import BaseMapper


class UserDataMapper(SelectableDataMapper):
    entity_class = User
    selectable = t_users


class UserDataDumper(TableDataDumper):
    entity_class = User
    table = t_users


class UserMapper(BaseMapper):
    name = 'user'
    _builder = CRUDQueries(
        base=t_users,
        id_fields=('org_id', 'uid'),
        mapper_cls=UserDataMapper,
        dumper_cls=UserDataDumper,
    )

    async def create(self, user: User) -> User:
        user.created = user.updated = func.now()
        query, mapper = self._builder.insert(user)
        try:
            return mapper(await self._query_one(query))
        except psycopg2.errors.ForeignKeyViolation as exc:
            if exc.diag.constraint_name == users_organizations_foreign_key:
                raise OrganizationNotFound
            raise

    async def delete(self, user: User) -> None:
        query = self._builder.delete(user)
        await self._query_one(query)

    async def find(self, org_id: int) -> AsyncIterable[User]:
        query, mapper = self._builder.select(id_values=(org_id,))
        async for row in self._query(query):
            yield mapper(row)

    async def get(self, org_id: int, uid: int) -> User:
        query, mapper = self._builder.select(id_values=(org_id, uid))
        return mapper(await self._query_one(query, raise_=UserNotFound))

    async def save(self, user: User) -> User:
        user.updated = func.now()
        query, mapper = self._builder.update(user, ignore_fields=('created',))
        return mapper(await self._query_one(query, raise_=UserNotFound))
