#include <mail/so/libs/jniwrapper_base/jniwrapper_base.h>
#include <mail/so/spamstop/tools/text2shingles/lib/text2shingles.h>

#include <kernel/dssm_applier/nn_applier/lib/layers.h>
#include <kernel/dssm_applier/nn_applier/lib/states.h>

#include <library/cpp/json/json_reader.h>

#include <util/generic/strbuf.h>
#include <util/generic/string.h>
#include <util/generic/vector.h>
#include <util/memory/blob.h>
#include <util/string/join.h>
#include <util/string/strip.h>

static const TString EmptyString;

class TDssmApplier {
private:
    NNeuralNetApplier::TModel Model;
    const TVector<TString> Outputs;
    const bool PreprocessFields;
    const bool Debug;

private:
    static TStringBuf FirstLine(const TString& value) {
        size_t pos = value.find('\n');
        if (pos == TString::npos) {
            return value;
        } else {
            return {value.Data(), pos};
        }
    }

    static TString PrepareInput(const TString& name, const TString& value) {
        if (name == "subject"sv || name == "body"sv) {
            return JoinSeq(
                TStringBuf(" "),
                NText2Shingles::Text2Shingles(StripString(value), LANG_UNK, true));
        } else if (name == "fromname"sv) {
            return JoinSeq(
                TStringBuf(" "),
                NText2Shingles::Text2Shingles(StripString(value), LANG_UNK, false));
        } else if (name == "fromaddr"sv) {
            return JoinSeq(
                TStringBuf(" "),
                NText2Shingles::Text2Shingles(StripString(FirstLine(value)), LANG_UNK, false));
        } else if (name == "raw_fromaddr"sv) {
            return TString{StripString(FirstLine(value))};
        } else {
            return value;
        }
    }

public:
    TDssmApplier(
        const TString& pathToModel,
        const TString& layer,
        bool preprocessFields,
        bool debug)
        : Outputs(1, layer)
        , PreprocessFields(preprocessFields)
        , Debug(debug)
    {
        Model.Load(TBlob::PrechargedFromFile(pathToModel));
    }

    TVector<float> Apply(const TString& input) const {
        NJson::TJsonValue json;
        ReadJsonTree(input, false, &json, true);
        const NJson::TJsonValue::TMapType& map = json.GetMapSafe();
        size_t size{map.size()};
        TVector<TString> annotations(Reserve(size));
        TVector<TString> inputs(Reserve(size));
        for (const auto& iter: map) {
            annotations.push_back(iter.first);
            if (iter.second.IsString()) {
                if (PreprocessFields) {
                    inputs.push_back(PrepareInput(iter.first, iter.second.GetStringSafe()));
                } else {
                    inputs.push_back(iter.second.GetStringSafe());
                }
            } else {
                inputs.push_back(EmptyString);
            }
        }
        if (Debug) {
            Cerr << "Dssm input:" << Endl;
            for (size_t i = 0; i < annotations.size(); ++i) {
                Cerr << annotations[i] << " = \"" << inputs[i] << "\"" << Endl;
            }
            Cerr << Endl;
        }
        TVector<float> result;
        Model.Apply(
            MakeAtomicShared<NNeuralNetApplier::TSample>(
                annotations,
                inputs),
            Outputs,
            result);
        return result;
    }
};

extern "C" JNIEXPORT jlong JNICALL
Java_ru_yandex_so_dssm_applier_DssmApplier_createInstance(
    JNIEnv* env,
    jclass,
    jstring pathToModel,
    jstring layer,
    jboolean preprocessFields,
    jboolean debug)
{
    try {
        return reinterpret_cast<jlong>(
            new TDssmApplier(
                NJniWrapper::JStringToUtf(env, pathToModel),
                NJniWrapper::JStringToUtf(env, layer),
                preprocessFields,
                debug));
    } catch (...) {
        NJniWrapper::RethrowAsJavaException(env);
        return 0;
    }
}

extern "C" JNIEXPORT void JNICALL
Java_ru_yandex_so_dssm_applier_DssmApplier_destroyInstance(
    JNIEnv*,
    jclass,
    jlong instance)
{
    delete reinterpret_cast<TDssmApplier*>(instance);
}

extern "C" JNIEXPORT jfloatArray JNICALL
Java_ru_yandex_so_dssm_applier_DssmApplier_apply(
    JNIEnv* env,
    jclass,
    jlong instance,
    jstring input)
{
    const TDssmApplier* applier =
        reinterpret_cast<const TDssmApplier*>(instance);
    try {
        TVector<float> embedding =
            applier->Apply(NJniWrapper::JStringToUtf(env, input));
        size_t size = embedding.size();
        jfloatArray result = env->NewFloatArray(size);
        env->SetFloatArrayRegion(result, 0, size, embedding.data());
        return result;
    } catch (...) {
        NJniWrapper::RethrowAsJavaException(env);
        return 0;
    }
}

