from typing import AsyncIterable, Iterable, Tuple

from sqlalchemy import func

from sendr_aiopg.query_builder import CRUDQueries

from mail.payments.payments.core.entities.product import Product
from mail.payments.payments.storage.db.tables import products as t_products
from mail.payments.payments.storage.exceptions import ProductNotFound
from mail.payments.payments.storage.mappers.base import BaseMapper
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper


class ProductDataMapper(SelectableDataMapper):
    entity_class = Product
    selectable = t_products


class ProductDataDumper(TableDataDumper):
    entity_class = Product
    table = t_products


class ProductMapper(BaseMapper[Product]):
    name = 'product'
    _builder = CRUDQueries(
        t_products,
        id_fields=('uid', 'product_id'),
        mapper_cls=ProductDataMapper,
        dumper_cls=ProductDataDumper,
    )

    async def create(self, obj: Product) -> Product:
        async with self.conn.begin():
            obj.product_id = await self._acquire_product_id(obj.uid)
            obj.revision = await self._acquire_revision(obj.uid)
            obj.created = func.now()
            query, mapper = self._builder.insert(obj)
            return mapper(await self._query_one(query))

    async def get(self, uid: int, product_id: int) -> Product:
        query, mapper = self._builder.select(id_values=(uid, product_id))
        return mapper(await self._query_one(query, raise_=ProductNotFound))

    async def get_many(self, uid: int, product_ids: Iterable[int]) -> AsyncIterable[Product]:
        query, mapper = self._builder.select(
            id_values=(uid,),
            filters={
                'product_id': lambda field: field.in_(product_ids),
            },
        )
        async for row in self._query(query):
            yield mapper(row)

    async def get_or_create(
        self,
        obj: Product,
        lookup_fields: Iterable[str] = tuple({'uid', 'name', 'nds', 'currency', 'price', 'status'}),
        for_update: bool = False,
    ) -> Tuple[Product, bool]:
        return await super().get_or_create(obj, lookup_fields=lookup_fields, for_update=for_update)
