#include "MediaDecoderJNI.hpp"
#include "ATrace.hpp"
#include "package.hpp"
#include "playercore/SecureSampleBuffer.hpp"
#include <jni.h>
#include <string>

namespace twitch {
jmethodID MediaDecoderJNI::s_configure;
jmethodID MediaDecoderJNI::s_decode;
jmethodID MediaDecoderJNI::s_flush;
jmethodID MediaDecoderJNI::s_release;
jmethodID MediaDecoderJNI::s_reset;
jmethodID MediaDecoderJNI::s_hasInput;
jmethodID MediaDecoderJNI::s_hasOutput;
jmethodID MediaDecoderJNI::s_handleException;
jmethodID MediaDecoderJNI::s_getOutputTime;
jmethodID MediaDecoderJNI::s_createAudioFormat;
jmethodID MediaDecoderJNI::s_createVideoFormat;
jmethodID MediaDecoderJNI::s_createTextFormat;
jmethodID MediaDecoderJNI::s_createSample;
jmethodID MediaDecoderJNI::s_createSecureSample;
jclass MediaDecoderJNI::s_mediaFormats;
jclass MediaDecoderJNI::s_mediaSample;
jclass MediaDecoderJNI::s_byteBuffer;

void MediaDecoderJNI::initialize(JNIEnv* env)
{
    jclass cls = FindPlayerClass(env, "MediaDecoder");
    s_configure = env->GetMethodID(cls, "configure", "(Landroid/media/MediaFormat;)V");
    s_decode = env->GetMethodID(cls, "decode", ("(L" + PlayerPackage + "MediaSample;)V").c_str());
    s_release = env->GetMethodID(cls, "release", "()V");
    s_reset = env->GetMethodID(cls, "reset", "()V");
    s_flush = env->GetMethodID(cls, "flush", "()V");
    s_hasInput = env->GetMethodID(cls, "hasInput", "()Z");
    s_hasOutput = env->GetMethodID(cls, "hasOutput", "()Z");
    s_getOutputTime = env->GetMethodID(cls, "getOutputTime", "()J");
    jclass platformClass = FindPlayerClass(env, "Platform");
    s_handleException = env->GetMethodID(platformClass, "handleDecoderException", "(Ljava/lang/Throwable;)V");
    s_mediaFormats = (jclass)env->NewGlobalRef(FindPlayerClass(env, "Formats"));
    s_mediaSample = (jclass)env->NewGlobalRef(FindPlayerClass(env, "MediaSample"));
    s_byteBuffer = (jclass)env->NewGlobalRef(env->FindClass("java/nio/ByteBuffer"));
    s_createVideoFormat = env->GetStaticMethodID(s_mediaFormats, "createVideoFormat",
        "(Ljava/lang/String;[Ljava/nio/ByteBuffer;II[Ljava/nio/ByteBuffer;)Landroid/media/MediaFormat;");
    s_createAudioFormat = env->GetStaticMethodID(s_mediaFormats, "createAudioFormat",
        "(Ljava/lang/String;[Ljava/nio/ByteBuffer;II[Ljava/nio/ByteBuffer;)Landroid/media/MediaFormat;");
    s_createTextFormat = env->GetStaticMethodID(s_mediaFormats, "createTextFormat",
        "(Ljava/lang/String;)Landroid/media/MediaFormat;");
    s_createSample = env->GetMethodID(s_mediaSample, "<init>", "(Ljava/nio/ByteBuffer;IJJJZZ)V");
    s_createSecureSample = env->GetMethodID(s_mediaSample, "<init>", "(Ljava/nio/ByteBuffer;IJJJZZ[B[B[I[III)V");
}

jobject MediaDecoderJNI::createMediaFormat(JNIEnv* env, const MediaFormat& format)
{
    ATrace("MediaDecoderJNI::createMediaFormat");
    const auto& mediaType = format.getType();
    // copy the codec initialization data
    jobjectArray buffers = nullptr;
    jobject jformat = nullptr;
    jstring type = env->NewStringUTF(mediaType.name.c_str());

    // create protection data if any
    jobjectArray protectionData = nullptr;
    if (!format.getProtectionData().empty()) {
        protectionData = env->NewObjectArray(static_cast<jsize>(format.getProtectionData().size()), s_byteBuffer, nullptr);

        jint index = 0;
        for (const auto& entry : format.getProtectionData()) {
            uint8_t* bufferPtr = const_cast<uint8_t*>(entry.data());
            jobject bb = env->NewDirectByteBuffer(bufferPtr, entry.size());
            env->SetObjectArrayElement(protectionData, index++, bb);
            env->DeleteLocalRef(bb);
        }
    }

    if (mediaType.matches(MediaType::Video_AVC)) {
        buffers = env->NewObjectArray(2, s_byteBuffer, nullptr);

        jint w = format.getInt(MediaFormat::Video_Width);
        jint h = format.getInt(MediaFormat::Video_Height);

        jint index = 0;

        if (format.hasCodecData(MediaFormat::Video_AVC_SPS)) {
            auto& buffer = format.getCodecData(MediaFormat::Video_AVC_SPS);
            jobject bb = nullptr;

            if (buffer.size() > 0) {
                uint8_t* bufferPtr = const_cast<uint8_t*>(buffer.data());
                bb = env->NewDirectByteBuffer(bufferPtr, buffer.size());
            }

            env->SetObjectArrayElement(buffers, index++, bb);
            env->DeleteLocalRef(bb);
        }

        if (format.hasCodecData(MediaFormat::Video_AVC_PPS)) {
            auto& buffer = format.getCodecData(MediaFormat::Video_AVC_PPS);
            jobject bb = nullptr;

            if (buffer.size() > 0) {
                uint8_t* bufferPtr = const_cast<uint8_t*>(buffer.data());
                bb = env->NewDirectByteBuffer(bufferPtr, buffer.size());
            }

            env->SetObjectArrayElement(buffers, index++, bb);
            env->DeleteLocalRef(bb);
        }

        jformat = env->CallStaticObjectMethod(s_mediaFormats, s_createVideoFormat, type, buffers, w, h, protectionData);
        env->DeleteLocalRef(buffers);

    } else if (mediaType.matches(MediaType::Video_VP9)) {

        jint w = format.getInt(MediaFormat::Video_Width);
        jint h = format.getInt(MediaFormat::Video_Height);
        jformat = env->CallStaticObjectMethod(s_mediaFormats, s_createVideoFormat, type, buffers, w, h, protectionData);

    } else if (mediaType.matches(MediaType::Audio_AAC)) {
        buffers = env->NewObjectArray(1, s_byteBuffer, nullptr);
        // jint sampleSize = format.getInt(MediaFormat::Audio_SampleSize);
        jint sampleRate = 0;

        if (format.hasInt(MediaFormat::Audio_SampleRate)) {
            sampleRate = format.getInt(MediaFormat::Audio_SampleRate);
        }

        jint channels = 0;

        if (format.hasInt(MediaFormat::Audio_ChannelCount)) {
            channels = format.getInt(MediaFormat::Audio_ChannelCount);
        }

        if (format.hasCodecData(MediaFormat::Audio_AAC_ESDS)) {
            auto& buffer = format.getCodecData(MediaFormat::Audio_AAC_ESDS);
            uint8_t* buffer_ptr = const_cast<uint8_t*>(buffer.data());
            jobject bb = env->NewDirectByteBuffer(buffer_ptr, buffer.size());
            env->SetObjectArrayElement(buffers, 0, bb);
            env->DeleteLocalRef(bb);
        }

        jformat = env->CallStaticObjectMethod(s_mediaFormats, s_createAudioFormat, type, buffers, sampleRate, channels, protectionData);
        env->DeleteLocalRef(buffers);
    } else if (mediaType.matches(MediaType::Text_Json)
        || mediaType.matches(MediaType::Text_ID3)
        || mediaType.matches(MediaType::Text_VTT)) {
        jformat = env->CallStaticObjectMethod(s_mediaFormats, s_createTextFormat, type);
    }

    env->DeleteLocalRef(protectionData);
    env->DeleteLocalRef(type);

    if (env->ExceptionOccurred()) {
        env->ExceptionDescribe();
        env->ExceptionClear();
        return nullptr;
    }

    return jformat;
}

jobject MediaDecoderJNI::createMediaSample(JNIEnv* env, const MediaSampleBuffer& sample)
{
    ATrace("MediaDecoderJNI::createMediaSample");

    const std::vector<uint8_t>& buffer = sample.buffer;
    uint8_t* data = const_cast<uint8_t*>(&buffer[0]);
    jlong decodeTimeUs = sample.decodeTime.microseconds().count();
    jlong presentationTimeUs = sample.presentationTime.microseconds().count();
    jlong durationUs = sample.duration.microseconds().count();
    jint size = static_cast<jint>(buffer.size());
    jobject byteBuffer = env->NewDirectByteBuffer(data, size);
    jobject jsample;

    if (sample.type == MediaSample::Type::EncryptedMemoryBuffer) {
        const SecureSampleBuffer& secureSample = static_cast<const SecureSampleBuffer&>(sample);

        // initialization vector
        jsize initializationSize = static_cast<jsize>(secureSample.initializationVector.size());
        jbyteArray initializationVector = env->NewByteArray(initializationSize);
        env->SetByteArrayRegion(initializationVector, 0,
            initializationSize, reinterpret_cast<const jbyte*>(secureSample.initializationVector.data()));

        // key id
        jsize keyIdBytes = static_cast<jsize>(secureSample.keyId.size());
        jbyteArray keyId = env->NewByteArray(keyIdBytes);
        env->SetByteArrayRegion(keyId, 0, keyIdBytes, reinterpret_cast<const jbyte*>(secureSample.keyId.data()));

        jint subsampleCount = static_cast<jint>(secureSample.subsampleRange.size());
        jintArray clearData = env->NewIntArray(subsampleCount);
        jintArray protectedData = env->NewIntArray(subsampleCount);
        // bytes of clear data
        for (int i = 0; i < subsampleCount; i++) {
            jint clearBytes = secureSample.subsampleRange[i].first;
            jint protectedBytes = secureSample.subsampleRange[i].second;
            env->SetIntArrayRegion(clearData, i, 1, &clearBytes);
            env->SetIntArrayRegion(protectedData, i, 1, &protectedBytes);
        }

        // see constants from
        // https://developer.android.com/reference/android/media/MediaCodec.html#CRYPTO_MODE_AES_CBC
        int cipherMode = 0; // CRYPTO_MODE_UNENCRYPTED
        switch (secureSample.mode) {
        case SecureSampleBuffer::CipherMode::CBC:
            cipherMode = 2; // CRYPTO_MODE_AES_CBC;
            break;
        case SecureSampleBuffer::CipherMode::CTR:
            cipherMode = 1; // CRYPTO_MODE_AES_CTR
            break;
        }

        // create secure sample
        jsample = env->NewObject(s_mediaSample, s_createSecureSample,
            byteBuffer, size, decodeTimeUs, presentationTimeUs, durationUs,
            secureSample.isSyncSample,
            secureSample.isDecodeOnly,
            // crypto data
            initializationVector, keyId, clearData, protectedData, subsampleCount, cipherMode);

        // clear local refs
        env->DeleteLocalRef(initializationVector);
        env->DeleteLocalRef(keyId);
        env->DeleteLocalRef(clearData);
        env->DeleteLocalRef(protectedData);
    } else {
        jsample = env->NewObject(s_mediaSample, s_createSample,
            byteBuffer, size, decodeTimeUs, presentationTimeUs, durationUs,
            sample.isSyncSample, sample.isDecodeOnly);
    }

    // since this called frequently manually delete the local reference
    env->DeleteLocalRef(byteBuffer);
    return jsample;
}

MediaDecoderJNI::MediaDecoderJNI(JNIEnv* env, jobject platform, jobject decoder)
    : m_env(env)
    , m_platform(env, platform)
    , m_object(env, decoder)
{
}

MediaDecoderJNI::~MediaDecoderJNI()
{
    m_env->CallVoidMethod(m_object, s_release);
    checkException(); // clear any release exception
}

MediaResult MediaDecoderJNI::checkException(MediaResult result)
{
    if (m_env->ExceptionCheck()) {
        jthrowable throwable = m_env->ExceptionOccurred();
        m_env->ExceptionClear();
        m_env->CallVoidMethod(m_platform.get(), s_handleException, throwable);
        return MediaResult::Error;
    }

    return result;
}

MediaResult MediaDecoderJNI::configure(const MediaFormat& input, MediaFormat& output)
{
    if (!m_object || !m_env) {
        return MediaResult::ErrorInvalidState;
    }

    // since rendering is done in java use the same type
    output.setType(input.getType());
    output.setValues(input.getValues());
    output.setCodecData(input.getCodecData());

    jobject mediaFormat = createMediaFormat(m_env, input);

    if (mediaFormat != nullptr) {
        m_env->CallVoidMethod(m_object, s_configure, mediaFormat);
        m_env->DeleteLocalRef(mediaFormat);
        return checkException();
    }

    return MediaResult::Error;
}

MediaResult MediaDecoderJNI::decode(const MediaSampleBuffer& input)
{
    ATrace("MediaDecoderJNI::decode");
    if (!m_object || !m_env) {
        return MediaResult::ErrorInvalidState;
    }

    if (!m_env->CallBooleanMethod(m_object, s_hasInput)) {
        return MediaResult::ErrorTimeout;
    }

    MediaResult result = checkException();
    if (result != MediaResult::Ok) {
        return result;
    }

    jobject sample = createMediaSample(m_env, input);
    result = checkException();
    if (result != MediaResult::Ok) {
        return result;
    }

    if (sample == nullptr) {
        return MediaResult::Error;
    }

    m_env->CallVoidMethod(m_object, s_decode, sample);
    result = checkException();
    m_env->DeleteLocalRef(sample);
    return result;
}

MediaResult MediaDecoderJNI::hasOutput(bool& hasOutput)
{
    if (!m_object || !m_env) {
        return MediaResult::ErrorInvalidState;
    }

    hasOutput = m_env->CallBooleanMethod(m_object, s_hasOutput);
    return checkException();
}

MediaResult MediaDecoderJNI::getOutput(std::shared_ptr<MediaSample>& output)
{
    if (!m_object || !m_env || !output) {
        return MediaResult::ErrorInvalidState;
    }

    // we only fill the presentation time of the output sample, the actual content and rendering
    // is handled at the java layer
    jlong timeUs = m_env->CallLongMethod(m_object, s_getOutputTime);
    MediaResult result = checkException();
    if (result == MediaResult::Ok) {
        output->decodeTime = MediaTime(timeUs, std::micro::den);
    }
    return result;
}

MediaResult MediaDecoderJNI::flush()
{
    if (!m_object || !m_env) {
        return MediaResult::ErrorInvalidState;
    }

    m_env->CallVoidMethod(m_object, s_flush);
    return checkException();
}

MediaResult MediaDecoderJNI::reset()
{
    if (!m_object || !m_env) {
        return MediaResult::ErrorInvalidState;
    }

    m_env->CallVoidMethod(m_object, s_reset);
    return checkException();
}
}
