#pragma once

#include <infra/netmon/agent/common/metrics.h>
#include <infra/netmon/agent/common/packet.h>
#include <infra/netmon/agent/common/settings.h>

#include <library/cpp/coroutine/engine/impl.h>
#include <library/cpp/coroutine/engine/network.h>
#include <library/cpp/logger/log.h>

#include <util/generic/maybe.h>
#include <util/system/sanitizers.h>

#if defined(_linux_)
#   include <linux/errqueue.h>
#   include <linux/icmp.h>
#   include <netinet/icmp6.h>
#endif

namespace NNetmon {
    using TMaybeIOStatus = TMaybe<TContIOStatus, NMaybe::TPolicyUndefinedFail>;

    namespace {
        inline int GetRecvFlags(bool errorQueue=false) noexcept {
#if defined(_linux_)
            return errorQueue ? MSG_ERRQUEUE : 0;
#else
            Y_UNUSED(errorQueue);
            return 0;
#endif
        }

        inline bool IsErrorQueueSupported() {
            return GetRecvFlags(true);
        }

        struct TReadMessageHeader : public TNonCopyable {
            inline TReadMessageHeader(TPacket& packet) noexcept
            {
                ::memset(&Message, 0, sizeof(Message));

                Message.msg_iov = &Entry;
                Message.msg_iovlen = 1;

                Entry.iov_base = packet.Data();
                Entry.iov_len = packet.Capacity();

                Message.msg_name = packet.ClientAddrData();
                Message.msg_namelen = packet.ClientAddrCapacity();

                ::memset(ControlBuffer, 0, sizeof(ControlBuffer));

                Message.msg_control = ControlBuffer;
                Message.msg_controllen = sizeof(ControlBuffer);
            }

            void ParseControlMessage() noexcept {
                struct cmsghdr* cmsg = nullptr;
                for (cmsg = CMSG_FIRSTHDR(&Message); cmsg; cmsg = CMSG_NXTHDR(&Message, cmsg)) {
                    NSan::Unpoison(cmsg, sizeof(cmsghdr));
                    switch (cmsg->cmsg_level) {
                        case SOL_SOCKET: {
                            switch (cmsg->cmsg_type) {
#ifdef SO_TIMESTAMPNS
                                case SO_TIMESTAMPNS: {
                                    if (cmsg->cmsg_len >= CMSG_LEN(sizeof(struct timespec))) {
                                        Timestamp = reinterpret_cast<struct timespec*>(CMSG_DATA(cmsg));
                                    }
                                    break;
                                }
#endif
#ifdef SO_RXQ_OVFL
                                case SO_RXQ_OVFL: {
                                    if (cmsg->cmsg_len >= CMSG_LEN(sizeof(ui32))) {
                                        Overflow = reinterpret_cast<ui32*>(CMSG_DATA(cmsg));
                                    }
                                    break;
                                }
#endif
                            }
                            break;
                        }
#if defined(_linux_)
                        case SOL_IP: {
                            if ((Message.msg_flags & MSG_ERRQUEUE)
                                    && cmsg->cmsg_type == IP_RECVERR
                                    && cmsg->cmsg_len >= CMSG_LEN(sizeof(sock_extended_err))
                            ) {
                                ExtendedError = reinterpret_cast<sock_extended_err*>(CMSG_DATA(cmsg));
                            }
                            break;
                        }
                        case SOL_IPV6: {
                            if ((Message.msg_flags & MSG_ERRQUEUE)
                                    && cmsg->cmsg_type == IPV6_RECVERR
                                    && cmsg->cmsg_len >= CMSG_LEN(sizeof(sock_extended_err))
                            ) {
                                ExtendedError = reinterpret_cast<sock_extended_err*>(CMSG_DATA(cmsg));
                            }

                            if (cmsg->cmsg_type == IPV6_TCLASS &&
                                cmsg->cmsg_len >= CMSG_LEN(sizeof(i32))) {
                                v6TrafficClass = reinterpret_cast<i32*>(CMSG_DATA(cmsg));
                            }
                            break;
                        }
#endif
                    }
                }
            }

