#include "DrmSessionJNI.hpp"
#include "package.hpp"

namespace twitch {
jmethodID DrmSessionJNI::s_createSession;
jmethodID DrmSessionJNI::s_initialize;
jmethodID DrmSessionJNI::s_getSessionId;
jmethodID DrmSessionJNI::s_generateKeyRequest;
jmethodID DrmSessionJNI::s_generateProvisionRequest;
jmethodID DrmSessionJNI::s_updateKeyResponse;
jmethodID DrmSessionJNI::s_updateProvisionResponse;
jmethodID DrmSessionJNI::s_release;
jfieldID DrmSessionJNI::s_requestUrl;
jfieldID DrmSessionJNI::s_requestData;
jclass DrmSessionJNI::s_drmSession;
jclass DrmSessionJNI::s_opaqueRequest;

void DrmSessionJNI::initialize(JNIEnv* env)
{
    s_drmSession = (jclass)env->NewGlobalRef(FindPlayerClass(env, "DrmSession"));
    s_opaqueRequest = (jclass)env->NewGlobalRef(FindPlayerClass(env, "DrmSession$OpaqueRequest"));
    s_createSession = env->GetStaticMethodID(s_drmSession, "create", ("(Ljava/nio/ByteBuffer;J)L" + PlayerPackage + "DrmSession;").c_str());
    s_initialize = env->GetMethodID(s_drmSession, "initialize", "()V");
    s_getSessionId = env->GetMethodID(s_drmSession, "getSessionId", "()[B");
    s_generateKeyRequest = env->GetMethodID(s_drmSession, "generateKeyRequest", ("([B)L" + PlayerPackage + "DrmSession$OpaqueRequest;").c_str());
    s_generateProvisionRequest = env->GetMethodID(s_drmSession, "generateProvisionRequest", ("()L" + PlayerPackage + "DrmSession$OpaqueRequest;").c_str());
    s_updateKeyResponse = env->GetMethodID(s_drmSession, "updateKeyResponse", "([B)V");
    s_updateProvisionResponse = env->GetMethodID(s_drmSession, "updateProvisionResponse", "([B)V");
    s_release = env->GetMethodID(s_drmSession, "release", "()V");
    s_requestUrl = env->GetFieldID(s_opaqueRequest, "url", "Ljava/lang/String;");
    s_requestData = env->GetFieldID(s_opaqueRequest, "data", "[B");
}

jobject DrmSessionJNI::create(JNIEnv* env, const std::vector<uint8_t>& system, DrmSessionJNI* ptr)
{
    uint8_t* buffer = const_cast<uint8_t*>(system.data());
    jobject bb = env->NewDirectByteBuffer(buffer, system.size());
    return env->CallStaticObjectMethod(s_drmSession, s_createSession, bb, reinterpret_cast<jlong>(ptr));
}

DrmSessionJNI::DrmSessionJNI(JNIEnv* env, const std::vector<uint8_t>& system, DrmSession::Listener& listener)
    : m_object(env, create(env, system, this))
    , m_listener(listener)
{
}

DrmSessionJNI::~DrmSessionJNI()
{
    jni::AttachThread attachThread(jni::getVM());
    JNIEnv* env = attachThread.getEnv();

    env->CallVoidMethod(m_object, s_release);
    checkException(env); // clear any release exception
}

void DrmSessionJNI::initialize()
{
    jni::AttachThread attachThread(jni::getVM());
    JNIEnv* env = attachThread.getEnv();

    env->CallVoidMethod(m_object, s_initialize);
    checkException(env);
}

const std::string& DrmSessionJNI::getSessionId() const
{
    if (m_sessionId.empty()) {
        jni::AttachThread attachThread(jni::getVM());
        JNIEnv* env = attachThread.getEnv();

        jbyteArray sessionId = static_cast<jbyteArray>(env->CallObjectMethod(m_object, s_getSessionId));
        jbyte* buffer = env->GetByteArrayElements(sessionId, JNI_FALSE);
        jsize bufferSize = env->GetArrayLength(sessionId);
        m_sessionId = std::string(reinterpret_cast<const char*>(buffer), static_cast<size_t>(bufferSize));
    }
    return m_sessionId;
}

DrmSession::OpaqueRequest DrmSessionJNI::generateKeyRequest(const std::vector<uint8_t>& init)
{
    jni::AttachThread attachThread(jni::getVM());
    JNIEnv* env = attachThread.getEnv();

    jsize dataBytes = static_cast<jsize>(init.size());
    jbyteArray dataArray = env->NewByteArray(dataBytes);
    env->SetByteArrayRegion(dataArray, 0, dataBytes, reinterpret_cast<const jbyte*>(init.data()));

    jobject requestObject = env->CallObjectMethod(m_object, s_generateKeyRequest, dataArray);
    return createRequest(env, requestObject);
}

DrmSession::OpaqueRequest DrmSessionJNI::generateProvisionRequest()
{
    jni::AttachThread attachThread(jni::getVM());
    JNIEnv* env = attachThread.getEnv();

    jobject requestObject = env->CallObjectMethod(m_object, s_generateProvisionRequest);
    return createRequest(env, requestObject);
}

void DrmSessionJNI::updateKeyResponse(const std::vector<uint8_t>& data)
{
    jni::AttachThread attachThread(jni::getVM());
    JNIEnv* env = attachThread.getEnv();

    jsize dataBytes = static_cast<jsize>(data.size());
    jbyteArray dataArray = env->NewByteArray(dataBytes);
    env->SetByteArrayRegion(dataArray, 0, dataBytes, reinterpret_cast<const jbyte*>(data.data()));
    env->CallVoidMethod(m_object, s_updateKeyResponse, dataArray);
    checkException(env);
}

void DrmSessionJNI::updateProvisionResponse(const std::vector<uint8_t>& data)
{
    jni::AttachThread attachThread(jni::getVM());
    JNIEnv* env = attachThread.getEnv();

    jsize dataBytes = static_cast<jsize>(data.size());
    jbyteArray dataArray = env->NewByteArray(dataBytes);
    env->SetByteArrayRegion(dataArray, 0, dataBytes, reinterpret_cast<const jbyte*>(data.data()));
    env->CallVoidMethod(m_object, s_updateProvisionResponse, dataArray);
    checkException(env);
}

DrmSession::OpaqueRequest DrmSessionJNI::createRequest(JNIEnv* env, jobject requestObject)
{
    // create the equivalent C++ request object from the java one
    jni::StringRef urlString(env, static_cast<jstring>(env->GetObjectField(requestObject, s_requestUrl)));
    jbyteArray dataArray = static_cast<jbyteArray>(env->GetObjectField(requestObject, s_requestData));
    jint dataSize = env->GetArrayLength(dataArray);

    OpaqueRequest request;
    request.url = urlString;
    request.data.resize(static_cast<size_t>(dataSize));
    env->GetByteArrayRegion(dataArray, 0, dataSize, (jbyte*)&request.data[0]);
    return request;
}

void DrmSessionJNI::checkException(JNIEnv* env)
{
    if (env->ExceptionCheck()) {
        jthrowable exception = env->ExceptionOccurred();
        (void)exception;
        env->ExceptionDescribe();
        env->ExceptionClear();
    }
}

extern "C" {
JNIEXPORT void JNICALL
Java_tv_twitch_android_player_DrmListener_onKeyExpired(JNIEnv* env, jobject instance, jlong ptr)
{
    (void)env;
    (void)instance;
    DrmSessionJNI* session = reinterpret_cast<DrmSessionJNI*>(ptr);
    if (session) {
        session->getListener().onKeyExpired(*session);
    }
}

JNIEXPORT void JNICALL
Java_tv_twitch_android_player_DrmListener_onKeyRequested(JNIEnv* env, jobject instance, jlong ptr)
{
    (void)env;
    (void)instance;
    DrmSessionJNI* session = reinterpret_cast<DrmSessionJNI*>(ptr);
    if (session) {
        session->getListener().onKeyRequested(*session);
    }
}

JNIEXPORT void JNICALL
Java_tv_twitch_android_player_DrmListener_onProvisionRequested(JNIEnv* env, jobject instance, jlong ptr)
{
    (void)env;
    (void)instance;
    DrmSessionJNI* session = reinterpret_cast<DrmSessionJNI*>(ptr);
    if (session) {
        session->getListener().onProvisioningRequested(*session);
    }
}

JNIEXPORT void JNICALL
Java_tv_twitch_android_player_DrmListener_onError(JNIEnv* env, jobject instance, jlong ptr, jstring message)
{
    (void)instance;
    DrmSessionJNI* session = reinterpret_cast<DrmSessionJNI*>(ptr);
    if (session) {
        jni::StringRef errorMessage(env, message);
        session->getListener().onError(*session, Error(ErrorSource::Decode, MediaResult::Error, errorMessage.get()));
    }
}
}
}
