#include "handler.h"

#include <infra/netmon/agent/common/settings.h>
#include <infra/netmon/agent/common/metrics.h>
#include <util/system/platform.h>
#include <util/random/random.h>

#include <linux/if_packet.h>
#include <linux/if_ether.h>
#include <linux/net_tstamp.h>
#include <linux/errqueue.h>
#include <linux/sockios.h>
#include <net/if.h>
#include <netinet/ip.h>
#include <netinet/ip6.h>
#include <sys/epoll.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <util/system/error.h>

#ifdef _linux_
#   include <linux/filter.h> // sock_filter and sock_fprog
#endif

#ifndef SOF_TIMESTAMPING_TX_SCHED
#define SOF_TIMESTAMPING_TX_SCHED (1<<8)
#endif

#ifndef SIOCGHWTSTAMP
#define SIOCGHWTSTAMP   0x89b1
#endif

/*
    SOF_TIMESTAMPING_OPT_ID not supported
    on AF_PACKET until very recent kernels
*/

namespace NNetmon {
    struct scm_timestamping {
        struct timespec ts[3];
    };

    const int HardwareTimestamping = SOF_TIMESTAMPING_RAW_HARDWARE |
                                     SOF_TIMESTAMPING_TX_HARDWARE |
                                     SOF_TIMESTAMPING_RX_HARDWARE;

    const int SoftwareTimestamping = SOF_TIMESTAMPING_RX_SOFTWARE |
                                     SOF_TIMESTAMPING_SOFTWARE |
                                     SOF_TIMESTAMPING_TX_SCHED;

    static inline bool TimespecNonZero(const struct timespec& ts) {
        return ts.tv_sec || ts.tv_nsec;
    }

    static inline ui64 TimespecToUs(const struct timespec& ts) {
        return ts.tv_sec * 1000000UL + ts.tv_nsec / 1000UL;
    }

    class TLinkSocketHandler: public TNonCopyable {
    public:
        static inline THolder<TLinkSocketHandler> Make(const NAddr::IRemoteAddrRef& addr,
                                                       const NAddr::TIPv6Addr& srcIp) {
            switch (addr->Addr()->sa_family) {
                case AF_PACKET:
                    return MakeHolder<TLinkSocketHandler>(addr, srcIp);
                default:
                    ythrow yexception() << TStringBuf("unknown family given for link poller");
            };
        }

        inline TLinkSocketHandler(const NAddr::IRemoteAddrRef& addr,
                                  const NAddr::TIPv6Addr& srcIp)
            : Addr_(addr)
            , Socket_(::socket(AF_PACKET, SOCK_RAW, htons(LL_SOCK_PROTO)))
            , HwTimestamps_(TSettings::Get()->GetUseHwTimestamps())
        {
            if (Socket_ == INVALID_SOCKET) {
                ythrow TSystemError() << TStringBuf("can't create socket for pinger");
            }

            SetNonBlock(Socket_);
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_RCVBUF, MAX_PACKET_LENGTH * 256, "recv buffer");
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_SNDBUF, MAX_PACKET_LENGTH * 256, "send buffer");

            struct hwtstamp_config hwt;
            struct ifreq ifr;

            ::memset(&ifr, 0, sizeof(ifr));

            struct sockaddr_ll *sll = (struct sockaddr_ll *)Addr_->Addr();

            ifr.ifr_ifindex = sll->sll_ifindex;

            if (ioctl(Socket_, SIOCGIFNAME, &ifr)) {
                ythrow TSystemError(LastSystemError()) << TStringBuf("cannot get ifname");
            }

            int tsFlags = SoftwareTimestamping;
            if (HwTimestamps_) {
                tsFlags |= HardwareTimestamping;
            }

            ifr.ifr_ifindex = 0;
            ifr.ifr_data = (char *)&hwt;

