#include "http_jni.hpp"
#include "PlatformJNI.hpp"
#include "package.hpp"
#include <mutex>
#include <string>
#include <thread>
#include <vector>

namespace twitch {
const std::string httpPackage(PlayerPackage + "http/");

jclass HttpClientJNI::s_class;
jclass HttpClientJNI::s_requestClass;
jclass HttpClientJNI::s_responseClass;
jclass HttpClientJNI::s_streamReadClass;
jclass HttpClientJNI::s_callbackClass;
jmethodID HttpClientJNI::s_classInit;
jmethodID HttpClientJNI::s_clientExecute;
jmethodID HttpClientJNI::s_clientRelease;
jmethodID HttpClientJNI::s_requestInit;
jmethodID HttpClientJNI::s_requestSetContent;
jmethodID HttpClientJNI::s_requestSetTimeout;
jmethodID HttpClientJNI::s_requestSetHeader;
jmethodID HttpClientJNI::s_requestCancel;
jmethodID HttpClientJNI::s_responseGetHeader;
jmethodID HttpClientJNI::s_responseGetStatus;
jmethodID HttpClientJNI::s_responseRead;
jmethodID HttpClientJNI::s_streamReadInit;
jmethodID HttpClientJNI::s_callbackInit;

void HttpClientJNI::initialize(JNIEnv* env)
{
    s_class = (jclass)env->NewGlobalRef(FindPlayerClass(env, "http/HttpUrlConnectionClient"));
    s_requestClass = (jclass)env->NewGlobalRef(FindPlayerClass(env, "http/Request"));
    s_responseClass = (jclass)env->NewGlobalRef(FindPlayerClass(env, "http/Response"));
    s_streamReadClass = (jclass)env->NewGlobalRef(FindPlayerClass(env, "http/NativeReadCallback"));
    s_callbackClass = (jclass)env->NewGlobalRef(FindPlayerClass(env, "http/NativeResponseCallback"));

    s_classInit = env->GetMethodID(s_class, "<init>", "()V");
    s_clientExecute = env->GetMethodID(s_class, "execute", ("(L" + httpPackage + "Request;L" + httpPackage + "ResponseCallback;)V").c_str());
    s_clientRelease = env->GetMethodID(s_class, "release", "()V");
    s_requestCancel = env->GetMethodID(s_requestClass, "cancel", "()V");
    s_requestInit = env->GetMethodID(s_requestClass, "<init>", "(Ljava/lang/String;Ljava/lang/String;)V");
    s_requestSetContent = env->GetMethodID(s_requestClass, "setContent", "(Ljava/nio/ByteBuffer;)V");
    s_requestSetTimeout = env->GetMethodID(s_requestClass, "setTimeout", "(I)V");
    s_requestSetHeader = env->GetMethodID(s_requestClass, "setHeader", "(Ljava/lang/String;Ljava/lang/String;)V");
    s_responseGetHeader = env->GetMethodID(s_responseClass, "getHeader", "(Ljava/lang/String;)Ljava/lang/String;");
    s_responseGetStatus = env->GetMethodID(s_responseClass, "getStatus", "()I");
    s_responseRead = env->GetMethodID(s_responseClass, "readContent", ("(L" + httpPackage + "ReadCallback;)V").c_str());
    s_streamReadInit = env->GetMethodID(s_streamReadClass, "<init>", "(J)V");
    s_callbackInit = env->GetMethodID(s_callbackClass, "<init>", "(J)V");
}

HttpClientJNI::HttpClientJNI(JNIEnv* env)
    : m_env(env)
    , m_attachThread(jni::getVM())
    , m_object(env, env->NewObject(s_class, s_classInit))
{
}

HttpClientJNI::~HttpClientJNI()
{
    jni::AttachThread attachThread(jni::getVM());
    JNIEnv* env = attachThread.getEnv();

    if (env) {
        env->CallVoidMethod(m_object.get(), s_clientRelease);
    }
}

class StreamHttpResponse;

class StreamHttpRequest : public HttpRequest {
public:
    StreamHttpRequest(const std::string& url, HttpMethod method)
        : HttpRequest(url, method)
    {
        jni::AttachThread attachThread(jni::getVM());
        JNIEnv* env = attachThread.getEnv();
        if (env) {
            jni::LocalRef<jstring> urlString(env, env->NewStringUTF(url.c_str()));
            jni::LocalRef<jstring> methodString(env, env->NewStringUTF(getMethodString().c_str()));
            jni::LocalRef<jobject> requestObject(env,
                env->NewObject(HttpClientJNI::s_requestClass,
                    HttpClientJNI::s_requestInit, urlString.get(), methodString.get()));
            m_request = env->NewGlobalRef(requestObject);
        }
    }

