#!/usr/bin/env python

import argparse
import gzip
import logging
import msgpack
import pandas as pa
import sys
import json

from packaging import version

FIELDS = ("switch", "invnum", "fqdn", "queue", "walle_project", "walle_tags", "dc", "Service", "planner_id", "Inv", "kernel")
MINIMAL_KERNEL = version.parse("4.19.119.30.1")

log = logging.getLogger()
log.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter("%(message)s"))
log.addHandler(handler)


def object_hook(kwargs, fields=FIELDS):
    return {key: kwargs[key] for key in kwargs.keys() if key in fields}


def label_kernel(row):
    return version.parse(str(row["kernel"]).replace('-', '.')) >= MINIMAL_KERNEL


def label_switch_type(data):
    if all(data):
        return "ready"
    if any(data):
        return "mixed"
    return "not_ready"


def get_dc(short_datacenter_name, short_queue_name):
    if short_datacenter_name == 'sas':
        if short_queue_name in ['sas1.1.1', 'sas1.1.2', 'sas1.1.3', 'sas1.1.4', 'sas1.2.1', 'sas1.2.2', 'sas1.3.1', 'sas1.3.2',
                                'sas1.3.3', 'sas1.3.4', 'sas1.4.1', 'sas1.4.2', 'sas1.4.3', 'sas1.4.4', 'sas2.3.1', 'sas2.3.2',
                                'sas2.3.3', 'sas2.3.4', 'sas2.4.1', 'sas2.4.2', 'sas2.4.3', 'sas2.4.4']:
            noc_datacenter_name = 'sas1'
        else:
            noc_datacenter_name = 'sas2'
    else:
        noc_datacenter_name = short_datacenter_name
    return noc_datacenter_name


def mark_data(netmon_data_raw, kernel_info_raw):
    kernel_info = pa.DataFrame(data=kernel_info_raw)
    netmon_data = pa.DataFrame(data=netmon_data_raw)
    netmon_data = netmon_data[netmon_data.invnum != ""]
    aggregated_data = netmon_data.merge(kernel_info, left_on="fqdn", right_on="fqdn", how="left")
#    aggregated_data["switch_fl_ready"] = aggregated_data.switch.isin(rt_data.switch)
    aggregated_data["dc"] = aggregated_data.apply(lambda x: get_dc(x.dc, x.queue), axis=1)
    aggregated_data["is_kernel_ready"] = aggregated_data.apply(lambda row: label_kernel(row), axis=1)
    aggregated_data["switch_type"] = aggregated_data.groupby("switch")["is_kernel_ready"].transform(label_switch_type)
    return aggregated_data


def get_filter(filter_type, filter_args):
    def by_walle_projects(data):
        switches = data[data.walle_project.isin(filter_args)]["switch"].unique()
        return data[data.switch.isin(switches)]

    def by_walle_tags(data):
        mask = None
        for filter_arg in filter_args:
            tokens = filter_arg.split('@')
            if len(tokens) == 2:
                tag = tokens[0]
                if tokens[1] == "sas":
                    dcs = ["sas1", "sas2"]
                else:
                    dcs = [tokens[1]]
                cond = data.walle_tags.apply(lambda x: tag in x) & data.dc.apply(lambda x: x in dcs)
            else:
                cond = data.walle_tags.apply(lambda x: filter_arg in x)
            if mask is None:
                mask = cond
            else:
                mask = mask | cond
        switches = data[mask]["switch"].unique()
        return data[data.switch.isin(switches)]

    filters = {"by_walle_projects": by_walle_projects,
               "by_walle_tags": by_walle_tags}

    return filters.get(filter_type)


def make_stat(aggregated_data, filter_f, args):
    data = filter_f(aggregated_data)
    if args.dump_switch is not None:
        print(data[data.switch.isin([args.dump_switch])])
    stat = {"ready": data[data.switch_type == "ready"]["switch"].unique().tolist(),
            "mixed": data[data.switch_type == "mixed"]["switch"].unique().tolist(),
            "not_ready": data[data.switch_type == "not_ready"]["switch"].unique().tolist()}
    if args.dump_not_ready_switches:
        for sw in stat["mixed"] + stat["not_ready"]:
            print(sw)
            print(data[data.switch.isin([sw])])
            print
    return stat


