import Queue
import contextlib
import os.path
import socket
import sys
from itertools import izip

import gevent
import gevent.queue
import psycopg2
import psycopg2.extensions as pge
import psycopg2.extras as pgex
import simplejson as json
from gevent.socket import wait_read, wait_write

from pymail.logged_connection import LoggingConnection


def describe_cursor(cur):
    return [c.name.lower() for c in cur.description]


def fetch_as_dicts(cur):
    desc = describe_cursor(cur)
    for row in cur:
        yield dict(izip(desc, row))


def gevent_wait_callback(conn):
    """A wait callback useful to allow gevent to work with Psycopg."""
    while 1:
        state = conn.poll()
        if state == pge.POLL_OK:
            break
        elif state == pge.POLL_READ:
            wait_read(conn.fileno())
        elif state == pge.POLL_WRITE:
            wait_write(conn.fileno())
        else:
            raise psycopg2.OperationalError("Bad result from poll: %r" % state)


def patch_psycopg2():
    pge.set_wait_callback(gevent_wait_callback)


class NoAvailableConnection(RuntimeError):
    pass


class AbstractDatabaseConnectionPool(object):
    def __init__(self, maxsize=100):
        self.maxsize = maxsize
        self.pool = gevent.queue.Queue()
        self.size = 0

    def create_connection(self):
        raise NotImplementedError()

    def get(self):
        try:
            if self.size >= self.maxsize or self.pool.qsize():
                return self.pool.get(timeout=1)
        except Queue.Empty:
            raise NoAvailableConnection()

        self.size += 1
        try:
            new_item = self.create_connection()
        except:
            self.size -= 1
            raise
        return new_item

    def put(self, item):
        self.pool.put(item)

    def closeall(self):
        while not self.pool.empty():
            conn = self.pool.get_nowait()
            self.size -= 1
            try:
                conn.close()
            except Exception:
                pass

    @contextlib.contextmanager
    def connection(self):
        conn = self.get()
        try:
            yield conn
        except:
            if conn.closed:
                conn = None
                self.closeall()
            else:
                conn = self._rollback(conn)
            raise
        else:
            if conn.closed:
                raise psycopg2.OperationalError(
                    "Cannot commit because connection was closed: %r" % (conn, )
                )
            conn.commit()
        finally:
            if conn is not None and not conn.closed:
                self.put(conn)
            else:
                self.size -= 1

    @contextlib.contextmanager
    def cursor(self, *args, **kwargs):
        with self.connection() as conn:
            yield conn.cursor(*args, **kwargs)

    def _rollback(self, conn):
        try:
            conn.rollback()
        except:
            gevent.get_hub().handle_error(conn, *sys.exc_info())
            return
        return conn

    def execute(self, *args, **kwargs):
        with self.cursor(**kwargs) as cursor:
            cursor.execute(*args)
            return cursor.rowcount

    def fetchone(self, *args, **kwargs):
        with self.cursor(**kwargs) as cursor:
            cursor.execute(*args)
            return cursor.fetchone()

    def fetchall(self, *args, **kwargs):
        with self.cursor(**kwargs) as cursor:
            cursor.execute(*args)
            return cursor.fetchall()

    def fetchiter(self, *args, **kwargs):
        with self.cursor(**kwargs) as cursor:
            cursor.execute(*args)
            while True:
                items = cursor.fetchmany()
                if not items:
                    break
                for item in items:
                    yield item


class PostgresConnectionPool(AbstractDatabaseConnectionPool):
    def __init__(self, *args, **kwargs):
        self.connect = kwargs.pop('connect', psycopg2.connect)
        maxsize = kwargs.pop('maxsize', None)
        self.args = args
        self.kwargs = kwargs
        AbstractDatabaseConnectionPool.__init__(self, maxsize)

    @staticmethod
    def set_keepalive(conn):
        s = socket.fromfd(conn.fileno(),
                          socket.AF_INET, socket.SOCK_STREAM)
        # Enable sending of keep-alive messages
        s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
        # Time the connection needs to remain idle before start sending
        # keepalive probes
        if hasattr(socket, 'TCP_KEEPIDLE'):
            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 3)
        # Time between individual keepalive probes
        s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 1)
        # The maximum number of keepalive probes should send before dropping
        # the connection
        s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3)

    def create_connection(self):
        conn = self.connect(*self.args, **self.kwargs)
        self.set_keepalive(conn)
        pgex.register_default_jsonb(conn, loads=json.loads)
        return conn

    @classmethod
    def from_conf(cls, conf):
        return cls(
            dsn=make_dsn(conf),
            connection_factory=LoggingConnection,
            maxsize=conf['db'].get('pool_size', 1),
        )


def make_dsn(conf):
    db_conf = conf['db']
    host = db_conf.get('host')
    if not host:
        qloud_meta_filepath = conf['qloud']['meta_filepath']
        assert os.path.exists(qloud_meta_filepath)
        with open(qloud_meta_filepath) as fd:
            current_dc = json.load(fd)['datacenter'].lower()
        assert current_dc in db_conf['host_by_dc'], 'Unknown dc: %s' % current_dc
        host = db_conf['host_by_dc'][current_dc]

    dbname = db_conf['dbname']
    user = db_conf['user']
    port = db_conf['port']
    password = db_conf.get('password')
    dsn = 'host={host} port={port} user={user} dbname={dbname}'.format(**locals())
    if password:
        dsn += ' password={}'.format(password)
    return dsn
