#include "string_compressor.h"

#include <library/cpp/testing/unittest/registar.h>

#include <util/random/random.h>
#include <util/system/sem.h>

using namespace NTravel;

TString GenerateString(size_t len) {
    TStringBuilder builder;
    for (size_t i = 0; i < len; i++) {
        builder << (RandomNumber<ui32>(2) ? '1' : '0');
    }
    return builder;
}

void DoTest(size_t trainSamplesCount, size_t strLen, size_t strCnt, bool expectCompression, bool syncCompression) {
    SetRandomSeed(1);

    TStringCompressor compressor("test", true, 1, 0.1, trainSamplesCount, {}, {}, {}, TDuration::MilliSeconds(1));
    compressor.Start();
    TFastSemaphore semaphore(0);
    compressor.OnTrainFinish([&semaphore](bool successful) {
        UNIT_ASSERT_VALUES_EQUAL(successful, true);
        semaphore.Release();
    });
    TVector<TString> strings;
    TVector<TCompressedString> compressedStrings;
    for (size_t i = 0; i < Min(trainSamplesCount, strCnt); i++) {
        strings.push_back(GenerateString(strLen));
        compressedStrings.push_back(compressor.Compress(strings.back()));
        UNIT_ASSERT_GE(compressor.GetStringAllocSizeOnlyForTests(compressedStrings.back()), strLen);
        UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.back()), strings.back());
    }

    if (expectCompression && syncCompression) {
        semaphore.Acquire();
    }

    for (size_t i = trainSamplesCount; i < strCnt; i++) {
        strings.push_back(GenerateString(strLen));
        compressedStrings.push_back(compressor.Compress(strings.back()));
        UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.back()), strings.back());
    }

    if (expectCompression && !syncCompression) {
        semaphore.Acquire();
    }

    for (size_t i = 0; i < strCnt; i++) {
        if (expectCompression) {
            UNIT_ASSERT_LT(compressor.GetStringAllocSizeOnlyForTests(compressedStrings.at(i)), strLen);
        } else {
            UNIT_ASSERT_GE(compressor.GetStringAllocSizeOnlyForTests(compressedStrings.at(i)), strLen);
        }
        UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.at(i)), strings.at(i));
    }
}

