#pragma once

#include <infra/netmon/state_keys.h>
#include <infra/netmon/library/memory_pool.h>
#include <infra/netmon/topology/types.h>

#include <infra/netmon/idl/probe.fbs.h>

#include <library/cpp/containers/intrusive_rb_tree/rb_tree.h>

#include <util/digest/multi.h>

namespace NNetmon {
    extern TAtomic MemPoolStatsProbeCount;

    namespace {
        struct TSwitchCompare {
            template <class T>
            static inline bool Compare(const T& lhs, const T& rhs) noexcept {
                if (lhs->GetSwitchKey() < rhs->GetSwitchKey()) {
                    return true;
                } else if (lhs->GetSwitchKey() == rhs->GetSwitchKey()) {
                    return &lhs < &rhs;
                } else {
                    return false;
                }
            }

            template <class TItem, class TKey>
            static inline bool Compare(const TItem& lhs, const TKey& rhs) noexcept {
                return lhs->GetSwitchKey() < rhs;
            }
        };

        // maintain generation order and delete the oldest
        struct TGeneratedCompare {
            template <class T>
            static inline bool Compare(const T& lhs, const T& rhs) noexcept {
                if (lhs->GetGenerated() < rhs->GetGenerated()) {
                    return true;
                } else if (lhs->GetGenerated() == rhs->GetGenerated()) {
                    // netmon use this index to distinct unique probes
                    return IsLess(
                        lhs->GetTargetIface()->GetComparator(),
                        rhs->GetTargetIface()->GetComparator(),
                        lhs->GetSourceIface()->GetComparator(),
                        rhs->GetSourceIface()->GetComparator()
                    );
                } else {
                    return false;
                }
            }

            template <class T>
            static inline bool Compare(const T& lhs, const TInstant& rhs) noexcept {
                return lhs->GetGenerated() < rhs;
            }
        };
    }

    // represent probe in-memory and define all related types
    // millions of probes can exists so try to reduce allocations
    class TProbe: public TNonCopyable, public TAtomicRefCount<TProbe>, public TSafeObjectFromPool<TProbe> {
    public:
        using TRef = TIntrusivePtr<TProbe>;
        using TRefVector = TVector<TRef>;

        template <class TCompare>
        class TIndexItem: public TRbTreeItem<TIndexItem<TCompare>, TCompare> {
        public:
            using TType = TIndexItem<TCompare>;
            using TTree = TRbTree<TType, TCompare>;

            inline TIndexItem(TProbe& probe)
                : Probe(probe)
            {
            }

            inline TProbe* Get() noexcept {
                return &Probe;
            }

            inline TProbe* Get() const noexcept {
                return &Probe;
            }

            inline TProbe* operator->() const noexcept {
                return &Probe;
            }

            inline TProbe& operator*() const noexcept {
                return Probe;
            }

            inline bool Exists(const TTree& tree) const noexcept {
                return tree.Find(*this) != nullptr;
            }

            inline bool Linked() const noexcept {
                return TType::ParentTree() != nullptr;
            }

        private:
            TProbe& Probe;
        };

        using TGeneratedItem = TIndexItem<TGeneratedCompare>;
        using TGeneratedTree = TGeneratedItem::TTree;

        using TSwitchItem = TIndexItem<TSwitchCompare>;
        using TSwitchTree = TSwitchItem::TTree;

        template <typename... Args>
        static inline TRef Make(Args&&... args) {
            AtomicIncrement(MemPoolStatsProbeCount);
            return new (Singleton<typename TSafeObjectFromPool<TProbe>::TPool>()) TProbe(std::forward<Args>(args)...);
        }

        inline const TTopology::THostInterfaceRef& GetSourceIface() const noexcept {
            return SourceIface;
        }
        inline const TTopology::THostInterfaceRef& GetTargetIface() const noexcept {
            return TargetIface;
        }

        inline const TIpAddress& GetSourceAddress() const {
            return SourceAddress;
        }
        inline const TIpAddress& GetTargetAddress() const {
            return TargetAddress;
        }
        inline ui16 GetSourcePort() const {
            return SourcePort;
        }
        inline ui16 GetTargetPort() const {
            return TargetPort;
        }

        inline const TInstant& GetGenerated() const noexcept {
            return Generated;
        }
        inline double GetScore() const noexcept {
            return Score;
        }
        inline double GetRoundTripTime() const noexcept {
            return RoundTripTime;
        }

        inline bool IsSuccessful() const noexcept {
            return GetScore() == 1.0;
        }
        inline bool IsDead(const TTopologyStorage::THostSet& seenHosts,
                           const TTopologyStorage::THostSet& terminatedHosts,
                           const TTopologyStorage::THostSet& deadHosts) const noexcept
        {
            Y_UNUSED(seenHosts);
            if (!IsSuccessful()) {
                const TTopology::THostRef sourceHost(GetSourceIface()->GetHost());
                const TTopology::THostRef targetHost(GetTargetIface()->GetHost());

                if (!terminatedHosts.find(sourceHost).IsEnd() || !terminatedHosts.find(targetHost).IsEnd()) {
                    return true;
#ifndef NOC_SLA_BUILD
                // In nocsla builds this check is disabled because it drops all unsuccessful cross-dc probes.
                // It's also useless because we only send probes to interested hosts.
                } else if (!seenHosts.empty() && seenHosts.find(targetHost).IsEnd()) {
                    return true;
#endif
                } else if (!deadHosts.find(sourceHost).IsEnd() || !deadHosts.find(targetHost).IsEnd()) {
                    return true;
                }
            }
            return false;
        }

