#include "factors_calcer.h"
#include "rty_features.h"

#include <saas/rtyserver/factors/factors_blocks.h>
#include <saas/rtyserver/factors/rank_model.h>
#include <library/cpp/logger/global/global.h>

#include <saas/rtyserver/components/cs/cs_manager.h>
#include <saas/rtyserver/components/erf/erf_manager.h>
#include <saas/rtyserver/search/gators/zone_gator.h>

#include <search/relevance/query_factors/query_factors.h>
#include <saas/rtyserver/search/external_search/rty_features.h>

namespace {
    void Y_FORCE_INLINE AddFactor(TFactorStorage& storage, i32 index, float value) {
        if (index != -1)
            storage[index] = value;
    }

    void Y_FORCE_INLINE AddFreshnessFactor(TFactorStorage& storage, i32 index, float age, const TDuration retardation) {
        if (index != -1) {
            const float value = 1.0f / (1 + age / retardation.Seconds());
            storage[index] = value;
        }
    }

    bool Y_FORCE_INLINE HasWebProductionFactors(const NRTYFactors::TFactorChunks& chunks, const TCombinedFactorMask& bsFactorMask) {
        return !chunks.empty() || /*hasZoneGator=*/bsFactorMask.Group.BM25F();
    }
}

TFactorsCalcer::TFactorsCalcer(
    TZoneGator*& zoneGator,
    const NRTYFactors::TRankModelHolder& formula,
    const TFactorsMaskCache& factorsMaskCache,
    const TRTYRankingPassContext* rankingPass,
    const TRTYIndexData* indexData,
    const IRTYCgiReader* cgi,
    const bool fastFeaturesOnly)
   : Config(*indexData->GetFactorsConfig())
   , ZoneGator(zoneGator)
   , QS(indexData->GetQSManager())
   , CS(indexData->GetCSManager())
   , Erf(indexData->GetErfManager())
   , DocLen(indexData->GetDocLenCalcer())
   , DDK(indexData->GetDDKManager())
   , Ann(indexData->GetAnnManager())
   , Cgi(cgi)
   , RankingPass(*rankingPass)
   , UserCalcer(formula.GetUserRanking())
   , FactorStorage(Config.GetFactorsDomain())
{
    Y_ASSERT(rankingPass);

    const NFactorSlices::TFactorDomain& fdm = Config.GetFactorsDomain();
    TSet<EFactorSlice> slices;
    for (auto it = fdm.Begin(); it != fdm.End(); ++it) {
        slices.insert(it.GetLeaf());
    }
    for (auto slice : slices) {
        if (slice != EFactorSlice::WEB_PRODUCTION) {
            FactorStorage.CreateViewFor(slice).FillCanonicalValues();
        }
    }

    const ECalcFactors calcType = Cgi->CalcAllFactors(fastFeaturesOnly);

    if (calcType == ECalcFactors::All) {
        Config.GetAllFactorsIndexes(UsedFactors);
        FactorMasks = &factorsMaskCache.GetAll();
        if (auto userCalcer = Cgi->GetUserFactorsCalcer()) {
            UserFactorsCalcer.Reset(userCalcer->CreateFormulaCalcer(UsedFactors));
        }
    } else {
        UsedFactors.insert(formula.GetUsedFactors().begin(), formula.GetUsedFactors().end());

        if (!fastFeaturesOnly && Cgi->GetFilterBorder() > 0 && Cgi->GetFilterModel() != nullptr && Cgi->GetFilterModel() != &formula) {
            const auto& factors = Cgi->GetFilterModel()->GetUsedFactors();
            UsedFactors.insert(factors.begin(), factors.end());
        }

        if (fastFeaturesOnly && Cgi->GetFastFilterBorder() > 0 && Cgi->GetFastFilterModel() != nullptr && Cgi->GetFastFilterModel() != &formula) {
            const auto& factors = Cgi->GetFastFilterModel()->GetUsedFactors();
            UsedFactors.insert(factors.begin(), factors.end());
        }

        const NRTYFactors::TFactorSet* factorSet = nullptr;
        NRTYFactors::TUsedFactors allExtraFactors;

        if (auto extraFactors = Cgi->GetExtraFactors(fastFeaturesOnly)) {
            for (const ui32 factorIndex : *extraFactors) {
                if (Config.IsUnusedFactor(factorIndex)) {
                    continue;
                }
                if (allExtraFactors.insert(factorIndex).second) {
                    UsedFactors.insert(factorIndex);
                }
            }
        }
        if (const TString* factorSetName = Cgi->GetNamedFactorSet(fastFeaturesOnly)) {
            factorSet = Config.GetFactorSet(*factorSetName);
            if (factorSet)
                UsedFactors.insert(factorSet->GetUsedFactors().begin(), factorSet->GetUsedFactors().end());
        }

        const bool forceCalcStaticFactors = Cgi->ShouldCalcAllStaticFactors() || (calcType == ECalcFactors::LightAndFormula);
        const bool forceCalсUserFactors = calcType == ECalcFactors::LightAndFormula;

        auto addFactorsFromGroup = [&allExtraFactors, this](const auto& factorGroup) {
            for (const auto& factor : factorGroup) {
                if (allExtraFactors.insert(factor.IndexGlobal).second) {
                    UsedFactors.insert(factor.IndexGlobal);
                }
            }
        };
        if (forceCalсUserFactors) {
            addFactorsFromGroup(Config.UserFactors());
        }
        if (forceCalcStaticFactors) {
            addFactorsFromGroup(Config.StaticFactors());
            addFactorsFromGroup(Config.IgnoredFactors());
        }

        if (auto userCalcer = Cgi->GetUserFactorsCalcer()) {
            UserFactorsCalcer.Reset(userCalcer->CreateFormulaCalcer(UsedFactors));
            if (UserFactorsCalcer) {
                for (auto&& factor : UserFactorsCalcer->GetUsedFactors()) {
                    if (allExtraFactors.insert(factor).second)
                        UsedFactors.insert(factor);
                }
            }
        }

        const NRTYFactors::TRankModelHolder* filterModel = fastFeaturesOnly ? Cgi->GetFastFilterModel() : Cgi->GetFilterModel();

        FactorMasks = &factorsMaskCache.Get(&formula, filterModel, factorSet, &allExtraFactors);
    }

    if (!GetRTYFactorMask().PluginFactors.empty()) {
        DP = NRTYFeatures::CreateCalcer(GetRTYFactorMask().PluginFactors, &Config, fastFeaturesOnly);
    }
}

