import argparse
import datetime
import json
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np

import travel.rasp.bus.scripts.cache_research.log_parser as log_parser
import travel.rasp.bus.scripts.cache_research.suppliers as suppliers
from travel.rasp.bus.scripts.cache_research.emulator import Cache
from travel.rasp.bus.scripts.cache_research.request_queue import Queue


BAR_WIDTH = 0.1
MAX_MARKER_SIZE = 12
MAX_TICK_NUMBER = 15
MIN_MARKER_SIZE = 1


def get_rps(statistic):
    cache_hit_rps = list(map(lambda x: x / (args.norm * 60), statistic.cache_hit_numbers))
    cache_miss_rps = list(map(lambda x: x / (args.norm * 60), statistic.cache_miss_numbers))
    rps = list(map(lambda x: x / (args.norm * 60), statistic.requests_numbers))
    return rps, cache_hit_rps, cache_miss_rps


def save_requests_distribution_plot(statistics):
    fig, axes = plt.subplots(3, 1, figsize=(16, 8), sharex=True)
    for supplier in statistics.keys():
        statistic = statistics[supplier]
        rps, cache_hit_numbers, cache_miss_numbers = get_rps(statistic)
        axes[0].plot(statistic.datetimes, rps, marker='o', label=supplier)
        axes[1].plot(statistic.datetimes, cache_hit_numbers, marker='o', label=supplier)
        axes[2].plot(statistic.datetimes, cache_miss_numbers, marker='o', label=supplier)
    axes[0].set_ylabel('requests (per sec)')
    axes[1].set_ylabel('cache hits (per sec)')
    axes[2].set_ylabel('cache misses (per sec)')
    axes[-1].set_xlabel('time')
    plt.legend(loc="best")
    full_name = os.path.join(
        args.path,
        f"cache_research_suppliers_requests_distribution_ttl={args.ttl}_{os.path.basename(args.file).split('.')[0]}")
    plt.tight_layout()
    plt.savefig(full_name)


def save_cache_age_distribution_plot(statistics):
    fig, axes = plt.subplots(figsize=(20, 10))
    ages_number = args.ttl // args.norm + int(args.ttl % args.norm != 0)
    ages = [datetime.timedelta(minutes=args.norm * k) for k in range(ages_number)]
    alteration = np.arange(ages_number)
    for i, supplier in enumerate(statistics.keys()):
        times = [0] * ages_number
        if statistics[supplier].cache_entries_age_distribution:
            statistic = statistics[supplier]
            for age in ages:
                if age not in statistic.cache_entries_age_distribution:
                    statistic.cache_entries_age_distribution[age] = 0
            _, times = zip(*sorted(statistic.cache_entries_age_distribution.items()))
        axes.bar(alteration + BAR_WIDTH * i, times, width=BAR_WIDTH, label=supplier)
    axes.set_ylabel('times')
    axes.set_xlabel('age at getting (sec)')
    axes.set_xticks(alteration)
    ages_intervals = list(map(
        lambda x: f"{x}-{min(x + datetime.timedelta(minutes=args.norm), datetime.timedelta(minutes=args.ttl))}",
        sorted(ages)))
    axes.set_xticklabels(ages_intervals)
    axes.grid(True)
    axes.legend(loc="best")
    full_name = os.path.join(
        args.path,
        f"cache_research_suppliers_cache_age_distribution_ttl={args.ttl}_{os.path.basename(args.file).split('.')[0]}")
    plt.tight_layout()
    plt.savefig(full_name)


def save_warmer_plot(cached_days, direction_counter):
    def get_marker_sizes(n):
        step = (MAX_MARKER_SIZE - MIN_MARKER_SIZE) / n
        return list(map(lambda x: x ** 2, np.arange(MAX_MARKER_SIZE, MIN_MARKER_SIZE, -step)))
    directions = sorted({direction for directions in cached_days.values() for direction in directions},
                        key=lambda x: direction_counter[x], reverse=True)
    fig, axes = plt.subplots(figsize=(min(320, max(20, len(directions) // 8)), 5))
    sizes = get_marker_sizes(len(cached_days))
    for i, supplier in enumerate(segments.keys()):
        days_number = [-1] * len(directions)
        if supplier in cached_days and cached_days[supplier]:
            _, days_number = zip(*sorted(cached_days[supplier].items(),
                                         key=lambda x: direction_counter[x[0]],
                                         reverse=True))
        axes.scatter(range(len(directions)), days_number, label=supplier, s=sizes[i])
    axes.set_xticks(range(len(directions)))
    directions_labels = list(map(lambda d: d.from_id + '-' + d.to_id, directions))
    axes.set_xticklabels(directions_labels, fontsize=7)
    axes.grid(True)
    axes.set_ylabel('days')
    axes.set_xlabel('directions')
    axes.tick_params(axis='x', labelrotation=90)
    plt.legend(loc="best")
    full_name = os.path.join(args.path, f"cache_research_cached_days_"
                                        f"{os.path.basename(args.warmer).split('.')[0]}_"
                                        f"{os.path.basename(args.file).split('.')[0]}")
    plt.tight_layout()
    plt.savefig(full_name)


def save_suppliers_plots(supplier_statistics, direction_counter, cached_days):
    os.makedirs(args.path, exist_ok=True)
    save_requests_distribution_plot(supplier_statistics)
    save_cache_age_distribution_plot(supplier_statistics)
    save_warmer_plot(cached_days, direction_counter)


def find_suppliers_for_direction(direction):
    return set(filter(lambda s: (direction.from_id in segments[s] and
                                 direction.to_id in segments[s][direction.from_id]), segments.keys()))


def is_cache_not_rotten(request, supplier, current_time, ttl):
    if (request, supplier) in Cache.DATE:
        return current_time - Cache.DATE[(request, supplier)] < ttl
    return False


def get_cached_days(current_time, requests):
    ttl = datetime.timedelta(minutes=args.ttl)
    cached_days = dict()
    for supplier in segments:
        cached_days[supplier] = defaultdict(int)
    for request in requests:
        direction_suppliers = find_suppliers_for_direction(request.direction)
        for supplier in segments:
            if supplier in direction_suppliers:
                if is_cache_not_rotten(request, supplier, current_time, ttl):
                    cached_days[supplier][request.direction] += 1
                else:
                    cached_days[supplier][request.direction] += 0
            else:
                cached_days[supplier][request.direction] = -1
    return cached_days


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-t", "--ttl", default=60, type=int)
    parser.add_argument("-n", "--norm", default=15, type=int)
    parser.add_argument("-f", "--file", required=True)
    parser.add_argument("-p", "--path", default="./")
    parser.add_argument("-w", "--warmer", required=True)

    args = parser.parse_args()

    raw_data = []
    with open(args.file, 'r') as f:
        for line in f.readlines():
            raw_data.append(json.loads(line))

    log_entries = log_parser.build_log_entries(raw_data)

    print('start getting segments')
    segments = suppliers.get_segments()
    print('got segments')

    queue = Queue(
        segments=segments,
        ttl=args.ttl,
        norm=args.norm,
        log_entries=log_entries,
        warmer_config=args.warmer)

    print("start queue processing")
    while True:
        try:
            queue.publish_entry()
        except StopIteration:
            break
    queue.warmer.process_remaining_requests()

    cached_days = get_cached_days(queue.cache_emulator.current_time, list(queue.warmer.warmer_requests))
    print("saving plots")
    save_suppliers_plots(queue.cache_emulator.suppliers_statistics,
                         queue.warmer.total_direction_counter,
                         cached_days)
