#include "yavision_classifier.h"

#include <maps/libs/json/include/value.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/algorithm/retry.h>

namespace maps::mrc::classifiers {

namespace {

template <typename Functor>
auto retry(Functor&& func) -> decltype(func())
{
    return common::retryOnException<std::exception>(
        common::RetryPolicy()
            .setInitialTimeout(std::chrono::seconds(1))
            .setMaxAttempts(4)
            .setTimeoutBackoff(2),
        std::forward<Functor>(func));
}

std::string responseBodyOf(http::Request& request)
{
    auto response = retry([&request] {
        try {
            return request.perform();
        }
        catch (const std::exception& e) {
            WARN() << e.what();
            throw;
        }
    });
    if (response.status() == 200) {
        return response.readBody();
    }
    else {
        throw maps::Exception() << "Unexpected status: " << response.status()
                                << " from URL: " << request.url();
    }
}

} // anonymous namespace

bool YavisionClassifier::hasForbiddenContent(const common::Blob& encodedImage) const
{
    static http::Client httpClient;
    http::Request jsonRequest{httpClient, http::POST, yavisionUrl_};
    jsonRequest.setContent(encodedImage);
    auto resp = responseBodyOf(jsonRequest);
    auto val = json::Value::fromString(resp);
    for (const auto& name : {"binary_porn", "gruesome"}) {
        auto cl = val["NeuralNetClasses"][name].as<int>();
        REQUIRE(cl >= 0 && cl <= 255, "Unexpected yavision rating");
        if (cl >= 215)
            return true;
    }
    return false;
}

double YavisionClassifier::estimateForbiddenProbability(const common::Blob& encodedImage) const
{
    return hasForbiddenContent(encodedImage) ? 1.0 : 0.0;
}

double YavisionClassifier::estimateForbiddenProbability(const common::Bytes& encodedImage) const
{
    return estimateForbiddenProbability(common::toBlob(encodedImage));
}

double YavisionClassifier::estimateForbiddenProbability(const cv::Mat& image) const
{
    return estimateForbiddenProbability(common::encodeImage(image));
}

} // namespace maps::mrc::classifiers
