#pragma once

#include <util/generic/yexception.h>
#include <util/generic/hash.h>
#include <util/generic/fwd.h>

// A small cache for storing intermediate results obtained from a single UserFactorsCalcer expression
// Example:
// relev=store=bg_boostingctr_embed:parse_bg_model_embed(base64_decode(get_relev("query_embedding_boosting_ctr")))
// relev=calc=factor_name:dot_product(
//      doc_dssm_decompress("embedding_name","BoostingCtr", bg_model_embed_version("bg_boostingctr_embed")),
//      bg_model_embed_value("bg_boostingctr_embed")
// )

struct TCalcFactorsContext;

namespace NRTYFactors {
    struct IStoredType {
        virtual ~IStoredType() {}
        virtual ui32 GetTypeId() const = 0;
    };

    template<ui32 TypeIdValue>
    struct StoredType : public IStoredType {
        static constexpr ui32 TypeId = TypeIdValue;

        ui32 GetTypeId() const override {
            return TypeId;
        }
    };

    class IStoreValueCalcer {
        bool DependsOnDoc_ = false;
        ui32 DocId = -1;

        virtual void UpdateDependsOnDoc() = 0;

    public:
        virtual ~IStoreValueCalcer() { };
        virtual bool DependsOnDoc() const {
            return DependsOnDoc_;
        }
        virtual void SetDependsOnDoc() {
            DependsOnDoc_ = true;
        }

        const IStoredType* Get(TCalcFactorsContext& ctx, ui32 typeId, bool enableCacheInvalidation);

        // this method MUST ignore ctx.DocId
        virtual const IStoredType* DoGet(TCalcFactorsContext& ctx, ui32 typeId) = 0;
        virtual void Clear() {
        }
    };

    struct TQueryValuesStorage: public TThrRefBase {
        using TPtr = TIntrusivePtr<TQueryValuesStorage>;

        TQueryValuesStorage(bool enableCacheInvalidation)
            : EnableCacheInvalidation(enableCacheInvalidation) {
        }


        void AddStoreValueCalcer(TStringBuf variable, THolder<IStoreValueCalcer>&& calcer, bool strict) {
            if (Calcers.find(variable) != Calcers.end()) {
                Y_ENSURE(!strict, "duplicate store directive for a variable " << variable);
                Calcers.erase(variable);
            }
            Calcers.emplace(variable, std::move(calcer));
        }

        bool DependsOnDoc(TStringBuf name) {
            const auto it = Calcers.find(name);
            if (it == Calcers.end()) {
                return false;
            }
            return it->second->DependsOnDoc();
        }

        template<typename T>
        const T* Get(TStringBuf name, TCalcFactorsContext& ctx) {
            const auto it = Calcers.find(name);
            if (it == Calcers.end()) {
                return nullptr;
            }
            auto result = it->second->Get(ctx, T::TypeId, EnableCacheInvalidation);
            if (result == nullptr) {
                return nullptr;
            }
            Y_ENSURE(result->GetTypeId() == T::TypeId); // we might have a coredump otherwise
            return reinterpret_cast<const T*>(result);
        }

    private:
        THashMap<TString, THolder<IStoreValueCalcer>> Calcers;
        bool EnableCacheInvalidation;
    };
} // namespace NRTYFactors
