#include <library/cpp/testing/unittest/registar.h>

#include <rt-research/broadmatching/scripts/cpp-source/bender/Index.h>

#include <irt/common/lib/logger/logger.h>

#include <library/cpp/json/json_reader.h>
#include <library/cpp/testing/unittest/env.h>

#include <util/folder/path.h>
#include <util/generic/yexception.h>
#include <util/stream/file.h>
#include <util/string/cast.h>
#include <util/string/split.h>
#include <util/system/tempfile.h>

using namespace NJson;


void MakeIndex(
    const TJsonValue& banners,
    TString encodedFile,
    TString indexFile
) {
    TMap<TString, TSet<bid_size>> wordsIndex;

    auto encodedOut = TFileOutput(encodedFile);
    auto indexOut = TFileOutput(indexFile);

    bid_size bannerNo = 1;
    for (const TJsonValue& pair : banners.GetArray()) {
        DataItem bannerId = pair[0].GetInteger();
        TString bannerText = pair[1].GetString();

        encodedOut << bannerId << Endl;

        for (const auto& it : StringSplitter(bannerText).Split(' ').SkipEmpty()) {
            TString word = ToString(it.Token());
            wordsIndex[word].insert(bannerNo);
        }

        bannerNo++;
    }

    // надеемся, что TString внути TMap отсортирован как мы ожидаем: в алфавитом порядке русского языка
    for (const auto& pair : wordsIndex) {
        const auto& word = pair.first;
        const auto& bannersForWord = pair.second;
        indexOut << word;
        bool first = true;
        for (auto bannerId : bannersForWord) {
            indexOut << (first ? '\t' : ' ') << bannerId;
            first = false;
        }
        indexOut << Endl;
    }
}

void MakeBinaryIndex(
    TString encodedFile,
    TString indexFile,
    TString binaryFile
) {
    Index index;

    Cout << "Loading from " << encodedFile << ", " << indexFile << Endl;
    if (!index.LoadText(encodedFile.c_str(), indexFile.c_str())) {
        ythrow yexception() << "can't load data";
    }

    Cout << "saving to " << binaryFile << Endl;
    if (!index.SaveBinary(binaryFile.c_str())) {
        ythrow yexception() << "can't save binary data to " << binaryFile;
    }
}

Index LoadFromBinary(TString binaryFile) {
    Index index;
    if (!index.LoadBinary(binaryFile.c_str())) {
        ythrow yexception() << "can't load binary data from " << binaryFile;
    }

    return index;
}

Index MakeIndex(const TJsonValue& bannerSet) {
    auto encodedFile = TTempFile("encoded_file");
    auto indexFile = TTempFile("index_file");
    auto binaryFile = TTempFile("binary_file");

    MakeIndex(bannerSet, encodedFile.Name(), indexFile.Name());
    MakeBinaryIndex(encodedFile.Name(), indexFile.Name(), binaryFile.Name());
    return LoadFromBinary(binaryFile.Name());
}

TJsonValue ReadJsonTestData(const TString& fname) {
    TJsonValue json;
    TFileInput in(JoinFsPaths(ArcadiaSourceRoot(), "rt-research/broadmatching/scripts/cpp-source/bender/ut/" + fname));
    ReadJsonTree(&in, &json, true);
    return json;
}

template <typename T>
TSet<T> ToSet(const std::vector<T>& data) {
    TSet<T> set;
    for (const auto& it : data) {
        set.insert(it);
    }
    return set;
}

template <typename T>
bool IsUniqueSet (const std::vector<T>& data) {
    return data.size() == ToSet(data).size();
}


Y_UNIT_TEST_SUITE(Index) {
    Y_UNIT_TEST(Search) {
        NIRT::InitLog<NIRT::TTabLoggerFormatter>({.ConsoleOnly=true});
        TJsonValue json = ReadJsonTestData("search.json");

        for (const TJsonValue& suite : json.GetArray()) {
            auto index = MakeIndex(suite["banners"]);
            for (const TJsonValue& test : suite["queries"].GetArray()) {
                const TString& query = test["query"].GetString();
                TSet<DataItem> expectedResultSet;
                for (const TJsonValue& it : test["bids"].GetArray()) {
                    expectedResultSet.insert(static_cast<DataItem>(it.GetInteger()));
                }

                std::vector<DataItem> actualResult;
                index.Search(query.c_str(), 0, actualResult);
                auto actualResultSet = TSet<DataItem>(actualResult.begin(), actualResult.end());

                TString comment = "For query '" + query + "' got " + ToString(actualResultSet.size()) + " results, expected " + ToString(expectedResultSet.size());
                UNIT_ASSERT_C(expectedResultSet == actualResultSet, comment);
                UNIT_ASSERT_C(IsUniqueSet(actualResult), "no duplicates in results");
            }
        }
    }

    Y_UNIT_TEST(RandomBanners) {
        NIRT::InitLog<NIRT::TTabLoggerFormatter>({.ConsoleOnly=true});
        TJsonValue json = ReadJsonTestData("random_banners.json");

        for (const TJsonValue& suite : json.GetArray()) {
            TSet<DataItem> allBanners;
            for (const auto& it : suite["banners"].GetArray()) {
                allBanners.insert(it.GetArray()[0].GetInteger());
            }

            auto index = MakeIndex(suite["banners"]);
            for (const TJsonValue& test : suite["queries"].GetArray()) {
                unsigned sample_size = test["sample_size"].GetInteger();
                unsigned result_size = test["result_size"].GetInteger();

                std::vector<DataItem> result;
                index.RandomBanners(sample_size, result);
                UNIT_ASSERT_C(result.size() == result_size, "correct result size");
                UNIT_ASSERT_C(IsUniqueSet(result), "no duplicates in results");

                if (test["expect_all"].GetBoolean()) {
                    UNIT_ASSERT_C(ToSet(result) == allBanners, "all banners present in results");
                }
            }
        }
    }

    Y_UNIT_TEST(RandomBannersQuery) {
        NIRT::InitLog<NIRT::TTabLoggerFormatter>({.ConsoleOnly=true});
        TJsonValue json = ReadJsonTestData("random_banners_query.json");

        for (const TJsonValue& suite : json.GetArray()) {
            auto index = MakeIndex(suite["banners"]);
            for (const TJsonValue& test : suite["queries"].GetArray()) {
                TString query = test["query"].GetString();
                unsigned sample_size = test["sample_size"].GetInteger();
                unsigned result_size = test["result_size"].GetInteger();

                std::vector<DataItem> result;
                index.RandomBannersQuery(query.c_str(), sample_size, result);

                UNIT_ASSERT_C(result.size() == result_size, "correct result size");
                UNIT_ASSERT_C(IsUniqueSet(result), "no duplicates in results");

                if (test["bids"].IsDefined()) {
                    TSet<DataItem> expectedResult;
                    for (const auto& it : test["bids"].GetArray()) {
                        expectedResult.insert(it.GetInteger());
                    }

                    UNIT_ASSERT_C(ToSet(result) == expectedResult, "result as expected");
                }
            }
        }
    }

}
