#include "common.h"
#include "segment_embedding_model.h"

#include <crypta/lib/proto/user_data/user_data_stats.pb.h>

#include <util/generic/map.h>
#include <util/generic/string.h>
#include <util/generic/vector.h>

using namespace NCrypta::NLookalike;

TSegmentEmbeddingModel::TSegmentEmbeddingModel(const TString& modelFileName, const TString& featuresMappingFileName)
    : TEmbeddingModel(modelFileName)
    , SegmentFeaturesCalculator(GetFeaturesMapping(featuresMappingFileName)) {
}

TEmbedding TSegmentEmbeddingModel::Embed(const NLab::TUserDataStats& segment) const {
    static const TString outputLayerName = "doc_embedding";
    const auto segmentDssmModel = DssmModel.GetSubmodel(outputLayerName);

    const auto segmentFloatFeatures = SegmentFeaturesCalculator.PrepareFloatFeatures(segment);

    TMap<TString, std::function<TString(const NLab::TUserDataStats&)>> headerToCalculatorMapping = {
        {"segment_affinitive_sites_ids", &TSegmentFeaturesCalculator::PrepareAffinitiveSitesIds},
        {"segment_affinitive_apps", &TSegmentFeaturesCalculator::PrepareAffinitiveApps}
    };


    TVector<TString> segmentAnnotations{"segment_float_features"};
    TVector<TString> segmentFeatures{segmentFloatFeatures};
    for (const auto& it : headerToCalculatorMapping) {
        if (segmentDssmModel->HasVariable("$fields$" + it.first)) {
            segmentAnnotations.push_back(it.first);
            segmentFeatures.push_back(it.second(segment));
        }
    }
    const auto segmentSample = MakeAtomicShared<NNeuralNetApplier::TSample>(segmentAnnotations, segmentFeatures);

    TEmbedding segmentEmbedding;
    segmentDssmModel->Apply(segmentSample, {outputLayerName}, segmentEmbedding);

    return segmentEmbedding;
}
