# -*- coding: utf-8 -*-
from collections import defaultdict
from itertools import chain
import logging
from time import time

from passport.backend.core.conf import settings
from passport.backend.core.db.query import (
    DbTransaction,
    DbTransactionContainer,
    join_queries,
    split_query_and_callback,
)
from passport.backend.core.db.schemas import central_metadata
from passport.backend.core.dbmanager.exceptions import (
    DBDataError,
    DBError,
    DBIntegrityError,
    DBOperationalError,
)
from passport.backend.core.dbmanager.sharder import get_db_name
from passport.backend.core.dbmanager.transaction_manager import get_transaction_manager
from passport.backend.core.host.host import get_current_host
from passport.backend.core.logging_utils.loggers import (
    FAILED_RESPONSE_CODE,
    GraphiteLogger,
    SUCCESS_RESPONSE_CODE,
)
import sqlalchemy
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import URL
from sqlalchemy.pool import QueuePool


log = logging.getLogger('passport.dbmanager')


# Константы соответствуют очередному действию в транзакции
TRX_STATE_BEGIN = 'begin'
TRX_STATE_ACTIVE = 'active'  # Выполнение обычных запросов в рамках транзакции
TRX_STATE_COMMIT = 'commit'
TRX_STATE_ROLLBACK = 'rollback'


def _format_host(host):
    """
    В настройках БД host может зависеть от текущего ДЦ
    """
    if host and '%(dc)s' in host:
        host = host % dict(dc=get_current_host().get_dc())
    return host


def get_db_errors():
    # Импортируем здесь, чтобы не зависеть от конкретной реализации драйвера в момент импорта passport/backend/core
    import MySQLdb

    return (
        sqlalchemy.exc.InterfaceError,
        sqlalchemy.exc.DatabaseError,
        MySQLdb.InterfaceError,
        MySQLdb.DatabaseError,
    )


def safe_execute(engine, executable, retries=None, graphite_logger=None, trx_state=None):
    """
    Безопасное выполнение единицы работы в БД (executable) с учетом ретраев и таймаутов.
    Оборачивает все исключения SQLAlchemy и нижележащих уровней в DBError, DBIntegrityError,
    DBOperationalError, DBDataError.
    """
    # Импортируем здесь, чтобы не зависеть от конкретной реализации драйвера в момент импорта passport/backend/core
    import MySQLdb

    DB_INTEGRITY_ERRORS = (
        sqlalchemy.exc.IntegrityError,
        MySQLdb.IntegrityError,
    )
    DB_OPERATIONAL_ERRORS = (
        sqlalchemy.exc.OperationalError,
        MySQLdb.OperationalError,
    )
    DB_DATA_ERRORS = (
        sqlalchemy.exc.DataError,
        MySQLdb.DataError,
    )
    DB_ERRORS = get_db_errors()
    FATAL_MYSQL_OPERATIONAL_ERROR_CODES = (1048,)

    graphite_logger = graphite_logger or GraphiteLogger(service='db')
    host_name = engine.url.host
    retries = retries or engine.reconnect_retries
    caught_exception = None
    is_connection_alive = True
    for i in range(retries):
        network_error = False
        is_connection_alive = True
        response_code = SUCCESS_RESPONSE_CODE
        start_time = time()
        caught_exception = None
        retries_left = retries - i - 1  # текущую попытку не учитываем
        try:
            return executable(engine)
        except DB_INTEGRITY_ERRORS as e:
            caught_exception = e
            # Нет смысла повторять попытки при ошибке целостности
            retries_left = 0
            response_code = FAILED_RESPONSE_CODE
            log.error(
                'Database integrity error: host %s, has_low_timeout=%d, trx_state=%s',
                host_name,
                engine.has_low_timeout,
                trx_state,
                exc_info=caught_exception,
            )
            raise DBIntegrityError(e)
        except DB_OPERATIONAL_ERRORS as e:
            caught_exception = e
            # Нет смысла повторять попытки
            retries_left = 0
            response_code = FAILED_RESPONSE_CODE
            log.error(
                'Database operational error: host %s, has_low_timeout=%d, trx_state=%s',
                host_name,
                engine.has_low_timeout,
                trx_state,
                exc_info=caught_exception,
            )
            orig_e = e.orig if isinstance(e, sqlalchemy.exc.DBAPIError) else e
            if orig_e.args[0] in FATAL_MYSQL_OPERATIONAL_ERROR_CODES:
                # Проверяем код ошибки, полученный от MySQL и заворачиваем фатальные ошибки
                # в DBIntegrityError
                # PASSP-23157
                raise DBIntegrityError(e)
            raise DBOperationalError(e)
        except DB_DATA_ERRORS as e:
            caught_exception = e
            # Нет смысла повторять попытки при невалидных данных
            retries_left = 0
            response_code = FAILED_RESPONSE_CODE
            log.error(
                'Database data error: host %s, has_low_timeout=%d, trx_state=%s',
                host_name,
                engine.has_low_timeout,
                trx_state,
                exc_info=e,
            )
            raise DBDataError(e)
        except DB_ERRORS as e:
            caught_exception = e
            response_code = FAILED_RESPONSE_CODE
            is_connection_alive = hasattr(e, 'connection_invalidated') and not e.connection_invalidated
            network_error = not is_connection_alive
            log.info(
                'Database attempt [%d/%d] error: host %s, has_low_timeout=%d, trx_state=%s, is_connection_alive=%s',
                i + 1,
                retries,
                host_name,
                engine.has_low_timeout,
                trx_state,
                is_connection_alive,
                exc_info=caught_exception,
            )
            if trx_state == TRX_STATE_ACTIVE and not is_connection_alive:
                # В случае выполнения запросов транзакции, если мы с достаточной уверенностью не можем сказать,
                # что соединение является рабочим, нужно перевыполнять всю транзакцию
                retries_left = 0
                raise DBError(e)
        finally:
            graphite_logger.log(
                duration=time() - start_time,
                response=response_code,
                network_error=network_error,
                srv_hostname=host_name,
                srv_ipaddress=host_name,
                retries_left=retries_left,
                with_low_timeout=engine.has_low_timeout,
                trx_state=trx_state,
                is_connection_alive=int(is_connection_alive),
            )

    log.error(
        'Database error, execute failed: host %s, has_low_timeout=%d, trx_state=%s, is_connection_alive=%s',
        host_name,
        engine.has_low_timeout,
        trx_state,
        is_connection_alive,
        exc_info=caught_exception,
    )
    raise DBError(caught_exception)


