import argparse
import os
import os.path
import re
from dateutil import parser as timeparser

main_line_pattern = re.compile(
    r"^\d{4}-\d{2}-\d{2}\s(\d\d:\d\d:\d\d\.\d\d\d).*\{r\.y\.t\.s\.p\.(\w+)TaskHandler\s*:\s+(.*)$")
request_start_pattern = re.compile(r"^Request\s([\w\d-]+)\swill\shandle\stasks\s(.*)$")
request_send_pattern = re.compile(r"HTTP\sRequest\s([\w\d-]+):\sSending\sHTTP\srequest:")
request_completed_pattern = re.compile(r"HTTP\sRequest\s([\w\d-]+):\sRequest\scompleted;\sCode\s(\d+)")
task_success_pattern = re.compile(r"Task\s([\w\d-]+):\shandling\scompleted\ssuccessfully\swith\s(\d+)\soffers")
task_exception_pattern = re.compile(r"Task\s([\w\d-]+):\shandling\scompleted\sexceptionally")


def parse(args):
    input_files = args.input
    output_file = args.output
    if not input_files:
        input_files = [os.path.join('./input', f) for f in os.listdir("./input")]
    # print("Getting input size...")
    total_lines = 122765211
    # for fn in input_files:
    #     with open(fn) as in_file:
    #         total_lines += sum(1 for _ in in_file)
    # print("{} total lines".format(total_lines))
    line_num = 0
    started_requests = {}
    repeated = {}
    task_request_mapping = {}
    with open(output_file, 'w') as out_file:
        for fn in input_files:
            with open(fn) as in_file:
                for line in in_file:
                    line_num += 1
                    if line_num % 100000 == 0:
                        print("Processed {} of {} lines ({:.2f} %)".format(line_num, total_lines, 100 * (
                            line_num / float(total_lines))))
                    match = main_line_pattern.match(line)
                    if not match:
                        continue
                    timestamp, partner, message = match.groups()
                    if partner.endswith("Partner"):
                        partner = partner[:-1 * len("Partner")]
                    timestamp = timeparser.parse(timestamp)

                    start_match = request_start_pattern.match(message)
                    if start_match:
                        request_id, tasks = start_match.groups()
                        tasks = tasks.split(', ')
                        tasks.sort()
                        tasks = tuple(tasks)
                        started_requests[request_id] = {
                            'partner': partner,
                            'completed_tasks': 0,
                            'total_offers': 0,
                            'total_tasks': len(tasks),
                            'task_ids':  tasks
                        }
                        if tasks in repeated:
                            repeated[tasks] += 1
                        else:
                            repeated[tasks] = 1
                        for task in tasks:
                            task_request_mapping[task] = request_id
                        continue

                    send_match = request_send_pattern.search(message)
                    if send_match:
                        request_id = send_match.groups()[0]
                        if request_id in started_requests:
                            started_requests[request_id]['started'] = timestamp
                        continue

                    completed_match = request_completed_pattern.search(message)
                    if completed_match:
                        request_id, code = completed_match.groups()
                        if code != "200":
                            started_requests.pop(request_id, {})
                        elif request_id in started_requests:
                            started_requests[request_id]['completed'] = timestamp
                        continue

                    task_success_match = task_success_pattern.search(message)
                    if task_success_match:
                        task_id, num_offers = task_success_match.groups()
                        num_offers = int(num_offers)
                        request_id = task_request_mapping.pop(task_id, None)
                        if request_id and request_id in started_requests:
                            started_requests[request_id]['completed_tasks'] += 1
                            started_requests[request_id]['total_offers'] += num_offers
                            if started_requests[request_id]['completed_tasks'] == started_requests[request_id]['total_tasks']:
                                if 'completed' in started_requests[request_id]:
                                    time = started_requests[request_id]['completed'] - started_requests[request_id][
                                        'started']
                                    tasks = started_requests[request_id]['task_ids']
                                    num_repeats = repeated[tasks]
                                    res_line = "{}\t{}\t{}\t{}\t{}\n".format(
                                        partner,
                                        started_requests[request_id]['total_tasks'],
                                        started_requests[request_id]['total_offers'],
                                        time.total_seconds() * 1000.0,
                                        num_repeats
                                    )
                                    out_file.write(res_line)
                                else:
                                    print("Completed task for incomplete request ({}: {})".format(partner, request_id))
                                started_requests.pop(request_id)
                        continue

                    task_exception_match = task_exception_pattern.search(message)
                    if task_exception_match:
                        task_id = task_exception_match.groups()[0]
                        request_id = task_request_mapping.pop(task_id, None)
                        task_request_mapping.pop(request_id, None)
                        continue


def histogram(args):
    raw = {}
    sum_attempts = {}
    with open(args.input) as f:
        for line in f:
            partner, num_hotels, num_offers, timing, num_attempts = line.split("\t")
            raw.setdefault(partner, []).append(float(timing))
            sum_attempts.setdefault(partner, 0)
            sum_attempts[partner] += int(num_attempts)

    with open(args.output, 'w') as f:
        for partner, timing_list in raw.iteritems():
            timing_list.sort()
            num_points = args.num_buckets
            f.write(partner.lower() + '\t')
            for i in xrange(num_points):
                index = (len(timing_list) / num_points) * i
                f.write(str(timing_list[index]))
                if i < num_points - 1:
                    f.write('\t')
                else:
                    f.write('\t{}\n'.format(sum_attempts[partner] / float(len(timing_list)) - 1))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
    parse_log_parser = subparsers.add_parser("parse")
    parse_log_parser.set_defaults(func=parse)
    parse_log_parser.add_argument("--input", default=[], action='append')
    parse_log_parser.add_argument("--output", default="output.tsv")
    histogram_parser = subparsers.add_parser("histogram")
    histogram_parser.set_defaults(func=histogram)
    histogram_parser.add_argument("--input", default="output.tsv")
    histogram_parser.add_argument("--output", default="histogram.tsv")
    histogram_parser.add_argument("--num_buckets", default=100, type=int)

    args = parser.parse_args()
    args.func(args)
