import asyncio
from typing import Any, Optional

import aiopg
from aiopg.connection import TIMEOUT
from aiopg.sa.connection import SAConnection
from aiopg.sa.engine import Engine, _dialect, _EngineContextManager
from sqlalchemy.engine.interfaces import Dialect

from sendr_aiopg.engine import EngineClosingAcquireContextManager, EngineMixin
from sendr_utils import sort_hosts_by_geo


class CustomEngine(EngineMixin, Engine):
    def acquire(self):
        return EngineClosingAcquireContextManager(self._acquire(), self)

    async def _acquire(self):
        raw = await self._pool.acquire()
        conn = SAConnection(raw, self)
        return conn


async def _create_engine(
    dsn: Optional[Any] = None,
    *,
    minsize: int = 1,
    maxsize: int = 10,
    dialect: Any = _dialect,
    timeout: float = TIMEOUT,
    pool_recycle: int = 120,
    **kwargs: Any,
) -> CustomEngine:
    loop = asyncio.get_event_loop()

    if 'host' in kwargs and kwargs.get('target_session_attrs') == 'any':
        hosts = kwargs['host'].split(',')
        if len(hosts) > 1:
            hosts = sort_hosts_by_geo(hosts)
        kwargs['host'] = ','.join(hosts)

    pool = await aiopg.create_pool(
        dsn,
        minsize=minsize, maxsize=maxsize,
        loop=loop, timeout=timeout, pool_recycle=pool_recycle,
        **kwargs
    )
    conn = await pool.acquire()
    try:
        real_dsn = conn.dsn
        conn.close()
        return CustomEngine(dialect, pool, real_dsn)
    finally:
        await pool.release(conn)


def create_engine(
    dsn: Optional[Any] = None, *,
    minsize: int = 1,
    maxsize: int = 10,
    dialect: Dialect = _dialect,
    timeout: float = TIMEOUT,
    **kwargs: Any,
) -> _EngineContextManager:
    """Копия aiopg.engine.create_engine с измененными классами"""

    coro = _create_engine(
        dsn=dsn,
        minsize=minsize,
        maxsize=maxsize,
        dialect=dialect,
        timeout=timeout,
        **kwargs
    )

    return _EngineContextManager(coro)
