#!/usr/bin/env python
"""
Check if local postgresql is dead long enough and resetup it
"""

import argparse
import json
import logging
import os
import subprocess
import time
import traceback
from datetime import datetime

import psycopg2

from kazoo.client import KazooClient

DISABLE_FLAG = "/tmp/.pg_resetup_disable.flag"
DEFAULT_CHECKPOINT_TIMEOUT = 300
RECENT_CHECKPOINT_TIME_FILE = '/tmp/.recent_checkpoint'
RECOVERY_TIMEOUT = "{{ salt['pillar.get']('data:pgsync:recovery_timeout', '1200') }}"
REWIND_FAIL_FLAG = '/tmp/.pgsync_rewind_fail.flag'
SALT_TIMEOUT = "86370"
SEPARATE_WAL_FLAG = "/tmp/.pg_resetup_separate_wal.flag"
STATUS_FILE = '/tmp/.pg_resetup.alive'
STREAM_FROM = "{{ salt['pillar.get']('data:pgsync:replication_source', None) }}"
ZK_HOSTS = "{{ salt['pillar.get']('data:pgsync:zk_hosts', '') }}"
ZK_LEADER_PATH = "{{ salt['pillar.get']('data:pgsync:zk_lockpath_prefix', '/pgsync/' + salt['pillar.get']('data:dbaas:cluster_id', salt['grains.get']('id').split('.')[0][:-1])) + '/leader' }}"


logging.basicConfig(
    level=logging.DEBUG, format='%(asctime)s %(levelname)-8s: %(message)s'
)
logger = logging.getLogger()


def log_exception(message):
    logger.error('%s: %s', message, traceback.format_exc().replace('\n', '\\n'))


def get_pg_version():
    clusters = subprocess.check_output(['pg_lsclusters'], stderr=subprocess.STDOUT)
    for line in clusters.splitlines():
        if line.startswith('Ver'):
            continue
        return line.split()[0]


def get_checkpoint_timeout(args):
    checkpoint_timeout = args.checkpoint_timeout
    if isinstance(checkpoint_timeout, int):
        return checkpoint_timeout
    elif checkpoint_timeout.endswith('ms'):
        return int(checkpoint_timeout[:-2]) / 1000
    elif checkpoint_timeout.endswith('s'):
        return int(checkpoint_timeout[:-1])
    elif checkpoint_timeout.endswith('min'):
        return int(checkpoint_timeout[:-3]) * 60
    elif checkpoint_timeout.endswith('h'):
        return int(checkpoint_timeout[:-1]) * 60 * 60
    elif checkpoint_timeout.endswith('d'):
        return int(checkpoint_timeout[:-1]) * 60 * 60 * 24
    else:
        return DEFAULT_CHECKPOINT_TIMEOUT


def get_latest_checkpoint(args):
    controldata = os.path.join('/usr/lib/postgresql', args.version, 'bin/pg_controldata')
    data_dir = os.path.join('/var/lib/postgresql', args.version, 'data')
    for c in str(subprocess.check_output([controldata, '-D', data_dir], stderr=subprocess.STDOUT)).split('\n'):
        if c.startswith('Time of latest checkpoint'):
            dt = datetime.strptime(c[26:].strip(), '%a %d %b %Y %I:%M:%S %p %Z')
            return dt.strftime('%Y-%m-%d %H:%M%S')


def read_recent_checkpoint_mtime(current_latest_checkpoint):
    recent_latest_checkpoint = open(RECENT_CHECKPOINT_TIME_FILE, 'r').readline()
    logger.info('Recent checkpoint: ' + recent_latest_checkpoint)
    logger.info('Latest checkpoint: ' + current_latest_checkpoint)
    if recent_latest_checkpoint == current_latest_checkpoint:
        recent_latest_checkpoint_time = os.path.getmtime(RECENT_CHECKPOINT_TIME_FILE)
        return recent_latest_checkpoint_time
    return None


def is_waiting_consistency(args):
    try:
        current_latest_checkpoint = get_latest_checkpoint(args)
    except Exception:
        return False

    if os.path.exists(RECENT_CHECKPOINT_TIME_FILE):
        checkpoint_timeout = get_checkpoint_timeout(args)
        read_recent_checkpoint_time = read_recent_checkpoint_mtime(current_latest_checkpoint)
        if read_recent_checkpoint_time is not None:
            return time.time() - read_recent_checkpoint_time < checkpoint_timeout * 3

    logger.info('Update recent_checkpoint file')
    with open(RECENT_CHECKPOINT_TIME_FILE, 'w') as f:
        f.write(current_latest_checkpoint)
    return True


def _check_pg_status():
    try:
        conn = psycopg2.connect("host=localhost port=5432 dbname=postgres user=admin connect_timeout=10 sslmode=allow")
        cursor = conn.cursor()
        cursor.execute('SELECT 1')
        if cursor.fetchone()[0] == 1:
            with open(STATUS_FILE, 'w') as status_file:
                status_file.write(str(int(time.time())))
            return True
    except Exception:
        log_exception('PostgreSQL is dead')
        return False