        inline const TSwitchPairKey& GetSwitchKey() const noexcept {
            return SwitchKey;
        }

        inline TGeneratedItem& GetGeneratedItem() noexcept {
            return GeneratedItem;
        }
        inline TSwitchItem& GetSwitchItem() noexcept {
            return SwitchItem;
        }

        bool HasExpressionId(const TTopologySelector& selector, TExpressionId expressionId) const;
        void IntersectExpressionIds(const TTopologySelector& selector, TExpressionIdList& expressionIds) const;

        flatbuffers::Offset<NProbe::TProbe> ToProto(flatbuffers::FlatBufferBuilder& builder) const;

        void Out(IOutputStream& stream) const;

        ~TProbe() {
            AtomicDecrement(MemPoolStatsProbeCount);
        }

    private:
        explicit TProbe(
            const TTopology::THostInterfaceRef sourceIface,
            const TTopology::THostInterfaceRef targetIface,
            const TIpAddress& sourceAddress,
            const TIpAddress& targetAddress,
            ui16 sourcePort,
            ui16 targetPort,
            TInstant generated,
            double score,
            double roundTripTime)
            : SourceIface(sourceIface)
            , TargetIface(targetIface)
            , SwitchKey(TargetIface.GetSwitch(), SourceIface.GetSwitch())
            , SwitchItem(*this)
            , Generated(generated)
            , GeneratedItem(*this)
            , Score(score)
            , RoundTripTime(roundTripTime)
            , SourceAddress(sourceAddress)
            , TargetAddress(targetAddress)
            , SourcePort(sourcePort)
            , TargetPort(targetPort)
        {
            Y_VERIFY(SourceIface && TargetIface);
        }

        const TTopology::THostInterfaceRef SourceIface;
        const TTopology::THostInterfaceRef TargetIface;

        TSwitchPairKey SwitchKey;
        TSwitchItem SwitchItem;

        const TInstant Generated;
        TGeneratedItem GeneratedItem;

        const double Score;
        const double RoundTripTime;

        const TIpAddress SourceAddress;
        const TIpAddress TargetAddress;

        const ui16 SourcePort;
        const ui16 TargetPort;
    };

    // network, protocol
    class TProbeSectionKey: public std::tuple<ENetworkType, EProtocolType> {
    public:
        using tuple::tuple;

        inline ENetworkType GetNetwork() const {
            return std::get<0>(*this);
        }
        inline EProtocolType GetProtocol() const {
            return std::get<1>(*this);
        }

        void Out(IOutputStream& stream) const;
    };

    // network, protocol, shard
    class TProbeSliceKey: public std::tuple<ENetworkType, EProtocolType, ui64> {
    public:
        using tuple::tuple;

        inline ENetworkType GetNetwork() const {
            return std::get<0>(*this);
        }
        inline EProtocolType GetProtocol() const {
            return std::get<1>(*this);
        }
        inline ui64 GetShardIndex() const {
            return std::get<2>(*this);
        }

        inline TProbeSectionKey GetSectionKey() const {
            return TProbeSectionKey(GetNetwork(), GetProtocol());
        }

        void Out(IOutputStream& stream) const;
    };

    // expressionId, network, protocol
    class TProbeAggregatorKey: public std::tuple<TExpressionId, ENetworkType, EProtocolType> {
    public:
        TProbeAggregatorKey(const NCommon::TAggregatorKey& key);
        TProbeAggregatorKey(TExpressionId expressionId, ENetworkType networkType, EProtocolType protocolType);

        static inline TProbeAggregatorKey FromSectionKey(const TProbeSectionKey& key, TExpressionId expressionId) {
            return {expressionId, key.GetNetwork(), key.GetProtocol()};
        }

        static inline TProbeAggregatorKey FromSliceKey(const TProbeSliceKey& key, TExpressionId expressionId) {
            return {expressionId, key.GetNetwork(), key.GetProtocol()};
        }

        const NCommon::TAggregatorKey& ToProto() const;

        inline TExpressionId GetExpressionId() const {
            return std::get<0>(*this);
        }
        inline ENetworkType GetNetwork() const {
            return std::get<1>(*this);
        }
        inline EProtocolType GetProtocol() const {
            return std::get<2>(*this);
        }

        inline TProbeSectionKey GetSectionKey() const {
            return TProbeSectionKey(GetNetwork(), GetProtocol());
        }

        void Out(IOutputStream& stream) const;

    private:
        const NCommon::TAggregatorKey Flat;
    };

    using TProbeAggregatorKeySet = TSet<TProbeAggregatorKey>;
}

template <>
class THash<NNetmon::TProbeSectionKey> {
public:
    size_t operator()(const NNetmon::TProbeSectionKey& key) const {
        return MultiHash(key.GetNetwork(), key.GetProtocol());
    }
};

template <>
class THash<NNetmon::TProbeSliceKey> {
public:
    size_t operator()(const NNetmon::TProbeSliceKey& key) const {
        return MultiHash(key.GetNetwork(), key.GetProtocol(), key.GetShardIndex());
    }
};

template <>
class THash<NNetmon::TProbeAggregatorKey> {
public:
    size_t operator()(const NNetmon::TProbeAggregatorKey& key) const {
        return MultiHash(key.GetExpressionId(), key.GetNetwork(), key.GetProtocol());
    }
};
