from decimal import Decimal
from operator import itemgetter
from typing import Dict, List, Optional

from maps_adv.adv_store.api.schemas.enums import (
    ReasonCampaignStoppedEnum as ReasonStopped,
)
from maps_adv.adv_store.client import Client as AdvStoreClient
from maps_adv.warden.client.lib import TaskContext

from . import schemas
from .base import BaseStep


class AdvStoreNotification(BaseStep):
    __slots__ = "_adv_store"

    input_schema = schemas.PaidEventsPorterState
    output_schema = schemas.AdvStoreNotificationState

    def __init__(self, adv_store: AdvStoreClient):
        self._adv_store = adv_store

    async def run(
        self, data: Optional[dict] = None, _: Optional[TaskContext] = None
    ) -> dict:
        orders = data["orders"]
        if not orders:
            data["stopped_campaigns"] = {}
            return data

        campaigns_for_stopping = self._find_campaigns_for_stopping(orders)
        if campaigns_for_stopping:
            async with self._adv_store as client:
                await client.stop_campaigns(
                    processed_at=data["packet_end"],
                    campaigns_to_stop=campaigns_for_stopping,
                )

        data["stopped_campaigns"] = campaigns_for_stopping
        return data

    def _find_campaigns_for_stopping(
        self, orders: List[Dict]
    ) -> Dict[int, ReasonStopped]:
        campaigns_stop_reason = {}
        for order in orders:
            if order["order_id"] is not None:
                if order["billing_success"] is False:
                    continue

            for campaign in order["campaigns"]:
                reason = self._detect_reason_for_stopping_campaign(campaign)
                if reason:
                    campaigns_stop_reason[campaign["campaign_id"]] = reason

        return campaigns_stop_reason

    @classmethod
    def _calculate_charge_for_campaign(cls, campaign: Dict) -> Decimal:
        cost_event, cost_last_event, count = _cost_per_event_cost_last_event_and_count(
            campaign
        )

        if not (cost_last_event or cost_event) or not count:
            return Decimal("0")

        return cost_last_event + (count - 1) * cost_event

    @classmethod
    def _detect_reason_for_stopping_campaign(cls, campaign) -> Optional[ReasonStopped]:
        charge_for_campaign = cls._calculate_charge_for_campaign(campaign)

        charged, budget = _campaign_charged_and_budget(campaign)

        if charged + charge_for_campaign >= budget:
            return ReasonStopped.BUDGET_REACHED


_campaign_charged_and_budget = itemgetter("charged", "budget")
_cost_per_event_cost_last_event_and_count = itemgetter(
    "paid_event_cost", "last_paid_event_cost", "paid_events_to_charge"
)
