#include "catboost_applier.h"

#include <crypta/lib/native/proto_serializer/proto_serializer.h>
#include <util/stream/file.h>

namespace {
    NCrypta::THerschelCatboostApplierScore ThresholdsFromJson(const TString& path) {
        NCrypta::THerschelCatboostApplierScore message{};
        NProtobufJson::Json2Proto(TFileInput(path).ReadAll(), message);
        return message;
    }
}

namespace NCrypta {

    THerschelCatboostApplier::THerschelCatboostApplier(const NGeobase::TLookup& lookup, const TString& modelFile, const THerschelCatboostApplierScore& thresholds)
        : FeaturesCalculator(lookup)
        , Model(ReadModel(modelFile))
        , Thresholds(thresholds) {
    };

    THerschelCatboostApplier::THerschelCatboostApplier(const NGeobase::TLookup& lookup, const TString& modelFile, const TString& thresholdsFile)
        : THerschelCatboostApplier(lookup, modelFile, ThresholdsFromJson(thresholdsFile)) {
    };

    NCB::NModelEvaluation::TConstModelEvaluatorPtr THerschelCatboostApplier::GetEvaluator() const {
        auto eval = Model.GetCurrentEvaluator()->Clone();
        eval->SetPredictionType(NCB::NModelEvaluation::EPredictionType::Probability);
        return eval;
    };

    double THerschelCatboostApplier::ApplyCatboost(const TFeatures& features) const {
        TVector<double> probs(1, 0);
        const auto& eval = GetEvaluator();

        TConstArrayRef<float> floatFeaturesArrayRef = MakeArrayRef(features.FloatFeatures);
        const TVector<TStringBuf> categoricalFeaturesArray(features.CategoricalFeatures.begin(), features.CategoricalFeatures.end());
        TConstArrayRef<TStringBuf> categoricalFeaturesArrayRef = MakeArrayRef(categoricalFeaturesArray);

        eval->Calc<TStringBuf>(MakeArrayRef(&floatFeaturesArrayRef, 1), MakeArrayRef(&categoricalFeaturesArrayRef, 1), probs);
        return probs[0];
    };

    THerschelCatboostApplierScore THerschelCatboostApplier::Score(const TStringBuf& ip, const TStringBuf& useragent, const THerschelStats& herschelStats) const {
        auto features = FeaturesCalculator.GetFeatures(ip, useragent, herschelStats);

        THerschelCatboostApplierScore score;

        FeaturesCalculator.UpdateKeyTypeFeatures(features, THerschelFeaturesCalculator::KeyTypeFeaturesByKeyType.Ip);
        score.SetIp(ApplyCatboost(features));

        FeaturesCalculator.UpdateKeyTypeFeatures(features, THerschelFeaturesCalculator::KeyTypeFeaturesByKeyType.IpUseragent);
        score.SetIpUseragent(ApplyCatboost(features));

        return score;
    };

    THerschelCatboostApplierResult THerschelCatboostApplier::Apply(const TStringBuf& ip, const TStringBuf& useragent, const THerschelStats& herschelStats) const {
        const auto score = Score(ip, useragent, herschelStats);

        THerschelCatboostApplierResult result;
        result.SetIp(score.GetIp() > Thresholds.GetIp());
        result.SetIpUseragent(score.GetIpUseragent() > Thresholds.GetIpUseragent());

        return result;
    };
}
