import json
from datetime import timedelta
from typing import Optional, Tuple

from aioch import Client

from maps_adv.warden.client.lib import ClientWithContextManager

from . import sqls
from .exceptions import NoNewData


class NormalizerTask:
    MAX_BUILD_NUMBER = 2 ** 32

    def __init__(
        self,
        min_packet_size: timedelta,
        max_packet_size: timedelta,
        lag: timedelta,
        deduplication_window: timedelta,
        ch_client_params: dict,
        app_filter: list,
        recognised_apps: dict,
        normalized_events_table_name: str = "normalized_events",
        ch_query_id: str = "",
        testing_future_events: bool = False,
    ):
        if min_packet_size.total_seconds() <= 0:
            raise ValueError("min_packet_size must be positive")
        if max_packet_size.total_seconds() <= 0:
            raise ValueError("max_packet_size must be positive")
        if lag.total_seconds() < 0:
            raise ValueError("lag must be non-negative")
        if deduplication_window.total_seconds() <= 0:
            raise ValueError("deduplication_window must be positive")
        if min_packet_size > max_packet_size:
            raise ValueError(
                "max_packet_size must be greater or equal than min_packet_size"
            )

        self._min_packet_len = int(min_packet_size.total_seconds())
        self._max_packet_len = int(max_packet_size.total_seconds())
        self._lag = int(lag.total_seconds())
        self._deduplication_window = int(deduplication_window.total_seconds())
        self._ch_client_params = ch_client_params
        self._normalized_events_table_name = normalized_events_table_name
        self._app_filter = {
            key: app_filter.get(key, self.MAX_BUILD_NUMBER)
            for key in (
                "ios_navi_build",
                "ios_maps_build",
                "ios_metro_build",
                "android_navi_build",
                "android_maps_build",
                "android_metro_build",
            )
        }
        self._recognised_apps = recognised_apps
        self._ch_query_id = ch_query_id
        self._testing_future_events = testing_future_events

    async def __call__(self, warden_client: ClientWithContextManager):
        ch_client = Client(**self._ch_client_params)
        try:
            packet_bounds = await self._calculate_packet_bounds(ch_client)
            if packet_bounds is None:
                raise NoNewData

            packet_start, packet_end = packet_bounds

            await warden_client.update_status(
                status="packet_bounds_calculated",
                metadata={"packet_start": packet_start, "packet_end": packet_end},
            )

            metadata_datatube = {
                "source": "datatube",
                "warden_executor_id": warden_client.executor_id,
                "app_filter": self._app_filter,
            }

            metadata_mapkit = {
                "source": "mapkittube",
                "warden_executor_id": warden_client.executor_id,
                "app_filter": self._app_filter,
            }

            await ch_client.execute(
                sqls.normalize_app_metric_and_mapkit(
                    self._normalized_events_table_name, self._recognised_apps
                ),
                {
                    "timestamp_start": packet_start,
                    "timestamp_end": packet_end,
                    "deduplication_window_start": packet_start - self._deduplication_window,
                    "deduplication_window_end": packet_start - 1,
                    "metadata_datatube": json.dumps(metadata_datatube),
                    "metadata_mapkit": json.dumps(metadata_mapkit),
                    **self._app_filter,
                },
                query_id=self._ch_query_id,
                settings={"replace_running_query": 0},
            )
        finally:
            await ch_client.disconnect()

    async def _calculate_packet_bounds(
        self, client: Client
    ) -> Optional[Tuple[int, int]]:
        packet_start, packet_end = (
            await client.execute(
                sqls.select_unprocessed_start_and_end_of_packet_for_mapkit_and_appmetric.format(  # noqa: E501
                    normalized_events_table_name=self._normalized_events_table_name
                ),
                query_id=self._ch_query_id,
                settings={"replace_running_query": 0},
            )
        )[0]

        if packet_start is None:
            return None

        if packet_start > packet_end and self._testing_future_events:
            packet_end = packet_start + self._max_packet_len

        packet_end -= self._lag
        packet_len = packet_end - packet_start
        if packet_len < self._min_packet_len:
            return None

        if packet_len > self._max_packet_len:
            packet_end = packet_start + self._max_packet_len

        return int(packet_start), int(packet_end)
