from datetime import datetime, timedelta, timezone
from operator import itemgetter
from typing import Optional

from asyncpg import Connection
from croniter import croniter

from maps_adv.warden.server.lib.data_managers.tasks import (
    AbstractDataManager,
    ConflictOperation,
    TaskNotFound,
    UnknownTaskType,
)

__all__ = [
    "ConflictOperation",
    "Domain",
    "StatusSequenceViolation",
    "TaskInProgressByAnotherExecutor",
    "TaskNotFound",
    "TaskTypeAlreadyAssigned",
    "TooEarlyForNewTask",
    "UnknownTaskType",
]


class TaskTypeAlreadyAssigned(Exception):
    pass


class TaskInProgressByAnotherExecutor(Exception):
    pass


class TooEarlyForNewTask(Exception):
    __slots__ = "scheduled_time"

    scheduled_time: datetime

    def __init__(self, scheduled_time):
        self.scheduled_time = scheduled_time
        super().__init__()


class StatusSequenceViolation(Exception):
    __slots__ = "from_", "to"

    from_: str
    to: str

    def __init__(self, from_: str, to: str):
        self.from_ = from_
        self.to = to
        super().__init__()


class UpdateToInitialStatus(Exception):
    pass


class ExecutorIdAlreadyUsed(Exception):
    pass


destruct_task = itemgetter("executor_id", "status", "scheduled_time")


class Domain:
    __slots__ = ("_dm", "extra_time")

    _dm: AbstractDataManager
    extra_time: timedelta

    def __init__(self, dm: AbstractDataManager, extra_time: int):
        self._dm = dm
        self.extra_time = timedelta(seconds=extra_time)

    async def create_task(
        self, executor_id: str, type_name: str, metadata: Optional[dict] = None
    ) -> dict:
        async with self._dm.lock(type_name) as con:
            if await self._dm.is_executor_id_exists(executor_id, con):
                raise ExecutorIdAlreadyUsed()

            type_details = await self._dm.retrieve_task_type_details(type_name, con)

            active_task = await self._dm.retrieve_active_task_details(
                type_details["id"], con
            )
            if active_task:
                deadline = (
                    active_task["intake_time"]
                    + timedelta(seconds=type_details["time_limit"])
                    + self.extra_time
                )
                now = datetime.now(tz=timezone.utc)
                if now > deadline:
                    await self._dm.mark_task_as_failed(active_task["id"], con)
                if now < deadline:
                    raise TaskTypeAlreadyAssigned(type_name)

            if type_details["restorable"]:
                task_to_restore = active_task
                if not task_to_restore:
                    task_to_restore = await self._dm.retrieve_last_failed_task_details(
                        type_details["id"], con=con
                    )
                if task_to_restore:
                    return await self._restore_task(
                        executor_id, type_details, task_to_restore, con=con
                    )

            now = datetime.now(tz=timezone.utc)
            scheduled_time = await self._calculate_scheduled_time(
                type_details, con, now
            )

            if now < scheduled_time:
                raise TooEarlyForNewTask(scheduled_time)

            task_details = await self._dm.create_task(
                executor_id, type_details["id"], scheduled_time, metadata, con
            )

        return dict(
            task_id=task_details["task_id"],
            status=task_details["status"],
            time_limit=type_details["time_limit"],
        )

    async def _restore_task(
        self, executor_id: str, type_details: dict, failed_task: dict, con: Connection
    ) -> dict:
        await self._dm.restore_task(
            executor_id=executor_id,
            task_id=failed_task["id"],
            status=failed_task["status"],
            metadata=failed_task["metadata"],
            con=con,
        )
        return dict(
            task_id=failed_task["id"],
            status=failed_task["status"],
            time_limit=type_details["time_limit"],
            metadata=failed_task["metadata"],
        )

    async def _calculate_scheduled_time(
        self, type_details: dict, con: Connection, now: datetime
    ) -> datetime:
        scheduled_time = now
        last_created = await self._dm.find_last_task_of_type(type_details["id"], con)
        if last_created:
            scheduled_time = self._calculate_next_launch_time(
                type_details["schedule"], last_created["scheduled_time"]
            )

        return scheduled_time

    def _calculate_next_launch_time(self, cron: str, base_time: datetime) -> datetime:
        schedule = croniter(cron, base_time)
        return datetime.fromtimestamp(schedule.get_next(), tz=timezone.utc)

    async def update_task(
        self,
        executor_id: str,
        type_name: str,
        task_id: int,
        status: str,
        metadata: Optional[str] = None,
    ) -> Optional[datetime]:
        if status == "accepted":
            raise UpdateToInitialStatus()

        async with self._dm.lock(type_name) as con:
            task_details = await self._dm.retrieve_task_details(task_id, con)
            current_executor, current_status, scheduled_time = destruct_task(
                task_details
            )

            if current_executor != executor_id:
                raise TaskInProgressByAnotherExecutor()

            if current_status in ("completed", "failed") or current_status == status:
                raise StatusSequenceViolation(current_status, status)

            await self._dm.update_task(executor_id, task_id, status, metadata, con)

            type_schedule = (await self._dm.retrieve_task_type_details(type_name, con))[
                "schedule"
            ]
            if status == "completed" and type_schedule is not None:
                return self._calculate_next_launch_time(type_schedule, scheduled_time)

    async def mark_tasks_as_failed(self) -> None:
        await self._dm.mark_tasks_as_failed(
            extra_time=int(self.extra_time.total_seconds())
        )
