#include "jni_utils.h"
#include "ru_yandex_crypta_graph_engine_score_stats_EngineHelper.h"
#include <crypta/graph/engine/proto/graph.pb.h>
#include <crypta/graph/engine/score/proto/stats.pb.h>
#include <crypta/graph/engine/score/native/lib/stats.h>
#include <library/cpp/protobuf/yql/descriptor.h>


namespace {
    class TJniData {
    public:
        TJniData(JNIEnv *env) : Env(env) {
            StatsClass = env->FindClass("ru/yandex/crypta/graph/engine/proto/TStats");
            Y_ENSURE(StatsClass, "Class not found");
            StatsParseFrom = env->GetStaticMethodID(StatsClass, "parseFrom", "([B)Lru/yandex/crypta/graph/engine/proto/TStats;");
            Y_ENSURE(StatsParseFrom, "parseFrom not found");
        }

        ~TJniData() {
            if (StatsClass) {
                Env->DeleteLocalRef(StatsClass);
                StatsClass = nullptr;
            }
        }

        NCrypta::NGraphEngine::TGraph ParseGraph(const jbyteArray& graphData) {
            NCrypta::NGraphEngine::TGraph graph;
            NJNI::FillProtoFromJByteArray(Env, graphData, graph);
            return graph;
        }

        NCrypta::NGraphEngine::TStatsOptions ParseOptions(const jbyteArray& statsOptionsData) {
            NCrypta::NGraphEngine::TStatsOptions options;
            NJNI::FillProtoFromJByteArray(Env, statsOptionsData, options);
            return options;
        }

        jobject ConvertToJStats(const NCrypta::NGraphEngine::TStats& stats) {
            auto serialized = stats.SerializeAsString();
            NJNI::TJniByteArrayHolder array{Env, serialized};
            auto result = Env->CallStaticObjectMethod(StatsClass, StatsParseFrom, array.Get());
            if (Env->ExceptionCheck()) {
              return nullptr;
            }
            Y_ENSURE(result, "Failed to call method");
            return result;
        }

    private:
        jclass StatsClass;
        jmethodID StatsParseFrom;
        JNIEnv *Env;
    };

}

jobject CollectStats(JNIEnv * env, jbyteArray graphData, jbyteArray statsOptionsData, bool isExp) {
    try {
        TJniData jni(env);
        NCrypta::NGraphEngine::TGraph graph = jni.ParseGraph(graphData);
        NCrypta::NGraphEngine::TStatsOptions statsOptions = jni.ParseOptions(statsOptionsData);
        NCrypta::NGraphEngine::TStats stats = isExp ? CollectExpStats(graph, statsOptions) : CollectProdStats(graph, statsOptions);

        return jni.ConvertToJStats(stats);
    } catch (...) {
        NJNI::Rethrow(env);
    }
    return 0;
}

JNIEXPORT jobject JNICALL Java_ru_yandex_crypta_graph_engine_score_stats_EngineHelper_collectProdStats
        (JNIEnv * env, jclass /*myclass*/, jbyteArray graphData, jbyteArray statsOptionsData) {
    return CollectStats(env, graphData, statsOptionsData, false);
}

JNIEXPORT jobject JNICALL Java_ru_yandex_crypta_graph_engine_score_stats_EngineHelper_collectExpStats
        (JNIEnv * env, jclass /*myclass*/, jbyteArray graphData, jbyteArray statsOptionsData) {
    return CollectStats(env, graphData, statsOptionsData, true);
}

JNIEXPORT jstring JNICALL Java_ru_yandex_crypta_graph_engine_score_stats_EngineHelper_getYQlProtoFieldForStatsRecord(JNIEnv* env, jclass /*myclass*/) {
    return env->NewStringUTF(GenerateProtobufTypeConfig<NCrypta::NGraphEngine::TStats>().c_str());
}

