#include "component.h"

#include "builder.h"
#include "config.h"
#include "manager.h"
#include "hnsw_search_request.h"
#include "signals.h"

#include <quality/relev_tools/knn/lib/hnsw_graph/layered_graph_factory.h>
#include <quality/relev_tools/knn/lib/hnsw_graph/simple_layered_graph_impl.h>
#include <quality/relev_tools/knn/lib/document_index_reader.h>
#include <quality/relev_tools/knn/lib/hnsw_index_holder.h>
#include <quality/relev_tools/knn/knn_ops/lib/shards_meta_info_table.h>
#include <library/cpp/containers/ext_priority_queue/ext_priority_queue.h>
#include <library/cpp/threading/future/future.h>

namespace NRTYServer {

    struct THnswShardMetaInfo {
        TMaybe<TString> EmbedsFingerPrint;
        TMaybe<TString> ShardName;

        Y_NODE_IO_AUTO(
            EmbedsFingerPrint,
            ShardName
        )
    };

    THnswComponent::THnswComponent(const TRTYServerConfig& config)
        : IIndexComponent(config.ComponentsSet.contains(HnswComponentName))
        , Config(config)
    {
        ComponentConfig = Config.ComponentsConfig.Get<THnswComponentConfig>(HnswComponentName);
        Signals.Reset(new THnswComponentSearchSignals);
        SegmentsTreadPool = CreateThreadPool(
            ComponentConfig->SegmentSearchersThreadsNum,
            ComponentConfig->SegmentSearchersQueueLimit,
            IThreadPool::TParams().SetBlocking(false).SetCatching(false).SetThreadName("HnswComponent")
        );
    }

    TString THnswComponent::GetName() const {
        return HnswComponentName;
    }

    bool THnswComponent::IsCommonSearch() const {
        return false;
    }

    IIndexComponent::TPriorityInfo THnswComponent::GetPriority() const {
        auto priority = IIndexComponent::GetPriority();
        return priority;
    }

    THolder<IIndexComponentBuilder> THnswComponent::CreateBuilder(const TBuilderConstructionContext&) const {
        return MakeHolder<THnswIndexFakeBuilder>(GetName());
    }

    THolder<IIndexComponentManager> THnswComponent::CreateManager(const TManagerConstructionContext& context) const {
        if (context.IndexType != IIndexController::FINAL) {
            return nullptr;
        }

        TFsPath path = ComponentConfig->IndexPath;
        if (!path) {
            path = context.Dir.PathName();
        }

        THnswShardMetaInfo meta = NNodeIo::FromNode(NNodeIo::NodeFromString(
            TUnbufferedFileInput(path / NKnnOps::HnswIndexMetaInfoFileName).ReadAll()));

        auto tryGetLockedBlob = [this] (TFsPath path) -> TBlob {
            TBlob res;
            if (ComponentConfig->TryUseLockedMapingsForAllFiles) {
                try {
                    res = TBlob::LockedFromFile(path);
                } catch (...) {
                    ERROR_LOG << "tryed to get locked-file for " << path.GetPath() << " but failed, fall back to anonymous memory\n";
                }
            }
            if (res.Size() == 0) {
                res = TBlob::PrechargedFromFile(path);
            }
            return res;
        };

        NKnnOps::TPackedEmbedsStorage embedsStorage;
        if (NFs::Exists(path / NKnnOps::HnswEmbedsBlobIndexFileNameHtxt)) {
            embedsStorage = NKnnOps::TPackedEmbedsStorage::ReadHtxt(
                tryGetLockedBlob(path / NKnnOps::HnswEmbedsBlobIndexFileNameHtxt), meta.EmbedsFingerPrint
            );
        } else {
            embedsStorage = NKnnOps::TPackedEmbedsStorage::ReadBin(
                tryGetLockedBlob(path / NKnnOps::HnswEmbedsBlobIndexFileNameBin), meta.EmbedsFingerPrint
            );
        }

        THolder<NHnswGraph::ILayeredGraph> layeredGraph;
        if (NFs::Exists(path / NKnnOps::HnswNearedGraphIndexFileNameHtxt)) {
            TBlob graphBlob = tryGetLockedBlob(path / NKnnOps::HnswNearedGraphIndexFileNameHtxt);
            layeredGraph = NHnswGraph::TLayeredGraphFactory::ReadLayeredGraphFromHtxt(graphBlob);
        } else {
            TBlob graphBlob = tryGetLockedBlob(path / NKnnOps::HnswNearedGraphIndexFileNameBin);
            layeredGraph = NHnswGraph::TSimpleLayeredGraphImpl::ReadLibHnswIndexFormat(graphBlob);
        }

        TAtomicSharedPtr<NKnnOps::THnswIndexHolder> hnswGraph = new NKnnOps::THnswIndexHolder(
            embedsStorage,
            std::move(layeredGraph)
        );

        TAtomicSharedPtr<NKnnOps::TDocumentIndexReader> docIndex = new NKnnOps::TDocumentIndexReader(
            path,
            {
                .UseDirectMode = ComponentConfig->DirectModeForKvData,
                .TryUseLockedFile = !ComponentConfig->DirectModeForKvData,
                .LockFat = ComponentConfig->LockDocumentIndexFat,
                .PrechargeFat = !ComponentConfig->LockDocumentIndexFat
            }
        );

        return MakeHolder<THnswIndexManager>(meta.ShardName.GetOrElse("unnamed"), *ComponentConfig, hnswGraph, docIndex);
    }

