#include <mail/so/libs/fast_text/fasttext.h>
#include <mail/so/libs/jniwrapper_base/jniwrapper_base.h>
#include <mail/so/libs/wmd/wmd.h>

#include <library/cpp/stopwords/stopwords.h>

#include <util/generic/ptr.h>
#include <util/string/split.h>

#include <contrib/libs/intel/mkl/include/mkl_cblas.h>

#include <jni.h>

struct TFastTextWrapper {
    fasttext::FastText FastText;
    TWordFilter WordFilter;
};

extern "C" JNIEXPORT jlong JNICALL
Java_ru_yandex_jni_fasttext_JniFastText_createInstance(
    JNIEnv* env,
    jclass,
    jstring pathToModel,
    jstring stopWordList)
{
    try {
        auto wrapper = MakeHolder<TFastTextWrapper>();

        wrapper->FastText.loadModel(
            NJniWrapper::JStringToUtf(env, pathToModel));
        if (stopWordList) {
            wrapper->WordFilter.InitStopWordsList(
                NJniWrapper::JStringToUtf(env, stopWordList).c_str());
        }

        return reinterpret_cast<jlong>(wrapper.Release());
    } catch (...) {
        NJniWrapper::RethrowAsJavaException(env);
        return 0;
    }
}

extern "C" JNIEXPORT void JNICALL
Java_ru_yandex_jni_fasttext_JniFastText_destroyInstance(
    JNIEnv*,
    jclass,
    jlong instance)
{
    delete reinterpret_cast<TFastTextWrapper*>(instance);
}

extern "C" JNIEXPORT jint JNICALL
Java_ru_yandex_jni_fasttext_JniFastText_getDimension(
    JNIEnv*,
    jclass,
    jlong instance)
{
    return reinterpret_cast<const TFastTextWrapper*>(instance)->FastText.getDimension();
}

extern "C" JNIEXPORT jfloatArray JNICALL
Java_ru_yandex_jni_fasttext_JniFastText_createDoc(
    JNIEnv* env,
    jclass,
    jlong instance,
    jstring text)
{
    try {
        const auto wrapper =
            reinterpret_cast<const TFastTextWrapper*>(instance);
        const auto& model = wrapper->FastText;
        const auto& wordFilter = wrapper->WordFilter;
        size_t dimension = model.getDimension();
        TString str = NJniWrapper::JStringToUtf(env, text);
        TVector<TStringBuf> allWords =
            StringSplitter(str).SplitByFunc(isspace).SkipEmpty();
        TVector<TStringBuf> words(Reserve(allWords.size()));
        for (const auto& word: allWords) {
            if (!wordFilter.IsStopWord(word.Data(), word.Size())) {
                words.emplace_back(word);
            }
        }

        size_t size = words.size();
        jfloatArray result = env->NewFloatArray(dimension * size);
        fasttext::Vector vector(dimension);
        for (size_t i = 0; i < size; ++i) {
            model.getWordVector(vector, std::string(words[i]));
            vector.mul(1.f / vector.norm());
            env->SetFloatArrayRegion(
                result,
                i * dimension,
                dimension,
                vector.data());
        }
        return result;
    } catch (...) {
        NJniWrapper::RethrowAsJavaException(env);
        return 0;
    }
}

static NWmd::TMatrix MakeDistances(
    const NJniWrapper::TCriticalArrayElements<float>& doc1,
    const NJniWrapper::TCriticalArrayElements<float>& doc2,
    size_t cols)
{
    size_t doc1Rows = doc1.GetSize() / cols;
    size_t doc2Rows = doc2.GetSize() / cols;
    NWmd::TMatrix distances(doc2Rows, doc1Rows, 0.5f);
    cblas_sgemm(
        CblasRowMajor,
        CblasNoTrans,
        CblasTrans,
        distances.Rows,
        distances.Cols,
        cols,
        -0.5f,
        doc1.GetData(),
        cols,
        doc2.GetData(),
        cols,
        1.f,
        distances.Data,
        distances.Cols);
    return distances;
}

extern "C" JNIEXPORT jfloat JNICALL
Java_ru_yandex_jni_fasttext_JniFastText_doRelaxedWmd(
    JNIEnv *env,
    jclass,
    jfloatArray f1,
    jfloatArray f2,
    jint cols)
{
    try {
        NWmd::TMatrix distances{
            MakeDistances(
                NJniWrapper::TCriticalArrayElements<float>{env, f1, JNI_ABORT},
                NJniWrapper::TCriticalArrayElements<float>{env, f2, JNI_ABORT},
                cols)};
        return NWmd::RelaxedWmd(distances);
    } catch (...) {
        NJniWrapper::RethrowAsJavaException(env);
        return 0.f;
    }
}

extern "C" JNIEXPORT jfloat JNICALL
Java_ru_yandex_jni_fasttext_JniFastText_doGreedWmd(
    JNIEnv *env,
    jclass,
    jfloatArray f1,
    jfloatArray f2,
    jint cols)
{
    try {
        NWmd::TMatrix distances{
            MakeDistances(
                NJniWrapper::TCriticalArrayElements<float>{env, f1, JNI_ABORT},
                NJniWrapper::TCriticalArrayElements<float>{env, f2, JNI_ABORT},
                cols)};
        return NWmd::GreedWmd(distances);
    } catch (...) {
        NJniWrapper::RethrowAsJavaException(env);
        return 0.f;
    }
}

