import json
from datetime import timedelta
from typing import Optional, Tuple

from aioch import Client

from maps_adv.warden.client.lib import ClientWithContextManager

from . import sqls
from .exceptions import NoNewData


class NormalizerTask:
    def __init__(
        self,
        min_packet_size: timedelta,
        max_packet_size: timedelta,
        lag: timedelta,
        ch_client_params: dict,
        ch_query_id: str = "",
    ):
        if min_packet_size.total_seconds() <= 0:
            raise ValueError("min_packet_size must be positive")
        if max_packet_size.total_seconds() <= 0:
            raise ValueError("max_packet_size must be positive")
        if lag.total_seconds() < 0:
            raise ValueError("lag must be non-negative")
        if min_packet_size > max_packet_size:
            raise ValueError(
                "max_packet_size must be greater or equal than min_packet_size"
            )

        self._min_packet_len = int(min_packet_size.total_seconds())
        self._max_packet_len = int(max_packet_size.total_seconds())
        self._lag = int(lag.total_seconds())
        self._ch_client_params = ch_client_params
        self._ch_query_id = ch_query_id

    async def __call__(self, warden_client: ClientWithContextManager):
        ch_client = Client(**self._ch_client_params)
        try:
            packet_bounds = await self._calculate_packet_bounds(ch_client)
            if packet_bounds is None:
                raise NoNewData

            packet_start, packet_end = packet_bounds

            await warden_client.update_status(
                status="packet_bounds_calculated",
                metadata={"packet_start": packet_start, "packet_end": packet_end},
            )

            metadata_datatube = {
                "source": "datatube",
                "warden_executor_id": warden_client.executor_id,
            }

            await ch_client.execute(
                sqls.normalize,
                {
                    "timestamp_start": packet_start,
                    "timestamp_end": packet_end,
                    "metadata_datatube": json.dumps(metadata_datatube),
                },
                query_id=self._ch_query_id,
                settings={"replace_running_query": 0},
            )
        finally:
            await ch_client.disconnect()

    async def _calculate_packet_bounds(
        self, client: Client
    ) -> Optional[Tuple[int, int]]:
        packet = (
            await client.execute(
                sqls.select_unprocessed_start_and_end_of_packet,
                query_id=self._ch_query_id,
                settings={"replace_running_query": 0},
            )
        )[0]
        packet_start, packet_end = packet

        if packet_start is None:
            return None

        packet_end -= self._lag
        packet_len = packet_end - packet_start
        if packet_len < self._min_packet_len:
            return None

        if packet_len > self._max_packet_len:
            packet_end = packet_start + self._max_packet_len

        return int(packet_start), int(packet_end)
