#! /usr/bin/env python3
import logging
import psycopg2
import re
import subprocess
import sys
import yaml

from collections import namedtuple
from multiprocessing.pool import ThreadPool

InitConfig = namedtuple('InitConfig', [
    'files',
    'params'
])
GenerateConfig = namedtuple('GenerateConfig', [
    'threads',
    'timeout',
    'sql',
    'params'
])

def validate_conninfo(conninfo):
    prod_conninfo = re.search('production', conninfo)
    corp_conninfo = re.search('corp', conninfo)
    if prod_conninfo or corp_conninfo:
        raise Exception('conninfo ' + conninfo + ' looks like production, aborting')

def init_loadtest(conninfo, config):
    logging.info('initializing loadtest helpers')
    sqls = [open(sql_file).read() for sql_file in config.files]
    with psycopg2.connect(conninfo) as connection:
        connection.autocommit = True
        with connection.cursor() as cursor:
            for sql in sqls:
                cursor.execute(sql, config.params)

def generate_thread_run(packed_args):
    conninfo, sql, uid_start, uid_end, params = packed_args
    logging.info('generating test data for uids %s %s', uid_start, uid_end)
    with psycopg2.connect(conninfo) as connection:
        connection.autocommit = True
        with connection.cursor() as cursor:
            query_params = params
            query_params['uid_start'] = uid_start
            query_params['uid_end'] = uid_end
            cursor.execute(sql, query_params)

def generate_test_data(conninfo, config):
    logging.info('generating test data')
    with open(config.sql) as f:
        sql = f.read()
    if config.params['num_users'] < config.threads:
        config = config._replace(threads=1)
    uids_per_thread = int(config.params['num_users'] / config.threads)
    pool = ThreadPool(processes=config.threads)
    chunks = [(conninfo, sql, c * uids_per_thread, (c + 1) * uids_per_thread - 1, config.params)
        for c in range(0, config.threads)]
    result = pool.map_async(generate_thread_run, chunks, 1)
    pool.close()
    result.get(config.timeout)
    pool.join()

def conninfo_args(conninfo):
    host = re.search('.*host=([.a-zA-Z0-9\-]*)', conninfo).group(1)
    port = re.search('.*port=([0-9]*)', conninfo).group(1)
    user = re.search('.*user=([a-zA-Z0-9_\-]*)', conninfo).group(1)
    dbname = re.search('.*dbname=([a-zA-Z0-9_\-]*)', conninfo).group(1)
    return [
        '-h', host,
        '-p', port,
        '-U', user,
        dbname]

def file_args(step_config):
    if 'files' not in step_config:
        return []
    args = []
    for f in step_config['files']:
        args.append('-f')
        args.append(f)
    return args

def define_args(step_config):
    if 'define' not in step_config:
        return []
    args = []
    for arg_val in step_config['define'].items():
        args.append('--define={}={}'.format(arg_val[0], arg_val[1]))
    return args

def rate_args(step_config):
    arg_names = ['client', 'jobs', 'rate', 'transactions', 'time']
    arg_pairs = [('--' + a, str(step_config[a])) for a in arg_names if a in step_config]
    ret = []
    for arg in arg_pairs:
        ret += [arg[0], arg[1]]
    return ret

def other_args(step_config):
    return step_config['args'].split() if 'args' in step_config else []

def run_subprocess(args, timeout):
    try:
        logging.info('run %s', args)
        ret = subprocess.check_output(args, stderr=subprocess.STDOUT,
            universal_newlines=True, timeout=timeout)
        logging.info(ret)
    except subprocess.CalledProcessError as e:
        logging.error('%s failed: %s', args[0], e.output)
    except subprocess.TimeoutExpired as e:
        logging.error('%s failed: timed out', args[0])

def run_step(conninfo, config):
    pgbench = 'pgbench' if 'path' not in config else config['path']
    time_delta = 30.0
    run_subprocess([pgbench]
        + file_args(config)
        + define_args(config)
        + rate_args(config)
        + other_args(config)
        + conninfo_args(conninfo),
        float(config['time']) + time_delta if 'time' in config else None)

def run_benchmark(conninfo, config):
    step_configs = config['serial']
    for step_config in step_configs:
        run_step(conninfo, step_config)

def main():
    if (len(sys.argv) != 2):
        print('usage: run.py config.yml')
        return 1
    LOGFORMAT = '%(asctime)s %(levelname)s %(message)s'
    logging.basicConfig(format=LOGFORMAT,
                    level=logging.INFO)
    config_path = sys.argv[1]
    config_doc = yaml.load(open(config_path))
    conninfo = config_doc['conninfo']
    validate_conninfo(conninfo)
    init_loadtest(conninfo, InitConfig(**config_doc['init']))
    generate_test_data(conninfo, GenerateConfig(**config_doc['generate']))
    run_benchmark(conninfo, config_doc['benchmark'])

if __name__ == '__main__':
    main()
