from datetime import datetime
from decimal import Decimal
from functools import partial
from itertools import chain, groupby
from operator import itemgetter
from typing import List, Optional

import pytz
from aioch import Client

from maps_adv.adv_store.client import Client as AdvStoreClient
from maps_adv.billing_proxy.client import Client as BillingProxyClient
from maps_adv.statistics.beekeeper.lib.packet_size_calculator import (
    PacketSizeCalculator,
)
from maps_adv.statistics.beekeeper.lib.steps import sqls
from maps_adv.statistics.beekeeper.lib.steps.base import (
    PAID_EVENTS,
    BaseStep,
    FreeEventProcessingMode,
)
from maps_adv.warden.client.lib import TaskContext

from . import schemas


class NoPacket(Exception):
    pass


class UnsupportedPrecision(Exception):
    pass


class ContextCollector(BaseStep):

    input_schema = schemas.EmptyState
    output_schema = schemas.ContextState

    def __init__(
        self,
        packet_size_calculator: PacketSizeCalculator,
        adv_store_client: AdvStoreClient,
        billing_proxy_client: BillingProxyClient,
        ch_client_params: dict,
        campaigns_for_processing: Optional[List[int]] = None,  # GEOPROD-4108
        ch_query_id: str = "",
    ):
        self._packet_size_calculator = packet_size_calculator
        self._adv_store_client = adv_store_client
        self._billing_proxy_client = billing_proxy_client
        self._ch_client_params = ch_client_params
        self._campaigns_for_processing = campaigns_for_processing  # GEOPROD-4108
        self._ch_query_id = ch_query_id

    async def run(
        self, data: Optional[dict] = None, _: Optional[TaskContext] = None
    ) -> dict:
        data = data or {}

        # Get package bounds
        packet_bounds = await self._packet_size_calculator()
        if packet_bounds is None:
            raise NoPacket

        packet_start, packet_end = itemgetter("packet_start", "packet_end")(
            packet_bounds
        )

        await self._fill_data_from_clients(data, packet_start, packet_end)
        await self._fill_data_from_stats_history(data, packet_start, packet_end)

        return data

    async def _fill_data_from_clients(
        self, data: dict, packet_start: datetime, packet_end: datetime
    ):
        data["packet_start"] = packet_start
        data["packet_end"] = packet_end

        # Get active campaigns from adv_store
        async with self._adv_store_client as client:
            campaigns_data = await client.list_active_cpm_campaigns(
                on_datetime=packet_end
            )
            for campaign in campaigns_data:
                campaign.update(
                    {
                        "billing_type": "cpm",
                        "paid_event_cost": campaign["cost"] / Decimal("1000"),
                        "paid_events_names": PAID_EVENTS,
                        "free_event_processing_mode": FreeEventProcessingMode.ONLY_IF_PAID_PRESENT,  # noqa: E501
                    }
                )

            campaigns_data_cpa = await client.list_active_cpa_campaigns(
                on_datetime=packet_end
            )
            for campaign in campaigns_data_cpa:
                campaign.update(
                    {
                        "billing_type": "cpa",
                        "paid_event_cost": campaign["cost"],
                        "free_event_processing_mode": FreeEventProcessingMode.ALL_EVENTS,  # noqa: E501
                    }
                )

            campaigns_data.extend(campaigns_data_cpa)

            campaigns_data = [
                item
                for item in campaigns_data
                if (
                    self._campaigns_for_processing is None
                    or item["campaign_id"] in self._campaigns_for_processing
                )
            ]  # TODO(megadiablo) Убрать фильтрацию GEOPROD-4108

        # Get orders and discount data from billing_proxy
        order_ids = {
            campaign_data["order_id"]
            for campaign_data in campaigns_data
            if campaign_data["order_id"] is not None
        }
        async with self._billing_proxy_client as client:
            orders_data = await client.fetch_orders_balance(*order_ids)
            discount_data = await client.fetch_orders_discounts(packet_end, *order_ids)

        data["orders"] = []

        campaigns_with_orders_iter = groupby(
            sorted(
                filter(lambda c: c["order_id"] is not None, campaigns_data),
                key=_order_id_and_campaign_id_getter,
            ),
            key=_order_id_getter,
        )

        for order_id, campaigns in campaigns_with_orders_iter:
            data["orders"].append(
                {
                    "order_id": order_id,
                    "balance": orders_data[order_id],
                    "campaigns": list(
                        map(
                            partial(
                                self._transform_campaign,
                                discount_data.get(order_id, Decimal("1.0")),
                            ),
                            campaigns,
                        )
                    ),
                }
            )

        campaigns_without_orders_info = sorted(
            filter(lambda c: c["order_id"] is None, campaigns_data),
            key=_campaign_id_getter,
        )

        if campaigns_without_orders_info:
            data["orders"].append(
                {
                    "order_id": None,
                    "balance": Decimal("Infinity"),
                    "campaigns": list(
                        map(
                            partial(self._transform_campaign, Decimal(1.0)),
                            campaigns_without_orders_info,
                        )
                    ),
                }
            )

    async def _fill_data_from_stats_history(
        self, data: dict, packet_start: datetime, packet_end: datetime
    ):
        ch_client = Client(**self._ch_client_params)
        timezone_campaigns = groupby(
            sorted(
                chain.from_iterable(map(itemgetter("campaigns"), data["orders"])),
                key=_group_key_for_ch,
            ),
            key=_group_key_for_ch,
        )

        sql_parts = []
        for (tz_name, paid_events_names), campaigns in timezone_campaigns:
            timezone = pytz.timezone(tz_name)
            local_packet_start = packet_start.astimezone(timezone)
            local_day_start = local_packet_start.replace(hour=0, minute=0, second=0)
            local_day_end = local_packet_start.replace(hour=23, minute=59, second=59)
            campaigns_ids = ",".join(map(str, map(_campaign_id_getter, campaigns)))
            event_names = ",".join(map(lambda ev: f"'{ev}'", paid_events_names))

            sql_parts.append(
                sqls.get_campaign_charge_stats.format(
                    campaigns_ids=campaigns_ids,
                    event_names=event_names,
                    packet_start=int(packet_start.timestamp()),
                    packet_end=int(packet_end.timestamp()),
                    local_day_start=int(local_day_start.timestamp()),
                    local_day_end=int(local_day_end.timestamp()),
                )
            )

        sql = " UNION ALL ".join(sql_parts)

        campaigns_stats = []
        if sql_parts:
            campaigns_stats = await ch_client.execute(
                sql,
                query_id=self._ch_query_id,
                settings={"replace_running_query": 0},
            )

        campaigns_stats_by_id = {
            campaign_id: (daily_charged, charged, events_count)
            for campaign_id, daily_charged, charged, events_count in campaigns_stats
        }

        for order_data in data["orders"]:
            for campaign_data in order_data["campaigns"]:
                campaign_stats = campaigns_stats_by_id.get(
                    campaign_data["campaign_id"], (Decimal(0), Decimal(0), 0)
                )
                campaign_data.update(
                    {
                        "daily_charged": campaign_stats[0],
                        "charged": campaign_stats[1],
                        "paid_events_count": campaign_stats[2],
                    }
                )

    @classmethod
    def _transform_campaign(cls, discount: Decimal, campaign: dict) -> dict:
        campaign_data = {
            "campaign_id": campaign["campaign_id"],
            "billing_type": campaign["billing_type"],
            "tz_name": campaign["timezone"],
            "paid_event_cost": campaign["paid_event_cost"] * discount,
            "paid_events_names": campaign["paid_events_names"],
            "free_event_processing_mode": campaign["free_event_processing_mode"],
            "budget": campaign["budget"],
            "daily_budget": campaign["daily_budget"],
        }
        campaign_data["paid_event_cost"] = campaign_data["paid_event_cost"].normalize()

        if campaign_data["budget"] is None:
            campaign_data["budget"] = Decimal("Infinity")
        if campaign_data["daily_budget"] is None:
            campaign_data["daily_budget"] = Decimal("Infinity")
        if campaign_data["paid_event_cost"].as_integer_ratio()[1] > 10 ** 6:
            raise UnsupportedPrecision

        return campaign_data


_order_id_getter = itemgetter("order_id")
_campaign_id_getter = itemgetter("campaign_id")
_order_id_and_campaign_id_getter = itemgetter("order_id", "campaign_id")
_paid_event_names_getter = itemgetter("paid_events_names")
_group_key_for_ch = itemgetter("tz_name", "paid_events_names")
