import logging
import re
import textwrap
import time

import psycopg2
import psycopg2.errors as errors
import psycopg2.errorcodes as errorcodes


PG_RETRYABLE_ERRORS = (
    errorcodes.CONNECTION_EXCEPTION,
    errorcodes.CONNECTION_FAILURE,
    errorcodes.CONNECTION_DOES_NOT_EXIST,
)


class ConnectionClosed(Exception):
    pass


class Database(object):
    errors = errors

    def __init__(self, host, port, user, password, dbname, logname='db'):
        self.host = host
        self.port = port
        self.user = user
        self.password = password
        self.dbname = dbname
        self.log = logging.getLogger(logname)
        self.log_sql = self.log.getChild('sql')

        self._db = None
        self._opened = False
        self._debug_sql = False
        self._debug_transactions = 0.2     # can be set to min secs to debug
        self._sql_warning_threshold = 0.2  # log as warning all meths > 0.2s

        assert self.host is not None, 'Host should be specified'

    def __enter__(self):
        if self._db.closed:
            self._opened = False
            raise ConnectionClosed('Db connection already closed')

        try:
            return self._db.__enter__()
        except (errors.DatabaseError, errors.InterfaceError) as ex:
            self.log.warning('pg error: %s: %s (code %s)', type(ex).__name__, str(ex).strip(), ex.pgcode)

            if (
                ex.pgcode in PG_RETRYABLE_ERRORS
                or 'EOF detected' in str(ex)
                or 'connection already closed' in str(ex)
            ):
                try:
                    self.close()
                except:
                    self._opened = False

            raise

    def __exit__(self, type, value, traceback):
        return self._db.__exit__(type, value, traceback)

    def connect(self, autocommit):
        self._db = psycopg2.connect(
            database=self.dbname, user=self.user, password=self.password,
            host=self.host, port=self.port, target_session_attrs='read-write'
        )
        self._db.autocommit = autocommit
        self.log.info('Connected to db: %r', self._db)
        self._opened = True

    def reconnect(self):
        return self.connect(self._db.autocommit)

    def close(self):
        self._db.close()
        self._opened = False

    def connected(self):
        return self._opened and not self._db.closed

    # Execute meths {{{
    def _sql2log(self, sql, fix_sql):
        sql = re.sub(r'\?(\s*?,\s*?\?)+', '<placeholders>', sql)  # "?, ?, ?" => "<placeholders>"

        if fix_sql:
            sql = re.sub(r'(\w|,) +', '\\1 ', sql)                    # eat meaningful whitespace
            sql = re.sub('\n\n', '\n', sql)                          # drop empty lines

        sql = sql.strip()
        return sql

    def _expand_params(self, sql, params):
        good_params = []

        for idx, param in enumerate(params):
            if isinstance(param, (list, tuple)):
                assert sql.count('%S') == 1

                sql = sql.replace('%S', ', '.join('%s' for _ in range(len(param))))
                if idx == 0:
                    good_params[:] = param
                else:
                    good_params = good_params[:idx] + param + good_params[idx + 1:]
            else:
                good_params.append(param)

        return sql, good_params

    def _execute(
        self, sql, params=None, script=False, many=False, log=True, fetch=False, fetch_one=False,
        fix_sql=True
    ):
        # log = False
        assert not (script and many)
        logger = self.log_sql

        ts = time.time()

        if params is None:
            params = ()
        elif isinstance(params, (list, tuple)):
            params = tuple(params)

        def _get_params_hint():
            if many:
                try:
                    params_len = len(params)
                except TypeError:
                    params_len = '?'
                params_hint = '[<?> x %s]' % (params_len, )
            else:
                params_print = []
                if params:
                    for param in params:
                        if isinstance(param, memoryview):
                            params_print.append('<b:%d:%s>' % (len(param), param[:20].encode('hex')))
                        else:
                            if isinstance(param, (list, tuple)):
                                if len(param) > 5:
                                    params_print.append('<list %d elements>' % (len(param), ))
                                else:
                                    clean_lst = []
                                    for _ in param:
                                        clean_lst.append(_)
                                    params_print.append(repr(clean_lst))
                            else:
                                params_print.append(repr(param))
                    params_print = '[%s] (%d total)' % (', '.join(params_print), len(params))

                params_hint = '%s' % (params_print if params_print else '[]')

            return params_hint

        if log is True:
            params_hint = _get_params_hint()
        else:
            params_hint = None

        try:
            cursor = None
            try:
                cursor = self._db.cursor()

                if many:
                    cursor.executemany(sql, params)
                else:
                    cursor.execute(*self._expand_params(sql, params))
                te1 = time.time()

                if fetch:
                    data = cursor.fetchall()
                elif fetch_one:
                    try:
                        data = next(cursor)
                    except StopIteration:
                        data = None
                te2 = time.time()

            except (errors.DatabaseError, errors.InterfaceError) as ex:
                self.log.warning('pg error: %s: %s (code %s)', type(ex).__name__, str(ex).strip(), ex.pgcode)

                if (
                    ex.pgcode in PG_RETRYABLE_ERRORS
                    or 'EOF detected' in str(ex)
                    or 'connection already closed' in str(ex)
                ):
                    if params_hint is None:
                        params_hint = _get_params_hint()

                    # Probably connection exhausted, we need to reconnect and retry
                    logger.warning('%s  %s:  %s', self._sql2log(sql, fix_sql), params_hint, str(ex).strip())

                    if self._db.autocommit:
                        self.log.warning('Will try to reconnect db again')

                        deadline = time.time() + 60
                        while time.time() < deadline:
                            try:
                                self.connect(self._db.autocommit)
                            except Exception as ex:
                                self.log.error(
                                    'Unable to connect: %s: %s, will try more', type(ex).__name__, str(ex).strip()
                                )
                                time.sleep(1)
                            else:
                                return False  # indicates possible retry of whole query
                    else:
                        self._opened = False
                        raise ConnectionClosed()

                raise  # dont know what to do with that error

        except Exception as ex:
            if log is True:
                logger.warning('%s  %s:  %s', self._sql2log(sql, fix_sql), params_hint, str(ex).strip())
            elif log:
                logger.warning('SQL: %s: %s', log, str(ex).strip())
            else:
                logger.warning('%s: %s', self._sql2log(sql, fix_sql), str(ex).strip())

            if not self._db.autocommit and self._opened:
                self._db.rollback()

            if cursor:
                cursor.close()

            raise
        else:
            try:
                cum_time = te2 - ts
                exec_time = te1 - ts
                fetch_time = te2 - te1

                if fetch:
                    data_len = len(data)
                elif fetch_one:
                    data_len = 1 if data is not None else 0
                else:
                    data_len = None

                if data_len is not None:
                    data_len_str = '%3d' % (data_len, )
                else:
                    data_len_str = '   '

                if log is True:
                    msg, args2 = '[%0.4fs %0.4fs][%s]  %s  %s', (
                        exec_time, fetch_time, data_len_str, self._sql2log(sql, fix_sql), params_hint
                    )
                elif log:
                    msg, args2 = '[%0.4fs %0.4fs][%s]  SQL: %s', (
                        exec_time, fetch_time, data_len_str, log
                    )

                if log:
                    if cum_time > self._sql_warning_threshold:
                        logger.warning(msg, *args2)
                    elif self._debug_sql:
                        logger.debug(msg, *args2)
            except:
                cursor.close()
                raise

        if fetch or fetch_one:
            return cursor, data
        return cursor

    def _safe_execute(self, *args, **kwargs):
        max_tries = 10
        attempt = 0

        while attempt <= max_tries:
            attempt += 1
            result = self._execute(*args, **kwargs)
            if result is False:
                time.sleep(0.3)
            else:
                return result

        import sys
        import os
        self.log.critical('We was unable to perform query')
        sys.stderr.write('We was unable to perform query, exit immidiately\n')
        os._exit(1)

    def execute_script(self, sqlscript, fix_sql=True):
        self._safe_execute(sqlscript, script=True, fix_sql=fix_sql)

    def execute_many(self, sql, params, fix_sql=True):
        self._safe_execute(sql, params, many=True, fix_sql=fix_sql)

    def execute(self, sql, params=None, log=True, fix_sql=True):
        cursor = self._safe_execute(sql, params, log=log, fix_sql=fix_sql)
        return cursor

    # Execute meths }}}

    # Query meths {{{
    def query(
        self, sql, params=None, one=False, as_dict=False, get_last_id=False, get_changed=False, log=True,
        fix_sql=True
    ):
        cursor, data = self._safe_execute(sql, params, log=log, fetch=not one, fetch_one=one, fix_sql=fix_sql)

        try:
            if get_last_id:
                return self._db.last_insert_rowid()
            if get_changed:
                return self._db.changes()

            if one:
                if as_dict:
                    if data is None:
                        return None
                    return self._to_dict(data, cursor)
                return data

            if as_dict:
                return tuple([self._to_dict(row, cursor) for row in data])

            return data
        finally:
            cursor.close()

    def iquery(self, sql, params=None, as_dict=False, log=True, fix_sql=True):
        cursor = self._execute(sql, params, log=log, fix_sql=fix_sql)
        try:
            for row in cursor:
                if as_dict:
                    yield self._to_dict(row, cursor)
                else:
                    yield row
        except BaseException as ex:
            self.log.warning('%s: iquery got exception: %s: %s', self._sql2log(sql, fix_sql), type(ex), str(ex))
            raise
        finally:
            cursor.close()

    def query_one(self, sql, args=(), as_dict=False, log=True, fix_sql=True):
        return self.query(sql, args, one=True, as_dict=as_dict, log=log, fix_sql=fix_sql)

    def query_col(self, sql, args=(), log=True, fix_sql=True):
        return [row[0] for row in self.query(sql, args, log=log, fix_sql=fix_sql)]

    def query_one_col(self, sql, args=(), log=True, fix_sql=True):
        result = self.query_one(sql, args, log=log, fix_sql=fix_sql)
        try:
            return result[0]
        except TypeError:
            return None
    # Query meths }}}

    # Transactions {{{
    def commit(self):
        assert not self._db.autocommit, 'This connection is in autocommit mode, you cant commit explicitly'

        ts = time.time()
        try:
            self.execute('COMMIT', log=False)
        finally:
            if self._debug_transactions:
                transaction_time = time.time() - ts
                if isinstance(self._debug_transactions, (int, float)) and transaction_time < self._debug_transactions:
                    return
                self.log_sql.debug('[%0.4fs]  COMMIT', transaction_time)

    def rollback(self):
        assert not self._db.autocommit, 'This connection is in autocommit mode, you cant rollback explicitly'

        ts = time.time()
        try:
            self.execute('ROLLBACK')
        finally:
            if self._debug_transactions and not self._debug_sql:
                transaction_time = time.time() - ts
                if isinstance(self._debug_transactions, (int, float)) and transaction_time < self._debug_transactions:
                    return
                self.log_sql.info('[%0.4fs]  ROLLBACK', transaction_time)

    def begin(self):
        self.execute('BEGIN')
    # Transactions }}}

    # Migrations {{{
    def _init_dbmaintain_if_needed(self):
        try:
            return self._migrate_get_version()
        except Database.errors.UndefinedTable:
            self.begin()
            self.execute(textwrap.dedent('''
                CREATE TABLE dbmaintain (
                    key         VARCHAR         NOT NULL,
                    value_text  VARCHAR             NULL,
                    value_int   INTEGER             NULL,

                    PRIMARY KEY (key),
                    CONSTRAINT dbmaintain_value_int_or_text
                        CHECK (value_text IS NOT NULL OR value_int IS NOT NULL)
                )
            '''), fix_sql=False)
            self.execute('INSERT INTO dbmaintain (key, value_int) VALUES (%s, %s)', ('schema_version', 0))
            self.execute(textwrap.dedent('''
                CREATE TABLE lock (
                    name        VARCHAR(255)    NOT NULL,
                    deadline    INTEGER         NOT NULL,
                    client      VARCHAR(255)    NOT NULL,

                    PRIMARY KEY (name)
                )
            '''), fix_sql=False)
            self.execute('INSERT INTO lock VALUES (%s, %s, %s)', ('dbop', time.time() + 3600, 'alpha'))
            self.commit()

            return 0

    def migrate(self):
        current_version = self._init_dbmaintain_if_needed()

        max_version = 0

        while True:
            max_version += 1
            if hasattr(self, '_migrate_%d_to_%d' % (max_version - 1, max_version)):
                pass
            else:
                max_version -= 1
                break

        if current_version != max_version:
            self.log.info('Migrating db (current v%d, we want v%d)', current_version, max_version)

            for i in range(current_version, max_version):
                self.log.debug('  migrating %d to %d...', i, i + 1)
                getattr(self, '_migrate_%d_to_%d' % (i, i + 1))()

            for i in range(current_version, max_version, -1):
                self.log.debug('  migrating %d to %d...', i, i - 1)
                getattr(self, '_migrate_%d_to_%d' % (i, i - 1))()

    def _migrate_get_version(self):
        return self.query_one_col('SELECT value_int FROM dbmaintain WHERE key = %s', ('schema_version', ))

    def _migrate_0_to_1(self):
        self.execute(textwrap.dedent('''
            BEGIN;

            CREATE TYPE session_type_t AS ENUM ('upload', 'download');
            CREATE TYPE session_state_t AS ENUM ('new', 'active', 'archive');
            CREATE TYPE origin_t AS ENUM ('evoq', 'hot', 'qdm');                -- qdm -- regular user backups
            CREATE TYPE storage_revision_state_t AS ENUM (
                'draft',    -- revision is not finished yet
                'active',   -- revision is active and usable to create vm's from
                'archive'   -- revision marked as archive and will be deleted soon
            );

            CREATE TABLE storage_revision (
                vm_id               VARCHAR         NOT NULL,                   -- vm id
                rev_id              INTEGER         NOT NULL,                   -- rev id
                key                 CHAR(64)        NOT NULL,                   -- unique key which allows to use revision
                state               storage_revision_state_t    NOT NULL DEFAULT ('draft'),

                origin              origin_t        NOT NULL,                   -- revision origin type
                create_ts           INTEGER         NOT NULL,                   -- timestamp rev was created
                access_ts           INTEGER             NULL,                   -- timestamp rev was accessed (new vm launched)
                access_cnt          INTEGER         NOT NULL DEFAULT 0,         -- how many times revision was used to start new vm

                filemap             BYTEA               NULL,                   -- msgpack-ed file map dict

                vmspec              BYTEA               NULL,                   -- msgpack-ed vmspec

                PRIMARY KEY (vm_id, rev_id),
                UNIQUE (key),

                -- If we have active state we must have filemap
                CONSTRAINT storage_revision_filemap_not_null_in_active_state
                    CHECK (state != 'active' OR filemap IS NOT NULL)
            );

            CREATE TABLE session (
                key                 CHAR(64)        NOT NULL,                   -- blake2b hex key (full perms)
                type                session_type_t  NOT NULL,                   -- session type
                state               session_state_t NOT NULL    DEFAULT 'new',  -- state
                state_ts            INTEGER         NOT NULL,                   -- ts state was last changed
                modify_ts           INTEGER         NOT NULL,                   -- last timestamp this session was modified by client
                origin              origin_t        NOT NULL,                   -- session origin type

                vm_id               VARCHAR         NOT NULL,                   -- vm id this session made for
                rev_id              INTEGER             NULL,                   -- rev
                node_id             VARCHAR             NULL,                   -- (obsolete, will be removed)

                run_vm_id           VARCHAR             NULL,                   -- vm_id this session running on
                run_node_id         VARCHAR             NULL,                   -- node_id this session running on

                -- session statistics
                bytes_total         INTEGER             NULL,                   -- total bytes need to work on
                bytes_done          INTEGER             NULL,                   -- total bytes done
                speed_bps           INTEGER             NULL,                   -- current speed in bps computed by client

                -- mds tvm ticket for client
                mds_tvm_ticket      VARCHAR             NULL,                   -- tvm ticket for mds
                mds_tvm_ticket_ts   INTEGER             NULL,                   -- ts tvm ticket was generated

                PRIMARY KEY (key),

                -- session can be created with vm_id set (but no such record in storage_revision) and rev_id set to NULL
                -- after that new storage_revision will be created with same vm_id and new rev_id
                -- and session can be updated with new generated rev_id afterwards
                FOREIGN KEY (vm_id, rev_id) REFERENCES storage_revision
            );

            CREATE TYPE hashtype_t AS ENUM ('sha256', 'b2');
            CREATE TYPE mds_storage_t AS ENUM ('storage-int-mdst', 'storage-int-mds');

            CREATE TABLE storage_block (
                id                  BIGSERIAL       NOT NULL,                   -- unique block id
                hashtype            hashtype_t      NOT NULL,                   -- type of hash used
                hash                CHAR(64)        NOT NULL,                   -- hash in hex
                size                INTEGER         NOT NULL,                   -- block size in bytes
                mds_storage         mds_storage_t   NOT NULL,                   -- mds storage
                mds_key             VARCHAR         NOT NULL,                   -- mds storage key
                mds_ttl             INTEGER         NOT NULL,                   -- mds storage ttl used (deadline ts)

                PRIMARY KEY (id)
            );

            CREATE TABLE storage_block_removed (
                id                  BIGSERIAL       NOT NULL,                   -- unique block id
                mds_storage         mds_storage_t   NOT NULL,                   -- mds storage
                mds_key             VARCHAR         NOT NULL,                   -- mds storage key
                remove_success      BOOL            NOT NULL DEFAULT false,

                PRIMARY KEY (id)
            );

            CREATE INDEX ON storage_block (mds_key);                            -- used for searching duplicate blocks by mds_key

            CREATE TABLE storage_data (
                vm_id               VARCHAR         NOT NULL,
                rev_id              INTEGER         NOT NULL,
                idx                 SMALLINT        NOT NULL,                   -- with 512mb blocks this will limit max backup to 16Tb
                block_id            BIGINT          NOT NULL,

                FOREIGN KEY (block_id) REFERENCES storage_block,
                FOREIGN KEY (vm_id, rev_id) REFERENCES storage_revision,

                UNIQUE (vm_id, rev_id, idx)                                     -- we will need to sort by idx as well
            );

            CREATE INDEX ON storage_data (block_id);

            CREATE TYPE audit_severity_t AS ENUM ('info', 'warning', 'error');
            CREATE TABLE audit (
                vm_id               VARCHAR         NOT NULL,
                session             CHAR(64)        NOT NULL,
                timestamp           INTEGER         NOT NULL,
                severity            audit_severity_t    NOT NULL,
                message             TEXT            NOT NULL,

                FOREIGN KEY (session) REFERENCES session (key)
            );

            CREATE INDEX audit_timestamp_idx ON audit (timestamp ASC);
            CREATE INDEX audit_session_timestamp_idx ON audit (session, timestamp ASC);
            CREATE INDEX audit_vm_id_timestamp_idx ON audit (vm_id, timestamp ASC);

            CREATE TYPE log_severity_t AS ENUM ('debug', 'info', 'warning', 'error');
            CREATE TABLE log (
                node                VARCHAR         NOT NULL,
                name                VARCHAR         NOT NULL,
                timestamp_ms        BIGINT          NOT NULL,
                severity            log_severity_t  NOT NULL,
                message             TEXT            NOT NULL
            );
            CREATE INDEX log_severity_timestamp_ms_idx ON log (timestamp_ms ASC, severity ASC);

            CREATE TABLE job (
                name                VARCHAR         NOT NULL,
                node                VARCHAR         NULL,
                run_ts              INTEGER         NULL,
                last_node           VARCHAR         NULL,
                last_run_ts         INTEGER         NULL,
                next_run_ts         INTEGER         NOT NULL,
                cnt_success         INTEGER         NOT NULL DEFAULT 0,
                cnt_failed          INTEGER         NOT NULL DEFAULT 0
            );

            INSERT INTO job (name, next_run_ts) VALUES ('data_cleaner', 0);
            INSERT INTO job (name, next_run_ts) VALUES ('block_cleaner', 0);

            CREATE TYPE evoq_state_t AS ENUM ('init', 'wait', 'run', 'fail', 'stop', 'done');

            CREATE TABLE evoq (
                id                  BIGSERIAL       NOT NULL,
                vm_id               VARCHAR         NOT NULL,
                node_id             VARCHAR         NOT NULL,
                session_key         CHAR(64)        NULL,
                state               evoq_state_t    NOT NULL DEFAULT 'init',
                extra               BYTEA           NULL,                        -- extra payload

                active              BOOLEAN         NOT NULL DEFAULT false,     -- eviction request is active
                init_ts             INTEGER         NOT NULL DEFAULT 0,         -- ts evoq record was created

                run_node            VARCHAR         NOT NULL,
                run_cnt             INTEGER         NOT NULL,
                run_ts              INTEGER         NOT NULL,
                run_duration        INTEGER         NOT NULL,
                info                VARCHAR         NOT NULL,                   -- text info for analysing what's happening

                fail_cnt            INTEGER         NOT NULL DEFAULT 0,

                PRIMARY KEY (id),
                FOREIGN KEY (session_key) REFERENCES session (key)
            );

            CREATE INDEX evoq_vm_id_node_id_idx ON evoq (vm_id, node_id);

            CREATE TABLE evoq_log (
                node                VARCHAR         NOT NULL,
                evoq_id             BIGINT          NOT NULL,
                timestamp_ms        BIGINT          NOT NULL,
                severity            log_severity_t  NOT NULL,
                message             TEXT            NOT NULL,

                FOREIGN KEY (evoq_id) REFERENCES evoq (id)
            );

            CREATE TABLE stat_worker (
                id                  BIGSERIAL       NOT NULL,
                start_ts            INTEGER         NOT NULL,                   -- time sender was started
                run_ts              INTEGER         NOT NULL,                   -- last send time
                run_node            VARCHAR         NOT NULL,                   -- node sender running on
                solomon_send_cnt    INTEGER         NOT NULL,                   -- how many data we sent to solomon during this run

                PRIMARY KEY (id)
            );

            CREATE TABLE vm (
                id                  BIGSERIAL       NOT NULL,
                vm_id               VARCHAR         NOT NULL,
                segment             VARCHAR         NOT NULL,
                node_id             VARCHAR         NOT NULL,
                exists              BOOL            NOT NULL,                   -- false if vm removed from yp
                last_touch_ts       INTEGER         NOT NULL,                   -- last time we tried to hotbackup this vm
                update_ts           INTEGER         NOT NULL,
                hot_period          INTEGER         NOT NULL,
                hot_allow           BOOL            NOT NULL
            );

            CREATE TABLE revision_user (
                vm_id               VARCHAR         NOT NULL,
                rev_id              INTEGER         NOT NULL,
                user_id             VARCHAR         NOT NULL,

                PRIMARY KEY (vm_id, rev_id, user_id),
                FOREIGN KEY (vm_id, rev_id) REFERENCES storage_revision
            );

            CREATE TABLE revision_group (
                vm_id               VARCHAR         NOT NULL,
                rev_id              INTEGER         NOT NULL,
                group_id            VARCHAR         NOT NULL,

                PRIMARY KEY (vm_id, rev_id, group_id),
                FOREIGN KEY (vm_id, rev_id) REFERENCES storage_revision
            );

            UPDATE dbmaintain SET value_int = %s WHERE key = %s;

            COMMIT;
        '''), (1, 'schema_version'), fix_sql=False)  # noqa

    def _migrate_1_to_0(self):
        self.execute(textwrap.dedent('''
            BEGIN;

            DROP TABLE storage_data;
            DROP TABLE storage_revision;
            DROP TABLE storage_block;
            DROP TABLE session;
            DROP TYPE storage_revision_state_t;
            DROP TYPE mds_storage_t;
            DROP TYPE hashtype_t;
            DROP TYPE session_state_t;
            DROP TYPE session_type_t;

            UPDATE dbmaintain SET value_int = %s WHERE key = %s;

            COMMIT;
        '''), (0, 'schema_version'), fix_sql=False)
    # Migrations }}}
