#include "search.h"
#include "context.h"
#include "replier.h"

#include <saas/library/metasearch/simple/policy.h>
#include <saas/library/searchserver/exception.h>
#include <saas/library/sharding/sharding.h>

#include <saas/searchproxy/unistat_signals/signals.h>
#include <saas/searchproxy/common/messages.h>
#include <saas/searchproxy/search_meta/extendedcontext.h>
#include <saas/searchproxy/logging/incoming_log.h>
#include <saas/searchproxy/logging/access_log.h>
#include <saas/searchproxy/logging/error_log.h>
#include <saas/searchproxy/logging/reqans_log.h>

#include <search/meta/metasearch.h>
#include <search/meta/scatter/source.h>
#include <search/meta/scatter/sd.h>
#include <search/common_signals/counter_factory.h>

#include <library/cpp/containers/stack_vector/stack_vec.h>
#include <library/cpp/logger/global/global.h>

#include <util/string/split.h>

namespace NProxyMeta {
    struct TShardIntervalEnumerator : NSaas::IShardIntervalCallback {
        TShardIntervalEnumerator(const NSaas::TShardsDispatcher& shardDispatcher,
                                 const NSaas::TShardIntervals& shardIntervals,
                                 const TVector<THolder<NScatter::ISource>>& sources,
                                 TKeysForSource& keysForSource)
            : ShardDispatcher(shardDispatcher)
            , ShardIntervals(shardIntervals)
            , Sources(sources)
            , KeysForSource(keysForSource)
        {
        }

        void FindShard(const TString& key, ui64 kps) {
            CurrentKey = &key;
            FoundSource = false;
            ShardDispatcher.EnumerateIntervals(key, kps, ShardIntervals, *this);
            if (!FoundSource) {
                ythrow TSearchException(HTTP_INTERNAL_SERVER_ERROR) << "No shard found for key " << key << ':' << kps;
            }
        }

        void BroadcastKey(const TString& key) {
            for (auto& source : Sources) {
                AddKeyToSource(source.Get(), key);
            }
        }

        void OnShardInterval(size_t interval) override {
            AddKeyToSource(Sources[interval].Get(), *CurrentKey);
            FoundSource = true;
        }

    private:
        void AddKeyToSource(NScatter::ISource* source, const TString& key) {
            auto& currentSource = KeysForSource[source];
            // Keys are sorted and unique so we can skip the real comparison here
            if (currentSource.empty() || currentSource.back().c_str() != key.c_str()) {
                currentSource.push_back(key);
            }
        }

        const NSaas::TShardsDispatcher& ShardDispatcher;
        const NSaas::TShardIntervals& ShardIntervals;
        const TVector<THolder<NScatter::ISource>>& Sources;
        TKeysForSource& KeysForSource;
        const TString* CurrentKey = nullptr;
        bool FoundSource = false;
    };

    void PrechargeSourceConnections(const NScatter::ISource& source) {
        DEBUG_LOG << "Connections precharge...." << Endl;
        TAutoPtr<IConnIterator> connIterator = source.GetSearchConnections(0);
        const NScatter::TSourceOptions& opts = source.GetSourceOptions();
        ui16 iter = 0;
        while (const TConnData* res = connIterator->Next(iter)) {
            ++iter;
            DEBUG_LOG << "Connections precharge: " << res->SearchScript() << Endl;
            TString searchScript = opts.SearcherDeltaPort
                ? NScatter::MakeSearchAddress(*res, opts).RealAddress
                : res->SearchScript();
            NNeh::TMessage message(searchScript, "dummy_ping=da");
            NNeh::THandleRef handle = NNeh::Request(message);
            handle->Wait(TDuration::MilliSeconds(30));
        }
        DEBUG_LOG << "Connections precharge... OK" << Endl;
    }

