#include <maps/wikimap/mapspro/services/mrc/libs/signdetect/include/signdetect_faster_rcnn.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/tf_inferencer.h>
#include <maps/wikimap/mapspro/libs/tf_inferencer/faster_rcnn_inferencer.h>
#include <maps/libs/common/include/exception.h>

#include <opencv2/opencv.hpp>
#include <library/cpp/resource/resource.h>

#include <utility>

namespace tf = tensorflow;

namespace maps {
namespace mrc {
namespace signdetect {

namespace {

const std::string TEMPORARY_SIGN_TF_MODEL_RESOURCE = "/maps/mrc/signdetect/models/tf_model.gdef"; // это для классификации знаков на временные и не временные
const std::string TRAFFIC_SIGN_FRCNN_TF_MODEL_RESOURCE = "/maps/mrc/signdetect/models/tf_model_faster_rcnn.gdef";
const std::string NUMBER_RECOGNIZER_TF_MODEL_RESOURCE = "/maps/mrc/signdetect/models/tf_model_number_recognizer.gdef"; // это модель для распознавания номеров на знаках
const long TEMPORARY_SIGN_IDX_UNDEFINED = -1;

} // namespace


FasterRCNNDetector::FasterRCNNDetector()
    : tfInferencer_(
        new wiki::tf_inferencer::TensorFlowInferencer(
            wiki::tf_inferencer::TensorFlowInferencer::fromResource(TEMPORARY_SIGN_TF_MODEL_RESOURCE)
        )
    )
    , tfInferencerFasterRCNN_(
        new wiki::tf_inferencer::FasterRCNNInferencer(
            wiki::tf_inferencer::FasterRCNNInferencer::fromResource(TRAFFIC_SIGN_FRCNN_TF_MODEL_RESOURCE)
        )
    )
    , numberRecognizer_(NUMBER_RECOGNIZER_TF_MODEL_RESOURCE)
    , temporarySignIdx_(TEMPORARY_SIGN_IDX_UNDEFINED)
{
    evalSupportedSigns();
    evalTemporarySignIdx();
}

FasterRCNNDetector::~FasterRCNNDetector() = default;

void FasterRCNNDetector::evalSupportedSigns()
{
    static const std::string TF_OUTPUT_CLASS_NAMES_LAYER_NAME = "class_names:0";
    wiki::tf_inferencer::TensorFlowInferencer tfInferencerClasses =
         wiki::tf_inferencer::TensorFlowInferencer::fromResource(TRAFFIC_SIGN_FRCNN_TF_MODEL_RESOURCE);

    auto strSigns = maps::wiki::tf_inferencer::tensorToVector<TString>(
            tfInferencerClasses.inference(TF_OUTPUT_CLASS_NAMES_LAYER_NAME)
        );

    supportedSigns_.reserve(strSigns.size());
    for(const auto& strSign : strSigns) {
        supportedSigns_.push_back(traffic_signs::stringToTrafficSign(strSign));
    }
}

void FasterRCNNDetector::evalTemporarySignIdx()
{
    static const TString TEMPORARY_TYPE = "temporary";
    static const std::string TF_OUTPUT_TEMP_NAMES_LAYER_NAME = "type_names:0";

    tensorflow::Tensor tensor =
            tfInferencer_->inference(TF_OUTPUT_TEMP_NAMES_LAYER_NAME);

    tf::TTypes<TString>::Vec names = tensor.vec<TString>();
    for (long i = 0; i < names.size(); i++)
    {
        if (names(i) == TEMPORARY_TYPE) {
            temporarySignIdx_ = i;
            break;
        }
    }
    REQUIRE(TEMPORARY_SIGN_IDX_UNDEFINED != temporarySignIdx_,
            "Unable to found temporary sign name");
}

float FasterRCNNDetector::calcTemporarySignConfidence(const cv::Mat &image) const
{
    static const TString TEMPORARY_TYPE = "temporary";

    static const std::string TF_INPUT_LAYER_NAME = "inference_input:0";
    static const std::string TF_OUTPUT_LAYER_NAME = "inference_type_softmax:0";

    REQUIRE(TEMPORARY_SIGN_IDX_UNDEFINED != temporarySignIdx_,
            "Unable to found temporary sign name");
    tensorflow::Tensor tensor =
            tfInferencer_->inference(TF_INPUT_LAYER_NAME, image, TF_OUTPUT_LAYER_NAME);
    return tensor.matrix<float>()(0, temporarySignIdx_);
}


DetectedSigns FasterRCNNDetector::detect(const cv::Mat& image) const
{
    constexpr size_t DESIRED_IMAGE_SIZE = 32;
    constexpr int MIN_IMAGE_SIZE_FOR_NUMBER = 20;

    constexpr float SCORE_THRESHOLD = 0.7f;
    constexpr float TEMPORARY_THRESHOLD = 0.5f;

    cv::Mat imageRGB;
    cv::cvtColor(image, imageRGB, cv::COLOR_BGR2RGB);

    wiki::tf_inferencer::FasterRCNNResults results =
        tfInferencerFasterRCNN_->inference(imageRGB, SCORE_THRESHOLD);

    DetectedSigns detected;
    cv::Mat scaled(DESIRED_IMAGE_SIZE, DESIRED_IMAGE_SIZE, CV_8UC3);

    traffic_signs::TemporarySign temp;
    float tempConf = 0.f;
    std::string number;
    float numberConfidence = 0.f;
    for (const auto & result : results) {
        traffic_signs::TrafficSign ts = supportedSigns_.at(result.classID - 1);
        if (traffic_signs::canBeTemporary(ts)) {
            cv::resize(image(result.bbox), scaled, cv::Size(DESIRED_IMAGE_SIZE, DESIRED_IMAGE_SIZE));
            tempConf = calcTemporarySignConfidence(scaled);
            temp = traffic_signs::TemporarySign::Yes;
            if (tempConf < TEMPORARY_THRESHOLD) {
                temp = traffic_signs::TemporarySign::No;
                tempConf = 1.f - tempConf;
            }
        } else {
            temp = traffic_signs::TemporarySign::No;
            tempConf = 1.f;
        }
        if (traffic_signs::needRecognizeNumber(ts)  &&
            MIN_IMAGE_SIZE_FOR_NUMBER <= result.bbox.width &&
            MIN_IMAGE_SIZE_FOR_NUMBER <= result.bbox.height)
        {
            tie(number, numberConfidence) = numberRecognizer_.recognize(image(result.bbox));
        } else {
            number = "";
            numberConfidence = 0.f;
        }
        detected.push_back({
            result.bbox,
            ts,
            result.confidence,
            temp, tempConf,
            number, numberConfidence
        });
    }

    return detected;
}

const std::vector<traffic_signs::TrafficSign>&
FasterRCNNDetector::supportedSigns() const
{
    return supportedSigns_;
}

} // namespace signdetect
} // namespace mrc
} // namespace maps
