#include <maps/wikimap/mapspro/services/mrc/libs/privacy_detector/include/privacy_detector_faster_rcnn.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/tf_inferencer.h>
#include <maps/wikimap/mapspro/libs/tf_inferencer/faster_rcnn_inferencer.h>
#include <maps/libs/common/include/exception.h>

#include <opencv2/opencv.hpp>
#include <library/cpp/resource/resource.h>

#include <utility>

namespace maps {
namespace mrc {
namespace privacy_detector {

using namespace wiki::tf_inferencer;

namespace {

const std::string FRCNN_TF_MODEL_RESOURCE = "/maps/mrc/privacy_detector/models/tf_model_faster_rcnn.gdef";

} // namespace


FasterRCNNDetector::FasterRCNNDetector()
    : tfInferencerFasterRCNN_(
        new FasterRCNNInferencer(FasterRCNNInferencer::fromResource(FRCNN_TF_MODEL_RESOURCE))
    ) {
    evalSupportedTypes();
}

FasterRCNNDetector::~FasterRCNNDetector() = default;

void FasterRCNNDetector::evalSupportedTypes() {
    static const std::string TF_OUTPUT_CLASS_NAMES_LAYER_NAME = "class_names:0";

    TensorFlowInferencer tfInferencerClasses = TensorFlowInferencer::fromResource(FRCNN_TF_MODEL_RESOURCE);

    std::vector<TString> strTypes = tensorToVector<TString>(
            tfInferencerClasses.inference(TF_OUTPUT_CLASS_NAMES_LAYER_NAME)
        );

    supportedTypes_.resize(strTypes.size());
    for(size_t i = 0; i < strTypes.size(); i++) {
        fromString(strTypes[i].c_str(), supportedTypes_[i]);
    }
}

PrivacyImageBoxes FasterRCNNDetector::detect(const cv::Mat& image) const
{
    std::vector<cv::Mat> images(1, image);
    return detect(images)[0];
}

std::vector<PrivacyImageBoxes> FasterRCNNDetector::detect(const std::vector<cv::Mat>& images) const
{
    constexpr float SCORE_THRESHOLD = 0.5f;

    std::vector<cv::Mat> imagesRGB(images.size());
    for (size_t i = 0; i < images.size(); i++) {
        cv::cvtColor(images[i], imagesRGB[i], cv::COLOR_BGR2RGB);
    }

    std::vector<FasterRCNNResults> batchResults = tfInferencerFasterRCNN_->inference(imagesRGB, SCORE_THRESHOLD);

    std::vector<PrivacyImageBoxes> batchDetected(batchResults.size());
    for (size_t imgIdx = 0; imgIdx < batchResults.size(); imgIdx++) {
        PrivacyImageBoxes &detected = batchDetected[imgIdx];
        const FasterRCNNResults &results = batchResults[imgIdx];
        for (FasterRCNNResults::const_iterator it = results.cbegin(); it != results.cend(); ++it) {
            detected.push_back({it->bbox, supportedTypes_[it->classID - 1], it->confidence});
        }
    }
    return batchDetected;
}

} // namespace signdetect
} // namespace mrc
} // namespace maps
