#include "apply_user_dssm_mapper.h"

#include <crypta/lookalike/lib/native/normalize.h>
#include <crypta/lookalike/proto/error.pb.h>
#include <crypta/lookalike/proto/user_embedding.pb.h>
#include <crypta/lookalike/proto/yt_node_names.pb.h>

#include <util/random/shuffle.h>

using namespace NCrypta;
using namespace NCrypta::NLookalike;
using namespace NCrypta::NLookalike::NUserDssmApplier;

TApplyUserDssmMapper::TApplyUserDssmMapper(TOutputIndexes outputIndexes)
    : OutputIndexes(std::move(outputIndexes))
{
}

TApplyUserDssmMapper::TApplyUserDssmMapper()
    : OutputIndexes(TOutputIndexes({{EOutputTables::UserEmbeddings, 0}, {EOutputTables::Errors, 1} }))
{
}

void NCrypta::NLookalike::NUserDssmApplier::TApplyUserDssmMapper::Start(TWriter* writer) {
    try {
        UserEmbeddingModel = MakeHolder<TUserEmbeddingModel>(
                TYtNodeNames().GetDssmModelFile(),
                TYtNodeNames().GetSegmentsDictFile());
    } catch (const yexception& e) {
        WriteError(writer, e.what());
        throw;
    }
}

void TApplyUserDssmMapper::Do(TReader* reader, TWriter* writer) {
    for (; reader->IsValid(); reader->Next()) {
        try {
            const auto& row = reader->GetRow();
            auto embedding = UserEmbeddingModel->Embed(row);
            Normalize(embedding);
            WriteEmbedding(writer, FromString<ui64>(row.GetYandexuid()), embedding, row.GetAttributes());
        } catch(const yexception& e) {
            WriteError(writer, e.what());
        }
    }
}

void TApplyUserDssmMapper::WriteEmbedding(TWriter* writer, ui64 userId, const TEmbedding& embedding, const NLab::TUserData::TAttributes& attributes) {
    TUserEmbedding userEmbedding;
    userEmbedding.SetUserId(userId);
    *userEmbedding.MutableEmbedding() = {embedding.begin(), embedding.end()};
    userEmbedding.MutableAttributes()->CopyFrom(attributes);
    writer->AddRow(userEmbedding, *OutputIndexes[EOutputTables::UserEmbeddings]);
}

void TApplyUserDssmMapper::WriteError(TWriter* writer, const TString& message) {
    TError error;
    error.SetMessage(message);
    writer->AddRow(error, *OutputIndexes[EOutputTables::Errors]);
}

TApplyUserDssmMapper::TOutputIndexes TApplyUserDssmMapper::PrepareOutput(NYT::TMapOperationSpec& spec, const TString& embeddings, const TString& errors) {
    TOutputIndexes::TBuilder outputBuilder;

    outputBuilder.AddOutput<TUserEmbedding>(spec, embeddings, EOutputTables::UserEmbeddings);
    outputBuilder.AddOutput<TError>(spec, errors, EOutputTables::Errors);

    return outputBuilder.GetIndexes();
}
