#include "iter_heap.h"
#include <library/cpp/testing/unittest/registar.h>

namespace NFusion {
    class TLemmaIteratorMock {
    public:
        TVector<TString> MockData;
        size_t Pos = 0;

        size_t CallsNext = 0;

    public:
        TLemmaIteratorMock() = default;

        TLemmaIteratorMock(TVector<TString>&& mockData)
            : MockData(mockData)
        {
        }

        TStringBuf GetLemma() {
            return MockData[Pos];
        }

        bool Valid() const {
            return Pos < MockData.size();
        }

        void Next() {
            ++CallsNext;
            ++Pos;
        }
    };

    template <typename TIteratorPtr>
    struct TLemmaComparerMock {
        bool operator()(const TIteratorPtr& a, const TIteratorPtr& b) const {
            Y_ASSERT(a && b && a->Valid() && b->Valid());
            return a->GetLemma() > b->GetLemma();
        }
    };

    class TIteratorHeapUnitTest: public NUnitTest::TTestBase {
        UNIT_TEST_SUITE(TIteratorHeapUnitTest)
            UNIT_TEST(TestHeapBasics);
        UNIT_TEST_SUITE_END();

        using TMockHeap = TIteratorHeap<TLemmaIteratorMock, TLemmaComparerMock<TLemmaIteratorMock*>>;

        void TestHeapBasics() {
            using TInputs = TVector<TLemmaIteratorMock>;
            auto createMock = [](TInputs& collection, TVector<TVector<TString>> mockLemmas) {
                THolder<TMockHeap> result = MakeHolder<TMockHeap>();
                collection.clear();
                for (TVector<TString>& m : mockLemmas) {
                    collection.push_back(TLemmaIteratorMock(std::move(m)));
                }

                for (TLemmaIteratorMock& i : collection) {
                    result->Add(&i);
                }
                return result;
            };

            auto readAll = [](TMockHeap& heap) {
                TString result;
                TStringOutput ss(result);
                for (; heap.Valid(); heap.Next()) {
                    ss << heap.GetHead()->GetLemma() << " ";
                }
                ss.Finish();
                return result;
            };

            auto mockCallsNext = [](const TInputs& coll) {
                size_t nCalls = 0;
                for(const auto& i : coll) {
                    nCalls += i.CallsNext;
                }
                return nCalls;
            };

            // test basic heap behavior
            TInputs inputs;
            THolder<TMockHeap> state;
            state = createMock(inputs, {});
            UNIT_ASSERT_EQUAL("", readAll(*state));
            UNIT_ASSERT_EQUAL(0, mockCallsNext(inputs));
            state = createMock(inputs, {{}});
            UNIT_ASSERT_EQUAL("", readAll(*state));
            UNIT_ASSERT_EQUAL(0, mockCallsNext(inputs));
            state = createMock(inputs, {{""}});
            UNIT_ASSERT_EQUAL(" ", readAll(*state));
            UNIT_ASSERT_EQUAL(1, mockCallsNext(inputs));
            state = createMock(inputs, {{"a", "b"}});
            UNIT_ASSERT_EQUAL("a b ", readAll(*state));
            UNIT_ASSERT_EQUAL(2, mockCallsNext(inputs));
            state = createMock(inputs, {{"a", "b"}});
            UNIT_ASSERT_EQUAL("a b ", readAll(*state));
            UNIT_ASSERT_EQUAL(2, mockCallsNext(inputs));
            state = createMock(inputs, {{"a"}, {"b"}});
            UNIT_ASSERT_EQUAL("a b ", readAll(*state));
            UNIT_ASSERT_EQUAL(2, mockCallsNext(inputs));
            state = createMock(inputs, {{"b"}, {"a"}});
            UNIT_ASSERT_EQUAL("a b ", readAll(*state));
            UNIT_ASSERT_EQUAL(2, mockCallsNext(inputs));
            state = createMock(inputs, {{"a", "b"}, {"c"}});
            UNIT_ASSERT_EQUAL("a b c ", readAll(*state));
            UNIT_ASSERT_EQUAL(3, mockCallsNext(inputs));
            state = createMock(inputs, {{"a", "c"}, {"b"}});
            UNIT_ASSERT_EQUAL("a b c ", readAll(*state));
            UNIT_ASSERT_EQUAL(3, mockCallsNext(inputs));
            state = createMock(inputs, {{"b", "c"}, {"a"}});
            UNIT_ASSERT_EQUAL("a b c ", readAll(*state));
            UNIT_ASSERT_EQUAL(3, mockCallsNext(inputs));
            state = createMock(inputs, {{"a", "b", "c"}, {"a"}});
            UNIT_ASSERT_EQUAL("a a b c ", readAll(*state));
            UNIT_ASSERT_EQUAL(4, mockCallsNext(inputs));
            state = createMock(inputs, {{"aa"}, {"aa", "ab", "ac"}});
            UNIT_ASSERT_EQUAL("aa aa ab ac ", readAll(*state));
            UNIT_ASSERT_EQUAL(4, mockCallsNext(inputs));
            state = createMock(inputs, {{"a", "b"}, {"c", "d"}});
            UNIT_ASSERT_EQUAL("a b c d ", readAll(*state));
            UNIT_ASSERT_EQUAL(4, mockCallsNext(inputs));
            state = createMock(inputs, {{"0", "2"}, {}, {"1", "3", "5"}, {"5"}, {"0"}, {""}, {"2"}});
            UNIT_ASSERT_EQUAL(" 0 0 1 2 2 3 5 5 ", readAll(*state));
            UNIT_ASSERT_EQUAL(9, mockCallsNext(inputs));
        }
    };

}

UNIT_TEST_SUITE_REGISTRATION(NFusion::TIteratorHeapUnitTest);
