#include <maps/wikimap/mapspro/services/mrc/libs/carsegm/include/carsegm.h>

#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

#include <opencv2/opencv.hpp>

#include <vector>

namespace tf = tensorflow;
namespace tfi = maps::wiki::tf_inferencer;

namespace maps {
namespace mrc {
namespace carsegm {

namespace {

const std::string TF_MODEL_RESOURCE = "/maps/mrc/carsegm/models/tf_model.gdef";

} // namespace

CarSegmentator::CarSegmentator()
    : maskRCNNinferencer_(tfi::MaskRCNNInferencer::fromResource(TF_MODEL_RESOURCE))
{}

cv::Mat CarSegmentator::segment(const cv::Mat &image) const
{
    std::vector<cv::Mat> imagesBatch(1, image);
    return segment(imagesBatch)[0];
}

std::vector<cv::Mat> CarSegmentator::segment(const std::vector<cv::Mat>& imagesBatch) const {
    const float SCORE_THRESHOLD                = 0.5f;
    const float MASK_THRESHOLD                 = 0.5f;

    std::vector<cv::Mat> imagesRgbBatch(imagesBatch.size());
    for (size_t i = 0; i < imagesBatch.size(); i++) {
        cv::cvtColor(imagesBatch[i], imagesRgbBatch[i], cv::COLOR_BGR2RGB);
    }

    return maskRCNNinferencer_.segment(imagesRgbBatch, SCORE_THRESHOLD, MASK_THRESHOLD);
}

} // carsegm
} // mrc
} // maps