Y_UNIT_TEST_SUITE(TestStringCompressor) {
    Y_UNIT_TEST(CheckSimple) {
        DoTest(1000, 1000, 3000, true, false);
    }

    Y_UNIT_TEST(CheckNoCompression) {
        DoTest(5000, 1000, 3000, false, false);
    }

    Y_UNIT_TEST(CheckAddAfterTrain) {
        DoTest(1000, 1000, 3000, true, true);
    }

    Y_UNIT_TEST(CheckMultithread) {
        size_t trainSamplesCount = 1000;
        size_t strLen = 1000;
        size_t strCnt = 3000;
        size_t threadCount = 8;

        SetRandomSeed(1);

        TStringCompressor compressor("test", true, 1, 0.1, trainSamplesCount * threadCount, {}, {}, {}, TDuration::MilliSeconds(1));
        compressor.Start();
        TFastSemaphore semaphore(0);
        compressor.OnTrainFinish([&semaphore](bool successful) {
            UNIT_ASSERT_VALUES_EQUAL(successful, true);
            semaphore.Release();
        });
        TVector<TVector<TString>> strings(threadCount);
        TVector<TVector<TCompressedString>> compressedStrings(threadCount);
        TVector<THolder<IThreadFactory::IThread>> threads;
        for (size_t th = 0; th < threadCount; th++) {
            threads.push_back(SystemThreadFactory()->Run([&compressor, &strings = strings.at(th), &compressedStrings = compressedStrings.at(th), trainSamplesCount, strLen, strCnt]() {
                for (size_t i = 0; i < Min(trainSamplesCount, strCnt); i++) {
                    strings.push_back(GenerateString(strLen));
                    compressedStrings.push_back(compressor.Compress(strings.back()));
                    UNIT_ASSERT_GE(compressor.GetStringAllocSizeOnlyForTests(compressedStrings.back()), strLen);
                    UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.back()), strings.back());
                }
            }));
        }

        semaphore.Acquire();

        for (auto& th : threads) {
            th->Join();
        }
        threads.clear();

        for (size_t th = 0; th < threadCount; th++) {
            threads.push_back(SystemThreadFactory()->Run([&compressor, &strings = strings.at(th), &compressedStrings = compressedStrings.at(th), trainSamplesCount, strLen, strCnt]() {
                for (size_t i = trainSamplesCount; i < strCnt; i++) {
                    strings.push_back(GenerateString(strLen));
                    compressedStrings.push_back(compressor.Compress(strings.back()));
                    UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.back()), strings.back());
                }

                for (size_t i = 0; i < strCnt; i++) {
                    UNIT_ASSERT_LT(compressor.GetStringAllocSizeOnlyForTests(compressedStrings.at(i)), strLen);
                    UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.at(i)), strings.at(i));
                }
            }));
        }

        for (auto& th : threads) {
            th->Join();
        }
    }

    Y_UNIT_TEST(CheckDeletions) {
        size_t trainSamplesCount = 1000;
        size_t firstPartSize = trainSamplesCount - 1;
        size_t secondPartSize = 10;
        size_t strLen = 1000;

        SetRandomSeed(1);

        TStringCompressor compressor("test", true, 1, 0.1, trainSamplesCount, {}, {}, {}, TDuration::MilliSeconds(1));
        compressor.Start();
        TFastSemaphore semaphore(0);
        compressor.OnTrainFinish([&semaphore](bool successful) {
            UNIT_ASSERT_VALUES_EQUAL(successful, true);
            semaphore.Release();
        });

        {
            TVector<TCompressedString> compressedStrings;
            for (size_t i = 0; i < firstPartSize; i++) {
                compressedStrings.push_back(compressor.Compress(GenerateString(strLen)));
            }
        }

        TVector<TString> strings;
        TVector<TCompressedString> compressedStrings;
        for (size_t i = 0; i < secondPartSize; i++) {
            strings.push_back(GenerateString(strLen));
            compressedStrings.push_back(compressor.Compress(strings.back()));
        }
        semaphore.Acquire();

        UNIT_ASSERT_VALUES_EQUAL(compressor.GetStringCount(), secondPartSize);

        for (size_t i = 0; i < secondPartSize; i++) {
            UNIT_ASSERT_LT(compressor.GetStringAllocSizeOnlyForTests(compressedStrings.at(i)), strLen);
            UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.at(i)), strings.at(i));
        }
    }

    Y_UNIT_TEST(CheckSchedule) {
        size_t trainSamplesCount = 1000;
        size_t rebuildSamplesCount = 1000;
        size_t strLen = 1000;
        size_t epochs = 5;

        SetRandomSeed(1);

        TStringCompressor compressor("test", true, 1, 2 /* for sure 100% sampling*/, trainSamplesCount, {}, rebuildSamplesCount, {}, TDuration::MilliSeconds(1));
        compressor.Start();
        TFastSemaphore semaphore(0);
        compressor.OnTrainFinish([&semaphore](bool successful) {
            UNIT_ASSERT_VALUES_EQUAL(successful, true);
            semaphore.Release();
        });

        TVector<TString> strings;
        TVector<TCompressedString> compressedStrings;
        for (size_t i = 0; i < trainSamplesCount; i++) {
            strings.push_back(GenerateString(strLen));
            compressedStrings.push_back(compressor.Compress(strings.back()));
            UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.back()), strings.back());
        }

        semaphore.Acquire();
        UNIT_ASSERT_VALUES_EQUAL(compressor.GetStringCount(), trainSamplesCount);

        for (size_t epoch = 0; epoch < epochs; epoch++) {
            for (size_t i = 0; i < rebuildSamplesCount; i++) {
                strings.push_back(GenerateString(strLen));
                compressedStrings.push_back(compressor.Compress(strings.back()));
                UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.back()), strings.back());
            }
            semaphore.Acquire();
        }

        UNIT_ASSERT_VALUES_EQUAL(compressor.GetStringCount(), trainSamplesCount + epochs * rebuildSamplesCount);

        for (size_t i = 0; i < trainSamplesCount + epochs * rebuildSamplesCount; i++) {
            UNIT_ASSERT_LT(compressor.GetStringAllocSizeOnlyForTests(compressedStrings.at(i)), strLen);
            UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.at(i)), strings.at(i));
        }
    }

    Y_UNIT_TEST(CheckDuplicatesAfterTrain) {
        size_t trainSamplesCount = 1000;
        size_t rebuildSamplesCount = 1000;
        size_t strLen = 1000;

        SetRandomSeed(1);

        TStringCompressor compressor("test", true, 1, 2 /* for sure 100% sampling*/, trainSamplesCount, {}, rebuildSamplesCount, {}, TDuration::MilliSeconds(1));
        compressor.Start();
        TFastSemaphore semaphore(0);
        compressor.OnTrainFinish([&semaphore](bool successful) {
            UNIT_ASSERT_VALUES_EQUAL(successful, true);
            semaphore.Release();
        });

        TVector<TCompressedString> compressedStrings;
        for (size_t i = 0; i < trainSamplesCount; i++) {
            compressedStrings.push_back(compressor.Compress(GenerateString(strLen)));
        }

        semaphore.Acquire();

        auto first = compressor.Compress("abc");
        auto second = compressor.Compress("abc");

        UNIT_ASSERT_VALUES_EQUAL(compressor.GetStringCount(), trainSamplesCount + 1);

        UNIT_ASSERT_VALUES_EQUAL(static_cast<const void*>(first.DangerouslyGetInnerDataOnlyForTests()),
                                 static_cast<const void*>(second.DangerouslyGetInnerDataOnlyForTests()));
    }

    Y_UNIT_TEST(CheckDuplicatesSimple) {
        size_t trainSamplesCount = 2;
        size_t rebuildSamplesCount = 1000;

        SetRandomSeed(1);

        TStringCompressor compressor("test", true, 1, 2 /* for sure 100% sampling*/, trainSamplesCount, {}, rebuildSamplesCount, {}, TDuration::MilliSeconds(1));
        compressor.Start();
        TFastSemaphore semaphore(0);
        compressor.OnTrainFinish([&semaphore](bool successful) {
            UNIT_ASSERT_VALUES_EQUAL(successful, false);
            semaphore.Release();
        });

        auto tmp1 = compressor.Compress("abc");
        auto tmp2 = compressor.Compress("abc");

        UNIT_ASSERT_VALUES_EQUAL(compressor.GetStringCount(), 1);

        //test deletions here

        UNIT_ASSERT_VALUES_EQUAL(static_cast<const void*>(tmp1.DangerouslyGetInnerDataOnlyForTests()),
                                 static_cast<const void*>(tmp2.DangerouslyGetInnerDataOnlyForTests()));

        auto tmp3 = compressor.Compress("test");

        semaphore.Acquire();

        auto first = compressor.Compress("xyz");
        auto second = compressor.Compress("xyz");

        UNIT_ASSERT_VALUES_EQUAL(compressor.GetStringCount(), 3);

        UNIT_ASSERT_VALUES_EQUAL(static_cast<const void*>(first.DangerouslyGetInnerDataOnlyForTests()),
                                 static_cast<const void*>(second.DangerouslyGetInnerDataOnlyForTests()));
    }

    Y_UNIT_TEST(CheckDuplicatesDeletion) {
        size_t trainSamplesCount = 2;
        size_t rebuildSamplesCount = 1000;

        SetRandomSeed(1);

        TStringCompressor compressor("test", true, 1, 2 /* for sure 100% sampling*/, trainSamplesCount, {}, rebuildSamplesCount, {}, TDuration::MilliSeconds(1));
        compressor.Start();
        TFastSemaphore semaphore(0);
        compressor.OnTrainFinish([&semaphore](bool successful) {
            UNIT_ASSERT_VALUES_EQUAL(successful, false);
            semaphore.Release();
        });

        TCompressedString second;
        {
            auto first = compressor.Compress("abc");
            second = compressor.Compress("abc");
        }

        auto tmp = compressor.Compress("test");
        semaphore.Acquire();

        UNIT_ASSERT_VALUES_EQUAL(compressor.GetStringCount(), 2);

        UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(second), "abc");
    }

    Y_UNIT_TEST(CheckDisabled) {
        size_t strCnt = 1000;
        size_t strLen = 1000;

        SetRandomSeed(1);

        TStringCompressor compressor("test", false, 1, 0.1, 10, {}, {}, {}, TDuration::MilliSeconds(1));
        compressor.Start();
        TVector<TString> strings;
        TVector<TCompressedString> compressedStrings;
        for (size_t i = 0; i < strCnt; i++) {
            strings.push_back(GenerateString(strLen));
            compressedStrings.push_back(compressor.Compress(strings.back()));
            UNIT_ASSERT_GE(compressor.GetStringAllocSizeOnlyForTests(compressedStrings.back()), strLen);
            UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.back()), strings.back());
        }

        for (size_t i = 0; i < strCnt; i++) {
            UNIT_ASSERT_GE(compressor.GetStringAllocSizeOnlyForTests(compressedStrings.at(i)), strLen);
            UNIT_ASSERT_VALUES_EQUAL(compressor.Decompress(compressedStrings.at(i)), strings.at(i));
        }
    }

    Y_UNIT_TEST(CheckEquality) {
        size_t trainSamplesCount = 2;
        size_t rebuildSamplesCount = 1000;

        SetRandomSeed(1);

        TStringCompressor compressor("test", true, 1, 2 /* for sure 100% sampling*/, trainSamplesCount, {}, rebuildSamplesCount, {}, TDuration::MilliSeconds(1));
        compressor.Start();
        TFastSemaphore semaphore(0);
        compressor.OnTrainFinish([&semaphore](bool successful) {
            UNIT_ASSERT_VALUES_EQUAL(successful, false);
            semaphore.Release();
        });

        auto first = compressor.Compress("abc");
        auto second = compressor.Compress("abc");
        auto other = compressor.Compress("test");

        semaphore.Acquire();

        UNIT_ASSERT_VALUES_EQUAL(compressor.GetStringCount(), 2);

        UNIT_ASSERT_VALUES_EQUAL(first == second, true);  // strange construction to be sure that == is used
        UNIT_ASSERT_VALUES_EQUAL(first != second, false); // strange construction to be sure that != is used
        UNIT_ASSERT_VALUES_EQUAL(first == other, false);
        UNIT_ASSERT_VALUES_EQUAL(first != other, true);
        UNIT_ASSERT_VALUES_EQUAL(THash<TCompressedString>()(first), THash<TCompressedString>()(second));
        UNIT_ASSERT_VALUES_UNEQUAL(THash<TCompressedString>()(first), THash<TCompressedString>()(other));
    }
}
