#include "predict.h"

#include "archive.h"
#include "dataset.h"

TClfPredictor::TClfPredictor() {
    TBlob blob;
    LoadDataFromArchive("/model.bin", blob);
    TMemoryInput inputStream(blob.Data(), blob.Length());
    Load(&inputStream, CatboostFormulaEvaluator);
    ProbabilityEvaluator = CatboostFormulaEvaluator.GetCurrentEvaluator()->Clone();
    ProbabilityEvaluator->SetPredictionType(NCB::NModelEvaluation::EPredictionType::Probability);
}

TVector<double> TClfPredictor::Predict(const TVector<TString> &hostnames) const {
    TVector<TVector<float>> factors(hostnames.size());
    for (size_t i = 0; i < hostnames.size(); i++) {
        ClfDataset.GetFeatures(hostnames[i], factors[i]);
    }
    TVector<double> results(hostnames.size());
    ProbabilityEvaluator->CalcFlat(factors, results);
    return results;
}

double TClfPredictor::Predict(const TString &hostname) const {
    TVector<TString> hostnames;
    hostnames.push_back(hostname);
    TVector<double> probability = Predict(hostnames);
    return probability[0];
}

int TClfPredictor::GetClass(double probability) const {
    return static_cast<int>(probability >= 0.9);
}
