import asyncio
import urllib.parse
from functools import wraps

import aiohttp
from tenacity import (
    AsyncRetrying,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from .exceptions import (
    BadGateway,
    Conflict,
    GatewayTimeout,
    InternalServerError,
    InvalidContentType,
    UnknownResponse,
    WrongPayload,
)

__all__ = ["BaseClient"]


def async_shield(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        return asyncio.shield(func(*args, **kwargs))

    return wrapper


class BaseClient:
    __slots__ = "url", "_session", "_retryer"

    url: str
    _session: aiohttp.ClientSession
    _retryer: AsyncRetrying

    def __init__(self, url: str, retry_settings: dict):
        self.url = url
        self._session = aiohttp.ClientSession()
        self._retryer = AsyncRetrying(
            stop=stop_after_attempt(retry_settings["max_attempts"]),
            retry=retry_if_exception_type(
                (
                    InternalServerError,
                    BadGateway,
                    GatewayTimeout,
                    aiohttp.ServerDisconnectedError,
                    aiohttp.ServerTimeoutError,
                )
            ),
            wait=wait_exponential(multiplier=retry_settings["wait_multiplier"]),
        )

    def _url(self, uri: str) -> str:
        return urllib.parse.urljoin(self.url, uri)

    async def _request(
        self, method: str, uri: str, expected_status: int, data: dict
    ) -> dict:
        return await self._retryer.call(
            self._request_inner,
            method=method,
            uri=uri,
            expected_status=expected_status,
            data=data,
        )

    async def _request_inner(
        self, method: str, uri: str, expected_status: int, data: dict
    ) -> dict:
        async with self._session.request(method, self._url(uri), json=data) as response:
            await self._check_response(response, expected_status)
            return await response.json()

    @staticmethod
    async def _check_response(response, expected_status: int):
        if response.status == expected_status:
            return
        elif response.status == 400:
            raise WrongPayload(await response.json())
        elif response.status == 409:
            raise Conflict(await response.json())
        elif response.status == 415:
            raise InvalidContentType(await response.json())
        elif response.status == 500:
            raise InternalServerError()
        elif response.status == 502:
            raise BadGateway()
        elif response.status == 504:
            raise GatewayTimeout()
        else:
            raise UnknownResponse(
                status_code=response.status, payload=await response.content.read()
            )

    async def close(self):
        await self._session.close()

    async def __aenter__(self):
        return self

    async def __aexit__(self, *args, **kwargs):
        await self._session.close()
