#!/usr/bin/env python
#
# Measure latency from write endpoint to read endpoint.
#
# example:
#  ./measure-latency -w 'user=subs password=mypassword port=5432 dbname=subs host=master.us-west2.justin.tv' \
#    -r 'user=subs password=mypassword port=5432 dbname=subs host=subs-rds-ro.justin.tv' \
#    -u 'update tickets set access_end = %(value)s where id = 1234' \
#    -s 'select access_end from tickets where id = 1235' \
#    -t date
#
# Note that this example is not secure. Please use a .pgpass file for storing host/user/etc passwords.
# https://www.postgresql.org/docs/10/libpq-pgpass.html

import datetime
import time

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="Estimate latency between write and read endpoints")
    parser.add_argument('--writer', '-w', help="writer dsn", required=True)
    parser.add_argument('--reader', '-r', help="reader dsn", required=True)
    parser.add_argument('--update', '-u', help="The write query", required=True)
    parser.add_argument('--select', '-s', help="The read query", required=True)
    parser.add_argument('--type', '-t', default='int', help="The type of the value")
    parser.add_argument('--sleep', type=float, default=0.05, help="How long to sleep between measurement reads.")
    parser.add_argument('--max', type=float, default=30, help="How long to let a measurement go before expiring.")
    parser.add_argument('--pause', type=float, default=5, help="How long to sleep between measurements.")
    parser.add_argument('--samples', type=int, default=10, help="How many samples to take")
    return parser.parse_args()

def _measure(wc, wq, rc, rq, type, sleep, max):
    "Return how long it took to read a value from rc after writing it to wc"
    wcc = wc.cursor()
    wcc.execute(rq)
    #print(rq)
    actual = wcc.fetchone()[0]
    value = next_value_fn[type](actual)
    #print("Found {} and writing {}".format(actual, value))
    wcc.execute(wq, {'value':value})
    rcc = rc.cursor()
    start = time.time()
    timeout = start + max
    while time.time() < timeout:
        rcc.execute(rq)
        actual = rcc.fetchone()[0]
        if actual == value:
            break
        time.sleep(sleep)
    finish = time.time()
    if finish > timeout:
        print("Gave up after {}s".format(max))
    rcc.close()
    wcc.close()
    return finish - start

def _dsn_host(dsn):
    elements = dsn.split(' ')
    values = {elem.split('=')[0]:elem.split('=')[1] for elem in elements}
    return values['host']

def main():
    import math
    import psycopg2
    args = parse_args()
    wc = psycopg2.connect(args.writer)
    wc.autocommit = True
    rc = psycopg2.connect(args.reader)
    rc.autocommit = True

    samples = args.samples
    measurements = list()
    while samples:
        time.sleep(args.pause)
        samples = samples - 1
        latency = _measure(wc, args.update, rc, args.select, args.type, args.sleep, args.max)
        measurements.append(latency)
        #print("{}s".format(latency))
    measurements.sort()
    latency = sum(measurements) / args.samples
    print("Approximately {0}s to replicate row from {2} written to {1}.".format(latency, _dsn_host(args.reader),
                                                                               _dsn_host(args.writer)))
    index = int(math.floor(float(len(measurements)) * 0.5) - 1)
    print(" P50 of {0}s".format(measurements[index]))
    index = int(math.floor(float(len(measurements)) * 0.9) - 1)
    print(" P90 of {0}s".format(measurements[index]))
    print("P100 of {0}s".format(measurements[-1]))

# Functions to generate the next value

def _next_date(current):
    "Get return current plus a second as a basic query"
    return current + datetime.timedelta(seconds=1)

def _next_int(current):
    "Return current plus 1"
    return current + 1

next_value_fn = {
    'date': _next_date,
    'int': _next_int,
}

if __name__ == '__main__':
    main()