void TFactorsCalcer::OnBeforePass() {
    if (RankingPass.FactorAnnCalcer) {
        RankingPass.FactorAnnCalcer->Bind(FactorStorage);
    }
}

void TFactorsCalcer::OnReopenTextMachine() {
    // Basesearch initializes the ranking pass in the following order:
    // - restart std iterators
    // - call OnBeforePass()
    // - Create a text machine (this opens more iterators)
    // - call OnReopenTextMachine
    if (RankingPass.TmCustomCalcer) {
        RankingPass.TmCustomCalcer->Bind(FactorStorage);
    }
}


template<bool all>
void TFactorsCalcer::InitSpecialFactors(ui32 docId) {
    const auto& dynamicFactorsMask = GetFactorMask();
    if (all || dynamicFactorsMask.Factor.DocLen()) {
        float value = (float)DocLen->GetDocLen(docId)/2560.0f;
        FactorStorage[FI_DOC_LEN] = Ui82Float((ui8)ClampVal((int)(value*255.0), 0, 255));
    }
    if (all || dynamicFactorsMask.Factor.TLen()) {
        FactorStorage[FI_TLEN] = MapFactor(DocLen->GetDocLen(docId), 1.f / 400);
    }
    if (RankingPass.TmCustomCalcer) {
        RankingPass.TmCustomCalcer->CalcFeatures(docId);
    }
    if (RankingPass.FactorAnnCalcer) {
        RankingPass.FactorAnnCalcer->CalcFeatures(docId);
    }
}

template<bool all>
void TFactorsCalcer::InitRTYFactors(TFactorStorage& factorStorage, ui32 docId) {
    const auto& rtyFactorsMask = GetRTYFactorMask();
    if (all || rtyFactorsMask.DynamicFactors) {
        const ui32 delta = Seconds() - DDK->GetTimeLiveStart(docId);
        AddFactor(factorStorage, rtyFactorsMask.IndexLiveTime, delta);
        AddFactor(factorStorage, rtyFactorsMask.IndexInvLiveTime, (delta >= 1) ? 1.0f / delta : 1);
        AddFreshnessFactor(factorStorage, rtyFactorsMask.IndexFreshnessDay, delta, TDuration::Days(1));
        AddFreshnessFactor(factorStorage, rtyFactorsMask.IndexFreshnessWeek, delta, TDuration::Days(7));
        AddFreshnessFactor(factorStorage, rtyFactorsMask.IndexFreshnessMonth, delta, TDuration::Days(30));
    }
    if (all || rtyFactorsMask.TimeFactors.size()) {
        for (ui32 i = 0; i < rtyFactorsMask.TimeFactors.size(); ++i) {
            const NRTYFactors::TTimeFactorData& tfd = rtyFactorsMask.TimeFactors[i];
            const ui32 delta = Now().Hours() - factorStorage[tfd.BaseIndex];
            AddFactor(factorStorage, tfd.DeltaIndex, delta);
            AddFactor(factorStorage, tfd.InvDeltaIndex, (delta >= 1) ? 1.0f / delta : 1);
        }
    }
    for (ui32 i = 0; i < rtyFactorsMask.RefineFactors.size(); ++i)
        factorStorage[rtyFactorsMask.RefineFactors[i].BaseIndex] = 0;
}