def safe_execute_queries(queries, engine=None, retries=None, transaction_retries=None, trx_state=None,
                         graphite_logger=None, with_low_timeout=False):
    """
    Выполнить последовательность запросов (с опциональным callback-ом) с учетом ретраев и таймаутов. Предварительно
    запросы джойнятся.
    """
    tm = get_transaction_manager()
    passed_engine = engine
    for query, callback in join_queries(queries):
        if isinstance(query, DbTransactionContainer):
            if trx_state == TRX_STATE_ACTIVE:
                raise ValueError('Nested transactions are not supported')
            _safe_execute_transaction(
                query,
                retries=retries,
                transaction_retries=transaction_retries,
                graphite_logger=graphite_logger,
                with_low_timeout=with_low_timeout,
            )
        elif tm.started and trx_state != TRX_STATE_ACTIVE:
            # Выполняем запросы под управлением менеджера транзакций. Каждый запрос нужно выполнить в контексте
            # транзакции. FIXME: сделать менее костыльно
            _safe_execute_transaction(
                DbTransaction(lambda: [(query, callback)])(),
                retries=retries,
                transaction_retries=transaction_retries,
                graphite_logger=graphite_logger,
                with_low_timeout=with_low_timeout,
            )
        else:
            raw_query = query.to_query()
            # Если мы не в транзакции и engine явно не передан, engine надо выбирать для каждого запроса.
            # Иначе, engine выбирается один раз при старте транзакции.
            if trx_state != TRX_STATE_ACTIVE and passed_engine is None:
                dbm = find_dbm_for_eav_query(query)
                engine = dbm.get_engine(force_master=not raw_query.is_selectable, with_low_timeout=with_low_timeout)

            log.debug('Query: %s', str(raw_query.compile(dialect=engine.dialect)))
            try:
                result = safe_execute(
                    engine,
                    executable=lambda engine_: engine_.execute(raw_query),
                    retries=retries,
                    trx_state=trx_state,
                    graphite_logger=graphite_logger,
                )
                try:
                    if callback:
                        callback(result)
                finally:
                    result.close()
            except Exception as e:
                if query.ignore_errors():
                    log.warning('Error executing eav query: %s', query, exc_info=e)
                    continue
                raise