            /* Probe current hw tx/rx timestamp setting */
            if (!ioctl(Socket_, SIOCGHWTSTAMP, &ifr)) {
                /* Try to reset hw tx/rx clocks if available and not matching */
                if ((hwt.tx_type != HWTSTAMP_TX_OFF ||
                     hwt.rx_filter != HWTSTAMP_FILTER_NONE ||
                     hwt.flags != 0)) {

                    hwt.flags = 0;
                    hwt.tx_type = HWTSTAMP_TX_OFF;
                    hwt.rx_filter = HWTSTAMP_FILTER_NONE;
                    ifr.ifr_ifindex = 0;
                    ifr.ifr_data = (char *)&hwt;

                    if (ioctl(Socket_, SIOCSHWTSTAMP, &ifr)) {
                        /* Fallback to pure software */
                        tsFlags = SoftwareTimestamping;
                        HwTimestamps_ = false;
                    }
                }

                /*
                   Works for 10G Intel 82599 and 1G Intel
                   Enable with care: dropped while rx ring overflow packets
                   cause rx timestamping hang with watchdog reset, see:
                   drivers/net/ethernet/intel/ixgbe/ixgbe_ptp.c:424 for 4.4 kernel
                */
                if (tsFlags & SOF_TIMESTAMPING_RAW_HARDWARE) {
                    hwt.flags = 0;
                    hwt.tx_type = HWTSTAMP_TX_ON;
                    hwt.rx_filter = HWTSTAMP_FILTER_PTP_V2_L4_EVENT;
                    ifr.ifr_ifindex = 0;
                    ifr.ifr_data = (char *)&hwt;

                    if (ioctl(Socket_, SIOCSHWTSTAMP, &ifr)) {
                        tsFlags = SoftwareTimestamping;
                        HwTimestamps_ = false;
                    }
                }
            } else {
                /* No hw timestamping? Nothing to enable */
                tsFlags = SoftwareTimestamping;
                HwTimestamps_ = false;
            }


#ifdef SO_TIMESTAMPING // only on Linux
            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_TIMESTAMPING,
                              tsFlags, "timestamping opts");
