#include <library/cpp/testing/gtest/gtest.h>
#include <library/cpp/testing/common/env.h>
#include <maps/libs/common/include/exception.h>
#include <maps/wikimap/mapspro/services/mrc/libs/traffic_light_detector/include/traffic_light_faster_rcnn.h>

#include <opencv2/opencv.hpp>

#include <limits>
#include <vector>

using namespace testing;

namespace maps::mrc::traffic_light_detector {

namespace tests {

namespace {

using TestDataPair = std::pair<std::string, DetectedTrafficLights>;

const std::vector<TestDataPair> TEST_DATA{
    {"traffic_light000.jpg", {{{802, 577, 21, 68}, 0.95}}},
    {"traffic_light001.jpg", {{{878, 459, 13, 42}, 0.95}}}
};

cv::Mat loadImage(const std::string& name) {
    static const std::string IMAGES_DIR
        = "maps/wikimap/mapspro/services/mrc/libs/traffic_light_detector/tests/images/";
    auto imagePath = static_cast<std::string>(BinaryPath(IMAGES_DIR + name));
    cv::Mat image = cv::imread(imagePath, CV_LOAD_IMAGE_COLOR);
    REQUIRE(image.data != nullptr, "Can't load image " << name);
    return image;
}

double iou(const cv::Rect& lhs, const cv::Rect& rhs) {
    const int i_area = (lhs & rhs).area();
    const int u_area = lhs.area() + rhs.area() - i_area;
    return i_area / static_cast<double>(u_area);
}

} // namespace

TEST(basic_tests, traffic_light_detector_on_reference_images)
{
    constexpr float IOU_THRESHOLD = 0.5f;

    FasterRCNNDetector detector;

    for (const auto& [path, gtTrafficLights] : TEST_DATA) {
        cv::Mat image = loadImage(path);
        DetectedTrafficLights testTrafficLights = detector.detect(image);

        EXPECT_EQ(gtTrafficLights.size(), testTrafficLights.size());

        std::set<size_t> usedObjectIndices;
        for (const DetectedTrafficLight& testTrafficLight : testTrafficLights) {
            std::optional<size_t> gtIndex;
            double maxIoU = std::numeric_limits<double>::lowest();
            for (size_t i = 0; i < gtTrafficLights.size(); i++) {
                if (usedObjectIndices.count(i) > 0) {
                    // this gt object corresponds to
                    // another test object
                    continue;
                }
                double iouValue = iou(
                    testTrafficLight.box,
                    gtTrafficLights[i].box
                );
                if (testTrafficLight.confidence < gtTrafficLights[i].confidence ||
                    iouValue < IOU_THRESHOLD)
                {
                    continue;
                }
                if (gtIndex.has_value() && iouValue > maxIoU || !gtIndex.has_value()) {
                    gtIndex = i;
                    maxIoU = iouValue;
                }
            }
            EXPECT_TRUE(gtIndex.has_value());
            if (gtIndex.has_value()) {
                usedObjectIndices.insert(gtIndex.value());
            }
        }
    }
}

} // namespace test

} // namespace maps::mrc::traffic_light_detector
