import asyncio
from asyncio import wait_for
from contextlib import asynccontextmanager
from typing import Optional

from aiopg.log import logger

from sendr_aiopg.engine import EngineMixin
from sendr_aiopg.engine.single import CustomEngine, create_engine


class BrokenEngineException(Exception):
    def __init__(self, orig_exception):
        self.orig_exception: Exception = orig_exception


class RetryableEngine(EngineMixin):
    def __init__(self, *args, **kwargs):
        self.logger = kwargs.pop('logger', logger)
        self.reconnect_timeout: float = kwargs.pop('reconnect_timeout', 1.0)
        self.acquire_connect_timeout: float = kwargs.pop('acquire_connect_timeout', 2.0)

        self._engine: Optional[CustomEngine] = None
        self._args = args
        self._exception: Optional[Exception] = None
        self._kwargs = kwargs

        self._connect_cond = asyncio.Condition()

    async def connect(self):
        async with self._connect_cond:
            if self._engine is None:
                while True:
                    try:
                        self._engine = await create_engine(*self._args, **self._kwargs).__aenter__()
                    except Exception as e:
                        self._exception = e
                        self.logger.error(e)
                        await asyncio.sleep(self.reconnect_timeout)
                    else:
                        break
                self._connect_cond.notify_all()

    def check_engine(self) -> None:
        if not self._engine:
            raise BrokenEngineException(self._exception)

    @asynccontextmanager
    async def acquire(self):
        if self._engine is None:
            try:
                await wait_for(
                    self._connect_cond.wait(),
                    timeout=self.acquire_connect_timeout
                )
            except asyncio.TimeoutError:
                pass
        self.check_engine()

        async with self._engine.acquire() as conn:
            yield conn

    def __getattr__(self, item):
        self.check_engine()
        return getattr(self._engine, item)
