#include "handler.h"
#include "headers.h"

#include <infra/netmon/agent/common/metrics.h>
#include <infra/netmon/agent/common/probe.h>

#include <util/generic/xrange.h>
#include <util/system/platform.h>
#include <util/random/random.h>

#include <netinet/icmp6.h>
#ifdef _linux_
#   include <linux/icmp.h> // ICMP_FILTER, instead of netinet/ip_icmp.h
#   include <linux/filter.h> // sock_filter and sock_fprog
#else
#   include <netinet/ip_icmp.h>
#endif

namespace NNetmon {
    namespace {
        const ui16 MAX_IDENT = 0xf;
        const ui16 IDENT_MASK = 0xfff0;

        inline ui16 GenerateIdent() noexcept {
            ui16 res = 0;
            while (!res) {
                res = RandomNumber<decltype(res)>();
                res &= IDENT_MASK;
            }
            return res;
        }
    }

    class TSocketAdapter4;
    class TSocketAdapter6;

    class TSocketAdapter : public TNonCopyable {
    public:
        static THolder<TSocketAdapter> Make(const NAddr::IRemoteAddrRef& addr);

        inline TSocketAdapter(const NAddr::IRemoteAddrRef& addr, int domain, int type, int protocol)
            : Addr_(addr)
            , Socket_(::socket(domain, type, protocol))
            , Ident_(GenerateIdent())
        {
            if (Socket_ == INVALID_SOCKET) {
                ythrow TSystemError() << TStringBuf("can't create socket for pinger");
            }

            SetNonBlock(Socket_);

#ifdef SO_RXQ_OVFL // since Linux 2.6.33
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_RXQ_OVFL, 1, "overflow detection");
#endif

#ifdef SO_TIMESTAMPNS // only on Linux
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_TIMESTAMPNS, 1, "kernel timestamps");
#endif

            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_RCVBUF, MAX_PACKET_LENGTH * 256, "recv buffer");
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_SNDBUF, MAX_PACKET_LENGTH * 256, "send buffer");

            EnableErrorQueue(Addr_->Addr()->sa_family, Socket_);

            if (bind(Socket_, Addr_->Addr(), Addr_->Len()) < 0) {
                ythrow TSystemError() << TStringBuf("bind to ") << NAddr::PrintHost(*Addr_) << TStringBuf(" failed");
            }
        }

        virtual ~TSocketAdapter() = default;

        inline const NAddr::IRemoteAddrRef& Addr() const noexcept {
            return Addr_;
        }

        void Close() noexcept {
            try {
                Socket_.Close();
            } catch(...) {
            }
        }

        virtual void FillPacket(TIcmpPacket& packet) = 0;
        virtual bool ValidatePacket(TIcmpPacket& packet) = 0;

        inline operator SOCKET() const noexcept {
            return Socket_;
        }

    protected:
        NAddr::IRemoteAddrRef Addr_;
        TSocketHolder Socket_;
        const ui16 Ident_;
    };

    class TSocketAdapter4 : public TSocketAdapter {
    public:
        inline TSocketAdapter4(const NAddr::IRemoteAddrRef& addr)
            : TSocketAdapter(addr, AF_INET, SOCK_RAW, IPPROTO_ICMP)
        {
#if defined(_linux_)
            // Use following command to generate BPF assembler:
            // tcpdump -i tmp-tun "ip and icmp and icmp[icmptype] == icmp-echoreply and (icmp[4:2] & 0xf0ff) == 0xDDBF" -dd
            // use https://paste.yandex-team.ru/198309 to create interface
            //
            // BPF should process packet containing IP and ICMP headers. There is no need to allow icmp errors in BPF
            // filters because kernel maintains separate skb queue for error messages
            struct sock_filter code[] = {
                { 0x30, 0, 0, 0x00000000 },
                { 0x54, 0, 0, 0x000000f0 },
                { 0x15, 0, 11, 0x00000040 },
                { 0x30, 0, 0, 0x00000009 },
                { 0x15, 0, 9, 0x00000001 },
                { 0x28, 0, 0, 0x00000006 },
                { 0x45, 7, 0, 0x00001fff },
                { 0xb1, 0, 0, 0x00000000 },
                { 0x50, 0, 0, 0x00000000 },
                { 0x15, 0, 4, 0x00000000 },
                { 0x48, 0, 0, 0x00000004 },
                { 0x54, 0, 0, 0x0000f0ff },
                { 0x15, 0, 1, HostToInet(Ident_) },
                { 0x6, 0, 0, 0x00040000 },
                { 0x6, 0, 0, 0x00000000 }
            };
            struct sock_fprog bpf = {
                .len = sizeof(code) / sizeof(code[0]),
                .filter = code,
            };
            // Attach created code to our socket
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_ATTACH_FILTER, bpf, "attach BPF");
#endif

            DoNotFragment(AF_INET, Socket_);
        }

        void FillPacket(TIcmpPacket& packet) noexcept override {

            TIcmpHeaders::ICMP& pkt = *reinterpret_cast<TIcmpHeaders::ICMP*>(packet.Data());

            TIcmpHeaders::icmp_type(pkt) = ICMP_ECHO;
            TIcmpHeaders::icmp_code(pkt) = 0;
            TIcmpHeaders::icmp_checksum(pkt) = 0;
            TIcmpHeaders::echo_id(pkt) = Ident_ + (packet.Stats().ProbeId & MAX_IDENT);
            TIcmpHeaders::echo_seq(pkt) = HostToInet(packet.Stats().Seqno);

            packet.SaveStats();
            packet.FillData();

            TIcmpHeaders::icmp_checksum(pkt) = TPacket::HdrBodyChecksum(nullptr, 0, reinterpret_cast<ui16*>(packet.Data()), packet.Size(), 0);
        }

        bool ValidatePacket(TIcmpPacket& packet) noexcept override {
            if (packet.Size() < sizeof(TIcmpHeaders::IP)) {
                return false;
            }

            const TIcmpHeaders::IP& pkt = *reinterpret_cast<const TIcmpHeaders::IP*>(packet.Data());

            // FIXME: linux kernel gives me reassembled packet, so packet size may be
            // over 9000. Hey, jumbograms! So, is ip->tot_len length of the last/single
            // fragment or of the full packet in case of fragmentation?

            const std::size_t ipHeaderLength = TIcmpHeaders::ip_hdrlen(pkt) * 4;
            const std::size_t ipPayloadLength = packet.Size() - ipHeaderLength;
            if (
                TIcmpHeaders::ip_version(pkt) != 4 ||
#if defined(_freebsd_) || defined(_darwin_)
                TIcmpHeaders::ip_pktlen(pkt) + ipHeaderLength != packet.Size() ||
#else
                InetToHost(TIcmpHeaders::ip_pktlen(pkt)) != packet.Size() ||
#endif
                packet.Size() < ipHeaderLength || ipHeaderLength < sizeof(TIcmpHeaders::IP) ||
                TIcmpHeaders::ip_proto(pkt) != IPPROTO_ICMP ||
                ipPayloadLength < sizeof(TIcmpHeaders::ICMP)
            ) {
                return false;
            }

            const auto* body = reinterpret_cast<const char*>(packet.Data());
            const auto& icmp = *reinterpret_cast<const TIcmpHeaders::ICMP*>(body + ipHeaderLength);

            if (!packet.FromErrorQueue() && TIcmpHeaders::icmp_type(icmp) != ICMP_ECHOREPLY) {
                return false;
            }
            if ((TIcmpHeaders::echo_id(icmp) & IDENT_MASK) != Ident_) {
                return false;
            }

            packet.LoadStats();
            packet.Stats().ProbeId = TIcmpHeaders::echo_id(icmp) & MAX_IDENT;
            packet.Stats().Seqno = InetToHost(TIcmpHeaders::echo_seq(icmp));
            return true;
        }
    };

    class TSocketAdapter6 : public TSocketAdapter {
    public:
        inline TSocketAdapter6(const NAddr::IRemoteAddrRef& addr)
            : TSocketAdapter(addr, AF_INET6, SOCK_RAW, IPPROTO_ICMPV6)
        {
#if defined(_linux_)
            // Use following command to generate BPF assembler:
            // tcpdump -i tmp-tun "ether[0] == 129 and (ether[4:2] & 0xf0ff) == 0xDDBF" -dd
            //
            // Offsets:
            //   packet[0] - icmp type, 129 in case of icmpv6 reply
            //   packet[4:2] - icmp ident
            struct sock_filter code[] = {
                { 0x30, 0, 0, 0x00000000 },
                { 0x15, 0, 4, 0x00000081 },
                { 0x28, 0, 0, 0x00000004 },
                { 0x54, 0, 0, 0x0000f0ff },
                { 0x15, 0, 1, HostToInet(Ident_) },
                { 0x6, 0, 0, 0x00040000 },
                { 0x6, 0, 0, 0x00000000 }
            };
            struct sock_fprog bpf = {
                .len = sizeof(code) / sizeof(code[0]),
                .filter = code,
            };
            // Attach created code to our socket
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_ATTACH_FILTER, bpf, "attach BPF");
#elif defined(_freebsd_) || defined(_darwin_)
            icmp6_filter filter;
            ICMP6_FILTER_SETBLOCKALL(&filter);
            ICMP6_FILTER_SETPASS(ICMP6_ECHO_REPLY, &filter);
            CheckedSetSockOpt(Socket_, IPPROTO_ICMPV6, ICMP6_FILTER, filter, "attach echo filter");
#else
#   error NO ICMPv6 filtering support, any sane OS does.
#endif

            const int cksumOffset = offsetof(icmp6_hdr, icmp6_cksum);
#if defined SOL_RAW && defined IPV6_CHECKSUM // @Linux
            CheckedSetSockOpt(Socket_, SOL_RAW, IPV6_CHECKSUM, cksumOffset, "enable checksum");
#elif defined IPPROTO_IPV6 && defined IPV6_CHECKSUM
            CheckedSetSockOpt(Socket_, IPPROTO_IPV6, IPV6_CHECKSUM, cksumOffset, "enable checksum");
#else
#   error NO IPv6 checksum offset support
#endif

            DoNotFragment(AF_INET6, Socket_);
            if (Addr_->Addr()->sa_family == AF_INET6) {
                CheckedSetSockOpt(Socket_, IPPROTO_IPV6, IPV6_RECVTCLASS, 1, "ipv6 recv tos");
            }
        }

        void FillPacket(TIcmpPacket& packet) noexcept override {
            icmp6_hdr* const icmp = reinterpret_cast<icmp6_hdr*>(packet.Data());

            icmp->icmp6_type = ICMP6_ECHO_REQUEST;
            icmp->icmp6_code = 0;
            icmp->icmp6_cksum = 0; // will be calculated by the mighty kernel
            // FIXME: what about HostToInet
            icmp->icmp6_id = Ident_ + (packet.Stats().ProbeId & MAX_IDENT);
            icmp->icmp6_seq = HostToInet(packet.Stats().Seqno);

            packet.SaveStats();
            packet.FillData();
        }

        bool ValidatePacket(TIcmpPacket& packet) noexcept override {
            // IPv4 and IPv6 difference: IPv4 replies come with IP header and IPv6 without it.
            if (packet.Size() < sizeof(icmp6_hdr)) {
                return false;
            }

            const icmp6_hdr* icmp = reinterpret_cast<const icmp6_hdr*>(packet.Data());
            if (!packet.FromErrorQueue() && icmp->icmp6_type != ICMP6_ECHO_REPLY) {
                return false;
            }
            if ((icmp->icmp6_id & IDENT_MASK) != Ident_) {
                return false;
            }

            packet.LoadStats();
            packet.Stats().ProbeId = icmp->icmp6_id & MAX_IDENT;
            packet.Stats().Seqno = InetToHost(icmp->icmp6_seq);
            return true;
        }
    };

    THolder<TSocketAdapter> TSocketAdapter::Make(const NAddr::IRemoteAddrRef& addr) {
        switch (addr->Addr()->sa_family) {
            case AF_INET:
                return MakeHolder<TSocketAdapter4>(addr);
            case AF_INET6:
                return MakeHolder<TSocketAdapter6>(addr);
            default:
                ythrow yexception() << TStringBuf("unknown family given");
        };
    }

    class TIcmpIOHandler::TImpl : public TNonCopyable {
    public:
        TImpl(TLog& logger, const NAddr::IRemoteAddrRef& addr)
            : Logger_(logger)
            , Adapter_(TSocketAdapter::Make(addr))
            , Handler_(Logger_, *Adapter_)
        {
        }

        ~TImpl() {
            if (Adapter_) {
                Adapter_->Close();
            }
        }

        inline const NAddr::IRemoteAddrRef& Addr() const noexcept {
            return Adapter_->Addr();
        }

        inline TMaybeIOStatus Read(TIcmpPacket& packet) noexcept {
            auto result(Handler_.Read(packet));
            if (packet.Size() && !Adapter_->ValidatePacket(packet)) {
                return Nothing();
            }
            if (result.Defined() && result->Processed()) {
                packet.Stats().SourceReceivedTime = Handler_.RcvdTime();
            }
            return result;
        }

        inline TMaybeIOStatus Write(TIcmpPacket& packet) noexcept {
            const auto timeToLive(packet.TimeToLive());
            if (timeToLive != LastTimeToLive) {
                SetTimeToLive(Addr()->Addr()->sa_family, *Adapter_, timeToLive);
                LastTimeToLive = timeToLive;
            }

            Adapter_->FillPacket(packet);
            return Handler_.Write(packet);
        }

        inline void PollD(TCont* cont, ui16 opFilter, const TInstant& deadline) {
            Handler_.PollD(cont, opFilter, deadline);
        }

        inline void PollT(TCont* cont, ui16 opFilter, const TDuration& duration) {
            Handler_.PollT(cont, opFilter, duration);
        }

    private:
        TLog& Logger_;
        THolder<TSocketAdapter> Adapter_;
        TIOHandler Handler_;
        i32 LastTimeToLive = -1;
    };

    TIcmpIOHandler::TIcmpIOHandler(TLog& logger, const NAddr::IRemoteAddrRef& addr)
        : Impl(MakeHolder<TImpl>(logger, addr))
    {
    }

    TIcmpIOHandler::~TIcmpIOHandler() {
    }

    const NAddr::IRemoteAddrRef& TIcmpIOHandler::Addr() const noexcept {
        Y_VERIFY(Impl);
        return Impl->Addr();
    }

    TMaybeIOStatus TIcmpIOHandler::Read(TIcmpPacket& packet) noexcept {
        Y_VERIFY(Impl);
        return Impl->Read(packet);
    }

    TMaybeIOStatus TIcmpIOHandler::Write(TIcmpPacket& packet) noexcept {
        Y_VERIFY(Impl);
        return Impl->Write(packet);
    }

    void TIcmpIOHandler::PollD(TCont* cont, ui16 opFilter, const TInstant& deadline) {
        Y_VERIFY(Impl);
        Impl->PollD(cont, opFilter, deadline);
    }

    void TIcmpIOHandler::PollT(TCont* cont, ui16 opFilter, const TDuration& duration) {
        Y_VERIFY(Impl);
        Impl->PollT(cont, opFilter, duration);
    }

    ui16 TIcmpIOHandler::MaxProbeId() noexcept {
        return MAX_IDENT;
    }

    void TIcmpIOHandler::Close() noexcept {
        Y_VERIFY(Impl);
        Impl.Destroy();
    }
}
