from decimal import Decimal
from typing import Dict, List, Optional

from sendr_utils import alist

from mail.payments.payments.core.actions.order.base import BaseOrderAction
from mail.payments.payments.core.entities.customer_subscription_transaction import CustomerSubscriptionTransaction
from mail.payments.payments.core.entities.enums import OrderKind, OrderTimelineEventType, PayStatus, RefundStatus
from mail.payments.payments.core.entities.order import Order, OrderTimelineEvent
from mail.payments.payments.core.exceptions import CoreActionDenyError
from mail.payments.payments.storage.mappers.order.order import FindOrderParams


class GetOrderTimelineAction(BaseOrderAction):
    def __init__(self, order: Order):
        super().__init__()
        self.order = order
        self.transactions: Optional[Dict[str, CustomerSubscriptionTransaction]] = None

    async def _fetch_data(self) -> None:
        if self.order.items is None:
            self.order.items = await self._fetch_items(self.order)

        if self.order.refunds is None:
            self.order.refunds = await alist(self.storage.order.find(
                FindOrderParams(
                    uid=self.order.uid,
                    original_order_id=self.order.order_id,
                    select_customer_subscription=None  # select refunds for customer subscriptions
                )
            ))

        for refund in self.order.refunds:
            if refund.items is None:
                refund.items = await self._fetch_items(refund)

    async def _fetch_subscription_transactions(self) -> Dict[str, CustomerSubscriptionTransaction]:
        if self.transactions is not None:
            return self.transactions

        self.transactions = dict()
        if self.order.customer_subscription_id is None:
            return self.transactions

        transactions = await alist(
            self.storage.customer_subscription_transaction.find(
                uid=self.order.uid,
                customer_subscription_id=self.order.customer_subscription_id,
            )
        )
        for tx in transactions:
            self.transactions[tx.purchase_token] = tx
        return self.transactions

    @staticmethod
    def _fetch_tx_amount(tx: CustomerSubscriptionTransaction) -> Optional[Decimal]:
        if tx.data is None:
            return None
        if 'amount' in tx.data:
            return Decimal(tx.data['amount'])
        return None

    def _timeline_item_created(self) -> OrderTimelineEvent:
        assert self.order.created is not None

        return OrderTimelineEvent(
            date=self.order.created,
            event_type=OrderTimelineEventType.CREATED
        )

    def _timeline_item_closed(self) -> Optional[OrderTimelineEvent]:
        if self.order.closed is None or self.order.pay_status != PayStatus.PAID:
            return None
        return OrderTimelineEvent(
            date=self.order.closed,
            event_type=OrderTimelineEventType.PAID
        )

    def _ordinary_refund_event(self, refund: Order) -> OrderTimelineEvent:
        order_price = self.order.price or Decimal(0)
        refund_price = refund.price or Decimal(0)

        assert refund.closed
        return OrderTimelineEvent(
            date=refund.closed,
            event_type=(
                OrderTimelineEventType.PARTIALLY_REFUNDED
                if refund_price < order_price
                else OrderTimelineEventType.REFUNDED
            ),
            extra={
                'refund_amount': refund_price
            }
        )

    async def _subscription_refund_event(self, refund: Order) -> OrderTimelineEvent:
        transactions = await self._fetch_subscription_transactions()

        assert refund.customer_subscription_tx_purchase_token
        tx = transactions[refund.customer_subscription_tx_purchase_token]

        assert refund.closed
        return OrderTimelineEvent(
            date=refund.closed,
            event_type=OrderTimelineEventType.PERIODIC_REFUNDED,
            extra={
                'refund_amount': self._fetch_tx_amount(tx)
            }
        )

    async def _timeline_items_refund(self) -> List[OrderTimelineEvent]:
        assert self.order.refunds is not None

        completed_refunds = filter(
            lambda refund: refund.refund_status == RefundStatus.COMPLETED and refund.closed is not None,
            self.order.refunds
        )

        refund_events = []
        for refund in completed_refunds:
            if refund.customer_subscription_tx_purchase_token is None:
                refund_events.append(self._ordinary_refund_event(refund))
            else:
                refund_events.append(await self._subscription_refund_event(refund))
        return refund_events

    async def _timeline_items_periodic(self) -> List[OrderTimelineEvent]:
        if self.order.customer_subscription_id is None:
            return []
        transactions = await self._fetch_subscription_transactions()
        return [
            OrderTimelineEvent(
                date=transaction.updated or transaction.created,  # type: ignore
                event_type=OrderTimelineEventType.from_transaction_status(transaction.payment_status),
                extra={
                    'periodic_amount': self._fetch_tx_amount(transaction),
                    'tx_id': {
                        'uid': transaction.uid,
                        'customer_subscription_id': transaction.customer_subscription_id,
                        'purchase_token': transaction.purchase_token
                    }
                },
            )
            for transaction in transactions.values()
        ]

    async def handle(self) -> List[OrderTimelineEvent]:
        if self.order.kind != OrderKind.PAY:
            raise CoreActionDenyError

        await self._fetch_data()

        timeline = filter(None, [
            self._timeline_item_created(),
            self._timeline_item_closed(),
            *(await self._timeline_items_refund()),
            *(await self._timeline_items_periodic()),
        ])
        return sorted(timeline, key=lambda item: item.date)