#endif

            if (bind(Socket_, Addr_->Addr(), Addr_->Len()) < 0) {
                ythrow TSystemError() << TStringBuf("bind to packet socket failed");
            }

            AttachSockFilter(srcIp);

            EPollFd_ = epoll_create(1);
            if (EPollFd_ < 0) {
                ythrow TSystemError(LastSystemError()) <<
                    TStringBuf("failed to create epoll for link poller");
            }

            EPollEvent_.data.fd = Socket_;
            EPollEvent_.events = EPOLLIN | EPOLLERR;

            if (epoll_ctl(EPollFd_, EPOLL_CTL_ADD, Socket_, &EPollEvent_) < 0) {
                ythrow TSystemError(LastSystemError()) << TStringBuf("bind to packet socket failed");
            }
        }

        virtual ~TLinkSocketHandler() = default;

        inline const NAddr::IRemoteAddrRef& Addr() const noexcept {
            return Addr_;
        }

        void Close() noexcept {
            Socket_.Close();
            ::close(EPollFd_);
        }

        inline operator SOCKET() const noexcept {
            return Socket_;
        }

        TMaybeIOStatus Write(void *data, std::size_t dataLen) {
            int res = ::sendto(Socket_, data, dataLen, 0, Addr_->Addr(), Addr_->Len());

            return res ? TContIOStatus::Success(res) :
                          TContIOStatus::Error(LastSystemError());
        }

        TMaybeIOStatus GetTxTs(ui64 probeId, struct timespec& txHwTime, struct timespec& txSysTime) {
            struct msghdr msg;
            struct iovec iov;
            struct sockaddr_ll sll;
            char data[MAX_PACKET_LENGTH];
            char ctrl[256];

            ::memset(&txHwTime, 0, sizeof(txHwTime));
            ::memset(&txSysTime, 0, sizeof(txSysTime));
            ::memset(ctrl, 0, sizeof(ctrl));
            ::memset(&msg, 0, sizeof(msg));

            msg.msg_iov = &iov;
            msg.msg_iovlen = 1;
            iov.iov_base = data;
            iov.iov_len = sizeof(data);
            msg.msg_name = (void *)&sll;
            msg.msg_namelen = sizeof(sll);
            msg.msg_control = ctrl;
            msg.msg_controllen = sizeof(ctrl);

            ssize_t res;

            while ((res = ::recvmsg(Socket_, &msg, MSG_ERRQUEUE)) > 0) {
                struct scm_timestamping *tss = nullptr;
                ui64 tmp = 0lu;

                struct cmsghdr* cmsg = nullptr;
                for (cmsg = CMSG_FIRSTHDR(&msg);
                     cmsg;
                     cmsg = CMSG_NXTHDR(&msg, cmsg)) {
                    if (cmsg->cmsg_level == SOL_SOCKET &&
                        cmsg->cmsg_type == SCM_TIMESTAMPING &&
                        cmsg->cmsg_len >= CMSG_LEN(sizeof(struct scm_timestamping))) {
                        tss = (struct scm_timestamping *)CMSG_DATA(cmsg);
                    }
                }

                if (tss && TLinkPacket::ExtractProbeIdFromRaw(data, res, tmp)) {
                    if (probeId == tmp) {
                        if (TimespecNonZero(tss->ts[2])) {
                            txHwTime = tss->ts[2];
                        }
                        if (TimespecNonZero(tss->ts[0])) {
                            txSysTime = tss->ts[0];
                        }
                    }
                }

                /* Note that we may catch many packets in case of TX delays over 100 ms */
            }

            if (TimespecNonZero(txHwTime) || TimespecNonZero(txSysTime)) {
                return TContIOStatus::Success(0);
            }

            int err = LastSystemError();

            if (err == EAGAIN || err == EWOULDBLOCK) {
                return Nothing();
            }

            return TContIOStatus::Error(LastSystemError());
        }

        TMaybeIOStatus Read(void* data, std::size_t dataLen, struct timespec& rxHwTime, struct timespec& rxSysTime) {
            struct msghdr msg;
            struct iovec iov;
            struct sockaddr_ll sll;
            char ctrl[256];

            ::memset(&rxHwTime, 0, sizeof(rxHwTime));
            ::memset(&rxSysTime, 0, sizeof(rxSysTime));
            ::memset(&ctrl, 0, sizeof(ctrl));
            ::memset(&msg, 0, sizeof(msg));

            msg.msg_iov = &iov;
            msg.msg_iovlen = 1;
            iov.iov_base = data;
            iov.iov_len = dataLen;
            msg.msg_name = (void *)&sll;
            msg.msg_namelen = sizeof(sll);
            msg.msg_control = ctrl;
            msg.msg_controllen = sizeof(ctrl);

            ssize_t res;

            while ((res = ::recvmsg(Socket_, &msg, 0)) > 0) {
                struct cmsghdr* cmsg = nullptr;
                for (cmsg = CMSG_FIRSTHDR(&msg);
                     cmsg;
                     cmsg = CMSG_NXTHDR(&msg, cmsg)) {
                    if (cmsg->cmsg_level == SOL_SOCKET &&
                        cmsg->cmsg_type == SCM_TIMESTAMPING &&
                        cmsg->cmsg_len >= CMSG_LEN(sizeof(struct scm_timestamping))) {
                        struct scm_timestamping *tss = (struct scm_timestamping *)CMSG_DATA(cmsg);
                        rxHwTime = tss->ts[2];
                        rxSysTime = tss->ts[0];
                    }
                }

                return TContIOStatus::Success(res);
            }

            int err = LastSystemError();

            if (err == EAGAIN || err == EWOULDBLOCK) {
                return Nothing();
            }

            return TContIOStatus::Error(err);
        }

        void PollD(TCont* cont, ui16 opFilter, const TInstant& deadline) {
            auto res = NCoro::PollD(cont, Socket_, opFilter, deadline);
            if (res != 0 && res != EIO && res != ETIMEDOUT && res != ECANCELED) {
                ythrow TSystemError(res) << TStringBuf("io error");
            }
        }

        void PollDWithoutYield(const TDuration& timeout, const TInstant& deadline) {
            auto now = TInstant::Now();
            if (deadline <= now) {
                return;
            }

            TDuration timeo = Min(deadline - now, timeout);
            struct epoll_event ev;

            if (epoll_wait(EPollFd_, &ev, 1, timeo.MilliSeconds()) < 0) {
                ythrow TSystemError(LastSystemError()) << TStringBuf("epoll wait error");
            }
        }

        bool HwTimestamps() {
            return HwTimestamps_;
        }

    private:
        void AttachSockFilter(const NAddr::TIPv6Addr& srcIp) {
#if defined(_linux_)
            /* Command to generate:
               $ sudo tcpdump -i eth0 'ip6[2:2] == 0x7255 and ip6[6:2] == 0x1101 and \
                                       ip6[8:4] == 0xaaaaaaaa and \
                                       ip6[12:4] == 0xbbbbbbbb and \
                                       ip6[16:4] == 0xcccccccc and \
                                       ip6[20:4] == 0xdddddddd' -dd */

            const struct sockaddr_in6* sin6 = reinterpret_cast<const struct sockaddr_in6*>(srcIp.Addr());
            const ui32* addr = sin6->sin6_addr.s6_addr32;

            struct sock_filter code[] = {
                { 0x28, 0, 0, 0x0000000c },
                { 0x15, 0, 13, 0x000086dd }, /* ETH_P_IPV6 */
                { 0x28, 0, 0, 0x00000010 },
                { 0x15, 0, 11, 0x00007255 }, /* NETMON_LINK_MAGIC */
                { 0x28, 0, 0, 0x00000014 },
                { 0x15, 0, 9, 0x00001101 }, /* IPPROTO_UDP and TTL == 1 */
                { 0x20, 0, 0, 0x00000016 },
                { 0x15, 0, 7, ntohl(addr[0]) },
                { 0x20, 0, 0, 0x0000001a },
                { 0x15, 0, 5, ntohl(addr[1]) },
                { 0x20, 0, 0, 0x0000001e },
                { 0x15, 0, 3, ntohl(addr[2]) },
                { 0x20, 0, 0, 0x00000022 },
                { 0x15, 0, 1, ntohl(addr[3]) },
                { 0x6, 0, 0, 0x00040000 },
                { 0x6, 0, 0, 0x00000000 },

            };

            struct sock_fprog bpf = {
                .len = sizeof(code) / sizeof(code[0]),
                .filter = code,
            };

            CheckedSetSockOpt(Socket_, SOL_SOCKET, SO_ATTACH_FILTER, bpf, "attach BPF");
#else
#   error NO BPF filtering support, any sane OS does.
#endif
    }

    protected:
        NAddr::IRemoteAddrRef Addr_;
        TSocketHolder Socket_;
        int EPollFd_;
        struct epoll_event EPollEvent_;
        bool HwTimestamps_;
    };

    class TLinkIOHandler::TImpl : public TNonCopyable {
    public:
        TImpl(TLog& logger,
              const NAddr::IRemoteAddrRef& addr,
              const NAddr::TIPv6Addr& srcIp)
            : Logger_(logger)
            , LinkSocketHandler_(TLinkSocketHandler::Make(addr, srcIp))
            , Handler_(Logger_, *LinkSocketHandler_)
        {
            if (TSettings::Get()->GetUseHwTimestamps() != LinkSocketHandler_->HwTimestamps()) {
                logger << TLOG_WARNING << "Cannot enable hardware timestamping, falling back to software" << Endl;
            }
        }

        ~TImpl() {
            if (LinkSocketHandler_) {
                LinkSocketHandler_->Close();
            }
        }

        void Reopen(const NAddr::IRemoteAddrRef& addr, const NAddr::TIPv6Addr& srcIp) {
            if (LinkSocketHandler_) {
                LinkSocketHandler_->Close();
            }

            LinkSocketHandler_ = TLinkSocketHandler::Make(addr, srcIp);
        }

        inline const NAddr::IRemoteAddrRef& Addr() const noexcept {
            return LinkSocketHandler_->Addr();
        }

        inline TMaybeIOStatus Read(TLinkPacket& packet) noexcept {
            struct timespec rxHwTs;
            struct timespec rxSysTs;

            auto res = LinkSocketHandler_->Read(packet.Data(), packet.Capacity(), rxHwTs, rxSysTs);
            if (res.Defined() && res->Processed() && !res->Status()) {
                packet.Size() = res->Processed();
                packet.UserTimestamps().RecvTime = TInstant::Now().MicroSeconds();
                if (TimespecNonZero(rxHwTs)) {
                    packet.HwTimestamps().RecvTime = TimespecToUs(rxHwTs);
                }
                if (TimespecNonZero(rxSysTs)) {
                    packet.SysTimestamps().RecvTime = TimespecToUs(rxSysTs);
                }
            } else {
                packet.Size() = 0;
            }

            return res;
        }

        inline TMaybeIOStatus GetTxTs(ui64 probeId, TLinkPacket& packet) {
            struct timespec txHwTs;
            struct timespec txSysTs;

            auto res = LinkSocketHandler_->GetTxTs(probeId, txHwTs, txSysTs);
            if (res.Defined() && !res->Status()) {
                if (TimespecNonZero(txHwTs)) {
                    packet.HwTimestamps().SentTime = TimespecToUs(txHwTs);
                }
                if (TimespecNonZero(txSysTs)) {
                    packet.SysTimestamps().SentTime = TimespecToUs(txSysTs);
                }
            }

            return res;
        }

        inline TMaybeIOStatus Write(TLinkPacket& packet) noexcept {
            packet.Fill();

            packet.UserTimestamps().SentTime = TInstant::Now().MicroSeconds();
            return LinkSocketHandler_->Write(packet.Data(), packet.Size());
        }

        inline void PollD(TCont* cont, ui16 opFilter, const TInstant& deadline) {
            if (LinkSocketHandler_) {
                LinkSocketHandler_->PollD(cont, opFilter, deadline);
            }
        }


        inline void PollDWithoutYield(const TDuration& timeout, const TInstant& deadline) {
            if (LinkSocketHandler_) {
                LinkSocketHandler_->PollDWithoutYield(timeout, deadline);
            }
        }

    private:
        TLog& Logger_;
        THolder<TLinkSocketHandler> LinkSocketHandler_;
        TIOHandler Handler_;
    };

    TLinkIOHandler::TLinkIOHandler(TLog& logger,
                                   const NAddr::IRemoteAddrRef& addr,
                                   const NAddr::TIPv6Addr& srcIp)
        : Impl(MakeHolder<TImpl>(logger, addr, srcIp))
    {
    }

    TLinkIOHandler::~TLinkIOHandler() {
    }

    const NAddr::IRemoteAddrRef& TLinkIOHandler::Addr() const noexcept {
        Y_VERIFY(Impl);
        return Impl->Addr();
    }


    void TLinkIOHandler::Reopen(const NAddr::IRemoteAddrRef& addr,
                                const NAddr::TIPv6Addr& srcIp) {
        Y_VERIFY(Impl);
        return Impl->Reopen(addr, srcIp);
    }

    TMaybeIOStatus TLinkIOHandler::GetTxTs(ui64 probeId, TLinkPacket& packet) noexcept {
        Y_VERIFY(Impl);
        return Impl->GetTxTs(probeId, packet);
    }

    TMaybeIOStatus TLinkIOHandler::Read(TLinkPacket& packet) noexcept {
        Y_VERIFY(Impl);
        return Impl->Read(packet);
    }

    TMaybeIOStatus TLinkIOHandler::Write(TLinkPacket& packet) noexcept {
        Y_VERIFY(Impl);
        return Impl->Write(packet);
    }

    void TLinkIOHandler::PollD(TCont* cont, ui16 opFilter, const TInstant& deadline) {
        Y_VERIFY(Impl);
        Impl->PollD(cont, opFilter, deadline);
    }

    void TLinkIOHandler::PollDWithoutYield(const TDuration& timeout, const TInstant& deadline) {
        Y_VERIFY(Impl);
        Impl->PollDWithoutYield(timeout, deadline);
    }

    void TLinkIOHandler::Close() noexcept {
        Y_VERIFY(Impl);
        Impl.Destroy();
    }
}
