#!/usr/bin/env python

import sys
import gzip
import json
import csv
import time

import argparse

from multiprocessing import Process, Queue, Lock, Value

import psycopg2

#

csv.field_size_limit(sys.maxsize)

def now():
    return int(round(time.time() * 1000))

def ms():
    return round((time.time() * 1000),2)

def printerror(slot, timestamp, exception, session_id=None, session_line=None, query=None):
    if not args.nolog:
        error = {
            "timestamp" : timestamp,
            "query" : query,
            "exception" : exception,
            "slot" : slot,
            "session_id" : session_id,
            "session_line" : session_line
            }
        error_lock.acquire()
        error_count.value += 1
        print >> sys.stderr, json.dumps(error)
        error_lock.release()

    return
    

def worker(input_queue, connection_string, error_lock, error_count):
    terminate = False
    conn = None
    respfile = None

    while not terminate:
        ( timestamp, query, slot, session_id, session_line ) = input_queue.get()

        if query[:2] =='<q':
            terminate = True
            continue

        if query[:2] == '<d':
            if conn is not None:
                try:
                    conn.close()
                    conn = None
                except Exception as ce1:
                    errorstr = 'conn.close() failed (<d case): %s' % str(ce1)
                    printerror(slot, timestamp, errorstr, session_id, session_line)
            continue

        if conn is None:
            try:
                conn = psycopg2.connect(connection_string)
                conn.set_session(autocommit=True)
            except Exception as e:
                errorstr = 'connection failed: %s' % str(e)
                printerror(slot, timestamp, errorstr, session_id, session_line)
                continue
                

        try:
            # if connection fails, dump error message and skip to next statement
            cur = conn.cursor()
        except Exception as e:
            run_time = None
            try:
                conn.close()
                conn = None
            except Exception as ce2:
                errorstr = 'conn.close() failed (cursor case): %s' % str(ce2)
                printerror(slot, timestamp, errorstr, session_id, session_line)
                
            errorstr = 'could not create cursor: %s' % str(e)
            printerror(slot, timestamp, errorstr, session_id, session_line)
            continue

        #open response time log
        if respfile is None and args.times:
            respfile = open("response_times.%d.log" % slot, 'w')
            resplog = csv.writer(respfile)

        while timestamp > now():
            time.sleep(1.0/2000)

        begin_ts = ms()
        
        try:
            cur.execute(query)
            run_time = ms() - begin_ts
        except Exception as e:
            run_time = None
            errorstr = 'statement execution failed: %s' % str(e)
            printerror(slot, timestamp, errorstr, session_id, session_line, query)

        if args.times:
            resplog.writerow([session_id, session_line, timestamp, run_time])

    if conn is not None:
        try:
            conn.close()
        except Exception as ce3:
            errorstr = 'conn.close() failed (terminate case): %s' % str(ce3)
            printerror(slot, timestamp, errorstr, session_id, session_line)

    if respfile is not None:
        respfile.close()
    print "Terminating!"

#

slots = []

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Replay log files.")

    parser.add_argument('--control-file', default='pyreplay_control.json')
    parser.add_argument('--connect-string', default='dbname=postgres user=postgres')
    parser.add_argument('--preroll', type=int, default=10)
    parser.add_argument('--nolog', help="turn off error logging", action="store_true")
    parser.add_argument('--times', help="log response times to files", action="store_true")
    parser.add_argument('--acceleration', help="acceleration factor", type=float, default=1.0)
    parser.add_argument('--minutes', help="stop after # minutes of log", type=float)

    args = parser.parse_args()

    control_structure = None
    with open(args.control_file, 'r') as control:
        control_structure = json.load(control)

    zero_point = control_structure['zero_point']

    time_zero = now() + args.preroll*1000 - (zero_point if zero_point < 0 else 0)

    # calculate stop time if given
    stop_time = None
    if args.minutes:
        stop_time = args.minutes * 60000
    else:
        stop_time = float("inf")

    error_lock = Lock()
    error_count = Value('i', 0)

    for i in range(control_structure['slots']):
        q = Queue()
        p = Process(target=worker, args=[q, args.connect_string, error_lock, error_count])
        slots.append( (p, q,) )
        p.start()

    filepath = control_structure['processed_logs']
    fz = gzip.open(filepath)
    reader = csv.reader(fz)

    query_count = 0

    for line in reader:
        query_count += 1
        if (query_count % 10000) == 0:
            sys.stdout.write(str(query_count) + " queries dispatched\n")

        timestamp = int(int(line[0])/args.acceleration)
        slot = int(line[1])
        query = line[2]
        session_id = line[3]
        session_line = line[4]

        if timestamp > stop_time:
            break

        slots[slot][1].put((timestamp + time_zero, query, slot, session_id, session_line))

    for (p, q,) in slots:
        q.put((0, '<q', 1, None, None))
    
    for (p, q,) in slots:
        p.join()

    sys.stdout.write(str(error_count.value) + " errors\n")
