#include "multiclass.h"

#include <rtline/util/algorithm/container.h>

TVector<double> NCatboostCalcer::TMulticlassPredictor::CalcRaw(
    TConstArrayRef<float> floatFeatures,
    const TVector<TStringBuf>& catFeatures
) const {
    TVector<double> result;
    result.resize(Model.GetDimensionsCount());
    Model.Calc(NContainer::Scalar(floatFeatures), NContainer::Scalar(catFeatures), result);
    return result;
}

TMaybe<NCatboostCalcer::TMulticlassPredictor::TResult> NCatboostCalcer::TMulticlassPredictor::Predict(
    TConstArrayRef<float> floatFeatures,
    const TVector<TStringBuf>& catFeatures
) const {
    TVector<double> rawValues = CalcRaw(floatFeatures, catFeatures);
    TMaybe<TResult> result;
    for (size_t i = 0; i < rawValues.size(); ++i) {
        if (result) {
            if (result->RawValue < rawValues[i]) {
                result->RawValue = rawValues[i];
                result->ClassId = i;
            }
        } else {
            result.ConstructInPlace();
            result->ClassId = i;
            result->RawValue = rawValues[i];
        }
    }
    return result;
}

void NCatboostCalcer::TMulticlassPredictor::Save(IOutputStream *out) const {
    Model.Save(out);
}

void NCatboostCalcer::TMulticlassPredictor::Load(IInputStream* in) {
    Model.Load(in);
}
