#include "dssm_manager.h"

#include <saas/rtyserver/components/dssm/embeddings_storage.fbs.h>
#include <saas/rtyserver/factors/function.h>

#include <kernel/dssm_applier/begemot/production_data.h>
#include <kernel/dssm_applier/utils/utils.h>
#include <library/cpp/string_utils/base64/base64.h>
#include <library/cpp/json/json_writer.h>

#include <util/stream/mem.h>

namespace {
    constexpr TStringBuf NAME_ATTRIBUTE = "name";
    constexpr TStringBuf VALUES_ATTRIBUTE = "values";

    class DocDssmDecompressionResult: public NRTYFactors::IFactorCalculate {
    public:
        DocDssmDecompressionResult() {
            SetDependsOnDoc();
        }

        explicit DocDssmDecompressionResult(TVector<TVector<float> >&& result, TMaybe<TVector<TString>>&& tags)
            : Value(std::move(result))
            , Tags(std::move(tags)) { }

        TVector<float> DoCalcFloatVector(TCalcFactorsContext& /*ctx*/) const override {
            Y_ENSURE(Value.size() == 1);
            return Value[0];
        }

        TVector<TVector<float>> DoCalcFloatVectors(TCalcFactorsContext& /*ctx*/) const override {
            return Value;
        }
        TVector<TString> DoCalcBlobVector(TCalcFactorsContext& /*ctx*/) const override {
            if (Tags.Defined()) {
                return Tags.GetRef();
            } else {
                return TVector<TString>();
            }
        }

        TStringBuf GetCalculatorName() const override {
            return "doc_dssm_decompression";
        }

    private:
        TVector<TVector<float>> Value;
        TMaybe<TVector<TString>> Tags;
    };

    class TDocDssmDecompressionResultWithVersions: public NRTYFactors::IFactorCalculate {
    public:
        TDocDssmDecompressionResultWithVersions() = default;
        explicit TDocDssmDecompressionResultWithVersions(TVector<TVector<float>>&& embeddings, TVector<TString>&& versions)
            : Embeddings(std::move(embeddings))
            , Versions(std::move(versions))
        {}

        TVector<TVector<float>> DoCalcFloatVectors(TCalcFactorsContext&) const override {
            return Embeddings;
        }

        TVector<TString> DoCalcBlobVector(TCalcFactorsContext&) const override {
            return Versions;
        }
    private:
        TVector<TVector<float>> Embeddings;
        TVector<TString> Versions;
    };

    TVector<float> AutoMaxCoordRenormDecompression(const ui8* data, size_t size) {
        TVector<float> decompressed = NDssmApplier::NUtils::TFloat2UI8Compressor::Decompress(TConstArrayRef<ui8>{data, size});
        NNeuralNetApplier::NormalizeVector(decompressed);
        return decompressed;
    }

    TVector<float> Float32Decompression(const ui8* data, size_t size) {
        Y_ENSURE(size % 4 == 0, "found embedding with incorrect length");
        TVector<float> decompressed(size / 4);
        memcpy((void *)decompressed.data(), data, size);
        return decompressed;
    }

    TVector<float> DssmModelDecompression(const ui8* data, size_t size, TStringBuf compressionAlgo) {
        TVector<float> decompressed;
        auto modelType = ::FromString<NNeuralNetApplier::EDssmModel>(compressionAlgo);
        auto serializer = NNeuralNetApplier::GetDssmDataSerializer(modelType);
        TMemoryInput argsStream(data, size);
        serializer->Load(&argsStream);
        decompressed.reserve(size);
        serializer->TryGetLatestEmbedding(decompressed);
        NNeuralNetApplier::NormalizeVector(decompressed);
        return decompressed;
    }

    TVector<float> DecompressEmbedding(const ui8* data, size_t size, TStringBuf compressionAlgo) {
        if (compressionAlgo == "AutoMaxCoordRenorm") {
            return AutoMaxCoordRenormDecompression(data, size);
        } else if (compressionAlgo == "Float32") {
            return Float32Decompression(data, size);
        }
        return DssmModelDecompression(data, size, compressionAlgo);
    }

    const NEmbeddingsScheme::TDocumentEmbeddings* ToDocumentEmbeddings(const TBlob& rawData) {
        if (rawData.IsNull() || rawData.Empty()) {
            return nullptr;
        }
        return NEmbeddingsScheme::GetTDocumentEmbeddings(rawData.Data());
    }

    const NEmbeddingsScheme::TEmbedding* ExtractEmbedding(const TBlob& rawData, const TString& searchKey) {
        const NEmbeddingsScheme::TDocumentEmbeddings* documentEmbeddings = ToDocumentEmbeddings(rawData);
        if (documentEmbeddings == nullptr) {
            return nullptr;
        }
        return documentEmbeddings->embeddings()->LookupByKey(searchKey.c_str());
    }

