import glob
import json
import os.path
import logging
import numpy
import pandas

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import seaborn

from ...common.interfaces import AbstractPlugin, MonitoringDataListener, AggregateResultListener


_LOGGER = logging.getLogger(__name__)
_ALL_ = "All"
_CHARTSETS = {
    "cpu-cpu-": {"CPU": _ALL_},
    "net-": {"Network": {"bytes_sent", "bytes_recv"}},
    "diskio-": {"Disk IO": {"read_bytes", "write_bytes"}, "Disk latency": {"read_time", "write_time"}},
    "kernel": {"Kernel": {"vmstat_pgmajfault"}},
}
_CUSTOM_PREFIX = "custom:"
_PHOUT_HEADER = [
    "time",            # timestamp
    "tag",             # request tags separated by '|'
    "interval_real",   # full time of request processing in microseconds
    "connect_time",    # connection time in microseconds
    "send_time",       # data sending time in microseconds
    "latency",         # time spent inside server in microseconds
    "receive_time",    # data receiving time in microseconds
    "interval_event",  # time spent in load generator in microseconds
    "size_out",        # response size in bytes
    "size_in",         # request size in bytes
    "net_code",        # libc errno
    "proto_code",      # HTTP protocol code
]


class Plugin(AbstractPlugin, AggregateResultListener, MonitoringDataListener):
    """
        Plugin generates and writes to file a simple shooting report
    """

    SECTION = 'offline'

    _latency_quantiles = (0.5, 0.95, 0.99)
    _size_quantiles = (0.5, 0.95, 1)

    def __init__(self, core, cfg, cfg_updater):
        super(Plugin, self).__init__(core, cfg, cfg_updater)

    def get_available_options(self):
        return ["warmup_time", "shutdown_time", "phout_path"]

    def configure(self):
        self.__warmup_time = int(self.get_option("warmup_time", "0"))
        self.__shutdown_time = int(self.get_option("shutdown_time", "0"))
        self.__phout_path = self.get_option("phout_path", "")
        self.__shooting_data = []
        self.__monitoring_data = []

        self.core.job.subscribe_plugin(self)

    def on_aggregated_data(self, data, stats):
        if data:
            self.__shooting_data.append(data)

    def monitoring_data(self, data_list):
        if data_list:
            self.__monitoring_data.extend(data_list)

    def __make_unique_path(self, suffix, prefix):
        path = self.core.mkstemp(suffix, prefix)
        self.core.add_artifact_file(path)
        return path

    def post_process(self, retcode):
        _LOGGER.info("Starting offline postprocessing, retcode={}".format(retcode))

        self.__shooting_data.sort(key=lambda item: item["ts"])
        warmup_index = self.__get_shooting_warmup_index()
        shutdown_index = self.__get_shooting_shutdown_index()
        short_data = self.__shooting_data[warmup_index:shutdown_index]
        _LOGGER.info("Shooting data ready ({} -> {} items)".format(len(self.__shooting_data), len(short_data)))

        self.__monitoring_data.sort(key=lambda item: item["timestamp"])
        _LOGGER.info("Monitoring data ready ({} items)".format(len(self.__monitoring_data)))

        with open(self.__make_unique_path(".json", "offline_stats_"), "w") as stats_file:
            json.dump(
                {
                    "dumper": self.__generate_dumper_stats(short_data),
                    "old_dumper": self.__generate_old_dumper_stats(),
                    "shooting": self.__generate_shooting_stats(short_data),
                    "monitoring": self.__generate_monitoring_stats(),
                },
                stats_file
            )

        self.__generate_report(os.path.join(self.core.artifacts_dir, "report.svg"), warmup_index, shutdown_index)

        return retcode

    def __generate_old_dumper_stats(self):
        _LOGGER.info("Generating old style dumper stats")

        if not self.__phout_path:
            dump_pattern = "{}/{}".format(self.core.artifacts_base_dir, "dolbilo_dump_*.log")
            dump_path = glob.glob(dump_pattern)
            self.__phout_path = dump_path[0] if dump_path else ""

        if not self.__phout_path:
            raise Exception("Failed to find dolbilo dump at '{}'".format(dump_pattern))

        df = pandas.read_csv(self.__phout_path, sep='\t', names=_PHOUT_HEADER, header=None)
        return {
            "rps": float(len(df[df.net_code == 0])) / (df.time.max() - df.time.min()),
        }

    def __generate_dumper_stats(self, shooting_data):
        _LOGGER.info("Generating dumper stats")

        result = {}

        # groups of codes
        xxx_requests = 0
        for n in (1, 2, 3, 4, 5):
            count = sum(
                num for num in self.__get_shooting_rps(shooting_data, n * 100, (n + 1) * 100)
            )
            result["{}xx_requests".format(n)] = count
            xxx_requests += count

        # some special codes
        for code in (302, 404, 503):
            result["{}_requests".format(code)] = sum(
                data['overall']['proto_code']['count'].get(str(code), 0)
                for data in shooting_data
            )

        total_requests = sum(data['overall']['interval_real'].get('len', 0) for data in shooting_data)

        result.update({
            "rps": float(result["2xx_requests"]) / (shooting_data[-1]['ts'] - shooting_data[0]['ts']),
            "total_requests": total_requests,
            "other_requests": total_requests - xxx_requests,
            "net_errors": sum(n for n in self.__get_shooting_net_errors(shooting_data)),
        })

        return result

    def __generate_shooting_stats(self, shooting_data):
        _LOGGER.info("Generating shooting stats")

        rps = list(self.__get_shooting_rps(shooting_data))

        result = {
            "errors": sum(
                data['overall']['interval_real'].get('len', 0)
                for data in shooting_data
            ) - sum(rps),
            "rps_0.5": numpy.median(rps),
            "rps_avg": numpy.average(rps),
            "rps_stddev": numpy.std(rps),
        }

        latency = []
        for data in shooting_data:
            hist = data['overall']['interval_real']['hist']
            for bucket, count in zip(hist['bins'], hist['data']):
                latency.extend([bucket] * count)

        latency_quantiles = numpy.percentile(latency, [v * 100.0 for v in self._latency_quantiles])
        result.update({
            "latency_{}".format(quantile_name): float(quantile_value)
            for quantile_name, quantile_value in zip(self._latency_quantiles, latency_quantiles)
        })

        response_size = [data['overall']['size_out']['max'] for data in shooting_data]
        size_quantiles = numpy.percentile(response_size, [v * 100.0 for v in self._size_quantiles])
        result.update({
            "response_size_{}".format(quantile_name): float(quantile_value)
            for quantile_name, quantile_value in zip(self._size_quantiles, size_quantiles)
        })

        return result

    def __generate_monitoring_stats(self):
        _LOGGER.info("Generating monitoring stats")

        cpu_user = []
        signal_name = 'cpu-cpu-total_usage_user'
        prefixed_signal_name = _CUSTOM_PREFIX + signal_name
        for data in self.__monitoring_data:
            metrics = data['data']['localhost']['metrics']
            if signal_name in metrics:
                cpu_user.append(metrics[signal_name])
            elif prefixed_signal_name in metrics:
                cpu_user.append(metrics[prefixed_signal_name])

        return {
            "cpu_user": numpy.median(cpu_user),
        }

    def __generate_report(self, output_path, warmup_index, shutdown_index):
        _LOGGER.info("Generating svg report")

        monitoring_chartsets = self.__get_monitoring_chartsets()
        min_x = self.__shooting_data[0]["ts"]  # sync start of shooting and start of monitoring
        max_x = self.__shooting_data[-1]["ts"]  # remove extra monitoring points after shooting

        seaborn.set(style="whitegrid", palette="Set2")
        seaborn.despine()

        plot_count = len(monitoring_chartsets) + 1
        plt.figure(figsize=(16, 3 * plot_count))

        # testing
        testing_plot = plt.subplot(plot_count, 1, 1)
        plt.title("RPS")
        plt.xlim(0, max_x - min_x)
        x, y = self.__get_shooting_coords("proto_code", min_x)
        for variant in x:
            plt.plot(x[variant], y[variant], label=variant)
        if warmup_index > 0:
            plt.axvline(x=self.__shooting_data[warmup_index]["ts"] - min_x, linestyle='dotted')
        if shutdown_index < len(self.__shooting_data):
            plt.axvline(x=self.__shooting_data[shutdown_index]["ts"] - min_x, linestyle='dotted')
        plt.legend(fontsize="x-small", loc="upper right")

        # monitoring
        for plot_num, chartset_data in enumerate(sorted(monitoring_chartsets.iteritems()), 1):
            chartset_title, signals = chartset_data

            plt.subplot(plot_count, 1, plot_num + 1, sharex=testing_plot)
            plt.title(chartset_title)
            plt.xlim(0, max_x - min_x)

            for signal_name, signal_suffix in signals:
                x, y = self.__get_monitoring_coords(signal_name, min_x, max_x)
                plt.plot(x, y, label=signal_suffix)
            plt.legend(fontsize="x-small", loc="upper right")

        plt.tight_layout()
        plt.suptitle('Shooting results', fontsize=16, fontweight='bold')
        plt.subplots_adjust(top=0.96)
        plt.savefig(output_path)

    def __find_monitoring_chartset(self, signal_prefix, signal_suffix):
        """
            Tune chartset content

            Some of signals should be skipped, and other should be distributed between
            two chartsets
        """

        if signal_prefix.startswith(_CUSTOM_PREFIX):
            signal_prefix = signal_prefix[len(_CUSTOM_PREFIX):]

        for chartset_prefix, chartset_data in _CHARTSETS.iteritems():
            if signal_prefix.startswith(chartset_prefix):
                for chartset_title, chartset_signals in chartset_data.iteritems():
                    if chartset_signals is _ALL_ or signal_suffix in chartset_signals:
                        return "{} {}".format(chartset_title, signal_prefix[len(chartset_prefix):])
                else:
                    return None
        else:
            return signal_prefix

    def __get_monitoring_chartsets(self):
        """Analyze monitoring signals and organize chartsets"""

        chartsets = {}
        for p in self.__monitoring_data:
            metrics = p['data']['localhost']['metrics']
            for signal_name, signal_value in metrics.iteritems():
                if not signal_value:
                    continue

                signal_prefix, signal_suffix = signal_name.split("_", 1)
                chartset_title = self.__find_monitoring_chartset(signal_prefix, signal_suffix)
                if not chartset_title:
                    continue

                chartsets.setdefault((chartset_title), set()).add((signal_name, signal_suffix))

        return chartsets

    def __get_monitoring_coords(self, signal_name, min_x, max_x):
        x, y = [], []
        for p in self.__monitoring_data:
            metrics = p['data']['localhost']['metrics']
            timestamp = p['timestamp']
            if signal_name in metrics and min_x <= timestamp and timestamp <= max_x:
                x.append(timestamp - min_x)
                y.append(metrics[signal_name])
        return x, y

    def __get_shooting_coords(self, signal_name, min_x):
        x = {}
        y = {}
        for data in self.__shooting_data:
            timestamp = data["ts"]
            for variant, count in data["overall"][signal_name]["count"].iteritems():
                x.setdefault(variant, []).append(timestamp - min_x)
                y.setdefault(variant, []).append(count)
        return x, y

    def __get_shooting_warmup_index(self):
        if len(self.__shooting_data) == 0:
            raise Exception("__shooting_data is empty")
        start_time = self.__shooting_data[0]["ts"] + self.__warmup_time
        if self.__shooting_data[-1]["ts"] < start_time:
            raise Exception("Failed to find warmup_index")

        i, last_i = 0, len(self.__shooting_data) - 1
        while i <= last_i and self.__shooting_data[i]["ts"] < start_time:
            i += 1
        return i

    def __get_shooting_shutdown_index(self):
        stop_time = self.__shooting_data[-1]["ts"] - self.__shutdown_time
        if self.__shooting_data[0]["ts"] > stop_time:
            raise Exception("Unable to find shutdown index")

        i, last_i = len(self.__shooting_data) - 1, 0
        while i >= last_i and self.__shooting_data[i]["ts"] > stop_time:
            i -= 1
        return i + 1

    def __get_shooting_rps(self, shooting_data, code1=200, code2=300):
        for data in shooting_data:
            counts = data['overall']['proto_code']['count']
            yield sum(num for code, num in counts.iteritems() if code1 <= int(code) < code2)

    def __get_shooting_net_errors(self, shooting_data, code1=200, code2=300):
        for data in shooting_data:
            counts = data['overall']['net_code']['count']
            yield sum(num for code, num in counts.iteritems() if int(code) != 0)
