import logging
import json

from contextlib import contextmanager

from psycopg2.pool import SimpleConnectionPool

from ora2pg.tools import http
from mail.pypg.pypg.common import autocommit_connection, transaction, SQL_LINE_LIMIT
from mail.pypg.pypg.connect import LoggingConnection

log = logging.getLogger(__name__)
__sharddb_pools = dict()
__huskydb_pools = dict()


@contextmanager
def __get_conn(dsn, autocommit):
    if autocommit:
        with autocommit_connection(dsn) as conn:
            yield conn
    else:
        with transaction(dsn) as conn:
            yield conn


class MasterNotFoundError(Exception):
    pass


def get_sharddb_master_host(sharpei, dsn_suffix):
    with http.request(
        url=http.url_join(
            host=sharpei,
            method='sharddb_stat'
        ),
        do_retries=True,
    ) as fd:
        statresp = json.load(fd)
        addr = next((host['address'] for host in statresp if host['role'] == 'master'), None)
        if not addr:
            raise MasterNotFoundError('master host not found for sharddb')
        addr['dsn_suffix'] = dsn_suffix
        return 'host={host} port={port} dbname={dbname} {dsn_suffix}'.format(**addr)


def find_sharddb(args):
    sharpei = getattr(args, 'sharpei')
    dsn_suffix = getattr(args, 'sharddb_dsn_suffix', '')
    master = get_sharddb_master_host(sharpei, dsn_suffix)
    log.info('Sharddb master is %s' % master)
    return master


def find_huskydb(args):
    return getattr(args, 'huskydb')


@contextmanager
def get_sharddb_pooled_conn(dsn, autocommit=False, pool=None):
    with __get_pooled_conn(__sharddb_pools, dsn, autocommit, pool) as conn:
        yield conn


@contextmanager
def get_huskydb_pooled_conn(dsn, autocommit=False, pool=None):
    with __get_pooled_conn(__huskydb_pools, dsn, autocommit, pool) as conn:
        yield conn


def __get_pool(pools, dsn):
    if dsn not in pools:
        pools[dsn] = SimpleConnectionPool(minconn=1, maxconn=1, dsn=dsn, connection_factory=LoggingConnection)
    return pools[dsn]


@contextmanager
def __get_pooled_conn(pools, dsn, autocommit, pool):
    if pool is None:
        pool = __get_pool(pools, dsn)
    for _ in range(pool.maxconn + 1):
        conn = pool.getconn(key=dsn)
        if conn.closed:
            pool.putconn(key=dsn, conn=conn)
            continue
        try:
            conn.autocommit = autocommit
            conn.SQL_LINE_LIMIT = SQL_LINE_LIMIT
            conn.wait()
            yield conn
            if not autocommit:
                conn.commit()
            break
        finally:
            pool.putconn(key=dsn, conn=conn)
