from typing import Optional

from sqlalchemy import func

from sendr_aiopg import BaseMapperCRUD
from sendr_aiopg.query_builder import CRUDQueries, Filters
from sendr_utils import anext

from mail.payments.payments.core.entities.arbitrage import Arbitrage
from mail.payments.payments.core.entities.enums import ArbitrageStatus
from mail.payments.payments.storage.db.tables import arbitrages as t_arbitrages
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper


class ArbitrageDataMapper(SelectableDataMapper):
    entity_class = Arbitrage
    selectable = t_arbitrages


class ArbitrageDataDumper(TableDataDumper):
    entity_class = Arbitrage
    table = t_arbitrages


class ArbitrageMapper(BaseMapperCRUD[Arbitrage]):
    name = 'merchant'
    model = Arbitrage

    _builder = CRUDQueries(
        base=t_arbitrages,
        id_fields=('arbitrage_id',),
        mapper_cls=ArbitrageDataMapper,
        dumper_cls=ArbitrageDataDumper,
    )

    async def create(self, obj: Arbitrage) -> Arbitrage:
        obj.created = obj.updated = func.now()
        return await super().create(obj, ignore_fields=('arbitrage_id',))

    async def get_by_escalate_id(self, escalate_id: int, for_update: bool = False) -> Arbitrage:
        query, mapper = self._builder.select(filters={'escalate_id': escalate_id}, limit=1, for_update=for_update)
        return mapper(await self._query_one(query, raise_=self.model.DoesNotExist))

    async def get_by_refund_id(self, uid: int, refund_id: int, for_update: bool = False) -> Arbitrage:
        filters = {'uid': uid, 'refund_id': refund_id}
        query, mapper = self._builder.select(filters=filters, limit=1, for_update=for_update)
        return mapper(await self._query_one(query, raise_=self.model.DoesNotExist))

    async def get_current(self, uid: int, order_id: int, for_update: bool = False) -> Optional[Arbitrage]:
        filters = Filters()
        filters['uid'] = uid
        filters['order_id'] = order_id
        filters['status'] = lambda field: field.in_(ArbitrageStatus.ACTIVE_STATUSES)

        return await anext(self.find(filters=filters, for_update=for_update), None)

    async def save(self, obj: Arbitrage) -> Arbitrage:
        obj.updated = func.now()
        return await super().save(obj, ignore_fields=('arbitrage_id', 'created'))
