#include <library/cpp/testing/gtest/gtest.h>
#include <library/cpp/testing/common/env.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/common/include/file_utils.h>
#include <maps/wikimap/mapspro/services/mrc/libs/signdetect/include/signdetect_complex.h>
#include <maps/wikimap/mapspro/services/mrc/libs/signdetect/include/signdetect_faster_rcnn.h>
#include <maps/wikimap/mapspro/services/mrc/libs/signdetect/include/sign_relations.h>

#include <opencv2/imgcodecs/imgcodecs_c.h>
#include <opencv2/opencv.hpp>

#include <fstream>
#include <iostream>
#include <sstream>
#include <unordered_set>
#include <utility>
#include <vector>

using namespace testing;

namespace maps {
namespace mrc {
namespace signdetect {

namespace tests {

namespace {

using SignStringPair = std::pair<traffic_signs::TrafficSign, std::string>;

const std::vector<SignStringPair> REFERENCE_SIGN_IMAGES =
    {
        {traffic_signs::TrafficSign::ProhibitoryMaxSpeed10, "speed_limit_10.jpg"},
        {traffic_signs::TrafficSign::ProhibitoryMaxSpeed20, "speed_limit_20.jpg"},
        {traffic_signs::TrafficSign::ProhibitoryMaxSpeed40, "speed_limit_40.jpg"},
        {traffic_signs::TrafficSign::ProhibitoryMaxSpeed40, "speed_limit_40_temp.jpg"},
        {traffic_signs::TrafficSign::ProhibitoryMaxSpeed60, "speed_limit_60.jpg"},
        {traffic_signs::TrafficSign::InformationParking, "parking_info.jpg"},
        {traffic_signs::TrafficSign::ProhibitoryNoParkingOrStopping, "no_parking_or_stopping.jpg"},
        {traffic_signs::TrafficSign::ProhibitoryNoParking, "no_parking.jpg"},
        {traffic_signs::TrafficSign::ProhibitoryNoEntry, "no_entry.jpg"},
        {traffic_signs::TrafficSign::PrescriptionOneWayRoad, "one_way.jpg"},
        {traffic_signs::TrafficSign::PrescriptionEofOneWayRoad, "eof_one_way.jpg"},
        {traffic_signs::TrafficSign::PrescriptionEntryToOneWayRoadOnTheLeft, "one_way_to_left.jpg"},
        {traffic_signs::TrafficSign::PrescriptionLaneDirectionF, "lanes.jpg"},

    };

cv::Mat loadImage(const std::string& name) {
    static const std::string IMAGES_DIR =
        "maps/wikimap/mapspro/services/mrc/libs/signdetect/tests/images/";
    auto imagePath = static_cast<std::string>(BinaryPath(IMAGES_DIR + name));
    cv::Mat image = cv::imread(imagePath, CV_LOAD_IMAGE_COLOR);
    REQUIRE(image.data != nullptr, "Can't load image " << name);
    return image;
}

#define CHECK_SIGN_ON_IMAGE(detector, cat, imageName)                        \
{                                                                            \
    auto image = loadImage(imageName);                                       \
    auto features = detector.detect(image);                                  \
    EXPECT_THAT(features, Contains(Field(&DetectedSign::sign, cat)))         \
        << "sign " << cat << " not detected on image " << imageName;         \
}

} // namespace

TEST(basic_tests, fasterrcnn_detect_supported_signs_on_reference_images)
{
    FasterRCNNDetector detector;
    const auto& signs = detector.supportedSigns();

    const std::unordered_set<traffic_signs::TrafficSign>
        supportedSignsSet(signs.begin(), signs.end());

    for(const auto& signImagePair : REFERENCE_SIGN_IMAGES) {
        if (!supportedSignsSet.count(signImagePair.first)) {
            continue;
        }
        CHECK_SIGN_ON_IMAGE(detector,
                            signImagePair.first,
                            signImagePair.second);
    }
}

TEST(basic_tests, fasterrcnn_list_supported_signs)
{
    FasterRCNNDetector detector;
    const auto& supportedSigns = detector.supportedSigns();
    EXPECT_THAT(supportedSigns.size(), Gt(40u));
}

TEST(basic_tests, fasterrcnn_detect_temp_sign)
{
    FasterRCNNDetector detector;
    auto image = loadImage("speed_limit_40_temp.jpg");
    auto features = detector.detect(image);
    EXPECT_THAT(features,
        Contains(Field(&DetectedSign::temporarySign,
                          traffic_signs::TemporarySign::Yes)));
}

TEST(basic_tests, number_recognizer)
{
    FasterRCNNDetector detector;
    {
        cv::Mat image = loadImage("speed_limit_40.jpg");
        DetectedSigns signs = detector.detect(image);
        for (size_t i = 0; i < signs.size(); i++) {
            if (signs[i].sign == traffic_signs::TrafficSign::ProhibitoryMaxSpeed40) {
                EXPECT_THAT(signs[i].number, "40");
                EXPECT_THAT(signs[i].numberConfidence, Gt(0.5f));
                break;
            }
        }
    }
    {
        cv::Mat image = loadImage("no_parking.jpg");
        DetectedSigns signs = detector.detect(image);
        for (size_t i = 0; i < signs.size(); i++) {
            if (signs[i].sign == traffic_signs::TrafficSign::ProhibitoryNoParking) {
                EXPECT_THAT(signs[i].number, "");
                EXPECT_THAT(signs[i].numberConfidence, Eq(0.0f));
                break;
            }
        }
    }
}

TEST(basic_tests, complex_detect_supported_signs_on_reference_images)
{
    SignDetectorComplex detector;
    const auto& signs = detector.supportedSigns();

    const std::unordered_set<traffic_signs::TrafficSign>
        supportedSignsSet(signs.begin(), signs.end());

    for(const auto& signImagePair : REFERENCE_SIGN_IMAGES) {
        if (!supportedSignsSet.count(signImagePair.first)) {
            continue;
        }
        CHECK_SIGN_ON_IMAGE(detector,
                            signImagePair.first,
                            signImagePair.second);
    }
}

TEST(basic_tests, complex_list_supported_signs)
{
    SignDetectorComplex detector;
    const auto& supportedSigns = detector.supportedSigns();
    const std::unordered_set<traffic_signs::TrafficSign>
        supportedSignsSet(supportedSigns.begin(), supportedSigns.end());

    EXPECT_THAT(supportedSigns.size(), Gt(40u));
    EXPECT_THAT(supportedSignsSet.count(traffic_signs::TrafficSign::ProhibitoryMaxSpeed), Eq(0u));
    EXPECT_THAT(supportedSignsSet.count(traffic_signs::TrafficSign::ProhibitoryMaxSpeed40), Gt(0u));
}

TEST(basic_tests, signs_relations)
{
    const DetectedSigns detectedSigns{
        {{1197, 737, 18, 31}, traffic_signs::TrafficSign::InformationInZone},
        {{1190, 705, 30, 30}, traffic_signs::TrafficSign::ProhibitoryNoParkingOrStopping}
    };

    std::vector<std::pair<size_t, size_t>> relations = foundRelations(detectedSigns);
    EXPECT_THAT(relations.size(), Eq(1u));
    EXPECT_THAT(relations[0], Eq(std::make_pair<size_t, size_t>(0, 1)));
}

} // namespace test

} // namespace signdetect
} // namespace mrc
} // namespace maps
