#include "ru_yandex_crypta_lib_experiments_Experiments.h"

#include <wordexp.h>

#include <ads/bsyeti/exp_stats/lib/exp.h>
#include <ads/bsyeti/libs/experiments/builder.h>
#include <ads/bsyeti/libs/experiments/ab_experiments_source.h>
#include <crypta/lib/java/experiments/proto/experiment.pb.h>

//https://nachtimwald.com/2017/07/09/jni-is-not-your-friend/

namespace {
    void Rethrow(JNIEnv* env) {
        env->ExceptionClear();
        try {
            throw;
        } catch (const std::exception& e) {
            jclass jc = env->FindClass("java/lang/RuntimeException");
            if (jc) {
                env->ThrowNew(jc, e.what());
            }
        } catch (...) {
            jclass jc = env->FindClass("java/lang/RuntimeException");
            if (jc) {
                env->ThrowNew(jc, "Unhandled exception");
            }
        }
    }

    NBSYeti::TExperimentSystem* CastHandle(jlong handle) {
        return reinterpret_cast<NBSYeti::TExperimentSystem*>(handle);
    }

    std::shared_ptr<wordexp_t> ParseArgs(const char* args) {
        std::shared_ptr<wordexp_t> result(new wordexp_t(), wordfree);
        Y_ENSURE(!wordexp(args, result.get(), 0), "Failed to parse arguments");
        return result;
    }
}

class TJniStringWrapper {
public:
    TJniStringWrapper(JNIEnv* env, jstring string)
        : Env(env)
        , JString(string)
        , String()
    {
        String = env->GetStringUTFChars(string, NULL);
    }

    ~TJniStringWrapper() {
        if (String) {
            Env->ReleaseStringUTFChars(JString, String);
        }
    }

    const char* c_str() {
        Y_ENSURE(String, "Incorrect string");
        return String;
    }

private:
    JNIEnv* Env;

    jstring JString;
    const char* String;
};

class TJniByteArrayWrapper {
public:
    TJniByteArrayWrapper(JNIEnv* env, jbyteArray data)
        : Env(env)
        , Data(nullptr)
        , Length(0)
    {
        Length = Env->GetArrayLength(data);
        Data = new char[Length];
        Env->GetByteArrayRegion(data, 0, Length, reinterpret_cast<jbyte*>(Data));
    }

    ~TJniByteArrayWrapper() {
        if (Data) {
            delete [] Data;
        }
    }

    TMemoryInput AsMemoryInput() const {
        return {Data, Length};
    }

private:
    JNIEnv* Env;

    char* Data;
    size_t Length;
};

class TJniClass {
public:
    TJniClass(JNIEnv* env, const char* name)
        : Env(env)
        , Class(nullptr)
    {
        Class = Env->FindClass(name);
    }

    ~TJniClass() {
        if (Class) {
            Env->DeleteLocalRef(Class);
            Class = nullptr;
        }
    }

    jclass Get() const {
        return Class;
    }

private:
    JNIEnv* Env;
    jclass Class;
};

class TJniByteArrayHolder {
public:
    TJniByteArrayHolder(JNIEnv* env, const TString& data)
        : Env(env)
        , Array(nullptr)
    {
        Array = Env->NewByteArray(data.size());
        if (!data.empty()) {
            Env->SetByteArrayRegion(Array, 0, data.size(), reinterpret_cast<const jbyte*>(data.c_str()));
        }
    }

    ~TJniByteArrayHolder() {
        if (Array) {
            Env->DeleteLocalRef(Array);
            Array = nullptr;
        }
    }

    jbyteArray Get() const {
        return Array;
    }

private:
    JNIEnv* Env;
    jbyteArray Array;
};

JNIEXPORT jlong JNICALL Java_ru_yandex_crypta_lib_experiments_Experiments_load(JNIEnv* env, jclass, jbyteArray behData, jbyteArray abData) {
    Y_UNUSED(behData); // TODO: remove completely because beh experiments are disabled
    try {
        TJniByteArrayWrapper jniAbData{env, abData};
        auto abMemoryInput = jniAbData.AsMemoryInput();
        auto abSource = NBSYeti::CreateAbExperimentsSource(abMemoryInput, {.LongConfig = true, .RequireDefault = false});

        auto handle = new NBSYeti::TExperimentSystem();
        handle->AddExternalExperimentSource(abSource);

        return (jlong)handle;
    } catch (...) {
        Rethrow(env);
    }
    return 0;
}

JNIEXPORT void JNICALL Java_ru_yandex_crypta_lib_experiments_Experiments_destroy(JNIEnv* env, jclass, jlong handle) {
    try {
        if (CastHandle(handle) != nullptr) {
            delete CastHandle(handle);
        }
    } catch (...) {
        Rethrow(env);
    }
}

JNIEXPORT jobject JNICALL Java_ru_yandex_crypta_lib_experiments_Experiments_getExperimentParameters(JNIEnv* env, jclass, jlong handle, jlong uniqId, jlong timestamp) {
    try {
        auto experimentSystem = CastHandle(handle);
        Y_ENSURE(experimentSystem, "Invalid handle");

        auto parameters = experimentSystem->GetParameters(uniqId, timestamp);
        auto unitedProto = parameters.UniteAll<NBSYeti::NExperimentParameters::TExperimentParameters>();

        auto serialized = unitedProto->SerializeAsString();
        TJniByteArrayHolder array{env, serialized};

        TJniClass expParametersClass{env, "NBSYeti/NExperimentParameters/Experiment$TExperimentParameters"};
        Y_ENSURE(expParametersClass.Get(), "Class not found");

        auto parseFrom = env->GetStaticMethodID(expParametersClass.Get(), "parseFrom", "([B)LNBSYeti/NExperimentParameters/Experiment$TExperimentParameters;");
        Y_ENSURE(parseFrom, "parseFrom not found");

        auto result = env->CallStaticObjectMethod(expParametersClass.Get(), parseFrom, array.Get());
        if (env->ExceptionCheck()) {
            return nullptr;
        }
        Y_ENSURE(result, "Failed to call method");
        return result;
    } catch (...) {
        Rethrow(env);
    }
    return 0;
}

