#include "dssm_factors_ut.h"
#include "dssm_factors.h"
#include "dssm_model_ops.h"

#include <saas/rtyserver/search/features/protos/dssm_models.pb.h>
#include <kernel/factor_storage/factor_storage.h>
#include <library/cpp/testing/unittest/registar.h>
#include <util/folder/dirut.h>
#include <util/stream/str.h>
#include <util/string/join.h>
#include <util/string/vector.h>
#include <util/system/tempfile.h>

using namespace NRTYFeatures;

Y_UNIT_TEST_SUITE(TRtyDssmFeature) {
    Y_UNIT_TEST(TestDssmBundle) {
        // TInputTraits combining
        using TInputTraits = TDssmBundle::TInputTraits;
        using TFactorBinding = TDssmBundle::TFactorBinding;
        using TModelBinding = TDssmBundle::TModelBinding;
        {
            TInputTraits s;
            UNIT_ASSERT_EQUAL("0 0 0 0", s.DebugString());
            s.UsesRegion = true;
            s.UsesQuery = true;
            s.UsesTitle = true;
            UNIT_ASSERT_EQUAL("1 1 0 0", TDssmBundleDebug::QueryUrlTitle().DebugString());
        }
        {

            TFactorBinding f;
            f.ModelName = "model1";
            f.VarName = "outputA";

            TFactorBinding f2;
            f2.ModelName = "model1";
            f2.VarName = "outputB";

            TFactorBinding f3 = f;

            TModelBinding m;
            m.Inputs.UsesRegion = true;
            m.Inputs.UsesQuery = true;
            UNIT_ASSERT_EQUAL("", m.ModelName);
            UNIT_ASSERT_EQUAL("", JoinSeq(" ", m.VarNames));
            UNIT_ASSERT_EQUAL("1 0 0 1", m.Inputs.DebugString());

            m.Combine(f);
            UNIT_ASSERT_EQUAL("model1", m.ModelName);
            UNIT_ASSERT_EQUAL("outputA", JoinSeq(" ", m.VarNames));
            UNIT_ASSERT_EQUAL("1 0 0 1", m.Inputs.DebugString());

            m.Combine(f2);
            UNIT_ASSERT_EQUAL("model1", m.ModelName);
            UNIT_ASSERT_EQUAL("outputA outputB", JoinSeq(" ", m.VarNames));
            UNIT_ASSERT_EQUAL("1 0 0 1", m.Inputs.DebugString());

            m.Combine(f3);
            UNIT_ASSERT_EQUAL("model1", m.ModelName);
            UNIT_ASSERT_EQUAL("outputA outputB", JoinSeq(" ", m.VarNames));
            UNIT_ASSERT_EQUAL("1 0 0 1", m.Inputs.DebugString());

            UNIT_ASSERT_EQUAL(0, m.FindOutputLocalId("outputA"));
            UNIT_ASSERT_EQUAL(1, m.FindOutputLocalId("outputB"));
            UNIT_ASSERT_EXCEPTION(m.FindOutputLocalId(""), yexception);
            UNIT_ASSERT_EXCEPTION(m.FindOutputLocalId("outputa"), yexception);
        }
    }

    static void FillProtoConfigForTests(TDssmModelsConfig& p) {
        auto addModel = [](TDssmModelsConfig& c, const TString& alias, const TString& fileName) {
            TDssmModelFile* modelFile = c.AddDssmModel();
            modelFile->SetAlias(alias);
            modelFile->SetFileName(fileName);
        };
        auto addFb = [](TDssmModelsConfig& c, const TString& modelName, const TVector<TString>& inputs, const TVector<TString>& queryEmbs) {
            TDssmFactorsBlock* fb = c.AddDssmFactors();
            fb->SetModelAlias(modelName);
            for (const auto& field: inputs)
                fb->AddInputField(field);
            for (const auto& emb: queryEmbs)
                fb->AddQueryEmbedding(emb);
            return fb;
        };
        auto addFactor = [](TDssmFactorsBlock* fb, const TString& name, const TString& outputVar) {
            TDssmFactorsBlock::TFactorDescription* f = fb->AddFactor();
            f->SetFactorName(name);
            f->SetOutputVar(outputVar);
        };

        {
            addModel(p, "query_word_title", "joint_afterdot_query_word_title_embedding.dssm");
            auto* fb = addFb(p, "query_word_title", {"doc_title", "doc_url", "query"}, {"query_embedding_query_word_title_embedding"});
            addFactor(fb, "DssmQueryTitle", "joint_afterdot_query_word_title_embedding");
        }
        {
            addModel(p, "reformulations", "reformulations.dssm");
            auto* fb = addFb(p, "reformulations", {"doc_title", "doc_url", "query"}, {"query_embedding_reformulations_longest_click_log_dt"});
            addFactor(fb, "ReformulationsLongestClickLogDt", "joint_output_reformulations_longest_click_log_dt");
        }
    }

    Y_UNIT_TEST(TestDssmFactorsExtractopn) {
        const TStringBuf DssmBundleBody = R"(
            CheckFilesExistence: true
            DssmFactors {
                ModelAlias: "reformulations"
                InputField: "doc_title"
                InputField: "doc_url"
                InputField: "query"
                QueryEmbedding: "query_embedding_reformulations_longest_click_log_dt"
                Factor {
                    FactorName: "ReformulationsLongestClickLogDt"
                    OutputVar: "joint_output_reformulations_longest_click_log_dt"
                }
            }
        )";

        TString modelsPath = GetSystemTempDir();
        TString dssmBundleFile = modelsPath + "/dssm_bundle.pb.txt";

        TTempFileHandle dssmBundle{dssmBundleFile};
        dssmBundle.Write(DssmBundleBody.data(), DssmBundleBody.size());

        TVector<TString> resultFactors;
        auto addFactor = [&resultFactors](const NRTYFeatures::TDynFactorInfo& factorInfo) {
            resultFactors.push_back(factorInfo.Name);
        };

        TDssmFeature dssmFeature;
        dssmFeature.InitStaticInfo(modelsPath, addFactor);

        UNIT_ASSERT_VALUES_EQUAL(1, resultFactors.size());
        UNIT_ASSERT_VALUES_EQUAL("ReformulationsLongestClickLogDt", resultFactors[0]);
    }

    Y_UNIT_TEST(TestDssmModelsConfig) {
        TDssmBundle::TPtr globals = MakeIntrusive<TDssmBundle>();
        TDssmModelsConfig proto;
        FillProtoConfigForTests(proto);
        proto.SetCheckFilesExistence(false);
        globals->SetBundle(proto, TFsPath());

        UNIT_ASSERT_EXCEPTION(globals->GetFactor("a", false), yexception);

        TDssmBundle::TFactorBinding f = globals->GetFactor("DssmQueryTitle", false);
        UNIT_ASSERT_VALUES_EQUAL("query_word_title", f.ModelName);
        UNIT_ASSERT_VALUES_EQUAL("joint_afterdot_query_word_title_embedding", f.VarName);
        UNIT_ASSERT_VALUES_EQUAL("joint_afterdot_query_word_title_embedding.dssm", globals->GetModelFileName(f.ModelName));
        UNIT_ASSERT_VALUES_EQUAL(false, f.ShouldFillCanonicalValue());

        {
            // try GetModelMeta()
            const TDssmBundle::TModelBinding& m = globals->GetModelMeta(f.ModelName);
            UNIT_ASSERT_VALUES_EQUAL("1 1 0 0", m.Inputs.DebugString());
        }
        {
            // try the second model
            TDssmBundle::TFactorBinding f2 = globals->GetFactor("ReformulationsLongestClickLogDt", false);
            UNIT_ASSERT_VALUES_EQUAL("reformulations", f2.ModelName);
            UNIT_ASSERT_VALUES_EQUAL("joint_output_reformulations_longest_click_log_dt", f2.VarName);
            UNIT_ASSERT_VALUES_EQUAL(false, f2.ShouldFillCanonicalValue());
        }
        {
            // check GetFactor() in the fastrank case
            TDssmBundle::TFactorBinding fNone = globals->GetFactor("DssmQueryTitle", true);
            UNIT_ASSERT_VALUES_EQUAL(true, fNone.ShouldFillCanonicalValue());
        }
        {
            // check handling of an "unknown factor" situation in the fastrank case
            TDssmBundle::TFactorBinding fNone = globals->GetFactor("a", true);
            UNIT_ASSERT_VALUES_EQUAL(true, fNone.ShouldFillCanonicalValue());
        }
    }

    static void Dump(TStringStream& o, const TDssmFeatureCalcerPlan::TModelDescription& md) {
        const TDssmBundle::TModelBinding& m = md.SumBinding;
        o << "Model{" << md.ModelAlias << "," << m.ModelName << ",[" << JoinSeq(",", m.VarNames) << "],'" << m.Inputs.DebugString() << "'}";
    }

    static void Dump(TStringStream& o, const TDssmFeatureCalcerPlan::TFactorDescription& fd) {
        const TString& modelAlias = fd.ModelDescr ? fd.ModelDescr->ModelAlias : "(null)";
        const bool zeroed = fd.ModelDescr == nullptr;
        o << "Factor{" << fd.FactorLocalId << "," << fd.FactorName << "," << modelAlias;
        if (!zeroed) {
            o << "," << fd.OutputLocalId << "," << fd.VarName;
        }
        o << "}";
    }

    static TString DumpPlan(const TDssmFeatureCalcerPlan& plan) {
        using TModelDescription = TDssmFeatureCalcerPlan::TModelDescription;
        using TFactorDescription = TDssmFeatureCalcerPlan::TFactorDescription;

        TStringStream result;
        result << "Models[";
        int nModels = 0;
        for (const auto& item : plan.GetModels()) {
            const TModelDescription& md = *item;
            if (nModels++)
                result << " ";
            Dump(result, md);
        }

        result << "] Factors[";
        int nFactors = 0;
        for (const auto& item : plan.GetFactors()) {
            const TFactorDescription& fd = *item;
            if (nFactors++)
                result << " ";

            if (fd.ModelDescr != nullptr) {
                //ensure fd.ModelDescr is correct
                const TModelDescription* md = fd.ModelDescr;
                const auto mdIter = std::find_if(plan.GetModels().cbegin(), plan.GetModels().cend(), [md](const THolder<TModelDescription>& x) {
                    return x.Get() == md;
                });
                Y_ENSURE(mdIter != plan.GetModels().cend());
            } else {
                // a filling
                Y_ENSURE(fd.VarName.empty());
                Y_ENSURE(fd.OutputLocalId == Max<ui32>());
            }
            Dump(result, fd);
        }
        result << "]";
        return result.Str();
    }

    Y_UNIT_TEST(TestDssmCalcerPlan) {
        TDssmBundleDebug bundle1;
        bundle1.AddModel("model1", TDssmBundleDebug::QueryHostPathTitle(), TVector<TString>());
        bundle1.AddModel("model2", TDssmBundleDebug::QueryUrlTitle(), TVector<TString>());
        bundle1.AddModel("model3", TDssmBundleDebug::QueryUrlTitle(), {"something_unused"});
        bundle1.AddFactor("f1", TDssmBundle::TFactorBinding("model1", "outputA"));
        bundle1.AddFactor("f2", TDssmBundle::TFactorBinding("model2", "outputX"));
        bundle1.AddFactor("f3", TDssmBundle::TFactorBinding("model1", "outputB"));
        bundle1.AddFactor("f4", TDssmBundle::TFactorBinding("model3", "outputC"));
        bool fastFeaturesOnly = false;

        // 3 factors, two models
        TDssmFeatureCalcerPlan plan1;
        plan1.Add(101, "f1", bundle1, fastFeaturesOnly);
        plan1.Add(105, "f2", bundle1, fastFeaturesOnly);
        plan1.Add(102, "f3", bundle1, fastFeaturesOnly);

        UNIT_ASSERT_VALUES_EQUAL(
            "Models["
                "Model{model1,model1,[outputA,outputB],'1 1 1 0'} "
                "Model{model2,model2,[outputX],'1 1 0 0'}] "
            "Factors["
                "Factor{101,f1,model1,0,outputA} "
                "Factor{105,f2,model2,0,outputX} "
                "Factor{102,f3,model1,1,outputB}]",
            DumpPlan(plan1));

        // FastRank: 3 factors, empty plan
        TDssmFeatureCalcerPlan plan2;
        fastFeaturesOnly = true;
        plan2.Add(101, "f1", bundle1, fastFeaturesOnly);
        plan2.Add(105, "f2", bundle1, fastFeaturesOnly);
        plan2.Add(102, "f3", bundle1, fastFeaturesOnly);
        UNIT_ASSERT_VALUES_EQUAL(
            "Models[] "
            "Factors[]",
            DumpPlan(plan2));
    }

    class TMockFeatureCalcer: public TDssmFeatureCalcerBase {
    public:
        mutable TStringStream Log;
        mutable ui32 GetResultsCount = 0;

    public:
        using TDssmFeatureCalcerBase::TDssmFeatureCalcerBase;

        TString GetLog() const {
            return Log.Str();
        }

    protected:
        void ApplyModel(const TDssmFeatureCalcerPlan::TModelDescription* m, const TRTYDynamicFeatureContext&, ui32 docId) override {
            Y_ENSURE(m != nullptr);
            Log << "ApplyModel ";
            Dump(Log, *m);
            Log << " " << docId << "; ";
        }

        void ApplyModelToQuery(const TDssmFeatureCalcerPlan::TModelDescription* m, const TRTYDynamicFeatureContext&) override {
            Y_ENSURE(m != nullptr);
            Log << "ApplyToQuery ";
            Dump(Log, *m);
            Log << " [" << JoinSeq(",", m->SumBinding.QueryCacheOutputs) << "]; ";
        }

        float GetResult(const ui32 outputLocalId) const override {
            const float result = float(++GetResultsCount) / 10;
            Log << "GetResult " << outputLocalId << " " << result << "; ";
            return result;
        }
    };

    struct TMockStaticInfo {
        ui32 FactorGlobalId;
        TString FactorName;
    };

    static THolder<TDssmFeatureCalcerPlan> MockCreatePlan(const TVector<TMockStaticInfo>& allFactors, const TDynMapping& enabledFactors, const TDssmBundle& modelsInfo, bool fastFeaturesOnly = false) {
        //This method mimics TDssmFeature::CreatePlan, but does not use the global bindings (like TDssmBundle)
        auto plan1 = MakeHolder<TDssmFeatureCalcerPlan>();
        ui32 localId = 0;
        for (const auto& dp: enabledFactors) {
            const auto& sourceFactor = dp.SourceFactor;
            const auto iter = std::find_if(allFactors.cbegin(), allFactors.cend(), [&sourceFactor](const TMockStaticInfo& fi) {
                return fi.FactorName == sourceFactor;
            });

            if (iter == allFactors.cend()) {
                continue;
            }
            const TMockStaticInfo& fi = *iter;
            plan1->Add(localId++, fi.FactorName, modelsInfo, fastFeaturesOnly);
        }
        Cout << DumpPlan(*plan1) << Endl;
        return plan1;
    }

    Y_UNIT_TEST(TestDssmFeatureCalcerBase) {
        TDssmBundleDebug bundle1;
        bundle1.AddModel("model1", TDssmBundleDebug::QueryHostPathTitle(), {});
        bundle1.AddModel("model2", TDssmBundleDebug::QueryUrlTitle(), {});
        bundle1.AddModel("model3", TDssmBundleDebug::QueryRegHostPathTitle(), {});
        bundle1.AddFactor("f1", TDssmBundle::TFactorBinding("model1", "outputA"));
        bundle1.AddFactor("f2", TDssmBundle::TFactorBinding("model2", "outputX"));
        bundle1.AddFactor("f3", TDssmBundle::TFactorBinding("model1", "outputB"));
        bundle1.AddFactor("f4", TDssmBundle::TFactorBinding("model3", "outputC"));

        const TRTYDynamicFeatureContext* const indexData = nullptr;
        TVector<TMockStaticInfo> mockStaticData{
            {101, "f1"},
            {105, "f2"},
            {103, "f3"},
            {104, "f4"}};

        {
            // Case 1. No factors is used
            TDynMapping reqFactors{};
            TFactorStorage fs(30);

            TMockFeatureCalcer calcer(MockCreatePlan(mockStaticData, reqFactors, bundle1), reqFactors);
            calcer.Calc(fs, *indexData, 33);

            UNIT_ASSERT_VALUES_EQUAL("", calcer.GetLog());
        }
        {
            // Case 2. One model is used (calc f1 f3 + extra feature)
            TDynMapping reqFactors{
                TDynMappingItem("f1", 20),
                TDynMappingItem("f3", 21),
            };

            TFactorStorage fs(30);

            auto plan = MockCreatePlan(mockStaticData, reqFactors, bundle1);
            reqFactors.emplace_back("f4", 21);
            TMockFeatureCalcer calcer(std::move(plan), reqFactors);
            calcer.Calc(fs, *indexData, 33);

            UNIT_ASSERT_VALUES_EQUAL(
                "ApplyModel Model{model1,model1,[outputA,outputB],'1 1 1 0'} 33; " // model1 is applied only once, yields two results
                "GetResult 0 0.1; "
                "GetResult 1 0.2; ",
                calcer.GetLog());
            UNIT_ASSERT_VALUES_EQUAL(0.1f, fs[20]);
            UNIT_ASSERT_VALUES_EQUAL(0.2f, fs[21]);
        }
        {
            // Case 3. Two models are used (calc f1 f2), model1 input is restricted
            TDynMapping reqFactors{
                TDynMappingItem("f1", 20),
                TDynMappingItem("f2", 22),
            };

            TFactorStorage fs(30);

            TMockFeatureCalcer calcer(MockCreatePlan(mockStaticData, reqFactors, bundle1), reqFactors);
            calcer.Calc(fs, *indexData, 33);

            UNIT_ASSERT_VALUES_EQUAL(
                "ApplyModel Model{model1,model1,[outputA],'1 1 1 0'} 33; "
                "GetResult 0 0.1; " // f1
                "ApplyModel Model{model2,model2,[outputX],'1 1 0 0'} 33; "
                "GetResult 0 0.2; ", // f2
                calcer.GetLog());
            UNIT_ASSERT_VALUES_EQUAL(0.1f, fs[20]);
            UNIT_ASSERT_VALUES_EQUAL(0.2f, fs[22]);
        }
        {
            // Case 4. Three models are used, model1 is reused, factor f2 is used twice
            TDynMapping reqFactors{
                TDynMappingItem("f1", 1), //f1, model1
                TDynMappingItem("f4", 2), //f4, model3
                TDynMappingItem("f3", 3), //f3, model1 (model 1 should be reused again)
                TDynMappingItem("f2", 4), //f2, model2
                TDynMappingItem("f2", 5), //duplicate of f2, model2
            };
            TFactorStorage fs(6);

            TMockFeatureCalcer calcer(MockCreatePlan(mockStaticData, reqFactors, bundle1), reqFactors);
            calcer.Calc(fs, *indexData, 33);

            UNIT_ASSERT_VALUES_EQUAL(
                "ApplyModel Model{model1,model1,[outputA,outputB],'1 1 1 0'} 33; " // model1 is applied only once, yields two results
                "GetResult 0 0.1; " // f1
                "GetResult 1 0.2; " // f3
                "ApplyModel Model{model3,model3,[outputC],'1 1 1 1'} 33; "
                "GetResult 0 0.3; " // f4
                "ApplyModel Model{model2,model2,[outputX],'1 1 0 0'} 33; "
                "GetResult 0 0.4; "  // f2
                "GetResult 0 0.5; ", // duplicate of f2
                calcer.GetLog());
            UNIT_ASSERT_VALUES_EQUAL(0.1f, fs[1]);
            UNIT_ASSERT_VALUES_EQUAL(0.3f, fs[2]);
            UNIT_ASSERT_VALUES_EQUAL(0.2f, fs[3]);
            UNIT_ASSERT_VALUES_EQUAL(0.4f, fs[4]);
            UNIT_ASSERT_VALUES_EQUAL(0.5f, fs[5]);
        }
    }

    Y_UNIT_TEST(TestDssmFeatureCalcerQueryCache) {
        //
        // Calculate query embeddings only once, use it several times
        //
        TDssmBundleDebug bundle1;
        bundle1.AddModel("model1", TDssmBundleDebug::QueryHostPathTitle(), {"MyQueryEmb"});
        bundle1.AddFactor("f1", TDssmBundle::TFactorBinding("model1", "outputA"));
        bundle1.AddFactor("f2", TDssmBundle::TFactorBinding("model1", "outputB"));

        const TRTYDynamicFeatureContext* const indexData = nullptr;
        TVector<TMockStaticInfo> mockStaticData{
            {101, "f1"},
            {102, "f2"}
        };

        {
            // Query cache (calc f1 f2)
            TDynMapping reqFactors{
                TDynMappingItem("f1", 20),
                TDynMappingItem("f2", 21),
            };

            TFactorStorage fs(30);

            TMockFeatureCalcer calcer(MockCreatePlan(mockStaticData, reqFactors, bundle1), reqFactors);
            calcer.Calc(fs, *indexData, 33);
            calcer.Calc(fs, *indexData, 34);

            UNIT_ASSERT_VALUES_EQUAL(
                "ApplyToQuery Model{model1,model1,[outputA,outputB],'1 1 1 0'} [MyQueryEmb]; "
                "ApplyModel Model{model1,model1,[outputA,outputB],'1 1 1 0'} 33; "
                "GetResult 0 0.1; "
                "GetResult 1 0.2; "
                "ApplyModel Model{model1,model1,[outputA,outputB],'1 1 1 0'} 34; "
                "GetResult 0 0.3; "
                "GetResult 1 0.4; ",
                calcer.GetLog());
            UNIT_ASSERT_VALUES_EQUAL(0.3f, fs[20]);
            UNIT_ASSERT_VALUES_EQUAL(0.4f, fs[21]);
        }
    }


    Y_UNIT_TEST(TestDssmFeatureCalcerQueryCache2) {
        //
        // Calculate 3 embeddings from 2 models. Use them for 3 documents.
        //
        TDssmBundleDebug bundle1;
        bundle1.AddModel("model1", TDssmBundleDebug::QueryHostPathTitle(), {"QueryEmbA"});
        bundle1.AddModel("model2", TDssmBundleDebug::QueryUrlTitle(), {"QueryEmbB", "QueryEmbC"});
        bundle1.AddFactor("f1", TDssmBundle::TFactorBinding("model1", "outputA"));
        bundle1.AddFactor("f2", TDssmBundle::TFactorBinding("model1", "outputB"));
        bundle1.AddFactor("f3", TDssmBundle::TFactorBinding("model2", "outputX"));

        const TRTYDynamicFeatureContext* const indexData = nullptr;
        TVector<TMockStaticInfo> mockStaticData{
            {101, "f1"},
            {102, "f2"},
            {103, "f3"}
        };

        {
            TDynMapping reqFactors{
                TDynMappingItem("f1", 20),
                TDynMappingItem("f2", 21),
                TDynMappingItem("f3", 22),
            };

            TFactorStorage fs(30);

            TMockFeatureCalcer calcer(MockCreatePlan(mockStaticData, reqFactors, bundle1), reqFactors);
            calcer.Calc(fs, *indexData, 33);
            calcer.Calc(fs, *indexData, 34);

            UNIT_ASSERT_VALUES_EQUAL(
                "ApplyToQuery Model{model1,model1,[outputA,outputB],'1 1 1 0'} [QueryEmbA]; "
                "ApplyToQuery Model{model2,model2,[outputX],'1 1 0 0'} [QueryEmbB,QueryEmbC]; "
                "ApplyModel Model{model1,model1,[outputA,outputB],'1 1 1 0'} 33; "
                "GetResult 0 0.1; "
                "GetResult 1 0.2; "
                "ApplyModel Model{model2,model2,[outputX],'1 1 0 0'} 33; "
                "GetResult 0 0.3; "
                "ApplyModel Model{model1,model1,[outputA,outputB],'1 1 1 0'} 34; "
                "GetResult 0 0.4; "
                "GetResult 1 0.5; "
                "ApplyModel Model{model2,model2,[outputX],'1 1 0 0'} 34; "
                "GetResult 0 0.6; ",
                calcer.GetLog());
            UNIT_ASSERT_VALUES_EQUAL(0.4f, fs[20]);
            UNIT_ASSERT_VALUES_EQUAL(0.5f, fs[21]);
            UNIT_ASSERT_VALUES_EQUAL(0.6f, fs[22]);
        }
    }

    static NNeuralNetApplier::TModelPtr CreateFakeModel(TVector<TString> queryInputs, TVector<TString> docInputs, TVector<TString> queryEmbs, TVector<TString> outputs, bool useFieldExtractor) {
        //
        // The implementation depends on the ability to extract a submodel.
        // Here we create a "fake" model to test this behavior
        //

        using namespace NNeuralNetApplier;
        TModelPtr fakeModelPtr = MakeIntrusive<TModel>();
        TModel& fakeModel = *fakeModelPtr;
        TMatrixPtr input = new TMatrix();

        if (!useFieldExtractor) {
            // no TFieldExtractorLayer in model
            for (const TString& qi: queryInputs) {
                fakeModel.Parameters[qi] = input;
                fakeModel.Inputs.push_back(qi);
            }
            for (const TString& di: docInputs) {
                fakeModel.Parameters[di] = input;
                fakeModel.Inputs.push_back(di);
            }
        } else {
            // model uses a composite Input called 'input'
            fakeModel.Inputs.push_back("input");
            THashMap<TString,TString> annotationToVariable;
            for (const TString& qi: queryInputs) {
                annotationToVariable[qi] = qi;
            }
            for (const TString& di: docInputs) {
                annotationToVariable[di] = di;
            }
            fakeModel.Layers.push_back(new TFieldExtractorLayer("input", annotationToVariable));
        }
        fakeModel.Layers.push_back(new TConcatLayer(queryInputs, "_query_inputs"));
        fakeModel.Layers.push_back(new TConcatLayer(docInputs, "_doc_inputs"));

        for (const TString& qe: queryEmbs) {
            fakeModel.Layers.push_back(new TElementwiseTransform<TRlu>("_query_inputs", qe));
        }

        // for simplicity, we make all the outputs depend on all queryEmbs in our FakeModel. It is not like that in the real world.
        TVector<TString> diAndQe;
        diAndQe.push_back("_doc_inputs");
        diAndQe.insert(diAndQe.end(), queryEmbs.cbegin(), queryEmbs.cend());

        fakeModel.Layers.push_back(new TConcatLayer(diAndQe, "_outputs"));
        for (const TString& o: outputs) {
            fakeModel.Layers.push_back(new TElementwiseTransform<TRlu>("_outputs", o));
        }
        fakeModel.Init();
        return fakeModelPtr;
    }

    static TVector<TString> GetInputs(const NNeuralNetApplier::TModel& model) {
        TVector<TString> inputs = TDssmHelper().GetFieldNames(model);
        Sort(inputs);
        return inputs;
    }

    static TVector<TString> GetVariables(const NNeuralNetApplier::TModel& model) {
        TVector<TString> vars;
        for (const TString& varName: model.AllVariables()) {
            if (varName.StartsWith("_") || varName == "input")
                continue;
            vars.push_back(varName);
        }
        Sort(vars);
        return vars;
    }

    Y_UNIT_TEST(TestSplitModelPrecondition) {
        //
        // TestSplitModelPreconditions hopes to catch regressions in dssm_applier that interfere with the way we use TModel.
        // Here we ensure the "precondition" - that TModel::GetSubmodel works when called directly. TestSplitModels does virtually the same.
        //

        TDssmHelper dssmHelper;
        for (bool withFieldExtractorLayer: {true, false}) {
            // checks that the helper method works as expected, and that it is possible to extract submodels like expected
            NNeuralNetApplier::TModelPtr fakeModel = CreateFakeModel({"query", "query_reg"}, {"host", "title"}, {"queryEmbX", "queryEmbY"}, {"factorX", "factorY"}, withFieldExtractorLayer);
            UNIT_ASSERT_VALUES_EQUAL("host,query,query_reg,title", JoinStrings(GetInputs(*fakeModel), ","));

            // Apply SplitFieldExtractor to resolve a common issue with DSSM models. The set of fields should remain the same.
            dssmHelper.SplitFieldExtractor(*fakeModel);
            UNIT_ASSERT_VALUES_EQUAL("host,query,query_reg,title", JoinStrings(GetInputs(*fakeModel), ","));

            // check that, after extraction of querySubmodel, there are document-related inputs
            NNeuralNetApplier::TModelPtr querySubmodel = fakeModel->GetSubmodel("queryEmbX");
            UNIT_ASSERT_VALUES_EQUAL("query,query_reg", JoinStrings(GetInputs(*querySubmodel), ","));
            UNIT_ASSERT_VALUES_EQUAL("query,queryEmbX,query_reg", JoinStrings(GetVariables(*querySubmodel), ","));

            // check that, after extraction of docSubmodel, there are query-related inputs
            NNeuralNetApplier::TModelPtr docSubmodel = fakeModel->GetSubmodel(TSet<TString>{"factorY"}, {"queryEmbX", "queryEmbY"});
            UNIT_ASSERT_VALUES_EQUAL("host,title", JoinStrings(GetInputs(*docSubmodel), ","));
            UNIT_ASSERT_VALUES_EQUAL("factorY,host,queryEmbX,queryEmbY,title", JoinStrings(GetVariables(*docSubmodel), ","));

            // extraction fails if not all embeddings are given (this behavior is used to check whether the configuration is correct);
            NNeuralNetApplier::TModelPtr failedDocSubmodel = fakeModel->GetSubmodel(TSet<TString>{"factorY"}, {"queryEmbX"});
            UNIT_ASSERT_VALUES_UNEQUAL("host,title", JoinStrings(GetInputs(*failedDocSubmodel), ","));
            NNeuralNetApplier::TModelPtr failedDocSubmodel2 = fakeModel->GetSubmodel("factorY", {});
            UNIT_ASSERT_VALUES_UNEQUAL("host,title", JoinStrings(GetInputs(*failedDocSubmodel2), ","));
        }
    }

    Y_UNIT_TEST(TestSplitModel) {
        TDssmBundleDebug bundle1;
        bundle1.AddModel("model1", TDssmBundleDebug::QueryHostPathTitle(), {"embeddingX", "embeddingY"});
        bundle1.AddModel("model2", TDssmBundleDebug::QueryUrlTitle(), {"something"});
        bundle1.AddFactor("f1", TDssmBundle::TFactorBinding("model1", "factorX"));
        bundle1.AddFactor("f2", TDssmBundle::TFactorBinding("model1", "factorY"));

        TDssmFeatureCalcerPlan plan1;
        plan1.Add(101, "f1", bundle1, false);
        plan1.Add(102, "f2", bundle1, false);


        NNeuralNetApplier::TModelPtr fakeModel = CreateFakeModel(
                {"query"},
                {"doc_url", "doc_host", "doc_path", "doc_title"}, //matches TInputTraits for model1
                {"embeddingX", "embeddingY"},
                {"factorX", "factorY", "factorZ"},
                /*withFieldExtractorLayer=*/true);

        // immitate LoadModels routine
        {
            UNIT_ASSERT(plan1.GetModels().size() == 1);
            TDssmFeatureCalcerPlan::TModelDescription* m;
            m = plan1.GetModels().front().Get();
            m->Model = fakeModel.Get();
            plan1.PrepareModels(nullptr, false);
        }

        // check the plan, and the outcome of PrepareModels
        {
            UNIT_ASSERT_VALUES_EQUAL(
                "Models["
                    "Model{model1,model1,[factorX,factorY],'1 1 1 0'}] "
                "Factors["
                    "Factor{101,f1,model1,0,factorX} "
                    "Factor{102,f2,model1,1,factorY}]",
                DumpPlan(plan1));
            TDssmFeatureCalcerPlan::TModelDescription* m;
            m = plan1.GetModels().front().Get();

            UNIT_ASSERT(m->QuerySubmodel);
            UNIT_ASSERT_VALUES_EQUAL("query", JoinStrings(GetInputs(*m->QuerySubmodel), ","));

            UNIT_ASSERT(m->DocSubmodel);
            UNIT_ASSERT_VALUES_EQUAL("doc_host,doc_path,doc_title,doc_url", JoinStrings(GetInputs(*m->DocSubmodel), ","));
            UNIT_ASSERT_VALUES_EQUAL("doc_host,doc_path,doc_title,doc_url,embeddingX,embeddingY,factorX,factorY", JoinStrings(GetVariables(*m->DocSubmodel), ","));
        }
    }
}
