import aiohttp
import urllib.parse
from contextlib import AbstractAsyncContextManager, nullcontext
from tenacity import (
    AsyncRetrying,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)
from typing import Optional

from .exceptions import (
    BadGateway,
    ServiceUnavailable,
)


REQUEST_MAX_ATTEMPTS = 5
RETRY_WAIT_MULTIPLIER = 0.1


class Client(AbstractAsyncContextManager):
    __slots__ = "_session", "url"

    _session: Optional[aiohttp.ClientSession]

    url: str

    def __init__(self, url: str):
        self._session = None
        self.url = url

    async def __aenter__(self) -> "Client":
        self._session = aiohttp.ClientSession()
        return self

    async def __aexit__(self, *args, **kwargs):
        await self._session.close()
        self._session = None

    async def _request(
        self, method: str, uri: str, *, expected_status: int, **kwargs
    ) -> bytes:
        url = urllib.parse.urljoin(self.url, uri)

        context = (
            nullcontext(self._session)
            if self._session is not None
            else aiohttp.ClientSession()
        )
        async with context as session, session.request(
            method, url, **kwargs
        ) as response:
            if response.status == expected_status:
                return await response.content.read()

            for k, v in ((502, BadGateway), (503, ServiceUnavailable)):
                if response.status == k:
                    raise v()

            await self._check_response(response)

    @property
    def _retryer(self):
        return AsyncRetrying(
            stop=stop_after_attempt(REQUEST_MAX_ATTEMPTS),
            retry=retry_if_exception_type(
                (
                    BadGateway,
                    ServiceUnavailable,
                    aiohttp.ServerDisconnectedError,
                    aiohttp.ServerTimeoutError,
                )
            ),
            wait=wait_exponential(multiplier=RETRY_WAIT_MULTIPLIER),
        )