            inline bool MessageTruncated() const noexcept {
                return Message.msg_flags & MSG_TRUNC;
            }

            inline bool ControlMessageTruncated() const noexcept {
                return Message.msg_flags & MSG_CTRUNC;
            }

            inline ui64 ReceivedTime() const noexcept {
                if (Timestamp) {
                    return Timestamp->tv_sec * 1000000UL + Timestamp->tv_nsec / 1000UL;
                } else {
                    return TInstant::Now().MicroSeconds();
                }
            }

            inline ui64 DroppedPackets() const noexcept {
                if (Overflow) {
                    return *Overflow;
                } else {
                    return 0;
                }
            }

            inline i32 TrafficClass() const noexcept {
                if (v6TrafficClass) {
                    return *v6TrafficClass;
                } else {
                    return 0;
                }
            }

            inline TMaybe<NAddr::TOpaqueAddr> Offender() const noexcept {
#if defined(_linux_)
                if (ExtendedError && (
                        (ExtendedError->ee_origin == SO_EE_ORIGIN_ICMP && ExtendedError->ee_type == ICMP_TIME_EXCEEDED)
                        || (ExtendedError->ee_origin == SO_EE_ORIGIN_ICMP6 && ExtendedError->ee_type == ICMP6_TIME_EXCEEDED)
                )) {
                    sockaddr* offender = SO_EE_OFFENDER(ExtendedError);
                    return NAddr::TOpaqueAddr(offender);
                }
#endif
                return Nothing();
            }

            struct msghdr Message;
            struct iovec Entry;
            char ControlBuffer[256];

            struct timespec* Timestamp = nullptr;
            ui32* Overflow = nullptr;
#if defined(_linux_)
            sock_extended_err* ExtendedError = nullptr;
#endif
            i32* v6TrafficClass = nullptr;
        };

        class TReadOperation : public TNonCopyable {
        public:
            inline TReadOperation(SOCKET sock, TPacket& packet, bool errorQueue=false) noexcept
                : Header_(packet)
                , Started_(TInstant::Now().MicroSeconds())
                , ReadLength_(::recvmsg(sock, &Header_.Message, GetRecvFlags(errorQueue)))
                , ErrorCode_(ReadLength_ <= 0 ? LastSystemError() : 0)
            {
                packet.FromErrorQueue() = errorQueue;
                if (Success()) {
                    packet.Size() = ReadLength_;
                }
                if (Header_.MessageTruncated()) {
                    // too big datagram, remember it
                    packet.Truncated() = true;
                }
                if (!Header_.ControlMessageTruncated()) {
                    Header_.ParseControlMessage();
                    packet.Offender() = Header_.Offender();
                    packet.TypeOfService() = Header_.TrafficClass();
                }
            }

            inline bool Success() const noexcept {
                return ReadLength_ > 0;
            }

            inline bool Retryable() const noexcept {
                return ErrorCode_ == EAGAIN || ErrorCode_ == EWOULDBLOCK;
            }

            inline TContIOStatus ToResult() const noexcept {
                Y_ASSERT(Success());
                return TContIOStatus::Success((size_t)ReadLength_);
            }

            inline TContIOStatus ToError() const noexcept {
                Y_ASSERT(!Success());
                return TContIOStatus::Error(ErrorCode_);
            }

            inline ui64 ReceivedTime() const noexcept {
                return Header_.ReceivedTime();
            }

            inline ui64 ReceiveLag() const noexcept {
                const auto received(Header_.ReceivedTime());
                return received > Started_ ? received - Started_ : 0;
            }

            inline ui64 DroppedPackets() const noexcept {
                return Header_.DroppedPackets();
            }

        private:
            TReadMessageHeader Header_;
            const ui64 Started_;
            const ssize_t ReadLength_;
            const int ErrorCode_;
        };

