from datetime import datetime, timedelta, timezone
from typing import Optional

from maps_adv.stat_controller.server.lib.data_managers.normalizer import (
    TaskManager,
    TaskStatus,
)

from .exceptions import InProgressByAnotherExecutor, StatusSequenceViolation

__all__ = [
    "Domain",
    "InProgressByAnotherExecutor",
    "StatusSequenceViolation",
    "TaskStatus",
]


class Domain:
    __slots__ = (
        "_dm",
        "time_lag",
        "min_time_range",
        "max_time_range",
        "max_time_range_to_skip",
    )

    dm: TaskManager

    time_lag: int
    min_time_range: int
    max_time_range: int
    max_time_range_to_skip: int

    _status_sequences = {TaskStatus.accepted: TaskStatus.completed}

    def __init__(
        self,
        dm: TaskManager,
        time_lag: int,
        min_time_range: int,
        max_time_range: int,
        max_time_range_to_skip: int,
    ):
        self._dm = dm

        self.time_lag = time_lag
        self.min_time_range = min_time_range
        self.max_time_range = max_time_range
        self.max_time_range_to_skip = max_time_range_to_skip

    async def find_new(self, executor_id: str) -> Optional[dict]:
        async with self._dm.lock() as con:
            if await self._dm.is_there_in_progress(con):
                return

            now = datetime.now(timezone.utc)
            timing_to = now - timedelta(seconds=self.time_lag)
            timing_from = timing_to - timedelta(seconds=self.min_time_range)

            failed = await self._dm.find_failed_task(con)
            if failed:
                await self._dm.update(executor_id, failed["id"], TaskStatus.accepted)

                return failed

            last = await self._dm.find_last(con)
            if last:
                if last["timing_to"] >= timing_from:
                    return

                timing_from = last["timing_to"] + timedelta(seconds=1)

            if self._are_times_inside_same_hour(timing_from, timing_to):
                # Leave enough space for next time interval inside current hour
                duration_to_next_hour = (
                    timing_to.replace(minute=59, second=59) - timing_to
                )
                if duration_to_next_hour < timedelta(
                    seconds=self.max_time_range_to_skip
                ):
                    return
            else:
                # Time interval can't violate hour boundaries
                timing_to = timing_from.replace(minute=59, second=59)

            max_time_duration = timedelta(seconds=self.max_time_range)
            if timing_to - timing_from > max_time_duration:
                timing_to = timing_from + max_time_duration - timedelta(seconds=1)

            task_id = await self._dm.create(executor_id, timing_from, timing_to, con)

        return {"id": task_id, "timing_from": timing_from, "timing_to": timing_to}

    async def update(self, executor_id: str, task_id: int, status: TaskStatus) -> dict:
        details = await self._dm.retrieve_details(task_id)

        # raise for invalid status
        current_status = details["status"]
        if self._status_sequences.get(current_status) != status:
            raise StatusSequenceViolation(
                task_id=task_id,
                executor_id=executor_id,
                current_status=current_status,
                target_status=status,
            )

        # raise when trying to change tasks executor
        current_executor_id = details["executor_id"]
        if current_executor_id != executor_id:
            raise InProgressByAnotherExecutor(
                task_id=task_id,
                status=status,
                current_executor_id=current_executor_id,
                executor_id=executor_id,
            )

        await self._dm.update(executor_id, task_id, status)

        return {
            "id": details["id"],
            "timing_from": details["timing_from"],
            "timing_to": details["timing_to"],
        }

    @staticmethod
    def _are_times_inside_same_hour(dt_1: datetime, dt_2: datetime) -> bool:
        distance = abs(dt_1 - dt_2)
        return distance < timedelta(hours=1) and dt_1.hour == dt_2.hour
