import uuid
from functools import partial, wraps
from typing import List, Optional

import aiohttp
import google.protobuf.message
from tenacity import (
    AsyncRetrying,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from maps_adv.warden.common.schemas import (
    CreateTaskInputProtoSchema,
    TaskDetailsProtoSchema,
    UpdateTaskInputProtoSchema,
    UpdateTaskOutputProtoSchema,
)
from maps_adv.warden.proto.errors_pb2 import Error

from .exceptions import (
    BadGateway,
    Conflict,
    ExecutorIdAlreadyUsed,
    GatewayTimeout,
    InternalServerError,
    StatusSequenceViolation,
    TaskInProgressByAnotherExecutor,
    TaskTypeAlreadyAssigned,
    TooEarlyForNewTask,
    UnknownError,
    UnknownResponse,
    UnknownResponseBody,
    UnknownTaskOrType,
    UpdateStatusToInitial,
    ValidationError,
)
from .status import COMPLETED, FAILED

__all__ = ["Client", "ClientWithContextManager", "ClientFactory"]


RETRY_MAX_ATTEMPTS = 10
RETRY_WAIT_MULTIPLIER = 0.1


MAP_PROTO_ERROR_TO_EXCEPTION = {
    Error.ERROR_CODE.TOO_EARLY_FOR_NEW_TASK_OF_REQUESTED_TYPE: lambda e: TooEarlyForNewTask(  # noqa: E501
        next_try_proto_dt=e.scheduled_time
    ),
    Error.ERROR_CODE.CONFLICT: lambda _: Conflict(),
    Error.ERROR_CODE.STATUS_SEQUENCE_VIOLATION: lambda e: StatusSequenceViolation(
        e.description
    ),
    Error.ERROR_CODE.TASK_IN_PROGRESS_BY_ANOTHER_EXECUTOR: lambda _: TaskInProgressByAnotherExecutor(),  # noqa: E501
    Error.ERROR_CODE.TASK_TYPE_ALREADY_ASSIGNED: lambda _: TaskTypeAlreadyAssigned(),
    Error.ERROR_CODE.UNKNOWN_TASK_OR_TYPE: lambda _: UnknownTaskOrType(),
    Error.ERROR_CODE.UPDATE_STATUS_TO_INITIAL: lambda _: UpdateStatusToInitial(),
    Error.ERROR_CODE.EXECUTOR_ID_ALREADY_USED: lambda _: ExecutorIdAlreadyUsed(),
    Error.ERROR_CODE.VALIDATION_ERROR: lambda e: ValidationError(e.description),
}


def with_retrier(*a, **kwa):
    retrier = AsyncRetrying(*a, **kwa)

    def decorator(func):
        @wraps(func)
        async def wrapper(self, *aa, **kwaa):
            method = partial(func, self, retrier)
            return await method(*aa, **kwaa)

        wrapper.retry = retrier

        return wrapper

    return decorator


class Client:
    __slots__ = "server_url", "executor_id", "task_type"

    server_url: str
    executor_id: str
    task_type: str

    def __init__(self, server_url: str, *, executor_id: str, task_type: str):
        self.server_url = server_url
        self.executor_id = executor_id
        self.task_type = task_type

    @with_retrier(
        stop=stop_after_attempt(RETRY_MAX_ATTEMPTS),
        retry=retry_if_exception_type(
            (
                BadGateway,
                GatewayTimeout,
                InternalServerError,
                aiohttp.ServerDisconnectedError,
                aiohttp.ServerTimeoutError,
            )
        ),
        wait=wait_exponential(multiplier=RETRY_WAIT_MULTIPLIER),
    )
    async def create_task(self, retrier, metadata: Optional[dict] = None) -> dict:
        url = f"{self.server_url}/tasks/"

        payload = {
            "type_name": self.task_type,
            "executor_id": self.executor_id,
            "metadata": metadata,
        }
        data = CreateTaskInputProtoSchema().to_bytes(payload)

        got = await retrier(self._request, "POST", url, data, [201])

        return TaskDetailsProtoSchema().from_bytes(got)

    @with_retrier(
        stop=stop_after_attempt(RETRY_MAX_ATTEMPTS),
        retry=retry_if_exception_type(
            (
                BadGateway,
                Conflict,
                GatewayTimeout,
                InternalServerError,
                aiohttp.ServerDisconnectedError,
                aiohttp.ServerTimeoutError,
            )
        ),
        wait=wait_exponential(multiplier=RETRY_WAIT_MULTIPLIER),
    )
    async def update_task(
        self, retrier, task_id: int, status: str, metadata: Optional[str] = None
    ) -> None:
        url = f"{self.server_url}/tasks/"

        payload = {
            "type_name": self.task_type,
            "task_id": task_id,
            "executor_id": self.executor_id,
            "status": status,
            "metadata": metadata,
        }

        data = UpdateTaskInputProtoSchema().to_bytes(payload)

        got = await retrier(self._request, "PUT", url, data, [204, 200])

        return UpdateTaskOutputProtoSchema().from_bytes(got)

    async def _request(
        self, method: str, url: str, payload: dict, expected_statuses: List[int]
    ):
        async with aiohttp.ClientSession() as session:
            async with session.request(method, url, data=payload) as resp:
                await self._check_response(resp, expected_statuses)
                return await self._extract_response_content(resp)

    @staticmethod
    async def _extract_response_content(response):
        if "application/json" in response.headers.get("Content-Type", ""):
            return await response.json()
        else:
            return await response.content.read()

    @classmethod
    async def _check_response(
        cls, resp: aiohttp.ClientResponse, expected_statuses: List[int]
    ):
        if resp.status not in expected_statuses:
            await cls._process_bad_response(resp)

    @classmethod
    async def _process_bad_response(cls, resp: aiohttp.ClientResponse):
        if resp.status == 500:
            raise InternalServerError()
        elif resp.status == 502:
            raise BadGateway()
        elif resp.status == 504:
            raise GatewayTimeout()
        elif resp.status in (400, 403, 404, 409):
            await cls._process_proto_error_response(resp)
        else:
            raise UnknownResponse(resp.status, await resp.read())

    @classmethod
    async def _process_proto_error_response(cls, resp: aiohttp.ClientResponse):
        response_body = await resp.read()
        try:
            error = Error.FromString(response_body)
        except google.protobuf.message.DecodeError:
            raise UnknownResponseBody(resp.status, response_body)

        if error.code in MAP_PROTO_ERROR_TO_EXCEPTION:
            raise MAP_PROTO_ERROR_TO_EXCEPTION[error.code](error)
        else:
            raise UnknownError(resp.status, error)


class ClientWithContextManager:
    __slots__ = "_client", "_task_info"

    def __init__(self, client: Client):
        self._client = client
        self._task_info = None

    def _clean(self):
        self._client = None
        self._task_info = None

    @property
    def time_limit(self) -> int:
        return self._task_info["time_limit"]

    @property
    def uid(self) -> Optional[int]:
        if self._task_info:
            return self._task_info["task_id"]
        return None

    @property
    def status(self) -> Optional[int]:
        if self._task_info:
            return self._task_info["status"]
        return None

    @property
    def metadata(self) -> Optional[dict]:
        if self._task_info:
            return self._task_info.get("metadata")
        return None

    @property
    def executor_id(self) -> str:
        return self._client.executor_id

    async def update_status(self, status: str, metadata: Optional[dict] = None):
        if not self.uid:
            raise UnknownTaskOrType()

        return await self._client.update_task(
            task_id=self.uid, status=status, metadata=metadata
        )

    async def failed(self, metadata: Optional[dict] = None):
        await self.update_status(FAILED, metadata)
        self._clean()

    async def completed(self, metadata: Optional[dict] = None):
        await self.update_status(COMPLETED, metadata)
        self._clean()

    async def __aenter__(self):
        self._task_info = await self._client.create_task()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.uid:
            if exc_type:
                await self.failed(
                    {
                        "exception_type": exc_type.__name__,
                        "exception_value": str(exc_val),
                    }
                )
            else:
                await self.completed()


class ClientFactory:
    __slots__ = "_server_url"

    _client_class = Client

    def __init__(self, server_url: str):
        self._server_url = server_url

    def client(self, task_type: str) -> Client:
        return self._client_class(
            server_url=self._server_url,
            executor_id=str(uuid.uuid4()),
            task_type=task_type,
        )

    def context(self, task_type: str) -> ClientWithContextManager:
        return ClientWithContextManager(self.client(task_type))
