import json
import logging
from operator import itemgetter
from typing import Optional

from aioch import Client

from maps_adv.statistics.beekeeper.lib.steps import sqls
from maps_adv.statistics.beekeeper.lib.steps.base import BaseStep
from maps_adv.warden.client.lib import TaskContext

from . import schemas


class PaidEventsPorter(BaseStep):

    input_schema = schemas.BillingNotificationState
    output_schema = schemas.PaidEventsPorterState

    def __init__(
        self,
        ch_client_params: dict,
        ch_max_memory_usage: int,
        ch_query_id: str = "",
    ):
        self._ch_client_params = ch_client_params
        self._ch_max_memory_usage = ch_max_memory_usage
        self._ch_query_id = ch_query_id

    async def run(
        self, data: Optional[dict] = None, context: Optional[TaskContext] = None
    ):
        packet_start_ts = int(data["packet_start"].timestamp())
        packet_end_ts = int(data["packet_end"].timestamp())

        ch_client = Client(
            settings={"max_memory_usage": self._ch_max_memory_usage},
            **self._ch_client_params,
        )

        processing_metadata = {}
        if context:
            processing_metadata["warden_executor_id"] = context.client.executor_id

        query_parts = []
        for order in data["orders"]:
            if not order["billing_success"] and order["order_id"] is not None:
                continue

            campaigns_ids = list(map(itemgetter("campaign_id"), order["campaigns"]))
            if not campaigns_ids:
                continue

            events_exist = await ch_client.execute(
                sqls.events_in_processed_exist.format(
                    packet_start=packet_start_ts,
                    packet_end=packet_end_ts,
                    campaigns_ids=",".join(map(str, campaigns_ids)),
                ),
                query_id=self._ch_query_id,
                settings={"replace_running_query": 0},
            )
            if events_exist:
                logging.getLogger("beekeeper.paid_events_porter").info(
                    "Campaigns for processing in batch (%s)",
                    ",".join(map(str, campaigns_ids)),
                )
                logging.getLogger("beekeeper.paid_events_porter").error(
                    "Ignored duplicate paid events saving for packet (%d, %d, %d)",
                    packet_start_ts,
                    packet_end_ts,
                    order["order_id"],
                )
                continue

            for campaign in order["campaigns"]:
                if campaign["paid_events_to_charge"] == 0:
                    continue

                regular_events_count = campaign["paid_events_to_charge"]
                custom_last_event = (
                    campaign["last_paid_event_cost"] != campaign["paid_event_cost"]
                )
                if custom_last_event:
                    regular_events_count -= 1

                # Regular (full-cost) campaign events
                if regular_events_count > 0:
                    query_parts.append(
                        sqls.select_paid_from_normalized_for_processed.format(
                            packet_start=packet_start_ts,
                            packet_end=packet_end_ts,
                            campaign_id=campaign["campaign_id"],
                            event_names=",".join(
                                map(lambda ev: f"'{ev}'", campaign["paid_events_names"])
                            ),
                            cost=campaign["paid_event_cost"],
                            timezone=campaign["tz_name"],
                            events_count=regular_events_count,
                            events_offset=0,
                        )
                    )

                # Last (partial-cost) campaign event
                if custom_last_event:
                    query_parts.append(
                        sqls.select_paid_from_normalized_for_processed.format(
                            packet_start=packet_start_ts,
                            packet_end=packet_end_ts,
                            campaign_id=campaign["campaign_id"],
                            event_names=",".join(
                                map(lambda ev: f"'{ev}'", campaign["paid_events_names"])
                            ),
                            cost=campaign["last_paid_event_cost"],
                            timezone=campaign["tz_name"],
                            events_count=1,
                            events_offset=regular_events_count,
                        )
                    )

            if query_parts:
                query = sqls.insert_to_processed_from_subq.format(
                    processing_metadata=json.dumps(processing_metadata),
                    subq=" UNION ALL ".join(query_parts),
                )
                await ch_client.execute(
                    query,
                    query_id=self._ch_query_id,
                    settings={"replace_running_query": 0},
                )
                query_parts = []

        return data
