#!/usr/bin/python3
from datetime import datetime
from functools import partial
import argparse
import itertools
import json
import os
import os.path
import re
import socket
import sys
import time


HOSTNAME = socket.gethostname()
NS = 10**9


def read_int(path):
    with open(path) as f:
        return int(f.read())


def read_field(path, k):
    with open(path) as f:
        for x in f:
            if x.startswith(k):
                _, v = x.split()
                return v
    raise LookupError


def columns(xss):
    width = [0]*len(xss[0])
    for xs in xss:
        for i, x in enumerate(xs):
            width[i] = max(width[i], len(x))
    return (
        (x + " " * (width[i] - len(x)) for i, x in enumerate(xs))
        for xs in xss
    )


def exists(path):
    try:
        os.stat(path)
        return True
    except FileNotFoundError:
        return False


def _msr_cgroup(name, state):
    pjo = partial(os.path.join, name)

    t = int(10**9 * time.clock_gettime(time.CLOCK_MONOTONIC))
    u = read_int(pjo("cpuacct.usage_user"))
    s = read_int(pjo("cpuacct.usage_sys"))
    w = read_int(pjo("cpuacct.wait"))
    with open(pjo("cpu.stat")) as f:
        cpu_stat = {}
        for x in f:
            k, v = x.split()
            cpu_stat[k] = v
    bu = int(cpu_stat["burst_usage"])
    bl = int(cpu_stat["burst_load"])
    r = int(cpu_stat["h_throttled_time"])
    if state:
        du = u - state['u']
        dw = w - state['w']

        dbu = bu - state['bu']
        dbl = bl - state['bl']

        dr = max(0, r - state['r'])

        dcw = max(0, dbl - dbu)

        retval = (name, du, s - state['s'], dbu, max(0, dw-dcw-dr), dcw, dr, t - state['t'])
    else:
        retval = None
    state['u'] = u
    state['s'] = s
    state['bu'] = bu
    state['bl'] = bl
    state['w'] = w
    state['r'] = r
    state['t'] = t
    return (retval, state)


def print_json(xss):
    t = datetime.now().timestamp()
    for (cgrp, du, ds, db, dw, dc, dr, dt) in xss:
        sys.stdout.write(json.dumps({
            'cg': cgrp,
            'h': HOSTNAME,
            'du': du,
            'ds': ds,
            'db': db,
            'dw': dw,
            'dc': dc,
            'dr': dr,
            'dt': dt,
            't':   t,
        }))
        sys.stdout.write('\n')
    sys.stdout.flush()


def print_table(xss):
    yss = [('cgroup', 'user', 'sys', 'busage', 'uwait', 'cwait', 'throt')]
    yss.extend(
        (
            cgrp,
            '{:.2f}c'.format(du/dt),
            '{:.2f}c'.format(ds/dt),
            '{:.2f}c'.format(db/dt),
            '{:.2f}c'.format(dw/dt),
            '{:.2f}c'.format(dc/dt),
            '{:.2f}c'.format(dr/dt),
        )
        for cgrp, du, ds, db, dw, dc, dr, dt in xss
    )

    if len(yss) == 1:
        return
    sys.stdout.write('='*80 + '\n')
    sys.stdout.write(datetime.now().isoformat() + '\n')
    for ys in columns(yss):
        sys.stdout.write('  '.join(ys))
        sys.stdout.write('\n')
    sys.stdout.flush()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("cgrp", default="/sys/fs/cgroup/cpu/", nargs='?', help="path to root cpu cgroup")
    parser.add_argument("-i", default=1.0, type=float, help="print interval (seconds)")
    parser.add_argument("-n", default='.*', type=re.compile, help="cgroup path regex filter")
    parser.add_argument("-f", default='table', type=str, choices=['json', 'table'], help="output format")
    parser.add_argument("-t", default=None, type=float, help="timeout")
    return parser.parse_args()


def scdir(path):
    return itertools.chain([path], (x.path for x in os.scandir(path) if x.is_dir()))


def main():
    args = parse_args()

    interval = args.i

    stop = lambda: False
    if args.t:
        start = time.clock_gettime(time.CLOCK_MONOTONIC)
        stop = lambda: time.clock_gettime(time.CLOCK_MONOTONIC) - start >= args.t

    root = os.path.realpath(args.cgrp)

    states = {}
    while not stop():
        nstates = {}
        lines = []
        for cgrp in scdir(root):
            if not args.n.search(cgrp):
                continue
            try:
                retval, state = _msr_cgroup(cgrp, states.get(cgrp, {}))
            except Exception:
                pass
            else:
                nstates[cgrp] = state
                if retval:
                    lines.append(retval)
        states = nstates
        (print_table if args.f == "table" else print_json)(lines)
        time.sleep(interval)


if __name__ == "__main__":
        main()
