"""
The content of the file is copied from <skynet_root>/copier2/server/db.py.
"""

import re
import py
import time
import apsw

from kernel.util.errors import formatException


class Database(object):
    def __init__(self, ctx):
        self.ctx = ctx
        self.cfg = ctx.cfg.database
        self.log = ctx.log.getChild('db')
        self.logSQL = self.log.getChild('sql')
        self._db = None
        self._opened = False
        self._initialized = False
        self._limitVariables = None
        self.lastRowId = None
        self.changedRows = None

    def open(self):
        assert not self._opened

        path = self.cfg.path.main
        assert isinstance(path, basestring)

        return self._open(path, self.cfg.attach)

    def _open(self, path, attach=(), noretry=False):
        try:
            self._db = apsw.Connection(path)
            self._db.setbusytimeout(self.cfg.lock_timeout * 1000)
            if self.cfg.quick_check and self.queryCol('pragma quick_check') != 'ok':
                raise apsw.CantOpenError('Quick check failed')
            if self.cfg.integrity_check and self.queryCol('pragma integrity_check') != 'ok':
                raise apsw.CantOpenError('Integrity check failed')
            self.query('pragma synchronous = %s' % (self.cfg.synchronous, ))
            self.query('pragma auto_vacuum = %s' % (self.cfg.auto_vacuum, ))
            self.query('pragma foreign_keys = %s' % (self.cfg.foreign_keys, ))
            self.query('pragma page_size = %s' % (self.cfg.page_size, ))
            self.query('pragma wal_autocheckpoint = %s' % (self.cfg.wal_autocheckpoint, ))

            self._limitVariables = self._db.limit(apsw.SQLITE_LIMIT_VARIABLE_NUMBER)

            for db in ['main'] + list(attach):
                if db != 'main':
                    self.query('ATTACH DATABASE ? AS ?', (getattr(self.cfg.path, db), db))
                self.query('pragma %s.journal_mode = %s' % (db, self.cfg.journal_mode, ))
        except apsw.Error as ex:
            self.log.error('Failed to open database ({0}): {1}. Will cleanup and try again'.format(path, ex))
            if noretry:
                raise apsw.CantOpenError('Unable to open database file ({0}): {1}'.format(
                    path, formatException()
                ))
        else:
            self.log.info('Opened database at {0}'.format(path))
            self._opened = True
            return

        try:
            pypath = py.path.local(path)
            pypath.remove()
            pypath.ensure(file=1, force=1)
        except:
            self.log.debug('Error while cleaning db file: {0}'.format(formatException()))

        return self._open(path, attach, noretry=True)

    def close(self):
        if not self._opened:
            return False

        try:
            self._db.close()
            self._opened = False
            self.log.info('Closed db')
            return True
        except Exception as ex:
            self.log.warning('Unable to close db properly: %s', str(ex))

        return False

    def _toDict(self, row, cursor):
        return dict((x[0], row[idx]) for idx, x in enumerate(cursor.description))

    def init(self):
        try:
            userVersion = self.queryCol('pragma user_version')
            config = self.cfg.migrations

            migrateLog = self.log.getChild('migrate')

            for version, sqlFiles in sorted(config.forward.iteritems()):
                version = int(version)
                if not isinstance(sqlFiles, list):
                    sqlFiles = [sqlFiles]

                if userVersion < version:
                    with self:
                        migrateLog.info('Migrating %d -> %d' % (userVersion, version))
                        for sqlFile in sqlFiles:
                            migrateLog.debug('  %s' % (sqlFile, ))
                            queries = py.path.local(config.path).join(sqlFile).read(mode='rb')
                            self.executeScript(queries)

                        self.query('pragma user_version = %d' % (version, ))
                        userVersion = version
                        migrateLog.info('Done (migrated to version %d)' % (userVersion, ))

            if userVersion != version:
                raise Exception('Dont know how to degrade db version (%d -> %d)' % (userVersion, version))

        except Exception as ex:
            raise apsw.SchemaChangeError(str(ex))
        else:
            self._initialized = True

    def vacuum(self, db='main'):
        if db == 'all':
            for db in ['main'] + self.cfg.attach:
                self.vacuum(db)
            return

        if db == 'main':
            self.log.info('VACUUM')
            return self.query('VACUUM')

        else:
            assert db in self.cfg.attach
            cdb = Database(self.ctx)
            cdb.log = cdb.log.getChild(db)
            cdb._open(getattr(self.cfg.path, db), noretry=True)
            try:
                return cdb.vacuum()
            finally:
                cdb.close()

    def commit(self):
        ts = time.time()
        try:
            self.queryOne('COMMIT')
        finally:
            if self.cfg.debug_transactions or self.cfg.debug_sql:
                self.logSQL.debug('COMMIT %0.4fs', time.time() - ts)

    def rollback(self):
        ts = time.time()
        try:
            self.queryOne('ROLLBACK')
        finally:
            if self.cfg.debug_transactions or self.cfg.debug_sql:
                self.logSQL.info('ROLLBACK %0.4fs', time.time() - ts)

    def createFunction(self, func, name, argsCount):
        return self._db.createscalarfunction(name, func, argsCount)

    def createModule(self, name, source):
        self._db.createmodule(name, source)

    def _limitParams(self, sql, params):
        if len(params) > self._limitVariables:
            return (sql.strip(';') + ';') * (len(params) / self._limitVariables + 1)
        return sql

    def _expandParams(self, sql, params):
        flatten = []
        sqls = []

        for idx, param in enumerate(params):
            if isinstance(param, (list, tuple)):
                numParams = len(param)
                if numParams > 500:
                    assert sql.count('??') == 1
                    assert len(sqls) == 0

                for start in range(0, numParams, 500):
                    sqls.append(sql.replace('??', ', '.join('?' for _ in param[start:start + 500]), 1))

                flatten.append(idx)

        if not flatten:
            return sql, params

        if sqls:
            sql = '; '.join(_.strip(';') for _ in sqls)

        newParams = []
        prevIdx = 0
        for idx in flatten:
            newParams.extend(params[prevIdx:idx])
            newParams.extend(params[idx])
            prevIdx = idx + 1

        return sql, newParams

    def _sql2log(self, sql):
        sql = re.sub('\?(\s*?,\s*?\?)+', '??', sql)  # "?, ?, ?" => "??"
        sql = re.sub('(\w|,) +', '\\1 ', sql)        # eat meaningful whitespace
        sql = re.sub('\n\n', '\n', sql)              # drop empty lines
        sql = sql.strip()
        return sql

    def _execute(self, sql, params=None, paramsHint=None, script=False, many=False):
        assert not (script and many)
        ts = time.time()

        if params is None:
            params = ()
        else:
            params = tuple(params)

        if paramsHint is None:
            paramsHint = repr(params) if params else ''

        if many:
            if not paramsHint:
                paramsHint = '()'
            paramsHint = '[(%s,) x %d]' % (paramsHint, len(params))
        else:
            if paramsHint:
                paramsHint = '[%s]' % (paramsHint, )

        try:
            cursor = self._db.cursor()
            if many:
                cursor.executemany(sql, params)
            else:
                cursor.execute(*self._expandParams(sql, params))
            te1 = time.time()
        except Exception as ex:
            self.logSQL.warning('%s  %s:  %s', self._sql2log(sql), paramsHint, ex)
            raise
        else:
            try:
                msg, args2 = '[%0.4fs]  %s  %s', (te1 - ts, self._sql2log(sql), paramsHint)

                if te1 - ts > self.cfg.long_time_queries_threshold:
                    self.logSQL.warning(msg, *args2)
                elif self.cfg.debug_sql:
                    self.logSQL.debug(msg, *args2)
            except:
                cursor.close()
                raise

        try:
            self.lastRowId = cursor.lastrowid
        except:
            self.lastRowId = None

        try:
            self.changedRows = cursor.rowcount
        except:
            self.changedRows = -1

        return cursor

    def executeScript(self, sqlscript):
        self._execute(sqlscript, script=True)

    def executeMany(self, sql, params, paramsHint):
        self._execute(sql, params, paramsHint, many=True)

    def query(self, sql, params=None, paramsHint=None, one=False, asDict=False, getLastId=False):
        cursor = self._execute(sql, params, paramsHint)

        try:
            if getLastId:
                return cursor.lastrowid

            if one:
                if asDict:
                    return self._toDict(cursor.next(), cursor)
                return cursor.next()

            if asDict:
                return tuple([self._toDict(row, cursor) for row in cursor])

            return cursor.fetchall()
        finally:
            cursor.close()

    def iquery(self, sql, params=None, paramsHint=None, asDict=False):
        if params is None:
            params = ()

        cursor = self._execute(sql, params, paramsHint)

        for row in cursor:
            if asDict:
                yield self._toDict(row, cursor)
            else:
                yield row

    def queryOne(self, sql, args=(), asDict=False):
        return self.query(sql, args, one=True, asDict=asDict)

    def queryCol(self, sql, args=()):
        result = self.queryOne(sql, args)
        try:
            return result[0]
        except TypeError:
            return result

    def query_col(self, sql, args=()):
        return [row[0] for row in self.query(sql, args)]

    def __repr__(self):
        if not self._opened:
            return '<Database>'
        elif not self._initialized:
            return '<Database (opened)>'
        else:
            return '<Database (opened ok)>'

    def __enter__(self):
        return self._db.__enter__()

    def __exit__(self, *args, **kwargs):
        return self._db.__exit__(*args, **kwargs)
