#pragma once

#include <saas/rtyserver/factors/feature.h>
#include <saas/rtyserver/factors/rty_features_factory.h>

#include <kernel/dssm_applier/nn_applier/lib/layers.h>

#include <util/generic/hash_set.h>
#include <util/string/printf.h>

class TFactorStorage;

namespace NRTYFeatures {
    // TDssmModelsConfig: protobuf with metadata
    class TDssmModelsConfig;

    //
    // TDssmBundle: a class that knows about dssm models metadata
    //
    class TDssmBundle: public TThrRefBase {
    public:
        struct TInputTraits {
            bool UsesQuery = false;
            bool UsesTitle = false;
            bool UsesHostPath = false;
            bool UsesRegion = false;

            TString DebugString() const {
                return Sprintf("%i %i %i %i", (i32)UsesQuery, (i32)UsesTitle, (i32)UsesHostPath, (i32)UsesRegion);
            }
        };

        struct TFactorBinding {
            TString ModelName;
            TString VarName;

            TFactorBinding() = default;
            TFactorBinding(TString modelName, TString varName)
                : ModelName(modelName)
                , VarName(varName)
            {
                Y_ASSERT(!modelName.empty() || varName.empty());
            }

            inline bool ShouldFillCanonicalValue() const {
                return ModelName.empty();
            }
        };

        struct TModelBinding {
            TString ModelName;
            TVector<TString> VarNames; // scalar(float) variables that are calculated for each doc
            TVector<TString> QueryCacheOutputs; // vectors(query embeddings), precalculated for the query only

            TInputTraits Inputs;

            TModelBinding() = default;
            TModelBinding(TString modelName, TVector<TString> varNames, TInputTraits inputs)
                : ModelName(modelName)
                , VarNames(varNames)
                , Inputs(inputs)
            {
            }

            TModelBinding(TString modelName, TVector<TString> varNames, TInputTraits inputs, TVector<TString> queryEmbs)
                : TModelBinding(modelName, varNames, inputs)
            {
                QueryCache(queryEmbs);
            }

            void AssignFromMetadata(const TModelBinding& meta) {
                ModelName = meta.ModelName;
                VarNames.clear();
                Inputs = meta.Inputs;
                QueryCacheOutputs = meta.QueryCacheOutputs;
            }

            TModelBinding& QueryCache(TVector<TString> outputVarNames) {
                QueryCacheOutputs = outputVarNames;
                return *this;
            }


            void Combine(const TFactorBinding& other) {
                if (ModelName.empty()) {
                    ModelName = other.ModelName;
                } else {
                    Y_ENSURE(ModelName == other.ModelName);
                }

                AddItem(VarNames, other.VarName);
            }

            ui32 FindOutputLocalId(const TString& varName) const {
                auto i = std::find(VarNames.begin(), VarNames.end(), varName);
                Y_ENSURE(i != VarNames.end());
                return (ui32)(i - VarNames.begin());
            }

            TString FormatCacheKey() const;

        private:
            static void AddItem(TVector<TString>& names, TString name) {
                if (std::find(names.begin(), names.end(), name) == names.end()) {
                    names.push_back(name);
                }
            }
        };

        struct TFactorsBlock: public TThrRefBase, public TModelBinding {
        };

        struct TFactor: public TFactorBinding {
        public:
            const TModelBinding* ModelMeta = nullptr;
        public:
            TFactor() = default;

            TFactor(TString modelName, TString varName, const TModelBinding* modelPtr)
                : TFactorBinding(modelName, varName)
                , ModelMeta(modelPtr)
            {
                Y_ASSERT(modelPtr);
            }
        };

    protected:
        // extra accessors (for unit tests)
        void DbgAddModel(const TString modelName, TInputTraits inputs, TVector<TString>&& queryEmbs);
        void DbgAddFactor(const TString& factorName, const TFactorBinding& factor);

    private:
        THashMap<TString, TString> ModelByName;
        THashMap<TString, TFactor> FactorByName;
        THashMap<TString, TIntrusivePtr<TFactorsBlock>> AvailableModels;
        bool StaticDisable = false;

    public:
        TDssmBundle();

        void SetBundle(const TDssmModelsConfig& proto, const TFsPath& fsPath);

        TFactorBinding GetFactor(TStringBuf factorName, bool fastFeaturesOnly) const;

        const TModelBinding& GetModelMeta(TStringBuf modelname) const;

        TString GetModelFileName(TStringBuf modelName) const;

        inline bool GetStaticDisable() const {
            return StaticDisable;
        }

        inline void SetStaticDisable(bool v) {
            StaticDisable = v;
        }

    public:
        using TPtr = TIntrusivePtr<TDssmBundle>;
    };

