from decimal import Decimal
from operator import itemgetter
from typing import Optional

from maps_adv.billing_proxy.client.lib import Client as BillingClient
from maps_adv.statistics.beekeeper.lib.steps.base import BaseStep
from maps_adv.warden.client.lib import TaskContext

from . import schemas


class BillingNotification(BaseStep):
    __slots__ = "_billing"

    input_schema = schemas.ChargesCalculatorState
    output_schema = schemas.BillingNotificationState

    def __init__(self, billing: BillingClient):
        self._billing = billing

    async def run(
        self, data: Optional[dict] = None, _: Optional[TaskContext] = None
    ) -> dict:
        orders = data["orders"]
        if not orders:
            data["billing_applied"] = False
            return data

        charges = {
            order_id: amount_to_bill
            for order_id, amount_to_bill in map(
                _order_id_and_amount_to_bill_getter, orders
            )
            if order_id is not None and amount_to_bill != Decimal("0")
        }

        applied, submit_result = False, {}
        if charges:
            async with self._billing as client:
                applied, submit_result = await client.submit_orders_charges(
                    charges=charges, bill_due_to=data["packet_end"]
                )

        data["billing_applied"] = applied
        for order in orders:
            billing_success = None
            order_id, amount_to_bill = _order_id_and_amount_to_bill_getter(order)

            if amount_to_bill != Decimal("0") and order_id is not None:
                billing_success = submit_result[order_id]
            order["billing_success"] = billing_success

        return data


_order_id_and_amount_to_bill_getter = itemgetter("order_id", "amount_to_bill")
