# -*- coding: utf-8 -*-
from collections import defaultdict
import logging
import random
from time import (
    sleep,
    time,
)

from django.conf import settings
import MySQLdb
from passport.backend.oauth.core.common.utils import (
    escape,
    first_or_none,
)
from passport.backend.oauth.core.db.eav.errors import (
    DBIntegrityError,
    DBTemporaryError,
)
from passport.backend.oauth.core.logs.graphite import (
    FAILED_RESPONSE_CODE,
    GraphiteLogger,
    SUCCESS_RESPONSE_CODE,
)
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import URL
import sqlalchemy.exc
from sqlalchemy.pool import QueuePool


log = logging.getLogger('db.eav')


RAW_DB_EXCEPTIONS = (
    sqlalchemy.exc.DatabaseError,
    sqlalchemy.exc.InterfaceError,
    MySQLdb.InterfaceError,
    MySQLdb.DatabaseError,
    sqlalchemy.exc.ResourceClosedError,  # возникает из-за бага в SQLAlchemy
)
DB_INTEGRITY_EXCEPTIONS = (
    sqlalchemy.exc.IntegrityError,
    MySQLdb.IntegrityError,
)


def query_to_string(query, engine):
    c = query.to_sql().compile(dialect=engine.dialect)
    return escape('%s %s' % (c, c.params))  # TODO: перейти на compile_kwargs={'literal_binds': True})


class _Router(object):
    """
    db :
    {
        master:
        {
            host: '127.0.0.1',
            database: 'socialdb',
            user: 'root',
            password: 'root',
            port: 3306,
            connection_timeout: 1,
            retries: 4,
            retry_timeout: 0.5,
            read_timeout: 1,
            write_timeout: 10,
            type: 'master',
            driver: 'mysql'
        }
    }
    """

    def __init__(self, configs):
        self.engines = []
        for config in configs:
            dsn = URL(
                config['driver'],
                host=config.get('host'),
                username=config.get('user'),
                password=config.get('password'),
                port=config.get('port'),
                database=config.get('database'),
            )

            connect_args = dict(settings.DB_DEFAULT_CONNECT_ARGS[dsn.drivername])
            for key in connect_args:
                if key in config:
                    connect_args[key] = config[key]

            engine = create_engine(
                dsn,
                connect_args=connect_args,
                poolclass=QueuePool,
                pool_size=settings.DB_POOL_SIZE,
                pool_recycle=settings.DB_POOL_RECYCLE,
                max_overflow=settings.DB_POOL_MAX_OVERFLOW,
                pool_pre_ping=settings.DB_POOL_PRE_PING,
            )

            engine.reconnect_retries = config.get('retries', settings.DB_RETRIES)
            engine.reconnect_timeout = config.get('retry_timeout', settings.DB_RETRY_TIMEOUT)
            self.engines.append(engine)

    def select_engine(self):
        """
        Выбор engine для запроса.
        Варианты: по приоритету, дефолтный, рандомный и тд
        """
        raise NotImplementedError()  # pragma: no cover


class _SlaveRouter(_Router):
    def select_engine(self):
        return random.choice(self.engines)


class _MasterRouter(_Router):
    def select_engine(self):
        return first_or_none(self.engines)


class _DBManager(object):
    def __init__(self):
        self._master = None
        self._slave = None

    def configure(self, db_config):
        slave_configs = []
        master_configs = []
        for key, config in db_config.items():
            if config.get('type', 'master') == 'master':
                master_configs.append(config)
            else:
                slave_configs.append(config)

        if not master_configs:
            log.info('No master configs passed')
        else:
            self._master = _MasterRouter(master_configs)

        if not slave_configs:
            log.info('No slave configs passed')
        else:
            self._slave = _SlaveRouter(slave_configs)

    def get_all_engines(self):
        return self._master.engines + self._slave.engines

    def select_engine(self, executable):
        """Выбирает engine (мастер или слейв) для исполнений executable"""
        is_master = True
        if self._slave is not None:
            if executable is None:
                is_master = False
            elif not executable.force_master and executable.is_selectable:
                is_master = False
        router = self._master if is_master else self._slave
        return is_master, router.select_engine()

    @staticmethod
    def _execute(engine, func, is_master, retries=None):
        """
        Безопасное выполнение функции с учетом ретраев и таймаутов.
        Функция func принимает один аргумент: engine.
        """
        host = engine.url.host
        graphite_logger = GraphiteLogger(
            service='db',
            srv_hostname=host,
            srv_ipaddress=host,
            is_master=is_master,
        )

        retries = retries or engine.reconnect_retries
        error = None
        for i in range(retries):
            retries_left = retries - i - 1
            network_error = False
            response_code = SUCCESS_RESPONSE_CODE
            start_time = time()
            try:
                return func(engine)
            except DB_INTEGRITY_EXCEPTIONS as e:
                error = e
                network_error = True
                response_code = FAILED_RESPONSE_CODE
                retries_left = 0
                break
            except RAW_DB_EXCEPTIONS as e:
                error = e
                network_error = True
                response_code = FAILED_RESPONSE_CODE
                if retries_left:
                    sleep(engine.reconnect_timeout)
            finally:
                graphite_logger.log(
                    duration=time() - start_time,
                    response=response_code,
                    network_error=network_error,
                    retries_left=retries_left,
                )

        if isinstance(error, DB_INTEGRITY_EXCEPTIONS):
            log.warning('Database integrity error (host %s): %s' % (host, error))
            raise DBIntegrityError(error, host=host)
        else:
            log.warning('Database error (host %s): %s' % (host, error))
            raise DBTemporaryError(error, host=host)

    def transaction(self, transaction):
        """
        Безопасное выполнение транзакции с учетом ретраев и таймаутов.
        """
        def func(engine):
            for i, query in enumerate(transaction.queries, start=1):
                log.debug('Queryset[%d]: %s', i, query_to_string(query, engine))
            with engine.begin() as connection:
                return [
                    connection.execute(query.to_sql())
                    for query in transaction.queries
                ]

        is_master, engine = self.select_engine(transaction)
        return self._execute(
            engine=engine,
            func=func,
            is_master=is_master,
            retries=transaction.retries,
        )

    def execute(self, query):
        """
        Безопасное выполнение запроса с учетом ретраев и таймаутов.
        """
        def func(engine):
            log.debug('Query: %s', query_to_string(query, engine))
            return engine.execute(query.to_sql().execution_options(autocommit=True))

        is_master, engine = self.select_engine(query)
        return self._execute(
            engine,
            func,
            is_master=is_master,
            retries=query.retries,
        )

    def ping(self, retries=None):
        """
        Пинг слейва (с ретраями)
        """
        def func(engine):
            engine.connect()

        is_master, engine = self.select_engine(executable=None)
        try:
            self._execute(
                engine,
                func,
                is_master=is_master,
                retries=retries,
            )
        except DBTemporaryError:
            raise DBTemporaryError(
                'Database is unavailable: "%s" (%s)' % (
                    engine.url.database,
                    engine.url.host,
                ),
            )


_dbms = defaultdict(_DBManager)


def get_dbm(db_name):
    return _dbms[db_name]
