from decimal import Decimal
from typing import AsyncIterable, List, Optional, Tuple

from sqlalchemy import func, select, tuple_

from sendr_aiopg.query_builder import CRUDQueries, Filters, RelationDescription

from mail.payments.payments.core.entities.enums import OrderKind, RefundStatus
from mail.payments.payments.core.entities.item import Item
from mail.payments.payments.storage.db.tables import DEFAULT_DECIMAL
from mail.payments.payments.storage.db.tables import images as t_images
from mail.payments.payments.storage.db.tables import items as t_items
from mail.payments.payments.storage.db.tables import orders as t_orders
from mail.payments.payments.storage.db.tables import products as t_products
from mail.payments.payments.storage.exceptions import ItemNotFound
from mail.payments.payments.storage.mappers.base import BaseMapper
from mail.payments.payments.storage.mappers.image import ImageDataMapper
from mail.payments.payments.storage.mappers.product import ProductDataMapper
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper


class ItemDataMapper(SelectableDataMapper):
    entity_class = Item
    selectable = t_items


class ItemDataDumper(TableDataDumper):
    entity_class = Item
    table = t_items


class ItemMapper(BaseMapper):
    name = 'item'
    _product_relation = RelationDescription(
        name='product',
        base=t_items,
        related=t_products,
        mapper_cls=ProductDataMapper,
        base_cols=('uid', 'product_id'),
        related_cols=('uid', 'product_id'),
    )
    _image_relation = RelationDescription(
        name='image',
        base=t_items,
        related=t_images,
        mapper_cls=ImageDataMapper,
        base_cols=('uid', 'image_id'),
        related_cols=('uid', 'image_id'),
        outer_join=True,
    )
    _builder = CRUDQueries(
        t_items,
        id_fields=('uid', 'order_id', 'product_id'),
        mapper_cls=ItemDataMapper,
        dumper_cls=ItemDataDumper,
        related=(_product_relation, _image_relation),
    )

    @staticmethod
    def _map_related(row, mapper, rel_mappers):
        assert rel_mappers
        item = mapper(row)
        item.product = rel_mappers['product'](row)
        if item.image_id is not None:
            item.image = rel_mappers['image'](row)
        return item

    async def create(self, obj: Item) -> Item:
        query, mapper = self._builder.insert(obj)
        return mapper(await self._query_one(query))

    async def create_or_update(self, obj: Item) -> Item:
        query, mapper = self._builder.insert(
            obj,
            on_conflict_do_update_constraint=t_items.primary_key,
            on_conflict_do_update_set={'amount': obj.amount, 'image_id': obj.image_id, 'markup': obj.markup}
        )
        return mapper(await self._query_one(query))

    async def save(self, obj: Item) -> Item:
        query, mapper = self._builder.update(
            obj,
            ignore_fields=(
                'uid',
                'order_id',
                'product_id',
            ),
        )
        return mapper(await self._query_one(query, raise_=ItemNotFound))

    async def find_by_image(
        self,
        uid: Optional[int],
        image_id: Optional[int],
        for_update: bool = False
    ) -> AsyncIterable[Item]:
        filters = Filters()
        filters.add_not_none('uid', uid)
        filters.add_not_none('image_id', image_id)

        query, mapper = self._builder.select(filters=filters, for_update=for_update)
        async for row in self._query(query):
            yield mapper(row)

    async def get(self, uid: int, order_id: int, product_id: int) -> Item:
        query, mapper = self._builder.select(id_values=(uid, order_id, product_id))
        return mapper(await self._query_one(query, raise_=ItemNotFound))

    async def get_for_order(
        self,
        uid: int,
        order_id: int,
        iterator: bool = False,
        for_update: bool = False
    ) -> AsyncIterable[Item]:
        query, mapper, rel_mappers = self._builder.select_related(
            id_values=(uid, order_id),
            for_update=for_update,
        )
        async for row in self._query(query, iterator=iterator):
            yield self._map_related(row, mapper, rel_mappers)

    async def get_for_orders(self,
                             uid_and_order_id_list: List[Tuple[int, int]],
                             iterator: bool = False) -> AsyncIterable[Item]:
        query, mapper, rel_mappers = self._builder.select_related()
        query = query.where(tuple_(t_items.c.uid, t_items.c.order_id).in_(uid_and_order_id_list))
        async for row in self._query(query, iterator=iterator):
            yield self._map_related(row, mapper, rel_mappers)

    async def get_product_amount_in_refunds(self, uid: int, order_id: int) -> AsyncIterable[Tuple[int, Decimal]]:
        query, mapper = self._builder.select(
            id_values=(uid,),
        )
        select_refunds_order_ids = select([t_orders.c.order_id]) \
            .where(t_orders.c.uid == uid) \
            .where(t_orders.c.original_order_id == order_id) \
            .where(t_orders.c.kind == OrderKind.REFUND) \
            .where(t_orders.c.refund_status != RefundStatus.FAILED)
        amount = func.sum(t_items.c.amount).cast(DEFAULT_DECIMAL).label('amount')
        query = (
            query.
            with_only_columns((t_items.c.product_id, amount)).
            where(t_items.c.order_id.in_(select_refunds_order_ids)).
            group_by(t_items.c.product_id)
        )

        async for row in self._query(query):
            yield row['product_id'], row['amount']

    async def delete(self, item: Item) -> Item:
        query = self._builder.delete(item)
        return await self._query_one(query)