        struct TSendMessageHeader : public TNonCopyable {
            inline TSendMessageHeader(TPacket& packet, size_t packetSize) noexcept
            {
                ::memset(&Message, 0, sizeof(Message));

                Message.msg_iov = &Entry;
                Message.msg_iovlen = 1;

                Entry.iov_base = packet.Data();
                Entry.iov_len = Min(packetSize, packet.Size());

                Message.msg_name = packet.ClientAddrData();
                Message.msg_namelen = NAddr::SockAddrLength(packet.ClientAddrData());

                if (packet.TypeOfService()) {
                    Message.msg_control = v6TClassBuffer;
                    Message.msg_controllen = sizeof(v6TClassBuffer);

                    struct cmsghdr *cmsg = CMSG_FIRSTHDR(&Message);
                    i32 *tclass = (i32 *)CMSG_DATA(cmsg);

                    cmsg->cmsg_level = SOL_IPV6;
                    cmsg->cmsg_type = IPV6_TCLASS;
                    cmsg->cmsg_len = CMSG_LEN(sizeof(*tclass));

                    if (TSettings::Get()->GetFixReplyTypeOfService()) {
                        packet.NOCifyTypeOfService();
                    }
                    *tclass = packet.TypeOfService();
                } else {
                    Message.msg_control = nullptr;
                    Message.msg_controllen = 0;
                }
            }

            struct msghdr Message;
            struct iovec Entry;
            ui8 v6TClassBuffer[CMSG_SPACE(sizeof(i32))];
        };

        class TSendOperation : public TNonCopyable {
        public:
            inline TSendOperation(SOCKET sock, TPacket& packet, size_t packetSize) noexcept
                : Header_(packet, packetSize)
                , SendLength_(::sendmsg(sock, &Header_.Message, 0))
                , ErrorCode_(SendLength_ <= 0 ? LastSystemError() : 0)
            {
            }

            inline bool Success() const noexcept {
                return SendLength_ > 0;
            }

            inline bool Retryable() const noexcept {
                return ErrorCode_ == EAGAIN || ErrorCode_ == EWOULDBLOCK;
            }

            inline int ErrorCode() const noexcept {
                return ErrorCode_;
            }

            inline TContIOStatus ToResult() const noexcept {
                Y_ASSERT(Success());
                return TContIOStatus::Success((size_t)SendLength_);
            }

            inline TContIOStatus ToError() const noexcept {
                Y_ASSERT(!Success());
                return TContIOStatus::Error(ErrorCode_);
            }

        private:
            TSendMessageHeader Header_;
            const ssize_t SendLength_;
            const int ErrorCode_;
        };
    }

    class TIOHandler : public TNonCopyable {
    public:
        inline TIOHandler(TLog& logger, SOCKET socket)
            : Logger_(logger)
            , Socket_(socket)
            , LastRcvdOverflow_(0)
            , RcvdTime_(0)
            , SentTime_(0)
        {
        }

        inline TMaybeIOStatus Read(TPacket& packet) noexcept {
            const TReadOperation operation(Socket_, packet);
            if (!operation.Success()) {
                if (IsErrorQueueSupported()) {
                    const TReadOperation errorOperation(Socket_, packet, true);
                    if (errorOperation.Retryable()) {
                        return Nothing();
                    } else if (errorOperation.Success()) {
                        return errorOperation.ToResult();
                    } else {
                        return operation.ToError();
                    }
                } else {
                    // error queue not supported
                    if (operation.Retryable()) {
                        return Nothing();
                    } else {
                        return operation.ToError();
                    }
                }
            }

            RcvdTime_ = operation.ReceivedTime();

            const auto receiveLag(operation.ReceiveLag());
            if (receiveLag) {
                TUnistat::Instance().PushSignalUnsafe(ESignals::ProbeReceiveLag, receiveLag);
            }

            const auto overflow(operation.DroppedPackets());
            if (overflow && LastRcvdOverflow_ != overflow) {
                ui32 overflowDelta = overflow - LastRcvdOverflow_;
                Logger_ << TLOG_WARNING << "overflow detected: " << overflowDelta << " packets dropped";
                LastRcvdOverflow_ = overflow;
                PushSignal(EPushSignals::ProbeOverflows, overflowDelta);
            }

            return operation.ToResult();
        }

