import apsw
import contextlib
import os
import sys
import py
import re
import textwrap
import time
import traceback

from .framework.component import Component
from .framework.utils import human_time, human_size


class VTIn(object):
    data = []

    def Create(self, db, modulename, dbname, tablename, *args):  # noqa
        schema = 'CREATE TABLE %s (value)' % (tablename, )
        return schema, self

    Connect = Create

    def Open(self):  # noqa
        return self

    def Filter(self, *args):  # noqa
        self.iter = iter(self.data)

    def Column(self, col):  # noqa
        return self.value

    def Eof(self):  # noqa
        try:
            self.value = self.iter.next()
        except StopIteration:
            return True

    def Close(self, *args):  # noqa
        pass

    Next = Close
    Destroy = Close
    BestIndex = Close


class Database(Component):
    """
    Thin db abstraction layer on top of APSW sqlite3 bindings.
    All methods not starting with "_" are safe to public use.
    """

    CHECK_FORCE = 1
    CHECK_IF_DIRTY = 2

    PAGE_SIZE = 4096
    CACHE_SIZE = 8192               # 32mb max cache
    JOURNAL_SIZE_LIMIT = 67108864   # truncate journal on checkpoints to 64mb
    WAL_AUTOCHECKPOINT = 500        # commit wal journal if size > 2mb (4096 * 500)

    # Init, constancts and set* meths {{{
    Error = apsw.Error

    def __init__(self, path, lock_timeout=60, mmap=False, temp=None, parent=None):
        assert isinstance(path, basestring)
        assert isinstance(lock_timeout, (int, float))

        super(Database, self).__init__(logname='db', parent=parent)

        self._path = path
        self._lock_timeout = lock_timeout

        self._log = self.log
        self._log_sql = self._log.getChild('sql')

        self._db = None
        self._opened = False
        self._debug_sql = False
        self._debug_transactions = False
        self._sql_warning_threshold = 1
        self._mmap = mmap
        self._temp = temp

        self._transaction_level = 0

    def set_debug(self, sql=True, transactions=True):
        ret = self._debug_sql, self._debug_transactions
        self._debug_sql = sql
        self._debug_transactions = transactions
        return ret

    def ping(self):
        log = self.log.getChild('chk')
        log.debug('Db ping?')
        ts = time.time()
        assert self.query_one_col('select 1', log=False) == 1
        log.debug('Db pong! [%0.4fs]', time.time() - ts)

        if self.in_transaction():
            log.warning('Db pong -- in transaction, fail')
            return False

        return True

    @contextlib.contextmanager
    def __call__(self, transaction=True, debug_sql=None, debug_transactions=None):
        try:
            if debug_sql is not None or debug_transactions is not None:
                old_debug = self.set_debug(
                    debug_sql if debug_sql is not None else self._debug_sql,
                    debug_transactions if debug_transactions is not None else self._debug_transactions
                )
            else:
                old_debug = None

            if transaction:
                with self:
                    yield self
            else:
                yield self
        finally:
            if old_debug:
                self.set_debug(*old_debug)

    def set_warning_threshold(self, seconds):
        self._sql_warning_threshold = seconds
    # }}}

    # open/close {{{
    def open(self, check=CHECK_FORCE, force=False, quiet=False):
        assert not self._opened
        return self._open(self._path, check=check, noretry=not force, quiet=quiet)

    def reopen(self):
        self.log.info('Reopening db...')
        self.close()
        self.open(check=False, force=False, quiet=True)

    def _open(self, path, check=False, noretry=False, quiet=False):
        try:
            self._db = apsw.Connection(path)
            self._db.setbusytimeout(self._lock_timeout * 1000)

            if self._mmap:
                self.log.debug('Setting mmap_size to %d', self._mmap)
                self.query('pragma mmap_size = %d;' % (self._mmap, ))

            if check:
                if check == self.CHECK_IF_DIRTY:
                    if os.path.exists(self._path + '.dirty'):
                        check = True
                    else:
                        self.log.info('Bypassing db check -- possibly after clear shutdown')
                        check = False

                if check:
                    if not self.check(quick=False):
                        raise apsw.Error('Check failed')

            open(self._path + '.dirty', 'wb').close()

            if sys.platform == 'cygwin':
                # For cygwin version we use temp storage in memory
                self.query('pragma temp_store = MEMORY;')
            else:
                self.query('pragma temp_store = FILE;')
                self.query("pragma temp_store_directory = '{temp}';".format(temp=self._temp))

            self.query(
                textwrap.dedent('''
                    pragma synchronous = OFF;
                    pragma auto_vacuum = OFF;
                    pragma foreign_keys = ON;
                    pragma page_size = {PAGE_SIZE:d};
                    pragma cache_size = {CACHE_SIZE:d};
                    pragma main.journal_mode = WAL;
                    pragma main.journal_size_limit = {JOURNAL_SIZE_LIMIT:d};
                    pragma wal_autocheckpoint = {WAL_AUTOCHECKPOINT:d};
                    pragma locking_mode = EXCLUSIVE;
                '''.format(**type(self).__dict__)).strip(),
                log=False
            )

            if self.query_one_col('pragma page_size') != self.PAGE_SIZE:
                self.log.warning('Unable to set page_size, we will do real job now...')
                self.query('pragma page_size = {0}'.format(self.PAGE_SIZE))
                self.query('pragma journal_mode = delete')
                self.query('vacuum')
                self._opened = True
                self.reopen()

            self.query(
                'CREATE TABLE IF NOT EXISTS dbmaintain ( '
                '    key TEXT, '
                '    value_text TEXT, '
                '    value_int INTEGER, '
                '    PRIMARY KEY (key) ON CONFLICT ABORT '
                ')'
            )

            self.LIMIT_VARIABLE_NUMBER = self._db.limit(apsw.SQLITE_LIMIT_VARIABLE_NUMBER)

            self._db.createmodule('vtin', VTIn())

            self.query('DROP TABLE IF EXISTS _vtin')
            self.query('CREATE VIRTUAL TABLE _vtin USING vtin()')

        except apsw.Error as ex:
            if noretry:
                self._log.error('Failed to open database {0}: {1}'.format(path, ex))
                raise apsw.CantOpenError('Unable to open database file {0}: {1}'.format(
                    path, traceback.format_exc()
                ))
            self._log.error('Failed to open database {0}: {1}, will cleanup and try again'.format(path, ex))
        else:
            if not quiet:
                self._log.info('Opened database at {0}'.format(path))
            self._opened = True
            return

        # Ok, first attempt to open was failed and noretry=False
        # So, remove db file and try again
        try:
            pypath = py.path.local(path)
            pypath.remove()
            pypath.ensure(file=1, force=1)
        except Exception:
            self._log.debug('Error while cleaning db file: {0}'.format(traceback.format_exc()))

        return self._open(path, noretry=True)

    def in_transaction(self):
        return not self._db.getautocommit()

    def status(self):
        status_dict = {
            'memory': {
                'total': apsw.memoryused(),
                'generic': self._db.status(apsw.SQLITE_STATUS_MEMORY_USED)[0],
                'cache': self._db.status(apsw.SQLITE_STATUS_PAGECACHE_USED)[0],
                'prep_stmt': self._db.status(apsw.SQLITE_DBSTATUS_STMT_USED)[0],
            },
            'cache': {
                'hit': self._db.status(apsw.SQLITE_DBSTATUS_CACHE_HIT)[0],
                'miss': self._db.status(apsw.SQLITE_DBSTATUS_CACHE_MISS)[0],
                'writes': self._db.status(apsw.SQLITE_DBSTATUS_CACHE_WRITE)[0] * self.PAGE_SIZE,
            },
            'data': {
                'used_bytes': self.query_one_col('pragma page_count', log=False) * self.PAGE_SIZE,
                'free_bytes': self.query_one_col('pragma freelist_count', log=False) * self.PAGE_SIZE,
            },
            'maintainance': dict(
                self.query(
                    'SELECT key, CASE WHEN value_text THEN value_text ELSE value_int END '
                    'FROM dbmaintain',
                    log=False
                )
            )
        }

        status_dict['maintainance']['last_vacuum_duration'] /= 1000.0
        status_dict['maintainance']['last_analyze_duration'] /= 1000.0

        return status_dict

    def close(self):
        if not self._opened:
            return False

        try:
            try:
                self.commit()
            except Exception:
                pass

            self._db.close()
            self._opened = False
            self._log.info('Closed db')
            try:
                os.unlink(self._path + '.dirty')
            except OSError:
                pass
            return True
        except Exception as ex:
            self._log.warning('Unable to close db properly: %s', str(ex))

        return False

    def maintain(self, vacuum=False, analyze=False, grow=False):
        log = self.log.getChild('maintain')

        if vacuum or analyze:
            props = dict(
                self.query(
                    'SELECT '
                    '    key, '
                    '    CASE WHEN value_text THEN value_text ELSE value_int END '
                    'FROM dbmaintain'
                )
            )

            if vacuum:
                last_vacuum = props.get('last_vacuum', 0)
                if time.time() - last_vacuum >= 24 * 3600:
                    time_ago = 'unknown time' if last_vacuum == 0 else human_time(time.time() - last_vacuum)
                    log.info('Last VACUUM was %s ago, will do it now', time_ago)
                    ts = time.time()
                    self.vacuum()
                    self.query(
                        'REPLACE INTO dbmaintain VALUES (?, null, ?)', ['last_vacuum', int(time.time())]
                    )
                    self.query(
                        'REPLACE INTO dbmaintain VALUES (?, null, ?)', [
                            'last_vacuum_duration',
                            int((time.time() - ts) * 1000)
                        ]
                    )
                    self.reopen()
                    grow = True

            if analyze:
                last_analyze = props.get('last_analyze', 0)
                if time.time() - last_analyze >= 3600:
                    time_ago = 'unknown time' if last_analyze == 0 else human_time(time.time() - last_analyze)
                    log.info('Last ANALYZE was %s ago, will do it now', time_ago)
                    ts = time.time()
                    self.analyze()
                    self.query(
                        'REPLACE INTO dbmaintain VALUES (?, null, ?)', ['last_analyze', int(time.time())]
                    )
                    self.query(
                        'REPLACE INTO dbmaintain VALUES (?, null, ?)', [
                            'last_analyze_duration',
                            int((time.time() - ts) * 1000)
                        ]
                    )

        if grow:
            current_bytes, free_bytes = [
                (c * self.PAGE_SIZE)
                for c in self.query_col('pragma page_count; pragma freelist_count', log=False)
            ]
            min_free = 200 * 1024 * 1024
            min_grow = 100 * 1024 * 1024

            if free_bytes < min_free:
                # we have 0, min free 20 -- increase -- 20
                # we have 5, min free 20 -- increase -- 15
                # we have 15, min free 20 -- increase -- 10
                # we have 20, min free 20 -- noop

                increase_bytes = max(min_free, min_grow + free_bytes)

                log.info(
                    'We have %s free bytes, will increase db size by %s (%s => %s)',
                    human_size(free_bytes), human_size(increase_bytes - free_bytes),
                    human_size(current_bytes), human_size(current_bytes + increase_bytes - free_bytes)
                )
                self.query(
                    'REPLACE INTO dbmaintain (key, value_text) VALUES (?, zeroblob(?))',
                    ['zero', increase_bytes]
                )
                self.query('DELETE FROM dbmaintain WHERE key = ?', ['zero'])
    # }}}

    # Migrations {{{
    def _run_migrate_script(self, script, log):
        log.debug('  %s' % (script, ))
        if script.strpath.endswith('.py'):
            execfile(script.strpath)
            locals()['migrate'](self, log)
        elif script.strpath.endswith('.sql'):
            queries = py.path.local(script).read(mode='rb')
            self.execute_script(queries)
        else:
            raise apsw.SchemaChangeError('Unable to migrate file %s' % (script, ))

    def migrate(self, fw, bw):
        assert self._opened, 'Cant migrate not opened db'

        try:
            user_version = self.query_one_col('pragma user_version')

            want_version = max(fw)

            log = self._log.getChild('migrate')

            for i in range(user_version + 1, want_version + 1):
                self.log.info('Want to migrate %d => %d', user_version, i)
                sql_files = fw[i]

                if not isinstance(sql_files, list):
                    sql_files = [sql_files]

                with self:
                    log.info('Migrating %d -> %d' % (user_version, i))
                    for sql_file in sql_files:
                        self._run_migrate_script(sql_file, log)

                    self.query('pragma user_version = %d' % (i, ))
                    user_version = i
                    log.info('Done (migrated to version %d)' % (user_version, ))

            for i in range(user_version - 1, want_version - 1, -1):
                self.log.debug('Want to degrade %d => %d', user_version, i)
                try:
                    sql_files = bw[i]
                except KeyError:
                    msg, msgargs = 'Dont know how to degrade db version (%r -> %r)', (user_version, i)
                    log.critical(msg, *msgargs)
                    raise Exception(msg % msgargs)

                if not isinstance(sql_files, list):
                    sql_files = [sql_files]

                with self:
                    log.info('Degrading %d -> %d' % (user_version, i))
                    for sql_file in sql_files:
                        self._run_migrate_script(sql_file, log)

                    self.query('pragma user_version = %d' % (i, ))
                    user_version = i
                    log.info('Done (degraded to version %d)' % (user_version, ))
        except Exception as ex:
            raise apsw.SchemaChangeError('%s: %s' % (ex.__class__.__name__, str(ex)))

        return True
    # }}}

    # Transactions {{{
    def begin(self):
        ts = time.time()
        try:
            self.query('BEGIN')
        finally:
            if self._debug_transactions and not self._debug_sql:
                self._log_sql.debug('[%0.4fs]  BEGIN', time.time() - ts)

    def commit(self):
        ts = time.time()
        try:
            self.query('COMMIT')
        finally:
            if self._debug_transactions and not self._debug_sql:
                self._log_sql.debug('[%0.4fs]  COMMIT', time.time() - ts)

    def rollback(self):
        ts = time.time()
        try:
            self.query('ROLLBACK')
        finally:
            if self._debug_transactions and not self._debug_sql:
                self._log_sql.info('[%0.4fs]  ROLLBACK', time.time() - ts)
    # }}}

    # createFunction, vacuum, analyze and check {{{
    def create_function(self, func, name, args_count):
        return self._db.createscalarfunction(name, func, args_count)

    def vacuum(self):
        self._log.info('VACUUM')
        return self.query('VACUUM') == []

    def backup(self):
        log = self.log.getChild('backup')
        new_path = os.path.join(os.path.dirname(self._path), self._path + '.backup')
        log.info('Backup target: %s', new_path)

        if os.path.exists(new_path):
            os.unlink(new_path)

        new_db = apsw.Connection(new_path)

        try:
            ts = time.time()
            print_ts = int(time.time())

            with new_db.backup('main', self._db, 'main') as backup:
                while not backup.done:
                    backup.step(1000)
                    done = backup.pagecount - backup.remaining
                    percent = (float(done) / backup.pagecount) * 100

                    now = int(time.time())
                    if now != print_ts:
                        log.debug('  done %5.2f%%', percent)
                        print_ts = now

            log.info('Backup finished in %ds', time.time() - ts)
        finally:
            new_db.close()

    def check(self, quick=True):
        self.log.debug('check db quick=%r', quick)
        if quick:
            return self.query_one_col('pragma quick_check') == 'ok'
        else:
            return self.query_one_col('pragma integrity_check') == 'ok'

    def analyze(self):
        self._log.info('ANALYZE')
        return self.query('ANALYZE') == []
    # }}}

    # sql execution {{{
    def _to_dict(self, row, cursor):
        return dict((x[0], row[idx]) for idx, x in enumerate(cursor.description))

    def _expand_params(self, sql, params):
        good_params = []

        for idx, param in enumerate(params):
            if isinstance(param, (list, tuple)):
                assert sql.count('??') == 1

                VTIn.data = param

                sql = sql.replace('??', 'SELECT value FROM _vtin')
            else:
                good_params.append(param)

        return sql, good_params

    def _sql2log(self, sql):
        sql = re.sub('\?(\s*?,\s*?\?)+', '<placeholders>', sql)  # "?, ?, ?" => "<placeholders>"
        sql = re.sub('(\w|,) +', '\\1 ', sql)        # eat meaningful whitespace
        sql = re.sub('\n\n', '\n', sql)              # drop empty lines
        sql = sql.strip()
        return sql

    def _execute(self, sql, params=None, script=False, many=False, log=True, fetch=False, fetch_one=False):
        assert not (script and many)
        logger = self._log_sql

        ts = time.time()

        if params is None:
            params = ()
        elif isinstance(params, (list, tuple)):
            params = tuple(params)

        if log is True:
            if many:
                try:
                    params_len = len(params)
                except TypeError:
                    params_len = '?'
                params_hint = '[<?> x %s]' % (params_len, )
            else:
                params_print = []
                if params:
                    for param in params:
                        if isinstance(param, buffer):
                            params_print.append('<b:%d:%s>' % (len(param), param[:20].encode('hex')))
                        else:
                            if isinstance(param, (list, tuple)):
                                if len(param) > 5:
                                    params_print.append('<list %d elements>' % (len(param), ))
                                else:
                                    params_print.append(repr(param))
                            else:
                                params_print.append(repr(param))
                    params_print = '[%s] (%d total)' % (', '.join(params_print), len(params))

                params_hint = '%s' % (params_print if params_print else '[]')

        cursor = self._db.cursor()
        try:
            if many:
                cursor.executemany(sql, params)
            else:
                cursor.execute(*self._expand_params(sql, params))
            te1 = time.time()

            if fetch:
                data = cursor.fetchall()
            elif fetch_one:
                try:
                    data = cursor.next()
                except StopIteration:
                    data = None
            te2 = time.time()

        except Exception as ex:
            if log is True:
                logger.warning('%s  %s:  %s', self._sql2log(sql), params_hint, ex)
            elif log:
                logger.warning('SQL: %s: %s', log, ex)
            else:
                logger.warning('%s: %s', self._sql2log(sql), ex)

            if not sql.lower().startswith('pragma'):
                if isinstance(ex, (apsw.CorruptError, apsw.NotADBError)):
                    # If we get db being corrupted -- nothing else we can do +(
                    # Just emergency exit
                    self.log.critical('Exiting immidiately')

                    import os
                    os._exit(1)

            cursor.close()

            if isinstance(ex, apsw.BusyError):
                self.log.warning('Got BusyError from APSW, interrupting all queries...')
                self._db.interrupt()
                return False
            elif isinstance(ex, apsw.InterruptError):
                return False
            else:
                raise
        else:
            try:
                cum_time = te2 - ts
                exec_time = te1 - ts
                fetch_time = te2 - te1

                if fetch:
                    data_len = len(data)
                elif fetch_one:
                    data_len = 1 if data is not None else 0
                else:
                    data_len = None

                if data_len is not None:
                    data_len_str = '%3d' % (data_len, )
                else:
                    data_len_str = '   '

                if log is True:
                    msg, args2 = '[%0.4fs %0.4fs][%s]  %s  %s', (
                        exec_time, fetch_time, data_len_str, self._sql2log(sql), params_hint
                    )
                elif log:
                    msg, args2 = '[%0.4fs %0.4fs][%s]  SQL: %s', (
                        exec_time, fetch_time, data_len_str, log
                    )

                if log:
                    if cum_time > self._sql_warning_threshold:
                        logger.warning(msg, *args2)
                    elif self._debug_sql:
                        logger.debug(msg, *args2)
            except Exception:
                cursor.close()
                raise

        if fetch or fetch_one:
            return cursor, data
        return cursor

    def _safe_execute(self, *args, **kwargs):
        while 1:
            result = self._execute(*args, **kwargs)
            if result is False:
                time.sleep(0.3)
            else:
                return result

    def execute_script(self, sqlscript):
        self._safe_execute(sqlscript, script=True)

    def execute_many(self, sql, params):
        self._safe_execute(sql, params, many=True)
    # }}}

    # query meths {{{
    def query(self, sql, params=None, one=False, as_dict=False, get_last_id=False, get_changed=False, log=True):
        cursor, data = self._safe_execute(sql, params, log=log, fetch=not one, fetch_one=one)

        try:
            if get_last_id:
                return self._db.last_insert_rowid()
            if get_changed:
                return self._db.changes()

            if one:
                if as_dict:
                    if data is None:
                        return None
                    return self._to_dict(data, cursor)
                return data

            if as_dict:
                return tuple([self._to_dict(row, cursor) for row in data])

            return data
        finally:
            cursor.close(True if one else False)

    def iquery(self, sql, params=None, as_dict=False, log=True):
        while True:
            try:
                cursor = self._execute(sql, params, log=log)
                try:
                    for row in cursor:
                        if as_dict:
                            yield self._to_dict(row, cursor)
                        else:
                            yield row
                except BaseException as ex:
                    self._log.warning('%s: iquery got exception: %s: %s', self._sql2log(sql), type(ex), str(ex))
                    raise
                finally:
                    cursor.close()
            except apsw.InterruptError:
                self._log.warning('%s: iquery interrupt, will retry', self._sql2log(sql))
            else:
                break

    def query_one(self, sql, args=(), as_dict=False, log=True):
        return self.query(sql, args, one=True, as_dict=as_dict, log=log)

    def query_col(self, sql, args=(), log=True):
        return [row[0] for row in self.query(sql, args, log=log)]

    def query_one_col(self, sql, args=(), log=True):
        result = self.query_one(sql, args, log=log)
        try:
            return result[0]
        except TypeError:
            return None
    # }}}

    # misc {{{
    def __repr__(self):
        if not self._opened:
            return '<Database>'
        else:
            return '<Database (opened)>'
    # }}}

    # __enter/exit__ for working as context manager {{{
    def __enter__(self):
        self._transaction_level += 1
        if self._transaction_level == 1:
            if self.in_transaction():
                self.log.critical('we should not be in transaction here (level %d)', self._transaction_level)
        else:
            self.log.warning('transaction level %d', self._transaction_level)

        return self._db.__enter__()

    def __exit__(self, *args, **kwargs):
        self._transaction_level -= 1
        return self._db.__exit__(*args, **kwargs)
    # }}}