def check_data(current_data):
    if not sum([len(i) for i in current_data.values()]):
        log.critical("Null current data")
        sys.exit(2)


def check_diff(previous_data, current_data, thresholds={}):
    excesses = {}
    for k in previous_data:
        threshold = abs(thresholds.get(k, 0))
        exceeded = False
        p = set(previous_data[k])
        c = set(current_data[k])
        log.debug("Previous '%s': %s" % (k, p))
        log.debug("Current '%s': %s" % (k, c))
        diff = p.symmetric_difference(c)

        if diff:
            log.debug("'%s' symmetric diff: %s" % (k, diff))
        else:
            log.debug("no diff in '%s'" % k)

        if threshold:
            exceeded = len(p) < len(diff) / threshold

        if exceeded:
            excess = len(diff) / threshold - len(p)
            log.debug("'%s' threshold(%s) exceeded: %s" % (k, threshold, excess))
            excesses[k] = diff

    for k, diff in excesses.items():
        log.critical("Threshold exceeded for %s: %s" % (k, diff))

    if excesses:
        sys.exit(2)


def short_report(args):
    log.info("Creating short report")
    filter_func_name = args.filter_function.replace("-", "_")
    filter_args = args.filter_args
    filter_func = get_filter(filter_func_name, filter_args)
    netmon_data_path = args.netmon_file
    kernel_data_path = args.kernel_file
    output_data_path = args.output_file

    log.info("Filter data with %s filter with args: %s " % (filter_func_name, filter_args))

    log.info("Reading netmon data from %s" % netmon_data_path)
    with gzip.open(netmon_data_path, "rb") as f:
        netmon_data_raw = pa.read_msgpack(f, object_hook=object_hook)

    log.info("Reading kernel version info from %s" % kernel_data_path)
    with gzip.open(kernel_data_path, "rb") as f:
        kernel_info_raw = pa.read_msgpack(f, object_hook=object_hook)

    log.info("Preparing report")
    res = make_stat(mark_data(netmon_data_raw, kernel_info_raw), filter_func, args)
    with open('result.json', 'w') as f:
        json.dump(res, f)

    with gzip.open(output_data_path, "wb") as f:
        msgpack.pack(res, f)
    log.info("Short report prepared %s" % output_data_path)


def diff_report(args):
    log.info("Check diff")
    previous_data_path = args.previous
    current_data_path = args.current
    thresholds = {i.split('=')[0]: float(i.split('=')[1]) for i in args.thresholds}
    log.debug("thresholds: %s" % thresholds)

    with gzip.GzipFile(previous_data_path, "rb") as f:
        previous_data = pa.read_msgpack(f)

    with gzip.GzipFile(current_data_path, "rb") as f:
        current_data = pa.read_msgpack(f)

    check_data(current_data)
    check_diff(previous_data, current_data, thresholds)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(help="mode")
    short_report_parser = subparsers.add_parser("short", help="short report")
    short_report_parser.add_argument("-f", "--filter-by", choices=["by-walle-tags", "by-walle-projects"], dest="filter_function", required=True)
    short_report_parser.add_argument("-fa", "--filter-args", nargs="+", dest="filter_args")
    short_report_parser.add_argument("-nf", "--netmon-file")
    short_report_parser.add_argument("-kf", "--kernel-file")
    short_report_parser.add_argument("-of", "--output-file")
    short_report_parser.add_argument("--dump-switch", default=None)
    short_report_parser.add_argument("--dump-not-ready-switches", action='store_true')
    short_report_parser.set_defaults(func=short_report)

    diff_report_parser = subparsers.add_parser("diff", help="diff report")
    diff_report_parser.add_argument("-p", "--previous")
    diff_report_parser.add_argument("-c", "--current")
    diff_report_parser.add_argument("--thresholds", nargs="+", dest="thresholds", help="threshold k=v comma separated, fi 'ready=0.05'", default=[])
    diff_report_parser.set_defaults(func=diff_report)

    args = parser.parse_args()
    args.func(args)