    //
    // TDssmFeatureCalcerPlan: a configured list of factors with their Dssm bindings
    //
    class TDssmFeatureCalcerPlan {
    public:
        struct TModelDescription {
            TString ModelAlias;
            TDssmBundle::TModelBinding SumBinding;
            const NNeuralNetApplier::TModel* Model = nullptr;

            // if QueryEmbeddings are precalculated, we will extract Submodels
            TString SubmodelCacheKey;
            NNeuralNetApplier::TModelPtr QuerySubmodel;  // (query, query_reg) -> QueryEmbeddings
            NNeuralNetApplier::TModelPtr DocSubmodel; // (QueryEmbeddings, url, title...) -> Factor

            inline bool IsSubmodelsCreated() const {
                return QuerySubmodel && DocSubmodel;
            }
        };

        struct TFactorDescription {
            ui32 FactorLocalId = Max();
            ui32 OutputLocalId = Max();
            TString FactorName;
            TString VarName;
            const TModelDescription* ModelDescr = nullptr;
        };

    private:
        using TFactors = TVector<THolder<TFactorDescription>>;
        using TModels = TVector<THolder<TModelDescription>>;
        TFactors Factors;
        TModels Models;
        bool Finalized = false;

    public:
        const TVector<THolder<TFactorDescription>>& GetFactors() const {
            return Factors;
        }
        const TVector<THolder<TModelDescription>>& GetModels() const {
            return Models;
        }

        TModelDescription* GetModel(const TDssmBundle::TModelBinding& modelMeta) {
            const TString& modelAlias = modelMeta.ModelName;
            Y_ENSURE(!modelAlias.empty());

            TModels::iterator m = std::find_if(Models.begin(), Models.end(), [&](const THolder<TModelDescription>& i) {
                return i->ModelAlias == modelAlias;
            });
            TModelDescription* modelDescr;
            if (m != Models.end()) {
                modelDescr = m->Get();
            } else {
                Y_ENSURE(!Finalized);
                THolder<TModelDescription> holder = MakeHolder<TModelDescription>();
                modelDescr = holder.Get();
                modelDescr->ModelAlias = modelAlias;
                modelDescr->SumBinding.AssignFromMetadata(modelMeta);
                Y_ASSERT(modelDescr->SumBinding.VarNames.empty()); // actual outputs is added at Combine()

                Models.emplace_back(std::move(holder));
            }
            return modelDescr;
        }

        const TModelDescription* GetUpdatedModel(const TDssmBundle::TModelBinding& modelMeta, const TDssmBundle::TFactorBinding& factor) {
            Y_ENSURE(!Finalized);
            TModelDescription* modelDescr = GetModel(modelMeta);
            modelDescr->SumBinding.Combine(factor);
            return modelDescr;
        }

        void Add(int factorLocalId, const TString& factorName, const TDssmBundle& modelsMeta, bool fastFeaturesOnly) {
            Y_ENSURE(!Finalized);
            TDssmBundle::TFactorBinding factor = modelsMeta.GetFactor(factorName, fastFeaturesOnly);
            if (Y_UNLIKELY(factor.ShouldFillCanonicalValue())) {
                return;
            }
            const TDssmBundle::TModelBinding& modelMeta = modelsMeta.GetModelMeta(factor.ModelName);
            const TModelDescription* modelDescr = GetUpdatedModel(modelMeta, factor);

            THolder<TFactorDescription> f = MakeHolder<TFactorDescription>();
            f->FactorLocalId = factorLocalId;
            f->OutputLocalId = modelDescr->SumBinding.FindOutputLocalId(factor.VarName);
            f->FactorName = factorName;
            f->VarName = factor.VarName;
            f->ModelDescr = modelDescr;
            Factors.emplace_back(std::move(f));
        }

        void PrepareModels(const NRTYFactors::TConfig* cacheObj, bool addToCache) {
            Y_ENSURE(!Finalized);
            for (const auto& modelDescrPtr: Models) {
                Y_ASSERT(modelDescrPtr->Model);
                Y_ENSURE(modelDescrPtr->Model != nullptr); // the field is set by LoadModels()
                if (!modelDescrPtr->SumBinding.QueryCacheOutputs.empty()) {
                    modelDescrPtr->SubmodelCacheKey = modelDescrPtr->SumBinding.FormatCacheKey();
                    if (cacheObj && FindCachedSubmodels(*cacheObj, *modelDescrPtr)) {
                        Y_ASSERT(modelDescrPtr->IsSubmodelsCreated());
                        continue;
                    }

                    DoExtractSubmodels(*modelDescrPtr);

                    if (cacheObj && addToCache && modelDescrPtr->IsSubmodelsCreated()) {
                        // addToCache==true happens once, during BaseSearch.Open. Hence, only the 'maximal' (relev=all_factors) submodel is cached here.
                        AddSubmodelsToCache(*cacheObj, *modelDescrPtr);
                    }
                }
            }
            Finalized = true;
        }

        void AddSubmodelsToCache(const NRTYFactors::TConfig& cacheObj, TModelDescription& modelDescr);

