import asyncio
import os
import random
import time
from abc import ABCMeta
from base64 import b64encode
from enum import Enum, Flag, auto, unique
from itertools import chain
from json import JSONDecodeError
from types import AsyncGeneratorType, GeneratorType
from typing import Any, ClassVar, Collection, Dict, Generic, Iterable, Mapping, Optional, Tuple, Type, TypeVar, Union

from aiohttp import ClientResponse, ClientSession, ClientTimeout, ContentTypeError, TCPConnector
from multidict import CIMultiDict

from sendr_interactions.deadline import Deadline
from sendr_interactions.exceptions import InteractionResponseError
from sendr_interactions.retry_budget import RetryBudgetProtocol, UnlimitedRetryBudget
from sendr_qlog import LoggerContext
from sendr_tvm.client import aiohttp as tvm_session
from sendr_utils import enum_value
from sendr_utils.requests import extract_correlation_headers
from sendr_writers.base.pusher import CommonPushers, InteractionResponseLog

ResponseType = TypeVar('ResponseType')


class LogFlag(Flag):
    NEVER = 0
    ON_ERROR = auto()
    ON_SUCCESS = auto()
    ALWAYS = ON_ERROR | ON_SUCCESS


@unique
class ResponseFormat(Enum):
    JSON = 'json'
    TEXT = 'text'
    BYTES = 'bytes'
    ORIGINAL = 'original'


