#include <maps/wikimap/mapspro/services/mrc/libs/signdetect/include/signdetect_complex.h>

#include <maps/wikimap/mapspro/libs/tf_inferencer/tf_inferencer.h>
#include <maps/wikimap/mapspro/libs/tf_inferencer/faster_rcnn_inferencer.h>
#include <maps/libs/common/include/exception.h>

#include <opencv2/opencv.hpp>
#include <library/cpp/resource/resource.h>

#include <utility>

namespace maps::mrc::signdetect {

namespace {
static const std::set<traffic_signs::TrafficSign> SPEED_LIMITS_WITH_VALUE = {
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed5,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed10,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed20,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed30,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed40,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed50,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed60,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed70,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed80,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed90,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed100,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed110,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed120,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed130,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed15,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed25,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed35,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed45,
    traffic_signs::TrafficSign::ProhibitoryMaxSpeed55
};

bool isSpeedLimit(traffic_signs::TrafficSign ts) {
    return (traffic_signs::TrafficSign::ProhibitoryMaxSpeed == ts) ||
           (0 < SPEED_LIMITS_WITH_VALUE.count(ts));
}

void changeSpeedLimitByNumber(DetectedSigns& signs) {
    for (int i = 0; i < (int)signs.size(); i++) {
        DetectedSign& sign = signs[i];
        if (!isSpeedLimit(sign.sign)) {
            continue;
        }
        if (sign.sign != traffic_signs::TrafficSign::ProhibitoryMaxSpeed &&
            sign.confidence > sign.numberConfidence) {
            continue;
        }
        int number = 0;
        try {
            number = std::stoi(sign.number);
        } catch (...) {
        }
        switch (number) {
        case 5:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed5;
            break;
        case 10:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed10;
            break;
        case 15:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed15;
            break;
        case 20:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed20;
            break;
        case 25:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed25;
            break;
        case 30:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed30;
            break;
        case 35:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed35;
            break;
        case 40:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed40;
            break;
        case 45:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed45;
            break;
        case 50:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed50;
            break;
        case 55:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed55;
            break;
        case 60:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed60;
            break;
        case 70:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed70;
            break;
        case 80:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed80;
            break;
        case 90:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed90;
            break;
        case 100:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed100;
            break;
        case 110:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed110;
            break;
        case 120:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed120;
            break;
        case 130:
            sign.sign = traffic_signs::TrafficSign::ProhibitoryMaxSpeed130;
            break;
        default:
            signs.erase(signs.begin() + i);
            i--;
            break;
        }
    }
}

} // namespace

SignDetectorComplex::SignDetectorComplex()
{
    evalSupportedSigns();
}

SignDetectorComplex::~SignDetectorComplex() = default;


void SignDetectorComplex::evalSupportedSigns() {
    const std::vector<traffic_signs::TrafficSign>& internalDetectorSigns = fasterRCNNDetector_.supportedSigns();

    std::set<traffic_signs::TrafficSign> signsSet = {internalDetectorSigns.begin(), internalDetectorSigns.end()};
    signsSet.erase(traffic_signs::TrafficSign::ProhibitoryMaxSpeed);
    signsSet.insert(SPEED_LIMITS_WITH_VALUE.begin(), SPEED_LIMITS_WITH_VALUE.end());
    supportedSigns_ = {signsSet.begin(), signsSet.end()};
}

DetectedSigns SignDetectorComplex::detect(const cv::Mat& image) const
{
    DetectedSigns signs = fasterRCNNDetector_.detect(image);
    changeSpeedLimitByNumber(signs);
    return signs;
}

const std::vector<traffic_signs::TrafficSign>& SignDetectorComplex::supportedSigns() const
{
    return supportedSigns_;
}

} // namespace maps::mrc::signdetect
