#include "make_segments_mapper.h"

#include <crypta/lookalike/proto/error.pb.h>
#include <crypta/lookalike/proto/user_segments.pb.h>
#include <crypta/lookalike/proto/yt_node_names.pb.h>

#include <util/random/shuffle.h>

using namespace NCrypta;
using namespace NCrypta::NLookalike;
using namespace NCrypta::NLookalike::NSegmentator;

TMakeSegmentsMapper::TMakeSegmentsMapper(TOutputIndexes outputIndexes, TMakeSegmentsMapperConfig config)
    : OutputIndexes(std::move(outputIndexes))
    , Config(std::move(config))
{
}

void NCrypta::NLookalike::NSegmentator::TMakeSegmentsMapper::Start(TWriter* writer) {
    try {
        HnswIndexModel = MakeHolder<THnswIndexModel>(
                Config.GetDimension(),
                TYtNodeNames().GetIndexFile(),
                TYtNodeNames().GetDataFile(),
                TYtNodeNames().GetLabelsFile(),
                Config.GetTopSize(),
                Config.GetTopSize());
    } catch (const yexception& e) {
        WriteError(writer, e.what());
        throw;
    }
}

void TMakeSegmentsMapper::Do(TReader* reader, TWriter* writer) {
    for (; reader->IsValid(); reader->Next()) {
        try {
            const auto& row = reader->GetRow();

            auto neighbors = HnswIndexModel->Retrieve(row.GetEmbedding());

            WriteSegments(writer, row.GetUserId(), neighbors);
        } catch(const yexception& e) {
            WriteError(writer, e.what());
        }
    }
}

void TMakeSegmentsMapper::WriteSegments(TWriter* writer, ui64 userId, const THnswIndexModel::TNeighbors& neighbors) {
    TUserSegments userSegments;

    userSegments.SetUserId(userId);

    for (const auto& neighbor : neighbors) {
        userSegments.AddSegments(neighbor.Id);
        userSegments.AddScores(neighbor.Dist);
    }

    writer->AddRow(userSegments, *OutputIndexes[EOutputTables::Segments]);
}

void TMakeSegmentsMapper::WriteError(TWriter* writer, const TString& message) {
    TError error;
    error.SetMessage(message);

    writer->AddRow(error, *OutputIndexes[EOutputTables::Errors]);
}

TMakeSegmentsMapper::TOutputIndexes TMakeSegmentsMapper::PrepareOutput(NYT::TMapOperationSpec& spec, const TString& segments, const TString& errors) {
    TOutputIndexes::TBuilder outputBuilder;

    outputBuilder.AddOutput<TUserSegments>(spec, segments, EOutputTables::Segments);
    outputBuilder.AddOutput<TError>(spec, errors, EOutputTables::Errors);

    return outputBuilder.GetIndexes();
}

REGISTER_MAPPER(TMakeSegmentsMapper);
