#include "hash_to_docid.h"

#include <library/cpp/digest/md5/md5.h>
#include <library/cpp/testing/unittest/gtest.h>
#include <library/cpp/testing/unittest/registar.h>

#include <util/stream/file.h>
#include <util/system/tempfile.h>

#include <random>

using namespace NSaas;

namespace {
    const TString TestIndexPath = "./test_index";

    struct TTestDocHashIterator : IDocHashIterator {
        TTestDocHashIterator(const TVector<TDocHash>& hashes, const TVector<ui32>& ids, size_t size)
            : Hashes(hashes)
            , Ids(ids)
            , Size(size)
        {
            Y_ENSURE(hashes.size() >= size);
            Y_ENSURE(ids.size() >= size);
        }

        bool IsValid() const override {
            return Current < Size;
        }
        void Next() override {
            ++Current;
        }
        std::pair<TDocHash, ui32> GetHashWithId() const override {
            return std::make_pair(Hashes[Current], Ids[Current]);
        }

    private:
        const TVector<TDocHash>& Hashes;
        const TVector<ui32>& Ids;
        size_t Current = 0;
        size_t Size = 0;
    };

    template<typename T>
    TDocHash GetMD5(const T& t) {
        TDocHash hash;
        MD5 hasher;
        hasher.Update(&t, sizeof(T));
        hasher.Final(hash.Ptr());
        return hash;
    }

    struct TTestDataGenerator : IDocHashStorage {
        TTestDataGenerator(ui64 seed, ui32 sampleSize, ui32 maxDocId)
            : SampleSize(sampleSize)
        {
            Y_ENSURE(sampleSize <= maxDocId);

            std::mt19937_64 rng(seed);
            std::uniform_int_distribution<ui64> docIdGen(0, maxDocId);

            Ids.resize(maxDocId);
            std::iota(Ids.begin(), Ids.end(), 0);
            std::shuffle(Ids.begin(), Ids.end(), rng);

            TVector<TDocHash> sortedHashes(maxDocId);
            Hashes.resize(maxDocId);
            for (ui32 i = 0; i < maxDocId; ++i) {
                Hashes[i] = GetMD5(rng());
                sortedHashes[i] = Hashes[i];
            }
            std::sort(sortedHashes.begin(), sortedHashes.end());
            for (ui32 i = 1; i < maxDocId; ++i) {
                Y_ENSURE(sortedHashes[i - 1] != sortedHashes[i]);
            }
        }

        ui64 GetEstimatedSize() const override {
            return SampleSize;
        }
        THolder<IDocHashIterator> CreateDocHashIterator() const override {
            return MakeHolder<TTestDocHashIterator>(Hashes, Ids, SampleSize);
        }

        auto GetRemovedHashes() const {
            return std::make_pair(Hashes.begin() + SampleSize, Hashes.end());
        }

    private:
        TVector<TDocHash> Hashes;
        TVector<ui32> Ids;
        ui64 SampleSize = 0;
    };

    void DoTest(const TTestDataGenerator& testData, TStringBuf indexPath) {
        THashToDocId hashToId;
        if (!hashToId.Open(indexPath)) {
            THashToDocId::Rebuild(testData, indexPath, Max<ui32>());
            EXPECT_TRUE(hashToId.Open(indexPath));
        }
        EXPECT_EQ(testData.GetEstimatedSize(), hashToId.GetEstimatedSize());
        auto dataIterator = testData.CreateDocHashIterator();
        for (; dataIterator->IsValid(); dataIterator->Next()) {
            auto docInfo = dataIterator->GetHashWithId();
            EXPECT_EQ(hashToId.GetIdByHash(docInfo.first), docInfo.second);
        }
        auto removedRange = testData.GetRemovedHashes();
        for (auto it = removedRange.first; it != removedRange.second; ++it) {
            EXPECT_EQ(hashToId.GetIdByHash(*it), THashToDocId::NotFound);
        }
    }

    void DoCreateTest(ui64 seed, ui32 sampleSize, ui32 maxDocId) {
        TTempFile tempFile(TestIndexPath);
        TTestDataGenerator data(seed, sampleSize, maxDocId);
        DoTest(data, TestIndexPath);
    }

    TEST(THashToDocId, CreateNew1) {
        DoCreateTest(0, 1, 1);
    }

    TEST(THashToDocId, CreateNew2) {
        DoCreateTest(1, 10, 12);
    }

    TEST(THashToDocId, CreateNew3) {
        DoCreateTest(2, 10000, 11000);
    }

    TEST(THashToDocId, CreateNew4) {
        DoCreateTest(3, 11000, 22000);
    }

    template<typename TCallback>
    void DoRebuildTest(TCallback&& initCallback) {
        {
            TOFStream ofs(TestIndexPath);
            initCallback(ofs);
        }
        TTempFile tempFile(TestIndexPath);
        TTestDataGenerator data(4, 40, 42);
        DoTest(data, TestIndexPath);
    }

    TEST(THashToDocId, Rebuild1) {
        DoRebuildTest([](IOutputStream& testFile) {
            testFile << "_";  // can't read version
        });
    }

    TEST(THashToDocId, Rebuild2) {
        DoRebuildTest([](IOutputStream& testFile) {
            testFile << "12345";  // wrong version
        });
    }

    TEST(THashToDocId, Rebuild3) {
        DoRebuildTest([](IOutputStream& testFile) {
            ui32 version = 4;
            testFile.Write(&version, sizeof(version));
            testFile << "_";  // can not read number of hashes
        });
    }

    TEST(THashToDocId, Rebuild4) {
        DoRebuildTest([](IOutputStream& testFile) {
            ui32 version = 4;
            testFile.Write(&version, sizeof(version));
            testFile << "Some garbage";  // file size does not match
        });
    }
}
