from datetime import datetime, timedelta, timezone
from typing import Optional

from aioch import Client


class NoNewNormalizedEventsFound(Exception):
    pass


class PacketSizeCalculator:
    __slots__ = "time_lag", "min_packet_size", "max_packet_size", "ch_client_params"

    time_lag: timedelta
    min_packet_size: timedelta
    max_packet_size: timedelta
    ch_client_params: dict

    def __init__(
        self,
        time_lag: timedelta,
        min_packet_size: timedelta,
        max_packet_size: timedelta,
        **ch_client_params,
    ):
        self.time_lag = time_lag
        self.min_packet_size = min_packet_size
        self.max_packet_size = max_packet_size
        self.ch_client_params = ch_client_params

    async def __call__(self) -> Optional[dict]:
        min_unprocessed_timing = await self.retrieve_min_unprocessed_timing()

        if not min_unprocessed_timing:
            raise NoNewNormalizedEventsFound()

        now = datetime.now(timezone.utc)
        packet_end = now - self.time_lag

        if min_unprocessed_timing > packet_end - self.min_packet_size:
            return

        if self._are_times_inside_same_hour(min_unprocessed_timing, packet_end):
            # Leave enough space for next time interval inside current hour
            duration_to_next_hour = (
                packet_end.replace(minute=59, second=59) - packet_end
            )
            if duration_to_next_hour < self.min_packet_size:
                return
        else:
            # Time interval can't violate hour boundaries
            packet_end = min_unprocessed_timing.replace(minute=59, second=59)

        if packet_end - min_unprocessed_timing > self.max_packet_size:
            packet_end = (
                min_unprocessed_timing + self.max_packet_size - timedelta(seconds=1)
            )

        max_unprocessed_timing = await self.retrieve_max_unprocessed_timing(packet_end)
        return {
            "packet_start": min_unprocessed_timing,
            "packet_end": max_unprocessed_timing,
        }

    @staticmethod
    def _are_times_inside_same_hour(dt_1: datetime, dt_2: datetime) -> bool:
        distance = abs(dt_1 - dt_2)
        return distance < timedelta(hours=1) and dt_1.hour == dt_2.hour

    async def retrieve_min_unprocessed_timing(self) -> datetime:
        sql = """
            SELECT
                CASE
                    WHEN any(receive_timestamp) > 0
                    THEN min(receive_timestamp)
                    ELSE NULL
                END
            FROM normalized_events_distributed
            WHERE receive_timestamp > (
                SELECT max(receive_timestamp)
                FROM processed_events_distributed
            )
            SETTINGS distributed_product_mode='global'
        """

        client = Client(**self.ch_client_params)
        try:
            return (await client.execute(sql))[0][0]
        finally:
            await client.disconnect()

    async def retrieve_max_unprocessed_timing(self, up_limit: datetime) -> datetime:
        sql = """
            SELECT max(receive_timestamp)
            FROM normalized_events_distributed
            WHERE receive_timestamp <= %(up_limit)s
        """
        client = Client(**self.ch_client_params)
        try:
            got = await client.execute(sql, {"up_limit": int(up_limit.timestamp())})
        finally:
            await client.disconnect()
        return got[0][0]
