#include "dssm_model_ops.h"

#include <util/generic/yexception.h>
#include <util/string/join.h>

namespace NRTYFeatures {
    //
    // TDssmHelper: helper ops for DSSM models
    //
    void TDssmHelper::SplitFieldExtractor(NNeuralNetApplier::TModel& model) {
        using namespace NNeuralNetApplier;
        //TODO(yrum): this routine is copypasted from kernel/dssm_applier/nn_applier/main.cpp:SplitExtractor
        size_t layerIdx = 0;
        for (const auto& layer : model.Layers) {
            if (dynamic_cast<TFieldExtractorLayer*>(layer.Get()) != nullptr) {
                auto* l = dynamic_cast<TFieldExtractorLayer*>(layer.Get());
                THashMap<TString, TString> queryFieldsMap;
                THashMap<TString, TString> docFieldsMap;
                for (const auto& item : l->GetAnnotations()) {
                    if (QueryFields.contains(item.first)) {
                        queryFieldsMap[item.first] = item.second;
                    } else {
                        docFieldsMap[item.first] = item.second;
                    }
                }
                if (!queryFieldsMap.empty() && !docFieldsMap.empty()) {
                    TString inputVar = l->GetInputs()[0];
                    model.Layers[layerIdx] = new TFieldExtractorLayer(inputVar, queryFieldsMap);
                    model.Layers.insert(model.Layers.begin() + layerIdx,
                                         new TFieldExtractorLayer(inputVar, docFieldsMap));
                }
                break;
            }
            ++layerIdx;
        }
    }

    TVector<TString> TDssmHelper::GetFieldNames(const NNeuralNetApplier::TModel& model) {
        using namespace NNeuralNetApplier;

        TVector<TString> fields(model.Inputs);
        for (const auto& layer : model.Layers) {
            auto* l = dynamic_cast<TFieldExtractorLayer*>(layer.Get());
            if (l == nullptr) {
                continue;
            }

            for (const TString& compositeInput : l->GetInputs()) {
                auto pos = std::find(fields.begin(), fields.end(), compositeInput);
                if (pos != fields.end())
                    fields.erase(pos);
            }

            for (const auto& item : l->GetAnnotations()) {
                fields.push_back(item.first);
            }
        }
        return fields;
    }

    NNeuralNetApplier::TModelPtr TDssmHelper::GetQuerySubmodel(const NNeuralNetApplier::TModel& fullModel, const TString& diagName, const TSet<TString>& embeddingNames) {
        NNeuralNetApplier::TModelPtr querySubmodel = fullModel.GetSubmodel(embeddingNames);

        TVector<TString> diagFields = GetFieldNames(*querySubmodel);
        Y_ENSURE(HasOnlyQueryInputs(diagFields), "incorrect inputs in querySubmodel for " << diagName
                    << " with outputsNames=" << JoinSeq(",", embeddingNames)
                    << " : " << JoinSeq(",", diagFields));

        return querySubmodel;
    }

    NNeuralNetApplier::TModelPtr TDssmHelper::GetDocSubmodel(const NNeuralNetApplier::TModel& fullModel, const TString& diagName,
                                                             const TSet<TString>& resultNames, const TSet<TString>& embeddingNames) {
        NNeuralNetApplier::TModelPtr docSubmodel = fullModel.GetSubmodel(resultNames, embeddingNames);

        TVector<TString> diagFields = GetFieldNames(*docSubmodel);
        Y_ENSURE(HasNoQueryInputs(diagFields), "incorrect inputs in docSubmodel for " << diagName
                    << " with outputsNames=" << JoinSeq(",", resultNames)
                    << " : " << JoinSeq(",", diagFields));

        return docSubmodel;
    }
}