JNIEXPORT jobject JNICALL Java_ru_yandex_crypta_lib_experiments_Experiments_getActiveExperiments(JNIEnv* env, jclass, jlong handle, jlong uniqId, jlong timestamp) {
    try {
        auto experimentSystem = CastHandle(handle);
        Y_ENSURE(experimentSystem, "Invalid handle");

        auto experiments = experimentSystem->GetActiveExperiments(uniqId, timestamp);

        TJniClass arrayListClass{env, "java/util/ArrayList"};
        auto arrayListConstructor = env->GetMethodID(arrayListClass.Get(), "<init>", "()V");
        Y_ENSURE(arrayListConstructor, "Constructor not found");

        TJniClass expClass{env, "NExperiments/Experiment$TExperiment"};
        Y_ENSURE(expClass.Get(), "Class not found");

        TJniClass listClass{env, "java/util/List"};
        Y_ENSURE(listClass.Get(), "Class not found");

        auto addToList = env->GetMethodID(listClass.Get(), "add", "(Ljava/lang/Object;)Z");
        Y_ENSURE(addToList, "add not found");

        auto parseFrom = env->GetStaticMethodID(expClass.Get(), "parseFrom", "([B)LNExperiments/Experiment$TExperiment;");
        Y_ENSURE(parseFrom, "parseFrom not found");

        auto result = env->NewObject(arrayListClass.Get(), arrayListConstructor);
        if (env->ExceptionCheck()) {
            return nullptr;
        }
        Y_ENSURE(result, "Failed to create List");

        for (const auto& experiment : experiments) {
            auto parameters = experiment.Parameters.UniteAll<NBSYeti::NExperimentParameters::TExperimentParameters>();
            NExperiments::TExperiment current;
            current.SetExperimentID(experiment.ExperimentId);
            current.SetBinID(experiment.BinId);
            current.MutableParameters()->MergeFrom(*parameters);

            auto serialized = current.SerializeAsString();
            TJniByteArrayHolder array{env, serialized};

            auto parsed = env->CallStaticObjectMethod(expClass.Get(), parseFrom, array.Get());
            if (env->ExceptionCheck()) {
                return nullptr;
            }
            Y_ENSURE(parsed, "Failed to parse");

            auto added = env->CallBooleanMethod(result, addToList, parsed);
            if (env->ExceptionCheck()) {
                return nullptr;
            }
            Y_ENSURE(added, "Faled to add");
        }

        return result;
    } catch (...) {
        Rethrow(env);
    }
    return 0;
}

JNIEXPORT jobject JNICALL Java_ru_yandex_crypta_lib_experiments_Experiments_getExperimentStatsNative(JNIEnv* env, jclass, jstring cmd, jstring ytToken) {
    try {
        InitGlobalLog2Null();
        NExpStats::TExpStatsArgs args{};
        {
            TJniStringWrapper cmdWrap(env, cmd);
            auto wordexp = ParseArgs(cmdWrap.c_str());

            TString errorMsg{};
            Y_ENSURE(
                NGetoptPb::GetoptPb(
                    wordexp->we_wordc, const_cast<const char**>(wordexp->we_wordv), args, errorMsg, { .DumpConfig = false }
                ),
                errorMsg
            );
        }
        {
            TJniStringWrapper ytTokenWrap(env, ytToken);
            args.MutableYt()->SetToken(ytTokenWrap.c_str());
        }

        TVector<TString> printFields{};
        TVector<NYT::TNode> stats = NExpStats::RunMetrics(args, &printFields);

        TJniClass statsClass{env, "NExperiments/Experiment$TExperimentStats"};
        Y_ENSURE(statsClass.Get(), "Class not found");

        auto parseFrom = env->GetStaticMethodID(statsClass.Get(), "parseFrom", "([B)LNExperiments/Experiment$TExperimentStats;");
        Y_ENSURE(parseFrom, "parseFrom not found");

        NExperiments::TExperimentStats statsProto{};

        for (const auto& row : stats) {
            auto addedRow = statsProto.AddRows();
            for (auto&& [key, value] : row.AsMap()) {
                const auto& attributes = value.GetAttributes();
                auto& addedValue = (*addedRow->MutableValues())[key];
                addedValue.SetRepr(value.ConvertTo<TString>());
                if (attributes["Value"].IsDouble()) {
                    addedValue.SetValue(attributes["Value"].AsDouble());
                }
                if (attributes["Deviation"].IsDouble()) {
                    addedValue.SetDeviation(attributes["Deviation"].AsDouble());
                }
                if (attributes["Delta"].IsDouble()) {
                    addedValue.SetDelta(attributes["Delta"].AsDouble());
                }
            }
        }

        auto serialized = statsProto.SerializeAsString();
        TJniByteArrayHolder array{env, serialized};

        auto result = env->CallStaticObjectMethod(statsClass.Get(), parseFrom, array.Get());
        if (env->ExceptionCheck()) {
            return nullptr;
        }
        Y_ENSURE(result, "Failed to parse");
        return result;
    } catch (...) {
        Rethrow(env);
    }
    return 0;
}
