#!/usr/bin/env python

import argparse
import fcntl
import fileinput
import logging
import os
import random
import re
import shutil
import socket
import subprocess
import sys
import tempfile
import time
from multiprocessing import Pool

import vault_client

import psycopg2

VERSION_MAP = {'9.6': '906',
               '10' : '1000'}
YAV_TOKEN_FILE_NAME = '.yav_oauth_token'


def setup_logging(shard, level='DEBUG', logfile='/var/log/pg_backupcheck.log'):
    logger = logging.getLogger()
    logger.setLevel(level)
    formatter = logging.Formatter('%(asctime)s\t%(levelname)s\t{}\t%(message)s'.format(shard))
    stdout_handler = logging.StreamHandler()
    stdout_handler.setFormatter(formatter)
    stdout_handler.setLevel(level)
    file_handler = logging.FileHandler(logfile)
    file_handler.setFormatter(formatter)
    file_handler.setLevel(level)
    logger.addHandler(file_handler)
    logger.addHandler(stdout_handler)

def set_lock(lockfile):
    if not os.path.isfile(lockfile):
        with open(lockfile, 'w') as f:
            pass
    lock = open(lockfile, 'r+')
    try:
        fcntl.flock(lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
    except:
        logging.debug('locking failed', exc_info=True)
        return False
    else:
        logging.debug('locking succeed')
        return True

def test_read_pg_table(conninfo, database, table):
    conn_string = '{} dbname={}'.format(conninfo, database)
    conn = psycopg2.connect(conn_string)
    cur = conn.cursor()
    logging.debug('dbname: %s table: %s Started table read test' % (database, table))
    schema, table = table.split('.')
    cur.execute("SET lock_timeout TO 0;")
    try:
        cur.execute('COPY "{}"."{}" TO \'/dev/null\''.format(schema, table))
    except Exception as e:
        logging.critical('dbname: %s table: %s Table read test FAILED' % (database, table), exc_info=True)
        status = False
    else:
        logging.debug('dbname: %s table: %s Table read test PASSED' % (database, table))
        status = True
    conn.close()
    return (database, schema, table), status

def test_vacuum_pg_table(conninfo, database, table):
    conn_string = '{} dbname={}'.format(conninfo, database)
    conn = psycopg2.connect(conn_string)
    conn.set_session(autocommit=True)
    cur = conn.cursor()
    logging.debug('dbname: %s table: %s Started table vacuum test' % (database, table))
    schema, table = table.split('.')
    try:
        cur.execute('VACUUM (DISABLE_PAGE_SKIPPING) "{}"."{}"'.format(schema, table))
    except Exception as e:
        logging.critical('dbname: %s table: %s Table vacuum test FAILED' % (database, table), exc_info=True)
        status = False
    else:
        logging.debug('dbname: %s table: %s Table vacuum test PASSED' % (database, table))
        status = True
    conn.close()
    return (database, schema, table), status


def get_dir_size(start_path):
    total_size = 0
    files = 0
    for dirpath, dirnames, filenames in os.walk(start_path):
        for f in filenames:
            files += 1
            fp = os.path.join(dirpath, f)
            # skip if it is symbolic link
            if not os.path.islink(fp):
                total_size += os.path.getsize(fp)
    return total_size, files


class CheckBackupError(Exception):
    pass


class CheckBackup:
    def __init__(self, cluster, workdir='/u0/backupcheck', work_subdir=None, pg_version='9.6'):
        self.cluster = cluster
        self.workdir = workdir
        self.walg_template_dir = os.path.join(workdir, 'wal-g')
        self.pg_version = pg_version
        logging.debug('cluster: %s, workdir: %s, walg_template_dir: %s' % (self.cluster, self.workdir, self.walg_template_dir))
        if not work_subdir:
            logging.debug('work_subdir arg not set, creating')
            self.work_subdir = self.create_work_subdir(cluster, os.path.join(workdir, 'data'))
        else:
            self.work_subdir = work_subdir
        logging.debug('work_subdir: %s' % self.work_subdir)
        self.pg_data_dir = os.path.join(self.work_subdir, 'data')
        self.walg_conf_dir = os.path.join(self.work_subdir, 'wal-g')
        self.walg_files_size = 0
        self.walg_files_count = 0

        self.logs_dir = os.path.join(self.work_subdir, 'logs')
        logging.debug('Trying to create subdir for logs %s' % self.logs_dir)
        try:
            os.mkdir(self.logs_dir)
        except OSError as err:
            logging.debug('Skipping: %s' % err)

        self.pg_conn_string = 'user=postgres connect_timeout=5 host={}'.format(self.work_subdir)
        self.pg_socket = os.path.join(self.work_subdir, '.s.PGSQL.5432')
        self.pg_ctl = '/usr/lib/postgresql/{}/bin/pg_ctl'.format(self.pg_version)
        self.stat = {}
        logging.info('Starting backup test for cluster {}. PG version {}. PG data dir {}'.format(self.cluster, self.pg_version, self.pg_data_dir))


    @staticmethod
    def create_work_subdir(cluster, workdir):
        suffix = '.backupcheck'
        prefix = '{}.'.format(cluster)
        work_subdir = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=workdir)
        return work_subdir

    def _fetch_gpg_key(self):
        pgp_key = ''
        keyfile = os.path.join(self.workdir, 'keys', self.cluster)
        try:
            pgp_key = open(keyfile, 'r').read()
        except IOError:
            logging.error('Can not open file with gpg secret, proceeding with yav')
        if not pgp_key:
            yav_token = os.path.join(os.path.expanduser("~"), YAV_TOKEN_FILE_NAME)
            with open(yav_token) as f:
                yav_oauth_token = f.read().strip()
            yav = vault_client.instances.Production(authorization=yav_oauth_token)
            yav_sec_id = yav.list_secrets(query='pg_gpg.' + self.cluster, page_size=500)[0]['uuid']
            pgp_key = yav.get_version(yav_sec_id, packed_value=True)['value']['armored.gpg']
        return pgp_key

    def _prepare_walg_envdir(self):
        logging.debug('Preparing wal-g envdir')
        pgp_keypath = os.path.join(self.walg_conf_dir, 'PGP_KEY')
        pgp_key = self._fetch_gpg_key()
        walg_vars = {'WALG_PGP_KEY_PATH': pgp_keypath,
                     'WALE_S3_PREFIX': 's3://disk-backup-pg/wal-e/{}/{}'.format(self.cluster, VERSION_MAP[self.pg_version])}
        shutil.copytree(self.walg_template_dir, self.walg_conf_dir)
        for var, value in walg_vars.items():
            var_file = os.path.join(self.walg_conf_dir, 'envdir', var)
            with open(var_file, 'w') as config:
                config.write(value + '\n')
        with open(pgp_keypath, 'w') as key_file:
            key_file.write(pgp_key)

    def _prepare_pg_configs(self):
        logging.info('Patching postgresql config files')
        # Create minimal postgresql.conf and include other config files.
        main_postgresql_conf_file = os.path.join(self.pg_data_dir, 'postgresql.conf')
        logging.debug('Creating postgresql.conf in data dir %s'  % main_postgresql_conf_file)
        main_postgresql_conf_contents = {'hba_file'       : os.path.join(self.pg_data_dir, 'conf.d/pg_hba.conf'),
                                         'data_directory' : self.pg_data_dir,
                                         'include'        : os.path.join(self.pg_data_dir, 'conf.d/postgresql.conf')}
        with open(main_postgresql_conf_file, 'w') as f:
            for option, value in main_postgresql_conf_contents.items():
                f.write('{} = \'{}\'\n'.format(option, value))

        # Populate recovery.conf
        recovery_conf = os.path.join(self.pg_data_dir, 'recovery.conf')
        logging.debug('Creating %s' % recovery_conf)
        recovery_conf_contents = {'recovery_target'         : 'immediate',
                                  'recovery_target_timeline': 'latest',
                                  'restore_command'         : '/usr/bin/envdir {} /usr/bin/wal-g wal-fetch "%f" "%p"'.format(os.path.join(self.walg_conf_dir, 'envdir')),
                                  'recovery_target_action'  : 'promote'}
        logging.debug('recovery.conf: %s' % repr(recovery_conf_contents))
        with open(recovery_conf, 'w') as f:
            for option, value in recovery_conf_contents.items():
                f.write('{} = \'{}\'\n'.format(option, value))

        # Patch lines in conf.d/postgresql.conf
        postgresql_conf = fileinput.input(os.path.join(self.pg_data_dir, 'conf.d/postgresql.conf'), inplace=True, backup='.bak')
        logging.debug('Patching postgresql.conf in conf.d')
        for line in postgresql_conf:
            line = re.sub(r'^ssl', r'#ssl', line)
            line = re.sub(r'^archive_command = .*', 'archive_command = \'/bin/true\'', line)
            line = re.sub(r'^archive_mode = .*', 'archive_mode = off', line)
            line = re.sub(r'^wal_level = .*', 'wal_level = minimal', line)
            line = re.sub(r'^max_wal_senders = .*', 'max_wal_senders = 0', line)
            line = re.sub(r'^listen_addresses = .*', 'listen_addresses = \'\'', line)
            line = re.sub(r'/var/lib/postgresql/9.6/data', self.pg_data_dir, line)
            line = re.sub(r'/var/lib/postgresql/10/data', self.pg_data_dir, line)
            line = re.sub(r'^hba_file', '#hba_file', line)
            line = re.sub(r'^external_pid_file', '#external_pid_file', line)
            line = re.sub(r'^log_directory = .*', 'log_directory = \'{}\''.format(self.logs_dir), line)
            line = re.sub(r'^shared_preload_libraries =', '#shared_preload_libraries =', line)
            line = re.sub(r'^#unix_socket_directories =.*', 'unix_socket_directories = \'{}\''.format(self.work_subdir), line)
            # does not actually print line, but writes it to file.
            print(line.rstrip())

    def _clean(self):
        self.stop_pg()
        logging.info('Removing temp data dir %s' % self.work_subdir)
        shutil.rmtree(self.work_subdir)

    def _pg_is_ready(self):
        logging.debug('Checking if postgres is ready to accept connections')
        try:
            conn = psycopg2.connect(self.pg_conn_string)
            cur = conn.cursor()
            cur.execute("SELECT 10;")
            result = cur.fetchone()[0]
            conn.close()
        except psycopg2.Error as err:
            logging.debug('Postgres is not ready', exc_info=True)
            status = False
        else:
            status = True if result == 10 else False
        logging.debug('Postgres is ready: {}'.format(status))
        return status

    def _pg_is_rw(self):
        logging.debug('Checking if postgres is ready to accept write requests')
        try:
            conn = psycopg2.connect(self.pg_conn_string)
            cur = conn.cursor()
            cur.execute('SELECT pg_is_in_recovery()')
            pg_is_ro = cur.fetchone()[0]
            conn.close()
        except psycopg2.Error as err:
            logging.debug('Postgres is not ready', exc_info=True)
            status = False
        else:
            status = pg_is_ro is False
        logging.debug('Postgres is in read-write mode: {}'.format(status))
        return status

    def _wait_for_pg(self, tries=600):
        logging.debug('Waiting for database')
        for i in range(tries):
            logging.debug('Try # %s' % i)
            if self._pg_is_ready() and self._pg_is_rw():
                return True
            logging.debug('Postgres is not ready yet, checking if it is up')
            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            time.sleep(10)
            try:
                sock.connect(self.pg_socket)
            except Exception:
                logging.error('Can not establish unix socket connection.', exc_info=True)
                raise CheckBackupError
            else:
                logging.debug('UNIX socket connection established, still waiting...')
                time.sleep(30)

    def list_databases(self):
        conn = psycopg2.connect(self.pg_conn_string)
        cur = conn.cursor()
        cur.execute("SELECT datname FROM pg_database WHERE datistemplate = false;")
        databases = [x[0] for x in cur.fetchall()]
        conn.close()
        return databases

    def list_tables(self, database):
        conn_string = '{} dbname={}'.format(self.pg_conn_string, database)
        conn = psycopg2.connect(conn_string)
        cur = conn.cursor()
        cur.execute("SELECT table_schema, table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE';")
        tables = ['{}.{}'.format(t[0], t[1]) for t in cur.fetchall()]
        conn.close()
        return tables

    def backup_fetch(self, backup_id='LATEST'):
        logging.info('Fetching database backup with wal-g')
        time_start = time.time()
        self._prepare_walg_envdir()
        walg_stdout_log = open(os.path.join(self.logs_dir, 'wal-g.stdout.log'), 'a')
        walg_stderr_log = open(os.path.join(self.logs_dir, 'wal-g.stderr.log'), 'a')
        walg_env_dir = os.path.join(self.walg_conf_dir, 'envdir')
        cmd = ['cd', '/', ';', '/usr/bin/envdir', walg_env_dir, '/usr/bin/wal-g', 'backup-fetch', self.pg_data_dir, backup_id]
        logging.debug(' '.join(cmd))
        res = subprocess.call(' '.join(cmd), shell=True, stderr=walg_stderr_log, stdout=walg_stdout_log)
        walg_stdout_log.close()
        walg_stderr_log.close()
        #data_size = os.path.getsize(self.pg_data_dir)
        logging.debug('wal-g exited with status: %s' % res)
        time_total = time.time() - time_start
        self.stat['backup_fetch'] = time_total
        self.walg_files_size, self.walg_files_count = get_dir_size(self.pg_data_dir)
        logging.info('Data directory size after restore: %s bytes' % self.walg_files_size)
        if res != 0:
            logging.error('wal-g failed to fetch backup. Exit status: %s' % res)
            raise CheckBackupError

    def start_pg(self):
        self._prepare_pg_configs()
        pg_ctl_stdout_log = open(os.path.join(self.logs_dir, 'pg_ctl.stdout.log'), 'a')
        pg_ctl_stderr_log = open(os.path.join(self.logs_dir, 'pg_ctl.stderr.log'), 'a')
        cmd = ['cd', '/', ';', self.pg_ctl, 'start', '-W', '-D', self.pg_data_dir]
        logging.info('Starting postgres with cmd: %s' % ' '.join(cmd))
        res = subprocess.call(' '.join(cmd), shell=True, stderr=pg_ctl_stderr_log, stdout=pg_ctl_stdout_log)
        pg_ctl_stderr_log.close()
        pg_ctl_stdout_log.close()
        logging.debug('pg_ctl exit status: %s' % res)
        if res != 0:
            logging.error('pg_ctl failed to start postgres. Exit status: %s' % res)
            raise CheckBackupError

    def stop_pg(self):
        logging.info('Stoppping postgres')
        pg_ctl_stdout_log = open(os.path.join(self.logs_dir, 'pg_ctl.stdout.log'), 'a')
        pg_ctl_stderr_log = open(os.path.join(self.logs_dir, 'pg_ctl.stderr.log'), 'a')
        cmd = ['cd', '/', ';', self.pg_ctl, 'stop', '-w', '-t', '300', '-D', self.pg_data_dir]
        logging.info('Sttoping postgres with cmd: %s' % ' '.join(cmd))
        res = subprocess.call(' '.join(cmd), shell=True, stderr=pg_ctl_stderr_log, stdout=pg_ctl_stdout_log)
        pg_ctl_stderr_log.close()
        pg_ctl_stdout_log.close()
        logging.debug('pg_ctl exit status: %s' % res)
        if res != 0:
            logging.error('pg_ctl failed to stop postgres. Exit status: %s' % res)
            raise CheckBackupError

    def check_data(self, workers=16):
        logging.info('Startting data tests')
        time_start = time.time()
        self._wait_for_pg()
        time_pg_started = time.time() - time_start
        self.stat['pg_start'] = time_pg_started
        dbs = self.list_databases()
        logging.info('Database list: %s' % ','.join(dbs))
        tables = []
        success = []
        failures = []
        for db in dbs:
            for table in self.list_tables(db):
                tables.append((self.pg_conn_string, db, table))
        logging.debug('Spawning %s workers for running data test' % workers)
        pool = Pool(processes=workers)
        results = []
        for each in tables:
            read_test_result = pool.apply_async(test_read_pg_table, each)
            results.append(read_test_result)
            vacuum_test_result = pool.apply_async(test_vacuum_pg_table, each)
            results.append(vacuum_test_result)
        pool.close()
        pool.join()
        failed = []
        for each in results:
            table, status = each.get()
            if status is True:
                success.append(table)
            else:
                failures.append(table)
        time_check_data = time.time() - time_start
        self.stat['check_data'] = time_check_data
        logging.info('Data check test finished. Passed: %s Failed: %s' % (len(success), len(failures)))
        if len(failures) > 0:
            logging.error('For %s table(s) tests failed' % len(failures))
            raise CheckBackupError

    def check_indexes(self):
        pass

    def full_test(self, delete_data=True):
        time_start = time.time()
        self.backup_fetch()
        self.start_pg()
        self.check_data()
        if delete_data:
            self._clean()
        time_finish = time.time() - time_start
        self.stat['full_test'] = time_finish

    def test_only(self):
        self.start_pg()
        self.check_data()
        self.stop_pg()

    def restore_only(self):
        self.backup_fetch()
        self.start_pg()


