#include "rank_model.h"
#include "factors_config.h"

#include <search/rank/sum_relev/sum_relev.h>

#include <kernel/relevfml/rank_models_factory.h>
#include <kernel/externalrelev/relev.h>

#include <library/cpp/logger/global/global.h>
#include <library/cpp/json/json_value.h>

#include <util/stream/file.h>

namespace NRTYFactors {
    TRankModelHolder::TRankModelHolder(const TStringBuf name, const SRelevanceFormula& polynom, const TConfig* owner, const TRankModelHolder* mxnetSource)
        : Name(name)
        , FactorsConfig(*owner)
    {
        RelevanceFormula.Reset(new SRelevanceFormula(polynom));
        RelevanceFormula->UsedFactors(UsedFactors);

        if (mxnetSource) {
            NMatrixnet::TRelevCalcerPtr matrixnet = mxnetSource->GetRankModel()->MatrixnetShared();
            if (matrixnet) {
                TUsedFactors tmpFactors;
                matrixnet->UsedFactors(tmpFactors);
                UsedFactors.insert(tmpFactors.begin(), tmpFactors.end());
            }
            Model.Reset(new TRankModel(matrixnet, TPolynomDescrStatic(RelevanceFormula.Get())));
        } else {
            Model.Reset(new TRankModel(TPolynomDescrStatic(RelevanceFormula.Get())));
        }
    }

    TRankModelHolder::TRankModelHolder(const TStringBuf name, const NJson::TJsonValue& config, const TConfig* owner)
        : Name(name)
        , FactorsConfig(*owner)
    {
        if (!config.IsMap())
            ythrow yexception() << "invalid formula config";
        bool useMatrixNet = true;
        if (config.Has("polynom")) {
            RelevanceFormula.Reset(new SRelevanceFormula);
            ui32 count = FactorsConfig.GetFactorsCountByIndex();
            Decode(RelevanceFormula.Get(), config["polynom"].GetString(), count);
            RelevanceFormula->UsedFactors(UsedFactors);
            useMatrixNet = UsedFactors.find(owner->GetMatrixNetGlobalIndex()) != UsedFactors.end(); // consider using FullMatrixNet and FastMatrixNet as polynom input
        }
        if (config.Has("user_ranking")) {
            UserRanking.Reset(Singleton<typename IUserRanking::TFactory>()->Construct(config["user_ranking"].GetString()));
            if (!UserRanking) {
                ythrow yexception() << "Incorrect user ranking: " << config["user_ranking"].GetString();
            }
            UserRanking->InitConfig(*owner, config);
            TUsedFactors srcFactors;
            const bool hasSrcFactors = UserRanking->GetUsedFactors(srcFactors);
            const bool hasUserFormula = UserRanking->HasRelevance();
            if (hasUserFormula) {
                Y_ENSURE(!hasSrcFactors || !srcFactors.contains(owner->GetMatrixNetGlobalIndex()), "user_ranking formula cannot be dependent on its own output");
                Y_ENSURE(!config.Has("matrixnet"), "use either a matrixnet formula or user_ranking formula");
                useMatrixNet = false;
            }
            if (hasSrcFactors) {
                UsedFactors.insert(srcFactors.begin(), srcFactors.end());
            }
        }
        if (useMatrixNet && !config.Has("matrixnet")) {
            ythrow yexception() << "formula tries to use MatrixNet factor, but does not contain matrixnet";
        } else if (config.Has("matrixnet") && useMatrixNet) {
            auto pathStr = config["matrixnet"].GetString();
            TFsPath path = owner->GetModelPath(pathStr);
            TUnbufferedFileInput fi(path.c_str());
            THolder<NMatrixnet::TMnSseDynamic> tmp(new NMatrixnet::TMnSseDynamic());
            tmp->Load(&fi);
            MxNetInfo.Reset(tmp.Release());
            TUsedFactors tmpFactors;
            MxNetInfo->UsedFactors(tmpFactors);
            UsedFactors.insert(tmpFactors.begin(), tmpFactors.end());
        }
        if (config.Has("normalize")) {
            Options.Normalize = config["normalize"].GetBooleanRobust();
        }
        NMatrixnet::TRelevCalcerPtr relevCalcer;
        if (!relevCalcer) {
            relevCalcer = MxNetInfo;
        }
        if (RelevanceFormula && relevCalcer)
            Model.Reset(new TRankModel(relevCalcer, TPolynomDescrStatic(RelevanceFormula.Get())));
        else if (RelevanceFormula)
            Model.Reset(new TRankModel(TPolynomDescrStatic(RelevanceFormula.Get())));
        else if (relevCalcer)
            Model.Reset(new TRankModel(relevCalcer));
        else if (UserRanking) {
            Y_ENSURE(UserRanking->HasRelevance(), "user_ranking '" << config["user_ranking"].GetString()
                    << "' is not a formula (add a polynom or matrixnet to the formula section)");
            Model.Reset(new TRankModel());
        } else
            ythrow yexception() << "formula section must contain either polynom, matrixnet or user_ranking section";
    }

