#pragma once

#include "catboost_features_calculator.h"

#include <catboost/libs/model/model.h>

#include <util/generic/maybe.h>
#include <util/generic/string.h>
#include <util/generic/vector.h>


namespace NCrypta {
    struct TBigbCatboostApplierResult {
        TVector<ui64> Weights;
        TMaybe<ui64> Class;
    };

    template <typename TProfile>
    class TBigbCatboostApplier {
    public:
        struct TFeatures {
            TVector<float> FloatFeatures;
            TString QueriesText;
        };

        TBigbCatboostApplier(const TCatboostFeaturesCalculator& catboostFeaturesCalculator, const TProfile& profile)
            : CatboostFeaturesCalculator(catboostFeaturesCalculator)
            , Profile(profile) {
        }

        template <typename TThresholds>
        TBigbCatboostApplierResult Apply(TFullModel model, const TThresholds& thresholds) {
            const auto& features = GetFeatures();

            TConstArrayRef<float> floatFeaturesArray = MakeArrayRef(features.FloatFeatures);
            const TVector<TStringBuf> textFeatures{features.QueriesText};
            TConstArrayRef<TStringBuf> textFeaturesArray = MakeArrayRef(textFeatures);

            TVector<double> probs(thresholds.size(), 0);
            const auto& customEval = GetCustomEvaluator(model);
            customEval->Calc(MakeArrayRef(&floatFeaturesArray, 1), {}, MakeArrayRef(&textFeaturesArray, 1), probs);

            return TBigbCatboostApplierResult{
                .Weights = ConvertProbsToWeights(probs),
                .Class = GetClass(probs, thresholds),
            };
        }

    private:
        const TFeatures& GetFeatures() {
            if (!Features.Defined()) {
                Features = TFeatures{
                    .FloatFeatures = CatboostFeaturesCalculator.PrepareFloatFeatures(Profile),
                    .QueriesText = CatboostFeaturesCalculator.PrepareTextFeatures(Profile),
                };
            }

            return *Features;
        }

        static NCB::NModelEvaluation::TConstModelEvaluatorPtr GetCustomEvaluator(TFullModel model) {
            auto customEval = model.GetCurrentEvaluator()->Clone();
            customEval->SetPredictionType(NCB::NModelEvaluation::EPredictionType::Probability);
            return customEval;
        }

        template <typename TThresholds>
        static TMaybe<ui64> GetClass(const TVector<double>& probs, const TThresholds& thresholds) {
            TVector<double> ratios;
            ratios.reserve(probs.size());
            for (const auto& [prob, threshold] : Zip(probs, thresholds)) {
                ratios.push_back(prob / threshold);
            }

            const auto maxElem = MaxElement(ratios.begin(), ratios.end());
            const auto argmax = static_cast<ui64>(std::distance(ratios.begin(), maxElem));
            return (*maxElem >= 1.0) ? TMaybe<ui64>(argmax) : Nothing();
        }

        static TVector<ui64> ConvertProbsToWeights(const TVector<double>& probs) {
            TVector<ui64> weights;
            weights.reserve(probs.size());
            for (const auto& prob : probs) {
                weights.push_back(prob * 1000000);
            }
            return weights;
        }

        const TCatboostFeaturesCalculator& CatboostFeaturesCalculator;
        const TProfile& Profile;
        TMaybe<TFeatures> Features;
    };
}