def _safe_execute_transaction(transaction_container, retries=None, transaction_retries=None, graphite_logger=None,
                              with_low_timeout=False):
    """
    Безопасно выполняет транзакцию, с учетом ретраев и таймаутов. Поддерживает ретраи всей транзакции.
    """
    graphite_logger = graphite_logger or GraphiteLogger(service='db')

    # Нужно узнать, в какую базу идти, и есть ли вообще запросы в транзакции.
    queries = transaction_container.get_queries()
    try:
        pair = next(queries)
    except StopIteration:
        return
    first_query, callback = split_query_and_callback(pair)

    dbm = find_dbm_for_eav_query(first_query)
    tm = get_transaction_manager()

    engine = dbm.get_engine(with_low_timeout=with_low_timeout)
    transaction = tm.enter_transaction(engine)
    transaction_retries = transaction_retries or engine.reconnect_retries
    if tm.started:
        # При работе под управлением менеджера транзакций - не умеем ретраиться
        transaction_retries = 1

    # Необходимо использовать тот же генератор, на случай если данная транзакция не поддерживает ретраи
    queries = chain([(first_query, callback)], queries)
    last_exc = None
    for i in range(transaction_retries):
        if i > 0:
            # Получаем новый генератор запросов на последующих итерациях
            queries = transaction_container.get_queries()
        begin_executed = all_queries_executed = has_unrecoverable_error = False
        try:
            # Если мы работаем под управлением менеджера транзакций, мы можем получить уже запущенную транзакцию
            if not transaction.is_started:
                # BEGIN выполняем как обычный запрос с заданным числом ретраев
                safe_execute(
                    engine,
                    transaction.begin,
                    retries=retries,
                    graphite_logger=graphite_logger,
                    trx_state=TRX_STATE_BEGIN,
                )
            begin_executed = True
            transaction_engine = transaction.connection
            safe_execute_queries(
                queries,
                engine=transaction_engine,
                trx_state=TRX_STATE_ACTIVE,
                retries=retries,
                graphite_logger=graphite_logger,
            )
            all_queries_executed = True
        except (DBIntegrityError, DBDataError) as exc:
            last_exc = exc
            # В случае любой ошибки целостности или валидации данных нет смысла делать ретраи
            log.error(
                'Database transaction attempt [%d/%d] failed due to %s, started=%s',
                i + 1,
                transaction_retries,
                exc.__class__.__name__,
                begin_executed,
                exc_info=exc,
            )
            has_unrecoverable_error = True
        except DBOperationalError as exc:
            last_exc = exc
            log.error(
                'Database transaction attempt [%d/%d] failed due to %s, started=%s',
                i + 1,
                transaction_retries,
                exc.__class__.__name__,
                begin_executed,
                exc_info=exc,
            )
        except DBError as exc:
            last_exc = exc
            log.info(
                'Database transaction attempt [%d/%d] error, started=%s',
                i + 1,
                transaction_retries,
                begin_executed,
                exc_info=exc,
            )
        except Exception as exc:
            last_exc = exc
            # Случай неизвестной ошибки (вероятнее всего, в сериализаторе)
            log.info(
                'Database transaction attempt [%d/%d] failed due to unknown error, started=%s',
                i + 1,
                transaction_retries,
                begin_executed,
                exc_info=exc,
            )
            has_unrecoverable_error = True

        if not begin_executed:
            # Транзакция даже не началась, смысла ретраить дальше нет
            break

        if all_queries_executed:
            if tm.started:
                # Если менеджер транзакций управляет транзакцией, не делаем commit
                return
            try:
                return safe_execute(
                    transaction_engine,
                    lambda engine_: transaction.commit_and_close(),
                    retries=1,
                    graphite_logger=graphite_logger,
                    trx_state=TRX_STATE_COMMIT,
                )
            except DBError as exc:
                last_exc = exc
                # Произошла ошибка при коммите транзакции - не ретраим, необходимо попытаться вызвать rollback и
                # закрыть соединение. Объект транзакции остается активным, даже если соединение инвалидировано - см.
                # http://docs.sqlalchemy.org/en/rel_0_8/core/connections.html#sqlalchemy.engine.Connection.invalidate
                log.error(
                    'Database transaction attempt [%d/%d] failed to commit',
                    i + 1,
                    transaction_retries,
                    exc_info=exc,
                )

        # Попали сюда - значит обязательно нужен rollback
        try:
            safe_execute(
                transaction_engine,
                lambda engine_: transaction.rollback_and_close(),
                retries=1,
                graphite_logger=graphite_logger,
                trx_state=TRX_STATE_ROLLBACK,
            )
        except DBError as exc:
            last_exc = exc
            # Если произошла ошибка при rollback-е, соединение гарантированно будет закрыто. Если
            # причина rollback-а не в ошибке commit, можем без опасений ретраить транзакцию.
            # При этом можем попасть на лок (idle-транзакция у нас висит 5 секунд), но попытаться стоит.
            log.info(
                'Database transaction attempt [%d/%d] failed to rollback',
                i + 1,
                transaction_retries,
                exc_info=exc,
            )
        if all_queries_executed or has_unrecoverable_error:
            # 1) Если не удалось выполнить commit с первого раза, не ретраим транзакцию. Состояние в БД
            # может быть неизвестно.
            # 2) Если получили ошибку целостности данных, ретраить бессмысленно.
            # 3) Если получили неизвестную ошибку при сериализации, также не ретраим.
            break
    raise last_exc


