from typing import AsyncIterable, Iterable, Optional

from sqlalchemy import func

from sendr_aiopg import BaseMapperCRUD
from sendr_aiopg.query_builder import CRUDQueries, Filters

from mail.payments.payments.core.entities.category import Category
from mail.payments.payments.storage.db.tables import categories as t_categories
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper


class CategoryDataMapper(SelectableDataMapper):
    entity_class = Category
    selectable = t_categories


class CategoryDataDumper(TableDataDumper):
    entity_class = Category
    table = t_categories


class CategoryMapper(BaseMapperCRUD[Category]):
    model = Category
    name = 'category'
    _builder = CRUDQueries(
        t_categories,
        id_fields=('category_id',),
        mapper_cls=CategoryDataMapper,
        dumper_cls=CategoryDataDumper,
    )

    async def create(self, obj: Category) -> Category:
        return await super().create(obj, ignore_fields=self._builder.id_fields + ('created', 'updated'))

    async def save(self, obj: Category) -> Category:
        obj.updated = func.now()
        return await super().save(obj, ignore_fields=('category_id', 'created'))

    async def get(self, category_id: int) -> Category:
        query, mapper = self._builder.select(id_values=(category_id,))
        return mapper(await self._query_one(query, raise_=Category.DoesNotExist))

    async def find(self,
                   category_ids: Optional[Iterable[int]] = None,
                   iterator: bool = False) -> AsyncIterable[Category]:
        filters = Filters()
        filters.add_not_none('category_id', category_ids, lambda column: column.in_(list(category_ids or [])))
        query, mapper = self._builder.select(filters=filters)
        async for row in self._query(query, iterator=iterator):
            yield mapper(row)
