#include <maps/wikimap/mapspro/services/mrc/eye/lib/detect_traffic_light/tests/common.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detect_traffic_light/tests/fixture.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detect_traffic_light/include/worker.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/detect_object/include/test_util.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/unit_test/include/frame.h>

#include <maps/wikimap/mapspro/services/mrc/libs/traffic_light_detector/include/traffic_light_faster_rcnn.h>

#include <library/cpp/testing/unittest/registar.h>
#include <mapreduce/yt/tests/yt_unittest_lib/yt_unittest_lib.h>

#include <deque>

namespace maps::mrc::eye::tests {

std::deque<traffic_light_detector::DetectedTrafficLights> detectionList()
{
    traffic_light_detector::DetectedTrafficLights first {
        {cv::Rect{100, 100, 100,100}, 1.0},
    };

    traffic_light_detector::DetectedTrafficLights second {};

    traffic_light_detector::DetectedTrafficLights third {
        {cv::Rect{100, 100, 50, 50}, 1.0},
        {cv::Rect{400, 100, 100, 100}, 1.0},
        {cv::Rect{600, 100, 100, 100}, 1.0},
    };

    return {first, second, third};
}

std::deque<cv::Mat> maskList()
{
    cv::Mat first = cv::Mat::zeros(1920, 1080, CV_64F);

    cv::Mat third = cv::Mat::zeros(1080, 1920, CV_64F);
    third(cv::Rect{300, 0, 300, 300}) = cv::Mat::ones(300, 300, CV_64F);

    return {first, /* second skiped */ third};
}

// use traffic light as simple object
struct Detector: public TestDetector<traffic_light_detector::DetectedTrafficLight> {
    Detector(): TestDetector(detectionList()) {}
};

struct Segmentator: public TestSegmentator {
    Segmentator(): TestSegmentator(maskList()) {}
};

using Worker = DetectObjectWithFilteringWorker<
    Detector,
    Segmentator,
    MakeTrafficLightRecognition
>;

REGISTER_MAPPER(Worker);


Y_UNIT_TEST_SUITE_F(worker, Fixture)
{

Y_UNIT_TEST(detect_object_with_filtering)
{
    const auto loader = FrameLoader::fromConfig(config());
    const auto yt = NYT::NTesting::CreateTestClient();

    // keep order of frames
    SimpleDetectObjectConfig config;
    config.yt.client = yt.Get();
    config.yt.frameLoader = &loader;
    config.yt.rootPath = "//tmp/worker/run";
    config.yt.partitionSize = 10; // all frames on one worker
    config.yt.concurrency = 1;
    config.yt.useGpu = false;


    const auto recognitions = simpleDetectObject<Worker>(config, "test", {frames[1], frames[2], frames[3]});
    UNIT_ASSERT_EQUAL(recognitions.size(), 3u);

    const auto list = detectionList();

    {
        const auto frame = frames.at(1);
        const auto recognition = recognitions.at(0);
        const auto detections = list.at(0);

        UNIT_ASSERT_EQUAL(recognition.orientation(), frame.orientation());
        UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectTrafficLight);
        UNIT_ASSERT_EQUAL(recognition.source(), db::eye::RecognitionSource::Model);

        const auto trafficLights = recognition.value<db::eye::DetectedTrafficLights>();

        UNIT_ASSERT_EQUAL(trafficLights.size(), 1u);

        const common::ImageBox box = common::revertByImageOrientation(
            detections.at(0).box,
            frame.originalSize(),
            frame.orientation()
        );

        UNIT_ASSERT(hasTrafficLight(trafficLights, box));
    }

    {
        const auto frame = frames.at(2);
        const auto recognition = recognitions.at(1);

        UNIT_ASSERT_EQUAL(recognition.orientation(), frame.orientation());
        UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectTrafficLight);
        UNIT_ASSERT_EQUAL(recognition.source(), db::eye::RecognitionSource::Model);

        const auto trafficLights = recognition.value<db::eye::DetectedTrafficLights>();

        UNIT_ASSERT_EQUAL(trafficLights.size(), 0u);
    }

    {
        const auto frame = frames.at(3);
        const auto recognition = recognitions.at(2);
        const auto detections = list.at(2);

        UNIT_ASSERT_EQUAL(recognition.orientation(), frame.orientation());
        UNIT_ASSERT_EQUAL(recognition.type(), db::eye::RecognitionType::DetectTrafficLight);
        UNIT_ASSERT_EQUAL(recognition.source(), db::eye::RecognitionSource::Model);

        const auto trafficLights = recognition.value<db::eye::DetectedTrafficLights>();

        UNIT_ASSERT_EQUAL(trafficLights.size(), 2u);
        {
            const common::ImageBox box = common::revertByImageOrientation(
                detections.at(0).box,
                frame.originalSize(),
                frame.orientation()
            );

            UNIT_ASSERT(hasTrafficLight(trafficLights, box));
        }

        // filtering  one box

        {
            const common::ImageBox box = common::revertByImageOrientation(
                detections.at(2).box,
                frame.originalSize(),
                frame.orientation()
            );

            UNIT_ASSERT(hasTrafficLight(trafficLights, box));
        }
    }
}

} // Y_UNIT_TEST_SUITE

} // namespace maps::mrc::eye::tests