    void FillIntervalsAndSources(const TSearchSources& searchSources, NSaas::TShardIntervals& shardIntervals,
        TVector<THolder<NScatter::ISource>>& sources, const NScatter::TSourceOptions& opts, const bool precharge,
        NYP::NServiceDiscovery::TSDClient* sdClient)
    {
        ui32 indexSource = 0;
        using TSourcePair = std::pair<NSearchMapParser::TShardsInterval,THolder<NScatter::ISource>>;
        TVector<TSourcePair> intervalSourcePairs;
        NScatter::TSourceOptions optsForSd(opts);
        optsForSd.SearcherProtocol = "tcp2";
        optsForSd.SearcherDeltaPort = 1;
        NYP::NServiceDiscovery::TEndpointSetOptions epsOpts;
        epsOpts.SetUseUnistat(false);
        for (auto&& i : searchSources) {
            if (!i.ProtoConfig_.GetServerGroup()) {
                continue;
            }
            NSearchMapParser::TShardsInterval interval;
            const bool parsed = NSearchMapParser::TShardsInterval::Parse(i.ProtoConfig_.GetServerGroup(), interval);
            if (!parsed) {
                INFO_LOG << "Group " << i.ProtoConfig_.GetServerGroup() << " for source " << i.ProtoConfig_.GetServerDescr() << " is not recognized as an interval" << Endl;
            }
            bool isSd = i.ProtoConfig_.GetSearchScript().StartsWith("sd://");
            //todo: NScatter::TMutableSource anyway
            THolder<NScatter::ISource> source = isSd
                ? THolder<NScatter::ISource>(new NScatter::TMutableSource(ToString(indexSource), indexSource, i.ProtoConfig_.GetSearchScript(), optsForSd, i.ProtoConfig_.GetServerGroup(), sdClient, epsOpts))
                : NScatter::CreateSimpleSource(ToString(indexSource), i.ProtoConfig_.GetSearchScript(), opts);
            ++indexSource;
            if (precharge && parsed) {
                PrechargeSourceConnections(*(source.Get()));
            }
            intervalSourcePairs.emplace_back(interval, std::move(source));
        }
        SortBy(intervalSourcePairs, [](const TSourcePair& a) { return a.first; });
        shardIntervals.reserve(intervalSourcePairs.size());
        sources.reserve(intervalSourcePairs.size());
        for (auto&& intervalSourcePair: intervalSourcePairs) {
            shardIntervals.emplace_back(intervalSourcePair.first);
            sources.emplace_back(std::move(intervalSourcePair.second));
        }
    }

    void ShardByKeysAndSgkps(const TCgiParameters& cgi,
                             const TSet<TString>& keys,
                             const NSaas::TShardsDispatcher& shardDispatcher,
                             const NSaas::TShardIntervals& shardIntervals,
                             const TVector<THolder<NScatter::ISource>>& sources,
                             TKeysForSource& keysForSource) {
        TShardIntervalEnumerator shardEnumerator(shardDispatcher, shardIntervals, sources, keysForSource);
        if (cgi.Has("key_name")) {
            for (const auto& key : keys) {
                shardEnumerator.BroadcastKey(key);
            }
            return;
        }
        TStackVec<NSaas::TKeyPrefix, 8> kpss;
        for (const auto& it : StringSplitter(cgi.Get("sgkps")).Split(',').SkipEmpty()) {
            kpss.push_back(FromString<NSaas::TKeyPrefix>(it.Token()));
        }
        if (Y_UNLIKELY(kpss.size() > 1)) {
            SortUnique(kpss);
        } else if (kpss.empty()) {
            kpss.reserve(1);
            kpss.push_back(0);
        }
        for (const auto& key : keys) {
            for (const auto& kps : kpss) {
                shardEnumerator.FindShard(key, kps);
            }
        }
    }

    TSearch::TSearch(const TServiceConfig& serviceConfig, const TSearchSources& searchSources
            , const NSaas::TShardsDispatcher::TPtr sharding, ICgiCorrectorPtr cc
            , const TServiceDiscoveryOptions& sdOpts, const TString& serviceName
            , TEventLogPtr eventlog)
        : NSimpleMeta::TBaseProxySearch(serviceConfig, *serviceConfig.GetProxyMetaConfig(), sharding, cc, eventlog)
        , RearrangeFactory(serviceConfig.GetCustomRearranges())
    {
        NScatter::TSourceOptions opts;
        opts.EnableIpV6 = true;
        opts.MaxAttempts = ProxyConfig.GetMaxAttempts();
        opts.AllowDynamicWeights = ProxyConfig.GetAllowDynamicWeights();
        opts.ConnectTimeouts = {ProxyConfig.GetConnectTimeout()};
        opts.SendingTimeouts = {ProxyConfig.GetSendingTimeout()};
        if (ProxyConfig.GetHedgedRequestTimeoutMs() > 0) {
            opts.HedgedRequestTimeouts = {TDuration::MilliSeconds(ProxyConfig.GetHedgedRequestTimeoutMs())};
        }
        if (ProxyConfig.GetHedgedRequestTimeoutsRange()) {
            try {
                opts.ParseTimeoutsRange(ProxyConfig.GetHedgedRequestTimeoutsRange(), '_');
            } catch (const yexception& e) {
                ERROR_LOG << "Incorrect hedgets for " << serviceConfig.GetName() << ": " << e.what() << Endl;
            }
        }
        opts.HedgedRequestRateThreshold = ProxyConfig.GetHedgedRequestRateThreshold();
        opts.HedgedRequestRateSmooth = ProxyConfig.GetHedgedRequestRateSmooth();
        if (auto opts = ProxyConfig.GetConnStatOptions()) {
            INFO_LOG << "Enable ConnStat for " << serviceConfig.GetName() << " with opts: " << opts->FailThreshold << ", " << opts->CheckTimeout << ", " << opts->CheckInterval << Endl;
            ConnStatMap.Reset(NScatter::CreateConnStatMap(*opts));
        }
        if (sdOpts.Enabled) {
            SDClient.Reset(new NYP::NServiceDiscovery::TSDClient(sdOpts
                , serviceName ? NSearch::MakeCounterFactory(serviceName) : NYP::NServiceDiscovery::StandaloneCounterFactory));
        }

        FillIntervalsAndSources(searchSources, ShardIntervals, Sources, opts, ProxyConfig.GetPrechargeSourceConnections(), SDClient.Get());
        if (SDClient)
            SDClient->Start();

        if (TAutoPtr<NCgiHash::IHashFunction> hf = NCgiHash::THashFunctionFactory::Construct(ProxyConfig.GetRouteHashType(), NCgiHash::DefaultHashName)) {
            HashFunction = hf->Get();
        }
        CHECK_WITH_LOG(HashFunction);
    }

