#include "ru_yandex_solomon_ts_codec_BitStreamNative.h"
#include "heap_bit_buf.h"

#include <solomon/libs/cpp/ts_codec/bit_reader.h>
#include <solomon/libs/cpp/ts_codec/bit_writer.h>

#include <util/generic/buffer.h>
#include <util/generic/vector.h>
#include <util/string/builder.h>
#include <util/generic/array_ref.h>

using namespace NSolomon;
using namespace NTs;

namespace {

void ThrowRuntimeException(JNIEnv* jenv, const TString& message) {
    jenv->ThrowNew(jenv->FindClass("java/lang/RuntimeException"), message.c_str());
}

template <typename T>
auto ToJavaArray(JNIEnv* jenv, TArrayRef<T> values) {
    if constexpr (std::is_same_v<T, jboolean>) {
        jbooleanArray result =jenv->NewBooleanArray(values.size());
        jenv->SetBooleanArrayRegion(result, 0, values.size(), values.data());
        return result;
    } else if constexpr (std::is_same_v<T, jbyte>) {
        jbyteArray result =jenv->NewByteArray(values.size());
        jenv->SetByteArrayRegion(result, 0, values.size(), values.data());
        return result;
    } else if constexpr (std::is_same_v<T, jint>) {
        jintArray result =jenv->NewIntArray(values.size());
        jenv->SetIntArrayRegion(result, 0, values.size(), values.data());
        return result;
    } else if constexpr (std::is_same_v<T, jlong>) {
        jlongArray result =jenv->NewLongArray(values.size());
        jenv->SetLongArrayRegion(result, 0, values.size(), values.data());
        return result;
    } else if constexpr (std::is_same_v<T, jdouble>) {
        jdoubleArray result =jenv->NewDoubleArray(values.size());
        jenv->SetDoubleArrayRegion(result, 0, values.size(), values.data());
        return result;
    } else {
        static_assert (TDependentFalse<T>, "unsupported type");
    }
}

template <typename TJavaArray, typename TConsumer>
void ForEach(JNIEnv* jenv, TJavaArray array, TConsumer consumer) {
    auto consumeElements = [=](auto* elements) {
        for (jsize i = 0, length = jenv->GetArrayLength(array); i < length; ++i) {
            consumer(elements[i]);
        }
    };

    if constexpr (std::is_same_v<TJavaArray, jbooleanArray>) {
        jboolean* elements = jenv->GetBooleanArrayElements(array, nullptr);
        consumeElements(elements);
        jenv->ReleaseBooleanArrayElements(array, elements, JNI_ABORT);
    } else if constexpr (std::is_same_v<TJavaArray, jbyteArray>) {
        jbyte* elements = jenv->GetByteArrayElements(array, nullptr);
        consumeElements(elements);
        jenv->ReleaseByteArrayElements(array, elements, JNI_ABORT);
    } else if constexpr (std::is_same_v<TJavaArray, jintArray>) {
        jint* elements = jenv->GetIntArrayElements(array, nullptr);
        consumeElements(elements);
        jenv->ReleaseIntArrayElements(array, elements, JNI_ABORT);
    } else if constexpr (std::is_same_v<TJavaArray, jlongArray>) {
        jlong* elements = jenv->GetLongArrayElements(array, nullptr);
        consumeElements(elements);
        jenv->ReleaseLongArrayElements(array, elements, JNI_ABORT);
    } else if constexpr (std::is_same_v<TJavaArray, jdoubleArray>) {
        jdouble* elements = jenv->GetDoubleArrayElements(array, nullptr);
        consumeElements(elements);
        jenv->ReleaseDoubleArrayElements(array, elements, JNI_ABORT);
    } else {
        static_assert (TDependentFalse<TJavaArray>, "unsupported type");
    }
}

void FillOffset(TBitWriter* w, jint offset) {
    for (jint i = 0; i < offset; ++i) {
        w->WriteBit(i % 2 == 0);
    }
}

} // namespace

jbooleanArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readBits(JNIEnv* jenv, jclass, jobject bitBuf, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jboolean> bits;
    while (r.Left() > 0) {
        bits.push_back(static_cast<jboolean>(r.ReadBit()));
    }

    return ToJavaArray(jenv, TArrayRef<jboolean>{bits});
}

jbyteArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readInt8(JNIEnv* jenv, jclass, jobject bitBuf, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jbyte> values;
    while (r.Left() > 0) {
        values.push_back(static_cast<jbyte>(r.ReadInt8()));
    }

    return ToJavaArray(jenv, TArrayRef<jbyte>{values});
}

jintArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readInt32(JNIEnv* jenv, jclass, jobject bitBuf, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jint> values;
    while (r.Left() > 0) {
        values.push_back(static_cast<jint>(r.ReadInt32()));
    }

    return ToJavaArray(jenv, TArrayRef<jint>{values});
}

jlongArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readInt64(JNIEnv* jenv, jclass, jobject bitBuf, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jlong> values;
    while (r.Left() > 0) {
        values.push_back(static_cast<jlong>(r.ReadInt64()));
    }

    return ToJavaArray(jenv, TArrayRef<jlong>{values});
}

jdoubleArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readDouble(JNIEnv* jenv, jclass, jobject bitBuf, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jdouble> values;
    while (r.Left() > 0) {
        values.push_back(static_cast<jdouble>(r.ReadDouble()));
    }

    return ToJavaArray(jenv, TArrayRef<jdouble>{values});
}

jintArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readInt32Bits(JNIEnv* jenv, jclass, jobject bitBuf, jint bits, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jint> values;
    while (r.Left() > 0) {
        values.push_back(static_cast<jint>(r.ReadInt32(static_cast<size_t>(bits))));
    }

    return ToJavaArray(jenv, TArrayRef<jint>{values});
}

jlongArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readInt64Bits(JNIEnv* jenv, jclass, jobject bitBuf, jint bits, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jlong> values;
    while (r.Left() > 0) {
        values.push_back(static_cast<jlong>(r.ReadInt64(static_cast<size_t>(bits))));
    }

    return ToJavaArray(jenv, TArrayRef<jlong>{values});
}

jintArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readVarInt32(JNIEnv* jenv, jclass, jobject bitBuf, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jint> values;
    while (r.Left() > 0) {
        size_t pos = r.Pos();
        auto value = r.ReadVarInt32();
        if (!value.has_value()) {
            ThrowRuntimeException(jenv, TStringBuilder() << "canot read value at position " << pos);
            return nullptr;
        }
        values.push_back(static_cast<jint>(*value));
    }

    return ToJavaArray(jenv, TArrayRef<jint>{values});
}

jlongArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readVarInt64(JNIEnv* jenv, jclass, jobject bitBuf, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jlong> values;
    while (r.Left() > 0) {
        size_t pos = r.Pos();
        auto value = r.ReadVarInt64();
        if (!value.has_value()) {
            ThrowRuntimeException(jenv, TStringBuilder() << "canot read value at position " << pos);
            return nullptr;
        }
        values.push_back(static_cast<jlong>(*value));
    }

    return ToJavaArray(jenv, TArrayRef<jlong>{values});
}

jintArray Java_ru_yandex_solomon_ts_1codec_BitStreamNative_readOnes(JNIEnv* jenv, jclass, jobject bitBuf, jint offset) {
    auto buf = NJava::FromHeapBitBuffer(jenv, bitBuf);

    TBitReader r{buf};
    r.SetPos(static_cast<size_t>(offset));

    TVector<jint> values;
    while (r.Left() > 0) {
        values.push_back(static_cast<jint>(r.ReadOnes(8)));
    }

    return ToJavaArray(jenv, TArrayRef<jint>{values});
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeBits(JNIEnv* jenv, jclass, jbooleanArray array, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w](jboolean value) { w.WriteBit(value == 1); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeInt8(JNIEnv* jenv, jclass, jbyteArray array, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w](jbyte value) { w.WriteInt8(value); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeInt32(JNIEnv* jenv, jclass, jintArray array, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w](jint value) { w.WriteInt32(value); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeInt64(JNIEnv* jenv, jclass, jlongArray array, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w](jlong value) { w.WriteInt64(value); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeDouble(JNIEnv* jenv, jclass, jdoubleArray array, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w](jdouble value) { w.WriteDouble(value); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeInt32Bits(JNIEnv* jenv, jclass, jintArray array, jint bits, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w, bits](jint value) { w.WriteInt32(value, static_cast<size_t>(bits)); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeInt64Bits(JNIEnv* jenv, jclass, jlongArray array, jint bits, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w, bits](jlong value) { w.WriteInt64(value, static_cast<size_t>(bits)); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeVarInt32(JNIEnv* jenv, jclass, jintArray array, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w](jint value) { w.WriteVarInt32(value); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeVarInt64(JNIEnv* jenv, jclass, jlongArray array, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w](jlong value) { w.WriteVarInt64(value); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}

jobject Java_ru_yandex_solomon_ts_1codec_BitStreamNative_writeOnes(JNIEnv* jenv, jclass, jintArray array, jint offset) {
    TBitBuffer buf;
    TBitWriter w{&buf};
    FillOffset(&w, offset);

    ForEach(jenv, array, [&w](jint value) { w.WriteOnes(value, 8); });
    w.Flush();

    return NJava::ToHeapBitBuffer(jenv, buf);
}
