from typing import Optional

from maps_adv.stat_controller.server.lib.data_managers.charger import (
    TaskManager,
    TaskStatus,
)

from .exceptions import InProgressByAnotherExecutor, StatusSequenceViolation

__all__ = [
    "Domain",
    "InProgressByAnotherExecutor",
    "StatusSequenceViolation",
    "TaskStatus",
]


class Domain:
    __slots__ = ("_dm",)

    _dm: TaskManager

    _status_sequences = {
        TaskStatus.accepted: TaskStatus.context_received,
        TaskStatus.context_received: TaskStatus.calculation_completed,
        TaskStatus.calculation_completed: TaskStatus.billing_notified,
        TaskStatus.billing_notified: TaskStatus.charged_data_sent,
        TaskStatus.charged_data_sent: TaskStatus.completed,
    }

    _important_statuses = [TaskStatus.billing_notified, TaskStatus.charged_data_sent]

    def __init__(self, dm: TaskManager):
        self._dm = dm

    async def _accept_task_from_start(self, executor_id, task, con):
        await self._dm.update(executor_id, task["id"], TaskStatus.accepted, con=con)
        task["status"] = TaskStatus.accepted
        if "execution_state" in task:
            del task["execution_state"]
        return task

    async def _restore_from_state(self, executor_id, task, con):
        await self._dm.update(
            executor_id, task["id"], task["status"], task["execution_state"], con
        )
        return task

    async def find_new(self, executor_id: str) -> Optional[dict]:
        async with self._dm.lock() as con:

            if await self._dm.is_there_in_progress(con):
                return

            task = await self._dm.find_oldest_available(con)
            if not task:
                return

            if "status" in task and task["status"] in self._important_statuses:
                return await self._restore_from_state(executor_id, task, con)

            return await self._accept_task_from_start(executor_id, task, con)

    async def update(
        self, executor_id: str, task_id: int, status: TaskStatus, execution_state: str
    ) -> dict:
        details = await self._dm.retrieve_details(task_id)

        # raise for invalid status
        current_status = details["status"]
        if self._status_sequences.get(current_status) != status:
            raise StatusSequenceViolation(
                task_id=task_id,
                executor_id=executor_id,
                current_status=current_status,
                target_status=status,
            )

        # raise when trying to change task_uid
        current_executor_id = details["executor_id"]
        if current_executor_id != executor_id:
            raise InProgressByAnotherExecutor(
                task_id=task_id,
                status=status,
                current_executor_id=current_executor_id,
                executor_id=executor_id,
            )

        await self._dm.update(executor_id, task_id, status, execution_state)

        return {
            "id": details["id"],
            "timing_from": details["timing_from"],
            "timing_to": details["timing_to"],
        }
