#pragma once

#include <saas/rtyserver/factors/rank_model.h>
#include <kernel/catboost/catboost_calcer.h>

namespace NRTYServer {
    class TCatboostFlatCalcer final: public NCatboostCalcer::TCatboostCalcer {
    public:
        // NMatrixNet::IRelevCalcer
        virtual double DoCalcRelev(const float* features) const override {
            double result;
            DoCalcRelevs(&features, &result, 1);
            return result;
        }

        void LoadForSaas(const TFsPath& f);

        virtual void DoCalcRelevs(const float* const* docsFeatures, double* resultRelev, const size_t numDocs) const override {
            if (Y_UNLIKELY(!numDocs))
                return;

            TVector<TConstArrayRef<float>> flatFeaturesVec(numDocs);
            for (size_t doc = 0; doc < numDocs; ++doc) {
                flatFeaturesVec[doc] = MakeArrayRef(docsFeatures[doc], FlatInputSize_);
            }
            GetModel().CalcFlat(flatFeaturesVec, MakeArrayRef(resultRelev, numDocs));
        }

        bool HasCatFeatures() const {
            return !CatFeaturesFlatIndexes_.empty();
        }

        TConstArrayRef<ui32> GetCatFeaturesFlatIndexes() const {
            return TConstArrayRef<ui32>(CatFeaturesFlatIndexes_.begin(), CatFeaturesFlatIndexes_.end());
        }

    private:
        size_t FlatInputSize_;
        TVector<ui32> CatFeaturesFlatIndexes_;
    };

    class TCatboostRelev: public NRTYFactors::IUserRanking {
    public:
        void InitConfig(const NRTYFactors::TConfig& relevConf, const NJson::TJsonValue& rankModelConf) override;

        bool HasRelevance() const override {
            return true; // CalcRelevance() should be called
        }

        bool HasFactors() const override {
            return false; // CalcFactors() should not be called
        }

        void CalcFactors(TCalcFactorsContext&) const override {
            Y_ASSERT(0); // should not be called
        }

        bool GetUsedFactors(TSet<ui32>& usedFactors) const override;

        void GetUsedCatFactors(TSet<ui32>& usedCatFactors) const;

        static void TransformCat(float& storage);

        static void TransformCats(const TCatboostFlatCalcer& f, float** factors, const size_t nDocs, TVector<float>& undoBuffer);

        static void RestoreCats(const TCatboostFlatCalcer& f, float** factors, const size_t nDocs, TVector<float>& undoBuffer);

        void CalcRelevance(float** factors, float* results, const size_t count) const override {
            Y_UNUSED(factors);

            TVector<float> undoBuffer;
            if (Calcer_.HasCatFeatures())
                TransformCats(Calcer_, factors, count, undoBuffer);

            TVector<double> mxValues(count);
            Calcer_.DoCalcRelevs(factors, mxValues.data(), count);
            for (size_t i = 0; i < count; ++i)
                results[i] = static_cast<float>(mxValues[i]);

            if (Calcer_.HasCatFeatures())
                RestoreCats(Calcer_, factors, count, undoBuffer);
        }

    public:
        void Init(const TFsPath& modelFile);

    private:
        TCatboostFlatCalcer Calcer_;

    private:
        static NRTYFactors::IUserRanking::TFactory::TRegistrator<TCatboostRelev> Registrator;
    };
}
