import aiopg
import psycopg2
import logging
from contextlib import asynccontextmanager
from psycopg2.extras import RealDictCursor

from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn, get_shard_dsn_by_uid, get_sharddb_dsn
from mail.shiva.stages.api.settings.sharpei import SharpeiSettings

log = logging.getLogger(__name__)


class DbConnectionError(RuntimeError):
    pass


class CursorProvider(object):
    def __init__(self, pg: aiopg.Pool):
        self.pg_pool = pg

    @asynccontextmanager
    async def cursor(self, **kwargs):
        async with self.pg_pool.acquire() as conn:
            try:
                async with conn.cursor(**kwargs) as cur:
                    yield cur
            except (
                psycopg2.OperationalError,
                psycopg2.InterfaceError
            ) as exc:
                log.error(f'Exception during request to db: {exc}')
                conn.close()
                raise DbConnectionError(f'Db connection error: {exc}')


@asynccontextmanager
async def create_cursor_provider(params, stats):
    async with aiopg.create_pool(
        dsn=await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats),
        minsize=1,
        maxsize=1,
        cursor_factory=RealDictCursor,
    ) as dp_pool:
        yield CursorProvider(dp_pool)


@asynccontextmanager
async def create_uid_cursor_provider(params, stats):
    async with aiopg.create_pool(
        dsn=await get_shard_dsn_by_uid(params.sharpei, params.db_user, params.uid, stats),
        minsize=1,
        maxsize=1,
        cursor_factory=RealDictCursor,
    ) as dp_pool:
        yield CursorProvider(dp_pool)


@asynccontextmanager
async def locked_transactional_cursor(cursor_provider, uid):
    async with cursor_provider.cursor() as cur:
        try:
            await cur.execute('BEGIN')
            await cur.execute('SELECT code.acquire_current_revision(%(uid)s)', dict(uid=uid))
            yield cur
        except Exception as exc:
            log.warning(f'Transaction for uid={uid} was rollbacked: {exc}')
            await cur.execute('ROLLBACK')
            raise
        finally:
            await cur.execute('COMMIT')


class ShardDbCursorProvider(object):
    def __init__(self, sharpei: SharpeiSettings, stats, dbuser='sharpei'):
        self._conn = None
        self._sharpei = sharpei
        self._stats = stats
        self._dbuser = dbuser

    async def _get_connection(self):
        if self._conn is None:
            dsn = await get_sharddb_dsn(self._sharpei, self._dbuser, self._stats)
            self._conn = await aiopg.connect(dsn, cursor_factory=RealDictCursor)
        return self._conn

    async def close(self):
        if self._conn:
            await self._conn.close()
            self._conn = None

    @asynccontextmanager
    async def cursor(self, **kwargs):
        conn = await self._get_connection()
        try:
            async with conn.cursor(**kwargs) as cur:
                yield cur
        except (
            psycopg2.OperationalError,
            psycopg2.InterfaceError
        ) as exc:
            log.error(f'Exception during request to db: {exc}')
            await self.close()
            raise DbConnectionError(f'Db connection error: {exc}')