        inline TMaybeIOStatus Write(TPacket& packet) noexcept {
            bool skipOutOfBandError = true;
            size_t packetSize = packet.Size();
            do {
                const TSendOperation operation(Socket_, packet, packetSize);
                if (operation.Success()) {
                    SentTime_ = TInstant::Now().MicroSeconds();
                    return operation.ToResult();
                } else if (operation.Retryable()) {
                    return Nothing();
                } else if (operation.ErrorCode() == EMSGSIZE) {
                    // reduce size of message that is too big to send
                    packetSize /= 2;
                    if (!packetSize) {
                        return operation.ToError();
                    }
                } else if (
                        skipOutOfBandError
                        && (operation.ErrorCode() == EHOSTUNREACH || operation.ErrorCode() == ENETUNREACH)
                ) {
                    // this block can be triggered by out-of-band (e.g. icmp packets) errors, skip it for the first time
                    skipOutOfBandError = false;
                } else {
                    return operation.ToError();
                }
            } while (true);
        }

        void PollD(TCont* cont, ui16 opFilter, const TInstant& deadline) {
            auto res = NCoro::PollD(cont, Socket_, opFilter, deadline);
            if (res != 0 && res != EIO && res != ETIMEDOUT && res != ECANCELED) {
                ythrow TSystemError(res) << TStringBuf("io error");
            }
        }

        void PollT(TCont* cont, ui16 opFilter, const TDuration& duration) {
            PollD(cont, opFilter, duration.ToDeadLine());
        }

        inline ui64 RcvdTime() const noexcept {
            return RcvdTime_;
        }

        inline ui64 SentTime() const noexcept {
            return SentTime_;
        }

    private:
        TLog& Logger_;
        const SOCKET Socket_;

        ui64 LastRcvdOverflow_;
        ui64 RcvdTime_;
        ui64 SentTime_;
    };

    inline void DoNotFragment(int family, SOCKET sock) {
        switch (family) {
            case AF_INET6: {
#if defined(_linux_)
                CheckedSetSockOpt(sock, IPPROTO_IPV6, IPV6_MTU_DISCOVER, IPV6_PMTUDISC_DO, "don't fragment");
#elif defined(_darwin_)
                CheckedSetSockOpt(sock, IPPROTO_IPV6, 62, 1, "don't fragment");
#endif
                return;
            }
            case AF_INET: {
#if defined(_linux_)
                CheckedSetSockOpt(sock, SOL_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO, "don't fragment");
#endif
                return;
            }
            default: {
                return;
            }
        }
    }

    inline void SetTimeToLive(int family, SOCKET sock, i32 ttl) {
        ttl = Max(Min(ttl, 255), -1);
        switch (family) {
            case AF_INET6: {
                CheckedSetSockOpt(sock, IPPROTO_IPV6, IPV6_UNICAST_HOPS, ttl, "time-to-live");
                return;
            }
            case AF_INET: {
                CheckedSetSockOpt(sock, IPPROTO_IP, IP_TTL, ttl, "time-to-live");
                return;
            }
            default: {
                return;
            }
        }
    }

    inline void EnableErrorQueue(int family, SOCKET sock) {
        switch (family) {
            case AF_INET6: {
#ifdef IPV6_RECVERR
                CheckedSetSockOpt(sock, SOL_IPV6, IPV6_RECVERR, 1, "error queue");
#endif
                return;
            }
            case AF_INET: {
#ifdef IP_RECVERR
                CheckedSetSockOpt(sock, SOL_IP, IP_RECVERR, 1, "error queue");
#endif
                return;
            }
            default: {
                Y_UNUSED(sock);
                return;
            }
        }
    }
}
