#!/usr/bin/env python
{% from "components/postgres/pg.jinja" import pg with context %}

import psycopg2
import sys
import os
import json
import datetime
import argparse
import subprocess

parser = argparse.ArgumentParser()

parser.add_argument('-w', '--warn',
                    type=int,
                    default=300,
                    help='Warning limit')

parser.add_argument('-c', '--crit',
                    type=int,
                    default=900,
                    help='Critical limit')

args = parser.parse_args()

def die(code=0, comment="OK"):
    if code == 0:
        print '0;OK'
    else:
        print('%d;%s' % (code, comment))
    sys.exit(0)


def count_files(path):
    proc = subprocess.Popen(
        "sudo -u postgres ls -1 %s 2>/dev/null | wc -l" % path,
        shell=True,
        stdout=subprocess.PIPE
    )
    stdout, _ = proc.communicate()
    return int(stdout.rstrip())


def count_xlogs(cur):
    wals_path = "{{ salt['pillar.get']('data:backup:archive:walsdir', pg.data + '/wals') }}",
    xlog_path = "{{ pg.wal_dir_path }}"
{% if salt['pillar.get']('data:use_wale', False) or salt['pillar.get']('data:use_walg', True) %}
    cur.execute("select setting::int from pg_settings where name = 'server_version_num'")
    (version, ) = cur.fetchone()
    max_wals_query = None
    if version < 90500:
        # See https://www.postgresql.org/docs/9.4/static/wal-configuration.html
        max_wals_query = """
            select greatest(
                (2 + current_setting('checkpoint_completion_target')::float) * current_setting('checkpoint_segments')::int + 1,
                current_setting('checkpoint_segments')::int + current_setting('wal_keep_segments')::int + 1
            )
        """
    else:
        max_wals_query = "select setting::int from pg_settings where name = 'max_wal_size'"
    cur.execute(max_wals_query)
    (max_wals, ) = cur.fetchone()
    xlogs = max(0, count_files(xlog_path) - max_wals)
{% else %}
    xlogs = 0
{% endif %}
    wals = count_files(wals_path)
    return xlogs + wals

try:
    conn = psycopg2.connect('dbname=postgres user=monitor connect_timeout=1 host=localhost')
    cur = conn.cursor()

    wals_count = count_xlogs(cur)
    if wals_count >= args.crit:
        die(2, '%d not archived WALs.' % wals_count)
    elif wals_count >= args.warn:
        die(1, '%d not archived WALs.' % wals_count)

    cur.execute("show transaction_read_only;")

    if 'on' in str(cur.fetchone()[0]):
        die(0, "OK")
    else:
        die_flag = 0

        prev_path = os.path.expanduser('/tmp/pg_xlog.prev')
        prev = {'last': 0, 'fail': 0}
        if os.path.exists(prev_path):
            with open(prev_path, 'r') as f:
                try:
                    prev = json.loads(''.join(f.readlines()))
                except Exception:
                    pass
        cur.execute("select last_archived_time, last_failed_time " +
                    "from pg_stat_archiver;")
        res = cur.fetchone()
        if res[0] is None:
            res = (datetime.datetime(1970, 1, 1, 0), res[1])
        if res[1] is None:
            res = (res[0], datetime.datetime(1970, 1, 1, 0))
        current = {'last': int(res[0].strftime("%s")),
                   'fail': int(res[1].strftime("%s"))}
        if current['last'] == prev['last']:
            if current['fail'] > current['last'] + args.crit:
                die(2, "Archiver stuck at " + res[0].strftime("%F %H:%M:%S"))
            elif current['fail'] > current['last'] + args.warn:
                die(1, "Archiver stuck at " + res[0].strftime("%F %H:%M:%S"))
            else:
                die(0, "Archiver ok")
        else:
            die(0, "Archiver ok")

        with open(prev_path, 'w') as f:
            f.write(json.dumps(current))
except Exception:
    die(1, "Could not get info about not archived xlogs")
finally:
    try:
        cur.close()
        conn.close()
    except Exception:
        pass