def read_status_file():
    if not os.path.exists(STATUS_FILE):
        # We are in initial setup phase
        logger.info('No alive pg in recorded history')
        return None
    try:
        last_status_time = int(open(STATUS_FILE, 'r').readline())
    except Exception:
        log_exception('Malformed status file')
    else:
        return last_status_time
    return None


def check_pg_status():
    """
    Check if postgresql is alive by running query
    and checking status file mtime
    """
    timeout = int(RECOVERY_TIMEOUT) * 2
    if _check_pg_status():
        return True

    last_status_time = read_status_file()
    if last_status_time is None:
        return True

    time_interval = time.time() - last_status_time
    if time_interval > timeout:
        logger.info('Pg dead timeout exceeded')
        return False

    logger.info('Pg dead timeout not exceeded (%d), %s remains' % (timeout, timeout - time_interval))
    return True


def get_master():
    """
    Get master from zookeeper
    """
    try:
        zk_conn = KazooClient(hosts=ZK_HOSTS)
        zk_conn.start()
        lock = zk_conn.Lock(ZK_LEADER_PATH)
        contenders = lock.contenders()
        zk_conn.stop()
        if contenders:
            return contenders[0]
    except Exception:
        log_exception('Unable to get master from zk')
        return None


def resetup(args):
    """
    Drop local postgresql and run highstate with master arg
    """

    cmds = [
        'rm -f ' + REWIND_FAIL_FLAG,
        'service pgsync stop',
        'service pgbouncer stop',
        'service postgresql stop',
    ]

    pg_data = os.path.join('/var/lib/postgresql', args.version, 'data')

    for wals_dir in ['pg_xlog', 'pg_wal']:
        wals_path = os.path.join(pg_data, wals_dir)
        if os.path.exists(wals_path):
            cmd = 'mountpoint -q %(path)s && touch %(flag)s && umount %(path)s || true' % {
                'path': wals_path,
                'flag': SEPARATE_WAL_FLAG
            }
            cmds.append(cmd)

    cmds.append('find %s -mindepth 1 -type f -delete' % pg_data)
    cmds.append('find %s -mindepth 1 -type d -delete' % pg_data)
    cmds.append('rm -rf ' + os.path.join('/etc/postgresql', args.version, 'data'))

    if STREAM_FROM != 'None':
        master = STREAM_FROM
    else:
        master = get_master()
        if not master:
            raise RuntimeError('No one holds the leader lock')

    with open(STATUS_FILE, 'w') as w:
        w.write(str(int(time.time())))

    for cmd in cmds:
        try:
            logger.info('Executing ' + cmd)
            out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, preexec_fn=os.setsid, shell=True)
        except Exception:
            log_exception('%s failed with: %s' % (cmd, out.replace('\n', '\\n')))
            raise

    hs_cmd = [
        'timeout', SALT_TIMEOUT, 'salt-call', 'state.highstate', 'queue=True',
        'pillar={value}'.format(value=json.dumps({'pg-master': master, 'walg-restore': True})),
    ]
    logger.info(hs_cmd)
    _ = subprocess.check_output(hs_cmd, stderr=subprocess.STDOUT, preexec_fn=os.setsid)
    logger.info('Done')


def check_rewind_flag():
    if os.path.exists(REWIND_FAIL_FLAG):
        logger.info('Rewind fail flag found')
        return True


def check_disable_flag():
    if os.path.exists(DISABLE_FLAG):
        logger.info('Resetup disable flag found')
        return True


def is_pg_dead(args):
    """
    Check if postgresql is dead more than
    recovery_timeout x 3 or if rewind fail flag exists
    """
    if check_rewind_flag():
        return True

    if check_pg_status():
        return False

    return not is_waiting_consistency(args)


def juggler_output(status=True, message='OK'):
    print('%s;%s' % (0 if status else 2, message))


def juggler_check(args):
    logger.disabled = True
    if check_rewind_flag():
        return juggler_output(False, "Rewind flag found")

    if _check_pg_status():
        return juggler_output()

    timeout = int(RECOVERY_TIMEOUT) * 10
    last_status_time = read_status_file()
    if last_status_time is not None:
        return juggler_output(False, "Pg dead timeout, %s remains" % (timeout + last_status_time - time.time()))

    try:
        current_latest_checkpoint = get_latest_checkpoint(args)
    except Exception:
        return juggler_output(False, "No checkpoint")

    read_recent_checkpoint_time = read_recent_checkpoint_mtime(current_latest_checkpoint)
    checkpoint_timeout = get_checkpoint_timeout(args)
    if read_recent_checkpoint_time is not None:
        if time.time() - read_recent_checkpoint_time > checkpoint_timeout:
            return juggler_output(False, "PG is waiting consistency, %s remains" % (checkpoint_timeout * 2))
    return juggler_output()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', type=str, default='5min', dest='checkpoint_timeout', help='Setting checkpoint_timeout')
    parser.add_argument('-m', '--version', type=str, help='Major pg version')
    parser.add_argument('--juggler', action='store_true', help='Check for juggler')
    args = parser.parse_args()
    if not args.version:
        args.version = get_pg_version()

    if args.juggler:
        return juggler_check(args)

    if check_disable_flag():
        return

    if is_pg_dead(args):
        resetup(args)
    else:
        logger.debug('All is fine')


if __name__ == '__main__':
    main()
