#!/usr/bin/python
from __future__ import print_function
from bcc import BPF
import multiprocessing
import argparse
import os
from time import sleep, strftime
from collections import defaultdict, Counter
import re
from heapq import nlargest

examples = """examples:

"""
# define BPF program
bpf_text = """
#include <uapi/linux/ptrace.h>
#include <linux/sched.h>
#include <linux/nsproxy.h>
#include <linux/pid_namespace.h>
#include <linux/cgroup.h>

BPF_PERCPU_ARRAY(pcpu_stats, s32, 1);

RAW_TRACEPOINT_PROBE(sched_switch)
{
    struct task_struct *prev = (struct task_struct *)ctx->args[1];
    struct task_struct *next = (struct task_struct *)ctx->args[2];
    u32 zero = 0;

    struct css_set *next_css = (struct css_set *)next->cgroups;
    struct cgroup *next_cgrp = (struct cgroup *)next_css->dfl_cgrp;

    int next_id = next_cgrp->kn->id.ino;

    $BLOCK

    return 0;
}
"""

bpf_per_cgrp_txt = """
    pcpu_stats.update(&zero, &next_id);
"""

bpf_per_pid_txt = """
    s32 id =  next->$PID;
    if (next_id != $CGRP) {
        id = -1; // skip
    }
    pcpu_stats.update(&zero, &id);
"""

per_cgroup_header = ("TIME", "ID", "CGROUP", "AAT(%)", "AAT(cores)", "Limit(cores)")
per_tgid_header = ("TIME", "TGID", "COMM", "AAT(%)", "AAT(cores)", "Limit(cores)")
un_pth = "/sys/fs/cgroup/unified/"
cpu_pth = "/sys/fs/cgroup/cpu/"
stats = defaultdict(list)
nproc = multiprocessing.cpu_count()
cgroup_cache = {1: ("root", nproc)}
known_cgrps = []


def un_to_cpu_cgrp_path(s):
    result = ""
    l = s.split("/")
    if l[0] != "porto":
        return s
    for i, p in enumerate(s.split("/")):
        if i == 0:
            w = ""
        elif i%2 == 1:
            w = "%"
        else:
            w = "/"
        result += w+p
    return result


def cgrp_path_to_ino(s):
    full_path = os.path.join(un_pth, "porto", s)
    return os.stat(full_path).st_ino


def get_cpu_limit(p):
    p = p.lstrip("/")
    if not p:
        return float(nproc)
    try:
        with open(os.path.join(cpu_pth, p.lstrip("/"), "cpu.cfs_quota_us")) as f:
            limit = float(f.read())
        if limit < 0:
            raise IOError
        with open(os.path.join(cpu_pth, p.lstrip("/"), "cpu.cfs_period_us")) as f:
            period = float(f.read())
        if period < 0:
            raise IOError
        else:
            cpu_limit = limit/period
    except IOError as e:
        cpu_limit = get_cpu_limit("/".join(p.split("/")[:-1]))
    return cpu_limit


def get_cgrp_info(i):
    cgrp_info = cgroup_cache.get(i)
    if cgrp_info:
        return cgrp_info
    for root, dirs, fns in os.walk(un_pth):
        for d in dirs:
            full_path = os.path.join(root, d)
            ino = os.stat(full_path).st_ino
            if ino in cgroup_cache:
                continue
            cgrp_name = full_path.replace(un_pth, "")
            cpu_limit = get_cpu_limit(un_to_cpu_cgrp_path(cgrp_name))
            cgroup_cache[ino] = (cgrp_name, cpu_limit)
            if ino == i:
                return cgrp_name, cpu_limit
    return str(i), -1


def get_comm_by_pid(pid):
    try:
        with open('/proc/'+ str(pid) +'/comm') as f:
            return f.readline().rstrip()
    except:
        return str(pid)


def get_tgid_by_pid(pid):
    try:
        with open('/proc/'+ str(pid) +'/status') as f:
            line = next((x for x in f if x.startswith('Tgid:')), None)
            if line is None:
                return "exited"
            return line.split()[1]
    except IOError:
        return "except"


def percentile(p, lst):
    # TODO: replace with faster quick select
    n = len(lst)*(100 - p)//100 + 1
    top = nlargest(n, lst)
    return top[-1]