    THolder<NRTYFactors::IFactorCalculate> ConstructDssmDecompressionResult(const NEmbeddingsScheme::TEmbedding* result, TStringBuf compressionAlgo, bool extractTags) {
        const auto embedValues = result->values();
        Y_ASSERT(embedValues != nullptr && embedValues->size() >= 1);

        TVector<TVector<float>> value(Reserve(embedValues->size()));
        TMaybe<TVector<TString>> tags;
        if (extractTags) {
            tags = TVector<TString>(Reserve(embedValues->size()));
        }
        for (size_t i = 0; i < embedValues->size(); i++) {
            const auto& compressedVec = ((*embedValues)[i])->body();
            TVector<float> decompressed = DecompressEmbedding(reinterpret_cast<const ui8*>(compressedVec->data()), compressedVec->size(), compressionAlgo);
            value.push_back(std::move(decompressed));
            if (extractTags) {
                auto tag = ((*embedValues)[i])->tag();
                if (tag != nullptr) {
                    tags->emplace_back(tag->c_str());
                } else {
                    tags->emplace_back("<default_tag>");
                }
            }
        }
        return THolder(new DocDssmDecompressionResult(std::move(value), std::move(tags)));
    }

} // empty namespace

namespace NRTYServer {
    // copy-paste of https://a.yandex-team.ru/arc/trunk/arcadia/kernel/dssm_applier/utils/utils.cpp?rev=6559682#L146
    // do not use their function: if they change it, we may loose the ability to search
    TString ComposeSearchKey(const TString& name, const TString& version) {
        if (version.empty()) {
            return name;
        }
        return name + "_v_" + version;
    }

    TString TDssmManager::GetDocumentEmbeddingsJson(ui32 docId) const {
        TBlob rawData = FetchRawL2Data(docId);
        if (rawData.IsNull() || rawData.Empty()) {
            return {};
        }

        const NEmbeddingsScheme::TDocumentEmbeddings* documentEmbeddings = NEmbeddingsScheme::GetTDocumentEmbeddings(rawData.Data());
        if (documentEmbeddings == nullptr) {
            return {};
        }

        NJson::TJsonValue resultJson(NJson::JSON_ARRAY);
        for (const NEmbeddingsScheme::TEmbedding* embedding : *documentEmbeddings->embeddings()) {
            NJson::TJsonValue item(NJson::JSON_MAP);

            const auto& name = embedding->name();
            item.InsertValue(NAME_ATTRIBUTE, TStringBuf{name->data(), name->size()});

            if (embedding->values() != nullptr) {
                NJson::TJsonValue values(NJson::JSON_ARRAY);
                for (const auto& value: *embedding->values()) {
                    const auto& compressedVec = value->body();
                    values.AppendValue(Base64Encode(TStringBuf{reinterpret_cast<const TStringBuf::char_type*>(compressedVec->data()), compressedVec->size()}));
                }
                item.InsertValue(VALUES_ATTRIBUTE, std::move(values));
            }

            resultJson.AppendValue(std::move(item));
        }
        return NJson::WriteJson(resultJson, false /*formatOutput*/, false /*sortkeys*/, false /*validateUtf8*/);
    }

    THolder<NRTYFactors::IFactorCalculate> TDssmManager::GetDecompressedEmbeddingInternal(
        TCalcFactorsContext& ctx, TRTYSearchStatistics& searchStatistics, TArrayRef<const NRTYFactors::IFactorCalculate*> args, bool checkSingle, bool extractTags) const
    {
        searchStatistics.ReportDocumentDssmRequested();

        const TBlob rawData = FetchRawL2Data(ctx.DocId);
        const TString embeddingName = args[0]->CalcBlob(ctx);
        TString version;
        if (args.size() > 2) {
            version = args[2]->CalcBlob(ctx);
        }

        const TString searchKey = ComposeSearchKey(embeddingName, version);
        const NEmbeddingsScheme::TEmbedding* result = ExtractEmbedding(rawData, searchKey);

        if (result == nullptr || result->values() == nullptr || result->values()->size() == 0) {
            searchStatistics.ReportDocumentDssmNotFound();
            DEBUG_LOG << "No embedding '" << searchKey << "' is found" << Endl;
            ythrow NRTYFactors::TFactorCalculateDefaultValueException() << "No embedding '" << searchKey << "' is found";
        }
        if (checkSingle && result->values()->size() > 1) {
            searchStatistics.ReportMultipleDocumentDssmFound();
            DEBUG_LOG << "Multiple embeddings '" << searchKey << "' are found" << Endl;
            ythrow NRTYFactors::TFactorCalculateDefaultValueException() << "Multiple embeddings '" << searchKey << "' are found";
        }
        const TString& compressionAlgo = args[1]->CalcBlob(ctx);
        return ConstructDssmDecompressionResult(result, compressionAlgo, extractTags);
    }

