#include "main.h"

#include <balancer/production/x/okula/apps/common/pcap.h>
#include <balancer/production/x/okula/apps/common/raw_packets.h>

#include <library/cpp/digest/murmur/murmur.h>

#include <util/system/thread.h>

#include <cstring>
#include <cerrno>
#include <signal.h>
#include <thread>
#include <vector>

#if defined(_linux_)
#include <linux/if_packet.h>
#endif

std::atomic_bool CAPTURE_ENABLED{true};

static void ExitHandler(int signum);
static void SetExitHandler();

void RunMainCounter(const TCounterOptions& options) {
    SetExitHandler();

    TStatisticsHolder statsHolder(options.prettyPrint);

    std::vector<std::thread> workers;
    auto clientWorker = std::make_unique<TClientWorker>(options, &statsHolder);

    if (options.displayTimeout > 0) {
        workers.emplace_back([&, options] {
            clientWorker->StatsObserver(options.displayTimeout);
        });
    }

    for (size_t i = 0; i < options.threadCount; ++i) {
        workers.emplace_back([&, i] {
            try {
                clientWorker->Start(i);
            } catch (const yexception& e) {
                Cerr << e.what() << "\n";
                CAPTURE_ENABLED.store(false, std::memory_order_relaxed);
            } catch (...) {
                CAPTURE_ENABLED.store(false, std::memory_order_relaxed);
            }
        });
    }

    for (auto& worker : workers) {
        worker.join();
    }

    Cout << statsHolder.JsonDump() << "\n";
}

void TClientWorker::Start(size_t threadId) {
    TStringStream threadName;
    threadName << "Capture worker thread " << threadId;
    TThread::SetCurrentThreadName(threadName.Str().c_str());
    int ret;

#if defined(_linux_)
    cpu_set_t cpuset;
    pthread_t thread = pthread_self();
    CPU_ZERO(&cpuset);
    CPU_SET(threadId, &cpuset);

    ret = pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
    if (ret != 0) {
        Cerr << "Could not set CPU affinity for thread: " << threadId
             << "\n";
    }
#endif
    TPcap pcap(Iface_, BufferSize_, SnapSize_);
    TPcapFilter pcapFilter(pcap.Session(), Filter_);

#if defined(_linux_)
    int val = (getpid() & 0xffff) | (PACKET_FANOUT_LB << 16);
    if (setsockopt(pcap_fileno(pcap.Session()), SOL_PACKET, PACKET_FANOUT, &val, sizeof(val)) < 0) {
        Cerr << "Failed to set fanout: "
             << strerror(errno)
             << "\n";
        return;
    }
#endif

    TPacketHeader* pktHdr;
    const unsigned char* pkt;

    while ((ret = pcap.NextEx(&pktHdr, &pkt)) >= 0 and CAPTURE_ENABLED.load()) {
        /* ret == 0 is timeout and pktHdr, pkt invalid pointers */
        if (ret == 0)
            usleep(100000);
        else
            ProcessPacket(pktHdr, pkt, pcap.Offset());
    }

    if (ret == -1) {
        ythrow yexception() << "Failed to capture packets: "
                            << pcap.ErrorString() << "\n";
    } else if (ret == -2) {
        Cout << "Capture loop is ended\n";
    }
}

void TClientWorker::StatsObserver(ui32 timeout) {
    while (CAPTURE_ENABLED.load()) {
        {
            Cout << StatsHolder_->JsonDump() << "\n";
        }
        sleep(timeout);
    }
}

