#include "handler.h"

#include <infra/netmon/agent/common/metrics.h>
#include <infra/netmon/agent/common/utils.h>

#include <util/system/platform.h>
#include <util/random/random.h>
#include <util/generic/list.h>

#include <netinet/tcp.h>
#include <netinet/ip.h>
#include <netinet/ip6.h>

#ifdef _linux_
#   include <linux/filter.h> // sock_filter and sock_fprog
#endif

namespace NNetmon {
    class TTcpSocketAdapter4;
    class TTcpSocketAdapter6;

    class TTcpSocketAdapter : public TNonCopyable {
    public:
        static THolder<TTcpSocketAdapter> Make(const NAddr::IRemoteAddrRef& addr, ui16 minPort, ui16 maxPort);

        inline TTcpSocketAdapter(const NAddr::IRemoteAddrRef& addr, int domain, ui16 minPort, ui16 maxPort)
            : Addr_(addr)
            , Socket_(::socket(domain, SOCK_RAW, IPPROTO_TCP))
            , MinPort_(minPort)
            , MaxPort_(maxPort)
        {
            if (Socket_ == INVALID_SOCKET) {
                ythrow TSystemError() << TStringBuf("can't create socket for pinger");
            }

            SetNonBlock(Socket_);

#ifdef SO_RXQ_OVFL // since Linux 2.6.33
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_RXQ_OVFL, 1, "overflow detection");
#endif

#ifdef SO_TIMESTAMPNS // only on Linux
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_TIMESTAMPNS, 1, "kernel timestamps");
#endif

            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_RCVBUF, MAX_PACKET_LENGTH * 256, "recv buffer");
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_SNDBUF, MAX_PACKET_LENGTH * 256, "send buffer");

            EnableErrorQueue(Addr_->Addr()->sa_family, Socket_);

            if (bind(Socket_, Addr_->Addr(), Addr_->Len()) < 0) {
                ythrow TSystemError() << TStringBuf("bind to ") << NAddr::PrintHost(*Addr_) << TStringBuf(" failed");
            }

            for (ui16 port = minPort; port <= maxPort; port++) {
                BindedSockets_.emplace_back(::socket(domain, SOCK_STREAM , IPPROTO_TCP));
                auto& sk = BindedSockets_.back();
                if (sk == INVALID_SOCKET) {
                    ythrow TSystemError() << TStringBuf("can't create tcp listen sockets");
                }

                SetReusePort(sk, true);

                auto withPort = SetPort(*Addr_, port);

                if (bind(sk, withPort->Addr() , withPort->Len())) {
                    ythrow TSystemError() << TStringBuf("can't bind listen tcp sockets");
                }

                if (::listen(sk, 0)) {
                    ythrow TSystemError() << TStringBuf("can't listen on tcp sockets");
                }

            }
        }

        virtual ~TTcpSocketAdapter() = default;

        inline const NAddr::IRemoteAddrRef& Addr() const noexcept {
            return Addr_;
        }

        void Close() noexcept {
            Socket_.Close();

            for (auto &sk: BindedSockets_) {
                sk.Close();
            }
        }

        virtual bool ValidatePacket(TTcpPacket& packet) noexcept = 0;

        inline operator SOCKET() const noexcept {
            return Socket_;
        }

        virtual void FillChecksum(TTcpPacket &packet) const = 0;

    protected:
        NAddr::IRemoteAddrRef Addr_;
        TSocketHolder Socket_;
        TList<TSocketHolder> BindedSockets_;
        ui16 MinPort_;
        ui16 MaxPort_;
    };


    class TTcpSocketAdapter6 : public TTcpSocketAdapter {
    public:
        inline TTcpSocketAdapter6(const NAddr::IRemoteAddrRef& addr, ui16 minPort, ui16 maxPort)
            : TTcpSocketAdapter(addr, AF_INET6, minPort, maxPort)
        {
#if defined(_linux_)
            /* Command to generate:
               $ tcpdump -i lo '(ether[4:4] == 0x12345678) and \
                                (ether[2:2] >= 1000 and ether[2:2] <= 2000) and \
                                (ether[0xd:1] == 0x14 or ether[0xd:1] == 0x06)' -dd */

            struct sock_filter code[] = {
                { 0x20, 0, 0, 0x00000004 }, // shift to tcp.seq
                { 0x15, 0, 7, NETMON_TCP_MAGIC }, // tcp.seq == NETMON_TCP_MAGIC
                { 0x28, 0, 0, 0x00000002 }, // shift tp tcp.dport
                { 0x35, 0, 5, minPort }, // tcp.dport >= minPort
                { 0x25, 4, 0, maxPort }, // and tcp.dport <= maxPort
                { 0x30, 0, 0, 0x0000000d }, // shift to tcp.flags
                { 0x15, 1, 0, 0x00000014 }, // tcp.flags == ACK, RST
                { 0x15, 0, 1, 0x00000006 }, // or tcp.flags == SYN, RST
                { 0x6, 0, 0, 0x00040000 },
                { 0x6, 0, 0, 0x00000000 },
            };
            struct sock_fprog bpf = {
                .len = sizeof(code) / sizeof(code[0]),
                .filter = code,
            };

            // Attach created code to our socket
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_ATTACH_FILTER, bpf, "attach BPF");
#else
#   error NO BPF filtering support, any sane OS does.
#endif

            const int cksumOffset = offsetof(struct tcphdr, check);
#if defined SOL_RAW && defined IPV6_CHECKSUM // @Linux
            CheckedSetSockOpt(Socket_, SOL_RAW, IPV6_CHECKSUM, cksumOffset, "enable checksum");
#elif defined IPPROTO_IPV6 && defined IPV6_CHECKSUM
            CheckedSetSockOpt(Socket_, IPPROTO_IPV6, IPV6_CHECKSUM, cksumOffset, "enable checksum");
#else
#   error NO IPv6 checksum offset support
#endif

            DoNotFragment(AF_INET6, Socket_);
            CheckedSetSockOpt(Socket_, IPPROTO_IPV6, IPV6_RECVTCLASS, 1, "ipv6 recv tos");
        }


        void FillChecksum(TTcpPacket &packet) const override {
            //Luckily, we have IPV6_CHECKSUM working for us
            Y_UNUSED(packet);
        }

        bool ValidatePacket(TTcpPacket& packet) noexcept override {
            if (packet.Size() < sizeof(tcphdr)) {
                return false;
            }

            const tcphdr* tcp = reinterpret_cast<const tcphdr*>(packet.Data());

            ui16 dport = InetToHost(tcp->dest);

            if (dport < MinPort_ || dport > MaxPort_)
                return false;

            if (tcp->rst && tcp->syn && !tcp->ack) {
                packet.Type() = TTcpPacket::EType::Request;
            } else if (tcp->rst && !tcp->syn && tcp->ack){
                packet.Type() = TTcpPacket::EType::Reply;
            } else {
                return false;
            }

            packet.DstPort() = InetToHost(tcp->dest);
            packet.SrcPort() = InetToHost(tcp->source);

            TMemoryInput stream((ui8 *)(packet.Data()) + sizeof(tcphdr),
                                packet.Size() - sizeof(tcphdr));

            packet.Stats().Load(&stream);

            SetPort(packet.ClientAddrData(), packet.SrcPort());

            return !packet.Truncated();
        }
    };

    class TTcpSocketAdapter4 : public TTcpSocketAdapter {
    public:
        inline TTcpSocketAdapter4(const NAddr::IRemoteAddrRef& addr, ui16 minPort, ui16 maxPort)
            : TTcpSocketAdapter(addr, AF_INET, minPort, maxPort)
        {
#if defined(_linux_)
            /* Same as SocketAdapter6, but offsets include ip header.
               See above description for BPF command annotation.
               Command to generate:

               $ tcpdump -i lo '(ether[24:4] == 0x12345678) and \
                                (ether[22:2] >= 1000 and ether[22:2] <= 2000) and \
                                (ether[33:1] == 0x14 or ether[33:1] == 0x06)' -dd */

            struct sock_filter code[] = {
                { 0x20, 0, 0, 0x00000018 },
                { 0x15, 0, 7, NETMON_TCP_MAGIC },
                { 0x28, 0, 0, 0x00000016 },
                { 0x35, 0, 5, minPort },
                { 0x25, 4, 0, maxPort },
                { 0x30, 0, 0, 0x00000021 },
                { 0x15, 1, 0, 0x00000014 },
                { 0x15, 0, 1, 0x00000006 },
                { 0x6, 0, 0, 0x00040000 },
                { 0x6, 0, 0, 0x00000000 },

            };
            struct sock_fprog bpf = {
                .len = sizeof(code) / sizeof(code[0]),
                .filter = code,
            };

            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_ATTACH_FILTER, bpf, "attach BPF");
            DoNotFragment(AF_INET, Socket_);
#else
#   error NO BPF filtering support, any sane OS does.
#endif
        }

        struct PseudoHeader4 {
            ui32 src;
            ui32 dst;
            ui8 pad;
            ui8 proto;
            ui16 len;
        } __attribute__((packed));

        void FillChecksum(TTcpPacket &packet) const override {
            // This time, we need TCP/IPv4 pseudo-header...
            PseudoHeader4 psh;

            psh.src = TIpAddress(*reinterpret_cast<const sockaddr_in*>(
                                 Addr_->Addr())).Host();
            psh.dst = TIpAddress(*reinterpret_cast<const sockaddr_in*>(
                                 packet.ClientAddrData())).Host();
            psh.pad = 0x0;
            psh.proto = 0x6;
            psh.len = HostToInet(ui16(packet.Size()));

            ui16 csum = 0;

            csum = TPacket::HdrBodyChecksum(reinterpret_cast<ui16*>(&psh), sizeof(psh),
                                            reinterpret_cast<ui16*>(packet.Data()), packet.Size(),
                                            0);
            packet.FillChecksum(csum);
        }

        bool ValidatePacket(TTcpPacket& packet) noexcept override {
            std::size_t offset = sizeof(tcphdr) + sizeof(iphdr);
            if (packet.Size() < offset) {
                return false;
            }

            const tcphdr* tcp = reinterpret_cast<const tcphdr*>((ui8*)packet.Data() + sizeof(iphdr));

            ui16 dport = InetToHost(tcp->dest);

            if (dport < MinPort_ || dport > MaxPort_)
                return false;

            if (tcp->rst && tcp->syn && !tcp->ack) {
                packet.Type() = TTcpPacket::EType::Request;
            } else if (tcp->rst && !tcp->syn && tcp->ack){
                packet.Type() = TTcpPacket::EType::Reply;
            } else {
                return false;
            }

            packet.DstPort() = InetToHost(tcp->dest);
            packet.SrcPort() = InetToHost(tcp->source);

            TMemoryInput stream((ui8 *)(packet.Data()) + offset,
                                packet.Size() - offset);

            packet.Stats().Load(&stream);

            SetPort(packet.ClientAddrData(), packet.SrcPort());

            return !packet.Truncated();
        }
    };

    THolder<TTcpSocketAdapter> TTcpSocketAdapter::Make(const NAddr::IRemoteAddrRef& addr, ui16 minPort, ui16 maxPort) {
        switch (addr->Addr()->sa_family) {
            case AF_INET:
                return MakeHolder<TTcpSocketAdapter4>(addr, minPort, maxPort);
            case AF_INET6:
                return MakeHolder<TTcpSocketAdapter6>(addr, minPort, maxPort);
            default:
                ythrow yexception() << TStringBuf("unknown family given");
        };
    }

    class TTcpIOHandler::TImpl : public TNonCopyable {
    public:
        TImpl(TLog& logger, const NAddr::IRemoteAddrRef& addr, ui16 minPort, ui16 maxPort)
            : Logger_(logger)
            , Adapter_(TTcpSocketAdapter::Make(addr, minPort, maxPort))
            , Handler_(Logger_, *Adapter_)
        {
        }

        ~TImpl() {
            if (Adapter_) {
                Adapter_->Close();
            }
        }

        inline const NAddr::IRemoteAddrRef& Addr() const noexcept {
            return Adapter_->Addr();
        }

        inline TMaybeIOStatus Read(TTcpPacket& packet) noexcept {
            auto result(Handler_.Read(packet));
            if (packet.Size() && !Adapter_->ValidatePacket(packet)) {
                return Nothing();
            }
            if (result.Defined() && result->Processed()) {
                packet.Stats().SourceReceivedTime = Handler_.RcvdTime();
            }
            return result;
        }

        inline TMaybeIOStatus Write(TTcpPacket& packet) noexcept {
            const auto timeToLive(packet.TimeToLive());
            if (timeToLive != LastTimeToLive) {
                SetTimeToLive(Addr()->Addr()->sa_family, *Adapter_, timeToLive);
                LastTimeToLive = timeToLive;
            }

            packet.Fill();
            Adapter_->FillChecksum(packet);

            return Handler_.Write(packet);
        }

        inline void PollD(TCont* cont, ui16 opFilter, const TInstant& deadline) {
            Handler_.PollD(cont, opFilter, deadline);
        }

        inline void PollT(TCont* cont, ui16 opFilter, const TDuration& duration) {
            Handler_.PollT(cont, opFilter, duration);
        }

    private:
        TLog& Logger_;
        THolder<TTcpSocketAdapter> Adapter_;
        TIOHandler Handler_;
        i32 LastTimeToLive = -1;
    };

    TTcpIOHandler::TTcpIOHandler(TLog& logger, const NAddr::IRemoteAddrRef& addr, ui16 minPort, ui16 maxPort)
        : Impl(MakeHolder<TImpl>(logger, addr, minPort, maxPort))
        , MinPort_(minPort)
        , MaxPort_(maxPort)
    {
    }

    TTcpIOHandler::~TTcpIOHandler() {
    }

    const NAddr::IRemoteAddrRef& TTcpIOHandler::Addr() const noexcept {
        Y_VERIFY(Impl);
        return Impl->Addr();
    }

    TMaybeIOStatus TTcpIOHandler::Read(TTcpPacket& packet) noexcept {
        Y_VERIFY(Impl);
        return Impl->Read(packet);
    }

    TMaybeIOStatus TTcpIOHandler::Write(TTcpPacket& packet) noexcept {
        Y_VERIFY(Impl);
        return Impl->Write(packet);
    }

    void TTcpIOHandler::PollD(TCont* cont, ui16 opFilter, const TInstant& deadline) {
        Y_VERIFY(Impl);
        Impl->PollD(cont, opFilter, deadline);
    }

    void TTcpIOHandler::PollT(TCont* cont, ui16 opFilter, const TDuration& duration) {
        Y_VERIFY(Impl);
        Impl->PollT(cont, opFilter, duration);
    }

    ui32 TTcpIOHandler::MaxProbeId() noexcept {
        /* Motivation:
           - Let's expect no more than 65536 <dst ip>:<port> target pairs.
             By default we expect no more than one random port per IP.
           - By default we use 32 source ports.
           - Mind of 8 possible DSCP classes */

        return 32 * 65536 * 8;
    }

    void TTcpIOHandler::Close() noexcept {
        Y_VERIFY(Impl);
        Impl.Destroy();
    }
}