    class TDummyComponentParser: public TComponentParser {
    public:
        void Parse(TParsingContext&) const override {
        }
    };

    IComponentParser::TPtr THnswComponent::BuildParser() const {
        return new TDummyComponentParser{};
    }

    class TDummyParsedEntity: public TParsedDocument::TParsedEntity {
    public:
        TDummyParsedEntity(TConstructParams& params)
            : TParsedDocument::TParsedEntity(params)
        {
        }
    };

    IParsedEntity::TPtr THnswComponent::BuildParsedEntity(IParsedEntity::TConstructParams& params) const {
        return new TDummyParsedEntity(params);
    }

    bool THnswComponent::DoAllRight(const TNormalizerContext& /*context*/) const {
        return true;
    }

    void THnswComponent::CheckAndFix(const TNormalizerContext& /*context*/) const {}

    const IIndexComponent::TIndexFiles& THnswComponent::GetIndexFiles() const {
        return IndexFiles;
    }

    bool THnswComponent::DoMerge(const TMergeContext& /*context*/) const {
        return true;
    }

    struct THnswComponentFoundDocId {
        const THnswIndexManager* ControllerPtr = nullptr;
        ui32 InPartId = Max();
        //float Distance = 0; //bigger-better
    };

    void THnswComponent::SearchCustom(const TVector<IIndexController::TPtr>& controllers,
                                      ICustomReportBuilder& report,
                                      const TRTYSearchRequestContext& context) const {
        Y_IF_DEBUG(DEBUG_LOG << "run THnswComponent::SearchCustom " << Endl;)
        THnswSearchRequest multiRequest(context);

        THnswComponentSearchStats fullStats;
        //NOTE: cpu-optimization can be performed here: local resulsts are sorted, so we can use merge-sort instead heap
        for(size_t reqId : xrange(multiRequest.MultiRequest.RequestsSize())) {
            const NKnnService::NProtos::TRequest& request = multiRequest.MultiRequest.GetRequests(reqId);
            Y_ENSURE(request.GetShardTopSize() < ComponentConfig->MaxAllowedTopSize);
            Y_ENSURE(request.GetOwerwriteSearchSize() < ComponentConfig->MaxAllowedSearchSize);

            using TValue = THolder<THnswIndexManager::TSearchResult>;
            TVector<NThreading::TFuture<
                TValue
            >> segmentsResults(controllers.size());

            TInstant tasksStart = Now();

            for (auto taskId : xrange(controllers.size())) {
                auto promise = NThreading::NewPromise<TValue>();
                segmentsResults[taskId] = promise.GetFuture();
                auto function = [&, taskId, promise] () mutable {
                    try {
                        auto& controller = controllers[taskId];
                        auto hnsw = dynamic_cast<const THnswIndexManager*>(controller->GetManager(HnswComponentName));
                        VERIFY_WITH_LOG(hnsw, "failed dyn cast");
                        TValue res = THolder(new THnswIndexManager::TSearchResult);
                        hnsw->FindNearest(request, *res);
                        promise.SetValue(std::move(res));
                    } catch (...) {
                        promise.SetException(std::exception_ptr());
                    }
                };

                if (!SegmentsTreadPool->AddFunc(function)) {
                    function();
                }
            }
            NThreading::WaitExceptionOrAll(segmentsResults).GetValueSync();

            TPriorityTopN<THnswComponentFoundDocId, float> resultsHeap(request.GetShardTopSize());
            size_t usedSearchSize = 0;
            THashSet<std::pair<TString, TString>> logLines;
            for(auto& future : segmentsResults) { //merging
                auto& r = *future.GetValueSync();
                for (const auto localResult : r.Nearest) {
                    resultsHeap.push(THnswComponentFoundDocId{
                        .ControllerPtr = r.Self,
                        .InPartId = localResult.DocId
                    }, localResult.DotProductDist);
                }
                usedSearchSize = r.UsedSearchSize;
                for(auto& l : r.LogLines) {
                    logLines.insert(l);
                }

                fullStats.UsedDotProductsNum +=
                    r.HnswStats.CheckedCandidatesOnEntryChoose + r.HnswStats.CheckedCandidatesOnMainStage;
                fullStats.InSegmentsQueueSumWaitTime += r.SearchStartTime - tasksStart;
            }

            for (auto& l : logLines) {
                report.AddReportProperty(l.first, l.second);
            }
            if (usedSearchSize) {
                report.AddReportProperty("UsedSearchSize", usedSearchSize);
            }

            {
                auto& container = resultsHeap.Container();
                Sort(container);

                if (!container.empty()) {
                    Y_ASSERT(container.rbegin()->Pri >= container.rend()->Pri);
                }

                for (auto it = container.begin(); it != container.end(); ++it) {
                    const float clampedDist = ComponentConfig->ResultDistanceClampInterval.Clamp(it->Pri);
                    if (clampedDist != it->Pri) {
                        fullStats.ClampedValsNum += 1;
                    }
                    it->Data.ControllerPtr->AddDocToReport(
                        reqId, request,
                        report, context.CgiParams(),
                        it->Data.InPartId, clampedDist
                    );
                }
            }
        }


        Signals->PushSearchStats(fullStats);
        Y_IF_DEBUG(DEBUG_LOG << "done THnswComponent::SearchCustom " << Endl;)
    }
}