    ~StreamHttpRequest()
    {
        cancel();
    }

    void cancel() override
    {
        std::lock_guard<std::recursive_mutex> lock(m_mutex);
        if (m_request) {
            jni::AttachThread attachThread(jni::getVM());
            JNIEnv* env = attachThread.getEnv();

            if (env) {
                env->CallVoidMethod(m_request, HttpClientJNI::s_requestCancel);
                env->DeleteGlobalRef(m_request);
            }
            m_request = nullptr;
        }
    }

    void setHeader(const std::string& key, const std::string& value) override
    {
        jni::AttachThread attachThread(jni::getVM());
        JNIEnv* env = attachThread.getEnv();
        std::lock_guard<std::recursive_mutex> lock(m_mutex);
        if (env && m_request) {
            jni::LocalRef<jstring> keyString(env, env->NewStringUTF(key.c_str()));
            jni::LocalRef<jstring> valueString(env, env->NewStringUTF(value.c_str()));
            env->CallVoidMethod(m_request, HttpClientJNI::s_requestSetHeader, keyString.get(), valueString.get());
        }
    }

    void setContent(const std::vector<uint8_t>& content) override
    {
        jni::AttachThread attachThread(jni::getVM());
        JNIEnv* env = attachThread.getEnv();
        std::lock_guard<std::recursive_mutex> lock(m_mutex);

        const uint8_t* data = content.data();
        size_t size = content.size();
        jni::LocalRef<jobject> buffer(env, env->NewDirectByteBuffer(const_cast<uint8_t*>(data), size));
        env->CallVoidMethod(m_request, HttpClientJNI::s_requestSetContent, buffer.get());
    }

    void send(jobject client, HttpClient::ResponseHandler onResponse, HttpClient::ErrorHandler onError)
    {
        std::lock_guard<std::recursive_mutex> lock(m_mutex);
        if (!m_request) {
            if (onError) {
                onError(-1);
            }
            return;
        }

        m_responseHandler = onResponse;
        m_errorHandler = onError;

        jni::AttachThread attachThread(jni::getVM());
        JNIEnv* env = attachThread.getEnv();

        env->CallVoidMethod(m_request, HttpClientJNI::s_requestSetTimeout, static_cast<jint>(getTimeout().count()));

        jmethodID constructor = HttpClientJNI::s_callbackInit;
        jlong ptr = reinterpret_cast<jlong>(this);
        jni::LocalRef<jobject> callback(env, env->NewObject(HttpClientJNI::s_callbackClass, constructor, ptr));
        env->CallVoidMethod(client, HttpClientJNI::s_clientExecute, m_request, callback.get());

        if (env->ExceptionCheck()) {
            env->ExceptionDescribe();
            env->ExceptionClear();

            if (onError) {
                onError(-1);
            }
        }
    }

    jobject m_request;
    HttpClient::ResponseHandler m_responseHandler;
    HttpClient::ErrorHandler m_errorHandler;
    std::recursive_mutex m_mutex;
};

class StreamHttpResponse : public HttpResponse {
public:
    StreamHttpResponse(JNIEnv* env, int status, jobject responseObject)
        : HttpResponse(status)
        , m_response(env, responseObject)
    {
    }

    ~StreamHttpResponse() = default;

