// based on https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
// and https://gist.github.com/pavel-odintsov/15b7435e484134650f20

#include <util/datetime/base.h>

#include <crypta/graph/fast_catcher/lib/capturer.h>
#include <crypta/graph/fast_catcher/lib/parser.h>
#include <crypta/graph/fast_catcher/lib/sender.h>

#include <crypta/graph/fast_catcher/proto/options.pb.h>
#include <crypta/graph/fast_catcher/proto/message.pb.h>
#include <crypta/lib/native/cmd_args/parse_pb_options.h>
#include <crypta/lib/native/proto_serializer/proto_serializer.h>
#include <crypta/lib/native/resolver/resolver.h>
#include <library/cpp/protobuf/json/proto2json.h>

#include <atomic>
#include <chrono>
#include <iomanip>
#include <iostream>
#include <functional>
#include <mutex>
#include <thread>
#include <net/ethernet.h>

using std::cout;
using std::endl;

namespace {
    TString MakeMessage(const TUserInfo& ui) {
        THitMessage message;

        switch(ui.AddressFamily) {
            case AF_INET:
                message.SetSource(ToString(ui.Src4));
                message.SetDestination(ToString(ui.Dst4));
                break;
            case AF_INET6:
                message.SetSource(ToString(ui.Src6));
                message.SetDestination(ToString(ui.Dst6));
                break;
        }

        message.SetSrcPort(ui.SrcPort);
        message.SetDstPort(ui.DstPort);
        message.SetTimestamp(TInstant::Now().Seconds());

        return NCrypta::NProtoSerializer::ToJson(message);
    }
}

class TTrafficAnalyzer {
public:
    using TProtocolsMap = THashMap<ui32, ui64>;
    struct TTrafficStats {
        ui64 TotalBlocks = 0;
        ui64 TotalPackets = 0;
        ui64 TotalBytes = 0;
    };

    TTrafficAnalyzer() : TTrafficAnalyzer(TCaptureOptions{}) {}
    TTrafficAnalyzer(const TCaptureOptions& options)
    : Options(options)
    , Sender({
        options.GetDestination(),
        static_cast<uint16_t>(options.GetPort()),
        options.GetWorkers(),
        options.GetQueueSize()
    })
    {
    }

    TProtocolsMap GetProtocols() const {
        std::lock_guard<std::mutex> mapLock(MapMutex);
        return ByProtocols;
    }
    TBlockInfo ParseBlock(const TCurrentBlock& block) {
        ++TotalBlocks;
        return ::ParseBlock(block, [this](const TRawPacket& packet) {
            ParsePacket(packet);
        });
    }
    void ParsePacket(const TRawPacket& packet) {
        std::optional<TUserInfo> userInfo;

        ++TotalPackets;
        TotalBytes += packet.Size;

        {
            std::lock_guard<std::mutex> mapLock(MapMutex);
            ++ByProtocols[packet.EtherProto];
        }

        if (packet.EtherProto == ETHERTYPE_IP) {
            userInfo = ParseV4(packet);
        }

        if (packet.EtherProto == ETHERTYPE_IPV6) {
            userInfo = ParseV6(packet);
        }

        if (userInfo) {
            Sender.Enqueue(MakeMessage(*userInfo));

            bool first = true;
            cout << userInfo->SrcPort << " - " << userInfo->DstPort << " // ";
            for (const auto& [k, v] : userInfo->TcpOptions) {
                cout << (first ? "" : ", "); first = false;
                cout << std::hex << static_cast<int>(k) << " : " << std::setfill('0') << std::setw(2);
                for (ui8 c : v) {
                    cout << static_cast<int>(c);
                }
            }
            cout << std::dec << endl;
        }
    }
    TTrafficStats GetStats() const {
        return {TotalBlocks, TotalPackets, TotalBytes};
    }
    ui32 GetFramesPerBlock() const {
        return Options.GetBlockSize() / Options.GetFrameSize();
    }

private:
    TCaptureOptions Options;
    mutable std::mutex MapMutex;
    TProtocolsMap ByProtocols;
    std::atomic<ui64> TotalBlocks = 0;
    std::atomic<ui64> TotalPackets = 0;
    std::atomic<ui64> TotalBytes = 0;
    TSender Sender;
};

void Capture(const TCaptureOptions& config, TTrafficAnalyzer& analyzer) {
    auto capturer = TCapturer(config);

    while (true) {
        auto currentBlock = capturer.GetCurrentBlock();

        if (!currentBlock.IsWithData()) {
            capturer.WaitForData();
            continue;
        }

        analyzer.ParseBlock(currentBlock);
        capturer.FlushCurrentBlock();
    }
}

void Reporter(const TTrafficAnalyzer& analyzer, bool beVerbose) {
    while (true) {
        if (beVerbose) {
            auto [blocks, packets, bytes] = analyzer.GetStats();
            cout << "total blocks: " << blocks << ", packets: " << packets << ", bytes: " << bytes;
            cout << ", avg pkts per block: " << (double)packets / blocks << endl;
            for (const auto [t, v] : analyzer.GetProtocols()) {
                cout << std::hex << t << " : " << std::dec << v << endl;
            }
        }
        std::this_thread::sleep_for(std::chrono::seconds(10));
    }
}

int main(int argc, const char** argv) {
    auto config = NCrypta::ParsePbOptionsExtended<TCaptureOptions>(argc, argv);

    TTrafficAnalyzer analyzer(config);
    auto unused = std::thread(Reporter, std::ref(analyzer), config.GetVerbose());
    Capture(config, analyzer);
}
