#include "yt_dumper.h"
#include "file_dumper.h"

#include <balancer/production/x/okula/apps/common/raw_packets.h>

#include <cerrno>
#include <cstring>
#include <csignal>

std::atomic_bool CAPTURE_ENABLED{true};

static void ExitHandler(int signum);
static void SetExitHandler();

class TYtPacketsReducer: public NYT::IReducer<NYT::TTableReader<NYT::TNode>, NYT::TTableWriter<NYT::TNode>> {
public:
    void Do(TReader* reader, TWriter* writer) override {
        TStringStream data;

        for (; reader->IsValid(); reader->Next()) {
            auto& row = reader->GetRow();
            data << row["data"].AsString();
        }

        NYT::TNode result;
        result["data"] = data.Str();
        writer->AddRow(result);
    }
};
REGISTER_REDUCER(TYtPacketsReducer);

template <typename T>
std::unique_ptr<TClientWorker<T>>
TClientWorker<T>::CreateWorker(const TString& filePath, const TString& filter, IPcapOptions& options) {
    std::unique_ptr<TPcap> pcap = std::make_unique<TPcap>(filePath);

    std::unique_ptr<TPcapFilter> pcapFilter = std::make_unique<TPcapFilter>(pcap->Session(), filter);

    options.SetHeader(pcap->FileHeader());
    options.SetLinkType(pcap->LinkType());
    options.SetSnapSize(pcap->SnapSize());

    std::unique_ptr<T> dumperInstance = std::make_unique<T>(&options);

    std::unique_ptr<TClientWorker<T>> instance(
        new TClientWorker<T>(std::move(pcap), std::move(pcapFilter), std::move(dumperInstance)));
    return instance;
}

template <typename T>
void TClientWorker<T>::Start() {
    TPcapPacketHeader* pktHdr;
    const unsigned char* pkt;

    int ret;
    while ((ret = Pcap_->NextEx(&pktHdr, &pkt)) >= 0 and CAPTURE_ENABLED.load()) {
        /* ret == 0 is timeout, pktHdr and pkt invalid pointers */
        if (ret == 0)
            usleep(100000);
        else
            ProcessPacket(pktHdr, pkt);
    }

    if (ret == -1) {
        ythrow yexception() << "Failed to capture packets: "
                            << Pcap_->ErrorString() << "\n";
    } else if (ret == -2) {
        Cout << "Capture loop is ended\n";
    }

    DumperInstance_->Finish();
}

template <typename T>
void TClientWorker<T>::ProcessPacket(
    const TPcapPacketHeader* pktHdr,
    const unsigned char* packet) {
    ui8 ipVersion = (*(packet + Pcap_->Offset())) >> 4;

    unsigned char srcIp[16]{0};
    unsigned char dstIp[16]{0};

    ui8 headerSize = 0x00;

    TMultiFrame frame;

    if (ipVersion == IPV4) {
        auto ipLayer = (TIp4Frame*)(packet + Pcap_->Offset());

        std::memcpy(&srcIp, &ipLayer->srcIp, 4);
        std::memcpy(&dstIp, &ipLayer->dstIp, 4);

        frame.Proto(ipLayer->proto);
        frame.AFInet(AF_INET);

        /* We can't cast 4 fields in С/C++. first 4bits of ipVersion is version, second is IHL */
        headerSize = (ipLayer->ipVersion & 0x0f) * 4;
    } else if (ipVersion == IPV6) {
        auto ipLayer = (TIp6Frame*)(packet + Pcap_->Offset());
        std::memcpy(&srcIp, &ipLayer->srcIp, 16);
        std::memcpy(&dstIp, &ipLayer->dstIp, 16);

        frame.Proto(ipLayer->proto);
        frame.AFInet(AF_INET6);

        /* Fixed size of header. All optional headers stored in proto field */
        headerSize += 40;

        if (frame.Proto() != IP6OPT_TCP) {
            bool nxt = true;

            while (nxt) {
                auto ipLayerNxt = (TIp6FrameNxt*)(packet + Pcap_->Offset() + headerSize);
                frame.Proto(ipLayerNxt->proto);
                /* 96 bits for optional field */
                headerSize += 12;
                /* Match only basic fields */
                switch (frame.Proto()) {
                    case IP6OPT_TCP:
                    case IP6OPT_HOP:
                    case IP6OPT_ROUTING:
                    case IP6OPT_FRAG:
                    case IP6OPT_AH:
                    case IP6OPT_DST:
                    case IP6OPT_MIPV6:
                    case IP6OPT_SHIM6:
                        break;
                    default:
                        nxt = false;
                        break;
                }
            }
        }
    } else {
        return;
    }

    /* Invalid ip header */
    if ((ipVersion == IPV4 && headerSize < 20) || (ipVersion == IPV6 && headerSize < 40))
        return;

    // Match only TCP
    if (frame.Proto() == IP6OPT_TCP) {
        auto tcpLayer = (TTcpFrame*)(packet + Pcap_->Offset() + headerSize);

        frame.SrcAdr(srcIp);
        frame.DstAdr(dstIp);
        frame.SrcPort(tcpLayer->srcPort);
        frame.DstPort(tcpLayer->dstPort);

        TDumpMessage message(pktHdr, packet, frame);
        DumperInstance_->Write(message);
    }
}

void RunFileSplitter(const TString& filePath,
                     const TString& outputDir,
                     int depth) {
    SetExitHandler();

    TFilePcapOptions options(outputDir, depth);

    std::unique_ptr<TClientWorker<TFilePcapDumper>> clientWorker(
        TClientWorker<TFilePcapDumper>::CreateWorker(filePath, "", options));
    if (clientWorker) {
        clientWorker->Start();
    }
}

void RunYtSplitter(const TString& filePath, const TString& clusterName, const TString& tableName) {
    SetExitHandler();

    auto ytClient = NYT::CreateClient(clusterName);

    if (ytClient->Exists(tableName)) {
        ythrow yexception() << "Table [" << tableName << "] already exists\n";
    }

    auto ytTrx = ytClient->StartTransaction();

    TYtPcapOptions options(ytTrx, tableName);

    std::unique_ptr<TClientWorker<TYtPcapDumper>> clientWorker(
        TClientWorker<TYtPcapDumper>::CreateWorker(filePath, "", options));
    if (clientWorker) {
        clientWorker->Start();
        ytTrx->Sort(
            NYT::TSortOperationSpec()
                .SortBy({"hash", "order"})
                .AddInput(tableName)
                .Output(tableName));

        ytTrx->Commit();
    }
}

static void ExitHandler(int signum) {
    Y_UNUSED(signum);

    CAPTURE_ENABLED.store(false, std::memory_order_relaxed);
}

static void SetExitHandler() {
    struct sigaction act {};

    act.sa_handler = ExitHandler;
    act.sa_flags = 0;
    sigemptyset(&act.sa_mask);
    sigaddset(&act.sa_mask, SIGINT);
    sigaddset(&act.sa_mask, SIGTERM);

    sigaction(SIGINT, &act, NULL);
    sigaction(SIGTERM, &act, NULL);
}
