import asyncio
from asyncio import Future
from logging import Logger
from typing import Any, Dict, List, Optional, Tuple

from aiopg.log import logger as aiopg_logger

from sendr_aiopg.engine import EngineMixin
from sendr_aiopg.engine.retryable import BrokenEngineException, RetryableEngine


class MultipleEngineException(Exception):
    def __init__(self, engine_exceptions: List[Tuple[str, Exception]], *args: Any):
        self.engine_exceptions = engine_exceptions
        super().__init__(*args)


class MultipleEngine(EngineMixin):
    def __init__(self,
                 engines: Dict[str, RetryableEngine],
                 logger: Optional[Logger] = None,
                 default: str = 'default'):
        self._engines = engines
        self.logger = logger or aiopg_logger
        self.default = default
        self._connect_futures: List[asyncio.Future] = []
        self._connect_async_future: Optional[asyncio.Future] = None

    def connect_async(self) -> Future:
        def on_done(f: Future) -> None:
            self._connect_async_future = None

            if f.cancelled():
                return None

            exc = f.exception()
            if exc:
                self.logger.exception(exc)

        self._connect_async_future = asyncio.create_task(self.connect())
        self._connect_async_future.add_done_callback(on_done)

        return self._connect_async_future

    async def connect(self) -> None:
        try:
            all_success = True

            def on_all_connected(f):
                nonlocal all_success

                if all_success:
                    self.logger.info('Connected to all')
                else:
                    self.logger.warning('Connected to all with errors')

                self._connect_futures = None

            def on_one_connected(f):
                nonlocal all_success
                engine_name = getattr(f, '_engine_name', 'None')

                cancelled = f.cancelled()
                if cancelled:
                    self.logger.error('Connection canceled engine "%s"', engine_name)
                    all_success = False
                    return

                exc = f.exception()
                if not exc:
                    return

                if exc:
                    self.logger.error(
                        'Exception while connecting engine "%s" : %s', engine_name,
                        str(exc), exc_info=exc
                    )
                    all_success = False

            for name, engine in self._engines.items():
                future = asyncio.ensure_future(engine.connect())
                future.add_done_callback(on_one_connected)
                setattr(future, '_engine_name', name)
                self._connect_futures.append(future)

            wait_future = asyncio.ensure_future(asyncio.wait(self._connect_futures))
            wait_future.add_done_callback(on_all_connected)
            await wait_future
        except asyncio.CancelledError:
            raise
        except Exception as e:
            self.logger.exception('Error while connecting: %s', str(e))

    def close(self) -> None:
        engine_exceptions = []
        for name, engine in self._engines.items():
            try:
                engine.close()
            except BrokenEngineException:
                pass
            except Exception as e:
                engine_exceptions.append((name, e))
        if engine_exceptions:
            raise MultipleEngineException(engine_exceptions)

    async def wait_closed(self) -> None:
        engine_exceptions = []
        for name, engine in self._engines.items():
            try:
                await engine.wait_closed()
            except BrokenEngineException:
                pass
            except Exception as e:
                engine_exceptions.append((name, e))
        if engine_exceptions:
            raise MultipleEngineException(engine_exceptions)

    def using(self, name: Optional[str] = None) -> RetryableEngine:
        if name is None:
            name = self.default
        return self._engines[name]

    def __getattr__(self, item: str) -> RetryableEngine:
        return getattr(self._engines[self.default], item)

    def __getitem__(self, item: str) -> RetryableEngine:
        return self._engines[item]

    def __iter__(self):
        return iter(self._engines.values())


def create_engine(config: Dict[str, dict], default: str = 'default') -> MultipleEngine:
    engines: Dict[str, RetryableEngine] = dict()
    assert default in config, f'Unknown default engine "{default}"'
    for key, kwargs in config.items():
        engines[key] = RetryableEngine(**kwargs)
    return MultipleEngine(engines, default=default)
