import psycopg2
import psycopg2.extras
import psycopg2.extensions
import select
import logging
import time

class CountError(Exception): pass
class ReadError(Exception): pass
class DeleteError(Exception): pass
class InsertError(Exception): pass

migrated_columns = '''
    gid,
    service,
    id,
    callback,
    filter,
    extra_data,
    client,
    ttl,
    session_key,
    init_local_id,
    init_time,
    platform,
    uid,
    device,
    bb_connection_id,
    uidset,
    ack_local_id,
    ack_time,
    smart_notify,
    next_retry_time,
    retry_interval,
    ack_event_ts
'''

count_sql = '''
    SELECT count(*) FROM xiva.subscriptions WHERE gid = %(gid)s;
'''

select_sql = '''
    SELECT
''' + migrated_columns + '''
    FROM
        xiva.subscriptions
    WHERE
        gid = %(gid)s;
'''

insert_sql = '''
    INSERT INTO xiva.subscriptions(
''' + migrated_columns + '''
    ) VALUES %s;
'''

delete_sql = '''
    DELETE FROM xiva.subscriptions WHERE gid = %(gid)s RETURNING 1;
'''

# Since we're connecting to pgbouncer with transaction pooling,
# we cannot set statement_timeout in any way on the server, therefore
# we're forced to time out on the client. It is impossible in sync mode,
# that's why we wait on postgres socket manually with timeout passed to
# select.
def wait(conn):
    start_time = time.time()
    while 1:
        state = conn.poll()
        if state == psycopg2.extensions.POLL_OK:
            break

        ready_to_read = []
        ready_to_write = []
        in_error = []

        if state == psycopg2.extensions.POLL_WRITE:
            ready_to_read, ready_to_write, in_error = select.select(
                [], [conn.fileno()], [], request_timeout)
        elif state == psycopg2.extensions.POLL_READ:
            # Not using actual timeout here, since in some cases
            # select does not seem to trigger, although conn.poll is
            # able to read something from the socket (or maybe it has a bug
            # and returns POLL_READ when it should have returned POLL_OK?).
            # This results in timeouts on requests that should not have
            # timeouted. Happens steadily on some gids in testing.
            ready_to_read, ready_to_write, in_error = select.select(
                [conn.fileno()], [], [], 0.1)
        else:
            raise psycopg2.OperationalError('poll() returned %s' % state)

        duration = time.time() - start_time
        if duration >= request_timeout:
            conn.cancel()
            raise psycopg2.OperationalError('request timed out')

psycopg2.extensions.set_wait_callback(wait)

def _error_str(e):
    return "%s: %s" % (type(e).__name__, e)

def execute(conninfo, sql, **args):
    with psycopg2.connect(conninfo) as connection:
        with connection.cursor() as cursor:
            cursor.execute(sql, args)
            return cursor.fetchall()

def count(gid, conninfo):
    try:
        return execute(conninfo, count_sql, gid=gid)[0][0]
    except Exception as e:
        raise CountError(_error_str(e))

def read(gid, conninfo):
    try:
        return execute(conninfo, select_sql, gid=gid)
    except Exception as e:
        raise ReadError(_error_str(e))

def delete(gid, conninfo):
    try:
        return execute(conninfo, delete_sql, gid=gid)
    except Exception as e:
        raise DeleteError(_error_str(e))

def insert(subs, conninfo):
    try:
        with psycopg2.connect(conninfo) as connection:
            with connection.cursor() as cursor:
                psycopg2.extras.execute_values(
                    cursor, insert_sql, subs, page_size=insert_page_size)
    except Exception as e:
        raise InsertError(_error_str(e))
