#include "factors_config.h"
#include "rank_model.h"

#include <kernel/relevfml/rank_models_factory.h>

#include <library/cpp/testing/unittest/registar.h>
#include <library/cpp/testing/common/env.h>
#include <util/stream/str.h>
#include <util/system/tempfile.h>
#include <library/cpp/json/json_reader.h>
#include <library/cpp/json/json_writer.h>
#include <util/string/printf.h>
#include <util/stream/file.h>


THolder<TTempFileHandle> WriteConfigFile(const NJson::TJsonValue &config) {
    THolder<TTempFileHandle> tempFile;
    tempFile.Reset(new TTempFileHandle());
    TFixedBufferFileOutput buf(*tempFile.Get());
    NJson::WriteJson(&buf, &config);
    return tempFile;
}

TFsPath GetTestResource(const TString& resource) {
    return TFsPath(GetWorkPath()) / resource;
}

TString BuildConfig() {
    static const TString default_mxnet = GetTestResource("mxnets/mxnets/default_mxnet").GetPath();
    static const TString fast_mxnet = GetTestResource("mxnets/mxnets/fast_mxnet").GetPath();
    static const TString filter_mxnet = GetTestResource("mxnets/mxnets/filter_mxnet").GetPath();
    static const TString custom_mxnet = GetTestResource("mxnets/mxnets/custom_mxnet").GetPath();

    /*
     * 10040020000000U7 <-> 1 * [0] * [1]
     * 10060020000000U7 <-> 1 * [0] * [2]
     * 10080020000000U7 <-> 1 * [0] * [3]
     * 100A0020000000U7 <-> 1 * [0] * [4]
     */

    static const auto format_string = R"({
            "static_factors": {
                "stat1": 1,
                "stat2": 2,
                "stat3": 3,
                "stat4": 4
            },
            "dynamic_factors": {
                "MatrixNet": 0
            },
            "formulas": {
                "default": {
                    "matrixnet": "%1$s",
                    "polynom": "10060020000000U7"
                },
                "default_with_custom_poly": {
                    "matrixnet": "%1$s",
                    "polynom": "10040020000000U7"
                },
                "default_with_custom_mxnet": {
                    "matrixnet": "%4$s",
                    "polynom": "10060020000000U7"
                },
                "fast_rank": {
                    "matrixnet": "%2$s",
                    "polynom": "10080020000000U7"
                },
                "fast_with_custom_poly": {
                    "matrixnet": "%2$s",
                    "polynom": "10040020000000U7"
                },
                "fast_with_custom_mxnet": {
                    "matrixnet": "%4$s",
                    "polynom": "10080020000000U7"
                },
                "filter_rank": {
                    "matrixnet": "%3$s",
                    "polynom": "100A0020000000U7"
                },
                "filter_with_custom_poly": {
                    "matrixnet": "%3$s",
                    "polynom": "10040020000000U7"
                },
                "filter_with_custom_mxnet": {
                    "matrixnet": "%4$s",
                    "polynom": "100A0020000000U7"
                }
            }
        })";

    return Sprintf(format_string, default_mxnet.c_str(), fast_mxnet.c_str(), filter_mxnet.c_str(), custom_mxnet.c_str());
}

NJson::TJsonValue CreateFactorsConfigTemplate() {
    TString configBody = BuildConfig();
    NJson::TJsonValue result;
    NJson::ReadJsonTree(configBody, &result, true /*throwOnError*/);
    return result;
}

template <typename Func>
void CmpConfigAndCustomFormula(TString expected_formula, TStringBuf expected_polynom, TStringBuf custom_formula, TStringBuf custom_polynom, Func GetModels) {
    auto temp_file = WriteConfigFile(CreateFactorsConfigTemplate());
    NRTYFactors::TConfig config(temp_file->Name().c_str());

    auto&& [basic_model, custom_model] = GetModels(config, expected_formula, custom_formula, custom_polynom);

    TString result_polynom = Encode(*custom_model->GetRankModel()->Polynom()->Descr);
    UNIT_ASSERT_EQUAL(expected_polynom, result_polynom);

    auto MxnetId1 = basic_model->GetRankModel()->Matrixnet()->GetId();
    auto MxnetId2 = custom_model->GetRankModel()->Matrixnet()->GetId();

    UNIT_ASSERT_EQUAL(MxnetId1, MxnetId2);
}

using TModels = std::pair<const NRTYFactors::TRankModelHolder*, const NRTYFactors::TRankModelHolder*>;

TModels GetRankModels(NRTYFactors::TConfig& config, TStringBuf formula1, TStringBuf formula2, TStringBuf polynom) {
    return {config.GetRankModel(formula1), config.GetRankModel(formula2, TStringBuf(), polynom)};
}

TModels GetFastRankModels(NRTYFactors::TConfig& config, TStringBuf formula1, TStringBuf formula2, TStringBuf polynom) {
    return {config.GetRankModel(formula1), config.GetFastRankModel(formula2, polynom)};
}

TModels GetFilterModels(NRTYFactors::TConfig& config, TStringBuf formula1, TStringBuf formula2, TStringBuf polynom) {
    return {config.GetRankModel(formula1), config.GetFilterModel(formula2, polynom)};
}


#define CHECK_RANK_MODEL(expected_formula, expected_polynom, custom_formula, custom_polynom) \
    CmpConfigAndCustomFormula(expected_formula, expected_polynom, custom_formula, custom_polynom, GetRankModels);

#define CHECK_FAST_RANK(expected_formula, expected_polynom, custom_formula, custom_polynom) \
    CmpConfigAndCustomFormula(expected_formula, expected_polynom, custom_formula, custom_polynom, GetFastRankModels);

