#include "antiporno.h"

#include <mail/so/libs/jniwrapper_base/jniwrapper_base.h>

#include <library/cpp/json/json_reader.h>
#include <library/cpp/testing/unittest/env.h>
#include <library/cpp/testing/unittest/registar.h>

#include <util/generic/map.h>
#include <util/generic/ptr.h>
#include <util/generic/string.h>
#include <util/generic/ylimits.h>
#include <util/generic/ymath.h>
#include <util/memory/blob.h>
#include <util/string/cast.h>

constexpr float EPSILON = 0.00001;

static TMap<TString, float> ParseSerialized(const TString& json) {
    TMap<TString, float> result;
    NJson::TJsonValue root{NJson::ReadJsonFastTree(json)};
    for (const auto& iter: root.GetMap()) {
        result.emplace(iter.first, iter.second.GetDouble());
    }
    return result;
}

static std::pair<TString, float> MaxDiff(const TMap<TString, float>& lhs, const TMap<TString, float>& rhs) {
    if (lhs.size() != rhs.size()) {
        return {"", Max()};
    }
    TString name;
    float maxDiff = 0;
    for (const auto& iter: lhs) {
        auto rhsIter = rhs.find(iter.first);
        if (rhsIter == rhs.end()) {
            return {iter.first, Max()};
        }
        float distance = Abs(iter.second - rhsIter->second);
        if (distance > maxDiff) {
            name = iter.first;
            maxDiff = distance;
        }
    }
    return {name, maxDiff};
}

static TMap<TString, float> DetectPorno(const TString& filename) {
    TBlob blob{TBlob::FromFile(filename)};
    THolder<TCvMatImgPreparer> image{NAntiPorno::LoadImage(blob.Data(), blob.Size())};
    if (!image) {
        UNIT_FAIL(TString::Join("Failed to load image from <", filename, ">"));
    }
    NAntiPorno::TAntiPorno antiPorno("config.cfg", "model", 1);
    return antiPorno.DetectPorno(*image);
}

static TString LoadCanonicalDataFor(const TString& filename) {
    TString canonical{TString::Join(SRC_("resources"), "/", filename, ".json")};
    TBlob blob{TBlob::FromFile(canonical)};
    return TString{blob.AsCharPtr(), blob.Size()};
}

static void Compare(const TMap<TString, float>& actual, const TString& expected) {
    auto diff = MaxDiff(actual, ParseSerialized(expected));
    if (diff.second > EPSILON) {
        UNIT_FAIL(
            TString::Join(
                "Result differs at '",
                diff.first,
                "' with distance ",
                ToString(diff.second),
                ": ",
                expected,
                " != ",
                NAntiPorno::SerializePredictions(actual)));
    }
}

static void TestFile(const TString& filename) {
    TString input{LoadCanonicalDataFor(filename)};
    auto predictions{DetectPorno(filename)};
    Compare(predictions, input);
}

Y_UNIT_TEST_SUITE(TestAntiPorno) {
    Y_UNIT_TEST(ScreenshotTest) {
        TestFile("screenshot.png");
    }
    Y_UNIT_TEST(CatTest) {
        TestFile("sarah.jpg");
    }
    Y_UNIT_TEST(JniWrapperTest) {
        const char* config = "<AntiPorno>\nNnConfig: config.cfg\nNnModel: model\nWorkers: 1\n</AntiPorno>";
        void *instance;
        UNIT_ASSERT_VALUES_EQUAL(JniWrapperCreateAntiPorno(config, &instance), 0);
        TBlob blob{TBlob::FromFile("spb.webp")};
        char* out;
        UNIT_ASSERT_VALUES_EQUAL(
            JniWrapperDetectPorno(instance, nullptr, nullptr, blob.Data(), blob.Size(), &out),
            0);
        Compare(ParseSerialized(out), LoadCanonicalDataFor("spb.webp"));
        JniWrapperFree(out);
        JniWrapperDestroyAntiPorno(instance);
    }
}

