#!/usr/bin/env python3

import argparse
import json
import multiprocessing as mp
import os
import re
import shlex
import sys
import tempfile

import yandex

class PerfTool(object):
    def __init__(self):
        self.exe = shlex.split(os.environ.get('PERF', ''))
        if not self.exe:
            self.exe = ['ya', 'tool', 'perf']
        with open(os.devnull, 'w') as devnull:
            #print(self.exe)
            yandex.subprocess_call(True, self.exe + ['--version'], stdout=devnull)

    def run(self, args, **kwargs):
        return yandex.subprocess_run(True, self.exe + args, **kwargs)

    def record(self, tmpdir, duration, vendor):
        args = ['record', '-a', '-e', '"cycles"']
        if vendor == 'intel':
            args = args + ['-e', 'core_power.lvl0_turbo_license']
            args = args + ['-e', 'core_power.lvl1_turbo_license']
            args = args + ['-e', 'core_power.lvl2_turbo_license']
        args = args + ['--', 'sleep', str(duration)]
        self.run(args, cwd=tmpdir)

    def to_json(self, tmpdir):
        self.run(['data', 'convert', '--force', '--to-json', 'perf.json'], cwd=tmpdir)

    def support_turbo_lvls(self):
        out = yandex.subprocess_check_output(True, self.exe + ['--no-pager', 'list']).strip().decode().split('\n')
        for line in out:
            if line.strip() == 'core_power.lvl0_turbo_license':
                return True
        return False


class Task(object):
    def __init__(self, pid, tid, comm, timestamp, symbol):
        self.pid = pid
        self.tid = tid
        self.comm = comm
        self.timestamp = timestamp
        self.symbol = symbol
        self.need_report = False

    def mark_for_report(self):
        self.need_report = True


class Cpu(object):
    def __init__(self, cpu, vendor):
        self.cpu = cpu
        self.vendor = vendor
        self.freq = 0.0
        self.license_lvl = 0
        self.license_lvl_max = 0
        self.current = None

    def set_current_license_lvl(self, lvl):
        self.license_lvl = lvl
        self.license_lvl_max = max(self.license_lvl, self.license_lvl_max)

    def reset_max_turbo_license(self):
        self.license_lvl_max = max(0, self.license_lvl)

    def set_current_task(self, task):
        cur = self.current
        if cur and cur.need_report:
            if self.vendor == 'intel':
                print('%-16s %-10u %-10u %-20.9f %-3u L%-4u %s' % \
                    (cur.comm, cur.pid, cur.tid, cur.timestamp, self.cpu, self.license_lvl_max, cur.symbol))
            else:
                print('%-16s %-10u %-10u %-20.9f %-3u %s' % \
                    (cur.comm, cur.pid, cur.tid, cur.timestamp, self.cpu, cur.symbol))

        self.current = task

    def get_cpu_id(self):
        return self.cpu

    def get_smt_cpu_id(self):
        half = int(mp.cpu_count() / 2)

        if self.cpu < half:
            return self.cpu + half
        else:
            return self.cpu - half


def cpu_vendor():
    lscpu = yandex.subprocess_check_output(False, ["lscpu"], shell=True).strip().decode().split('\n')
    for line in lscpu:
        if line.startswith('Vendor ID:'):
            if 'GenuineIntel' in line:
                return 'intel'
            elif 'AuthenticAMD' in line:
                return 'amd'
            else:
                break
    raise RuntimeError('Unknown vendor ID')


def avx_symbol(symbol):
    return symbol is not None and 'avx' in symbol.lower()


def main(args, tmpdir):
    perf = PerfTool()
    vendor = cpu_vendor()

    if vendor == 'intel' and not perf.support_turbo_lvls():
        vendor = 'intel_old'

    perf.record(tmpdir, args.time, vendor)
    perf.to_json(tmpdir)

    cpus = []
    for i in range(0, mp.cpu_count()):
        cpus.append(Cpu(i, vendor))

    if vendor == 'intel':
        print('%-16s %-10s %-10s %-20s %-3s %-5s %s' % \
                ('COMM', 'PID', 'TID', 'TIME', 'CPU', 'TURBO', 'SYMBOL'))
    else:
        print('%-16s %-10s %-10s %-20s %-3s %s' % \
                ('COMM', 'PID', 'TID', 'TIME', 'CPU', 'SYMBOL'))

    json_name = tmpdir + '/perf.json'

    with open(json_name, 'r') as json_file:
        data = json.load(json_file)
        for sample in data['samples']:
            timestamp = sample['timestamp'] / 1000000000.0
            cpuid = sample['cpu']
            pid = sample['pid']
            tid = sample['tid']
            comm = sample['comm']
            event = sample['event']

            symbol = None
            if 'symbol' in sample['callchain'][0]:
                symbol = sample['callchain'][0]['symbol']

            #print('%-20s %-3u %-8u %u %s %s' % (comm, cpuid, pid, timestamp, event, symbol))

            this_cpu = cpus[cpuid]

            smt_cpuid = this_cpu.get_smt_cpu_id()
            smt_cpu = cpus[smt_cpuid]

            if event.startswith('core_power'):
                lvl = int(re.findall(r'\d+', event)[0])
                this_cpu.set_current_license_lvl(lvl)

            if this_cpu.current is not None and this_cpu.current.pid == pid:
                continue

            task = Task(pid, tid, comm, timestamp, symbol)
            this_cpu.set_current_task(task)
            this_cpu.reset_max_turbo_license()

            if pid != args.pid and avx_symbol(symbol):
                if smt_cpu.current is not None and smt_cpu.current.pid == args.pid:
                    this_cpu.current.mark_for_report()
            elif pid == args.pid and smt_cpu.current is not None:
                if avx_symbol(smt_cpu.current.symbol):
                    smt_cpu.current.mark_for_report()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--pid', metavar='PID', required=True, type=int,
                        help='victim task pid')
    parser.add_argument('-t', '--time', metavar='TIME', type=float, default=5.0,
                        help='perf record duration time, default: 5.0')
    args = parser.parse_args()
    with tempfile.TemporaryDirectory(prefix='avxspot-') as tmpdir:
        main(args, tmpdir)
