#include <infra/ebpf-agent/lib/program.h>
#include <infra/ebpf-agent/lib/utils.h>

#include <library/cpp/getopt/last_getopt.h>

#include <util/system/interrupt_signals.h>
#include <util/system/condvar.h>
#include <library/cpp/deprecated/atomic/atomic.h>
#include <util/generic/hash.h>

#include <sys/stat.h>

using namespace NEbpfAgent;

TMutex WaitMutex;
TCondVar WaitCondVar;
TAtomic Stopped = false;

void SignalHandler(int) {
    AtomicSet(Stopped, true);
    WaitCondVar.Signal();
}

int numCpus;
THashMap<TString, ui64> cgroupIds;

struct TParams {
    TDuration interval;
    bool perDc;
    bool ingress;
};

void PrintStats(const TString& cgroup, const TBpfFd& mapFd, const TParams& params)
{
    static TVector<struct cgroup_net_stat> perCpuStats(numCpus);
    static TVector<struct cgroup_net_stat_dc> perCpuDcStats(numCpus);

    struct bpf_cgroup_storage_key key = { cgroupIds[cgroup], params.ingress ? BPF_CGROUP_INET_INGRESS : BPF_CGROUP_INET_EGRESS };
    void* value = params.perDc ? static_cast<void*>(perCpuDcStats.data()) : static_cast<void*>(perCpuStats.data());
    if (bpf_map_lookup_elem(mapFd, &key, value) < 0) {
        return;
    }

    if (params.perDc) {
        struct cgroup_net_stat_dc stat = {};
        for (int i = 0; i < numCpus; ++i) {
            stat = stat + perCpuDcStats[i];
        }
        Cout << ToString(stat) << Endl;
    } else {
        struct cgroup_net_stat stat = {};
        for (int i = 0; i < numCpus; ++i) {
            stat = stat + perCpuStats[i];
        }
        Cout << ToString(stat) << Endl;
    }
}

void Loop(const TVector<TString>& cgroups, const TParams& params)
{
    TStringBuf mapName;
    if (params.perDc) {
        mapName = params.ingress ? "net_stat_dc_rx_map" : "net_stat_dc_tx_map";
    } else {
        mapName = params.ingress ? "net_stat_rx_map" : "net_stat_tx_map";
    }
    auto pinPath = TString::Join("/sys/fs/bpf/", mapName);
    struct stat sb;
    auto ret = ::stat(pinPath.c_str(), &sb);
    if (ret == 0) {
        auto mapFd = TBpfFd::Get(pinPath.c_str());
        Y_VERIFY(mapFd >= 0);

        for (const auto& cgroup: cgroups) {
            cgroupIds[cgroup] = GetCgroupId(cgroup);
            PrintStats(cgroup, mapFd, params);
        }

        return;
    } else if (errno != ENOENT) {
        return;
    }

    SetMlockLimit(300 * 1024 * 1024);

    const TCgroupBpfProgram* prog;
    if (params.perDc) {
        prog = params.ingress ? &TNetStatDcBpfProgram::Rx(true, true) : &TNetStatDcBpfProgram::Tx(true, true);
    } else {
        prog = params.ingress ? &TNetStatBpfProgram::Rx(true, true) : &TNetStatBpfProgram::Tx(true, true);
    }
    auto mapFd = prog->GetMap(mapName);
    Y_VERIFY(mapFd >= 0);
    mapFd.Pin(pinPath);

    for (const auto& cgroup: cgroups) {
        cgroupIds[cgroup] = GetCgroupId(cgroup);
        TCgroupBpfProgram::Attach(*prog, cgroup, params.ingress ? BPF_CGROUP_INET_INGRESS : BPF_CGROUP_INET_EGRESS);
    }

    while (!AtomicGet(Stopped)) {
        for (const auto& cgroup: cgroups) {
            auto cgroupId = GetCgroupId(cgroup);
            if (cgroupId != cgroupIds[cgroup]) {
                TCgroupBpfProgram::Attach(*prog, cgroup, params.ingress ? BPF_CGROUP_INET_INGRESS : BPF_CGROUP_INET_EGRESS);
                cgroupIds[cgroup] = cgroupId;
            }

            //PrintStats(cgroup, mapFd, params);
        }

        with_lock (WaitMutex) {
            while (!Stopped && WaitCondVar.WaitD(WaitMutex, params.interval.ToDeadLine()));
        }
    }

    for (const auto& cgroup: cgroups) {
        TCgroupBpfProgram::Detach(*prog, cgroup, params.ingress ? BPF_CGROUP_INET_INGRESS : BPF_CGROUP_INET_EGRESS);
    }
    ::unlink(pinPath.c_str());
}

int main(int argc, char* argv[])
{
    try {
        numCpus = libbpf_num_possible_cpus();
        Y_VERIFY(numCpus > 0);

        using namespace NLastGetopt;

        TOpts opts(TOpts::Default());
        opts.SetFreeArgTitle(0, "cgroup", "Path to cgroup");
        opts.AddHelpOption('h');
        opts.AddVersionOption('v');

        TOpt& intervalOpt = opts.AddLongOption('n', "interval", "Update interval in seconds").DefaultValue(1);
        TOpt& perDcOpt = opts.AddLongOption("per-dc", "Show per datacenter statistics"); perDcOpt.Optional().HasArg(NO_ARGUMENT);
        TOpt& ingressOpt = opts.AddLongOption("ingress", "Show ingress statistics instead of egress"); ingressOpt.Optional().HasArg(NO_ARGUMENT);

        TOptsParseResult parsedOpts(&opts, argc, argv);

        auto cgroups = parsedOpts.GetFreeArgs();
        if (cgroups.empty()) {
            cgroups.emplace_back(FindCgroupRoot());
        }

        SetInterruptSignalsHandler(SignalHandler);

        TParams params;
        params.interval = TDuration::Seconds(FromString<unsigned>(parsedOpts.Get(&intervalOpt)));
        params.perDc = parsedOpts.Has(&perDcOpt);
        params.ingress = parsedOpts.Has(&ingressOpt);
        Loop(cgroups, params);
    } catch (...) {
        Cerr << CurrentExceptionMessage() << Endl;
        return -1;
    }

    return 0;
}
