#include <util/generic/yexception.h>
#include <util/string/cast.h>
#include <util/string/printf.h>
#include <util/stream/file.h>
#include <library/cpp/string_utils/base64/base64.h>
#include <quality/ab_testing/usersplit_lib/carrier.h>
#include <quality/ab_testing/usersplit_lib/adapters/protobuf/adapter.h>
#include <quality/ab_testing/usersplit_lib/adapters/protobuf/proto/input.pb.h>
#include <quality/ab_testing/usersplit_lib/adapters/protobuf/proto/output.pb.h>

#include "carrierholder.h"
#include "common_jni_objects.h"

#include "ru_yandex_common_abt_Abt.h"

// java.lang.IllegalStateException
static jclass JC_IllegalStateException;

static jclass JC_RuntimeException;

static jclass iaeClass;
static jclass eeeClass;
static jclass stringClass;
static jstring emptyString;

// java.nio.ByteBuffer
static jclass JC_ByteBuffer;
static jclass JC_Buffer;
static jmethodID JMID_ByteBuffer_position;
static jmethodID JMID_ByteBuffer_limit;

#ifdef __ANDROID__
    static jint JNI_VERSION = JNI_VERSION_1_6;
#else
    static jint JNI_VERSION = JNI_VERSION_1_8;
#endif

jint JNI_OnLoad(JavaVM *jvm, void *) {
    JNIEnv *env;
    if (jvm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION) != JNI_OK) {
        return JNI_ERR;
    }

    iaeClass = getClassReference(env, "java/lang/IllegalArgumentException");
    eeeClass = getClassReference(env, "java/lang/Error");
    stringClass = getClassReference(env, "java/lang/String");
    emptyString = (jstring) env->NewGlobalRef(env->NewStringUTF(nullptr));

    JC_IllegalStateException = getClassReference(env, "java/lang/IllegalStateException");
    JC_RuntimeException = getClassReference(env, "java/lang/RuntimeException");

    JC_Buffer = getClassReference(env, "java/nio/Buffer");
    JC_ByteBuffer = getClassReference(env, "java/nio/ByteBuffer");
    JMID_ByteBuffer_position = env->GetMethodID(JC_Buffer, "limit", "()I");
    JMID_ByteBuffer_limit = env->GetMethodID(JC_Buffer, "limit", "(I)Ljava/nio/Buffer;");

    additional_onLoad(env);

    return JNI_VERSION;
}

void JNI_OnUnload(JavaVM *vm, void *) {
    // Obtain the JNIEnv from the VM
    // NOTE: some re-do the JNI Version check here, but I find that redundant
    JNIEnv* env;
    vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION);

    // Destroy the global references
    env->DeleteGlobalRef(iaeClass);
    env->DeleteGlobalRef(eeeClass);
    env->DeleteGlobalRef(stringClass);
    env->DeleteGlobalRef(emptyString);
    env->DeleteGlobalRef(JC_IllegalStateException);
    env->DeleteGlobalRef(JC_RuntimeException);
    env->DeleteGlobalRef(JC_Buffer);
    env->DeleteGlobalRef(JC_ByteBuffer);
}

jlong Java_ru_yandex_common_abt_Abt_create(JNIEnv *env, jclass, jstring jpath) {
    try {
        const char *path = env->GetStringUTFChars(jpath, nullptr);
        TCarrierHolder *obj = new TCarrierHolder(TString(path));
        env->ReleaseStringUTFChars(jpath, path);
        return (long) obj;
    } catch (yexception e) {
        env->ThrowNew(JC_RuntimeException, e.what());
        return 0;
    }
}

void Java_ru_yandex_common_abt_Abt_delete(JNIEnv *, jclass, jlong jptr) {
    TCarrierHolder *obj = (TCarrierHolder *) jptr;
    delete obj;
}

jboolean Java_ru_yandex_common_abt_Abt_reload(JNIEnv *env, jclass, jlong jptr, jstring jpath) {
    try {
        const char *chars = env->GetStringUTFChars(jpath, nullptr);
        TString path(chars);
        env->ReleaseStringUTFChars(jpath, chars);
        TFileInput input(path);
        TCarrierHolder *obj = (TCarrierHolder *) jptr;
        obj->Reload(&input);
        return JNI_TRUE;
    } catch (yexception e) {
        env->ThrowNew(JC_RuntimeException, e.what());
        return 0;
    }
}

