from typing import Dict, List, Optional, cast

from mail.payments.payments.core.actions.arbitrage.notify import NotifyArbitrageAction
from mail.payments.payments.core.actions.base.db import BaseDBAction
from mail.payments.payments.core.actions.order.refund import CoreCreateRefundAction
from mail.payments.payments.core.entities.arbitrage import Arbitrage
from mail.payments.payments.core.entities.enums import ArbitrageStatus, ArbitrageVerdict
from mail.payments.payments.core.entities.order import Order
from mail.payments.payments.core.entities.product import Product
from mail.payments.payments.core.exceptions import (
    ArbitrageNotEscalateError, ArbitrageNotFoundError, OrderNotFoundError, ProductNotFoundError
)
from mail.payments.payments.storage.exceptions import OrderNotFound


class VerdictArbitrageAction(BaseDBAction):
    transact = True

    def __init__(self, escalate_id: int, verdict: ArbitrageVerdict, items: Optional[List[dict]] = None):
        super().__init__()
        self.escalate_id = escalate_id
        self.verdict = verdict
        self.items = items

    async def handle(self) -> Arbitrage:
        try:
            arbitrage = await self.storage.arbitrage.get_by_escalate_id(self.escalate_id, for_update=True)
            order = await self.storage.order.get(arbitrage.uid, arbitrage.order_id)
        except Arbitrage.DoesNotExist:
            raise ArbitrageNotFoundError
        except OrderNotFound:
            raise OrderNotFoundError

        self.logger.context_push(
            order_id=order.order_id,
            arbitrage_id=arbitrage.arbitrage_id,
            status=arbitrage.status.value,
            verdict=self.verdict.value
        )

        if arbitrage.status != ArbitrageStatus.ESCALATE:
            raise ArbitrageNotEscalateError

        arbitrage.status = ArbitrageStatus.COMPLETE
        arbitrage.verdict = self.verdict

        if arbitrage.verdict == ArbitrageVerdict.REFUND:
            assert self.items
            product_ids = [item['product_id'] for item in self.items]

            products: Dict[int, Product] = {
                product.product_id: product
                async for product in self.storage.product.get_many(order.uid, product_ids)
                if product.product_id
            }

            for item in self.items:
                product_id = item['product_id']
                if product_id not in products:
                    raise ProductNotFoundError(product_id=product_id)
                for field in ('name', 'currency', 'price', 'nds'):
                    item[field] = getattr(products[product_id], field)

            refund: Order = await CoreCreateRefundAction(
                order_id=arbitrage.order_id,
                uid=arbitrage.uid,
                caption=order.caption,
                items=self.items,
                description=order.description
            ).run()

            arbitrage.refund_id = refund.order_id
            self.logger.info('Refund was created.')
        else:
            await NotifyArbitrageAction(arbitrage_id=cast(int, arbitrage.arbitrage_id)).run_async()
            self.logger.info('Refund was not created.')

        return await self.storage.arbitrage.save(arbitrage)
