#include <infra/netmon/state_index_base.h>
#include <infra/netmon/common_ut.h>
#include <infra/netmon/topology/common_ut.h>

#include <library/cpp/testing/unittest/registar.h>

using namespace NNetmon;

class TFakeState: public TObjectPairState<TFakeState, TDatacenterPairKey> {
public:
    friend TObjectPairState<TFakeState, TDatacenterPairKey>;

    TFakeState(const TTopology::TDatacenterRef& target,
               const TTopology::TDatacenterRef& source,
               std::size_t index)
        : TObjectPairState(target, source)
        , Index(index)
    {
    }

    std::size_t GetIndex() const {
        return Index;
    }

private:
    const std::size_t Index;
};

class TFakeIndex: public TObjectPairIndex<TFakeIndex, TFakeState> {
public:
    friend TObjectPairIndex<TFakeIndex, TFakeState>;

    using TObjectPairIndex::TObjectPairIndex;
};

class TMergeIteratorTest: public TTestBase {
    UNIT_TEST_SUITE(TMergeIteratorTest);
    UNIT_TEST(TestEmpty)
    UNIT_TEST(TestIterator)
    UNIT_TEST_SUITE_END();

public:
    TMergeIteratorTest()
        : TopologyStorage(TGlobalTopology::GetTopologyStorage())
        , SectionKey{NIL_NETWORK, NIL_PROTOCOL}
        , AggregatorKey(TProbeAggregatorKey::FromSectionKey(SectionKey, TopologyStorage.DefaultExpressionId()))
    {
    }

public:
    inline void TestEmpty() {
        auto left(TFakeIndex::Make(AggregatorKey));
        auto right(TFakeIndex::Make(AggregatorKey));

        std::size_t elements = 0;
        for (auto& pair : TMergeIterator<TFakeIndex, TFakeIndex>(*left, *right)) {
            Y_UNUSED(pair);
            elements++;
        }
        UNIT_ASSERT(!elements);
    }

    inline void TestIterator() {
        UNIT_ASSERT(FindDatacenter("man") < FindDatacenter("sas"));
        UNIT_ASSERT(FindDatacenter("sas") < FindDatacenter("vla"));

        auto left(TFakeIndex::Make(AggregatorKey));
        InsertState(left->GetTree(), FindDatacenter("man"), FindDatacenter("man"), 1);
        InsertState(left->GetTree(), FindDatacenter("sas"), FindDatacenter("vla"), 2);
        InsertState(left->GetTree(), FindDatacenter("vla"), FindDatacenter("sas"), 3);

        auto right(TFakeIndex::Make(AggregatorKey));
        InsertState(right->GetTree(), FindDatacenter("man"), FindDatacenter("man"), 4);
        InsertState(right->GetTree(), FindDatacenter("sas"), FindDatacenter("man"), 5);
        InsertState(right->GetTree(), FindDatacenter("vla"), FindDatacenter("man"), 6);
        InsertState(right->GetTree(), FindDatacenter("vla"), FindDatacenter("vla"), 7);

        TVector<std::size_t> leftIndexes, rightIndexes;
        for (auto& pair : TMergeIterator<TFakeIndex, TFakeIndex>(*left, *right)) {
            leftIndexes.push_back(pair.first ? pair.first->GetIndex() : 0);
            rightIndexes.push_back(pair.second ? pair.second->GetIndex() : 0);
            // let's mutate underlying trees
            if (!leftIndexes.back()) {
                InsertState(left->GetTree(), pair.second->GetTargetRef(), pair.second->GetSourceRef(), pair.second->GetIndex());
            } else if (!rightIndexes.back()) {
                InsertState(right->GetTree(), pair.first->GetTargetRef(), pair.first->GetSourceRef(), pair.first->GetIndex());
            }
        }

        TVector<std::size_t> leftRef{1, 0, 2, 0, 3, 0};
        TVector<std::size_t> rightRef{4, 5, 0, 6, 0, 7};
        UNIT_ASSERT_EQUAL(leftIndexes, leftRef);
        UNIT_ASSERT_EQUAL(rightIndexes, rightRef);

        for (auto& pair : TMergeIterator<TFakeIndex, TFakeIndex>(*left, *right)) {
            UNIT_ASSERT_EQUAL(pair.first->GetKey(), pair.second->GetKey());
        }
    }

    template <typename... Args>
    void InsertState(TFakeIndex::TTree& tree, Args&&... args) {
        auto state(TFakeState::Make(std::forward<Args>(args)...));
        tree.Insert(state.Release());
    }

    TTopologyStorage& TopologyStorage;
    const TProbeSectionKey SectionKey;
    const TProbeAggregatorKey AggregatorKey;
};

UNIT_TEST_SUITE_REGISTRATION(TMergeIteratorTest);