class AbstractInteractionClient(Generic[ResponseType], metaclass=ABCMeta):
    DEBUG: ClassVar[bool]
    CONNECTOR: ClassVar[TCPConnector]
    REQUEST_RETRY_TIMEOUTS: ClassVar[Iterable[float]]

    REQUEST_TIMEOUT: ClassVar[Optional[float]] = None
    CONNECT_TIMEOUT: ClassVar[Optional[float]] = None
    RESPONSE_FORMAT: ClassVar[ResponseFormat] = ResponseFormat.JSON

    RTT_ESTIMATE: ClassVar[float] = 0.03  # 30ms

    SERVICE: ClassVar[str]
    BASE_URL: ClassVar[str]
    TVM_ID: ClassVar[Optional[int]] = None
    TVM_SESSION_CLS: ClassVar[Optional[Type[tvm_session.TvmSession]]] = None

    LOGGING_SENSITIVE_FIELDS: ClassVar[Tuple[str, ...]] = ()
    LOGGING_EXPOSED_HEADERS: ClassVar[Tuple[str, ...]] = ()

    LOG_CORRELATION_HEADERS: ClassVar[bool] = False
    LOG_RESPONSE_BODY: ClassVar[LogFlag] = LogFlag.NEVER

    logger: LoggerContext
    request_id: str
    tvm_id: Optional[int]

    _session: Optional[ClientSession] = None

    def __init__(
        self,
        logger: LoggerContext,
        request_id: str,
        pushers: Optional[CommonPushers] = None,
        tvm_id: Optional[int] = None,
        retry_budget: Optional[RetryBudgetProtocol] = None
    ):
        self.logger = logger
        self.logger.context_push(service=self.SERVICE)
        self.request_id = request_id
        self.pushers = pushers
        self.tvm_id = tvm_id if tvm_id is not None else self.TVM_ID
        self.retry_budget = retry_budget or UnlimitedRetryBudget()

        self.default_timeout: Optional[ClientTimeout] = None
        if self.REQUEST_TIMEOUT:
            self.default_timeout = ClientTimeout(total=self.REQUEST_TIMEOUT, connect=self.CONNECT_TIMEOUT)

        assert all(value.lower() == value for value in self.LOGGING_EXPOSED_HEADERS)
        assert all(value.lower() == value for value in self.LOGGING_SENSITIVE_FIELDS)

    @property
    def ip(self):
        # В deploy нету подходящего прямого аналога. В #SWATTOOLS-43 обсуждаем, где этот аналог найти
        return os.environ.get('QLOUD_IPV6', '127.0.0.1').split('/')[0]

    def _get_session_cls(self) -> Type[ClientSession]:
        """Returns ClientSession if no tvm_id specified and TvmSession otherwise."""
        if self.tvm_id is None:
            return ClientSession
        if self.TVM_SESSION_CLS is None:
            raise RuntimeError('TVM_ID specified, but TVM_SESSION_CLS is None')
        return self.TVM_SESSION_CLS

    def _get_timeout(self, deadline: Optional[Deadline]) -> Optional[ClientTimeout]:
        if deadline is None:
            return self.default_timeout

        if (seconds_to := deadline.seconds_to()) > .0:
            if self.default_timeout and self.default_timeout.total <= seconds_to:
                return self.default_timeout

            return ClientTimeout(total=seconds_to, connect=self.CONNECT_TIMEOUT)

        raise asyncio.TimeoutError('Request deadline reached')

    def _get_session_kwargs(self) -> Dict[str, Any]:
        """Returns kwargs necessary for creating a session instance."""
        kwargs = {
            'connector': self.CONNECTOR,
            'connector_owner': False,
            'trust_env': True,
        }
        if self.tvm_id is not None:
            kwargs['tvm_dst'] = self.TVM_ID
        if self.default_timeout:
            kwargs['timeout'] = self.default_timeout
        return kwargs

    @property
    def session(self) -> ClientSession:
        if self._session is None:
            self._session = self.create_session()
        return self._session

    def create_session(self) -> ClientSession:
        session_cls = self._get_session_cls()
        kwargs = self._get_session_kwargs()
        return session_cls(**kwargs)

    @staticmethod
    async def _get_response_body(response: ClientResponse) -> Dict[str, Any]:
        try:
            return {'response': await response.json()}
        except (JSONDecodeError, ContentTypeError):
            pass

        try:
            return {'response_text': await response.text()}
        except UnicodeDecodeError:
            pass

        return {'response_base64': b64encode(await response.read()).decode('utf-8')}

    async def _try_log_error_response_body(self, response: ClientResponse) -> None:
        if self.LOG_RESPONSE_BODY & LogFlag.ON_ERROR:
            with self.logger:
                self.logger.context_push(**await self._get_response_body(response))
                self.logger.error('Unsuccessful interaction response body logged')

    async def _try_log_success_response_body(self, response: ClientResponse) -> None:
        if self.LOG_RESPONSE_BODY & LogFlag.ON_SUCCESS:
            with self.logger:
                self.logger.context_push(client=self.__class__.__name__, **await self._get_response_body(response))
                self.logger.info('Successful interaction response body logged')

    async def _handle_response_error(self, response: ClientResponse) -> None:
        params = {}

        if self.LOG_CORRELATION_HEADERS:
            params['headers'] = extract_correlation_headers(response.headers)

        await self._try_log_error_response_body(response)
        raise InteractionResponseError(
            status_code=response.status,
            method=response.method,
            service=self.SERVICE,
            params=params or None,
        )

    async def _format_response(self, response: ClientResponse) -> ResponseType:
        if self.RESPONSE_FORMAT == ResponseFormat.JSON:
            return await response.json()
        if self.RESPONSE_FORMAT == ResponseFormat.TEXT:
            return await response.text()  # type: ignore
        if self.RESPONSE_FORMAT == ResponseFormat.BYTES:
            return await response.read()  # type: ignore
        if self.RESPONSE_FORMAT == ResponseFormat.ORIGINAL:
            return response  # type: ignore
        raise NotImplementedError(f'Response format {enum_value(self.RESPONSE_FORMAT)} not implemented')

    async def _process_response(self, response: ClientResponse, interaction_method: str) -> ResponseType:
        if response.status >= 400:
            await self._handle_response_error(response)
        return await self._format_response(response)

    @classmethod
    def _response_time_metrics(cls, interaction_method: str, response_time: float) -> None:
        pass

    @classmethod
    def _response_status_metrics(cls, interaction_method: str, status: int) -> None:
        pass

    async def _is_failed(
        self,
        response: ClientResponse,
    ) -> bool:
        if response.status >= 500:  # type: ignore
            return True

        return False

    async def _should_retry_failed_request(
        self,
        interaction_method: str,
        exc: Optional[Exception] = None,
        response: Optional[ClientResponse] = None,
    ) -> bool:
        can_retry = self.retry_budget.can_retry(self.SERVICE)
        if not can_retry:
            self.logger.info('Retry  budget for "%s" spent', self.SERVICE)
        return can_retry

    async def _make_request(
        self,
        interaction_method: str,
        method: str,
        url: str,
        deadline: Deadline = None,
        user_ticket: Optional[str] = None,
        **kwargs: Any
    ) -> ClientResponse:
        """Wraps ClientSession.request allowing retries, updating metrics."""
        self.logger.context_push(interaction_method=interaction_method)

        kwargs.setdefault('headers', {})
        if self.request_id:
            kwargs['headers']['X-Request-Id'] = self.request_id
        if user_ticket:
            kwargs['headers']['x-ya-user-ticket'] = user_ticket

        response_time = 0.0
        response = exc = None
        for retry_number, retry_delay in enumerate(chain((0.0,), self.REQUEST_RETRY_TIMEOUTS)):
            if retry_delay:
                delay = retry_delay - response_time
                await asyncio.sleep(delay + random.uniform(-delay / 2, delay / 2))
            self.logger.context_push(retry_number=retry_number)

            exc = None
            response = None
            before = time.monotonic()
            try:
                if (timeout := self._get_timeout(deadline)) is not None and timeout.total is not None:
                    kwargs['timeout'] = timeout
                    kwargs['headers']['X-Request-Timeout'] = str(int(max(0, timeout.total - self.RTT_ESTIMATE) * 1000))
                    self.logger.context_push(request_timeout=timeout.total)

                response = await self.session.request(method, url, **kwargs)

                assert response is not None
                success = not await self._is_failed(response)
            except Exception as e:
                exc = e
                success = False

            response_time = time.monotonic() - before
            response_status = response.status if response else 599
            self.logger.context_push(
                client=self.__class__.__name__,
                method=interaction_method,
                response_time=response_time,
                response_status=response_status,
                response_headers=self._scrub_logging_headers(response.headers) if response else None,
                exc=str(exc),
            )

            self._response_time_metrics(interaction_method, response_time)
            self._response_status_metrics(interaction_method, response_status)

            if success:
                self.logger.info('Interaction client request success')
                self.retry_budget.success(self.SERVICE)
            else:
                self.logger.warning('Interaction client request failed')
                self.retry_budget.fail(self.SERVICE)

            if (
                success
                or isinstance(exc, asyncio.TimeoutError)
                or not await self._should_retry_failed_request(interaction_method, exc, response)
            ):
                break

        if exc:
            raise exc

        return response  # type: ignore

    @classmethod
    def _scrub_logging_request_kwargs_value(cls, value: Any) -> Any:
        if isinstance(value, AsyncGeneratorType):
            return '<AsyncGenerator>'
        elif isinstance(value, GeneratorType):
            return '<Generator>'
        elif isinstance(value, bytes):
            return f'<Bytes: len {len(value)}>'
        else:
            return value

    @classmethod
    def _scrub_logging_headers(cls, headers: Mapping[str, Any]) -> CIMultiDict[str]:
        filtered: CIMultiDict[str] = CIMultiDict()
        for key, value in headers.items():
            if key.lower() not in cls.LOGGING_EXPOSED_HEADERS:
                value = ''
            filtered.add(key, value)
        if cls.LOG_CORRELATION_HEADERS:
            filtered.update(extract_correlation_headers(headers))
        return filtered

    @classmethod
    def _scrub_logging_request_kwargs(cls, request_kwargs: dict, safe: bool = True) -> dict:
        scrubbed_kwargs = {}
        for key, value in request_kwargs.items():
            if isinstance(value, dict):
                if key == 'headers':
                    value = dict(cls._scrub_logging_headers(value))
                    if not value:
                        continue
                else:
                    value = {
                        k: cls._scrub_logging_request_kwargs_value(v)
                        for k, v in value.items() if not safe or k.lower() not in cls.LOGGING_SENSITIVE_FIELDS
                    }
            else:
                value = cls._scrub_logging_request_kwargs_value(value)
            scrubbed_kwargs[key] = value
        return scrubbed_kwargs

    async def _get_body(self, response: ResponseType) -> Union[str, bytes, ResponseType]:
        response_data: Union[str, bytes, ResponseType]
        if isinstance(response, ClientResponse):
            try:
                response_data = await response.text('utf-8')
            except ValueError:
                response_data = await response.read()
        else:
            response_data = response
        return response_data

    async def _request(  # noqa: C901
        self,
        interaction_method: str,
        method: str,
        url: str,
        response_log: bool = True,
        response_log_body: bool = True,
        **kwargs: Any,
    ) -> ResponseType:
        """
        :param response_log: нужно ли логировать ответ в pushers
        :param response_log_body: нужно ли логировать тело ответа (если response_log = True)
        """

        with self.logger:
            if self.DEBUG:
                self.logger.context_push(
                    method=method,
                    request_url=url,
                    request_kwargs=self._scrub_logging_request_kwargs(kwargs),
                )

            # TODO: добавить тестов на pushers PAYBACK-614
            response_log_enabled = self.pushers is not None and self.pushers.response_log is not None
            response = await self._make_request(interaction_method, method, url, **kwargs)

            try:
                processed = await self._process_response(response, interaction_method)
            except Exception as exc:
                if response_log_enabled and response_log:
                    try:
                        assert self.pushers is not None
                        await self.pushers.response_log.push(
                            InteractionResponseLog(
                                response=await response.read(),
                                response_headers=self._scrub_logging_headers(response.headers),
                                request_id=self.request_id,
                                request_url=url,
                                request_method=method,
                                request_kwargs=self._scrub_logging_request_kwargs(kwargs, safe=False),
                                status=response.status,
                                exception_type=type(exc).__name__,
                                exception_message=str(exc)
                            )
                        )
                    except Exception:
                        self.logger.exception('Push to InteractionResponseLog failed')
                raise

            if response_log_enabled and response_log:
                response_data = await self._get_body(processed) if response_log_body else '<logging disabled>'
                try:
                    assert self.pushers is not None
                    await self.pushers.response_log.push(
                        InteractionResponseLog(
                            response=self._scrub_logging_request_kwargs_value(response_data),
                            response_headers=self._scrub_logging_headers(response.headers),
                            request_id=self.request_id,
                            request_url=url,
                            request_method=method,
                            request_kwargs=self._scrub_logging_request_kwargs(kwargs, safe=False),
                            status=response.status
                        )
                    )
                except Exception:
                    self.logger.exception('Push to InteractionResponseLog failed')

            if self.DEBUG:
                self.logger.context_push(response=processed)
                self.logger.debug('Interaction client processed response')

        if self.LOG_CORRELATION_HEADERS:
            self.logger.context_push(
                response_headers=extract_correlation_headers(response.headers)
            )

        await self._try_log_success_response_body(response)
        return processed

    async def get(self, interaction_method: str, url: str, **kwargs: Any) -> ResponseType:
        return await self._request(interaction_method, 'GET', url, **kwargs)

    async def post(self, interaction_method: str, url: str, **kwargs: Any) -> ResponseType:
        return await self._request(interaction_method, 'POST', url, **kwargs)

    async def put(self, interaction_method: str, url: str, **kwargs: Any) -> ResponseType:
        return await self._request(interaction_method, 'PUT', url, **kwargs)

    async def patch(self, interaction_method: str, url: str, **kwargs: Any) -> ResponseType:
        return await self._request(interaction_method, 'PATCH', url, **kwargs)

    async def delete(self, interaction_method: str, url: str, **kwargs: Any) -> ResponseType:
        return await self._request(interaction_method, 'DELETE', url, **kwargs)

    async def close(self) -> None:
        if self._session:
            await self._session.close()
            self._session = None

    def endpoint_url(self, relative_url: str, base_url_override: Optional[str] = None) -> str:
        base_url = (base_url_override or self.BASE_URL).rstrip('/')
        relative_url = relative_url.lstrip('/')
        return f'{base_url}/{relative_url}'

    @staticmethod
    def assert_string_urlsafe_for_path(string: str, custom_whitelist: Collection[str] = '') -> None:
        """
        Чтобы убедиться, что можно вставлять эту строку в path у URL'а.
        """
        assert string != '..'
        assert '/' not in string
        if custom_whitelist:
            assert all(c in custom_whitelist for c in string)
