#pragma once

#include <infra/netmon/agent/common/buf.h>
#include <infra/netmon/agent/common/metrics.h>
#include <infra/netmon/agent/common/settings.h>
#include <infra/netmon/agent/common/utils.h>

#include <util/memory/smallobj.h>
#include <util/network/address.h>
#include <util/ysaveload.h>

namespace NNetmon {
    namespace {
        const std::size_t MAX_PACKET_LENGTH = 8192;

        // We use one of these masks to check type of service
        // for response packets, depending on settings.
        // check_full_tos enabled: ignore only ECN bits
        const i32 FULL_TOS_MASK = 0b11111100;
        // check_full_tos disabled: check just 3 first bits
        const i32 SHORT_TOS_MASK = 0b11100000;

        const i32 ECN_CAPABLE_MASK    = 0b00000010;
        const i32 ECN_CONGESTION_MASK = 0b00000011;

        enum {
            CS0 = 0,
            CS1 = 32,
            CS2 = 64,
            CS3 = 96,
            CS4 = 128
        };
    }

    class TPacket {
    public:
        virtual ~TPacket() = default;

        // Defines the structure of the packet header.
        struct TStats {
            TStats()
                : Signature(0)
                , ProbeId(0)
                , Seqno(0)
                , SourceSentTime(0)
                , SourceReceivedTime(0)
                , TargetReceivedTime(0)
                , TargetSentTime(0)
                , TargetRespTime(0)
            {
            }

            std::pair<ui64, double> ComputeRoundTrip() const {
                // try to compensate delay on target side
                ui64 targetDelay = (TargetSentTime > TargetReceivedTime) ? TargetSentTime - TargetReceivedTime : 0;
                PushSignal(EPushSignals::TargetDelay, (double)targetDelay);

                ui64 roundTripTime = (SourceReceivedTime > SourceSentTime) ? SourceReceivedTime - SourceSentTime : 0;
                double clockSkewSys = 0;
                if (targetDelay && roundTripTime > targetDelay) {
                    roundTripTime -= targetDelay;

                    /*
                        d - delta between wall clocks on src and dst hosts (current dst time = current src time + delta)
                        TgtRcv = SrcSnt + rtt/2 + d
                        SrcRcv = TgtSnt + rtt/2 - d

                        2 * d = TgtRcv - SrcSnt + TgtSnt - SrcRcv
                    */
                    clockSkewSys = std::fabs(((double)TargetReceivedTime - (double)SourceSentTime + (double)TargetSentTime - (double)SourceReceivedTime) / 2.0);
                }
                return std::make_pair(roundTripTime, clockSkewSys);
            }

            static inline void Copy(const TStats& source, TStats& target) {
                target.TargetReceivedTime = source.TargetReceivedTime;
                target.TargetSentTime = source.TargetSentTime;
                target.SourceReceivedTime = source.SourceReceivedTime;
            }

            bool inline operator==(const TStats& other) const {
                //Only Signature, ProbeId, Seqno, SourceSentTime are initialized on sender
                return Signature == other.Signature &&
                       ProbeId == other.ProbeId &&
                       Seqno == other.Seqno &&
                       SourceSentTime == other.SourceSentTime;
            }

            bool inline operator!=(const TStats& other) const {
                return !((*this) == other);
            }

            // distinguish valid probe from anything else
            ui32 Signature;

            // distinguish parallel probes to same address
            ui32 ProbeId;

            // distinguish probes in series
            ui16 Seqno;

            // the timestamp when the probe was sent (usec)
            ui64 SourceSentTime;

            // the timestamp when the probe was received (usec)
            ui64 SourceReceivedTime;

            // the timestamp when probe was received by target
            ui64 TargetReceivedTime;

            // the timestamp when target sent this probe
            ui64 TargetSentTime;

            // the timestamp when target replied with this probe
            ui64 TargetRespTime;

            Y_SAVELOAD_DEFINE(Signature, ProbeId, Seqno, SourceSentTime, TargetReceivedTime, TargetSentTime);
        };

        TPacket(TOnDemandBuffer::TPool& pool, i32 timeToLive=-1)
            : Buf_(pool)
            , Length_(0)
            , TimeToLive_(timeToLive)
            , Truncated_(false)
            , FromErrorQueue_(false)
        {
            ::memset(Data(), 0, Capacity());
            ::memset(ClientAddrData(), 0, ClientAddrCapacity());
        }

        TPacket(TOnDemandBuffer::TPool& pool, const NAddr::IRemoteAddr& addr, int tos = 0, i32 timeToLive=-1)
            : TPacket(pool, timeToLive)
        {
            Y_VERIFY(ClientAddrCapacity() >= addr.Len());
            ::memcpy(ClientAddrData(), addr.Addr(), addr.Len());
            TypeOfService_ = addr.Addr()->sa_family == AF_INET6 ? tos : 0;
        }

        inline void* Data() noexcept {
            return Buf_.Data();
        }

        inline std::size_t Capacity() const noexcept {
            return Buf_.Capacity();
        }

        inline void DetachBuffer() noexcept {
            Buf_.Reset();
        }

        inline std::size_t& Size() noexcept {
            return Length_;
        }
        inline const std::size_t& Size() const noexcept {
            return Length_;
        }

