#include "offroad_saas_serf.h"

#include <library/cpp/testing/unittest/registar.h>
#include <util/stream/buffer.h>
#include <util/generic/bitops.h>
#include <util/generic/buffer.h>
#include <util/string/hex.h>

#include <random>

using namespace NRTYSerf;

class TOffroadSerfTest: public NUnitTest::TTestBase {
    UNIT_TEST_SUITE(TOffroadSerfTest)
    UNIT_TEST(TestEmpty);
    UNIT_TEST(TestOneRow);
    UNIT_TEST(TestNoFat);
    UNIT_TEST(TestSparseRng);
    UNIT_TEST_SUITE_END();

private:
    static TString PrintSimpleRow(const TRowfatSearcher::TSimpleRow& row) {
        TStringStream ss;
        ss << "(" << row.RowId << "," << row.ByteOffset << ",";
        auto mask = row.FactorMask.to_ullong();
        ui32 nBits = !row.FactorMask.any() ? 1 : MostSignificantBit(mask) + 1;
        if (nBits > 32)
            ss << std::bitset<64>(mask).to_string();
        else if (nBits > 16)
            ss << std::bitset<32>(mask).to_string();
        else if (nBits > 8)
            ss << std::bitset<16>(mask).to_string();
        else
            ss << std::bitset<8>(mask).to_string().substr(8 - nBits);
        ss << "b)";
        return ss.Str();
    }

    struct TPrinter {
        TStringStream Ss_;
        bool operator()(ui32 rowid, ui32 offsetdw, ui64 factormask) {
            if (!Ss_.Empty())
                Ss_ << " ";
            Ss_ << PrintSimpleRow(TRowfatSearcher::TSimpleRow(rowid, offsetdw, factormask));
            return true;
        };

        inline TString Detach() {
            TString s = Ss_.Str();
            Ss_.Clear();
            return s;
        }
    };

    static TString HexEncodeN(const void* p, size_t sz, size_t maxChars) {
        if (sz > maxChars) {
            return HexEncode(p, Max<ui32>(3, maxChars) - 3) + "...";
        } else {
            return HexEncode(p, sz);
        }
    }

    bool CheckReasonableSize(const TBuffer& fat, const TBuffer& subfat, size_t nFactorsTotal, size_t nRows) {
        ui32 reasonableSize = sizeof(ui32) * nRows + sizeof(ui16) * nFactorsTotal;
        reasonableSize = Max(16u, reasonableSize);

        // we expect that fat.size() + subfat.size() is "reasonable"
        size_t totalSize = 0;
        Cout << "SubFat blob actual size: " << subfat.size() << ", content: " << HexEncodeN(subfat.data(), subfat.size(), 300) << Endl;
        totalSize += subfat.size();
        if (fat.size() > 0) {
            Cout << "Fat blob actual size: " << fat.size() << ", content: " << HexEncodeN(fat.data(), fat.size(), 300) << Endl;
            totalSize += subfat.size();
        }

        if (fat.size() > 0)
            Cout << "Actual total size: " << totalSize;
        else
            Cout << "Fat blob is empty.";
        Cout << " (we got nValues=" << nFactorsTotal << ", nRows=" << nRows << ")" << Endl;

        Cout << "Expected (total) size: anything less than " << reasonableSize << Endl;
        return reasonableSize > totalSize;
    }

public:
    void TestEmpty() {
        TBufferStream fat, subfat;

        {
            TRowfatWriter wr(&fat, &subfat);
            wr.Finish();
        }

        {
            TRowfatSearcher rd(TBlob::FromBuffer(fat.Buffer()), TBlob::FromBuffer(subfat.Buffer()));
            UNIT_ASSERT_VALUES_EQUAL(0, rd.Size());
            UNIT_ASSERT_VALUES_EQUAL(true, rd.IsEmpty());

            bool found;
            TRowfatSearcher::TSimpleRow row;
            found = rd.ReadRow(Max<ui32>(), row);
            UNIT_ASSERT(!found);
            found = rd.ReadRow(0, row);
            UNIT_ASSERT(!found);
        }
    }

