#include "manager.h"

#include "config.h"
#include "hnsw_search_request.h"

#include <saas/library/proto_helper/proto_helper.h>
#include <saas/rtyserver/config/config.h>
#include <saas/rtyserver/components/fullarchive/manager.h>
#include <saas/rtyserver/model/index.h>
#include <library/cpp/protobuf/json/proto2json.h>
#include <kernel/knn_service/component_api.h>
#include <kernel/knn_service/string_constants.h>
#include <util/thread/pool.h>

namespace NRTYServer {

    bool THnswIndexManager::DoOpen() {
        return true;
    }

    bool THnswIndexManager::DoClose() {
        return true;
    }

    ui32 THnswIndexManager::GetDocumentsCount() const {
        return Max<ui32>();
    }

    void THnswIndexManager::InitInteractions(const IIndexManagersStorage& /*storage*/) {
    }

    ui32 THnswIndexManager::RemoveDocids(const TVector<ui32>& /*docIds*/) {
        return 0;
    }

    ERTYSearchResult THnswIndexManager::DoSearch(const TRTYSearchRequestContext& /*context*/,
                                                 ICustomReportBuilder& /*report*/,
                                                 const IIndexController& /*controller*/) const {
        Y_ENSURE(false, "unimplemented");
    }

    void THnswIndexManager::AddDocToReport(
        size_t requestId,
        const NKnnService::NProtos::TRequest& request,
        ICustomReportBuilder& reportBuilder,
        const TCgiParameters& /*parameters*/,
        ui32 docId,
        float dotProductDist) const
    {
        NMetaProtocol::TDocument doc;
        TString groupVal = request.GetTargetGroupName();

        doc.SetDocId(ToString(docId));
        if (!request.GetSkipShardId()) {
            auto shardInfo = doc.MutableArchiveInfo()->AddGtaRelatedAttribute();
            shardInfo->SetKey(TString{NKnnService::ShardIdDocAttrName});
            shardInfo->SetValue(EscapeC(this->Name));
        }
        if (!request.GetSkipRequestId()) {
            auto shardInfo = doc.MutableArchiveInfo()->AddGtaRelatedAttribute();
            shardInfo->SetKey(TString{NKnnService::RequestIdInDocAttrName});
            shardInfo->SetValue(ToString(requestId));
        }

        if (request.GetNeedDocData()) {
            TMaybe<NKnnDocumentIndex::TParsedDoc> storedPb = DocIndex->GetSaasDocumentData(docId);
            Y_ENSURE(!!storedPb, "docid " << docId << " not exists in doc-index");

            if (!request.GetSkipTitle()) {
                doc.MutableArchiveInfo()->SetTitle(storedPb->GetDocument().GetUrl());
            }
            if (!request.GetSkipUrl()) {
                doc.SetUrl(storedPb->GetDocument().GetUrl());
            }

            if (request.GetNeedDocBody()) {
                auto prop = doc.MutableArchiveInfo()->AddGtaRelatedAttribute();
                prop->SetKey(TString{NKnnService::BodyGtaAttrName});
                const auto& body = storedPb->GetDocument().GetBody();
                prop->SetValue(request.GetBase64EncodeFields() ? Base64Encode(body) : body);
            }
            THashSet<TString> filters;
            for (auto f : request.GetDocFieldsFilters()) {
                filters.insert(f);
            }

            for (auto& prop : storedPb->GetDocument().GetDocumentProperties()) {
                if (filters.empty() || filters.contains(prop.GetName())) {
                    auto attr = doc.MutableArchiveInfo()->AddGtaRelatedAttribute();
                    attr->SetKey(EscapeC(prop.GetName()));
                    attr->SetValue(request.GetBase64EncodeFields() ? Base64Encode(prop.GetValue()) : prop.GetValue());
                }
                if (request.HasRewriteGroupNameByGtaField() && prop.GetName() == request.GetRewriteGroupNameByGtaField()) {
                    groupVal = prop.GetValue();
                }
            }
        }
        NKnnService::TResponseTraits::SetDistanceFromQuery(doc, dotProductDist);

        reportBuilder.AddDocumentToGroup(
            doc, EscapeC(request.GetTargetGroupingName()), EscapeC(groupVal));
    }

    ui32 THnswIndexManager::DoSearch(const NKnnService::NProtos::TRequest& /*request*/,
                                     const TCgiParameters& /*parameters*/,
                                     ICustomReportBuilder& /*reportBuilder*/) const
    {
        Y_ENSURE(false, "unimplemented");
    }

    void THnswIndexManager::FindNearest(
        const NKnnService::NProtos::TRequest& request,
        TSearchResult& dst) const
    {
        dst.SearchStartTime = Now();
        dst.Self = this;

        auto& query = request.GetEmbed().GetComponents();
        if (!Config.IgnoreFingerPrints
            && request.GetEmbed().GetModelFingerPrint()
            && !Config.AllowedFingerprintMappings.IsCompitable(
                request.GetEmbed().GetModelFingerPrint(), HnswGraph->Storage.FingerPrint.GetOrElse(""))
        ) {
            dst.LogLines.insert({
                "skippedByEmbedsMissmatch",
                request.GetEmbed().GetModelFingerPrint() + "!=" + HnswGraph->Storage.FingerPrint.GetOrElse("")}
            );
            return;
        }

        size_t searchSizeToUse = Config.SearchSizeChooser.GetRecommendedSearchSize(request.GetShardTopSize());
        if (request.HasOwerwriteSearchSize()) {
            searchSizeToUse = request.GetOwerwriteSearchSize();
        }
        dst.UsedSearchSize = searchSizeToUse;
        NHnswGraph::TNearestSearchOpts searchOpts = {
            .ResultTopSize = request.GetShardTopSize(),
            .ActiveNodesSize = searchSizeToUse,

            .MaxEntryChangesAtLevel = Max(),
            .MaxCheckedCandidatesOnMainStage = Max(),

            .WriteHigherPrioritiesFirst = true,
        };
        auto result = HnswGraph->FindNearest(
            query,
            searchOpts,
            &dst.HnswStats,
            Nothing()
        );
        dst.Nearest.assign(result.begin(), result.end());
    }

    THnswIndexManager::THnswIndexManager(TString name, const THnswComponentConfig & config, TAtomicSharedPtr<NKnnOps::THnswIndexHolder> graph, TAtomicSharedPtr<NKnnOps::TDocumentIndexReader> docInex)
        : IIndexComponentManager(HnswComponentName)
        , HnswGraph(graph)
        , DocIndex(docInex)
        , Config(config)
        , Name(name)
    {
    }
}
