#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import, print_function, unicode_literals

import datetime
import fcntl
import json
import logging
import os
import re
import socket
import subprocess
import sys
import time
from ConfigParser import SafeConfigParser
from multiprocessing import Pool

import psycopg2

from barman.config import Config
from barman.server import Server


def get_list_of_servers(skipDbs):
    c = Config('/etc/barman.conf')
    c.load_configuration_files_directory()
    allServers = [x.name for x in c.servers()]
    return list(filter(lambda d: d not in skipDbs, allServers))


def get_last_backup(server):
    c = Config('/etc/barman.conf')
    c.load_configuration_files_directory()
    s = Server(c.get_server(server))
    return s.get_last_backup_id(status_filter=('DONE',))


def restore_pgdata(path_base, server, backup):
    for f in os.listdir(path_base):
        if 'recover_' + server in f:
            logging.error('Prev attempt for %s found: %s' % (server, f))
            return

    c = Config('/etc/barman.conf')
    c.load_configuration_files_directory()
    s = Server(c.get_server(server))
    b = s.get_backup(backup)

    cmd = ['sed', '-i', '/etc/barman.d/' + server + '.conf',
           '-e', """'s/^incr_rsync_options/;incr_rsync_options/g'"""]

    logging.debug(' '.join(cmd))
    res = subprocess.call(' '.join(cmd), shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return

    cmd = ['sed', '-i', '/etc/barman.d/' + server + '.conf',
           '-e', """'s/^bandwidth_limit/;bandwidth_limit/g'"""]

    logging.debug(' '.join(cmd))
    res = subprocess.call(' '.join(cmd), shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)

    res_path = os.path.join(path_base, 'recover_' + server + '_' + backup)

    cmd = ['mkdir', '-p', res_path]

    logging.debug(' '.join(cmd))
    subprocess.call(' '.join(cmd), shell=True,
                    stderr=sys.stderr, stdout=sys.stdout)

    if res != 0:
        return

    cmd = ['chown', 'robot-pgbarman:dpt_virtual_robots', res_path]

    logging.debug(' '.join(cmd))
    subprocess.call(' '.join(cmd), shell=True,
                    stderr=sys.stderr, stdout=sys.stdout)

    if res != 0:
        return

    cmd = ['barman', 'recover']

    if b.tablespaces is not None:
        cmd.append('--tablespace')
        for i in b.tablespaces:
            cmd.append(i.name + ':' + os.path.join(res_path, i.name))

    cmd += [server, backup, res_path]
    cmd += ['--target-time',
            (b.end_time +
             datetime.timedelta(hours=1)).strftime('%Y-%m-%d\ %H:%M:%S\ %z')]
    logging.debug(' '.join(cmd))
    res = subprocess.call(' '.join(cmd), shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)

    cmd = ['sed', '-i', '/etc/barman.d/' + server + '.conf',
           '-e', """'s/^;incr_rsync_options/incr_rsync_options/g'"""]

    logging.debug(' '.join(cmd))
    subprocess.call(' '.join(cmd), shell=True,
                    stderr=sys.stderr, stdout=sys.stdout)

    if res != 0:
        return

    cmd = ['sed', '-i', '/etc/barman.d/' + server + '.conf',
           '-e', """'s/^;bandwidth_limit/bandwidth_limit/g'"""]

    logging.debug(' '.join(cmd))
    res = subprocess.call(' '.join(cmd), shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)

    cmd = ['chmod', '700', res_path]

    logging.debug(' '.join(cmd))
    res = subprocess.call(' '.join(cmd), shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return

    cmd = ['chown', '-R', 'postgres:postgres', res_path]

    logging.debug(' '.join(cmd))
    res = subprocess.call(' '.join(cmd), shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return

    return res_path


def hack_configs(path):
    sed_cmd = "sed -i %s/conf.d/postgresql.conf" % path
    cmd = sed_cmd + """ -e "/^shared_preload_libraries/s/'.*'/''/" """
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    cmd = sed_cmd + """ -e "s/^#fsync\ =\ on/fsync\ =\ off/g" """
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    cmd = sed_cmd + """ -e "/^shared_buffers/s/=\ .*$/=\ 2GB/" """
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    cmd = sed_cmd + """ -e "/^stats_temp_directory/s/\ =\ """ + \
        """.*$/\ =\ \'pg_stat_tmp\'/" """
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    cmd = sed_cmd + """ -e 's/logging_collector\ =\ on/""" +\
        """logging_collector\ =\ off/g'"""
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    cmd = sed_cmd + """ -e "/^ssl/d" """
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    escaped = path.replace('/', '\/') + '\/conf.d\/pg_hba.conf'
    cmd = sed_cmd + """ -e "/^hba_file/s/'.*'/'%s'/" """ % escaped
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    escaped = path.replace('/', '\/')
    cmd = sed_cmd + """ -e "/^data_directory/s/'.*'/'%s'/" """ % escaped
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    sed_cmd = "sed -i %s/recovery.conf" % path
    cmd = sed_cmd + """ -e 's/cp\ /mv\ /g'"""
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    sed_cmd = "sed -i %s/conf.d/pg_hba.conf" % path
    cmd = sed_cmd + """ -e 's/host\s*postgres\s*barman/host\ all\ barman/g' """
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res

    with open('%s/postgresql.conf' % path, 'w') as f:
        f.write("include = '%s'" % (path + '/conf.d/postgresql.conf'))

    return 0


def start_postgres(server, backup, path):
    version = get_pg_version(server, backup)
    cmd = 'sudo -u postgres /usr/pgsql-%s/bin/pg_ctl start -D %s' % (version,
                                                                     path)
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    return res


def get_pg_version(server, backup):
    c = Config('/etc/barman.conf')
    c.load_configuration_files_directory()
    s = Server(c.get_server(server))
    full_version = s.get_backup(backup).version
    version = '%d.%d' % (full_version/10000, full_version/100 % 100)
    return version


def get_conn_string(server):
    c = Config('/etc/barman.conf')
    c.load_configuration_files_directory()
    conninfo = c.get_server(server).conninfo
    hostinfo = re.match('host=[0-9a-z_\.-]*', conninfo).group(0)
    conninfo = conninfo.replace(hostinfo, 'host=localhost')
    if re.match('.+(sslmode=[0-9a-z_\.-]*)', conninfo):
        sslinfo = re.match('.+(sslmode=[0-9a-z_\.-]*)', conninfo).group(1)
        conninfo = conninfo.replace(sslinfo, 'sslmode=prefer')
    return conninfo


def copy_table(conninfo, table):
    try:
        for i in xrange(6):
            try:
                conn = psycopg2.connect(conninfo)
                break
            except:
                conn = None
        cur = conn.cursor()

        logging.debug('Reading table %s: starting' % table)

        with open('/dev/null', 'w') as f:
            cur.execute("SET lock_timeout TO 0;")
            cur.copy_to(f, '.'.join(['"' + x + '"' for x in table.split('.')]))

        logging.debug('Reading table %s: looks good' % table)
        return table, True
    except Exception as e:
        logging.error('Reading table %s: %s' % (table, e))
        return table, False


def check_consistency_of_all_dbs(server, backup, path):
    try:
        conninfo = get_conn_string(server)
        conn = psycopg2.connect(conninfo)
        cur = conn.cursor()

        try:
            cur.execute('SELECT pg_xlog_replay_pause();')
            cur.fetchall()
        except Exception:
            logging.info('Replay pause failed. Continue anyway.')
            conn = psycopg2.connect(conninfo)
            cur = conn.cursor()

        cur.execute('SELECT datname FROM pg_database ' +
                    'WHERE datistemplate = false;')

        dbs = [x[0] for x in cur.fetchall()]

        for db in dbs:
            logging.info('Checking consistency of %s' % db)

            db_conninfo = re.sub('dbname=.*?(\s|$)',
                                 'dbname=%s ' % db, conninfo)

            pool = Pool(processes=16)

            res = []
            conn = None
            for i in xrange(6):
                try:
                    conn = psycopg2.connect(db_conninfo)
                    break
                except:
                    conn = None

            cur = conn.cursor()

            cur.execute("SELECT table_schema, table_name " +
                        "FROM information_schema.tables " +
                        "WHERE table_type = 'BASE TABLE';")

            tables = [x[0] + '.' + x[1] for x in cur.fetchall()]
            logging.info('Dumping %d tables' % len(tables))

            for i in tables:
                r = pool.apply_async(copy_table,
                                     (db_conninfo, i))
                res.append(r)

            pool.close()
            pool.join()

            for i in res:
                table, success = i.get()
                if not success:
                    logging.error("%s seems inconsistent" % table)
                    stop_postgres(server, backup, path)
                    return 4, None
                else:
                    logging.debug("%s ok" % table)

        stop_postgres(server, backup, path)
        return 0, path
    except Exception as e:
        logging.error(e)
        stop_postgres(server, backup, path)
        return 4, None


def check_consistency_of_one_backup(path_base, server, sizes, backup):
    if backup:
        vfs = os.statvfs(path_base)
        increase = sizes.get(server, 0)/vfs.f_bsize
        if float(vfs.f_bavail-increase)/vfs.f_blocks < 0.2:
            logging.error('No space left on device')
            return 1, None
        path = restore_pgdata(path_base, server, backup)
        if path is None:
            logging.error('Could not recover %s.' % server)
            return 1, None
        res = hack_configs(path)
        if res != 0:
            logging.error('Could not hack configs for %s.' % server)
            return 2, None
        res = start_postgres(server, backup, path)
        if res != 0:
            logging.error('Could not start PostgreSQL for %s.' % server)
            return 3, None

        conninfo = get_conn_string(server)
        for i in xrange(1, 2160):  # 36 hours
            try:
                time.sleep(60)
                conn = psycopg2.connect(conninfo)
                cur = conn.cursor()
                cur.execute('SELECT 42;')
                if cur.fetchone()[0] == 42:
                    return check_consistency_of_all_dbs(server, backup, path)
            except Exception as err:
                if 'the database system is starting up' not in err:
                    try:
                        socket.create_connection(('127.0.0.1', 5432),
                                                 timeout=5)
                    except Exception:
                        logging.error('Unable to start database')
                        stop_postgres(server, backup, path)
                        return 4, None
        stop_postgres(server, backup, path)
        logging.error('No consistent state for %s after 36 hours.' % server)
        return 4, None
    else:
        logging.error('Lask OK backup not today. Skipping server %s.' % server)
        return 5, None


def stop_postgres(server, backup, path):
    version = get_pg_version(server, backup)
    cmd = 'sudo -u postgres /usr/pgsql-' + \
          '%s/bin/pg_ctl stop -m immediate -D %s' % (version, path)
    logging.debug(cmd)
    res = subprocess.call(cmd, shell=True,
                          stderr=sys.stderr,
                          stdout=sys.stdout)
    if res != 0:
        return res


def drop_deployed_backup(server, backup, path):
    stop_postgres(server, backup, path)

    time.sleep(5)

    size = 0
    for dir_path, _, file_names in os.walk(path):
        for filename in file_names:
            try:
                file_path = os.path.join(dir_path, filename)
                file_fd = os.open(file_path, os.O_RDONLY)
                file_stat = os.fstat(file_fd)
                size += file_stat.st_size
                os.close(file_fd)
            except Exception as e:
                logging.warning('Get backup size issue: ' + str(e))

    cmd = 'rm -rf %s' % path
    logging.debug(cmd)
    subprocess.call(cmd, shell=True, stderr=sys.stderr, stdout=sys.stdout)
    return size


def init_logging(level):
    level = getattr(logging, level)
    root = logging.getLogger()
    root.setLevel(level)
    formatter = logging.Formatter("%(levelname)s\t%(asctime)s\t\t%(message)s")
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    handler.setLevel(level)
    root.handlers = [handler, ]


if __name__ == '__main__':
    if not os.path.isfile('/var/lock/check_backup_consistency.lock'):
        with open('/var/lock/check_backup_consistency.lock', 'w') as f:
            pass

    lock = open('/var/lock/check_backup_consistency.lock', 'r+')
    try:
        fcntl.flock(lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
    except:
        sys.exit(0)

    config = SafeConfigParser()
    config.read('/etc/barman-consistency-check.conf')

    sizes = {}
    try:
        with open('/tmp/check_backup_consistency.sizes') as f:
            sizes = json.loads(f.read())
    except Exception:
        pass

    try:
        path_base = config.get('main', 'path_base')
    except Exception:
        path_base = '/u0/barman'

    skip = config.get('main', 'skip_dbs').split()

    init_logging(config.get('main', 'log_level').upper())

    status_file_path = '/tmp/check_backup_consistency.status'
    if os.path.exists(status_file_path):
        status_file = open(status_file_path, 'r')
        ts, status, description = status_file.read().rstrip().split(';')
        last = datetime.datetime.fromtimestamp(float(ts))
        current_date = datetime.datetime.today()
        day_start = current_date.combine(current_date.date(),
                                         current_date.min.time())
        if last > day_start:
            logging.info('Backups have already been checked today.')
            sys.exit(0)
        status_file.close()

    problems = []
    for server in get_list_of_servers(skip):
        backup = get_last_backup(server)
        res, path = check_consistency_of_one_backup(path_base, server,
                                                    sizes, backup)
        if res != 0:
            problems.append(server)
        else:
            sizes[server] = drop_deployed_backup(server, backup, path)

    status = 0
    msg = 'All backups are consistent. Good boy!'
    if len(problems) != 0:
        problems.sort()
        status = 1
        msg = 'Clusters with failed backups are %s.' % ', '.join(problems)

    logging.info(msg)
    with open(status_file_path, 'w') as status_file:
        status_file.write('%d;%d;%s\n' % (int(time.time()), status, msg))

    try:
        with open('/tmp/check_backup_consistency.sizes', 'w') as f:
            f.write(json.dumps(sizes))
    except Exception:
        pass

    fcntl.flock(lock, fcntl.LOCK_UN)
