#include "cb_calcer.h"

#include <saas/rtyserver/factors/factors_config.h>

#include <catboost/libs/cat_feature/cat_feature.h>

#include <util/stream/file.h>

using namespace NRTYServer;

void TCatboostFlatCalcer::LoadForSaas(const TFsPath& f) {
    f.CheckExists();
    {
        TFileInput fi(f);
        Load(&fi);
    }

    Y_ENSURE(GetSlices().empty(), "SaaS catboost model must have no slices");

    const TModelTrees* model = GetModel().ModelTrees.Get();
    Y_VERIFY(model);

    Y_ENSURE(model->GetTextFeatures().empty(), "SaaS catboost models supports only two types of features: float and categorical");

    // enumerate Cat input indexes in the "flat" vector
    CategoricalFeatureFlatIndexes(CatFeaturesFlatIndexes_);

    // cache the expected size of the "flat" input vector (float factors + cats)
    FlatInputSize_ = model->GetFlatFeatureVectorExpectedSize();
}

void TCatboostRelev::InitConfig(const NRTYFactors::TConfig& owner, const NJson::TJsonValue& rankModelConf) {
    Y_ENSURE(rankModelConf.Has("model"), "'model' property must be set"); // path to file
    TFsPath modelPath = owner.GetModelPath(rankModelConf["model"].GetString());
    Init(modelPath);
}

void TCatboostRelev::GetUsedCatFactors(TSet<ui32>& factors) const {
    const TModelTrees* model = Calcer_.GetModel().ModelTrees.Get();
    Y_VERIFY(model);
    TConstArrayRef<TCatFeature> catFeatures = model->GetCatFeatures();
    for (const TCatFeature& f : catFeatures)
        factors.insert(f.Position.FlatIndex);
}

bool TCatboostRelev::GetUsedFactors(TSet<ui32>& factors) const {
    GetUsedCatFactors(factors);

    const TModelTrees* model = Calcer_.GetModel().ModelTrees.Get();
    Y_VERIFY(model);

    TConstArrayRef<TFloatFeature> floatFeatures = model->GetFloatFeatures();
    for (const TFloatFeature& f : floatFeatures)
        factors.insert(f.Position.FlatIndex);

    return true;
}

void TCatboostRelev::Init(const TFsPath& modelFile) {
    Calcer_.LoadForSaas(modelFile);
}

void TCatboostRelev::TransformCat(float& storage) {
    ui32 oldVl = (ui32)storage;
    if ((float)oldVl == storage) {
        //TODO(SAASSUP-1202): this is not a good solution. It consumes CPU. Could we invent something better?
        TString sCat = "c" + ToString(oldVl);            // Adds "c" prefix, because catboost requires a non-decimal here
        ui32 hashVal = CalcCatFeatureHash(sCat);         // calculate CityHash
        storage = ConvertCatFeatureHashToFloat(hashVal); // CatBoost does reinterpret_cast here
    } else {
        ythrow yexception() << "incorrect cat factor value";
    }
}

void TCatboostRelev::TransformCats(const TCatboostFlatCalcer& f, float** factors, const size_t nDocs, TVector<float>& undoBuffer) {
    const auto catIndexes = f.GetCatFeaturesFlatIndexes();
    const ui32 nCats = catIndexes.size();
    undoBuffer.clear();
    undoBuffer.reserve(nCats * nDocs);
    for (size_t i = 0; i < nDocs; ++i) {
        for (ui32 idx : catIndexes) {
            undoBuffer.push_back(factors[i][idx]);
        }
    }
    for (size_t i = 0; i < nDocs; ++i) {
        for (ui32 idx : catIndexes) {
            TransformCat(factors[i][idx]);
        }
    }
}

void TCatboostRelev::RestoreCats(const TCatboostFlatCalcer& f, float** factors, const size_t nDocs, TVector<float>& undoBuffer) {
    const auto catIndexes = f.GetCatFeaturesFlatIndexes();
    const ui32 nCats = catIndexes.size();
    Y_ASSERT(undoBuffer.size() == nCats * nDocs);
    auto pUndo = undoBuffer.begin();
    for (size_t i = 0; i < nDocs; ++i) {
        for (ui32 idx : catIndexes) {
            factors[i][idx] = *pUndo++;
        }
    }
}

NRTYFactors::IUserRanking::TFactory::TRegistrator<TCatboostRelev> TCatboostRelev::Registrator("catboost");
