#include <maps/wikimap/mapspro/services/mrc/libs/house_number_sign_detector/include/house_number_sign_detector.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/faster_rcnn_inferencer.h>
#include <maps/libs/common/include/exception.h>

#include <opencv2/opencv.hpp>
#include <library/cpp/resource/resource.h>

#include <utility>

namespace maps {
namespace mrc {
namespace house_number_sign_detector {

using namespace wiki::tf_inferencer;

namespace {

const std::string TF_MODEL_RESOURCE = "/maps/mrc/house_number_sign_detector/models/tf_model_faster_rcnn.gdef";
const std::string TF_MODEL_RECOGNIZER_RESOURCE = "/maps/mrc/house_number_sign_detector/models/tf_model_recognizer.gdef";

} // namespace

number_recognizer::NumberRecognizer makeHouseNumberRecognizer() {
    return number_recognizer::NumberRecognizer(TF_MODEL_RECOGNIZER_RESOURCE);
}


FasterRCNNDetector::FasterRCNNDetector()
    : houseNumberSignClassID_(-1)
    , tfInferencerFasterRCNN_(
        new FasterRCNNInferencer(FasterRCNNInferencer::fromResource(TF_MODEL_RESOURCE))
    )
    , numberRecognizer_(TF_MODEL_RECOGNIZER_RESOURCE)
{
    evalHouseNumberSignID();
}

FasterRCNNDetector::~FasterRCNNDetector() = default;

void FasterRCNNDetector::evalHouseNumberSignID() {
    static const std::string TF_OUTPUT_CLASS_NAMES_LAYER_NAME = "class_names:0";
    static const TString     HOUSE_NUMBER_SIGN_CLASS_NAME     = "house_number_sign";

    TensorFlowInferencer tfInferencerClasses = TensorFlowInferencer::fromResource(TF_MODEL_RESOURCE);

    std::vector<TString> strTypes = tensorToVector<TString>(
            tfInferencerClasses.inference(TF_OUTPUT_CLASS_NAMES_LAYER_NAME)
        );

    for(size_t i = 0; i < strTypes.size(); i++) {
        if (HOUSE_NUMBER_SIGN_CLASS_NAME == strTypes[i])
            houseNumberSignClassID_ = i + 1;
    }
}

HouseNumberSigns FasterRCNNDetector::detect(const cv::Mat& image, RecognizeNumber recognizeNumber) const
{
    std::vector<cv::Mat> images(1, image);
    return detect(images, recognizeNumber)[0];
}

std::vector<HouseNumberSigns> FasterRCNNDetector::detect(const std::vector<cv::Mat>& images, RecognizeNumber recognizeNumber) const
{
    constexpr float SCORE_THRESHOLD = 0.5f;
    constexpr float RECOGNIZE_THRESHOLD = 0.85f;

    std::vector<cv::Mat> imagesRGB(images.size());
    for (size_t i = 0; i < images.size(); i++) {
        cv::cvtColor(images[i], imagesRGB[i], cv::COLOR_BGR2RGB);
    }

    std::vector<FasterRCNNResults> batchResults = tfInferencerFasterRCNN_->inference(imagesRGB, SCORE_THRESHOLD);

    std::vector<HouseNumberSigns> batchDetected(batchResults.size());
    for (size_t imgIdx = 0; imgIdx < batchResults.size(); imgIdx++) {
        HouseNumberSigns &detected = batchDetected[imgIdx];
        const FasterRCNNResults &results = batchResults[imgIdx];
        for (FasterRCNNResults::const_iterator it = results.cbegin(); it != results.cend(); ++it) {
            if (houseNumberSignClassID_ != it->classID)
                continue;
            HouseNumberSign sign;
            sign.box = it->bbox;
            sign.confidenceDetector = it->confidence;
            sign.confidenceRecognizer = 1.;
            if (recognizeNumber == RecognizeNumber::Yes) {
                std::pair<std::string, float> number = numberRecognizer_.recognize(images[imgIdx](it->bbox));
                if (RECOGNIZE_THRESHOLD >= number.second)
                    continue;
                sign.number = number.first;
                sign.confidenceRecognizer = number.second;
            }
            sign.confidence = sign.confidenceDetector * sign.confidenceRecognizer;
            detected.emplace_back(sign);
        }
    }
    return batchDetected;
}



} // namespace house_number_sign_detector
} // namespace mrc
} // namespace maps
