#include "forbidden_classifier.h"

#include <library/cpp/resource/resource.h>

#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

namespace tf = tensorflow;

namespace maps::mrc::classifiers {

namespace {

const std::string FORBIDDEN_CLASSIFIER_RESOURCE_KEY = "/maps/mrc/classifiers/models/forbidden_classifier.gdef";

cv::Mat ResizeImage(const cv::Mat& img) {
    constexpr int IMAGE_WIDTH = 224;
    constexpr int IMAGE_HEIGHT = 224;
    constexpr int ASPECT_RATIO_MAX = 5;

    const int cropRows = std::min(img.rows, img.cols * ASPECT_RATIO_MAX);
    const int cropCols = std::min(img.cols, img.rows * ASPECT_RATIO_MAX);
    cv::Rect rcROI(0, 0, img.cols, img.rows);
    if (cropRows < img.rows || cropCols < img.cols) {
        rcROI = cv::Rect((img.cols - cropCols) / 2, (img.rows - cropRows) / 2, cropCols, cropRows);
    }

    cv::Mat dst = cv::Mat::zeros(IMAGE_WIDTH, IMAGE_HEIGHT, CV_8UC3);
    if (IMAGE_HEIGHT * rcROI.width < IMAGE_WIDTH * rcROI.height) {
        const int outCols = rcROI.width * IMAGE_HEIGHT / rcROI.height;
        const int outOffset = (IMAGE_WIDTH - outCols) / 2;
        cv::resize(img(rcROI), dst.colRange(outOffset, outOffset + outCols), cv::Size(outCols, IMAGE_HEIGHT), 0.0, 0.0, cv::INTER_CUBIC);
    } else {
        const int outRows = rcROI.height * IMAGE_WIDTH / rcROI.width;
        const int outOffset = (IMAGE_HEIGHT - outRows) / 2;
        cv::resize(img(rcROI), dst.rowRange(outOffset, outOffset + outRows), cv::Size(IMAGE_WIDTH, outRows), 0.0, 0.0, cv::INTER_CUBIC);
    }
    return dst;
}

cv::Mat pepareImage(const cv::Mat& image) {
    cv::Mat imageRGB;
    cv::cvtColor(ResizeImage(image), imageRGB, cv::COLOR_BGR2RGB);
    return imageRGB;
}

} // anonymous namespace

ForbiddenClassifier::ForbiddenClassifier()
    : inferencer_{
        wiki::tf_inferencer::TensorFlowInferencer::fromResource(
            FORBIDDEN_CLASSIFIER_RESOURCE_KEY
        )
    }
{}

std::vector<double> ForbiddenClassifier::estimateForbiddenProbability(const wiki::tf_inferencer::ImagesBatch &batch) const
{
    const std::string TF_LAYER_IMAGE_NAME       = "inference_input";
    const std::string TF_LAYER_BINARY_PORN_NAME = "binary_porn/output/Softmax:0";
    const std::string TF_LAYER_GRUESOME_NAME    = "gruesome/output/Softmax:0";

    const size_t batchSize = batch.size();
    wiki::tf_inferencer::ImagesBatch batchPrepared(batchSize);
    for (size_t i = 0; i < batchSize; i++) {
        batchPrepared[i] = pepareImage(batch[i]);
    }

    std::vector<tf::Tensor> result =
         inferencer_.inference(TF_LAYER_IMAGE_NAME, batchPrepared,
            {
                TF_LAYER_BINARY_PORN_NAME,
                TF_LAYER_GRUESOME_NAME
            });

    // return 2 tensors
    //    1. binary_porn scores with shape [N, 2], which contains propability porn or not porn on image
    //    2. gruesome scores with shape [N, 2], which contains propability gruesome or not gruesome on image
    REQUIRE(2 == result.size(), "Invalid output tensors number");
    REQUIRE((2 == result[0].dims()) &&
            ((int)batchSize == result[0].dim_size(0)) &&
            (2 == result[0].dim_size(1)),
            "Invalid binary_porn scores tensor dimension");
    REQUIRE((2 == result[1].dims()) &&
            ((int)batchSize == result[1].dim_size(0)) &&
            (2 == result[1].dim_size(1)),
            "Invalid gruesome scores tensor dimension");

    const float *pBinPornScores   = static_cast<const float*>(tf::DMAHelper::base(&result[0]));
    const float *pGruesomeScores  = static_cast<const float*>(tf::DMAHelper::base(&result[1]));

    std::vector<double> results(batchSize);
    for (size_t i = 0; i < batchSize; i++) {
        results[i] = std::max(pBinPornScores[2 * i + 0], pGruesomeScores[2 * i + 0]);
    }
    return results;
}

double ForbiddenClassifier::estimateForbiddenProbability(const cv::Mat& image) const
{
    wiki::tf_inferencer::ImagesBatch batch(1, image);
    return estimateForbiddenProbability(batch)[0];
}

double ForbiddenClassifier::estimateForbiddenProbability(const common::Bytes& encodedImage) const
{
    return estimateForbiddenProbability(common::decodeImage(encodedImage));
}

double ForbiddenClassifier::estimateForbiddenProbability(const common::Blob& encodedImage) const
{
    return estimateForbiddenProbability(common::decodeImage(encodedImage));
}

} // namespace maps::mrc::classifiers
