from functools import partial
from contextlib import contextmanager

import flask
import psycopg2
import psycopg2.pool
from flask_sqlalchemy import SQLAlchemy, SessionBase, SignallingSession, get_state


class FlaskMasterSlave(object):
    def __init__(self, app=None):
        if app is not None:
            self.init_app(app)

    def init_app(self, app):
        if 'fms' in app.extensions:
            return

        app.extensions['fms'] = self
        slave = (app.config.get('SQLALCHEMY_BINDS') or {}).get('slave')
        if slave is None:
            return

        app.before_request(partial(self.select_conn, app))

    def select_conn(self, app):
        req = flask.request
        view_func = app.view_functions.get(req.endpoint)
        if getattr(view_func, 'db_readonly_method', False):
            flask.g.db_readonly_method = True
            return

        view_class = getattr(view_func, 'view_class', None)
        if not view_class:
            flask.g.db_readonly_method = False
            return

        if getattr(view_class, 'db_readonly_method', False):
            flask.g.db_readonly_method = True
            return

        method = req.method.lower()
        flask.g.db_readonly_method = getattr(getattr(view_class, method, None), 'db_readonly_method', False)


class MasterSlaveSession(SignallingSession):
    def get_bind(self, mapper=None, clause=None):
        if mapper is not None:
            info = getattr(mapper.mapped_table, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)

        if getattr(flask.g, 'db_readonly_method', False):
            state = get_state(self.app)
            return state.db.get_engine(self.app, bind='slave')

        return SessionBase.get_bind(self, mapper, clause)


class MasterSlaveSqlAlchemy(SQLAlchemy):
    def apply_pool_defaults(self, app, options):
        SQLAlchemy.apply_pool_defaults(self, app, options)
        options['isolation_level'] = 'READ_UNCOMMITTED'
        options['pool_reset_on_return'] = 'rollback'

    def create_session(self, options):
        return partial(MasterSlaveSession, self, **options)


def db_readonly(fn):
    fn.db_readonly_method = True
    return fn


class MasterSlaveConnectionPool:
    def __init__(self, app):
        if app is not None:
            self.init_app(app)

    def init_app(self, app):
        if 'mscp' in app.extensions:
            return

        app.extensions['mscp'] = self

        binds = app.config.get('SQLALCHEMY_BINDS') or {}
        if not binds or not binds.get('slave') or not binds.get('master'):
            master_pool = slave_pool = psycopg2.pool.ThreadedConnectionPool(
                minconn=1,
                maxconn=10,
                dsn=app.config['SQLALCHEMY_DATABASE_URI'],
            )
        else:
            master_pool = psycopg2.pool.ThreadedConnectionPool(
                minconn=0,
                maxconn=10,
                dsn=binds['master'],
            )
            slave_pool = psycopg2.pool.ThreadedConnectionPool(
                minconn=1,
                maxconn=10,
                dsn=binds['slave'],
            )

        self.pools = {
            'master': master_pool,
            'slave': slave_pool,
        }


    @contextmanager
    def connection(self, readonly=None):
        if readonly is None:
            readonly = getattr(flask.g, 'db_readonly_method', False)

        pool = self.pools['slave' if readonly else 'master']
        conn = pool.getconn()
        try:
            with conn:
                conn.isolation_level = psycopg2.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
                conn.cursor_factory = psycopg2.extras.NamedTupleCursor
                yield conn
        except (psycopg2.InterfaceError, psycopg2.DatabaseError) as e:
            try:
                conn.rollback()
            except Exception:
                pass

            try:
                conn.close()
            except Exception:
                pass

            pool.putconn(conn)
            raise e
        except Exception:
            pool.putconn(conn)
            raise
        else:
            pool.putconn(conn)
