import argparse
import subprocess as subproc
import sys
import time

import py
import simplejson
import requests

# Result: {
#   v: <VERSION> (1),
#   stats: [
#       <time passed>, <labels dict>, <value>
#   ]
# }
#
# <time passed> should be -1 if stat is an absolute value (e.g. current memory usage, copier
# shared resources, etc.
#
# for derived stats <time passed> should be the time in seconds between previous report and current one. For
# example, skycore start_count shows how much skycore has been started up in last <time passed> seconds.
#
# <value> could be integer or float, other values will be ignored on skystat aggregator.
#
# To count derived stats this plugin uses state file with previous values and timestamps. Absolute values reported
# as-is, without being stored in state file.
#
# DO NOT USE THIS LABELS: project, service and cluster. They are not allowed in solomon. That's why we report skycore
# service as "component" here.
#
# Final report looks like this (please update this comment with any new stats!):
# {
#   v: 1,
#   stats: [
#       (skycore basic stats)
#       42, {skycore: start_count}, 12
#
#       (skycore rusage)
#     None, {rusage: rss, proc: skycore}, 12
#       42, {rusage: cpu, proc: skycore}, 12
#
#       (skycore services stats)
#       42, {namespace: skynet, component: <EACH_SKYCORE_SERVICE>, start_cnt: service_main}, 123
#       42, {namespace: skynet, component: <EACH_SKYCORE_SERVICE>, start_cnt: service_check}, 123
#       42, {namespace: skynet, component: <EACH_SKYCORE_SERVICE>, start_cnt: service_stop}, 123
#
#       (skycore services rusage)
#       42, {namespace: skynet, component: <EACH_SKYCORE_SERVICE>, rusage: cpu, proc: service_<EACH_PROCESS_KIND>}, 123  # noqa
#     None, {namespace: skynet, component: <EACH_SKYCORE_SERVICE>, rusage: rss, proc: service_main}, 123
#
#       (skybone stats)
#       42, {namespace: skynet, component: skybone, tracker: pkt_sent}, 123             (not working atm)
#       42, {namespace: skynet, component: skybone, tracker: pkt_recv}, 321             (not working atm)
#       42, {namespace: skynet, component: skybone, tracker: pkt_lost}, 321             (not working atm)
#     None, {namespace: skynet, component: skybone, storage: file_cnt}, 321
#     None, {namespace: skynet, component: skybone, storage: file_size}, 321
#     None, {namespace: skynet, component: skybone, storage: data_size}, 321
#     None, {namespace: skynet, component: skybone, storage: resource_cnt}, 321
#       42, {namespace: skynet, component: skybone, proto: skybit, net_bytes: in}, 321
#       42, {namespace: skynet, component: skybone, proto: skybit, net_bytes: ou}, 321
#       42, {namespace: skynet, component: skybone, proto: skybit, connects: in}, 321
#       42, {namespace: skynet, component: skybone, proto: skybit, connects: ou}, 321
#       42, {namespace: skynet, component: skybone, rusage: cpu, proc: self}, 321
#       42, {namespace: skynet, component: skybone, rusage: cpu, proc: chld}, 321
#   ]
# }


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--format', choices=('pprint', 'msgpack'), default='pprint')
    parser.add_argument('--workdir', required=True)
    parser.add_argument('--skyctl', default='/skynet/startup/skyctl')
    parser.add_argument('--skybonectl', default='/skynet/tools/skybone-ctl')
    parser.add_argument('--show-saved', action='store_true')

    args = parser.parse_args()
    return args


def handle_skycore_stats(now, data):
    stats = []

    if 'main' in data:
        for ts, raw_key, labels, mult in (
            (now, 'skycore_start_count', {'skycore': 'start_cnt'}, 1),

            # We report 100 x seconds, thus on server side (value / time_passed) will be percent cpu used
            (now, 'skycore_cpu_usage', {'rusage': 'cpu', 'proc': 'skycore'}, 100),

            (None, 'skycore_rss', {'rusage': 'rss', 'proc': 'skycore'}, 1)
        ):
            if raw_key in data['main']:
                stats.append((ts, labels, data['main'][raw_key] * mult))

    for ns, services in data.get('svcs', {}).iteritems():
        for svc, svc_stats in services.iteritems():
            def _add_stat(labels, raw_key, derive=True, mult=1):
                labels.update({
                    'namespace': ns, 'component': svc
                })

                if raw_key in svc_stats:
                    value = svc_stats[raw_key] * mult
                    stats.append((
                        now if derive else None,
                        labels,
                        value
                    ))

            for name in (
                'main', 'stop', 'check'
            ):
                raw_key = 'proc_%s_count' % (name, )
                _add_stat({'start_cnt': 'service_%s' % (name, )}, raw_key)

            for name in (
                'main', 'stop', 'check', 'install', 'uninstall', 'notify_config',
            ):
                proc = 'service_%s' % (name, )

                # We report 100 x seconds, thus on server side (value / time_passed) will be percent cpu used
                _add_stat({'rusage': 'cpu', 'proc': proc}, '%s_%s' % (name, 'cpu_usage'), mult=100)

                _add_stat({'rusage': 'rss', 'proc': proc}, '%s_%s' % (name, 'rss'), False)

    return stats


