#include <maps/wikimap/mapspro/services/mrc/libs/roadmark_detector/include/roadmarkdetector.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/faster_rcnn_inferencer.h>

#include <opencv2/opencv.hpp>

#include <vector>
#include <map>
#include <list>

namespace maps {
namespace mrc {
namespace roadmarkdetector {

using namespace wiki::tf_inferencer;
using namespace maps::mrc::traffic_signs;

namespace {

const std::string TF_MODEL_RESOURCE = "/maps/mrc/roadmark_detector/models/tf_model.gdef";

const std::map<std::string, TrafficSign> ROAD_MARKS_STRINGS = {
    {"1.18.1_Russian_road_marking", TrafficSign::RoadMarkingLaneDirectionF},
    {"1.18.2_Russian_road_marking", TrafficSign::RoadMarkingLaneDirectionR},
    {"1.18.3_Russian_road_marking", TrafficSign::RoadMarkingLaneDirectionL},
    {"1.18.4_Russian_road_marking", TrafficSign::RoadMarkingLaneDirectionFR},
    {"1.18.5_Russian_road_marking", TrafficSign::RoadMarkingLaneDirectionFL},
    {"1.18.6_Russian_road_marking", TrafficSign::RoadMarkingLaneDirectionRL}
};

TrafficSign fromString(const std::string &str) {
    std::map<std::string, TrafficSign>::const_iterator cit = ROAD_MARKS_STRINGS.find(str);
    if (cit != ROAD_MARKS_STRINGS.end())
        return cit->second;
    return TrafficSign::Unknown;
}

} // namespace

RoadMarkDetector::RoadMarkDetector()
    : tfInferencerFasterRCNN_(
        new FasterRCNNInferencer(FasterRCNNInferencer::fromResource(TF_MODEL_RESOURCE))
    )
{
    evalSupportedTypes();
}

RoadMarkDetector::~RoadMarkDetector()
{}

void RoadMarkDetector::evalSupportedTypes() {
    static const std::string TF_OUTPUT_CLASS_NAMES_LAYER_NAME = "class_names:0";

    TensorFlowInferencer tfInferencerClasses = TensorFlowInferencer::fromResource(TF_MODEL_RESOURCE);

    std::vector<TString> strTypes = tensorToVector<TString>(
            tfInferencerClasses.inference(TF_OUTPUT_CLASS_NAMES_LAYER_NAME)
        );

    supportedTypes_.resize(strTypes.size());
    for(size_t i = 0; i < strTypes.size(); i++) {
        supportedTypes_[i] = fromString(strTypes[i].c_str());
    }
}

RoadMarkDetectionVector RoadMarkDetector::detect(const cv::Mat &image) const {
    const float SCORE_THRESHOLD                = 0.7f;

    cv::Mat imageRGB;
    cv::cvtColor(image, imageRGB, cv::COLOR_BGR2RGB);

    wiki::tf_inferencer::FasterRCNNResults results =
        tfInferencerFasterRCNN_->inference(imageRGB, SCORE_THRESHOLD);

    RoadMarkDetectionVector detected;
    for (FasterRCNNResults::const_iterator it = results.cbegin(); it != results.cend(); ++it) {
        detected.push_back({it->bbox, supportedTypes_[it->classID - 1], it->confidence});
    }
    return detected;
}

} // carsegm
} // mrc
} // maps
