#pragma once

#include "dssm_factors.h"

#include <kernel/dssm_applier/nn_applier/lib/layers.h>

#include <util/generic/set.h>
#include <util/generic/string.h>

namespace NRTYFeatures {
    class TDssmHelper {
    private:
        TSet<TString> QueryFields;

    public:
        TDssmHelper()
            : QueryFields({"query", "query_reg"})
        {
        }

        void SplitFieldExtractor(NNeuralNetApplier::TModel& model);

        static TVector<TString> GetFieldNames(const NNeuralNetApplier::TModel& model);

        NNeuralNetApplier::TModelPtr GetQuerySubmodel(const NNeuralNetApplier::TModel& fullModel, const TString& diagName, const TSet<TString>& embeddingNames);

        NNeuralNetApplier::TModelPtr GetDocSubmodel(const NNeuralNetApplier::TModel& fullModel, const TString& diagName, const TSet<TString>& resultNames, const TSet<TString>& embeddingNames);

        Y_FORCE_INLINE bool IsQueryInputVar(const TString& v) {
            return QueryFields.contains(v);
        }

        Y_FORCE_INLINE bool HasOnlyQueryInputs(const TVector<TString>& v) {
            return std::all_of(v.cbegin(), v.cend(), [this](const TString& v) {
                return IsQueryInputVar(v);
            });
        }

        Y_FORCE_INLINE bool HasNoQueryInputs(const TVector<TString>& v) {
            return std::none_of(v.cbegin(), v.cend(), [this](const TString& v) {
                return IsQueryInputVar(v);
            });
        }
    };
}
