#include <passport/infra/libs/cpp/logbroker/processing/reader_pool.h>

#include <library/cpp/testing/unittest/registar.h>

using namespace NPassport;
using namespace NPassport::NLb;

Y_UNIT_TEST_SUITE(ReaderPool) {
    class TPool: public TReaderPool<TString> {
        using TBase = TReaderPool<TString>;

    public:
        using TBase::CreateRaw;
        using TBase::Parse;
        using TBase::QueueIn_;
        using TBase::QueueInSize_;
        using TBase::QueueOut_;
        using TBase::QueueOutSize_;
        using TBase::TBase;
        using TBase::Worker;
    };

    static TString Parser(const TDataSet<>& s) {
        TStringStream res;

        for (const TDataWithType<>& d : s) {
            for (const TTopicData& t : d.Data.Messages) {
                for (const TChunk& m : t.Data) {
                    res << m.Data << ";";
                }
            }
        }

        return res.Str();
    }

    static TPool::TRawDataPtr CreateRaw(const TString& s) {
        TData d;
        d.Messages.emplace_back();
        d.Messages.back().Data.push_back(TChunk{.Data = s});
        return TPool::CreateRaw(std::move(d), {}, 0);
    }

    Y_UNIT_TEST(parse) {
        TPool pool(TReaderPoolSettings{.JoinMessages = 2});

        TSmallVec<TPool::TRawDataPtr> raws;
        std::vector<TPool::TParsedDataPtr> results;

        pool.Parse(Parser, raws);
        pool.QueueOut_.DequeueAll(&results);
        UNIT_ASSERT_VALUES_EQUAL(1, results.size());
        UNIT_ASSERT_VALUES_EQUAL("", results.at(0)->Data);
        results.clear();

        raws = {
            CreateRaw("foo"),
            CreateRaw("bar"),
            CreateRaw("kek"),
        };
        pool.Parse(Parser, raws);
        pool.QueueOut_.DequeueAll(&results);
        UNIT_ASSERT_VALUES_EQUAL(1, results.size());
        UNIT_ASSERT_VALUES_EQUAL("foo;bar;kek;", results.at(0)->Data);
    }

    Y_UNIT_TEST(workerWithJoinOne) {
        TPool pool(TReaderPoolSettings{.QueueSizeLimit = 2, .JoinMessages = 1});
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueOutSize_.GetValue());

        UNIT_ASSERT(!pool.Worker(Parser));
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueOutSize_.GetValue());

        pool.QueueIn_.EnqueueAll(std::vector<TPool::TRawDataPtr>{
            CreateRaw("foo"),
            CreateRaw("bar"),
        });
        pool.QueueInSize_ += 2;
        UNIT_ASSERT(pool.Worker(Parser));
        UNIT_ASSERT_VALUES_EQUAL(1, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(1, pool.QueueOutSize_.GetValue());

        UNIT_ASSERT(pool.Worker(Parser));
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(2, pool.QueueOutSize_.GetValue());

        pool.QueueIn_.EnqueueAll(std::vector<TPool::TRawDataPtr>{
            CreateRaw("foo2"),
            CreateRaw("bar2"),
            CreateRaw("kek2"),
        });
        pool.QueueInSize_ += 3;
        UNIT_ASSERT(!pool.Worker(Parser));
        UNIT_ASSERT_VALUES_EQUAL(3, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(2, pool.QueueOutSize_.GetValue());

        std::vector<TPool::TParsedDataPtr> results;
        pool.QueueOut_.DequeueAll(&results);
        UNIT_ASSERT_VALUES_EQUAL(2, results.size());
    }

    Y_UNIT_TEST(workerWithJoinTwo) {
        TPool pool(TReaderPoolSettings{.QueueSizeLimit = 2, .JoinMessages = 2});
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueOutSize_.GetValue());

        std::vector<TPool::TRawDataPtr> tmp;

        UNIT_ASSERT(!pool.Worker(Parser));
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueOutSize_.GetValue());

        pool.QueueIn_.EnqueueAll(std::vector<TPool::TRawDataPtr>{
            CreateRaw("foo"),
            CreateRaw("bar"),
            CreateRaw("kek"),
        });
        pool.QueueInSize_ += 3;
        UNIT_ASSERT(pool.Worker(Parser));
        UNIT_ASSERT_VALUES_EQUAL(1, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(1, pool.QueueOutSize_.GetValue());

        UNIT_ASSERT(pool.Worker(Parser));
        UNIT_ASSERT_VALUES_EQUAL(0, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(2, pool.QueueOutSize_.GetValue());

        pool.QueueIn_.EnqueueAll(std::vector<TPool::TRawDataPtr>{
            CreateRaw("foo2"),
            CreateRaw("bar2"),
            CreateRaw("kek2"),
        });
        pool.QueueInSize_ += 3;
        UNIT_ASSERT(!pool.Worker(Parser));
        UNIT_ASSERT_VALUES_EQUAL(3, pool.QueueInSize_.GetValue());
        UNIT_ASSERT_VALUES_EQUAL(2, pool.QueueOutSize_.GetValue());

        std::vector<TPool::TParsedDataPtr> results;
        pool.QueueOut_.DequeueAll(&results);
        UNIT_ASSERT_VALUES_EQUAL(2, results.size());
    }
}
