from __future__ import annotations

import inspect
import logging
from types import TracebackType
from typing import Dict, Optional, Type, get_type_hints

from sendr_interactions.base import AbstractInteractionClient, ResponseFormat  # noqa: F401
from sendr_interactions.retry_budget import RetryBudgetProtocol
from sendr_qlog import LoggerContext
from sendr_writers.base.pusher import CommonPushers

default_logger = logging.getLogger('context_logger')


class InteractionClientsMeta(type):
    def __init__(cls, name, bases, attributes):
        super().__init__(name, bases, attributes)
        cls._classes: Dict[str, Type[AbstractInteractionClient]] = {}
        cls_annotations = get_type_hints(cls)

        base_client_class = getattr(cls, 'abstract_client_class', AbstractInteractionClient)
        assert isinstance(base_client_class, type)

        for a_name, a_cls in cls_annotations.items():
            if inspect.isclass(a_cls) and issubclass(a_cls, base_client_class):
                cls._classes[a_name] = a_cls


class InteractionClients(metaclass=InteractionClientsMeta):
    """Wraps interactions clients mentioned in annotations.

    Usage example:
        ```
        class InteractionClients(sendr_interactions.InteractionClients):
            payments: PaymentsClient

        clients = InteractionClients(logger, request_id)
        async with clients:
            clients.payments.get(...)
        ```
    """

    _classes: Dict[str, Type[AbstractInteractionClient]]
    _clients: Dict[str, AbstractInteractionClient]

    def __init__(
        self,
        logger: Optional[LoggerContext] = None,
        request_id: Optional[str] = None,
        pushers: Optional[CommonPushers] = None,
        retry_budget: Optional[RetryBudgetProtocol] = None,
    ):
        self._logger = logger or LoggerContext(default_logger, {})
        self._request_id = request_id
        self._pushers = pushers
        self._retry_budget = retry_budget
        self._clients = {}

    def __getattr__(self, name: str) -> AbstractInteractionClient:
        if name in self._classes:
            return self._get_client(name)
        raise AttributeError

    async def __aenter__(self) -> InteractionClients:  # NOQA
        return self

    async def __aexit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: TracebackType) -> None:
        await self.close()

    @property
    def logger(self) -> LoggerContext:
        return self._logger

    @property
    def pushers(self) -> CommonPushers:
        return self._pushers

    @property
    def request_id(self) -> Optional[str]:
        return self._request_id

    def _get_client(self, name: str) -> AbstractInteractionClient:
        if name not in self._clients:
            self._clients[name] = self._classes[name](
                logger=self.logger,
                request_id=self.request_id or '',
                pushers=self.pushers,
                retry_budget=self._retry_budget,
            )
        return self._clients[name]

    async def close(self) -> None:
        for client in self._clients.values():
            await client.close()