    THolder<NRTYFactors::IFactorCalculate> TDssmManager::GetDecompressedEmbedding(TCalcFactorsContext& ctx, const TRTYFunctionCtx& uctx, TArrayRef<const NRTYFactors::IFactorCalculate*> args) const {
        Y_ASSERT(uctx.SearchStatistics);
        TRTYSearchStatistics& searchStatistics = *uctx.SearchStatistics;
        return GetDecompressedEmbeddingInternal(ctx, searchStatistics, args, /*checkSingle = */true, /*extractTags = */false);
    }

    THolder<NRTYFactors::IFactorCalculate> TDssmManager::GetDecompressedEmbeddings(TCalcFactorsContext& ctx, const TRTYFunctionCtx& uctx, TArrayRef<const NRTYFactors::IFactorCalculate*> args) const {
        Y_ASSERT(uctx.SearchStatistics);
        TRTYSearchStatistics& searchStatistics = *uctx.SearchStatistics;
        return GetDecompressedEmbeddingInternal(ctx, searchStatistics, args, /*checkSingle = */false, /*extractTags = */false);
    }

    THolder<NRTYFactors::IFactorCalculate> TDssmManager::GetDecompressedEmbeddingsWithTags(TCalcFactorsContext& ctx, const TRTYFunctionCtx& uctx, TArrayRef<const NRTYFactors::IFactorCalculate*> args) const {
        Y_ASSERT(uctx.SearchStatistics);
        TRTYSearchStatistics& searchStatistics = *uctx.SearchStatistics;
        return GetDecompressedEmbeddingInternal(ctx, searchStatistics, args, /*checkSingle = */false, /*extractTags = */true);
    }

    THolder<NRTYFactors::IFactorCalculate> TDssmManager::GetDecompressedEmbeddingsByVersion(TCalcFactorsContext& ctx, const TRTYFunctionCtx& uctx, TArrayRef<const NRTYFactors::IFactorCalculate*> args) const {
        Y_ASSERT(uctx.SearchStatistics);
        TRTYSearchStatistics& searchStatistics = *uctx.SearchStatistics;

        searchStatistics.ReportDocumentDssmRequested();
        const TBlob rawData = FetchRawL2Data(ctx.DocId);
        const NEmbeddingsScheme::TDocumentEmbeddings* documentEmbeddings = ToDocumentEmbeddings(rawData);
        
        TVector<TString> versions;
        TVector<TVector<float>> embeddings;
        if (documentEmbeddings) {
            const TString embeddingName = args[0]->CalcBlob(ctx);
            const TString compressionAlgo = args[1]->CalcBlob(ctx);
            for (size_t i = 2; i < args.size(); i++) {
                const TString& version = args[i]->CalcBlob(ctx);
                auto searchKey = ComposeSearchKey(embeddingName, version);
                
                auto* embedding = documentEmbeddings->embeddings()->LookupByKey(searchKey.c_str());
                if (embedding && embedding->values()) {
                    if (embedding->values()->size() != 1) {
                        searchStatistics.ReportMultipleDocumentDssmFound();
                        ythrow NRTYFactors::TFactorCalculateDefaultValueException() << "Multiple embeddings '" << searchKey << "' are found";
                    }
                    const auto* embeddingValue = (*embedding->values())[0]->body();

                    embeddings.push_back(DecompressEmbedding(
                        reinterpret_cast<const ui8*>(embeddingValue->data()),
                        embeddingValue->size(),
                        compressionAlgo
                    ));
                    versions.push_back(version);
                } 
            }
        } else {
            searchStatistics.ReportDocumentDssmNotFound();
        }
        return MakeHolder<TDocDssmDecompressionResultWithVersions>(std::move(embeddings), std::move(versions));
    }

    void TDssmManager::GetExportedFunctions(NRTYFeatures::TImportedFunctionsBuilder& exports) const {
        exports.AddGta(&TDssmManager::GetDocumentEmbeddingsJson, this, "EmbeddingsJson");
        exports.Add<NRTYFeatures::TFactorCalcerGenericUserFunc>(&TDssmManager::GetDecompressedEmbedding, this, "doc_dssm_decompress", 2, 3);
        exports.Add<NRTYFeatures::TFactorCalcerGenericUserFunc>(&TDssmManager::GetDecompressedEmbeddings, this, "doc_dssm_array_decompress", 2, 3);
        exports.Add<NRTYFeatures::TFactorCalcerGenericUserFunc>(&TDssmManager::GetDecompressedEmbeddingsWithTags, this, "doc_dssm_array_decompress_with_tags", 2, 3);
        exports.Add<NRTYFeatures::TFactorCalcerGenericUserFunc>(&TDssmManager::GetDecompressedEmbeddingsByVersion, this, "doc_dssm_array_decompress_with_versions", 3, 12);
    }
} // namespace NRTYServer