class _Router(object):
    """
    db :
    {
        master:
        {
            host: '127.0.0.1',
            database: 'socialdb',
            user: 'root',
            password: 'root',
            port: 3306,
            connection_timeout: 1,
            retries: 4,
            read_timeout: 1,
            write_timeout: 10,
            type: 'master',
            driver: 'mysql'
        }
    }
    """

    def __init__(self, configs, graphite_logger=None):
        self._configs = [self._override_default_config(c) for c in configs]

        self.engines = []
        for config in self._configs:
            dsn = URL(
                config['driver'],
                host=_format_host(config.get('host')),
                username=config.get('user'),
                password=config.get('password'),
                port=config.get('port'),
                database=config.get('database'),
            )

            connect_args = dict(settings.DB_DEFAULT_CONNECT_ARGS[dsn.drivername])
            for key in connect_args:
                if key in config:
                    connect_args[key] = config[key]

            engine_args = dict(
                connect_args=connect_args,
                pool_size=settings.DB_POOL_SIZE,
                pool_recycle=settings.DB_POOL_RECYCLE,
                echo=False,
            )
            if config['driver'] != 'sqlite':
                engine_args.update(dict(
                    poolclass=QueuePool,
                    max_overflow=settings.DB_POOL_MAX_OVERFLOW,
                ))
            engine = create_engine(dsn, **engine_args)

            engine.reconnect_retries = config['retries']
            engine.has_low_timeout = config['has_low_timeout']
            self.engines.append(engine)

        self.graphite_logger = graphite_logger

    def _override_default_config(self, user_config):
        """Строит исчерпывающую конфигурацию."""
        full_config = {
            'retries': settings.DB_RETRIES,
            'has_low_timeout': False,
        }
        full_config.update(
            settings.DB_DEFAULT_CONNECT_ARGS[user_config['driver']],
        )
        full_config.update(user_config)
        return full_config

    def get_configs(self):
        """Возвращает список конфигураций (конфигурация -- это словарь)."""
        return self._configs[:]

    def select_engine(self, with_low_timeout=False):
        matching_engines = [
            engine for engine in self.engines
            if engine.has_low_timeout == with_low_timeout
        ]
        if not matching_engines:
            raise RuntimeError('No matching engine found')
        return matching_engines[0]


class _DBManager(object):
    """
    Обертка над SQLAlchemy Engine, призванная упростить работу с конфигурацией нескольких БД.
    Для запроса умеет выдавать подходящий Engine.
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.configured = False
        self._master = None
        self._slave = None

    def is_configured(self):
        """Функция нужна в тестировании для обхода mocks"""
        return self.configured

    def configure(self, db_config):
        self.reset()
        slave_configs = []
        master_configs = []
        for key, config in db_config.items():
            if config.get('type', 'master') == 'master':
                master_configs.append(config)
            else:
                slave_configs.append(config)

        if not master_configs:
            log.info('No master configs passed')
        else:
            self._master = _Router(master_configs)

        if not slave_configs:
            log.info('No slave configs passed')
        else:
            self._slave = _Router(slave_configs)

        self.configured = True

    def get_engine(self, force_master=True, with_low_timeout=False):
        """
        По умолчанию возвращает engine для похода в мастер-базу.
        Если разрешен поход в слейв-базу, при отсутствии слейв-конфигов откатываемся на мастер-базу.
        """
        router = self._master
        if not force_master and self._slave:
            router = self._slave
        if not router:
            raise RuntimeError('No matching router found')
        return router.select_engine(with_low_timeout=with_low_timeout)


_dbms = defaultdict(_DBManager)


def get_dbm(db_name):
    return _dbms[db_name]


METADATA_TO_DBM = {
    central_metadata: 'passportdbcentral',
}


def find_dbm(table):
    return get_dbm(METADATA_TO_DBM[table.metadata])


def find_sharded_dbm(table, key):
    return get_dbm(get_db_name(table.name, key))


def find_dbm_for_eav_query(eavquery):
    table = eavquery.get_table()
    if eavquery.is_sharded():
        return find_sharded_dbm(table, eavquery.uid)
    else:
        return find_dbm(table)
