import inspect
import logging
from typing import Dict, Optional

from sendr_interactions import AbstractInteractionClient
from sendr_qlog import LoggerContext

from mail.beagle.beagle.interactions.base import BaseInteractionClient
from mail.beagle.beagle.interactions.blackbox import BlackBoxClient
from mail.beagle.beagle.interactions.directory import DirectoryClient
from mail.beagle.beagle.interactions.hound import HoundClient
from mail.beagle.beagle.interactions.mbody import MBodyClient
from mail.beagle.beagle.interactions.passport import PassportClient
from mail.beagle.beagle.interactions.sender import SenderClient

default_logger = logging.getLogger('context_logger')


class InteractionClientsAnnotatedMeta(type):
    def __init__(cls, name, superclasses, attributes):
        super().__init__(name, superclasses, attributes)

        cls._clients_cls = {
            name: a_cls for name, a_cls in cls.__annotations__.items()
            if inspect.isclass(a_cls) and issubclass(a_cls, AbstractInteractionClient)
        }


class InteractionClients(metaclass=InteractionClientsAnnotatedMeta):
    blackbox: BlackBoxClient
    directory: DirectoryClient
    hound: HoundClient
    mbody: MBodyClient
    passport: PassportClient
    sender: SenderClient

    def __init__(self, logger: Optional[LoggerContext] = None, request_id: Optional[str] = None):
        if logger is None:
            logger = LoggerContext(default_logger, {})
        self._logger = logger
        self._request_id = request_id
        self._clients: Dict[str, BaseInteractionClient] = {}

    def __getattr__(self, item):
        if item in self._clients_cls:
            return self._get_client(item)
        raise AttributeError

    def _get_client(self, client: str) -> BaseInteractionClient:
        if client not in self._clients:
            self._clients[client] = self._clients_cls[client](
                logger=self.logger,
                request_id=self.request_id,
            )
        return self._clients[client]

    @property
    def logger(self):
        return self._logger

    @property
    def request_id(self):
        return self._request_id

    async def close(self):
        for client in self._clients.values():
            await client.close()

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.close()