    private:
        bool FindCachedSubmodels(const NRTYFactors::TConfig& cacheObj, TModelDescription& modelDescr);
        void DoExtractSubmodels(TModelDescription& modelDescr);
    };

    //
    // TDssmFeatureCalcerBase: applies one or more models to calculate factors
    //
    class TDssmFeatureCalcerBase: public IFeatureCalcer {
    private:
        THolder<TDssmFeatureCalcerPlan> Plan;
        const TDynMapping& Factors;
        bool IsFirstDoc;

    public:
        TDssmFeatureCalcerBase(THolder<TDssmFeatureCalcerPlan>&& plan, const TDynMapping& factorsInRequest)
            : Plan(std::move(plan))
            , Factors(factorsInRequest)
            , IsFirstDoc(true)
        {
            Y_VERIFY(Plan && Plan->GetFactors().size() <= factorsInRequest.size());
        }

        virtual void Calc(TFactorStorage& storage, const TRTYDynamicFeatureContext& ctx, ui32 docId) override;

    protected:
        virtual void ApplyModelToQuery(const TDssmFeatureCalcerPlan::TModelDescription* m, const TRTYDynamicFeatureContext& ctx) = 0;

        virtual void ApplyModel(const TDssmFeatureCalcerPlan::TModelDescription* m, const TRTYDynamicFeatureContext& ctx, ui32 docId) = 0;

        virtual float GetResult(const ui32 outputLocalId) const = 0;
    };

    //
    // TDssmFeature : DSSM-calculating feature (plugin)
    //
    class TDssmFeature final: public IFeature {
    private:
        using TFactorsInfo = THashSet<TString>;

    private:
        // Note: keep the fields immutable, because the object is a singleton
        TFactorsInfo FactorsStaticInfo;

    public:
        static void LoadModels(TDssmFeatureCalcerPlan& plan, const NRTYFactors::TConfig* relevCfg, const TString& modelsPath) {
            for (const auto& m : plan.GetModels()) {
                m->Model = GetOrLoadModel(relevCfg, modelsPath, m->ModelAlias); // throws
                Y_ASSERT(m->Model);
            }
            plan.PrepareModels(relevCfg, /*addToCache=*/true);
        }

        virtual void InitModels(const TDynMapping& factorsInConfig, const NRTYFactors::TConfig* relevCfg, const TString& modelsPath) override {
            if (IsFeatureEnabledFor(factorsInConfig)) {
                const TDssmBundle& bundle = GetOrLoadDssmBundle(relevCfg, modelsPath);
                THolder<TDssmFeatureCalcerPlan> plan = CreatePlan(FactorsStaticInfo, bundle, factorsInConfig, /*fastFeaturesOnly=*/false);
                PrintDebugInfo(*plan);
                LoadModels(*plan, relevCfg, modelsPath);
            }
        }

        static void BindPreloadedModels(TDssmFeatureCalcerPlan& plan, const NRTYFactors::TConfig* relevCfg) {
            for (const auto& m : plan.GetModels()) {
                m->Model = GetModel(relevCfg, m->ModelAlias);
            }
            plan.PrepareModels(relevCfg, /*addToCache=*/false);
        }

        void AddFactor(const TString& factorName, IFeature::TAddFactorFunc adder);

        virtual void InitStaticInfo(TStringBuf modelsPath, IFeature::TAddFactorFunc adder) override;

        virtual IFeatureCalcer::TPtr CreateCalcer(const TDynMapping& factorsInRequest, const NRTYFactors::TConfig* relevCfg, bool fastFeaturesOnly) override;

    private:
        bool IsFeatureEnabledFor(const TDynMapping& usedFactors) const;

    private:

        static THolder<TDssmFeatureCalcerPlan> CreatePlan(const TFactorsInfo& factorsInfo, const TDssmBundle& modelsInfo, const TDynMapping& factorsList, bool fastFeaturesOnly);

        static void PrintDebugInfo(const TDssmFeatureCalcerPlan& plan);

        static const NNeuralNetApplier::TModel* GetModel(const NRTYFactors::TConfig* relevCfg, const TString& modelName);

        static const NNeuralNetApplier::TModel* GetOrLoadModel(const NRTYFactors::TConfig* relevCfg, const TString& modelsPath, const TString& modelName);

        static TFsPath PathToDssmBundle(TStringBuf modelsPath);

        static bool IsStaticDisable(TStringBuf modelsPath);

        static TVector<TString> GetFactors(TStringBuf modelsPath);

        static const TDssmBundle& GetDssmBundle(const NRTYFactors::TConfig* holder);

        static const TDssmBundle& GetOrLoadDssmBundle(const NRTYFactors::TConfig* relevCfg, const TString& modelsPath);

        static TRTYFeaturesFactory::TRegistrator<TDssmFeature> Registrator;
    };
}