    void read(ContentHandler onBuffer, ErrorHandler onError) override
    {
        m_contentHandler = onBuffer;
        m_errorHandler = onError;
        jmethodID constructor = HttpClientJNI::s_streamReadInit;
        jlong ptr = reinterpret_cast<jlong>(this);
        jni::AttachThread attachThread(jni::getVM());
        JNIEnv* env = attachThread.getEnv();
        jobject callback = env->NewObject(HttpClientJNI::s_streamReadClass, constructor, ptr);

        jmethodID responseRead = HttpClientJNI::s_responseRead;
        env->CallVoidMethod(m_response, responseRead, callback);
        env->DeleteLocalRef(callback);
    }

    std::string getHeader(const std::string& key) const override
    {
        jni::AttachThread attachThread(jni::getVM());
        JNIEnv* env = attachThread.getEnv();
        if (env) {
            jni::LocalRef<jstring> keyString(env, env->NewStringUTF(key.c_str()));
            jni::LocalRef<jobject> result(env, env->CallObjectMethod(m_response, HttpClientJNI::s_responseGetHeader, keyString.get()));
            if (result) {
                jni::StringRef valueString(env, static_cast<jstring>(result.get()));
                return valueString.get();
            }
        }
        return "";
    }

    jni::GlobalRef<jobject> m_response;
    ContentHandler m_contentHandler;
    ErrorHandler m_errorHandler;
};

std::shared_ptr<HttpRequest> HttpClientJNI::createRequest(const std::string& url, HttpMethod method)
{
    return std::make_shared<StreamHttpRequest>(url, method);
}

void HttpClientJNI::send(std::shared_ptr<HttpRequest> request, ResponseHandler onResponse, ErrorHandler onError)
{
    std::shared_ptr<StreamHttpRequest> streamRequest = std::static_pointer_cast<StreamHttpRequest>(request);
    streamRequest->send(m_object, onResponse, onError);
}

extern "C" {
JNIEXPORT void JNICALL
Java_tv_twitch_android_player_http_NativeResponseCallback_onResponse(JNIEnv* env,
    jobject instance,
    jlong ptr,
    jobject response)
{
    StreamHttpRequest* request = reinterpret_cast<StreamHttpRequest*>(ptr);

    if (!request || !request->m_responseHandler) {
        return;
    }

    jint status = env->CallIntMethod(response, HttpClientJNI::s_responseGetStatus);
    auto httpResponse = std::make_shared<StreamHttpResponse>(env, status, response);
    request->m_responseHandler(httpResponse);
}

JNIEXPORT void JNICALL
Java_tv_twitch_android_player_http_NativeResponseCallback_onError(JNIEnv* env,
    jobject instance,
    jlong ptr,
    jthrowable exception)
{
    StreamHttpRequest* request = reinterpret_cast<StreamHttpRequest*>(ptr);

    if (request && request->m_errorHandler) {
        request->m_errorHandler(-1);
    }
}

JNIEXPORT void JNICALL
Java_tv_twitch_android_player_http_NativeReadCallback_onBuffer(JNIEnv* env,
    jobject instance,
    jlong ptr,
    jobject buffer,
    jint size,
    jboolean endOfStream)
{
    StreamHttpResponse* response = reinterpret_cast<StreamHttpResponse*>(ptr);

    if (response && response->m_contentHandler) {
        uint8_t* data = nullptr;

        if (buffer && size > 0) {
            data = reinterpret_cast<uint8_t*>(env->GetDirectBufferAddress(buffer));
        }

        response->m_contentHandler(data, size_t(size), endOfStream);
    }
}

JNIEXPORT void JNICALL
Java_tv_twitch_android_player_http_NativeReadCallback_onError(JNIEnv* env,
    jobject instance,
    jlong ptr,
    jthrowable exception)
{
    StreamHttpResponse* response = reinterpret_cast<StreamHttpResponse*>(ptr);

    if (response && response->m_errorHandler) {
        response->m_errorHandler(-1);
    }
}

JNIEXPORT jint JNICALL
Java_tv_twitch_android_player_http_NativeReadCallback_getTimeout(JNIEnv* env, jobject instance, jlong ptr)
{
    StreamHttpResponse* response = reinterpret_cast<StreamHttpResponse*>(ptr);

    if (response) {
        return static_cast<jint>(response->getReadTimeout().count());
    }

    return 0;
}
}
}
