# coding: utf-8
import contextlib
import logging
import random
import os
import socket
import time
import threading
from functools import wraps
from itertools import chain, cycle

import six
import sqlalchemy as sal
from sqlalchemy import orm
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.exc import TimeoutError, OperationalError

from .xml import DashElementMaker, TYPE_SERIALIZERS

from .config import get_config


def init_random_generator():
    seed = hash(socket.gethostname()) + int(time.time())
    random.seed(seed)


class WrongSessionType(Exception):
    pass


CORE_DB = "core"
VALIDATION_DB = "validation"
POOL_ROLE = os.environ.get("POOL_ROLE", "tasks")
LONG_QUERY_THRESHOLD = 0.5  # sec
DB_POOL_RECYCLE = 3600  # sec, connections will be renewed after


@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
    context._query_start_time = time.time()


@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
    total = time.time() - context._query_start_time
    if (total > LONG_QUERY_THRESHOLD):
        db_profiler_logger = logging.getLogger("db_profile")
        stmt = statement.replace("\n", " ")
        db_profiler_logger.warning(
            'Too long execution of query "%s" with params %r:  %.3f sec' %
            (stmt, parameters, total))


class TestSessionMaker(object):
    def __init__(self, real_maker):
        self._session = None
        self._get_session = real_maker

    def __call__(self, *args, **kwargs):
        if self._session is None:

            def replace(instance, name, new_method=None):
                def mock(*args, **kwargs):
                    """Mock object, replacement for real %s""" % name
                    pass
                setattr(instance, '_original_' + name, getattr(instance, name))
                setattr(instance, name, new_method or mock)

            def wrap_transaction(t):
                replace(t, 'commit')
                replace(t, 'close')
                replace(t, 'rollback')

                def real_rollback():
                    if t.is_active:
                        t._original_rollback()
                        t._original_close()
                t.real_rollback = real_rollback
                return t

            def wrap_session(s):
                def begin(*args, **kwargs):
                    if s.transaction is None:
                        return wrap_transaction(
                            s._original_begin(*args, **kwargs))
                    return s.transaction
                replace(s, 'begin', begin)

                def commit():
                    s.flush()
                replace(s, 'commit', commit)

                def real_begin():
                    s.transaction = None
                    return s.begin()
                s.real_begin = real_begin
                return s
            s = self._get_session(*args, **kwargs)
            self._session = wrap_session(s)
        return self._session


class ConnectionChecker(object):
    """Проверяет доступность мастера и реплики, а также обновляет информацию в пуле"""

    def __init__(self, dbpool, ping_period, fail_ping_period):
        """parts: список DbPool"""
        self.dbpool = dbpool
        # период обновления информации о доступности
        self.ping_period = ping_period
        # период обновления информации при потере соединения
        self.fail_ping_period = fail_ping_period

    def update_availability_info(self):
        self.check_availability(self.dbpool.master)
        for slave in self.dbpool.slaves:
            self.check_availability(slave)
        self.dbpool.update_slaves_choosing_strategy()
        if not self.dbpool.master.available:
            log = logging.getLogger('db_pinger')
            log.error("Master db %s is not available" % self.dbpool.master)

    def check_availability(self, dbinfo):
        try:
            # timeout is in seconds
            s = socket.create_connection((dbinfo.host, dbinfo.port), timeout=5)
            s.close()

            with contextlib.closing(dbinfo.get_session()) as session:
                session.execute("SELECT 1").scalar()

            if not dbinfo.available:
                log = logging.getLogger('db_pinger')
                log.info('Restored connection to %s' % dbinfo)
                dbinfo.available = True
        except (socket.error, TimeoutError, OperationalError):
            log = logging.getLogger('db_pinger')
            if dbinfo.available:
                log.exception("Failed database connection %s" % dbinfo)
            else:
                log.warning("Database %s is not available" % dbinfo)
            dbinfo.available = False

    def calc_sleep_period(self):
        """Возвращает время в секундах до следующей проверки доступности"""
        if not self.dbpool.master.available or \
                any(not d.available for d in self.dbpool.slaves):
            return self.fail_ping_period
        return self.ping_period


