#include "cb_calcer.h"

#include <saas/rtyserver/factors/factors_config.h>

#include <kernel/relevfml/rank_models_factory.h>

#include <library/cpp/testing/unittest/registar.h>
#include <library/cpp/testing/unittest/tests_data.h>

#include <library/cpp/json/writer/json.h>

#include <util/stream/file.h>
#include <util/string/join.h>

using namespace NRTYServer;
using namespace NRTYFactors;

class TRtyCatboostCalcerTest: public NUnitTest::TTestBase {
    UNIT_TEST_SUITE(TRtyCatboostCalcerTest)
    UNIT_TEST(TestCatFactor);
    UNIT_TEST(TestRankModelHolder);
    UNIT_TEST_SUITE_END();

private:
    static TFsPath GetTestResource(const TString& resource) {
        return TFsPath(GetWorkPath()) / resource;
    }

public:
    using TTestSet = TVector<std::pair<TVector<float>, bool>>;

    static TTestSet GetTestSet() {
        // see also: training set in ut/train.tsv
        return TTestSet{
            {{0.1, 1}, true},
            {{0.1, 2}, false},
            {{0.1, 3}, true},
            {{0.1, 4}, false},
            {{0.5, 1}, false},
            {{0.5, 2}, false},
            {{0.5, 3}, false},
            {{0.5, 4}, false},
            {{0.9, 1}, false},
            {{0.9, 2}, true},
            {{0.9, 3}, false},
            {{0.9, 4}, true},
        };
    }

    TVector<float*> GetCalcerInputRefs(TTestSet& testSet) {
        TVector<float*> inputs;
        std::transform(testSet.begin(), testSet.end(), std::back_inserter(inputs), [](auto& v) { return v.first.data(); });
        return inputs;
    }

    void CompareOutputs(const TTestSet& testSet, const TVector<float>& outputs) {
        UNIT_ASSERT(testSet.size() == outputs.size());
        for (size_t i = 0; i < testSet.size(); ++i) {
            bool expectedYes = testSet[i].second;
            if (expectedYes) {
                UNIT_ASSERT_GT_C(outputs[i], 0.6, "i = " << i);
            } else {
                UNIT_ASSERT_LT_C(outputs[i], 0.4, "i = " << i);
            }
        }
    }

    void TestCatFactor() {
        TCatboostRelev c;

        TFsPath testModel = GetTestResource("model/rty_test.cbm");

        c.Init(testModel);

        TSet<ui32> usedFactors;
        UNIT_ASSERT(c.GetUsedFactors(usedFactors));
        UNIT_ASSERT_VALUES_EQUAL("0,1", JoinSeq(",", usedFactors));

        TSet<ui32> catFactors;
        c.GetUsedCatFactors(catFactors);
        UNIT_ASSERT_VALUES_EQUAL("1", JoinSeq(",", catFactors));

        TTestSet testSet = GetTestSet();
        TVector<float*> inputs = GetCalcerInputRefs(testSet);
        TVector<float> outputs(inputs.size());
        c.CalcRelevance(inputs.data(), outputs.data(), inputs.size());
        CompareOutputs(testSet, outputs);
    }

    static void WriteFactorsConfig(const TString& filename) {
        using namespace NJson;
        TJsonValue relevConf(EJsonValueType::JSON_MAP);

        TJsonValue erfFactors(EJsonValueType::JSON_MAP);
        erfFactors.InsertValue("f_val", TJsonValue(0));
        erfFactors.InsertValue("f_categ", TJsonValue(1));
        relevConf.InsertValue("static_factors", erfFactors);

        TJsonValue formulas(EJsonValueType::JSON_MAP);

        TJsonValue formula1(EJsonValueType::JSON_MAP);
        formula1.InsertValue("user_ranking", "catboost");
        formula1.InsertValue("model", "model/rty_test.cbm");

        formulas.InsertValue("default", formula1);

        NJsonWriter::TBuf json;
        json
            .BeginObject()
            .WriteKey("static_factors")
            .WriteJsonValue(&erfFactors)
            .WriteKey("formulas")
            .WriteJsonValue(&formulas)
            .EndObject();

        TFileOutput o(filename);
        o.SetFinishPropagateMode(true);
        o << json.Str() << Endl;
    }

    void TestRankModelHolder() {
        WriteFactorsConfig("relev.conf");
        TConfig relevConf("relev.conf");
        const TRankModelHolder* model = relevConf.GetRankModel("default");

        UNIT_ASSERT(model != nullptr);

        TSet<ui32> usedFactors = model->GetUsedFactors();
        UNIT_ASSERT_VALUES_EQUAL("0,1", JoinSeq(",", usedFactors));

        const TRankModel* traditionalModel = model->GetRankModel();
        UNIT_ASSERT(traditionalModel != nullptr);
        UNIT_ASSERT_VALUES_EQUAL(false, traditionalModel->HasMatrixnet());
        UNIT_ASSERT_VALUES_EQUAL(false, traditionalModel->HasPolynom());

        TAtomicSharedPtr<IUserRanking> userRanking = model->GetUserRanking();
        UNIT_ASSERT(!!userRanking);
        UNIT_ASSERT_VALUES_EQUAL(true, userRanking->HasRelevance());
        UNIT_ASSERT_VALUES_EQUAL(false, userRanking->HasFactors());
        TSet<ui32> usedFactors2;
        UNIT_ASSERT(userRanking->GetUsedFactors(usedFactors2));
        UNIT_ASSERT_VALUES_EQUAL("0,1", JoinSeq(",", usedFactors2));

        TCatboostRelev* impl = dynamic_cast<TCatboostRelev*>(userRanking.Get());
        UNIT_ASSERT(!!impl);
        TSet<ui32> catFactors;
        impl->GetUsedCatFactors(catFactors);
        UNIT_ASSERT_VALUES_EQUAL("1", JoinSeq(",", catFactors));

        // now call TRankModelHolder::MultiCalc and check the results
        TTestSet testSet = GetTestSet();
        TVector<float*> inputs = GetCalcerInputRefs(testSet);
        TVector<float> outputs(inputs.size());
        model->MultiCalc(inputs.data(), outputs.data(), outputs.size());
        CompareOutputs(testSet, outputs);
    }
};

UNIT_TEST_SUITE_REGISTRATION(TRtyCatboostCalcerTest);
