#include "apply_segment_dssm_mapper.h"

#include <crypta/lib/native/log/loggers/std_logger.h>
#include <crypta/lookalike/lib/native/normalize.h>
#include <crypta/lookalike/lib/native/pqlib_producer_initializer.h>
#include <crypta/lookalike/proto/error_with_segment_id.pb.h>
#include <crypta/lookalike/proto/mode.pb.h>
#include <crypta/lookalike/proto/yt_node_names.pb.h>
#include <crypta/lookalike/services/lal_manager/commands/lal_cmd.pb.h>

#include <library/cpp/protobuf/json/proto2json.h>

#include <util/random/shuffle.h>
#include <util/system/env.h>

using namespace NCrypta;
using namespace NCrypta::NLookalike;
using namespace NCrypta::NLookalike::NSegmentDssmApplier;

TApplySegmentDssmMapper::TApplySegmentDssmMapper(TMapperConfig config, TOutputIndexes outputIndexes)
    : Config(std::move(config))
    , OutputIndexes(std::move(outputIndexes))
{
}

void TApplySegmentDssmMapper::Start(TWriter* writer) {
    try {
        using namespace NCrypta::NPQ;

        auto log = NLog::NStdLogger::RegisterLog("pqlib", "stderr", "debug");

        SegmentEmbeddingModel = MakeHolder<TSegmentEmbeddingModel>(
            TYtNodeNames().GetDssmModelFile(),
            TYtNodeNames().GetSegmentsDictFile());

        Config.MutableWriter()->MutableCredentials()->MutableTvm()->SetClientTvmSecret(GetEnv("YT_SECURE_VAULT_TVM_SECRET"));

        Producer = GetPqlibProducer(Config.GetWriter());
        Mode = Config.GetMode();
    } catch (const yexception& e) {
        WriteError(writer, Nothing(), e.what());
    }
}

void TApplySegmentDssmMapper::Finish(TWriter* writer) {
    try {
        Producer->Stop(TDuration::Seconds(30));
    } catch (const yexception& e) {
        WriteError(writer, Nothing(), e.what());
        throw;
    }
}

void TApplySegmentDssmMapper::Do(TReader* reader, TWriter* writer) {
    for (; reader->IsValid(); reader->Next()) {
        auto row = reader->MoveRow();
        TMaybe<ui64> segmentId;

        try {
            segmentId = FromString<ui64>(row.GetKey());
            auto& value = *row.MutableValue();

            bool isNew = value.GetMeta().GetNew();

            if (Mode == ModeValue::NEW && !isNew) {
                continue;
            }

            if (value.HasDescription()) {
                auto embedding = SegmentEmbeddingModel->Embed(value.GetDescription().GetUserDataStats());
                Normalize(embedding);

                WriteEmbedding(writer, *segmentId, embedding);
                WriteParentMeta(writer, *segmentId, *value.MutableMeta());

                if (isNew) {
                    ChangeNewness(*segmentId);
                }
            }
        } catch (const yexception& e) {
            WriteError(writer, segmentId, e.what());
        }
    }
}

void TApplySegmentDssmMapper::WriteEmbedding(TWriter* writer, ui64 segmentId, const TEmbedding& embedding) {
    TSegmentEmbedding segmentEmbedding;

    segmentEmbedding.SetSegmentId(segmentId);
    *segmentEmbedding.MutableEmbedding() = {embedding.begin(), embedding.end()};

    writer->AddRow(segmentEmbedding, *OutputIndexes[EOutputTables::SegmentEmbeddings]);
}

void TApplySegmentDssmMapper::WriteParentMeta(TWriter* writer, ui64 segmentId, TSegmentMeta& meta) {
    TSegmentMetaEntry segmentMetaEntry;

    segmentMetaEntry.SetSegmentId(segmentId);
    segmentMetaEntry.MutableMeta()->Swap(&meta);

    writer->AddRow(segmentMetaEntry, *OutputIndexes[EOutputTables::SegmentParentMeta]);
}

void TApplySegmentDssmMapper::WriteError(TWriter* writer, TMaybe<ui64> segmentId, const TString& message) {
    TErrorWithSegmentId error;
    if (segmentId.Defined()) {
        error.SetSegmentId(*segmentId);
    }
    error.SetMessage(message);

    writer->AddRow(error, *OutputIndexes[EOutputTables::Errors]);
}

TApplySegmentDssmMapper::TOutputIndexes TApplySegmentDssmMapper::PrepareOutput(
    NYT::TMapOperationSpec& spec,
    const TString& embeddings,
    const TString& metas,
    const TString& errors) {
    TOutputIndexes::TBuilder outputBuilder;

    outputBuilder.AddOutput<TSegmentEmbedding>(spec, embeddings, EOutputTables::SegmentEmbeddings);
    outputBuilder.AddOutput<TSegmentMetaEntry>(spec, metas, EOutputTables::SegmentParentMeta);
    outputBuilder.AddOutput<TErrorWithSegmentId>(spec, errors, EOutputTables::Errors);

    return outputBuilder.GetIndexes();
}

void TApplySegmentDssmMapper::ChangeNewness(ui64 lalId) {
    TLalCmd cmd;

    cmd.MutableChangeNewnessCmd()->SetLalId(lalId);
    cmd.MutableChangeNewnessCmd()->SetNewnessState(false);

    Y_ENSURE(Producer->TryEnqueue(NProtobufJson::Proto2Json(cmd)));
}

REGISTER_MAPPER(TApplySegmentDssmMapper);
