import copy
from typing import Awaitable, Callable, Iterable

from maps_adv.warden.client.lib.exceptions import UnknownWardenStatus
from maps_adv.warden.client.lib.status import ACCEPTED

from .context import TaskContext


class Step:
    __slots__ = "status", "func"

    def __init__(self, status: str, func: Callable[[TaskContext], Awaitable]):
        self.status = status
        self.func = func

    def __str__(self):
        return f"status={self.status}, func={self.func.__name__}"


class Pipeline:
    __slots__ = "_steps"

    def __init__(self, steps: Iterable[Step]):
        self._steps = tuple(steps)

    async def __call__(self, context: TaskContext):
        client = context.client  # type: ClientWithContextManager  # noqa
        start_from = self._find_step_by_status(context.status)
        steps = self._steps[start_from:]
        current_context = self._generate_context(context)

        for current_step in steps:
            step_result = await current_step.func(current_context)

            await client.update_status(current_step.status, metadata=step_result)

            current_context = self._generate_context(
                context, status=current_step.status, metadata=step_result
            )

    @classmethod
    def _generate_context(cls, context: TaskContext, **params) -> TaskContext:
        return TaskContext(
            client=context.client,
            status=params.get("status", context.status),
            metadata=params.get("metadata", context.metadata),
            params=copy.deepcopy(context.params),
        )

    def _find_step_by_status(self, status: str) -> int:
        if status == ACCEPTED:
            return 0

        for index, step in enumerate(self._steps):
            if step.status == status:
                return index + 1

        raise UnknownWardenStatus(status)
