from typing import AsyncIterable, Optional

from sqlalchemy import func, select
from sqlalchemy.sql.expression import and_, or_, text

from sendr_aiopg.query_builder import CRUDQueries, Filters, RelationDescription

from mail.payments.payments.conf import settings
from mail.payments.payments.core.entities.customer_subscription import CustomerSubscription
from mail.payments.payments.core.entities.enums import PayStatus, TransactionStatus
from mail.payments.payments.storage.db.tables import \
    customer_subscription_transactions as t_customer_subscription_transactions
from mail.payments.payments.storage.db.tables import customer_subscriptions as t_customer_subscriptions
from mail.payments.payments.storage.db.tables import orders as t_orders
from mail.payments.payments.storage.exceptions import CustomerSubscriptionNotFound
from mail.payments.payments.storage.mappers.base import BaseMapper
from mail.payments.payments.storage.mappers.order.serialization import OrderDataMapper
from mail.payments.payments.storage.mappers.subscription.customer_subscription_transaction import (
    CustomerSubscriptionTransactionDataMapper
)
from mail.payments.payments.utils.db import SelectableDataMapper, TableDataDumper


class CustomerSubscriptionDataMapper(SelectableDataMapper):
    entity_class = CustomerSubscription
    selectable = t_customer_subscriptions


class CustomerSubscriptionDataDumper(TableDataDumper):
    entity_class = CustomerSubscription
    table = t_customer_subscriptions


class CustomerSubscriptionMapper(BaseMapper):
    name = 'customer_subscriptions'
    _tx_relation = RelationDescription(
        name='tx',
        base=t_customer_subscriptions,
        related=t_customer_subscription_transactions,
        mapper_cls=CustomerSubscriptionTransactionDataMapper,
        base_cols=('uid', 'customer_subscription_id'),
        related_cols=('uid', 'customer_subscription_id'),
        outer_join=True,
    )
    _order_relation = RelationDescription(
        name='order',
        base=t_customer_subscriptions,
        related=t_orders,
        mapper_cls=OrderDataMapper,
        base_cols=('uid', 'customer_subscription_id'),
        related_cols=('uid', 'customer_subscription_id'),
    )
    _builder = CRUDQueries(
        t_customer_subscriptions,
        id_fields=('uid', 'customer_subscription_id'),
        mapper_cls=CustomerSubscriptionDataMapper,
        dumper_cls=CustomerSubscriptionDataDumper,
        related=(_tx_relation, _order_relation)
    )

    async def create(self, obj: CustomerSubscription) -> CustomerSubscription:
        async with self.conn.begin():
            obj.customer_subscription_id = await self._acquire_customer_subscription_id(obj.uid)
            obj.created = obj.updated = func.now()
            query, mapper = self._builder.insert(obj)
            return mapper(await self._query_one(query))

    async def find(self,
                   uid: Optional[int] = None,
                   subscription_id: Optional[int] = None,
                   limit: Optional[int] = None,
                   offset: Optional[int] = None,
                   enabled: Optional[bool] = None,
                   ) -> AsyncIterable[CustomerSubscription]:
        filters = Filters()
        filters.add_not_none('uid', uid)
        filters.add_not_none('subscription_id', subscription_id)
        filters.add_not_none('enabled', enabled)

        query, mapper = self._builder.select(filters=filters, limit=limit, offset=offset)

        async for row in self._query(query):
            yield mapper(row)

    async def get(self,
                  uid: int,
                  customer_subscription_id: int,
                  service_merchant_id: Optional[int] = None,
                  for_update: bool = False,
                  ) -> CustomerSubscription:
        filters = Filters()
        filters.add_not_none('service_merchant_id', service_merchant_id)

        query, mapper = self._builder.select(id_values=(uid, customer_subscription_id),
                                             filters=filters,
                                             for_update=for_update)
        return mapper(await self._query_one(query, raise_=CustomerSubscriptionNotFound))

    async def save(self, obj: CustomerSubscription) -> CustomerSubscription:
        obj.updated = func.now()
        query, mapper = self._builder.update(
            obj,
            ignore_fields=('uid', 'customer_subscription_id', 'created')
        )
        return mapper(await self._query_one(query, raise_=CustomerSubscriptionNotFound))

    async def get_for_check(self) -> CustomerSubscription:
        pause = settings.CUSTOMER_SUBSCRIPTION_UPDATER_PAUSE
        lag = settings.CUSTOMER_SUBSCRIPTION_CALLBACK_LAG

        query, mapper, _ = self._builder.select_related(order=('updated',))
        query = (
            query.
            with_only_columns(mapper.columns).
            where(
                and_(
                    or_(
                        and_(
                            t_customer_subscriptions.c.time_until <= func.now() - text(f"INTERVAL '{lag} SECONDS'"),
                            t_customer_subscriptions.c.time_finish.is_(None),
                        ),
                        t_customer_subscription_transactions.c.payment_status.notin_(TransactionStatus.FINAL_STATUSES)
                    ),
                    t_orders.c.pay_status == PayStatus.PAID,
                    t_customer_subscriptions.c.updated <= func.now() - text(f"INTERVAL '{pause} SECONDS'"),
                )
            ).
            group_by(t_customer_subscriptions.c.uid, t_customer_subscriptions.c.customer_subscription_id).
            order_by(t_customer_subscriptions.c.updated)
        )

        return mapper(await self._query_one(query, raise_=CustomerSubscriptionNotFound))

    async def count_enabled(self, uid: int, subscription_id: int) -> int:
        query = (
            select([func.count()]).
            select_from(t_customer_subscriptions).
            where(
                and_(
                    t_customer_subscriptions.c.uid == uid,
                    t_customer_subscriptions.c.subscription_id == subscription_id,
                    t_customer_subscriptions.c.enabled.is_(True)
                )
            )
        )
        return await self._query_scalar(query)