    void TestOneRow() {
        TBufferStream fat, subfat;
        {
            TRowfatWriter wr(&fat, &subfat);
            wr.WriteKey(TRowfatKey{/*rowId=*/1232, 42});
            wr.WriteHit(std::bitset<64>("1110").to_ullong());
            wr.Finish();
        }

        UNIT_ASSERT(CheckReasonableSize(fat.Buffer(), subfat.Buffer(), 1, 1));
        {
            TRowfatSearcher rd(TBlob::FromBuffer(fat.Buffer()), TBlob::FromBuffer(subfat.Buffer()));
            UNIT_ASSERT_VALUES_EQUAL(1, rd.Size());
            UNIT_ASSERT_VALUES_EQUAL(false, rd.IsEmpty());

            bool found;
            TRowfatSearcher::TSimpleRow row;
            found = rd.ReadRow(Max<ui32>(), row);
            UNIT_ASSERT(!found);
            found = rd.ReadRow(0, row);
            UNIT_ASSERT(!found);

            found = rd.ReadRow(1232, row);
            UNIT_ASSERT(found);
            UNIT_ASSERT_VALUES_EQUAL(1232, row.RowId);
            UNIT_ASSERT_VALUES_EQUAL(42 * sizeof(ui32), row.ByteOffset);
            UNIT_ASSERT_VALUES_EQUAL(0xE, row.FactorMask.to_ullong());

            // read all
            TPrinter printer;
            rd.ReadAll(/*maxRow=*/Max<ui32>(), printer);
            UNIT_ASSERT_VALUES_EQUAL("(1232,168,1110b)", printer.Detach());
            rd.ReadAll(/*maxRow=*/1232, printer);
            UNIT_ASSERT_VALUES_EQUAL("(1232,168,1110b)", printer.Detach());
            rd.ReadAll(/*maxRow=*/1231, printer);
            UNIT_ASSERT_VALUES_EQUAL("", printer.Detach());
        }
    }

    void TestNoFat() {
        // When there is less than 16 rows, we use only the linear search
        TBufferStream fat, subfat;
        {
            TRowfatWriter wr(&fat, &subfat);
            wr.WriteKey(TRowfatKey{/*rowId=*/0, 0});
            wr.WriteHit(std::bitset<64>("111").to_ullong());
            wr.WriteKey(TRowfatKey{/*rowId=*/1, 8});
            wr.WriteHit(std::bitset<64>("111").to_ullong());
            wr.WriteKey(TRowfatKey{/*rowId=*/2, 16});
            wr.WriteHit(std::bitset<64>("111").to_ullong());
            wr.WriteKey(TRowfatKey{/*rowId=*/5, 20});
            wr.WriteHit(std::bitset<64>("101").to_ullong());
            wr.Finish();
        }

        // we expect to have zero records in Fat_
        UNIT_ASSERT_VALUES_EQUAL(0, fat.Buffer().Size());

        // we expect that the size of Subfat_ is "reasonable"
        UNIT_ASSERT(CheckReasonableSize(fat.Buffer(), subfat.Buffer(), 11, 4));

        {
            TRowfatSearcher rd(TBlob::FromBuffer(fat.Buffer()), TBlob::FromBuffer(subfat.Buffer()));
            UNIT_ASSERT_VALUES_EQUAL(false, rd.IsEmpty());
            UNIT_ASSERT_VALUES_EQUAL(4, rd.Size());

            bool found;
            TRowfatSearcher::TSimpleRow row;
            found = rd.ReadRow(Max<ui32>(), row);
            UNIT_ASSERT(!found);
            found = rd.ReadRow(0, row);
            UNIT_ASSERT(found);
            found = rd.ReadRow(3, row);
            UNIT_ASSERT(!found);
            found = rd.ReadRow(5, row);
            UNIT_ASSERT(found);
            UNIT_ASSERT_VALUES_EQUAL(0x5, row.FactorMask.to_ullong());

            found = rd.ReadRow(2, row);
            UNIT_ASSERT(found);
            UNIT_ASSERT_VALUES_EQUAL(2, row.RowId);
            UNIT_ASSERT_VALUES_EQUAL(16 * sizeof(ui32), row.ByteOffset);
            UNIT_ASSERT_VALUES_EQUAL(0x7, row.FactorMask.to_ullong());

            TPrinter printer;

            // read all
            rd.ReadAll(/*maxRow=*/Max<ui32>(), printer);
            UNIT_ASSERT_VALUES_EQUAL("(0,0,111b) (1,32,111b) (2,64,111b) (5,80,101b)", printer.Detach());

            // read first N
            rd.ReadAll(/*maxRow=*/4, printer);
            UNIT_ASSERT_VALUES_EQUAL("(0,0,111b) (1,32,111b) (2,64,111b)", printer.Detach());
            rd.ReadAll(/*maxRow=*/1, printer);
            UNIT_ASSERT_VALUES_EQUAL("(0,0,111b) (1,32,111b)", printer.Detach());
            rd.ReadAll(/*maxRow=*/0, printer);
            UNIT_ASSERT_VALUES_EQUAL("(0,0,111b)", printer.Detach());
        }
    }

