import enum

from marshmallow_enum import EnumField

from maps_adv.stat_controller.client.lib.base import (
    BaseClient,
    FindTaskSchema,
    NoTasksFound,
    TaskSchema,
    UnknownResponse,
    UpdateTaskSchema,
    with_schemas,
)

from .base.client import async_shield

__all__ = ["Client", "TaskStatus", "UnknownResponse", "NoNormalizerTaskFound"]


class TaskStatus(enum.Enum):
    completed = "completed"


class UpdateNormalizerTaskInputBaseSchema(UpdateTaskSchema):
    status = EnumField(TaskStatus)


class NoNormalizerTaskFound(NoTasksFound):
    pass


class Client(BaseClient):
    @with_schemas(FindTaskSchema, TaskSchema)
    async def find_new_task(self, executor_id: str) -> dict:
        json = {"executor_id": executor_id}

        try:
            return await self._request("POST", "/tasks/normalizer/", 201, json)
        except UnknownResponse as exc:
            if exc.status_code == 200 and exc.payload == b"{}":
                raise NoNormalizerTaskFound
            raise

    @async_shield
    @with_schemas(UpdateNormalizerTaskInputBaseSchema, TaskSchema)
    async def update_task(self, task_id: int, status: str, executor_id: str) -> dict:
        json = {"status": status, "executor_id": executor_id}

        return await self._request("PUT", f"/tasks/normalizer/{task_id}/", 200, json)