        inline sockaddr* ClientAddrData() noexcept {
            return (sockaddr*)&ClientAddr_;
        }
        inline const sockaddr* ClientAddrData() const noexcept {
            return (sockaddr*)&ClientAddr_;
        }

        inline std::size_t ClientAddrCapacity() noexcept {
            return sizeof(ClientAddr_);
        }

        inline i32 TimeToLive() const noexcept {
            return TimeToLive_;
        }

        inline i32& TypeOfService() noexcept {
            return TypeOfService_;
        }

        inline i32 TypeOfService() const noexcept {
            return TypeOfService_;
        }

        inline bool& Truncated() noexcept {
            return Truncated_;
        }
        inline bool Truncated() const noexcept {
            return Truncated_;
        }

        inline bool& FromErrorQueue() noexcept {
            return FromErrorQueue_;
        }
        inline bool FromErrorQueue() const noexcept {
            return FromErrorQueue_;
        }

        inline TMaybe<NAddr::TOpaqueAddr>& Offender() noexcept {
            return Offender_;
        }
        inline const TMaybe<NAddr::TOpaqueAddr>& Offender() const noexcept {
            return Offender_;
        }

        inline TStats& Stats() noexcept {
            return Stats_;
        }

        inline const TStats& Stats() const noexcept {
            return Stats_;
        }

        static inline ui16 HdrBodyChecksum(const ui16 *hdr, int hdrLen, const ui16 *body, int len, ui32 csum) {
            int nleft = hdrLen;
            const ui16 *w = hdr;
            ui32 sum = csum;
            ui16 answer;

            while (nleft > 1)  {
                sum += *w++;
                nleft -= 2;
            }

            if (nleft == 1) {
                sum += HostToInet(*(u_char *)w << 8);
            }

            w = body;
            nleft = len;

            while (nleft > 1)  {
                sum += *w++;
                nleft -= 2;
            }

            if (nleft == 1) {
                sum += HostToInet(*(u_char *)w << 8);
            }

            sum = (sum >> 16) + (sum & 0xffff); /* add hi 16 to low 16 */
            sum += (sum >> 16);         /* add carry */
            answer = ~sum;              /* truncate to 16 bits */
            return (answer);
        }

        static inline bool SameTypeOfService(i32 lhs, i32 rhs) {
            static const auto mask = TSettings::Get()->GetCheckFullTypeOfService()
                                   ? FULL_TOS_MASK
                                   : SHORT_TOS_MASK;
            return (lhs & mask) == (rhs & mask);
        }

        inline bool CongestionEncountered() const {
            return (TypeOfService_ & ECN_CONGESTION_MASK) == ECN_CONGESTION_MASK;
        }

        inline void NOCifyTypeOfService() {
            static const auto mask = TSettings::Get()->GetCheckFullTypeOfService()
                                   ? FULL_TOS_MASK
                                   : SHORT_TOS_MASK;
            TypeOfService_ &= mask;
            if (TypeOfService_ == CS4) {
                TypeOfService_ |= ECN_CAPABLE_MASK;
            }
        }

        virtual std::size_t GetDataOffset() const {
            static constexpr std::size_t PACKET_DATA_OFFSET = sizeof(TStats::Signature) + sizeof(TStats::ProbeId) + sizeof(TStats::Seqno) +
                sizeof(TStats::SourceSentTime) + sizeof(TStats::TargetReceivedTime) +sizeof(TStats::TargetSentTime);

            return PACKET_DATA_OFFSET;
        }

        inline bool ValidateData(bool compatibility = false) const {
            ui64 filler = (static_cast<ui64>(Stats_.ProbeId) << 32) | Stats_.Seqno;
            ui64 *data = reinterpret_cast<ui64*>(Buf_.Data() + GetDataOffset());
            const ui64 *end = reinterpret_cast<ui64*>(Buf_.Data() + Size());

            if (compatibility && end - data >= 4) {
                // old agent compatibility
                if (UnalignedLoad64(data + 0) == filler &&
                    UnalignedLoad64(data + 1) == filler &&
                    (UnalignedLoad64(data + 2) & 0x0302010000000000ull) == 0x0302010000000000ull &&
                    UnalignedLoad64(data + 3) == 0x0b0a090807060504ull)
                {
                    return true;
                }
            }

            for (; data + 1 <= end; ++data) {
                if (UnalignedLoad64(data) != filler) {
                    return false;
                }
            }

            return true;
        }

        void FillData() {
            ui64 filler = (static_cast<ui64>(Stats_.ProbeId) << 32) | Stats_.Seqno;
            ui64 *data = reinterpret_cast<ui64*>(Buf_.Data() + GetDataOffset());
            const ui64 *end = reinterpret_cast<ui64*>(Buf_.Data() + Size());

            for (; data + 1 <= end; ++data) {
                UnalignedStore64(data, filler);
            }
        }

    protected:
        mutable TOnDemandBuffer Buf_;
        std::size_t Length_;
        struct sockaddr_storage ClientAddr_;
        i32 TimeToLive_;
        i32 TypeOfService_;
        bool Truncated_;
        bool FromErrorQueue_;
        TMaybe<NAddr::TOpaqueAddr> Offender_;
        TStats Stats_;
    };
}
