#pragma once

#include <infra/netmon/probe.h>
#include <infra/netmon/library/memory_pool.h>

namespace NNetmon {
    extern TAtomic MemPoolStatsPairStateCount;

    template <class TState>
    struct TCompareState {
        static inline bool Compare(const TState& lhs, const TState& rhs) noexcept {
            return lhs.GetKey() < rhs.GetKey();
        }

        template <class TKey>
        static inline bool Compare(const TKey& lhs, const TState& rhs) noexcept {
            return lhs < rhs.GetKey();
        }

        template <class TKey>
        static inline bool Compare(const TState& lhs, const TKey& rhs) noexcept {
            return lhs.GetKey() < rhs;
        }
    };

    template <class T, class TKeyType>
    class TObjectPairState: public TNonCopyable,
                            public TRbTreeItem<T, TCompareState<T>>,
                            public TSafeObjectFromPool<T> {
    public:
        using TKey = TKeyType;
        using TRef = THolder<T>;
        using TTree = TRbTree<T, TCompareState<T>>;

        template <typename... Args>
        static inline TRef Make(Args&&... args) {
            AtomicIncrement(MemPoolStatsPairStateCount);
            return TRef(new (Singleton<typename TSafeObjectFromPool<T>::TPool>()) T(std::forward<Args>(args)...));
        }

        virtual ~TObjectPairState() {
            AtomicDecrement(MemPoolStatsPairStateCount);
        }

        inline const typename TKey::TType& GetSource() const noexcept {
            return *Key.GetSource();
        }
        inline const typename TKey::TType& GetTarget() const noexcept {
            return *Key.GetTarget();
        }

        inline const typename TKey::TTypeRef& GetSourceRef() const noexcept {
            return Key.GetSource();
        }
        inline const typename TKey::TTypeRef& GetTargetRef() const noexcept {
            return Key.GetTarget();
        }

        inline const TKey& GetKey() const noexcept {
            return Key;
        }

    protected:
        TObjectPairState(const typename TKey::TTypeRef& target,
                         const typename TKey::TTypeRef& source)
            : Key(target, source)
        {
        }

        const TKey Key;
    };

    template <class T, class TStateType>
    class TObjectPairIndex: public TNonCopyable,
                            public TAtomicRefCount<T>,
                            public TSafeObjectFromPool<T> {
    public:
        using TRef = TIntrusivePtr<T>;
        using TTree = typename TStateType::TTree;
        using TState = TStateType;

        template <typename... Args>
        static inline TRef Make(Args&&... args) {
            return new (Singleton<typename TSafeObjectFromPool<T>::TPool>()) T(std::forward<Args>(args)...);
        }

        virtual ~TObjectPairIndex() {
            while (!Tree.Empty()) {
                typename TState::TRef state(&(*Tree.Begin()));
            }
        }

        inline TTree& GetTree() noexcept {
            return Tree;
        }
        inline const TTree& GetTree() const noexcept {
            return Tree;
        }

        inline const TProbeAggregatorKey& GetKey() const noexcept {
            return Key;
        }

    protected:
        TObjectPairIndex(const TProbeAggregatorKey& key)
            : Key(key)
        {
        }

        const TProbeAggregatorKey Key;
        TTree Tree;
    };

    template <class TLeftIndex, class TRightIndex>
    class TMergeIterator {
    public:
        using TLeftState = typename TLeftIndex::TState;
        using TRightState = typename TRightIndex::TState;
        using TPair = std::pair<TLeftState*, TRightState*>;

        TMergeIterator(const TLeftIndex& left, const TRightIndex& right)
            : CurrentPair(nullptr, nullptr)
            , LeftIterator(left.GetTree().Begin())
            , LeftEnd(left.GetTree().End())
            , RightIterator(right.GetTree().Begin())
            , RightEnd(right.GetTree().End())
        {
            if (LeftIterator != LeftEnd && RightIterator != RightEnd) {
                LeftIterator = left.GetTree().LowerBound(RightIterator->GetKey());
            }
        }

        TMergeIterator(const TMergeIterator& other)
            : CurrentPair(other.CurrentPair)
            , LeftIterator(other.LeftIterator)
            , LeftEnd(other.LeftEnd)
            , RightIterator(other.RightIterator)
            , RightEnd(other.RightEnd)
        {
        }

        inline TMergeIterator begin() noexcept {
            TMergeIterator other(*this);
            other.Inc();
            return other;
        }

        inline TMergeIterator end() noexcept {
            TMergeIterator other(*this);
            other.LeftCurrent() = nullptr;
            other.RightCurrent() = nullptr;
            return other;
        }

        inline const TPair& operator*() const noexcept {
            return CurrentPair;
        }

        inline const TPair* operator->() const noexcept {
            return &CurrentPair;
        }

        inline TMergeIterator& operator++() noexcept {
            Inc();
            return *this;
        }

        bool operator==(const TMergeIterator& other) const noexcept {
            return CurrentPair == other.CurrentPair;
        }

        inline bool operator!=(const TMergeIterator& other) const noexcept {
            return !(*this == other);
        }

    private:
        void Inc() {
            LeftCurrent() = nullptr;
            RightCurrent() = nullptr;
            if (LeftIterator != LeftEnd && RightIterator != RightEnd) {
                if (LeftIterator->GetKey() < RightIterator->GetKey()) {
                    LeftCurrent() = &(*LeftIterator);
                    ++LeftIterator;
                } else if (RightIterator->GetKey() < LeftIterator->GetKey()) {
                    RightCurrent() = &(*RightIterator);
                    ++RightIterator;
                } else {
                    LeftCurrent() = &(*LeftIterator);
                    RightCurrent() = &(*RightIterator);
                    ++LeftIterator;
                    ++RightIterator;
                }
            } else if (RightIterator != RightEnd) {
                RightCurrent() = &(*RightIterator);
                ++RightIterator;
            }
        }

        inline TLeftState*& LeftCurrent() noexcept {
            return CurrentPair.first;
        }

        inline TRightState*& RightCurrent() noexcept {
            return CurrentPair.second;
        }

        TPair CurrentPair;

        typename TLeftIndex::TTree::TIterator LeftIterator;
        const typename TLeftIndex::TTree::TIterator LeftEnd;

        typename TRightIndex::TTree::TIterator RightIterator;
        const typename TRightIndex::TTree::TIterator RightEnd;
    };
}
