import json
import logging
from datetime import timedelta
from itertools import chain, groupby
from operator import itemgetter
from typing import Optional

from aioch import Client

from maps_adv.statistics.beekeeper.lib.steps import sqls
from maps_adv.statistics.beekeeper.lib.steps.base import (
    BaseStep,
    FreeEventProcessingMode,
)
from maps_adv.warden.client.lib import TaskContext

from . import schemas


class FreeEventsPorter(BaseStep):

    input_schema = schemas.AdvStoreNotificationState
    output_schema = schemas.FreeEventsPorterState

    def __init__(
        self,
        ch_client_params: dict,
        ch_max_memory_usage: int,
        event_group_id_time_threshold: timedelta,
        ch_query_id: str = "",
    ):
        self._ch_client_params = ch_client_params
        self._ch_max_memory_usage = ch_max_memory_usage
        self._event_group_id_time_threshold = event_group_id_time_threshold
        self._ch_query_id = ch_query_id

    async def run(
        self, data: Optional[dict] = None, context: Optional[TaskContext] = None
    ):
        packet_start_ts = int(data["packet_start"].timestamp())
        packet_end_ts = int(data["packet_end"].timestamp())

        if not data["orders"]:
            return data

        query_parts = []
        # We assume "paid_events_names" is always sorted (it is likely to be
        # because this lists are pregenerated basing on campaign billing type).
        # Even if this is wrong, we only create bigger SQL than we could
        grouped_campaigns = groupby(
            sorted(
                chain.from_iterable(map(_campaigns_getter, data["orders"])),
                key=_paid_events_names_getter,
            ),
            key=_paid_events_names_getter,
        )
        for paid_events_names, campaigns in grouped_campaigns:
            query_parts.append(
                sqls.free_events_for_campaigns_exist_in_processed.format(
                    campaign_ids=",".join(
                        map(lambda c: str(c["campaign_id"]), campaigns)
                    ),
                    event_names=",".join(map(lambda ev: f"'{ev}'", paid_events_names)),
                    packet_start=packet_start_ts,
                    packet_end=packet_end_ts,
                )
            )

        query = " UNION ALL ".join(query_parts)
        ch_client = Client(
            settings={"max_memory_usage": self._ch_max_memory_usage},
            **self._ch_client_params,
        )
        events_exist = await ch_client.execute(
            query,
            query_id=self._ch_query_id,
            settings={"replace_running_query": 0},
        )
        if events_exist:
            logging.getLogger("beekeeper.free_events_porter").error(
                "Ignored duplicate free events saving for packet (%d, %d)",
                packet_start_ts,
                packet_end_ts,
            )
            return data

        query_parts = []
        event_group_id_packet_start_ts = int(
            (data["packet_start"] - self._event_group_id_time_threshold).timestamp()
        )
        event_group_id_packet_end_ts = packet_end_ts
        grouped_campaigns = groupby(
            sorted(
                chain.from_iterable(map(_campaigns_getter, data["orders"])),
                key=_paid_events_and_tz_name_getter,
            ),
            key=_paid_events_tz_name_free_event_processing_mode_getter,
        )

        for (
            (paid_events_names, tz_name, free_event_processing_mode),
            campaigns,
        ) in grouped_campaigns:
            if (
                free_event_processing_mode
                == FreeEventProcessingMode.ONLY_IF_PAID_PRESENT
            ):
                query_parts.append(
                    sqls.select_free_from_normalized_for_processed_only_if_paid_present.format(  # noqa: E501
                        campaigns_ids=",".join(
                            map(lambda c: str(c["campaign_id"]), campaigns)
                        ),
                        event_names=",".join(
                            map(lambda ev: f"'{ev}'", paid_events_names)
                        ),
                        packet_start=packet_start_ts,
                        packet_end=packet_end_ts,
                        timezone=tz_name,
                        event_group_id_packet_start=event_group_id_packet_start_ts,
                        event_group_id_packet_end=event_group_id_packet_end_ts,
                    )
                )
            elif free_event_processing_mode == FreeEventProcessingMode.ALL_EVENTS:
                query_parts.append(
                    sqls.select_free_from_normalized_for_processed_all_events.format(
                        campaigns_ids=",".join(
                            map(lambda c: str(c["campaign_id"]), campaigns)
                        ),
                        event_names=",".join(
                            map(lambda ev: f"'{ev}'", paid_events_names)
                        ),
                        packet_start=packet_start_ts,
                        packet_end=packet_end_ts,
                        timezone=tz_name,
                    )
                )
            else:
                logging.getLogger(__name__).error(
                    f"Unknown free event processing mode {free_event_processing_mode}"
                )

        processing_metadata = {}
        if context:
            processing_metadata["warden_executor_id"] = context.client.executor_id
        query = sqls.insert_to_processed_from_subq.format(
            processing_metadata=json.dumps(processing_metadata),
            subq=" UNION ALL ".join(query_parts),
        )
        await ch_client.execute(
            query,
            query_id=self._ch_query_id,
            settings={"replace_running_query": 0},
        )

        return data


_campaigns_getter = itemgetter("campaigns")
_paid_events_names_getter = itemgetter("paid_events_names")
_paid_events_and_tz_name_getter = itemgetter("paid_events_names", "tz_name")
_paid_events_tz_name_free_event_processing_mode_getter = itemgetter(
    "paid_events_names", "tz_name", "free_event_processing_mode"
)