    TRankModelHolder::~TRankModelHolder()
    {}

    void TRankModelHolder::MultiCalc(float** factors, float* results, const size_t count, TSumRelevParams* srParamsForCache,
                                        TMaybe<size_t> modelMatrixNetIndex, TMaybe<size_t> modelPolynomIndex) const {
        if (Y_UNLIKELY(count == 0))
            return;

        bool useUserRanking = !!UserRanking && UserRanking->HasRelevance();
        bool usePolynom = Model->HasPolynom();
        bool useMatrixNet = Model->HasMatrixnet();
        Y_ASSERT(!useUserRanking || !useMatrixNet);

        if (useMatrixNet) {
            TVector<double> mx_values(count);

            TVector<size_t> unCachedIndices;
            if (srParamsForCache) {
                unCachedIndices.reserve(count);
                for (size_t i = 0; i < count; ++i) {
                    if (srParamsForCache[i].MxValue != 0) {
                        mx_values[i] = srParamsForCache[i].MxValue;
                    } else {
                        unCachedIndices.push_back(i);
                    }
                }
            }
            if (unCachedIndices.size() == count || !srParamsForCache) {
                Model->Matrixnet()->DoCalcRelevs(factors, &mx_values[0], count);
            } else if (unCachedIndices.size() > 0) {
                size_t countRest = unCachedIndices.size();
                TVector<double> mx_values_rest(countRest);
                TVector<float*> factorsRest(countRest);
                for (size_t i=0; i<unCachedIndices.size(); ++i) {
                    factorsRest[i] = factors[unCachedIndices[i]];
                }
                Model->Matrixnet()->DoCalcRelevs(factorsRest.begin(), &mx_values_rest[0], countRest);
                for (size_t i=0; i<unCachedIndices.size(); ++i) {
                    mx_values[unCachedIndices[i]] = mx_values_rest[i];
                }
            }
            if (usePolynom) {
                for (size_t i = 0; i < count; ++i) {
                    if (modelMatrixNetIndex.Defined()) {
                        factors[i][*modelMatrixNetIndex] = static_cast<float>(mx_values[i]);
                    }
                    factors[i][FactorsConfig.GetMatrixNetGlobalIndex()] = static_cast<float>(mx_values[i]);
                    if (srParamsForCache)
                        srParamsForCache[i].MxValue = static_cast<float>(mx_values[i]);
                }
            } else {
                for (size_t i = 0; i < count; ++i) {
                    results[i] = static_cast<float>(mx_values[i]);
                    if (srParamsForCache)
                        srParamsForCache[i].MxValue = results[i];
                    if (modelMatrixNetIndex.Defined()) {
                        factors[i][*modelMatrixNetIndex] = static_cast<float>(mx_values[i]);
                    }
                    if (FactorsConfig.GetMatrixNetGlobalIndex() < Max<ui32>()) {
                        factors[i][FactorsConfig.GetMatrixNetGlobalIndex()] = static_cast<float>(mx_values[i]);
                    }
                }
                return;
            }
        } else if (useUserRanking) {
            UserRanking->CalcRelevance(factors, results, count);
            for (size_t i = 0; i < count; ++i) {
                if (modelMatrixNetIndex.Defined()) {
                    factors[i][*modelMatrixNetIndex] = static_cast<float>(results[i]);
                }
                if (FactorsConfig.GetMatrixNetGlobalIndex() < Max<ui32>()) {
                    factors[i][FactorsConfig.GetMatrixNetGlobalIndex()] = static_cast<float>(results[i]);
                }
            }
        }

        if (usePolynom) {
            Model->Polynom()->MultiCalc(factors, results, count);
            if (modelPolynomIndex.Defined()) {
                for (size_t i = 0; i < count; ++i) {
                    factors[i][*modelPolynomIndex] = static_cast<float>(results[i]);
                }
            }
        }
    }