class ThreadConnectionChecker(ConnectionChecker, threading.Thread):
    """Контролирует доступность баз. Запускается в виде отдельного потока.
    Предназначен для использования в серванте"""

    def __init__(self, dbpool, ping_period, fail_ping_period):
        threading.Thread.__init__(self)
        self.daemon = True
        super(ThreadConnectionChecker, self).__init__(dbpool, ping_period,
                                                        fail_ping_period)
        self.lock = threading.Lock()
        self.condition = threading.Condition(self.lock)
        self.can_run = True

    def run(self):
        self.can_run = True
        log = logging.getLogger('db_pinger')
        log.debug("Starting ThreadConnectionChecker")
        while self.can_run:
            try:
                self.update_availability_info()

                with self.condition:
                    self.condition.wait(self.calc_sleep_period())

            except:
                log.exception("Error happened till db availability checking")

    def stop(self):
        """Останавливает поток выполнения"""
        self.can_run = False
        with self.condition:
            self.condition.notify()

    def wake_up(self):
        """Пробуждает от возможного сна"""
        self.condition.notify()


class DbInfo(object):

    def __init__(self, host, port, user, passwd, db, engine, session_maker):
        self.host = host
        self.port = port
        self.user = user
        self.passwd = passwd
        self.db = db
        self.available = True
        self.engine = engine
        self.session_maker = session_maker

    @classmethod
    def from_cfg(cls, params, **common_options):
        host = params.host
        port = params.port
        user = params.user
        passwd = params.password
        db = params.dbname
        con_str = 'postgresql://%(user)s:%(passwd)s@%(host)s:%(port)d/%(db)s' % locals()
        if params.sslmode:
            con_str += "?sslmode=" + params.sslmode
        logging.debug("registering database %s with options %r", con_str, common_options)
        engine = sal.create_engine(con_str, **common_options)
        session_maker = orm.sessionmaker(bind=engine, autocommit=False)
        return cls(host, port, user, passwd, db, engine, session_maker)

    def __str__(self):
        res = "postgresql://%(user)s:%(passwd)s@%(host)s:%(port)d/%(db)s" % self.__dict__
        return res

    def get_session(self):
        return self.session_maker()

    def get_status(self):
        """Возвращает строку со статусом пула соединений"""
        return self.__str__() + ": " + self.engine.pool.status()