if __name__ == '__main__':
    lockfile = '/var/lock/pg_backupcheck.lock'
    workdir = '/u0/backupcheck'
    status_log_dir = os.path.join(workdir, 'status')
    status_file = '/u0/backupcheck/LASTSTATUS'
    status_format = '{}: {}. Data dir after wal-g backup fetch {} bytes, {} files. Took {} seconds.'

    parser = argparse.ArgumentParser()
    parser.add_argument('--shard')
    parser.add_argument('--pg-version', default='9.6')
    parser.add_argument('--work-subdir')
    parser.add_argument('--tests-only', action='store_true', default=False)
    parser.add_argument('--restore-only', action='store_true', default=False)
    parser.add_argument('--batch-mode', action='store_true', default=False)
    parser.add_argument('--batch-file')
    args = parser.parse_args()
    shard = args.shard

    setup_logging(shard, 'DEBUG')

    if not set_lock(lockfile):
        logging.error('Instance of script is already running. Please check lockfile %s' % lockfile)
        sys.exit(1)

    if args.tests_only:
        logging.info('Running in test only mode in pg dir %s' % args.work_subdir)
        if args.work_subdir is None:
            parser.error(' --work-subdir is required for --tests-only')
        runner = CheckBackup(shard, workdir=workdir, work_subdir=args.work_subdir, pg_version=args.pg_version)
        runner.test_only()

    elif args.restore_only:
        logging.info('Running in restore only mode for shard %s' % shard)
        runner = CheckBackup(shard, workdir=workdir, pg_version=args.pg_version)
        runner.restore_only()

    elif args.batch_mode:
        if args.batch_file is None:
            parser.error(' --batch-file is required for --batch-mode')
        logging.info('Running in batch mode.')
        while True:
            with open(args.batch_file) as batch_file:
                shards = [line.strip() for line in batch_file if line.strip()]
            line = random.choice(shards)
            shard, version = line.split(',')
            runner = CheckBackup(shard, workdir=workdir, pg_version=version)
            try:
                runner.full_test()
            except:
                success = False
            else:
                success = True

            test_result = status_format.format(shard, str(success), runner.walg_files_size, runner.walg_files_count, runner.stat.get('full_test', 0))
            logging.info(test_result)

            with open(status_file, 'w') as f:
                f.write(test_result + '\n')

            status_log = os.path.join(status_log_dir, time.strftime('backupcheck-%Y-%m-%d.log'))
            with open(status_log, 'a') as status_log_file:
                status_log_file.write(test_result + '\n')

            status_csv = ';'.join([str(item) for item in (shard, success, runner.walg_files_size,
                                                          runner.stat.get('backup_fetch', '0'),
                                                          runner.stat.get('pg_start', '0'),
                                                          runner.stat.get('check_data', '0'),
                                                          runner.stat.get('full_test', '0'))])
            status_csv_log = os.path.join(status_log_dir, time.strftime('backupcheck-%Y-%m-%d.csv'))
            with open(status_csv_log, 'a') as status_csv_file:
                status_csv_file.write(status_csv + '\n')

            if not success:
                logging.info('Test failed exiting...')
                sys.exit(1)

    else:
        logging.info('Full run for shard %s' % shard)
        if shard is None:
            parser.error('Please specify shard name --shard')
        runner = CheckBackup(shard, workdir=workdir, pg_version=args.pg_version)
        runner.full_test()