void TFactorsCalcer::CalcRefineFactor(TCalcFactorsContext& ctx) {
    if (ctx.Helper && ctx.TextHits && ctx.LinkHits) {
        //TODO(yrum): [refactor] в нынешней реализации базового оба вызова (c true и false) делают одно и то же - убрать параметр
        ctx.Helper->CalcRefineFactor(*ctx.Factors, *ctx.TextHits, *ctx.LinkHits, true);
        ctx.Helper->CalcRefineFactor(*ctx.Factors, *ctx.TextHits, *ctx.LinkHits, false);
    }
}
void TFactorsCalcer::CalcFactors(TCalcFactorsContext& ctx) {
    CalcNoUserFactors(ctx);
    CalcUserFactors(ctx);
}

//FIXME(SAAS-5629): [refactor] make CalcFactors() calls here reuse the 'saasView'
void TFactorsCalcer::CalcNoUserFactors(TCalcFactorsContext& ctx) {
    const auto& rtyFactorsMask = GetRTYFactorMask();
    const auto& dynamicFactorsMask = GetFactorMask();
    const auto& dynamicFactorsChunks = GetFactorChunks();
    if (rtyFactorsMask.StaticFactors) {
        TFactorView saasView(static_cast<TBasicFactorStorage&>(*ctx.Factors));
        Erf->Read(saasView, ctx.DocId);
    }

    if (!!ZoneGator) {
        ZoneGator->SetFactorStorage(ctx.Factors);
    }

    if (HasWebProductionFactors(dynamicFactorsChunks, dynamicFactorsMask)) {
        ctx.Helper->CalcFastFeatures(FactorStorage);
        ctx.Helper->CalcAllFeatures(FactorStorage); // we ignore ctx.Fast here
        ctx.DocRelevPerQueryParams->CacheFactors.Apply(FactorStorage.RawData());
        dynamicFactorsMask.AllFactors ? InitSpecialFactors<true>(ctx.DocId) : InitSpecialFactors<false>(ctx.DocId);
        NRTYFactors::TConfig::CopyFactors(dynamicFactorsChunks, FactorStorage, *ctx.Factors);
    }

    if (rtyFactorsMask.QSFactorsInfo.IsUsed() && QS) {
        if (!QSItPull)
            QSItPull = QS->BuildIteratorsPull(Cgi, &rtyFactorsMask);
        QSItPull->CalcFactors(ctx);
    }

    if (CS && rtyFactorsMask.CSFactorsInfo.IsUsed()) {
        CS->CalcFactors(ctx, &rtyFactorsMask);
    }

    if (DP) {
        Y_ASSERT(rtyFactorsMask.PluginFactors.size() > 0);
        TRTYDynamicFeatureContext r;
        Cgi->FillDynamicContext(r);
        DP->Calc(*ctx.Factors, r, ctx.DocId);
    }


    dynamicFactorsMask.AllFactors ? InitRTYFactors<true>(*ctx.Factors, ctx.DocId) : InitRTYFactors<false>(*ctx.Factors, ctx.DocId);

    CalcRefineFactor(ctx);

    if (!!UserCalcer && UserCalcer->HasFactors())
        UserCalcer->CalcFactors(ctx);
}

void TFactorsCalcer::CalcUserFactors(TCalcFactorsContext& ctx) {
    if (!!UserFactorsCalcer)
        UserFactorsCalcer->Calc(ctx);
}

size_t TFactorsCalcer::GetFactorCount() const {
    return Config.GetFactorCount();
}

bool TFactorsCalcer::GetFactorIndex(const char* name, size_t* index) const {
    *index = Config.GetFactorGlobalNum(name);
    return *index != NRTYFactors::NOT_FACTOR && UsedFactors.find(*index) != UsedFactors.end();
}

const char* TFactorsCalcer::GetFactorName(size_t index) const {
    return (UsedFactors.find(index) == UsedFactors.end()) ? "" : Config.GetFactorByFormulaIndex(index)->Name.data();
}

const char* const* TFactorsCalcer::GetFactorNames() const {
    return nullptr;
}

bool TFactorsCalcer::UsesWebProductionCalcer() const {
    return HasWebProductionFactors(GetFactorChunks(), GetFactorMask());
}

bool TFactorsCalcer::IsUnusedFactor(size_t index) const {
    return UsedFactors.find(index) == UsedFactors.end();
}
