from decimal import Decimal
from typing import AsyncIterable, Dict, List, Optional

from sqlalchemy import func

from sendr_aiopg.query_builder import CRUDQueries, Filters

from mail.payments.payments.core.entities.subscription import Subscription, SubscriptionData, SubscriptionPrice
from mail.payments.payments.storage.db.tables import subscriptions as t_subscriptions
from mail.payments.payments.storage.exceptions import SubscriptionNotFound
from mail.payments.payments.storage.mappers.base import BaseMapper
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper


class SubscriptionDataMapper(SelectableDataMapper):
    entity_class = Subscription
    selectable = t_subscriptions

    @staticmethod
    def map_prices(rows: List[dict]) -> List[SubscriptionPrice]:
        return [
            SubscriptionPrice(
                price=Decimal(row['price']),
                currency=row['currency'],
                region_id=row['region_id'],
            )
            for row in rows
        ]

    @staticmethod
    def map_data(data: dict) -> SubscriptionData:
        return SubscriptionData(
            fast_moderation=data.get('fast_moderation', False)
        )


class SubscriptionDataDumper(TableDataDumper):
    entity_class = Subscription
    table = t_subscriptions

    def dump_prices(self, prices: List[SubscriptionPrice]) -> List[Dict]:
        return [
            {
                'price': str(price.price),
                'currency': price.currency,
                'region_id': price.region_id,
            }
            for price in prices
        ]


class SubscriptionMapper(BaseMapper):
    name = 'subscriptions'

    _builder = CRUDQueries(
        t_subscriptions,
        id_fields=('uid', 'subscription_id'),
        mapper_cls=SubscriptionDataMapper,
        dumper_cls=SubscriptionDataDumper,
    )

    async def create(self, obj: Subscription) -> Subscription:
        async with self.conn.begin():
            obj.subscription_id = await self._acquire_subscription_id(obj.uid)
            obj.revision = await self._acquire_revision(obj.uid)
            obj.created = obj.updated = func.now()
            query, mapper = self._builder.insert(obj)
            return mapper(await self._query_one(query))

    async def find(self,
                   uid: Optional[int] = None,
                   limit: Optional[int] = None,
                   offset: Optional[int] = None,
                   enabled: Optional[bool] = True,
                   ) -> AsyncIterable[Subscription]:
        filters = Filters()
        filters.add_not_none('uid', uid)
        filters.add_not_none('enabled', enabled)
        filters['deleted'] = False

        query, mapper = self._builder.select(filters=filters, limit=limit, offset=offset)

        async for row in self._query(query):
            yield mapper(row)

    async def get(self,
                  uid: int,
                  subscription_id: int,
                  service_merchant_id: Optional[int] = None,
                  enabled: Optional[bool] = True,
                  for_update: bool = False,
                  ) -> Subscription:
        filters = Filters()
        filters.add_not_none('enabled', enabled)
        filters.add_not_none('service_merchant_id', service_merchant_id)
        filters['deleted'] = False

        query, mapper = self._builder.select(
            id_values=(uid, subscription_id),
            filters=filters,
            for_update=for_update,
        )
        return mapper(await self._query_one(query, raise_=SubscriptionNotFound))

    async def save(self, obj: Subscription) -> Subscription:
        async with self.conn.begin():
            obj.revision = await self._acquire_revision(obj.uid, raise_=SubscriptionNotFound)
            obj.updated = func.now()
            query, mapper = self._builder.update(
                obj,
                ignore_fields=(
                    'uid',
                    'subscription_id',
                    'product_uuid',
                    'period_amount',
                    'period_units',
                    'trial_period_amount',
                    'trial_period_units',
                    'created'
                )
            )
            return mapper(await self._query_one(query, raise_=SubscriptionNotFound))
