#include <mail/so/libs/jniwrapper_base/jniwrapper_base.h>

#include <catboost/libs/model/model.h>
#include <library/cpp/json/json_reader.h>

#include <util/stream/file.h>
#include <catboost/libs/cat_feature/cat_feature.h>
#include <catboost/libs/column_description/cd_parser.h>
#include <jni.h>


#define JniCatboostMethod(methodName) Java_ru_yandex_jni_catboost_JniCatboostModel_ ## methodName
#define JniFeaturesMethod(methodName) Java_ru_yandex_jni_catboost_JniCatboostFeatures_ ## methodName

struct TJniFeatures{
    THashMap<TString, float> NumericFeatures;
    THashMap<TString, int> CategoricalFeatures;
};

class TJniCatboost {
private:
	TFullModel Model;
	THashMap<TString, size_t> Dict;
public:
	TJniCatboost(IInputStream& modelStream, IInputStream& cdStream) {
		Model.Load(&modelStream);

        const TVector<TColumn> columns = ReadCD(&cdStream);

        size_t numSize = 0;
        size_t catSize = 0;

		for(size_t i = 0; i < columns.size(); i ++) {
		    const TColumn& column = columns[i];

		    if(column.Type == EColumn::Categ || column.Type == EColumn::Num) {
                size_t index;
                if(column.Type == EColumn::Categ) {
                    index = catSize++;
                } else {
                    index = numSize++;
                }

                Dict.emplace(column.Id, index);
		    }
		}

        Y_VERIFY(catSize <= Model.GetNumCatFeatures());
        Y_VERIFY(numSize <= Model.GetNumFloatFeatures());
	}

	[[nodiscard]] TVector<double> Calc(const TJniFeatures& featuresMap) const {
		TVector<float> numericFeatures(Model.GetNumFloatFeatures(), 0.);
        TVector<int> categoricalFeatures(Model.GetNumCatFeatures(), 0);

		for(const auto& [feature, value]: featuresMap.NumericFeatures) {
            if(const size_t* index = MapFindPtr(Dict, feature)) {
                numericFeatures[*index] = value;
            }
		}

        for(const auto& [feature, value]: featuresMap.CategoricalFeatures) {
            if(const size_t* index = MapFindPtr(Dict, feature)) {
                categoricalFeatures[*index] = value;
            }
        }

		TVector<double> results(Model.GetDimensionsCount(), 0.);
		Model.Calc(numericFeatures, categoricalFeatures, results);

		return results;
	}
};


extern "C" JNIEXPORT jlong JNICALL
JniCatboostMethod(createInstance)(
    JNIEnv* env,
    jclass,
    jstring pathToModel,
    jstring pathToDict) {
    try {
		TIFStream modelStream(NJniWrapper::JStringToUtf(env, pathToModel));
		TIFStream dictStream(NJniWrapper::JStringToUtf(env, pathToDict));
		auto model = MakeHolder<TJniCatboost>(modelStream, dictStream);
		return reinterpret_cast<jlong>(model.Release());
    } catch (...) {
        NJniWrapper::RethrowAsJavaException(env);
        return 0;
    }
}

extern "C" JNIEXPORT void JNICALL
JniCatboostMethod(destroyInstance)(
    JNIEnv*,
    jclass,
    jlong instance)
{
    delete reinterpret_cast<TJniCatboost*>(instance);
}

extern "C" JNIEXPORT jlong JNICALL
JniFeaturesMethod(createInstance)(
		JNIEnv* env,
		jclass) {
	try {
		const auto instance = reinterpret_cast<jlong>(new TJniFeatures);
		return instance;
	} catch (...) {
		NJniWrapper::RethrowAsJavaException(env);
		return 0;
	}
}

extern "C" JNIEXPORT void JNICALL
JniFeaturesMethod(destroyInstance)(
		JNIEnv*,
		jclass,
		jlong instance)
{
	delete reinterpret_cast<TJniFeatures*>(instance);
}

extern "C" JNIEXPORT void JNICALL
JniFeaturesMethod(setNumericFeature)(
		JNIEnv* env,
		jclass,
		jlong instance,
		jstring feature,
		jfloat value)
{
	auto* features = reinterpret_cast<TJniFeatures*>(instance);
	try {
		features->NumericFeatures.emplace(NJniWrapper::JStringToUtf(env, feature), value);
	} catch (...) {
		NJniWrapper::RethrowAsJavaException(env);
	}
}

extern "C" JNIEXPORT void JNICALL
JniFeaturesMethod(setCategoricalFeature)(
        JNIEnv* env,
        jclass,
        jlong instance,
        jstring feature,
        jstring value)
{
    auto* features = reinterpret_cast<TJniFeatures*>(instance);
    try {
        features->CategoricalFeatures.emplace(
                NJniWrapper::JStringToUtf(env, feature),
                CalcCatFeatureHash(NJniWrapper::JStringToUtf(env, value)));
    } catch (...) {
        NJniWrapper::RethrowAsJavaException(env);
    }
}

extern "C" JNIEXPORT jdoubleArray JNICALL
JniCatboostMethod(calc)(
    JNIEnv* env,
    jclass,
    jlong modelInstance,
	jlong featuresInstance)
{
	const auto* model = reinterpret_cast<const TJniCatboost*>(modelInstance);
	const auto* features = reinterpret_cast<const TJniFeatures*>(featuresInstance);

    try {
		const TVector<double> result = model->Calc(*features);
		jdoubleArray jResult = env->NewDoubleArray(result.size());
		env->SetDoubleArrayRegion(jResult, 0, result.size(), result.data());
		return jResult;
    } catch (...) {
        NJniWrapper::RethrowAsJavaException(env);
        return nullptr;
    }
}

