#include "parser.h"
#include <linux/if_ether.h>
#include <netinet/ip.h>
#include <linux/ipv6.h>
#include <linux/tcp.h>

#include <iostream>
using namespace std;

namespace {
    TTcpOptions GetTcpOptions(ui8* opt, int optsLen) {
        using ETcpOptions = TTcpOptions::ETcpOptions;
        TTcpOptions result;

        while (optsLen > 0) {
            const ui8 optionId{*opt++};
            if (optionId == ETcpOptions::ENDOFOPTIONLIST) {
                break;
            }
            if (optionId == ETcpOptions::NOOPERATION) {
                --optsLen;
                continue;
            }

            const ui8 currLen{*opt++};
            if (optsLen < currLen || currLen < 2) {
                break;
            }

            result[optionId] = TString((const char *)opt, currLen-2);
            optsLen -= currLen;
            opt += currLen - 2;
        }

        return result;
    }

    std::optional<TUserInfo> ParseTcp(const TRawPacket& packet, tcphdr* tcpHeader, TUserInfo ui) {
        if (!tcpHeader->syn || tcpHeader->ack) {
            return std::nullopt;
        }

        int bytesLeft = packet.Size - (((ui8*)tcpHeader - packet.Data) + sizeof(tcphdr));
        int optsLen = tcpHeader->doff * 4 - sizeof(tcphdr);
        if (optsLen > bytesLeft) {
            return std::nullopt;
        }

        ui.SrcPort = htons(tcpHeader->source);
        ui.DstPort = htons(tcpHeader->dest);
        ui.TcpOptions = GetTcpOptions((ui8*)tcpHeader + sizeof(tcphdr), optsLen);
        ui.Reserved = tcpHeader->res1;

        return ui;
    }
}

TBlockInfo ParseBlock(const TCurrentBlock& block, std::function<void(const TRawPacket&)> callback) {
    auto pbd = block.Data;
    ui32 numPkts = pbd->h1.num_pkts;
    auto ppd = (tpacket3_hdr*)((ui8*)pbd + pbd->h1.offset_to_first_pkt);

    for (ui32 i = 0; i < numPkts; ++i) {
        sockaddr_ll *ll = (sockaddr_ll*)((ui8*)ppd + TPACKET_ALIGN(sizeof(tpacket3_hdr)));

        callback(TRawPacket{
            htons(ll->sll_protocol),
            ll->sll_pkttype,
            (ui8*)ppd + ppd->tp_net,
            ppd->tp_snaplen - (ppd->tp_net - ppd->tp_mac)});

        ppd = (tpacket3_hdr*)((ui8*)ppd + ppd->tp_next_offset);
    }

    return {numPkts};
}

std::optional<TUserInfo> ParseV4(const TRawPacket& packet) {
    auto ipHeader = (iphdr*)packet.Data;
    int ipHeaderLen = ipHeader->ihl * 4;
    TUserInfo ui;

    if (ipHeader->protocol != IPPROTO_TCP) {
        return std::nullopt;
    }
    if (ipHeaderLen + sizeof(tcphdr) > packet.Size) {
        return std::nullopt;
    }

    ui.AddressFamily = AF_INET;
    ui.Src4.s_addr = ipHeader->saddr;
    ui.Dst4.s_addr = ipHeader->daddr;

    auto tcpHeader = (tcphdr*)(packet.Data + ipHeaderLen);
    return ParseTcp(packet, tcpHeader, std::move(ui));
}

std::optional<TUserInfo> ParseV6(const TRawPacket& packet) {
    auto ipHeader = (ipv6hdr*)packet.Data;
    TUserInfo ui;

    if (ipHeader->nexthdr != IPPROTO_TCP) {
        return std::nullopt;
    }
    if (sizeof(ipv6hdr) + sizeof(tcphdr) > packet.Size) {
        return std::nullopt;
    }

    ui.AddressFamily = AF_INET6;
    ui.Src6 = ipHeader->saddr;
    ui.Dst6 = ipHeader->daddr;

    auto tcpHeader = (tcphdr*)(packet.Data + sizeof(ipv6hdr));
    return ParseTcp(packet, tcpHeader, std::move(ui));
}