    void TSearch::DefaultSharding(IReplyContext& context, bool noTextSplit, TKeysForSource& keysForSource) const {
        const auto& cgi = context.GetRequestData().CgiParam;
        TSet<TString> keys;
        if (noTextSplit) {
            for (const auto& val: cgi.Range("text")) {
                keys.insert(val);
            }
            if (keys.empty()) {
                ythrow TSearchException(HTTP_BAD_REQUEST) << "Query does not have |text| parameter";
            }
        } else {
            for (const auto& it : StringSplitter(cgi.Get("text")).Split(',').SkipEmpty()) {
                keys.emplace(it.Token());
            }
            if (keys.empty()) {
                keys.emplace();
            }
        }
        ShardByKeysAndSgkps(cgi, keys, *ShardsDispatcher, ShardIntervals, Sources, keysForSource);
    }

    bool TSearch::DoSearchContext(IReplyContext::TPtr context) const {
        auto processingStartTime = TInstant::Now();
        if (LoggingEnabled()) {
            NSearchProxy::NLogging::TIncomingLog::Log(context->GetRequestId(), *context);
        }

        CgiCorrector->FormCgi(context->MutableRequestData().CgiParam, &context->GetRequestData());

        TSelfFlushLogFramePtr eventLogFrame;
        TEventLogger eventLogger;
        const bool addEventLogToReport = context->GetRequestData().CgiParam.Has("dump", "eventlog") ;
        if (addEventLogToReport || (!!EventLog.Get() && IsEventLogEnabled())) {
            if (!!EventLog.Get() && IsEventLogEnabled()) {
                eventLogFrame = MakeIntrusive<TSelfFlushLogFrame>(*EventLog.Get());
            } else {
                eventLogFrame = MakeIntrusive<TSelfFlushLogFrame>();
            }
            eventLogFrame->ForceDump();
            eventLogger.AssignLog(eventLogFrame);
        }
        if (eventLogFrame) {
            eventLogger.LogEvent<NEvClass::TContextCreated>(TString{context->GetRequestData().Query()}, context->GetRequestData().RP->UserReqId, 0);
            eventLogger.LogEvent<NEvClass::TReqId>(context->GetRequestData().CgiParam.Get(NSearchProxyCgi::queryid));
        }

        TRearrangeEngine rearrangeEngine = RearrangeFactory.CreateEngine(context->MutableRequestData().CgiParam, eventLogFrame.Get());
        rearrangeEngine.FormCgiParams(context->MutableRequestData().CgiParam);

        const bool noTextSplit = context->GetRequestData().CgiParam.Has("saas_no_text_split");
        TKeysForSource keysForSource;

        if (!rearrangeEngine.ProcessSharding(*context, *ShardsDispatcher, ShardIntervals, Sources, keysForSource)) {
            DefaultSharding(*context, noTextSplit, keysForSource);
        }

        THolder<TAnswerBuilder> builder = MakeHolder<TAnswerBuilder>(*context, eventLogFrame, addEventLogToReport, &rearrangeEngine, HttpStatusManagerConfig);
        try {
            TExtendedReplyContext erc(context.Get());
            erc.SetMetaSearchType(PROXY_TYPE_PROXY);
            erc.SetProcessingStartTime(processingStartTime);
            NScatter::TAsyncTaskRunner taskRunner;
            if (eventLogFrame) {
                taskRunner.EventLogger = &eventLogger;
            }
            if (ConnStatMap) {
                taskRunner.ConnStatMap = ConnStatMap.Get();
            }
            for (auto&& src : keysForSource) {
                Y_ASSERT(src.first);
                NScatter::ISource& source = *src.first;
                const auto& keys = src.second;
                if (keysForSource.size() == 1 && keys.size() <= ProxyConfig.GetMaxKeysPerRequest()) {
                    NScatter::ITaskRef task = new TSearchContext(context.Get(), builder.Get(), /*textOverride=*/{}, source, HashFunction, ServiceConfig);
                    taskRunner.Schedule(task);
                } else {
                    if (keys.empty()) {
                        if (eventLogFrame) {
                            eventLogger.LogEvent<NEvClass::TSubSourceSkip>(
                                source.Num,
                                /*attempt=*/0,
                                /*taskNum=*/0,
                                "no keys"
                            );
                        }
                        continue;
                    }

                    auto batchStart = keys.cbegin();
                    auto batchEnd = batchStart;
                    ui32 pos = 0;
                    do {
                        const ui32 batchSize = std::min<ui32>(keys.size() - pos, ProxyConfig.GetMaxKeysPerRequest());
                        batchEnd = batchStart + batchSize;

                        TVector<TString> text;
                        if (noTextSplit) {
                            text.reserve(batchSize);
                            text.assign(batchStart, batchEnd);
                        } else {
                            text.reserve(1);
                            text.emplace_back(JoinRange(",", batchStart, batchEnd));
                        }
                        NScatter::ITaskRef task = new TSearchContext(context.Get(), builder.Get(), std::move(text), source, HashFunction, ServiceConfig);
                        taskRunner.Schedule(task);

                        batchStart = batchEnd;
                        pos += batchSize;

                    } while (batchEnd != keys.cend());
                }
            }
            NScatter::IWaitPolicyRef waitPolicy = NSimpleMeta::CreateWaitPolicy(context->GetRequestData().CgiParam, ProxyConfig);
            taskRunner.Wait("proxy", waitPolicy);

            erc.SetReportStartTime(Now());
            const TString& proxyType = context->GetRequestData().CgiParam.Has("sp_meta_search") ? context->GetRequestData().CgiParam.Get("sp_meta_search") : ServiceConfig.GetDefaultMetaSearch();
            builder->Answer(proxyType);

            erc.MarkTimeouted(context->GetRequestDeadline() < Now());
            erc.SetAnswerIsComplete(builder->GetAnswerIsComplete());
            erc.SetFailedAttempts(builder->GetFailedAttempts());
            erc.SetReportDocsCount(builder->GetDocumentsCount());
            erc.SetReportByteSize(builder->GetReportByteSize());
            erc.SetHttpStatus(builder->GetAnsweredHttpCode());
            TSaasSearchProxySignals::DoUnistatRecord(erc);
            if (LoggingEnabled() || erc.ForceLogging()) {
                NSearchProxy::NLogging::TInfoLog::Log(context->GetRequestId(), erc);
            }
            if (ProxyConfig.GetReqAnsEnabled()) {
                NSearchProxy::NLogging::TReqAnsLog::Log(context->GetRequestId(), builder->GetReport());
            }
        } catch (...) {
            builder->AnswerFail("cannot build response: " + CurrentExceptionMessage());
            TSaasSearchProxySignals::DoUnistatExceptionRecord(ServiceConfig.GetName());
            NSearchProxy::NLogging::TErrorLog::Log(context->GetRequestId(), context->GetRequestData(), CurrentExceptionMessage());
        };

        return true;
    }

    TProxy::TFactory::TRegistrator<TProxy> TProxy::Registrator("proxy");

    ISearchReplier::TPtr TProxy::BuildReplier(const TString& service, IReplyContext::TPtr context) {
        auto it = Searchers.find(service);
        if (it == Searchers.end())
            return nullptr;
        return new TReplier(context, it->second.Get());
    }

    bool TProxy::Process(IMessage* message) {
        if (auto eventLogControl = message->As<NSearchProxy::TControlEventLogMessage>()) {
            auto searcher = Searchers.find(eventLogControl->Service);
            if (searcher != Searchers.end()) {
                eventLogControl->TargetServiceFound = true;
                if (eventLogControl->TargetState) {
                    searcher->second->EnableEventLog();
                } else {
                    searcher->second->DisableEventLog();
                }
                eventLogControl->ResultState = searcher->second->IsEventLogEnabled();
            }
            return true;
        }
        return false;
    }

}
