#pragma once

#include <infra/netmon/agent/common/io.h>
#include <infra/netmon/agent/common/buf.h>
#include <infra/netmon/agent/common/utils.h>
#include <infra/netmon/agent/common/packet.h>

#include <util/network/address.h>
#include <util/memory/smallobj.h>
#include <util/generic/xrange.h>
#include <util/generic/list.h>

#include <netinet/ip6.h>
#include <netinet/in.h>
#include <netinet/udp.h>

#include <array>

namespace NNetmon {
    namespace {
        /* IPv6 version + 0x7255 flow id */
        const ui32 NETMON_LINK_MAGIC = 0x60007255;
    }

    class TLinkPacket : public TPacket, public TIntrusiveListItem<TLinkPacket>, public TObjectFromPool<TLinkPacket> {
    public:
        struct TTimestamps {
            ui64 SentTime = 0ull;
            ui64 RecvTime = 0ull;
        };

        using TRef = THolder<TLinkPacket>;
        using TListType = TIntrusiveListWithAutoDelete<TLinkPacket, TDelete>;

        using TPacket::TPacket;

        template <typename... Args>
        static inline TRef Make(TPool& pool, Args&&... args) {
            return TRef(new (&pool) TLinkPacket(std::forward<Args>(args)...));
        }

        TLinkPacket(TOnDemandBuffer::TPool& pool,
                    const NAddr::IRemoteAddr& srcIp,
                    ui16 srcPort,
                    const NAddr::IRemoteAddr& srcMac,
                    const NAddr::IRemoteAddr& dstIp,
                    ui16 dstPort,
                    const NAddr::IRemoteAddr& dstMac,
                    int tos = 0)
            : TPacket(pool, dstIp, tos)
            , SrcPort_(srcPort)
            , DstPort_(dstPort)
            , Valid_(false)
        {
            Y_VERIFY(SrcIpCapacity() >= srcIp.Len());
            ::memcpy(SrcIpData(), srcIp.Addr(), srcIp.Len());

            Y_VERIFY(SrcMacCapacity() >= srcMac.Len());
            ::memcpy(SrcMacData(), srcMac.Addr(), srcMac.Len());

            Y_VERIFY(DstMacCapacity() >= dstMac.Len());
            ::memcpy(DstMacData(), dstMac.Addr(), dstMac.Len());
        }

        static bool ExtractProbeIdFromRaw(void* data, size_t len, ui64& probeId) {
            if (len < sizeof(struct ethhdr) + sizeof(struct ip6_hdr) + sizeof(struct udphdr) + TSettings::Get()->GetLinkPollerPacketSize()) {
                return false;
            }

            ui8 *ptr = reinterpret_cast<ui8 *>(data);
            ptr += sizeof(struct ethhdr) + sizeof(struct ip6_hdr) + sizeof(struct udphdr) + 2;
            ui64 *id = reinterpret_cast<ui64 *>(ptr);

            probeId = *id;

            return true;
        }

        struct PseudoHeader6 {
            ui8 src[16];
            ui8 dst[16];
            ui32 len;
            ui8 zeroes[3];
            ui8 next;
        } __attribute__((packed));

        inline ui16 Checksum(const struct udphdr* udp) {
            PseudoHeader6 psh;

            const struct sockaddr_in6* sin6_src = reinterpret_cast<struct sockaddr_in6*>(SrcIpData());
            const struct sockaddr_in6* sin6_dst = reinterpret_cast<struct sockaddr_in6*>(ClientAddrData());
            ::memcpy(&psh.src, &sin6_src->sin6_addr, sizeof(struct in6_addr));
            ::memcpy(&psh.dst, &sin6_dst->sin6_addr, sizeof(struct in6_addr));

            psh.len = HostToInet(ui32(sizeof(struct udphdr) + TSettings::Get()->GetLinkPollerPacketSize()));

            memset(&psh.zeroes, 0, sizeof(psh.zeroes));
            psh.next = 0x11;

            return TPacket::HdrBodyChecksum(reinterpret_cast<ui16*>(&psh), sizeof(psh),
                                            reinterpret_cast<const ui16*>(udp), sizeof(struct udphdr) + TSettings::Get()->GetLinkPollerPacketSize(),
                                            0);
        }

