#!/usr/bin/env python3
"""
Disk usage watcher (PG only for now)
Closes master if disk usage is too high
"""
import shutil
import logging
import os
import sys
import argparse

import psycopg2
from psycopg2.extensions import AsIs

RO_FLAG = '/tmp/.pg_ro'
SOFT_LIMIT_DEFAULT = 97.0
HARD_LIMIT_DEFAULT = 99.0
RESERVED_DEFAULT = 104857600


class ReplicaException(Exception):
    """
    Raised on connect to replica.
    """
    pass


def pg_setting(cur, name, as_bool=False):
    """
    Retrieve current value of setting `name`.
    If as_bool is set, return a boolean.
    """
    cur.execute('SHOW "%s"', (AsIs(name), ))
    res = cur.fetchone()
    if res is None:
        return
    (setting, ) = res
    if as_bool:
        if setting.lower() in ('false', 'off'):
            return False
        return True
    return setting


def pg_database_size(cur):
    """
    Return sum of all user database sizes in bytes
    """
    cur.execute("""
        SELECT sum(pg_database_size(datname))
        FROM pg_database
        WHERE datname != 'postgres'
        AND datistemplate = false;
    """)
    return cur.fetchone()[0]


def kill_sessions(cur):
    """
    Kills all user sessions
    """
    cur.execute("""
        SELECT pg_terminate_backend(pid)
        FROM pg_stat_activity
        WHERE usename NOT IN ('postgres', 'monitor')
    """)


def soft_close(cur, closed=True):
    """
    Sets all transactions to be RO by default in all non-system dbs
    """
    cur.execute("""
        SELECT datname from pg_database
        WHERE datname != 'postgres'
        AND datistemplate = false
    """)
    databases = [x[0] for x in cur.fetchall()]
    for database in databases:
        db_cur = acquire_master_cursor(database)
        escaped = '"%s"' % database if '-' in database else database
        currently_closed = pg_setting(
            db_cur,
            'default_transaction_read_only',
            as_bool=True, )
        if currently_closed is None or currently_closed == closed:
            continue
        if closed:
            # We need to execute close command
            cur.execute(
                'ALTER DATABASE %s SET default_transaction_read_only TO %s',
                (AsIs(escaped), closed, ), )
        else:
            # We need to execute open command
            cur.execute(
                'ALTER DATABASE %s RESET default_transaction_read_only',
                (AsIs(escaped), ), )


def hard_close(cur, closed=True):
    """
    Set dbaas.closed to true.
    This will trigger an exception in dbaas_pg_poll() and
    pgaas-proxy will mark this host as dead.
    """
    currently_closed = pg_setting(cur, 'dbaas.closed', as_bool=True)
    if currently_closed is None or currently_closed == closed:
        return True
    if closed:
        cur.execute('ALTER SYSTEM SET dbaas.closed TO %s', (closed, ))
    else:
        cur.execute('ALTER SYSTEM RESET dbaas.closed')
    cur.execute('SELECT pg_reload_conf()')


def check_closed(cur):
    """
    Check if `dbaas.closed` is set.
    """
    return bool(pg_setting(cur, 'dbaas.closed', as_bool=True))


def usage(opts):
    """
    Returns used disk space in percents.
    """
    disk_usage = shutil.disk_usage(opts.path)
    used_ratio = disk_usage.used / disk_usage.total
    return used_ratio * 100


def acquire_master_cursor(database='postgres'):
    """
    Connects to master ot throws an exception if localhost is a replica.
    """

    def is_master(conn):
        """
        Check if pg is master.
        """
        cur = conn.cursor()
        cur.execute('SELECT pg_is_in_recovery()')
        (is_replica, ) = cur.fetchone()
        if is_replica or is_replica is None:
            return False
        return not is_replica

    conn = psycopg2.connect('dbname=%s' % database)
    conn.autocommit = True
    if is_master(conn):
        return conn.cursor()
    raise ReplicaException('Cannot operate on replica')


def set_ro_flag(path):
    """
    Set flag by creating an empty file
    """
    open(path, 'w').close()


def check_ro_flag(path):
    """
    Check if ro flag exists
    """
    return os.path.exists(path)


def unset_ro_flag(path):
    """
    "Unset" ro flag-file removing it.
    """
    try:
        os.remove(path)
    except IOError:
        pass


def enforce_space_usage(log, opts):
    """
    1. Check used space.
    2. If usage is at or above `soft` limit, kill all sessions and set default
       transaction to be RO.
    3. If usage is at or above `hard` limit, kill all sessions and change
       `dbaas.closed` to true, so pgaas-proxy polls will fail and all requests
       will go to replicas.
    """
    cursor = acquire_master_cursor()
    disk_used = usage(opts)
    ro_flag = check_ro_flag(opts.flag)
    closed = check_closed(cursor)
    log.debug('Used: %.3f, soft closed: %s, hard closed: %s', disk_used,
              ro_flag, closed)
    # 1. Usage is under the limit.
    #    This a normal situation.
    #    Ensure no limits are enforced.
    if disk_used < opts.soft:
        soft_close(cursor, False)
        hard_close(cursor, False)
        unset_ro_flag(opts.flag)
    # 2. Usage is above soft limit, but is yet below the hard one.
    #    Set all user transactions to be RO by default.
    #    If there is no flag-file kill user sessions and create it.
    #    Flag file is used to prevent killing users` query on every invocation.
    #
    #    Note that this setting can be overridden by user.
    elif opts.soft <= disk_used < opts.hard:
        soft_close(cursor, True)
        if not ro_flag:
            log.info('enforcing soft limit on %.3f%%', disk_used)
            kill_sessions(cursor)
            set_ro_flag(opts.flag)
    # 3. Usage is above the hard limit.
    #    Check if dbaas.closed is toggled, and if not,
    #    do it and kill user sessions.
    elif disk_used >= opts.hard:
        hard_close(cursor, True)
        if not closed:
            log.info('enforcing hard limit on %.3f%%', disk_used)
            kill_sessions(cursor)


def parse_args():
    """
    Process cmdline arguments.
    """
    arg = argparse.ArgumentParser(
        description="""
        Disk usage enforcing script
        """, )

    arg.add_argument(
        '-r',
        '--reserved',
        type=int,
        default=RESERVED_DEFAULT,
        help='Rootfs reserved size (in bytes)', )
    arg.add_argument(
        '-s',
        '--soft',
        type=int,
        default=SOFT_LIMIT_DEFAULT,
        help='Soft limit. RO mode is activated and can be overridden by user',
    )
    arg.add_argument(
        '-d',
        '--hard',
        type=int,
        default=HARD_LIMIT_DEFAULT,
        help='Hard limit. Host becomes unavailable for read and write access',
    )
    arg.add_argument(
        '-p',
        '--path',
        type=str,
        default='/',
        help='Pgdata root path', )
    arg.add_argument(
        '-f',
        '--flag',
        type=str,
        default=RO_FLAG,
        help='RO flag file path', )
    parsed = arg.parse_args()
    if parsed.soft > parsed.hard:
        raise ValueError('soft limit cannot be greater than hard')
    return parsed


def main():
    """
    Entry point
    """
    logging.basicConfig(level=logging.DEBUG, format='%(message)s')
    log = logging.getLogger('disk_usage_limiter')
    try:
        args = parse_args()
        enforce_space_usage(log, args)
    except ReplicaException:
        # Not master. Bail out.
        sys.exit(0)
    except Exception as exc:
        log.exception('error: %s', exc)
        sys.exit(1)


if __name__ == '__main__':
    main()
