from datetime import datetime, timedelta, timezone
import calendar
import pytz
import logging
from collections import defaultdict
from maps_adv.adv_store.api.schemas.enums import FixTimeIntervalEnum


class PaidTillProcessorTask:
    def __init__(self, adv_store_client, billing_proxy_client):
        self._adv_store_client = adv_store_client
        self._billing_proxy_client = billing_proxy_client

    async def __call__(self, context):
        now = datetime.now(timezone.utc)
        async with self._adv_store_client as client:
            fix_campaigns = await client.list_active_fix_campaigns(on_datetime=now)

        orders_campaigns = defaultdict(lambda: defaultdict(list))

        unaccounted_orders_ids = set()

        for campaign in fix_campaigns:
            if campaign["paid_till"] is None or campaign["paid_till"] < now:
                orders_campaigns[campaign["order_id"]]["unaccounted"].append(campaign)
                unaccounted_orders_ids.add(campaign["order_id"])
                logging.getLogger("beekeeper.paid_till_processor").info(
                    f"Unaccounted campaign {campaign['campaign_id']}, "
                    "order {campaign['order_id']}, paid till {campaign['paid_till']}"
                )
            else:
                orders_campaigns[campaign["order_id"]]["accounted"].append(campaign)

        if not unaccounted_orders_ids:
            return

        earliest_debit = now - timedelta(days=31)

        async with self._billing_proxy_client as client:
            unaccounted_orders_debits = await client.fetch_orders_debits(
                unaccounted_orders_ids, paid_from=earliest_debit
            )

        (
            unpaid_campaigns,
            unaccounted_paid_campaigns,
        ) = self._process_uncaccounted_orders(
            orders_campaigns, unaccounted_orders_debits, now
        )

        async with self._billing_proxy_client as client:
            # Charge each campaign individually in case there's not enough balance to pay for all at once
            for campaign in unpaid_campaigns:
                charges = {campaign["order_id"]: campaign["cost"]}
                applied, submit_result = await client.submit_orders_charges(
                    charges=charges, bill_due_to=now
                )
                if applied and submit_result.get(campaign["order_id"]):
                    unaccounted_paid_campaigns.append((campaign, now))

        async with self._adv_store_client as client:
            for campaign, paid_at in unaccounted_paid_campaigns:
                paid_till = self._compute_paid_till(campaign, paid_at)
                client.update_paid_till(campaign["campaign_id"], paid_till)

    @staticmethod
    def _compute_paid_till(campaign, paid_at):
        campaign_timezone = pytz.timezone(campaign["timezone"]) or pytz.utc
        paid_at_tz = paid_at.astimezone(campaign_timezone)

        match campaign["time_interval"]:
            case FixTimeIntervalEnum.DAILY:
                paid_till_tz = datetime(
                    year=paid_at_tz.year,
                    month=paid_at_tz.month,
                    day=paid_at_tz.day,
                    tzinfo=campaign_timezone,
                ) + timedelta(days=1)
            case FixTimeIntervalEnum.WEEKLY:
                paid_till_tz = datetime(
                    year=paid_at_tz.year,
                    month=paid_at_tz.month,
                    day=paid_at_tz.day,
                    tzinfo=campaign_timezone,
                ) + timedelta(days=7)
            case FixTimeIntervalEnum.MONTHLY:
                paid_till_tz = datetime(
                    year=paid_at_tz.year,
                    month=paid_at_tz.month,
                    day=calendar.monthrange(paid_at_tz.year, paid_at_tz.month)[1],
                    tzinfo=campaign_timezone,
                ) + timedelta(days=1)
            case _:
                return None

        return paid_till_tz.astimezone(pytz.utc)

    @staticmethod
    def _process_uncaccounted_orders(
        orders_campaigns: dict[int, dict], debit_data: dict[int, list], now: datetime
    ) -> tuple[list, list]:
        unpaid_campaigns = []
        paid_campaigns = []
        for order_id, order_campaigns in orders_campaigns.items():
            order_debits = debit_data.get(order_id)
            if order_debits is None:
                unpaid_campaigns.extend(order_campaigns["unaccounted"])
                continue

            for accounted_campaign in order_campaigns["accounted"]:
                PaidTillProcessorTask._pop_earliest_campaign_debit_dt(
                    accounted_campaign, order_debits, now
                )
            for unaccounted_campaign in order_campaigns["unaccounted"]:
                paid_at = PaidTillProcessorTask._pop_earliest_campaign_debit_dt(
                    unaccounted_campaign, order_debits, now
                )
                if paid_at:
                    paid_campaigns.append((unaccounted_campaign, paid_at))
                else:
                    unpaid_campaigns.append(unaccounted_campaign)

        return (unpaid_campaigns, paid_campaigns)

    @staticmethod
    def _pop_earliest_campaign_debit_dt(
        campaign: dict, debits: list[dict], now: datetime
    ) -> datetime | None:
        first_matching_debit = next(
            (
                debit
                for debit in debits
                if PaidTillProcessorTask._campaign_matches_debit(campaign, debit, now)
            ),
            None,
        )
        if first_matching_debit:
            debits.remove(first_matching_debit)
            return first_matching_debit["billed_at"]
        return None

    @staticmethod
    def _campaign_matches_debit(campaign: dict, debit: dict, now: datetime) -> bool:
        if campaign["cost"] != debit["amount"]:
            return False

        campaign_timezone = pytz.timezone(campaign["timezone"]) or pytz.utc
        now_tz = now.astimezone(campaign_timezone)
        match campaign["time_interval"]:
            case FixTimeIntervalEnum.DAILY:
                paid_after_tz = datetime(
                    year=now_tz.year,
                    month=now_tz.month,
                    day=now_tz.day,
                    tzinfo=campaign_timezone,
                )
            case FixTimeIntervalEnum.WEEKLY:
                paid_after_tz = datetime(
                    year=now_tz.year,
                    month=now_tz.month,
                    day=now_tz.day,
                    tzinfo=campaign_timezone,
                ) - timedelta(days=6)
            case FixTimeIntervalEnum.MONTHLY:
                paid_after_tz = datetime(
                    year=now_tz.year,
                    month=now_tz.month,
                    day=1,
                    tzinfo=campaign_timezone,
                )
            case _:
                return False

        return debit["billed_at"] >= paid_after_tz.astimezone(pytz.utc)
