# -*- encoding: utf-8 -*-
"""
Работа с базами данных Директа
"""

from __future__ import absolute_import

import os, re, random
import simplejson, yaml, time
import MySQLdb
from MySQLdb.constants import CLIENT
import settings
import sqlalchemy
from warnings import warn
from direct.tracing import Trace


DEAD_LOCK_TRIES = 3


# Включаем возврат кол-ва записей из WHERE в rowcount для UPDATE/DELETE
MYSQL_CLIENT_FLAG = CLIENT.FOUND_ROWS


def db_engine(dbname):
    """По имени базы возвращает закешированный engine"""
    if db_engine.cache_pid != os.getpid():
        # Выкидываем все имеющиеся соединения
        disconnect_all()
        db_engine.cache_pid = os.getpid()
    engine = db_engine.cache.get(dbname)
    if engine is None:
        get_db_config(dbname) # Проверяем, что такая база существует
        def connect():
            cfg = get_db_config(dbname)
            compress = bool(cfg.get('compression', False))
            if 'host' in cfg:
                con = MySQLdb.connect(
                    host=cfg['host'], port=int(cfg['port']), db=cfg['db'],
                    user=cfg['user'], passwd=cfg['pass'],
                    use_unicode=True, charset='utf8',
                    compress=compress,
                    client_flag=MYSQL_CLIENT_FLAG,
                )
            else:
                con = MySQLdb.connect(
                    unix_socket=cfg['mysql_socket'], db=cfg['db'],
                    user=cfg['user'], passwd=cfg['pass'],
                    use_unicode=True, charset='utf8',
                    compress=compress,
                    client_flag=MYSQL_CLIENT_FLAG,
                )
            return TracedConnection(con, dbname)
        engine = sqlalchemy.create_engine('mysql+mysqldb://', creator=connect)
        db_engine.cache[dbname] = engine
    return engine
db_engine.cache = {}
db_engine.cache_pid = os.getpid()


def get_db_config(dbname):
    """
    По имени базы получить dict с конфигом
    В имени может встретиться компонента ? - это означает любой child с ненулевым(или неопределённым) весом
    """
    def hash_merge(d, s):
        for k, v in s.items():
            d[k] = v
        return d
    cur_cfg = parse_db_config()
    ret = hash_merge({}, cur_cfg)
    for part in dbname.split(":"):
        childs = cur_cfg.get('CHILDS', None)
        if childs is not None and part == '?' and childs:
            cur_cfg = next(iter(sorted(childs.values(), key=lambda v: v.get('weight', 1) * random.random(), reverse=True)))
            hash_merge(ret, cur_cfg)
        elif childs is not None and part in childs:
            cur_cfg = childs[part]
            hash_merge(ret, cur_cfg)
        else:
            raise Exception("Can't find db config for %s" % dbname)

    if 'CHILDS' in cur_cfg:
        if '_' in cur_cfg['CHILDS']:
            hash_merge(ret, cur_cfg['CHILDS']['_'])
        else:
            raise Exception("Can't find db config for %s (not leaf)" % dbname)

    if 'CHILDS' in ret:
        del ret['CHILDS']

    ret['dbname'] = dbname
    if 'db' not in ret:
        ret['db'] = dbname.split(':')[0]

    if 'pass' in ret and isinstance(ret['pass'], dict):
        if 'file' in ret['pass']:
            with open(ret['pass']['file'], 'rb') as f:
                ret['pass'] = f.read().strip()
        else:
            raise Exception('incorrect pass in '+dbname+': dict, but without file')

    return ret


def parse_db_config():
    """
    Прочитать и распарсить конфиг коннектов к БД
    """
    fn = settings.DB_CONFIG_FILE
    if fn.endswith('.json'):
        return simplejson.loads(open(fn).read())['db_config']
    elif fn.endswith('.yaml'):
        return yaml.safe_load(open(fn).read())['db_config']
    else:
        raise "Incorrect db_config_file: %s" % fn


def disconnect_all():
    """Закрыть все существующие коннекты в пуле"""
    for dbname, engine in db_engine.cache.iteritems():
        cnt = engine.pool.checkedout()
        if cnt > 0:
            warn('Database %s has %d connections checked out' % (dbname, cnt))
        engine.dispose()


class TracedConnection(object):
    # Ловим всякие неожиданные setattr
    __slots__ = ('_connection', '_dbname')

    def __init__(self, connection, dbname):
        self._connection = connection
        self._dbname = dbname

    def __getattr__(self, name):
        return getattr(self._connection, name)

    def cursor(self, *args, **kwargs):
        return TracedCursor(self._connection.cursor(*args, **kwargs), self._dbname)


class TracedCursor(object):
    # Ловим всякие неожиданные setattr
    __slots__ = ('_cursor', '_dbname')

    def __init__(self, cursor, dbname):
        self._cursor = cursor
        self._dbname = dbname

    def __getattr__(self, name):
        return getattr(self._cursor, name)

    def __unicode__(self):
        return "TracedCursor({}, '{}')".format(unicode(self._cursor), self._dbname)

    def __str__(self):
        return "TracedCursor({}, '{}')".format(str(self._cursor), self._dbname)

    def execute(self, query, *args, **kwargs):
        with Trace.current().profile("db:"+query_type(query), tags=unicode(self._dbname)):
            self._cursor.execute(query, *args, **kwargs)

    def executemany(self, query, *args, **kwargs):
        with Trace.current().profile("db:"+query_type(query), tags=unicode(self._dbname)):
            self._cursor.executemany(query, *args, **kwargs)

    def fetchone(self, *args, **kwargs):
        with Trace.current().profile("db:fetch", tags=unicode(self._dbname)):
            return self._cursor.fetchone(*args, **kwargs)

    def fetchmany(self, *args, **kwargs):
        with Trace.current().profile("db:fetch", tags=unicode(self._dbname)):
            return self._cursor.fetchmany(*args, **kwargs)

    def fetchall(self, *args, **kwargs):
        with Trace.current().profile("db:fetch", tags=unicode(self._dbname)):
            return self._cursor.fetchall(*args, **kwargs)


def now(engine):
    """
    Получить текущее время в указанной базе данных
    """
    return engine.scalar("SELECT now()")


def table_exists(engine, table_name):
    """
    Проверить существование таблицы, возвращает Bool
    """
    cnt = engine.scalar("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema=database() AND table_name = %s", (table_name,))
    return cnt > 0


def exec_sql(engine, *args, **kwargs):
    # в случае дедлока - повторяем DEAD_LOCK_TRIES раз
    try_num = 0
    while(True):
        try_num += 1
        try:
            return engine.execute(*args, **kwargs)
        except MySQLdb.OperationalError, (code, msg):
            if code not in (1213, 1205) or try_num >= DEAD_LOCK_TRIES:
                raise MySQLdb.OperationalError(code, "try %d: %s" % (try_num, msg))
            time.sleep(try_num * 3)


lock_partial_matcher = re.compile(r'(get_lock|release_lock|is_free_lock|is_used_lock)\s*\(',  re.IGNORECASE)
def query_type(query):
    splitted_query = query.split(None, 1)
    query_type = splitted_query[0].lower()

    if query_type in ('set', 'begin'):
        return 'read'
    elif query_type == 'select':
        query_tail = splitted_query[1]
        if lock_partial_matcher.match(query_tail):
            return 'lock'
        else:
            return 'read'
    elif query_type in ['delete', 'update', 'insert', 'create', 'drop', 'replace']:
        return 'write'
    else:
        return 'unknown'


#if __name__ == '__main__':
#    import direct.tables.profile_stats as p
#    print p.Base.metadata.create_all()
