#include "hnsw_search_request.h"

#include "config.h"

#include <saas/rtyserver/search/context/rty_search_context.h>
#include <kernel/knn_service/string_constants.h>
#include <kernel/knn_service/float_packing.h>
#include <util/string/cast.h>
#include <library/cpp/string_utils/base64/base64.h>
#include <util/string/split.h>
#include <util/string/type.h>


namespace NRTYServer {

 struct TKnnLegacyParams {

    TMaybe<TString> Text;
    TMaybe<TString> GroupingAttr; // component param: grouping_attr
    TMaybe<TString> EmbedFingerPrint;
    TVector<float> ParsedEmbed;
    THashSet<TString> Base64ProtectedGtaFields;
    ui32 TopSize = 10;                // component param: top_size
    TMaybe<ui32> SearchSize;          // component param: search_size
    TMaybe<bool> Base64EncodeFields;  // component param: base64_encode_fields

    TKnnLegacyParams(const TRTYSearchRequestContext& context) {
        auto& cgi = context.CgiParams();
        auto compParams = context.GetParameters(HnswComponentName);
        //NOTE: string constants must be syncronized with kernel/knn_service/component_api.cpp
        if (compParams != nullptr) {
            {
                auto ts = compParams->FindPtr(NKnnService::TopSizeParamCgiName);
                if (ts) {
                    TopSize = FromString<ui32>(*ts);
                }
            }

            {
                auto ss = compParams->FindPtr(NKnnService::SearchSizeParamCgiName);
                if (ss) {
                    SearchSize = FromString<ui32>(*ss);
                }
            }

            {
                auto fp = compParams->FindPtr(NKnnService::ModelFingerPrintParamCgiName);
                if (fp) {
                    EmbedFingerPrint = *fp;
                }
            }

            {
                auto b64encode = compParams->FindPtr(NKnnService::Base64EncodeFieldsParamCgiName);
                if (b64encode) {
                    Base64EncodeFields = IsTrue(*b64encode);
                }
            }

            {
                auto embText = compParams->FindPtr(NKnnService::TextEmbedParamCgiName);
                if (embText) {
                    StringSplitter(*embText).Split(',').ParseInto(&ParsedEmbed);
                    Text = *embText;
                }

                auto packedEmb = compParams->FindPtr(NKnnService::PackedEmbedParamCgiName);
                auto discrEmb = compParams->FindPtr(NKnnService::DiscrModeCgiName);
                bool useFloats = discrEmb && *discrEmb == "no";
                if (packedEmb) {
                    if (useFloats) {
                        ParsedEmbed = NKnnService::UnpackEmbedsBase64(*packedEmb);
                    } else {
                        ParsedEmbed = NKnnService::UnpackEmbedsAsInt8Base64(*packedEmb);
                    }
                }

                if (ParsedEmbed.empty()) {
                    auto& text = cgi.Get("text");
                    StringSplitter(text).Split(',').ParseInto(&ParsedEmbed);
                    Text = text;
                }
            }

            {
                auto it = compParams->find("grouping_attr");
                if (it != compParams->end()) {
                    GroupingAttr = it->second;
                }
            }

            if (auto ptr = compParams->FindPtr(NKnnService::Base64ProtectedFieldsParamCgiName)) {
                StringSplitter(*ptr).Split(',').AddTo(&Base64ProtectedGtaFields);
            }
        }
    }



        //also available cgi-params in full-arc style (see manager.cpp):
        // key_name
        // "gta" -> "_Body"
        // "gta" -> "_Timestamp"
        // "gta" -> "_UrlOnly"
        // "gta" -> <attr_only>
        // "hr" -> IsTrue
        // "laas_json"
        // "unpack_json" -> IsTrue
        // "gta" -> "_AllDocInfos"
        // "normal_kv_report" -> IsTrue
};

static void BasicValidateRequest(const NKnnService::NProtos::TRequest& req) {
    Y_ENSURE(req.GetEmbed().ComponentsSize() == req.GetEmbed().GetExpectedDim());
}

THnswSearchRequest::THnswSearchRequest(const TRTYSearchRequestContext& ctx) {
    //auto& cgi = ctx.CgiParams();
    auto compParams = ctx.GetParameters(HnswComponentName);

    try {
        //NOTE: string constants must be syncronized with kernel/knn_service/component_api.cpp
        Y_ENSURE(compParams != nullptr);

        {
            auto fp = compParams->FindPtr(NKnnService::MultiRequestParamName);
            if (fp) {
                Y_PROTOBUF_SUPPRESS_NODISCARD MultiRequest.ParseFromString(Base64Decode(*fp));
            } else {
                //we failed to get multi-request -> try to parse
                TKnnLegacyParams legacyParams(ctx);
                Y_ENSURE(legacyParams.ParsedEmbed.size() > 0);
                NKnnService::NProtos::TRequest& request = *MultiRequest.AddRequests();
                for (float f : legacyParams.ParsedEmbed) {
                    request.MutableEmbed()->MutableComponents()->Add(f);
                }
                request.MutableEmbed()->SetExpectedDim(legacyParams.ParsedEmbed.size());

                for (TString f : legacyParams.Base64ProtectedGtaFields) {
                    request.AddDocFieldsFilters(f);
                }

                if (legacyParams.EmbedFingerPrint) {
                    request.MutableEmbed()->SetModelFingerPrint(*legacyParams.EmbedFingerPrint);
                }
                request.SetNeedDocData(true);
                request.SetNeedDocBody(true);
                request.SetFetchAllDocFields(true);
                if (legacyParams.GroupingAttr) {
                    request.SetRewriteGroupNameByGtaField(*legacyParams.GroupingAttr);
                }

                if (legacyParams.Base64EncodeFields) {
                    request.SetBase64EncodeFields(*legacyParams.Base64EncodeFields);
                }

                request.SetShardTopSize(legacyParams.TopSize);

                if (legacyParams.SearchSize) {
                    request.SetOwerwriteSearchSize(*legacyParams.SearchSize);
                }
            }
        }

        for (auto& req : MultiRequest.GetRequests()) {
            BasicValidateRequest(req);
        }
    } catch (yexception& e) {
        ERROR_LOG << "Failed to parse request: " << e.what() << Endl;
        throw;
    }

}

}
