#include "antiporno.h"

#include <mail/so/libs/jniwrapper_base/jniwrapper_base.h>

#include <cv/imgclassifiers/danet/external/ext_entry_point/entry_point.h>
#include <cv/library/imgcore/opencvutils/cvmat_img_preparer.h>

#include <library/cpp/json/writer/json.h>
#include <library/cpp/yconf/conf.h>

#include <contrib/libs/opencv/modules/core/include/opencv2/core/mat.hpp>
#include <contrib/libs/opencv/modules/imgcodecs/include/opencv2/imgcodecs.hpp>

#include <util/charset/wide.h>
#include <util/generic/map.h>
#include <util/generic/strbuf.h>
#include <util/generic/string.h>
#include <util/generic/yexception.h>

#include <stdlib.h> // strdup

DEFINE_SECTION(AntiPorno)
    DIRECTIVE(NnConfig)
    DIRECTIVE(NnModel)
    DIRECTIVE(Workers)
END_DEFINE_SECTION

DECLARE_CONFIG(TAntiPornoConfig)
BEGIN_CONFIG(TAntiPornoConfig)
    BEGIN_TOPSECTION(AntiPorno)
    END_SECTION()
END_CONFIG()

NAntiPorno::TAntiPorno::TAntiPorno(
    const TString& nnConfig,
    const TString& nnModel,
    ui64 workers)
    : DanetOps(new NNeuralNet::TDanetOps{nnConfig, nnModel, workers})
{
    NNeuralNet::DanetEnv().SetLoggingLevel(NNeuralNet::LOGGING_MODE_OFF);
    NNeuralNet::DanetEnv().SetNumThreads(1);
}

NAntiPorno::TAntiPorno::~TAntiPorno() = default;

TMap<TString, float> NAntiPorno::TAntiPorno::DetectPorno(const TCvMatImgPreparer& image) const {
    TMap<TString, float> predictions;
    DanetOps->PredictOnImage(DanetOps->GetFreeWorker(), image, predictions);
    return predictions;
}

THolder<TCvMatImgPreparer> NAntiPorno::LoadImage(const void* data, size_t size) {
    auto mat = cv::imdecode(cv::_InputArray(reinterpret_cast<const ui8*>(data), size), cv::IMREAD_COLOR);
    if (mat.data) {
        return MakeHolder<TCvMatImgPreparer>(mat);
    } else {
        return nullptr;
    }
}

TString NAntiPorno::SerializePredictions(const TMap<TString, float>& predictions) {
    NJsonWriter::TBuf writer(NJsonWriter::HEM_UNSAFE);
    writer.BeginObject();
    for (const auto& iter: predictions) {
        writer.WriteKey(iter.first);
        writer.WriteFloat(iter.second);
    }
    writer.EndObject();
    return writer.Str();
}

extern "C"
int JniWrapperCreateAntiPorno(const char* config, void** out) noexcept {
    try {
        TAntiPornoConfig configParser;
        if (!configParser.ParseMemory(config)) {
            TString message{TString::Join("Failed to parse config:\n", config, "\n")};
            configParser.PrintErrors(message);
            *out = strdup(message.c_str());
            return -1;
        }
        TYandexConfig::Section* section = configParser.GetFirstChild("AntiPorno");
        if (!section) {
            TString message{TString::Join("<AntiPorno> section not found in config:\n", config)};
            *out = strdup(message.c_str());
            return -1;
        }
        const TYandexConfig::Directives& directives = section->GetDirectives();
        TString nnConfig;
        if (!directives.GetValue("NnConfig", nnConfig)) {
            TString mesage{TString::Join("'NnConfig' not found in config:\n", config)};
            *out = strdup(mesage.c_str());
            return -1;
        }
        TString nnModel;
        if (!directives.GetValue("NnModel", nnModel)) {
            TString mesage{TString::Join("'NnModel' not found in config:\n", config)};
            *out = strdup(mesage.c_str());
            return -1;
        }
        ui64 workers;
        if (!directives.GetValue("Workers", workers)) {
            TString mesage{TString::Join("'Workers' not found in config:\n", config)};
            *out = strdup(mesage.c_str());
            return -1;
        }
        *out = new NAntiPorno::TAntiPorno(nnConfig, nnModel, workers);
    } catch (...) {
        return NJniWrapper::ProcessJniWrapperException((char**) out);
    }
    return 0;
}

extern "C"
void JniWrapperDestroyAntiPorno(void* instance) noexcept {
    delete static_cast<NAntiPorno::TAntiPorno*>(instance);
}

extern "C"
int JniWrapperDetectPorno(
    void* instance,
    const char* uri Y_DECLARE_UNUSED,
    const char* metainfo Y_DECLARE_UNUSED,
    const void* data,
    size_t size,
    char** out) noexcept
{
    const NAntiPorno::TAntiPorno* antiPorno =
        static_cast<const NAntiPorno::TAntiPorno*>(instance);
    try {
        THolder<TCvMatImgPreparer> image{NAntiPorno::LoadImage(data, size)};
        if (!image) {
            *out = strdup("Image loading failed");
            if (*out) {
                return -2;
            } else {
                return -1;
            }
        }
        TMap<TString, float> predictions{antiPorno->DetectPorno(*image)};
        *out = strdup(NAntiPorno::SerializePredictions(predictions).c_str());
        if (*out) {
            return 0;
        } else {
            return -1;
        }
    } catch (...) {
        return NJniWrapper::ProcessJniWrapperException(out);
    }
}

