#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/tests/fixture.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/common/include/id.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/detection/include/position_matcher.h>
#include <maps/wikimap/mapspro/services/mrc/eye/lib/location/include/rotation.h>

#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/frame_gateway.h>
#include <maps/wikimap/mapspro/services/mrc/libs/db/include/eye/recognition_gateway.h>

#include <maps/wikimap/mapspro/services/mrc/eye/lib/unit_test/include/frame.h>

#include <library/cpp/testing/unittest/registar.h>
#include <library/cpp/testing/gmock_in_unittest/gmock.h>

namespace maps::mrc::eye::tests {

Y_UNIT_TEST_SUITE_F(position_matcher, Fixture)
{

struct BaseMatchingFixture: public Fixture {
    BaseMatchingFixture() {
        devices = insertx<db::eye::DeviceGateway>(
            db::eye::Devices {
                {db::eye::MrcDeviceAttrs{""}}
            }
        );

        frames = insertx<db::eye::FrameGateway>(
            db::eye::Frames {
                {devices[0].id(), identical, makeUrlContext(1, "1"), {1200, 800}, time()},
                {devices[0].id(), identical, makeUrlContext(2, "2"), {1200, 800}, time()},
                {devices[0].id(), identical, makeUrlContext(3, "3"), {3600, 2400}, time()},
                {devices[0].id(), identical, makeUrlContext(4, "4"), {3600, 2400}, time()},
            }
        );

        locations = insertx<db::eye::FrameLocationGateway>(
            db::eye::FrameLocations {
                {frames[0].id(), geolib3::Point2{0, 0}, toRotation(geolib3::Heading(90), identical)},
                {frames[1].id(), geolib3::Point2{20, 0}, toRotation(geolib3::Heading(90), identical)},
                {frames[2].id(), geolib3::Point2{0, 100}, toRotation(geolib3::Heading(90), identical)},
                {frames[3].id(), geolib3::Point2{0, 10}, toRotation(geolib3::Heading(90), identical)},
            }
        );

        for (const auto& frame : frames) {
            privacies.emplace_back(frame.id(), db::FeaturePrivacy::Public);
        }
        privacies = insertx<db::eye::FramePrivacyGateway>(std::move(privacies));
    }

    db::eye::Devices devices;
    db::eye::Frames frames;
    db::eye::FrameLocations locations;
    db::eye::FramePrivacies privacies;
    db::eye::DetectionGroups groups;
    db::eye::Detections detections;
};

struct HouseNumberMatcingFixture: public BaseMatchingFixture {
    HouseNumberMatcingFixture() {
        groups = insertx<db::eye::DetectionGroupGateway>(
            db::eye::DetectionGroups {
                {frames[0].id(), db::eye::DetectionType::HouseNumber},
                {frames[1].id(), db::eye::DetectionType::HouseNumber},
                {frames[2].id(), db::eye::DetectionType::HouseNumber},
                {frames[3].id(), db::eye::DetectionType::HouseNumber},
            }
        );

        detections = insertx<db::eye::DetectionGateway>(
            db::eye::Detections {
                {groups[0].id(), db::eye::DetectedHouseNumber{{100, 100, 200, 200}, 1.0, "12"}},
                {groups[0].id(), db::eye::DetectedHouseNumber{{400, 100, 500, 200}, 1.0, "12"}},
                {groups[1].id(), db::eye::DetectedHouseNumber{{100, 100, 200, 200}, 1.0, "12"}},
                {groups[2].id(), db::eye::DetectedHouseNumber{{100, 100, 200, 200}, 1.0, "12"}},
                {groups[3].id(), db::eye::DetectedHouseNumber{{100, 100, 200, 200}, 1.0, "11"}},
            }
        );
    }
};

Y_UNIT_TEST_F(house_number, HouseNumberMatcingFixture)
{
    const DetectionStore store {
        byId(groups),
        byId(detections),
        byId(frames),
        byId(locations),
        byId(privacies),
        byId(devices)
    };
    const DetectionIdPairSet detectionPairs{
        {detections[0].id(), detections[1].id()},
        {detections[0].id(), detections[2].id()},
        {detections[0].id(), detections[3].id()},
        {detections[1].id(), detections[2].id()},
        {detections[1].id(), detections[3].id()},
        {detections[2].id(), detections[3].id()},
    };

    PositionDetectionMatcher matcher;
    MatchedFrameDetections result = matcher.makeMatches(store, detectionPairs);

    const MatchedFrameDetections expected{
        {
            {frames[0].id(), detections[0].id()},
            {frames[1].id(), detections[2].id()},
            1.f
        },
        {
            {frames[0].id(), detections[1].id()},
            {frames[1].id(), detections[2].id()},
            1.f
        },
    };

    EXPECT_EQ(result.size(), expected.size());

    for (size_t i = 0; i < result.size(); i++) {
        EXPECT_EQ(result[i].id0().frameId, expected[i].id0().frameId);
        EXPECT_EQ(result[i].id0().detectionId, expected[i].id0().detectionId);
        EXPECT_EQ(result[i].id1().frameId, expected[i].id1().frameId);
        EXPECT_EQ(result[i].id1().detectionId, expected[i].id1().detectionId);
        EXPECT_EQ(result[i].relevance(), 1.f);
    }
}

struct SignMatcingFixture: public BaseMatchingFixture {
    SignMatcingFixture() {
        groups = insertx<db::eye::DetectionGroupGateway>(
            db::eye::DetectionGroups {
                {frames[0].id(), db::eye::DetectionType::Sign},
                {frames[1].id(), db::eye::DetectionType::Sign},
                {frames[2].id(), db::eye::DetectionType::Sign},
                {frames[3].id(), db::eye::DetectionType::Sign},
            }
        );

        using Type = traffic_signs::TrafficSign;

        detections = insertx<db::eye::DetectionGateway>(
            db::eye::Detections {
                {
                    groups[0].id(), db::eye::DetectedSign {
                        {100, 100, 200, 200},
                        Type::MandatoryTurnLeft, 1.0,
                        false, 1.0
                    }
                },
                {
                    groups[0].id(), db::eye::DetectedSign {
                        {400, 100, 500, 200},
                        Type::MandatoryTurnLeft, 1.0,
                        false, 1.0
                    }
                },
                {
                    groups[1].id(), db::eye::DetectedSign {
                        {100, 100, 200, 200},
                        Type::MandatoryTurnLeft, 1.0,
                        false, 1.0
                    }
                },
                {
                    groups[2].id(), db::eye::DetectedSign {
                        {100, 100, 200, 200},
                        Type::MandatoryTurnLeft, 1.0,
                        false, 1.0
                    }
                },
                {
                    groups[3].id(), db::eye::DetectedSign {
                        {100, 100, 200, 200},
                        Type::MandatoryTurnRight, 1.0,
                        false, 1.0
                    }
                },
            }
        );
    }
};

Y_UNIT_TEST_F(sign, SignMatcingFixture)
{
    const DetectionStore store {
        byId(groups),
        byId(detections),
        byId(frames),
        byId(locations),
        byId(privacies),
        byId(devices)
    };

    const DetectionIdPairSet detectionPairs{
        {detections[0].id(), detections[1].id()},
        {detections[0].id(), detections[2].id()},
        {detections[0].id(), detections[3].id()},
        {detections[1].id(), detections[2].id()},
        {detections[1].id(), detections[3].id()},
        {detections[2].id(), detections[3].id()},
    };

    PositionDetectionMatcher matcher;
    MatchedFrameDetections result = matcher.makeMatches(store, detectionPairs);

    const MatchedFrameDetections expected{
        {
            {frames[0].id(), detections[0].id()},
            {frames[1].id(), detections[2].id()},
            1.f
        },
    };

    EXPECT_EQ(result.size(), expected.size());

    for (size_t i = 0; i < result.size(); i++) {
        EXPECT_EQ(result[i].id0().frameId, expected[i].id0().frameId);
        EXPECT_EQ(result[i].id0().detectionId, expected[i].id0().detectionId);
        EXPECT_EQ(result[i].id1().frameId, expected[i].id1().frameId);
        EXPECT_EQ(result[i].id1().detectionId, expected[i].id1().detectionId);
        EXPECT_EQ(result[i].relevance(), 1.f);
    }
}

struct TrafficLightMatcingFixture: public BaseMatchingFixture {
    TrafficLightMatcingFixture() {
        groups = insertx<db::eye::DetectionGroupGateway>(
            db::eye::DetectionGroups {
                {frames[0].id(), db::eye::DetectionType::TrafficLight},
                {frames[1].id(), db::eye::DetectionType::TrafficLight},
                {frames[2].id(), db::eye::DetectionType::TrafficLight},
                {frames[3].id(), db::eye::DetectionType::TrafficLight},
            }
        );

        detections = insertx<db::eye::DetectionGateway>(
            db::eye::Detections {
                {groups[0].id(), db::eye::DetectedTrafficLight{{100, 100, 130, 200}, 1.0}},
                {groups[0].id(), db::eye::DetectedTrafficLight{{400, 100, 430, 200}, 1.0}},
                {groups[1].id(), db::eye::DetectedTrafficLight{{100, 100, 130, 200}, 1.0}},
                {groups[2].id(), db::eye::DetectedTrafficLight{{100, 100, 130, 200}, 1.0}},
            }
        );
    }
};

Y_UNIT_TEST_F(traffic_light, TrafficLightMatcingFixture)
{
    const DetectionStore store {
        byId(groups),
        byId(detections),
        byId(frames),
        byId(locations),
        byId(privacies),
        byId(devices)
    };

    const DetectionIdPairSet detectionPairs{
        {detections[0].id(), detections[1].id()},
        {detections[0].id(), detections[2].id()},
        {detections[0].id(), detections[3].id()},
        {detections[1].id(), detections[2].id()},
        {detections[1].id(), detections[3].id()},
        {detections[2].id(), detections[3].id()},
    };

    PositionDetectionMatcher matcher;
    MatchedFrameDetections result = matcher.makeMatches(store, detectionPairs);

    const MatchedFrameDetections expected{
        {
            {frames[0].id(), detections[0].id()},
            {frames[1].id(), detections[2].id()},
            1.f
        },
        {
            {frames[0].id(), detections[1].id()},
            {frames[1].id(), detections[2].id()},
            1.f
        },
    };

    EXPECT_EQ(result.size(), expected.size());

    for (size_t i = 0; i < result.size(); i++) {
        EXPECT_EQ(result[i].id0().frameId, expected[i].id0().frameId);
        EXPECT_EQ(result[i].id0().detectionId, expected[i].id0().detectionId);
        EXPECT_EQ(result[i].id1().frameId, expected[i].id1().frameId);
        EXPECT_EQ(result[i].id1().detectionId, expected[i].id1().detectionId);
        EXPECT_EQ(result[i].relevance(), 1.f);
    }
}

struct RoadMarkingMatchingFixture: public BaseMatchingFixture {
    RoadMarkingMatchingFixture() {
        groups = insertx<db::eye::DetectionGroupGateway>(
            db::eye::DetectionGroups {
                {frames[0].id(), db::eye::DetectionType::RoadMarking},
                {frames[1].id(), db::eye::DetectionType::RoadMarking},
                {frames[2].id(), db::eye::DetectionType::RoadMarking},
                {frames[3].id(), db::eye::DetectionType::RoadMarking},
            }
        );

        using Type = traffic_signs::TrafficSign;

        detections = insertx<db::eye::DetectionGateway>(
            db::eye::Detections {
                {
                    groups[0].id(), db::eye::DetectedRoadMarking {
                        {100, 100, 200, 200},
                        Type::RoadMarkingLaneDirectionL, 1.0
                    }
                },
                {
                    groups[0].id(), db::eye::DetectedRoadMarking {
                        {400, 100, 500, 200},
                        Type::RoadMarkingLaneDirectionL, 1.0
                    }
                },
                {
                    groups[1].id(), db::eye::DetectedRoadMarking {
                        {100, 100, 200, 200},
                        Type::RoadMarkingLaneDirectionL, 1.0
                    }
                },
                {
                    groups[2].id(), db::eye::DetectedRoadMarking {
                        {100, 100, 200, 200},
                        Type::RoadMarkingLaneDirectionL, 1.0
                    }
                },
                {
                    groups[3].id(), db::eye::DetectedRoadMarking {
                        {100, 100, 200, 200},
                        Type::RoadMarkingLaneDirectionR, 1.0
                    }
                },
            }
        );
    }
};

Y_UNIT_TEST_F(road_marking, RoadMarkingMatchingFixture)
{
    const DetectionStore store {
        byId(groups),
        byId(detections),
        byId(frames),
        byId(locations),
        byId(privacies),
        byId(devices)
    };

    const DetectionIdPairSet detectionPairs{
        {detections[0].id(), detections[1].id()},
        {detections[0].id(), detections[2].id()},
        {detections[0].id(), detections[3].id()},
        {detections[1].id(), detections[2].id()},
        {detections[1].id(), detections[3].id()},
        {detections[2].id(), detections[3].id()},
    };

    PositionDetectionMatcher matcher;
    MatchedFrameDetections result = matcher.makeMatches(store, detectionPairs);

    const MatchedFrameDetections expected{
        {
            {frames[0].id(), detections[0].id()},
            {frames[1].id(), detections[2].id()},
            1.f
        },
        {
            {frames[0].id(), detections[1].id()},
            {frames[1].id(), detections[2].id()},
            1.f
        },
    };

    ASSERT_EQ(result.size(), expected.size());

    for (size_t i = 0; i < result.size(); i++) {
        EXPECT_EQ(result[i].id0().frameId, expected[i].id0().frameId);
        EXPECT_EQ(result[i].id0().detectionId, expected[i].id0().detectionId);
        EXPECT_EQ(result[i].id1().frameId, expected[i].id1().frameId);
        EXPECT_EQ(result[i].id1().detectionId, expected[i].id1().detectionId);
        EXPECT_EQ(result[i].relevance(), 1.f);
    }
}

} // Y_UNIT_TEST_SUITE

} // namespace maps::mrc::eye::tests