def columns(xss):
    width = [0]*len(xss[0])
    for xs in xss:
        for i, x in enumerate(xs):
            width[i] = max(width[i], len(x))
    return (
        (x + " " * (width[i] - len(x)) for i, x in enumerate(xs))
        for xs in xss
    )


def handle_sample(cpu_entries, key=lambda x: x):
    tmp = defaultdict(int)
    return Counter(key(k) for k in map(int, cpu_entries) if k > 0)


def parse_args():
    parser = argparse.ArgumentParser(
        description="",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=examples)
    parser.add_argument("-d", action="store_true",
                        help="print ebpf and exit")
    parser.add_argument("-G", action="store_true",
                        help="per TGID statistics")
    parser.add_argument("-P", action="store_true",
                        help="per PID statistics")
    parser.add_argument("-F",
                        help="filter by COMM/CGROUP field (rgxp)")
    parser.add_argument("-t", default=0, type=float,
                        help="filter AAT by threshold (cores)")
    parser.add_argument("-r", default=-1.0, type=float,
                        help="filter AAT by threshold (perc)")
    parser.add_argument("-i", default=1, type=int,
                        help="print interval (seconds)")
    parser.add_argument("-p", default=100, type=int,
                        help="take pth percentile of aat sample")
    return parser.parse_args()


def main():
    global bpf_text
    args = parse_args()

    if not BPF.support_raw_tracepoint():
        exit(1)

    filterf = lambda x: True
    keyf = lambda x: x

    if args.P or args.G:
        if not args.F:
            print("-F is required")
            exit(1)

        cgrp_ino = cgrp_path_to_ino(args.F)
        _, cgrp_cpu_limit = get_cgrp_info(cgrp_ino)

        bpf_text = bpf_text.replace('$BLOCK', bpf_per_pid_txt)
        if args.G:
            bpf_text = bpf_text.replace('$PID', 'tgid')
            expand_key = lambda tgid: (tgid, get_comm_by_pid(tgid), cgrp_cpu_limit)
        else:
            bpf_text = bpf_text.replace('$PID', 'pid')
            keyf = lambda pid: (get_tgid_by_pid(pid), get_comm_by_pid(pid))
            expand_key = lambda tgid_comm: tgid_comm + (cgrp_cpu_limit,)
        bpf_text = bpf_text.replace('$CGRP', str(cgrp_ino))
    else:
        if args.F:
            filterf = re.compile(args.F).search
        expand_key = lambda cgrp_ino: (cgrp_ino,) + get_cgrp_info(cgrp_ino)
        bpf_text = bpf_text.replace('$BLOCK', bpf_per_cgrp_txt)

    if args.d:
        print(bpf_text)
        exit()

    if args.G or args.P:
        header = per_tgid_header
    else:
        header = per_cgroup_header

    # load BPF program
    b = BPF(text=bpf_text)

    # output
    print("Tracing average active threads... Hit Ctrl-C to end.")
    i = 0

    while 1:
        try:
            sleep(0.005)
            items = next(iter(b["pcpu_stats"].values()), [])
            for k, c in handle_sample(items, key=keyf).items():
                stats[k].append(c)
            i += 1
            if i < args.i*200:
                continue
            ts = strftime("%H:%M:%S")
            lines = [header]
            istats = ((k, percentile(args.p, v)) for k, v in stats.items())
            for k, aat in sorted(istats, key=lambda kv: kv[1]):
                if aat < args.t:
                    continue
                k, entity, cpu_limit = expand_key(k)
                if not filterf(entity):
                    continue
                if len(entity) > 100:
                    entity = entity[:50]+'....'+entity[-50:]
                if cpu_limit <= 0:
                    aat_perc = -1
                else:
                    aat_perc = round(100.0*aat/cpu_limit)
                if aat_perc < args.r:
                    continue
                line = (ts, k, entity, aat_perc, aat, cpu_limit)
                lines.append(tuple(map(str, line)))
            print()
            for x in columns(lines):
                print("  ".join(x))
            i = 0
            stats.clear()
        except KeyboardInterrupt:
            b["pcpu_stats"].clear()
            exit()

if __name__ == "__main__":
    main()
