#include <passport/infra/libs/cpp/yt/batch.h>

#include <library/cpp/testing/unittest/env.h>
#include <library/cpp/testing/unittest/registar.h>

using namespace NPassport::NYt;

Y_UNIT_TEST_SUITE(Batch) {
    class TYtClientTest: public TYtClient {
    public:
        NThreading::TFuture<TWriteResult> Write(TWriteQuery req) override {
            return req.Size == 0
                       ? NThreading::NewPromise<TWriteResult>().GetFuture()
                       : NThreading::MakeFuture<TWriteResult>();
        }
    };

    class TRequestConverterTest: public IQueryConverter {
    public:
        TRequestConverterTest(bool success = true)
            : Success(success)
        {
            Calls = {
                {0, 3},
                {3, 2},
            };
        }

        size_t size() const override {
            return 5;
        }

        TWriteQuery Convert(
            const TString&,
            size_t offset,
            size_t count) const override {
            UNIT_ASSERT(!Calls.empty());
            UNIT_ASSERT_VALUES_EQUAL(offset, Calls.front().first);
            UNIT_ASSERT_VALUES_EQUAL(count, Calls.front().second);
            Calls.erase(Calls.begin());

            TWriteQuery res;
            res.Size = Success ? 1 : 0;
            return res;
        }

        TString DebugString() const override {
            return "";
        }

        void Merge(TQueryConverterPtr) override {
            ythrow yexception() << "not implemented";
        }

        std::type_index GetType() const override {
            ythrow yexception() << "not implemented";
        }

        mutable std::vector<std::pair<size_t, size_t>> Calls;
        bool Success;
    };

    Y_UNIT_TEST(badSettings) {
        TYtClientTest client;
        UNIT_ASSERT_EXCEPTION_CONTAINS(TBatch({.Limit = 0}, client),
                                       yexception,
                                       "limit == 0 is illegal size for batch");
    }

    Y_UNIT_TEST(success) {
        TBatchSettings settings{
            .Limit = 3,
            .Timeout = TDuration::MilliSeconds(1),
        };
        TYtClientTest client;
        TBatch batch(settings, client);

        TRequestConverterTest builder;

        std::vector<size_t> rows = {3, 2};
        batch.Send(
            "table#1",
            builder,
            [&](TBatch::TQueryResult& r) {
                UNIT_ASSERT(!rows.empty());
                UNIT_ASSERT_VALUES_EQUAL(r.Rows, rows.front());
                rows.erase(rows.begin());
            });
        UNIT_ASSERT(rows.empty());
        UNIT_ASSERT(builder.Calls.empty());

        size_t calls = 0;
        UNIT_ASSERT(batch.WaitResult(
            [&](TBatch::TQueryResult&) {
                ++calls;
                return true;
            }));
        UNIT_ASSERT_VALUES_EQUAL(calls, 2);

        calls = 0;
        UNIT_ASSERT(!batch.WaitResult(
            [&](TBatch::TQueryResult&) {
                ++calls;
                return false;
            }));
        UNIT_ASSERT_VALUES_EQUAL(calls, 1);
    }

    Y_UNIT_TEST(fail) {
        TBatchSettings settings{
            .Limit = 3,
            .Timeout = TDuration::MilliSeconds(1),
        };
        TYtClientTest client;
        TBatch batch(settings, client);

        TRequestConverterTest builder(false);

        std::vector<size_t> rows = {3, 2};
        batch.Send(
            "table#1",
            builder,
            [&](TBatch::TQueryResult& r) {
                UNIT_ASSERT(!rows.empty());
                UNIT_ASSERT_VALUES_EQUAL(r.Rows, rows.front());
                rows.erase(rows.begin());
            });
        UNIT_ASSERT(rows.empty());
        UNIT_ASSERT(builder.Calls.empty());

        size_t calls = 0;
        UNIT_ASSERT_EXCEPTION(
            batch.WaitResult(
                [&](TBatch::TQueryResult&) {
                    ++calls;
                    return true;
                }),
            yexception);
        UNIT_ASSERT_VALUES_EQUAL(calls, 0);
    }
}
