from itertools import chain
from typing import Dict, Iterable, List, Tuple, Union, cast

from mail.payments.payments.conf import settings
from mail.payments.payments.core.actions.base.db import BaseDBAction
from mail.payments.payments.core.entities.order import Order
from mail.payments.payments.core.entities.service import ServiceMerchant
from mail.payments.payments.core.entities.transaction import Transaction, TransactionStatus
from mail.payments.payments.core.exceptions import CoreFailError
from mail.payments.payments.storage.exceptions import (
    OrderNotFound, SerialNotFound, ServiceMerchantNotFound, TransactionNotFound
)
from mail.payments.payments.storage.mappers.order.order import FindOrderParams


class SendToHistoryError(Exception):
    """
    SendToHistoryOrderAction internal exception. Used to indicate that there's nothing to send.
    """
    pass


class SendToHistoryOrderAction(BaseDBAction):
    """
    Gathers full order information and sends it to Ohio.
    """
    action_name = 'send_to_history_order_action'

    def __init__(self, uid: int, order_id: int):
        super().__init__()
        self._uid = uid
        self._order_id = order_id

    async def _get_order(self) -> Order:
        try:
            order = await self.storage.order.get(self._uid, self._order_id)
        except OrderNotFound:
            self.logger.info('Order not found. Nothing to send.')
            raise SendToHistoryError

        if order.customer_uid is None:
            self.logger.info('Order has no customer_uid. Nothing to send.')
            raise SendToHistoryError

        if order.exclude_stats:
            self.logger.info('Order marked with exclude_stats, so it must be ignored. Nothing to send.')
            raise SendToHistoryError

        return order

    async def _get_service_merchant(self, order: Order) -> ServiceMerchant:
        if order.service_merchant_id is None:
            self.logger.info('Order has no service_merchant_id. Nothing to send.')
            raise SendToHistoryError
        try:
            return await self.storage.service_merchant.get(service_merchant_id=order.service_merchant_id)
        except ServiceMerchantNotFound:
            with self.logger:
                self.logger.context_push(service_merchant_id=order.service_merchant_id)
                self.logger.error('ServiceMerchant not found. Nothing to send.')
            raise SendToHistoryError

    async def _get_transaction(self) -> Transaction:
        try:
            transaction = cast(
                Transaction,
                await self.storage.transaction.get_last_by_order(self._uid, self._order_id),
            )
        except TransactionNotFound:
            self.logger.info('Transaction not found. Nothing to send.')
            raise SendToHistoryError
        if not TransactionStatus.was_held(transaction.status):
            self.logger.info('Transaction was never held. Nothing to send.')
            raise SendToHistoryError
        return transaction

    async def handle(self) -> None:
        with self.logger:
            self.logger.context_push(uid=self._uid, order_id=self._order_id)

            if not settings.SEND_TO_OHIO:
                self.logger.info('Order is not sent to history. Sending to history is turned off.')
                return

            # Locking serial so that we get consistent state which is important when calculating max revision.
            try:
                await self.storage.serial.get(self._uid, for_update=True)
            except SerialNotFound:
                raise CoreFailError('Serial does not exist.')

            try:
                order = await self._get_order()
                service_merchant = await self._get_service_merchant(order)
                transaction = await self._get_transaction()
            except SendToHistoryError:
                return

            if order.is_test:
                self.logger.info('Order is not sent to history. Order is test.')
                return

            # Getting order refunds
            refunds = []
            async for refund in self.storage.order.find(
                FindOrderParams(uid=self._uid, original_order_id=self._order_id)
            ):
                if refund.trust_refund_id is None:
                    with self.logger:
                        self.logger.context_push(refund_id=refund.order_id)
                        self.logger.info('Refund was not created in trust yet. Skipping it.')
                    continue
                refunds.append(refund)

            # Loading items into order and refunds
            entity_by_id: Dict[int, Order] = {}
            uid_and_order_id_list: List[Tuple[int, int]] = []
            for entity in chain((order,), refunds):
                entity.items = []
                assert entity.order_id is not None
                entity_by_id[entity.order_id] = entity
                uid_and_order_id_list.append((entity.uid, entity.order_id))
            async for item in self.storage.item.get_for_orders(uid_and_order_id_list):
                entity = entity_by_id[cast(int, item.order_id)]
                assert entity.items is not None
                entity.items.append(item)

            # Calculating revision as maximum of all relevant revisions
            revision_entities: Iterable[Union[Order, Transaction]] = chain((order, transaction), refunds)
            revisions = [
                obj.revision
                for obj in revision_entities
                if obj.revision is not None
            ]
            assert revisions
            revision = max(revisions)
            self.logger.context_push(revision=revision)
            self.logger.info('Revision calculated.')

            await self.clients.ohio.post_order(
                order=order,
                transaction=transaction,
                refunds=refunds,
                service_merchant=service_merchant,
                revision=revision,
            )
            self.logger.info('Order was sent to history.')
