from typing import Any, ClassVar, Iterable, Optional

from aiohttp import ClientResponse, ClientTimeout, TCPConnector
from aiosocksy.connector import ProxyClientRequest, ProxyConnector

import sendr_interactions
import sendr_qstats
from sendr_interactions.connector import sd
from sendr_tvm import qloud_async_tvm
from sendr_tvm.client.aiohttp import sessions_producer

from mail.payments.payments.conf import settings
from mail.payments.payments.utils.stats import (
    interaction_method_response_status, interaction_method_response_time, interaction_response_status,
    interaction_response_time
)

TVM_CONFIG = dict(
    client_name=settings.TVM_CLIENT,
    host=settings.TVM_HOST,
    port=settings.TVM_PORT,
)

PaymentsClientSession = sessions_producer(get_tvm=lambda: qloud_async_tvm.QTVM(**TVM_CONFIG))


class AbstractInteractionClient(sendr_interactions.AbstractInteractionClient):
    TVM_SESSION_CLS = PaymentsClientSession
    DEBUG = settings.DEBUG
    PROXY = settings.SOCKS_PROXY
    REQUEST_RETRY_TIMEOUTS: ClassVar[Iterable[int]] = settings.REQUEST_RETRY_TIMEOUTS

    @property
    def connector(self) -> TCPConnector:
        if getattr(AbstractInteractionClient, 'CONNECTOR', None) is None:
            kwargs = {
                'keepalive_timeout': settings.KEEPALIVE_TIMEOUT,
                'limit': settings.CONNECTION_LIMIT,
            }
            if self.DEBUG:
                kwargs['verify_ssl'] = False

            if self.PROXY:
                connector_cls = ProxyConnector
            else:
                connector_cls = TCPConnector

            connector_cls = sd(settings.SD_CONFIG)(connector_cls)
            AbstractInteractionClient.CONNECTOR = connector_cls(**kwargs)

        assert self.CONNECTOR
        return self.CONNECTOR

    @classmethod
    async def close_connector(cls):
        if getattr(cls, 'CONNECTOR', None):
            await cls.CONNECTOR.close()
            cls.CONNECTOR = None

    def _get_session_kwargs(self) -> dict:
        kwargs = {
            'connector': self.connector,
            'connector_owner': False,
        }
        if self.tvm_id is not None:
            kwargs['tvm_dst'] = self.tvm_id
        kwargs['trust_env'] = True
        if self.PROXY:
            kwargs['request_class'] = ProxyClientRequest
        return kwargs

    async def _make_request(self, interaction_method: str, method: str, url: str, **kwargs: Any) -> ClientResponse:
        if self.PROXY:
            kwargs['proxy'] = self.PROXY
        return await super()._make_request(interaction_method, method, url, **kwargs)

    async def _handle_response_error(self, response: ClientResponse) -> None:
        if self.DEBUG:
            self.logger.context_push(response=await response.text())
            self.logger.warning('Response error')
        await super()._handle_response_error(response)

    def _response_time_metrics(self, interaction_method: str, response_time: float) -> None:
        label_interaction_method = interaction_method.rsplit('/')[-1]
        service_response_time: sendr_qstats.Histogram = interaction_response_time.labels(self.SERVICE)
        service_response_time.observe(response_time)
        service_method_response_time: sendr_qstats.Histogram = interaction_method_response_time.labels(
            self.SERVICE,
            label_interaction_method,
        )
        service_method_response_time.observe(response_time)

    def _response_status_metrics(self, interaction_method: str, status: int) -> None:
        label_interaction_method = interaction_method.split('/')[-1]
        short_status = str(status // 100) + 'xx'
        interaction_response_status.labels(self.SERVICE, short_status).inc()
        interaction_method_response_status.labels(self.SERVICE, label_interaction_method, short_status).inc()

    @staticmethod
    def get_merchant_setting(uid: int, key: str, default: Optional[Any] = None) -> Optional[Any]:
        merchant_settings = settings.INTERACTION_MERCHANT_SETTINGS
        if uid in merchant_settings:
            return merchant_settings[uid].get(key, default)
        return default

    @staticmethod
    def _get_timeout_kwargs(total: Optional[float] = None) -> dict:
        return {
            'timeout': ClientTimeout(total=total)
        } if total is not None else dict()