    void TestSparseByRowIdRng(ui32 rowMin, ui32 rowMax, float percUsedRowIds, std::random_device::result_type rngSeed) {
        Cout << Endl;
        TVector<ui64> factorMasks;
        TVector<ui32> indexes;
        indexes.resize(rowMax - rowMin + 1);
        std::iota(indexes.begin(), indexes.end(), rowMin);

        std::mt19937 rand(rngSeed);
        std::shuffle(indexes.begin(), indexes.end(), rand);

        ui32 numUsed = Min<ui32>(indexes.size() * percUsedRowIds, indexes.size());

        TVector<ui32> expectedToBePresent;
        TVector<ui32> expectedToBeMissing;
        expectedToBePresent.resize(Min<ui32>(10, numUsed));
        expectedToBeMissing.resize(Min<ui32>(10, indexes.size() - numUsed));

        std::copy(indexes.begin(), indexes.begin() + expectedToBePresent.size(), expectedToBePresent.begin());
        std::copy(indexes.begin() + numUsed, indexes.begin() + numUsed + expectedToBeMissing.size(), expectedToBeMissing.begin());
        std::sort(indexes.begin(), indexes.begin() + numUsed);

        TBufferStream fat, subfat;
        ui32 numFactorsTotal = 0;
        ui32 dataOffsetDw = 0;
        {
            TRowfatWriter wr(&fat, &subfat);
            ui64 factorMask1 = std::bitset<64>("1010101").to_ullong();
            ui64 factorMask2 = std::bitset<64>("1010001").to_ullong();
            for (ui32 i = 0; i < numUsed; ++i) {
                ui64 factorMask;
                ui32 numFactors;
                if (i % 3) {
                    std::tie(factorMask, numFactors) = std::make_pair(factorMask1, 4);
                } else {
                    std::tie(factorMask, numFactors) = std::make_pair(factorMask2, 3);
                }
                const ui32 rowId = indexes[i];
                wr.WriteKey(TRowfatKey{rowId, dataOffsetDw});
                wr.WriteHit(factorMask);
                dataOffsetDw += numFactors;
                numFactorsTotal += numFactors;
            }
            wr.Finish();
        }

        // we expect that the size of Subfat_ is "reasonable"
        UNIT_ASSERT(CheckReasonableSize(fat.Buffer(), subfat.Buffer(), numFactorsTotal, numUsed));

        {
            TRowfatSearcher rd(TBlob::FromBuffer(fat.Buffer()), TBlob::FromBuffer(subfat.Buffer()));
            UNIT_ASSERT_VALUES_EQUAL(false, rd.IsEmpty());
            UNIT_ASSERT_VALUES_EQUAL(numUsed, rd.Size());

            const auto endUsed = indexes.begin() + numUsed;
            for (const ui32 rowId : expectedToBePresent) {
                Y_ASSERT(std::find(indexes.begin(), endUsed, rowId) != endUsed);

                bool found;
                TRowfatSearcher::TSimpleRow row;
                found = rd.ReadRow(rowId, row);
                UNIT_ASSERT_C(found, "RowId missing: " << rowId);
                UNIT_ASSERT(row.RowId == rowId);
            }
            for (const ui32 rowId : expectedToBeMissing) {
                Y_ASSERT(std::find(indexes.begin(), endUsed, rowId) == endUsed);

                bool found;
                TRowfatSearcher::TSimpleRow row;
                found = rd.ReadRow(rowId, row);
                UNIT_ASSERT_C(!found, "RowId was expected to be missing, but is found: " << rowId);
            }
        }
    }

    void TestSparseRng() {
        std::random_device::result_type rngSeed = std::random_device()();
        Cout << "TestSparseRng, rngSeed=" << Hex(rngSeed) << Endl;
        TestSparseByRowIdRng(0, 9, 0.7, rngSeed);
        TestSparseByRowIdRng(0, 200, 0.3, rngSeed);
        TestSparseByRowIdRng(10, 34, 0.2, rngSeed);
        TestSparseByRowIdRng(1, 50000, 0.03, rngSeed);
        TestSparseByRowIdRng(0, 70, 1.0, rngSeed);
    }
};

UNIT_TEST_SUITE_REGISTRATION(TOffroadSerfTest);