    TMaybe<size_t> TRankModelHolder::GetMatrixNetIndex() const {
        if (Model->HasMatrixnet() && FactorsConfig.GetMatrixNetGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetMatrixNetGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetMetaMatrixNetIndex() const {
        if (Model->HasMatrixnet() && FactorsConfig.GetMetaMatrixNetGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetMetaMatrixNetGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetFullMatrixNetIndex() const {
        if (Model->HasMatrixnet() && FactorsConfig.GetFullMatrixNetGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetFullMatrixNetGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetFastMatrixNetIndex() const {
        if (Model->HasMatrixnet() && FactorsConfig.GetFastMatrixNetGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetFastMatrixNetGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetFilterMatrixNetIndex() const {
        if (Model->HasMatrixnet() && FactorsConfig.GetFilterMatrixNetGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetFilterMatrixNetGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetFastFilterMatrixNetIndex() const {
        if (Model->HasMatrixnet() && FactorsConfig.GetFastFilterMatrixNetGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetFastFilterMatrixNetGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetMetaPolynomIndex() const {
        if (FactorsConfig.GetMetaPolynomGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetMetaPolynomGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetFullPolynomIndex() const {
        if (FactorsConfig.GetFullPolynomGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetFullPolynomGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetFastPolynomIndex() const {
        if (FactorsConfig.GetFastPolynomGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetFastPolynomGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetFilterPolynomIndex() const {
        if (FactorsConfig.GetFilterPolynomGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetFilterPolynomGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    TMaybe<size_t> TRankModelHolder::GetFastFilterPolynomIndex() const {
        if (FactorsConfig.GetFastFilterPolynomGlobalIndex() < Max<ui32>()) {
            return FactorsConfig.GetFastFilterPolynomGlobalIndex();
        } else {
            return TMaybe<size_t>();
        }
    }

    //
    // TFactorSet
    //
    TFactorSet::TFactorSet(const TString& name, NRTYFactors::TUsedFactors&& usedFactors)
        : Name(name)
        , UsedFactors(usedFactors)
    {
        Y_ASSERT(name);
    }

    TFactorSet::TFactorSet(const TString& name, const NJson::TJsonValue& config, const TConfig* owner)
        : Name(name)
    {
        Y_VERIFY(name && owner);
        Y_ENSURE(config.IsArray(), "factor_set should be a Json array");

        auto getGlobalIndex = [owner, &name](const NJson::TJsonValue& item) -> size_t {
            if (item.IsString()) {
                const TFactor* factor = owner->GetFactorByName(item.GetString());
                if (!factor)
                    ythrow yexception() << "unknown factor '" << item.GetString() << "' in factor_set '" << name << "'";
                return factor->IndexGlobal;
            } else if (item.IsUInteger()) {
                const TFactor* factor = owner->GetFactorByFormulaIndex(item.GetUIntegerSafe(Max<ui32>())); // throws
                return factor->IndexGlobal;
            } else {
                ythrow yexception() << "expected a string name or an integer index in factor_set '" << name << "', got: " << item.GetStringRobust();
            }
        };
        auto onUnknownType = [](const NJson::TJsonValue& item) {
            ERROR_LOG << "Cannot handle a Json item in factor_set, will ignore: " << item.GetStringRobust() << Endl;
        };

        for (const NJson::TJsonValue& item: config.GetArray()) {
            if (Y_LIKELY(!item.IsMap())) {
                // POD entities like "TRDocQuorum", 0, 42 are interpreted as a factor name (factor index)
                // Unknown factors are not allowed (they result in FailedConfig)
                const size_t globalIndex = getGlobalIndex(item);
                UsedFactors.insert(globalIndex);
            } else {
                const NJson::TJsonValue& vType = item["type"];
                const TString& type = vType.GetStringSafe("range");

                if (type == "range") {
                    // Include all factors in range.
                    // Examples: {}, {"from":0, "to":20}, {"to":500}, {"type":"range", "from":"PR", "to":"TRDocQuorum"}, {"type":"range", "from":100}
                    const NJson::TJsonValue& vFrom = item["from"];
                    const NJson::TJsonValue& vTo = item["to"];
                    if (!vType.IsDefined()) {
                        // if "type" is omitted, we assume "range", but do not allow unknown fields
                        Y_ENSURE(vFrom.IsDefined() || vTo.IsDefined() || item.GetMap().empty(), "incorrect item: " << item.GetStringRobust());
                    }
                    // "from" and "to" values may be either factor names/indexes, or arbitrary integer boundaries
                    // (e.g: {"from":"PR", "to":10000} , where 10000 is not an index of some factor, is considered a correct input)
                    size_t from = !vFrom.IsDefined() ? Min<size_t>() : vFrom.IsUInteger() ? vFrom.GetUInteger() : getGlobalIndex(vFrom);
                    size_t to = !vTo.IsDefined() ? Max<size_t>() : vTo.IsUInteger() ? vTo.GetUInteger() : getGlobalIndex(vTo);
                    if (from > to)
                        std::swap(from, to);
                    if (to != Max<size_t>())
                        ++to; //ListFactorsByRange expects stl-style half-interval
                    for (const TFactor* factor : owner->ListFactorsByRange(from, to)) {
                        Y_ASSERT(factor != nullptr);
                        UsedFactors.insert(factor->IndexGlobal);
                    }
                } else if (type == "rty_group") {
                    // Include all factors of some type. Example: {"type":"rty_group", "name":"static_factors"}
                    // See GetFactorsByRtyType implementation for possible "name" values (not everything is available)
                    const TString& name = item["name"].GetString();
                    const TFactorsList* factorList = owner->GetFactorsByRtyType(name);
                    if (factorList == nullptr) {
                        onUnknownType(item);
                        continue;
                    }
                    for (const TFactor& factor: *factorList) {
                        UsedFactors.insert(factor.IndexGlobal);
                    }
                } else {
                    // There is a type that we do not know. Ignore it.
                    onUnknownType(item);
                    continue;
                }
            }
        }
    }
}