#define CHECK_FILTER_RANK(expected_formula, expected_polynom, custom_formula, custom_polynom) \
    CmpConfigAndCustomFormula(expected_formula, expected_polynom, custom_formula, custom_polynom, GetFilterModels);


Y_UNIT_TEST_SUITE(TestRankModelCustomisation) {
    Y_UNIT_TEST(Correctness) {
        bool no_errors = true;
        auto temp_file = WriteConfigFile(CreateFactorsConfigTemplate());
        try {
            NRTYFactors::TConfig config(temp_file->Name().c_str());
        } catch (std::runtime_error &e) {
            Cerr << "ERROR: " << e.what() << Endl;
            no_errors = false;
        }
        UNIT_ASSERT(no_errors);
    }

    Y_UNIT_TEST(EmptyFormulaEmptyPoly) {
        CHECK_RANK_MODEL("default", "10060020000000U7", "", "")
    }

    Y_UNIT_TEST(ConfigFormulaEmptyPoly) {
        CHECK_RANK_MODEL("default_with_custom_poly", "10040020000000U7", "default_with_custom_poly", "")
    }

    Y_UNIT_TEST(NonconfigFormulaEmptyPoly) {
        CHECK_RANK_MODEL("default", "10060020000000U7", "wrong_formula", "")
    }

    Y_UNIT_TEST(EmptyFormulaCustomPoly) {
        CHECK_RANK_MODEL("default_with_custom_poly", "10040020000000U7", "", "10040020000000U7")
    }

    Y_UNIT_TEST(ConfigFormulaCustomPoly) {
        CHECK_RANK_MODEL("default", "10060020000000U7", "default_with_custom_poly", "10060020000000U7")
    }

    Y_UNIT_TEST(NonconfigFormulaCustomPoly) {
        CHECK_RANK_MODEL("default_with_custom_poly", "10040020000000U7", "wrong_formula", "10040020000000U7")
    }

    Y_UNIT_TEST(IncorrectPolynom) {
        CHECK_RANK_MODEL("default", "10060020000000U7", "default_with_custom_poly", "wrong_polynom")
    }

    Y_UNIT_TEST(NonConfigCustomModel) {
        CHECK_RANK_MODEL("default_with_custom_mxnet", "10060020000000U7", "default_with_custom_mxnet",
                         "10060020000000U7")
    }
}

Y_UNIT_TEST_SUITE(TestFastRankModelCustomisation) {
    Y_UNIT_TEST(EmptyFormulaEmptyPoly) {
        CHECK_FAST_RANK("fast_rank", "10080020000000U7", "", "")
    }

    Y_UNIT_TEST(ConfigFormulaEmptyPoly) {
        CHECK_FAST_RANK("default_with_custom_poly", "10040020000000U7", "default_with_custom_poly", "")
    }

    Y_UNIT_TEST(NonconfigFormulaEmptyPoly) {
        CHECK_FAST_RANK("fast_rank", "10080020000000U7", "wrong_formula", "")
    }

    Y_UNIT_TEST(EmptyFormulaCustomPoly) {
        CHECK_FAST_RANK("fast_with_custom_poly", "10040020000000U7", "", "10040020000000U7")
    }

    Y_UNIT_TEST(ConfigFormulaCustomPoly) {
        CHECK_FAST_RANK("fast_rank", "10080020000000U7",  "fast_with_custom_poly", "10080020000000U7")
    }

    Y_UNIT_TEST(NonconfigFormulaCustomPoly) {
        CHECK_FAST_RANK("fast_with_custom_poly", "10040020000000U7", "wrong_formula", "10040020000000U7")
    }

    Y_UNIT_TEST(IncorrectPolynom) {
        CHECK_FAST_RANK("fast_rank", "10080020000000U7", "fast_with_custom_poly", "wrong_polynom")
    }

    Y_UNIT_TEST(NonConfigCustomModel) {
        CHECK_FAST_RANK("fast_with_custom_mxnet", "10060020000000U7", "fast_with_custom_mxnet", "10060020000000U7")
    }
}

Y_UNIT_TEST_SUITE(TestFilterModelCustomisation) {
    Y_UNIT_TEST(EmptyFormulaEmptyPoly) {
        CHECK_FILTER_RANK("filter_rank", "100A0020000000U7", "", "")
    }

    Y_UNIT_TEST(ConfigFormulaEmptyPoly) {
        CHECK_FILTER_RANK("default_with_custom_poly", "10040020000000U7", "default_with_custom_poly", "")
    }

    Y_UNIT_TEST(NonconfigFormulaEmptyPoly) {
        CHECK_FILTER_RANK("filter_rank", "100A0020000000U7", "wrong_formula", "")
    }

    Y_UNIT_TEST(EmptyFormulaCustomPoly) {
        CHECK_FILTER_RANK("filter_with_custom_poly", "10040020000000U7", "", "10040020000000U7")
    }

    Y_UNIT_TEST(ConfigFormulaCustomPoly) {
        CHECK_FILTER_RANK("filter_rank", "100A0020000000U7",  "filter_with_custom_poly", "100A0020000000U7")
    }

    Y_UNIT_TEST(NonconfigFormulaCustomPoly) {
        CHECK_FILTER_RANK("filter_with_custom_poly", "10040020000000U7", "wrong_formula", "10040020000000U7")
    }

    Y_UNIT_TEST(IncorrectPolynom) {
        CHECK_FILTER_RANK("filter_rank", "100A0020000000U7", "filter_with_custom_poly", "wrong_polynom")
    }

    Y_UNIT_TEST(NonConfigCustomModel) {
        CHECK_FILTER_RANK("filter_with_custom_mxnet", "10060020000000U7", "filter_with_custom_mxnet", "10060020000000U7")
    }
}