void TClientWorker::ProcessPacket(const TPacketHeader* pktHdr, const unsigned char* packet, ptrdiff_t etherOffset) {
    Y_UNUSED(pktHdr);

    ui8 ipVersion = (*(packet + etherOffset)) >> 4;

    unsigned char srcIp[16]{0};
    unsigned char dstIp[16]{0};

    ui8 proto = 0x00;
    ui8 headerSize = 0x00;

    if (ipVersion == 0x04) {
        auto ipLayer = (TIp4Frame*)(packet + etherOffset);

        /* convert address to IPv6 80bit == 0, 16 bits == 1, 32bit - IPv4 */
        srcIp[10] = 0xff;
        srcIp[11] = 0xff;
        dstIp[10] = 0xff;
        dstIp[11] = 0xff;
        std::memcpy(((unsigned char*)&srcIp) + 12, &ipLayer->srcIp, 4);
        std::memcpy(((unsigned char*)&dstIp) + 12, &ipLayer->dstIp, 4);
        proto = ipLayer->proto;

        /* We can't cast 4 fields in С/C++. first 4bits of ipVersion is version, second is IHL */
        headerSize = (ipLayer->ipVersion & 0x0f) * 4;
    } else if (ipVersion == 0x06) {
        auto ipLayer = (TIp6Frame*)(packet + etherOffset);
        std::memcpy(&srcIp, &ipLayer->srcIp, 16);
        std::memcpy(&dstIp, &ipLayer->dstIp, 16);
        proto = ipLayer->proto;

        /* Fixed size of header. All optional headers stored in proto field */
        headerSize += 40;

        if (proto != IP6OPT_TCP) {
            bool nxt = true;

            while (nxt) {
                auto ipLayerNxt = (TIp6FrameNxt*)(packet + etherOffset + headerSize);
                proto = ipLayerNxt->proto;
                /* 96 bits for optional field */
                headerSize += 12;
                /* Match only basic fields */
                switch (proto) {
                    case IP6OPT_TCP:
                    case IP6OPT_HOP:
                    case IP6OPT_ROUTING:
                    case IP6OPT_FRAG:
                    case IP6OPT_AH:
                    case IP6OPT_DST:
                    case IP6OPT_MIPV6:
                    case IP6OPT_SHIM6:
                        break;
                    default:
                        nxt = false;
                        break;
                }
            }
        }
    } else {
        return;
    }

    /* Invalid ip header */
    if ((ipVersion == 0x04 && headerSize < 20) || (ipVersion == 0x06 && headerSize < 40))
        return;

    // Match only TCP
    if (proto == IP6OPT_TCP) {
        auto tcpLayer = (TTcpFrame*)(packet + etherOffset + headerSize);

        ui64 hash = 0L;
        ui64 hashInv = 0L;

        TMurmurHash2A<ui64> hasher;
        hasher.Update(&srcIp, 16);
        hasher.Update(&tcpLayer->srcPort, 2);
        hasher.Update(&dstIp, 16);
        hasher.Update(&tcpLayer->dstPort, 2);
        hash = hasher.Value();

        TMurmurHash2A<ui64> hasherInv;
        hasherInv.Update(&dstIp, 16);
        hasherInv.Update(&tcpLayer->dstPort, 2);
        hasherInv.Update(&srcIp, 16);
        hasherInv.Update(&tcpLayer->srcPort, 2);
        hashInv = hasherInv.Value();

        if (!StatsHolder_->HasKey(hash)) {
            if (!StatsHolder_->HasKey(hashInv)) {
                StatsHolder_->FillData(hash, &srcIp[0], tcpLayer->srcPort, &dstIp[0], tcpLayer->dstPort);
            } else {
                hash = hashInv;
            }
        }

        if (tcpLayer->flags & TCP_ACK) {
            StatsHolder_->IncAck(hash);
        }
        if (tcpLayer->flags & TCP_FIN) {
            StatsHolder_->IncFin(hash);
            StatsHolder_->SetTimeEnd(hash);
        }
        if (tcpLayer->flags & TCP_RST) {
            StatsHolder_->IncRst(hash);
            StatsHolder_->SetTimeEnd(hash);
        }
        if (tcpLayer->flags & TCP_SYN) {
            StatsHolder_->SetTimeStart(hash);
            StatsHolder_->IncSyn(hash);
        }
    }
}

static void ExitHandler(int signum) {
    Y_UNUSED(signum);

    CAPTURE_ENABLED.store(false, std::memory_order_relaxed);
}

static void SetExitHandler() {
    struct sigaction act {};

    act.sa_handler = ExitHandler;
    act.sa_flags = 0;
    sigemptyset(&act.sa_mask);
    sigaddset(&act.sa_mask, SIGINT);
    sigaddset(&act.sa_mask, SIGTERM);

    sigaction(SIGINT, &act, NULL);
    sigaction(SIGTERM, &act, NULL);
}