def handle_skybone_stats(now, data):
    stats = []

    def _skbn_stat(data, **labels):
        labels.update({'namespace': 'skynet', 'component': 'skybone'})
        stats.append((now, labels, data))

    def _skbn_stat_abs(data, **labels):
        labels.update({'namespace': 'skynet', 'component': 'skybone'})
        stats.append((None, labels, data))

    _skbn_stat(data['proxy']['connects_ou_skbt'], proto='skybit', connects='ou')
    _skbn_stat(data['proxy']['connects_in_skbt'], proto='skybit', connects='in')
    _skbn_stat(data['proxy']['bytes_in_skbt'], proto='skybit', net_bytes='in')
    _skbn_stat(data['proxy']['bytes_ou_skbt'], proto='skybit', net_bytes='ou')

    # We report 100 x seconds, thus on server side (value / time_passed) will be percent cpu used
    _skbn_stat(data['rusage']['self_user'] * 100 + data['rusage']['self_system'] * 100, rusage='cpu', proc='self')
    _skbn_stat(data['rusage']['chld_user'] * 100 + data['rusage']['chld_system'] * 100, rusage='cpu', proc='chld')

    _skbn_stat_abs(data['file_cache']['data_size'], storage='data_size')
    _skbn_stat_abs(data['file_cache']['file_size'], storage='file_size')
    _skbn_stat_abs(data['file_cache']['file_count'], storage='file_cnt')
    _skbn_stat_abs(data['file_cache']['resource_count'], storage='resource_cnt')

    return stats


def get_skycore_stats(skyctl, now):
    try:
        proc = subproc.Popen([skyctl, 'stats', '--format', 'json'], stdout=subproc.PIPE)
    except:
        return []

    data = proc.stdout.read()
    proc.wait()

    if proc.returncode == 0:
        data = simplejson.loads(data)

        return handle_skycore_stats(now, data)
    else:
        return []


def get_skybone_stats(skybonectl, now):
    try:
        proc = subproc.Popen(
            [skybonectl, '-f', 'json', 'counters', 'proxy', 'file_cache', 'rusage'],
            stdout=subproc.PIPE
        )
    except:
        return []

    data = proc.stdout.read()
    proc.wait()

    if proc.returncode == 0:
        data = simplejson.loads(data)
        return handle_skybone_stats(now, data)
    else:
        return []


def get_stats_difference(old, new):
    cleaned_stats = []

    for ts, labels, value in new:
        if ts is None:
            cleaned_stats.append((None, labels, value))
            continue

        try:
            for old_raw in old:
                if len(old_raw) != 3:
                    continue

                old_ts, old_labels, old_value = old_raw
                if old_labels == labels:
                    value = max(0, value - old_value)
                    time_passed = ts - old_ts

                    if time_passed > 0:
                        cleaned_stats.append((time_passed, labels, value))

                    break
        except:
            continue

    return cleaned_stats


