#include <maps/wikimap/mapspro/services/mrc/libs/traffic_light_detector/include/traffic_light_faster_rcnn.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/faster_rcnn_inferencer.h>

#include <opencv2/opencv.hpp>

namespace maps::mrc::traffic_light_detector {

namespace {

const std::string TRAFFIC_LIGHT_FRCNN_TF_MODEL_RESOURCE = "/maps/mrc/traffic_light_detector/models/tf_model_faster_rcnn.gdef";

} // namespace

FasterRCNNDetector::FasterRCNNDetector()
    : tfInferencerFasterRCNN_(
        new wiki::tf_inferencer::FasterRCNNInferencer(
            wiki::tf_inferencer::FasterRCNNInferencer::fromResource(TRAFFIC_LIGHT_FRCNN_TF_MODEL_RESOURCE)
        )
    )
{
}

FasterRCNNDetector::~FasterRCNNDetector() = default;

DetectedTrafficLights FasterRCNNDetector::detect(const cv::Mat& image) const {
    constexpr float SCORE_THRESHOLD = 0.7f;

    cv::Mat imageRGB;
    cv::cvtColor(image, imageRGB, cv::COLOR_BGR2RGB);

    wiki::tf_inferencer::FasterRCNNResults results =
        tfInferencerFasterRCNN_->inference(imageRGB, SCORE_THRESHOLD);

    DetectedTrafficLights detected;

    for (const auto& result : results) {
        detected.push_back({result.bbox, result.confidence});
    }

    return detected;
}

} // namespace maps::mrc::traffic_light_detector