        virtual size_t GetDataOffset() const override {
            // Header size + magic offset from Fill()
            return sizeof(struct ethhdr) + sizeof(struct ip6_hdr) + sizeof(struct udphdr) + sizeof(ui16) + sizeof(ui64);
        }

        inline void Fill() noexcept {
            Length_ = sizeof(struct ethhdr) + sizeof(struct ip6_hdr) + sizeof(struct udphdr) + TSettings::Get()->GetLinkPollerPacketSize();

            uint8_t* ptr = reinterpret_cast<uint8_t *>(Data());
            struct ethhdr *const eth = reinterpret_cast<struct ethhdr *>(ptr);

            const struct sockaddr_ll *ll_src = reinterpret_cast<struct sockaddr_ll *>(SrcMacData());
            const struct sockaddr_ll *ll_dst = reinterpret_cast<struct sockaddr_ll *>(DstMacData());
            ::memcpy(eth->h_source, ll_src->sll_addr, ETH_ALEN);
            ::memcpy(eth->h_dest, ll_dst->sll_addr, ETH_ALEN);

            eth->h_proto = htons(ETH_P_IPV6);

            ptr += sizeof(ethhdr);

            struct ip6_hdr* const ip6 = reinterpret_cast<struct ip6_hdr*>(ptr);

            ::memset(ip6, 0, sizeof(struct ip6_hdr));
            ip6->ip6_ctlun.ip6_un1.ip6_un1_nxt  = IPPROTO_UDP;
            ip6->ip6_ctlun.ip6_un1.ip6_un1_hlim  = 2;
            ip6->ip6_ctlun.ip6_un1.ip6_un1_flow  = htonl(NETMON_LINK_MAGIC | ((TypeOfService_ << 20) & 0x0ff00000));

            const struct sockaddr_in6* sin6_src = reinterpret_cast<struct sockaddr_in6*>(SrcIpData());
            const struct sockaddr_in6* sin6_dst = reinterpret_cast<struct sockaddr_in6*>(ClientAddrData());
            ::memcpy(&ip6->ip6_src, &sin6_src->sin6_addr, sizeof(struct in6_addr));
            ::memcpy(&ip6->ip6_dst, &sin6_dst->sin6_addr, sizeof(struct in6_addr));

            ptr += sizeof(ip6_hdr);
            struct udphdr * const udp = reinterpret_cast<struct udphdr*>(ptr);

            udp->source = htons(SrcPort_);
            udp->dest = htons(DstPort_);
            udp->check = 0;
            udp->len = htons(sizeof(struct udphdr) + TSettings::Get()->GetLinkPollerPacketSize());

            ptr += sizeof(*udp);
            ui16 * const ptp = reinterpret_cast<ui16*>(ptr);

            /*
                Magic constant that lets ixgbe match packet
                (along with UDP L4 proto on 319 port)
                as PTPv2 datagram and report hw timestamp
            */
            *ptp = htons(0x0002);

            ptr += sizeof(ui16);
            ui64* const data = reinterpret_cast<ui64*>(ptr);
            *data = Stats_.ProbeId;

            ip6->ip6_ctlun.ip6_un1.ip6_un1_plen = htons(sizeof(struct udphdr) + TSettings::Get()->GetLinkPollerPacketSize());

            FillData();

            udp->check = Checksum(udp);
        }