def merge_stats(a, b):
    # a -- old, from state file
    # b -- new, generated here

    merged = []

    try:
        for a_ts, a_labels, a_value in a:
            # First of, ignore some weird invalid values
            if not isinstance(a_ts, int):
                continue
            if not isinstance(a_labels, dict):
                continue
            if not isinstance(a_value, (int, float)):
                continue

            for b_ts, b_labels, b_value in b:
                if b_ts is None:
                    # Ignore new stats with TS = None
                    continue

                if b_labels == a_labels:
                    merged.append((b_ts, b_labels, b_value))
                    break
            else:
                if a_ts is None:
                    # Ignore suddenly appeared absolute stat in state file with TS = None
                    continue

                # Value was not found in new stats -- use old, but only if they are not too old (24hrs)
                if time.time() - a_ts < 86400:
                    merged.append((a_ts, a_labels, a_value))
    except Exception:
        # We were unable to parse either old either new stats
        # just continue here
        pass

    for b_ts, b_labels, b_value in b:
        if b_ts is None:
            # Ignore new stats with TS = None
            continue

        try:
            for a_ts, a_labels, a_value in a:
                if not isinstance(a_ts, int):
                    # Ignore suddenly appeared absolute stat in state file with TS = None
                    continue

                if not isinstance(a_labels, dict) or not isinstance(a_value, (int, float)):
                    # Also ignore invalid labels or values
                    continue

                if a_labels == b_labels:
                    # We already added those above in first block
                    break
            else:
                # Value was not found in old stats -- use new
                merged.append((b_ts, b_labels, b_value))
        except Exception:
            # We were unable to parse old stats, use new here
            merged.append((b_ts, b_labels, b_value))

    return merged


def yasm_forward(stats):
    bytes_in, bytes_ou = None, None
    passed_in, passed_ou = None, None

    try:
        found_in, found_ou = False, False

        for time_passed, labels, value in stats:
            if labels.get('namespace') == 'skynet' and labels.get('component') == 'skybone':
                if 'net_bytes' in labels:
                    if labels['net_bytes'] == 'in':
                        bytes_in = value / time_passed
                        passed_in = time_passed
                        found_in = True
                    elif labels['net_bytes'] == 'ou':
                        bytes_ou = value / time_passed
                        passed_ou = time_passed
                        found_ou = True

            if found_in and found_ou:
                break

        if found_in and found_ou:
            avg_passed = (passed_in + passed_ou) / 2.

            yasm_data = [{
                'tags': {
                    'ctype': 'prod',
                    'prj': 'skybone'
                },
                'ttl': avg_passed * 2,
                'values': [
                    {
                        'name': 'skybone-bytes_ou_vhhh',
                        'val': bytes_ou,
                    },
                    {
                        'name': 'skybone-bytes_in_vhhh',
                        'val': bytes_in
                    }
                ]
            }]

            requests.post('http://localhost:11005', simplejson.dumps(yasm_data), timeout=15)
    except Exception as ex:
        sys.stderr.write('Unable to forward to yasm: %s\n' % (str(ex), ))


def main():
    args = parse_args()

    wd = py.path.local(args.workdir)
    wd.ensure(dir=1)

    now = int(time.time())

    old_stats_fn = wd.join('stats.json')
    if old_stats_fn.check(file=1, exists=1):
        try:
            old_stats = simplejson.load(old_stats_fn.open(mode='rb'))
        except Exception:
            old_stats = []
    else:
        old_stats = []

    if not args.show_saved:
        stats = []

        stats.extend(get_skycore_stats(args.skyctl, now))
        stats.extend(get_skybone_stats(args.skybonectl, now))

        simplejson.dump(merge_stats(old_stats, stats), old_stats_fn.open(mode='wb'))
        stats = get_stats_difference(old_stats, stats)
    else:
        stats = old_stats

    yasm_forward(stats)

    if args.format == 'pprint':
        def _sorter(labels):
            if 'skycore' in labels:
                return 0, labels.get('skycore'), labels
            elif 'proc' in labels and labels['proc'] == 'skycore':
                return 1, labels.get('proc'), labels
            else:
                return 2, labels.get('namespace', 'unknown'), labels.get('component', 'unknown'), labels

        for timeframe, labels, value in sorted(stats, key=lambda x: _sorter(x[1])):
            if args.show_saved:
                now = int(time.time())
                timeframe = timeframe - now

            if 'skycore' in labels:
                ns, sv = 'skynet', 'skycore'
            elif 'proc' in labels and labels['proc'] == 'skycore':
                ns, sv = 'skynet', 'skycore'
            else:
                ns, sv = labels.pop('namespace', 'unknown'), labels.pop('component', 'unknown')

            if isinstance(value, int):
                value = '%13d     ' % (value, )
            else:
                value = '%18.4f' % (value, )

            print(
                '%s [%s]  [%s:%-20s]  %s' % (
                    value,
                    '%3ds' % (timeframe, ) if timeframe is not None else ' -- ',
                    ns, sv,
                    ' '.join('%s=%s' % (key, value) for key, value in sorted(labels.items()))
                )
            )
    else:
        import msgpack
        sys.stdout.write(msgpack.dumps({'v': 1, 'stats': stats}))


if __name__ == '__main__':
    main()