jboolean Java_ru_yandex_common_abt_Abt_reloads(JNIEnv *env, jclass, jlong jptr, jstring jconfig) {
    try {
        const char *chars = env->GetStringUTFChars(jconfig, nullptr);
        TString config(chars);
        env->ReleaseStringUTFChars(jconfig, chars);
        TStringInput input(config);
        TCarrierHolder *obj = (TCarrierHolder *) jptr;
        obj->Reload(&input);
        return JNI_TRUE;
    } catch (yexception e) {
        env->ThrowNew(JC_RuntimeException, e.what());
        return 0;
    }
}

jstring Java_ru_yandex_common_abt_Abt_version(JNIEnv *env, jclass, jlong jptr) {
    try {
        TCarrierHolder *obj = (TCarrierHolder *) jptr;
        TString result = obj->Version();
        return env->NewStringUTF(result.data());
    } catch (yexception e) {
        env->ThrowNew(JC_RuntimeException, e.what());
        return 0;
    }
}

#define GET_STR_ARG(arg) \
    TString arg; \
    do { \
        if (Y_LIKELY(j##arg)) { \
            const char* c##arg = env->GetStringUTFChars(j##arg, nullptr); \
            arg = TString(c##arg); \
            env->ReleaseStringUTFChars(j##arg, c##arg); \
        } \
    } while(0);

void getInfo(JNIEnv *env, jlong jptr, jstring jfuid, jstring jyuid, jstring jicookie,
             jstring jtext, jstring jreqid, jlong jtimestamp, jstring jip,
             jint jregion, jstring jservice, jstring jdevice, jstring jtestid,
             jstring jconfig, jstring juuid, jstring jplatform, jstring jbrowser,
             jstring jhalting, jstring jloginHash, jboolean jstaff, jstring jdeviceId, jstring jpuid,
             NUserSplit::TExperimentsInfo<> &info) {
    TCarrierHolder *obj = (TCarrierHolder *) jptr;

    Y_UNUSED(jfuid);
    Y_UNUSED(jtext);
    Y_UNUSED(jreqid);
    Y_UNUSED(jip);
    Y_UNUSED(jtestid);
    Y_UNUSED(jhalting);
    Y_UNUSED(jloginHash);

    GET_STR_ARG(yuid);
    GET_STR_ARG(icookie);
    GET_STR_ARG(service);
    GET_STR_ARG(device);
    GET_STR_ARG(config);
    GET_STR_ARG(uuid);
    GET_STR_ARG(platform);
    GET_STR_ARG(browser);
    GET_STR_ARG(deviceId);
    GET_STR_ARG(puid);

    if (Y_UNLIKELY(jtimestamp < 0)) {
        env->ThrowNew(iaeClass, Sprintf("Invalid timestamp %ld", (long)jtimestamp).data());
    }
    if (Y_UNLIKELY(jregion < 0)) {
        env->ThrowNew(iaeClass, Sprintf("Invalid regionId %d", jregion).data());
    }
    NUserSplit::TRequestParams params;
    params.Yuid = yuid;
    params.ICookie = icookie;
    params.Timestamp = jtimestamp;
    params.RegionId = (ui32) jregion;
    params.Service = service;
    params.DeviceType = device;
    params.Uuid = uuid;
    params.MobilePlatform = platform;
    params.BrowserName = browser;
    params.IsStaff = jstaff;         // see https://st.yandex-team.ru/USEREXP-6541
    params.Puid = puid;
    params.DeviceId = deviceId;
    obj->Classify(params, info);
}

jobjectArray
Java_ru_yandex_common_abt_Abt_classify(JNIEnv *env, jobject, jlong jptr, jstring jfuid, jstring jyuid, jstring jicookie,
                                       jstring jtext, jstring jreqid, jlong jtimestamp, jstring jip,
                                       jint jregion, jstring jservice, jstring jdevice, jstring jtestid,
                                       jstring jconfig, jstring juuid, jstring jplatform, jstring jbrowser,
                                       jstring jhalting, jstring jlogin, jboolean jstaff) {
    NUserSplit::TExperimentsInfo<> info;
    getInfo(env, jptr, jfuid, jyuid, jicookie, jtext, jreqid, jtimestamp, jip, jregion, jservice, jdevice, jtestid,
            jconfig, juuid, jplatform, jbrowser, jhalting, jlogin, jstaff, nullptr, nullptr, info);
    jobjectArray result = env->NewObjectArray(info.ExperimentsList.size(), stringClass, emptyString);
    for (size_t i = 0; i < info.ExperimentsList.size(); ++i) {
        env->SetObjectArrayElement(result, i, env->NewStringUTF(info.ExperimentsList[i].Id.data()));
    }
    return result;
}

jobjectArray
Java_ru_yandex_common_abt_Abt_classifyAndGetParams(JNIEnv *env, jobject, jlong jptr, jstring jfuid, jstring jyuid,
                                                   jstring jicookie,
                                                   jstring jtext, jstring jreqid, jlong jtimestamp, jstring jip,
                                                   jint jregion, jstring jservice, jstring jdevice, jstring jtestid,
                                                   jstring jconfig, jstring juuid, jstring jplatform, jstring jbrowser,
                                                   jstring jhalting, jstring jlogin, jboolean jstaff, jstring jdeviceId,
                                                   jstring jpuid) {
    NUserSplit::TExperimentsInfo<> info;
    getInfo(env, jptr, jfuid, jyuid, jicookie, jtext, jreqid, jtimestamp, jip, jregion, jservice, jdevice, jtestid,
            jconfig, juuid, jplatform, jbrowser, jhalting, jlogin, jstaff, jdeviceId, jpuid, info);
    size_t len = info.ExperimentsList.size();
    jobjectArray result = env->NewObjectArray(len * 3, stringClass, emptyString);
    for (size_t i = 0; i < len; ++i) {
        env->SetObjectArrayElement(result, i, env->NewStringUTF(info.ExperimentsList[i].Id.data()));
        env->SetObjectArrayElement(result, len + i, env->NewStringUTF(info.ExperimentsParams[i].data()));
        env->SetObjectArrayElement(result, len * 2 + i, env->NewStringUTF(
                ToString<NUserSplit::TBucket>(info.ExperimentsList[i].Bucket).data()));
    }
    return result;
}

int classifyWithProtobuf(JNIEnv * env, jlong jptr, const void * inBufferPtr, int inBufferSize, void * outBufferPtr, int outBufferSize, THashSet<TString>* haltingPointsSetPtr) {
    try {
        TCarrierHolder *obj = (TCarrierHolder *) jptr;

        TUsersplitRequestParams input;
        if (!input.ParseFromArray(inBufferPtr, inBufferSize)) {
            env->ThrowNew(JC_RuntimeException, "Error while parsing protobuf");
            return 0;
        }

        NUserSplit::TRequestParams requestParams;
        if (haltingPointsSetPtr != nullptr) {
            requestParams.HaltingPointsSet = haltingPointsSetPtr;
        }

        NUserSplitAdapters::TransferProtoToRequestParams(input, requestParams);

        NUserSplit::TExperimentsInfo<> info;
        obj->Classify(requestParams, info);

        TUsersplitResult response;
        //NUserSplitAdapters::ExperimentsInfoToProto(info, response);

        if (response.ByteSize() < outBufferSize) {
            Y_PROTOBUF_SUPPRESS_NODISCARD response.SerializeToArray(outBufferPtr, outBufferSize);
            return response.ByteSize();
        } else {
            env->ThrowNew(JC_IllegalStateException, ::ToString(response.ByteSize()).c_str());
            return 0;
        }
    } catch (yexception e) {
        env->ThrowNew(JC_RuntimeException, e.what());
        return 0;
    }
}

JNIEXPORT void JNICALL Java_ru_yandex_common_abt_Abt_getExperiments
        (JNIEnv * env, jobject, jlong jptr, jlong inBufferPtr, jint inBufferSize, jlong outBufferPtr, jint outBufferSize, jlong haltingPointsSetPtr) {
    classifyWithProtobuf(env, jptr, (const void *) inBufferPtr, inBufferSize, (void *) outBufferPtr, outBufferSize, (THashSet<TString>*)haltingPointsSetPtr);
}

JNIEXPORT void JNICALL Java_ru_yandex_common_abt_Abt_getExperimentsBuf
        (JNIEnv * env, jobject, jlong jptr, jobject inBuffer, jobject outBuffer, jlong haltingPointsSetPtr) {

    // find direct byte buffer ptr
    void *inBufferPtr = env->GetDirectBufferAddress(inBuffer);
    // find direct byte buffer data size
    jint inBufferSize = env->CallIntMethod(inBuffer, JMID_ByteBuffer_position);

    // find direct byte buffer ptr
    void *outBufferPtr = env->GetDirectBufferAddress(outBuffer);
    // find direct byte buffer available size
    jint outBufferSize = env->CallIntMethod(outBuffer, JMID_ByteBuffer_position);

    int size = classifyWithProtobuf(env, jptr, (const void *) inBufferPtr, inBufferSize, (void *) outBufferPtr,
                                    outBufferSize, (THashSet<TString>*)haltingPointsSetPtr);
    // set right limit
    env->CallObjectMethod(outBuffer, JMID_ByteBuffer_limit, size);
}

#undef GET_STR_ARG