class DbPool(object):
    """Управляет соединениями до группы баз (например, основной или социальной)"""

    def __init__(self, master, slaves, db_params, role_params, **engine_options):
        self.master = master
        self.slaves = slaves
        self.connection_checker = None
        self.config_controller = None

        # Assume write_conn is always present, take basic conn params from it
        self.conn_params = db_params.write_conn.template()
        self.role_params = role_params
        self.ping_period = db_params.pool_params.ping_period
        self.fail_ping_period = db_params.pool_params.fail_ping_period

        self.engine_options = engine_options
        self.mutex = threading.RLock()

        if self.ping_period:
            self.start_connection_checker()
        self.update_slaves_choosing_strategy()

    @classmethod
    def from_cfg(cls, db_params, role, **engine_options):
        if role not in db_params.pool_params.roles:
            raise Exception("Could not find params for role '%s'" % role)
        role_params = db_params.pool_params.roles[role]

        master = DbInfo.from_cfg(db_params.write_conn,
                                 pool_size=role_params.write_size,
                                 max_overflow=role_params.write_overflow,
                                 **engine_options)
        slaves = []
        if role_params.read_size + role_params.read_overflow > 0:
            slaves = [DbInfo.from_cfg(conn,
                                      pool_size=role_params.read_size,
                                      max_overflow=role_params.read_overflow,
                                      **engine_options)
                            for conn in db_params.read_conns]

        return cls(master, slaves, db_params, role_params, **engine_options)

    def start_connection_checker(self):
        self.connection_checker = ThreadConnectionChecker(
            self,
            self.ping_period,
            self.fail_ping_period)
        self.connection_checker.start()

    def stop_connection_checker(self):
        if self.connection_checker is not None:
            self.connection_checker.stop()
            self.connection_checker.join()

    def update_slaves_choosing_strategy(self):
        """Обновляет стратегию для выбора slave-базы
        Среди доступных баз выбирает, по возможности, те, что находятся в том же ДЦ.
        Если ни одной slave-базы не доступно, выбирает мастер.
        """
        with self.mutex:
            res = list(filter(lambda o: o.available, self.slaves))
            if len(res) == 0:
                res = [self.master]
            random.shuffle(res)
            self.slaves_chooser = cycle(res)

    def session_failed(self, session):
        """вызывается в случае, когда случается ошибка при работе с базой
        База помечается как недоступная
        """
        with self.mutex:
            engine = session.bind
            if self.master.engine == engine:
                self.master.available = False
            else:
                for db in self.slaves:
                    if db.engine == engine:
                        db.available = False
                self.update_slaves_choosing_strategy()

    def get_master_session(self):
        with self.mutex:
            return self.master.get_session()

    def set_unittest_mode(self):
        with self.mutex:
            self.master.session_maker = TestSessionMaker(self.master.session_maker)
            for slave in self.slaves:
                slave.session_maker = self.master.session_maker

    def get_slave_session(self):
        with self.mutex:
            return next(self.slaves_chooser).get_session()

    def get_master_engine(self):
        with self.mutex:
            return self.master.engine

    def get_slave_engine(self):
        with self.mutex:
            return next(self.slaves_chooser).engine

    def is_write_available(self):
        with self.mutex:
            return self.master.available

    def get_hosts(self):
        with self.mutex:
            return [
                db.host for db in
                chain((self.master,), self.slaves)]

    def get_status(self):
        with self.mutex:
            return '\n'.join([
                p.get_status() for p in
                chain([self.master], self.slaves)])