        bool Validate(ui64 probeId) noexcept {
            if (Length_ != sizeof(struct ethhdr) + sizeof(struct ip6_hdr) + sizeof(struct udphdr) + TSettings::Get()->GetLinkPollerPacketSize()) {
                return false;
            }

            struct sockaddr_in6* sin6_src = reinterpret_cast<struct sockaddr_in6*>(SrcIpData());
            struct sockaddr_in6* sin6_dst = reinterpret_cast<struct sockaddr_in6*>(ClientAddrData());

            uint8_t* ptr = reinterpret_cast<uint8_t *>(Data());
            const struct ethhdr* eth = reinterpret_cast<const struct ethhdr*>(ptr);
            ptr += sizeof(struct ethhdr);

            const struct ip6_hdr* ip6 = reinterpret_cast<const struct ip6_hdr*>(ptr);

            /*
                Incoming packet validation and parsing (along with socket eBPF):
                1) eth protocol is ETH_P_IPV6
                2) src and dst ipv6 addrs are correct
                3) IPv6 Next Header contains IPPROTO_UDP
                4) TTL == 1 (Switch should decrease initial value 2 by one)
                5) IPv6 version and flowid matches magic.
                   We should keep in mind possible tclass reassignment on switch,
                   so just mask its value
            */

            if (eth->h_proto == htons(ETH_P_IPV6) &&
                !::memcmp(&ip6->ip6_src, &sin6_src->sin6_addr, sizeof(struct in6_addr)) &&
                !::memcmp(&ip6->ip6_dst, &sin6_dst->sin6_addr, sizeof(struct in6_addr)) &&
                ip6->ip6_ctlun.ip6_un1.ip6_un1_nxt == IPPROTO_UDP &&
                ip6->ip6_ctlun.ip6_un1.ip6_un1_hlim == 1 &&
                (ntohl(ip6->ip6_ctlun.ip6_un1.ip6_un1_flow) & 0xf000ffff) == NETMON_LINK_MAGIC)
            {
                ptr += sizeof(struct ip6_hdr) + sizeof(struct udphdr) + 2;
                const ui64* data = (ui64*)ptr;
                Stats().ProbeId = *data;
                Valid_ = Stats().ProbeId == probeId;

                return Valid_;
            }

            return false;
        }

        inline sockaddr* SrcIpData() noexcept {
            return (sockaddr*)&SrcIp_;
        }
        inline const sockaddr* SrcIpData() const noexcept {
            return (sockaddr*)&SrcIp_;
        }

        inline std::size_t SrcIpCapacity() noexcept {
            return sizeof(SrcIp_);
        }

        inline sockaddr* SrcMacData() noexcept {
            return (sockaddr*)&SrcMac_;
        }
        inline const sockaddr* SrcMacData() const noexcept {
            return (sockaddr*)&SrcMac_;
        }

        inline std::size_t SrcMacCapacity() noexcept {
            return sizeof(SrcMac_);
        }

        inline sockaddr* DstMacData() noexcept {
            return (sockaddr*)&DstMac_;
        }
        inline const sockaddr* DstMacData() const noexcept {
            return (sockaddr*)&DstMac_;
        }

        inline std::size_t DstMacCapacity() noexcept {
            return sizeof(DstMac_);
        }

        inline bool Valid() const noexcept {
            return Valid_;
        }

        inline TTimestamps& HwTimestamps() noexcept {
            return HwTimestamps_;
        }
        inline const TTimestamps& HwTimestamps() const noexcept {
            return HwTimestamps_;
        }

        inline TTimestamps& SysTimestamps() noexcept {
            return SysTimestamps_;
        }
        inline const TTimestamps& SysTimestamps() const noexcept {
            return SysTimestamps_;
        }

        inline TTimestamps& UserTimestamps() noexcept {
            return UserTimestamps_;
        }
        inline const TTimestamps& UserTimestamps() const noexcept {
            return UserTimestamps_;
        }

    private:
        struct sockaddr_storage SrcIp_;
        struct sockaddr_storage SrcMac_;
        struct sockaddr_storage DstMac_;

        const ui16 SrcPort_;
        const ui16 DstPort_;

        bool Valid_;

        TTimestamps HwTimestamps_;
        TTimestamps SysTimestamps_;
        TTimestamps UserTimestamps_;
    };

    class TLinkIOHandler : public TNonCopyable {
    public:
        using TPacket = TLinkPacket;

        TLinkIOHandler(TLog& logger,
                       const NAddr::IRemoteAddrRef& addr,
                       const NAddr::TIPv6Addr& srcIp);
        ~TLinkIOHandler();

        const NAddr::IRemoteAddrRef& Addr() const noexcept;

        TMaybeIOStatus GetTxTs(ui64 probeId, TLinkPacket& packet) noexcept;
        TMaybeIOStatus Read(TLinkPacket& packet) noexcept;
        TMaybeIOStatus Write(TLinkPacket& packet) noexcept;

        void PollD(TCont* cont, ui16 opFilter, const TInstant& deadline);
        void PollDWithoutYield(const TDuration& timeout, const TInstant& deadline);

        void Close() noexcept;

        void Reopen(const NAddr::IRemoteAddrRef& addr, const NAddr::TIPv6Addr& srcIp);

    private:
        class TImpl;
        THolder<TImpl> Impl;
    };
}


