# -*- coding: utf-8 -*-

from functools import partial
import logging
import time

from alembic import command
from flask import (
    current_app,
    g,
)
from flask_migrate import Migrate
from flask_sqlalchemy import (
    _EngineConnector,
    SQLAlchemy as OriginalSQLAlchemy,
)
from passport.backend.core.lazy_loader import LazyLoader
from passport.backend.utils.common import chunks
from sqlalchemy import (
    event,
    orm,
)
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError


def chunked_merge(
    data,
    chunk_size,
    log_message='Upserted chunk of size %s',
):
    db = get_db()
    for chunk in chunks(data, chunk_size):
        try:
            for el in chunk:
                try:
                    with db.session.begin_nested():
                        db.session.merge(el)
                except IntegrityError as e:
                    logging.getLogger('exception_logger').exception(e, exc_info=True)

            db.session.commit()
            if log_message:
                logging.getLogger('info_logger').info(
                    log_message, len(chunk),
                )
        except Exception as e:  # pragma: no cover
            logging.getLogger('exception_logger').exception(e, exc_info=True)
            db.session.rollback()
            raise


def chunked_delete(
    data,
    chunk_size,
    log_message='Deleted chunk of size %s',
):
    db = get_db()
    for chunk in chunks(data, chunk_size):
        try:
            for el in chunk:
                db.session.delete(el)

            db.session.commit()
            if log_message:
                logging.getLogger('info_logger').info(
                    log_message, len(chunk),
                )
        except Exception as e:  # pragma: no cover
            logging.getLogger('exception_logger').exception(e, exc_info=True)
            db.session.rollback()
            raise


class AutoRouteSession(orm.Session):
    def __init__(self, db, autocommit=False, autoflush=False, **options):
        self._model_changes = {}
        orm.Session.__init__(
            self,
            autocommit=autocommit,
            autoflush=autoflush,
            bind=db.get_engine(),
            binds=db.get_binds(current_app),
            **options
        )

    def get_bind(self, mapper=None, clause=None):
        return get_db().get_engine(
            app=current_app,
            bind=None,
        )


class SQLAlchemy(OriginalSQLAlchemy):
    """Override to fix issues when doing a rollback with sqlite driver
    See http://docs.sqlalchemy.org/en/rel_1_0/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
    and https://bitbucket.org/zzzeek/sqlalchemy/issues/3561/sqlite-nested-transactions-fail-with
    for furhter information

    Only do this on sqlite, not on MySQL, Postgres nor other RDBMS!
    """
    def __init__(self, config, *args, **kwargs):
        super(SQLAlchemy, self).__init__(*args, **kwargs)
        self.config = config

    def init_app(self, app):
        super(SQLAlchemy, self).init_app(app)
        app.config.setdefault('SQLALCHEMY_POOL_PRE_PING', None)

        logger = logging.getLogger('vault.database.stat')

        @event.listens_for(Engine, 'before_cursor_execute')
        def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
            conn.info.setdefault('query_start_time', []).append(time.time())

        @event.listens_for(Engine, 'after_cursor_execute')
        def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
            total = time.time() - conn.info['query_start_time'].pop(-1)
            logger.info(dict(
                bind=getattr(conn.engine, 'yandex_vault_bind', None) or 'master',
                duration=total,
                dialect=conn.engine.dialect.name,
                has_errors=False
            ))

    def apply_pool_defaults(self, app, options):
        options = super(SQLAlchemy, self).apply_pool_defaults(app, options)

        def _setdefault(optionkey, configkey):
            value = app.config[configkey]
            if value is not None:
                options[optionkey] = value
        _setdefault('pool_pre_ping', 'SQLALCHEMY_POOL_PRE_PING')
        return options

    def make_connector(self, app=None, bind=None):
        """Creates the connector for a given state and bind."""
        if self.config['SQLALCHEMY_DATABASE_URI'].startswith('sqlite://'):
            return CustomSQLiteEngineConnector(self, self.get_app(app), bind)
        else:
            return super(SQLAlchemy, self).make_connector(app, bind)  # pragma: no cover

    def get_engine(self, app=None, bind=None):
        if (
            hasattr(g, 'use_slave') and
            g.use_slave and
            not self.config['SQLALCHEMY_DATABASE_URI'].startswith('sqlite://')
        ):
            bind = 'slave'  # pragma: no cover
        engine = super(SQLAlchemy, self).get_engine(app=app, bind=bind)
        engine.yandex_vault_bind = bind
        return engine

    def create_scoped_session(self, options=None):
        """Helper factory method that creates a scoped session."""
        if options is None:
            options = {}
        scopefunc = options.pop('scopefunc', None)
        return orm.scoped_session(
            partial(AutoRouteSession, self, **options),
            scopefunc=scopefunc,
        )


class CustomSQLiteEngineConnector(_EngineConnector):
    """Used by overrideb SQLAlchemy class to fix rollback issues"""

    def get_engine(self):
        # Use an existent engine and don't register events if possible
        uri = self.get_uri()
        echo = self._app.config['SQLALCHEMY_ECHO']
        if (uri, echo) == self._connected_for:
            return self._engine

        # Call original metohd and register events
        rv = super(CustomSQLiteEngineConnector, self).get_engine()
        with self._lock:
            @event.listens_for(rv, 'connect')
            def do_connect(dbapi_connection, connection_record):
                # disable pysqlite's emitting of the BEGIN statement entirely.
                # also stops it from emitting COMMIT before any DDL.
                dbapi_connection.isolation_level = None

            @event.listens_for(rv, 'begin')
            def do_begin(conn):
                # emit our own BEGIN
                conn.execute('BEGIN')
        return rv


migrate = Migrate()


def get_db():
    return LazyLoader.get_instance('db')


def get_migrations_dir(config):
    return config['application']['migrations_dir']


def configure_migrations(app, config):
    migrate.init_app(
        app=app,
        db=get_db(),
        directory=get_migrations_dir(config),
    )


def upgrade_db(app):  # pragma: no cover
    """Вспомогательная функция для запуска миграции в дев-окружении app.py run"""
    with app.app_context():
        config = app.extensions['migrate'].migrate.get_config()
        command.upgrade(config, 'head')