class DbMgr(object):
    def __init__(self, role, part_ids):
        """
        @param role имя роли (влияет на выбор параметров пула соединений)
        @param start_connection_checker -- запускать ли поток проверки доступности баз
        """
        self.mutex = threading.Lock()
        self.initialized = False
        self.connection_checker = None
        self.role = role
        self.part_ids = set(part_ids)
        self.unittest_mode = False
        self.parts = dict()

    def initialize(self):
        logging.debug("initializing db")
        init_random_generator()
        self.parts = dict()
        for id, db_params in six.iteritems(get_config().databases):
            if id in self.part_ids:
                self.parts[id] = DbPool.from_cfg(db_params, self.role,
                                                 pool_recycle=DB_POOL_RECYCLE)
        self.initialized = True

    def check_initialized(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            self = args[0]
            if not self.initialized:
                with self.mutex:
                    if not self.initialized:
                        self.initialize()
            return func(*args, **kwargs)
        return wrapper

    @check_initialized
    def set_unittest_mode(self):
        self.unittest_mode = True
        for part in self.parts.values():
            part.set_unittest_mode()

    @check_initialized
    def get_master_session(self, part):
        return self.parts[part].get_master_session()

    @check_initialized
    def get_slave_session(self, part):
        return self.parts[part].get_slave_session()

    @check_initialized
    def is_write_available(self):
        return self.parts[CORE_DB].is_write_available()

    def session_failed(self, part, session):
        self.parts[part].session_failed(session)

    def get_status(self):
        return '\n'.join('%s:\n%s' % (name, part.get_status())
                                        for name, part in six.iteritems(self.parts))


_pool = None


def init_pool(db_ids, role=None):
    if role is None:
        role = os.environ.get("POOL_ROLE")
    global _pool
    if _pool is None:
        _pool = DbMgr(role, db_ids)


def get_read_session(part):
    session = _pool.get_slave_session(part)
    setattr(session, 'read_only', True)
    setattr(session, 'part', part)
    return contextlib.closing(session)


@contextlib.contextmanager
def get_write_session(part):
    session = _pool.get_master_session(part)
    setattr(session, 'read_only', False)
    setattr(session, 'part', part)
    transaction = session.transaction
    try:
        yield session
        if transaction.is_active:
            if _pool.unittest_mode:
                session.flush()
            else:
                transaction.commit()
    except Exception:
        if transaction.is_active:
            transaction.rollback()
        raise
    finally:
        session.close()


# DECORATORS


def get_token(session):
    return session.execute('SELECT public.get_token()').scalar()


def check_token(session, token):
    return session.execute('SELECT public.check_token(:token)', locals()).scalar()


def _get_session(part, *args, **kwargs):
    for arg in chain(args, kwargs.values()):
        if isinstance(arg, orm.session.Session):
            if getattr(arg, 'part') == part:
                return arg
    return None


def read_session(part):
    def impl(func):
        @wraps(func)
        def wrap(*args, **kwargs):
            session = _get_session(part, *args, **kwargs)
            if session is not None:
                return func(*args, **kwargs)
            else:
                with get_read_session(part) as session:
                    kwargs['session'] = session
                    try:
                        return func(*args, **kwargs)
                    except (TimeoutError, OperationalError):
                        _pool.session_failed(part, session)
                        raise
        return wrap
    return impl


def read_session_sync(part, token_getter):
    """Передает в функцию read-only сессию соединения к базе, убеждаясь при этом, что
    база в актуальном состоянии.
    Значение токена получается вызовом функции token_getter.
    Декоратор сначала пытается получить соединение с репликой, проверяет ее актуальность и, если
    она не актуальна, создает соединение с мастер-базой
    """
    def impl(func):
        @wraps(func)
        def wrap(*args, **kwargs):
            token = token_getter()
            session = _get_session(part, *args, **kwargs)
            if session is not None:
                try:
                    if token and check_token(session, token):
                        return func(*args, **kwargs)
                except (TimeoutError, OperationalError):
                    _pool.session_failed(part, session)
                    raise
            with get_read_session(part) as session:
                try:
                    if not token or check_token(session, token):
                        kwargs['session'] = session
                        return func(*args, **kwargs)
                except (TimeoutError, OperationalError):
                    _pool.session_failed(part, session)
                    raise
            with get_write_session(part) as session:
                kwargs['session'] = session
                return func(*args, **kwargs)
        return wrap
    return impl


def write_session(part):
    def impl(func):
        @wraps(func)
        def wrap(*args, **kwargs):
            session = _get_session(part, *args, **kwargs)
            if session is not None:
                if getattr(session, 'read_only', True):
                    raise WrongSessionType('session is read-only')
                return func(*args, **kwargs)
            else:
                with get_write_session(part) as session:
                    kwargs['session'] = session
                    return func(*args, **kwargs)
        return wrap
    return impl


def acquire_token(func):

    def _acquire_token(_session, *args, **kwargs):
        res = func(*args, **kwargs)
        EMns = DashElementMaker(namespace="http://maps.yandex.ru/wikimap/1.x",
                                nsmap={'wm': "http://maps.yandex.ru/wikimap/1.x"},
                                typemap=TYPE_SERIALIZERS)
        # ожидается, что функция возвращает ElementTree
        res.append(EMns.token(get_token(_session)))
        return res

    @wraps(func)
    def wrap(*args, **kwargs):
        session = _get_session(DbMgr.CORE, *args, **kwargs)
        if session:
            return _acquire_token(session, *args, **kwargs)
        else:
            with get_write_session(DbMgr.CORE) as session:
                return _acquire_token(session, *args, **kwargs)
    return wrap


@contextlib.contextmanager
def constraint_exclusion(session):
    """Turns on query planner option constraint_exclusion"""
    try:
        session.execute('SET constraint_exclusion = on')
        yield session
    except:
        logging.exception('error in constratin exclusion block')
    finally:
        session.execute('SET constraint_exclusion = off')


def stop_connection_checker():
    _pool.stop_connection_checker()


def set_unittest_mode():
    _pool.set_unittest_mode()
