#include "dssm_core.h"

#include <saas/protos/rtyserver.pb.h>
#include <saas/rtyserver/components/dssm/embeddings_storage.fbs.h>
#include <saas/rtyserver/components/l2/l2_memory_manager.h>
#include <contrib/libs/flatbuffers/include/flatbuffers/flatbuffers.h>

namespace {
    void Save(IOutputStream& s, const google::protobuf::RepeatedPtrField<NRTYServer::TMessage_TEmbedding>& embeddings) {
        flatbuffers::FlatBufferBuilder fb;

        THashMap<TString, TVector<int> > searchkeyToEmbeddingIdx;
        for (int i = 0; i < embeddings.size(); i++) {
            const auto& embedding = embeddings[i];
            searchkeyToEmbeddingIdx[NRTYServer::ComposeSearchKey(embedding.GetName(), embedding.GetVersion())].push_back(i);
        }

        std::vector<flatbuffers::Offset<NEmbeddingsScheme::TEmbedding>> embeddingOffsets;
        embeddingOffsets.reserve(searchkeyToEmbeddingIdx.size());

        for (const auto& [searchKey, embeddingIndicies]: searchkeyToEmbeddingIdx) {
            auto nameOffset = fb.CreateString(searchKey.c_str());

            std::vector<flatbuffers::Offset<NEmbeddingsScheme::TValue>> valuesOffsets(embeddingIndicies.size());
            for (size_t i = 0; i < embeddingIndicies.size(); i++) {
                const int embedIdx = embeddingIndicies[i];

                const auto& val = embeddings[embedIdx].GetValue();
                auto valueOffset = fb.CreateVector(reinterpret_cast<const i8*>(val.data()), val.size());
                flatbuffers::Offset<flatbuffers::String> tagOffset;
                if (embeddings[embedIdx].HasTag()) {
                    tagOffset = fb.CreateString(embeddings[embedIdx].GetTag().c_str());
                } else {
                    tagOffset = fb.CreateString("");
                }
                valuesOffsets[i] = NEmbeddingsScheme::CreateTValue(fb, valueOffset, tagOffset);
            }
            embeddingOffsets.push_back(NEmbeddingsScheme::CreateTEmbedding(fb, nameOffset, fb.CreateVector(valuesOffsets)));
        }

        auto documentEmbeddings = NEmbeddingsScheme::CreateTDocumentEmbeddings(fb, fb.CreateVectorOfSortedTables(&embeddingOffsets));
        fb.Finish(documentEmbeddings);

        flatbuffers::DetachedBuffer buffer = fb.Release();
        s.Write(buffer.data(), buffer.size());
    }
}

namespace NRTYServer {
    void TDssmParser::Parse(TParsingContext& context) const {
        const TMessage::TDocument& document = context.Document;
        const auto& embeddings = document.GetEmbeddings();
        if (embeddings.empty()) {
            return;
        }

        auto entity = GetComponentEntity(context);
        if (entity != nullptr) {
            TBufferOutput data;
            Save(data, embeddings);
            WriteRawData(entity, context, TBlob::FromBuffer(data.Buffer()));
        }
    }
}
