#include "utils.h"

#include <library/cpp/testing/gtest/gtest.h>
#include <library/cpp/testing/common/env.h>
#include <maps/wikimap/mapspro/services/mrc/libs/sensors_feature_positioner/object_positioner/object_intersections_positioner.h>

namespace maps::mrc::sensors_feature_positioner::tests {

using namespace maps::mrc::pos_improvment;

using maps::geolib3::Point2;
using maps::geolib3::Vector2;
using maps::geolib3::Vector3;

namespace {

std::vector<ObjectIntersections> sortByRayId(std::vector<ObjectIntersections> objects) {
    for (auto& objectIntersections : objects) {
        for (auto& objectIntersection : objectIntersections) {
            if (objectIntersection.ray1.rayId > objectIntersection.ray2.rayId) {
                std::swap(objectIntersection.ray1, objectIntersection.ray2);
                std::swap(objectIntersection.metersFromCamera1,
                          objectIntersection.metersFromCamera2);
            }
        }
        std::sort(objectIntersections.begin(), objectIntersections.end(),
                  [](const ObjectIntersection& lhs, const ObjectIntersection& rhs) {
                      if (lhs.ray1.rayId == rhs.ray1.rayId) {
                          return lhs.ray2.rayId < rhs.ray2.rayId;
                      }
                      return lhs.ray1.rayId < rhs.ray1.rayId;
                  });
    }
    std::sort(objects.begin(), objects.end(),
              [](const ObjectIntersections& lhs, const ObjectIntersections& rhs) {
                  REQUIRE(lhs.size(), "empty ObjectPositions array");
                  REQUIRE(rhs.size(), "empty ObjectPositions array");
                  if (lhs[0].ray1.rayId == rhs[0].ray1.rayId) {
                      return lhs[0].ray2.rayId < rhs[0].ray2.rayId;
                  }
                  return lhs[0].ray1.rayId < rhs[0].ray1.rayId;
              });

    return objects;
}

} // anonymous namespace

// at this point 2 mercator meters = 1 real meter
const Point2 P0 = geolib3::geoPoint2Mercator(Point2{40, 60});

TEST(rays_intersection_positioner_tests, test_one_ray)
{
    auto signType = traffic_signs::TrafficSign::ProhibitoryNoParking;

    geolib3::Point2 cameraMercPos = P0;
    geolib3::Point2 cameraOdoMercPos = P0 + Vector2{5, 15};
    pos_improvment::UnitVector3 directionToObject(Vector3(0, 1, 0)); // to North
    double metersToObject = 20;
    RayId rayId = 45;
    db::TId featureId = 37;

    Ray ray = createRay(signType,
                        cameraMercPos,
                        cameraOdoMercPos,
                        directionToObject,
                        metersToObject,
                        rayId,
                        featureId);

    Rays rays{ray};
    auto objects = locateObjectsUsingIntersections(rays);

    ASSERT_EQ(objects.size(), 0u);
}

TEST(rays_intersection_positioner_tests, test_two_parallel_rays)
{
    auto signType = traffic_signs::TrafficSign::ProhibitoryNoParking;

    geolib3::Point2 cameraMercPos = P0;
    geolib3::Point2 cameraOdoMercPos = P0 + Vector2{5, 15};
    pos_improvment::UnitVector3 directionToObject(Vector3(0, 1, 0)); // to North
    double metersToObject = 20;
    RayId rayId = 45;
    db::TId featureId = 37;

    Ray ray1 = createRay(signType,
                         cameraMercPos,
                         cameraOdoMercPos,
                         directionToObject,
                         metersToObject,
                         rayId,
                         featureId);

    Ray ray2 = createRay(signType,
                         cameraMercPos,
                         cameraOdoMercPos + Vector2(0.1, 0),
                         directionToObject,
                         metersToObject,
                         rayId + 1,
                         featureId + 1);

    Rays rays{ray1, ray2};
    auto objects = locateObjectsUsingIntersections(rays);

    ASSERT_EQ(objects.size(), 0u);
}

TEST(rays_intersection_positioner_tests, test_rays_intersection)
{
    auto signType = traffic_signs::TrafficSign::ProhibitoryNoParking;

    geolib3::Point2 cameraMercPos = P0;
    geolib3::Point2 cameraOdoMercPos = P0 + Vector2{5, 15};
    pos_improvment::UnitVector3 directionToObject(Vector3(0, 1, 0)); // to North
    double metersToObject = 20;
    RayId rayId = 45;
    db::TId featureId = 37;

    Ray ray1 = createRay(signType,
                         cameraMercPos,
                         cameraOdoMercPos,
                         directionToObject,
                         metersToObject,
                         rayId,
                         featureId);

    pos_improvment::UnitVector3 directionToObject2(Vector3(3, 4, 0)); // to East-North

    Ray ray2 = createRay(signType,
                         cameraMercPos - Vector2(30, 0),
                         cameraOdoMercPos - Vector2(30, 0),
                         directionToObject2,
                         metersToObject,
                         rayId + 1,
                         featureId + 1);

    Rays rays{ray1, ray2};
    auto objects = locateObjectsUsingIntersections(rays);
    objects = sortByRayId(objects);

    ASSERT_EQ(objects.size(), 1u);
    ASSERT_EQ(objects[0].size(), 1u);
    EXPECT_EQ(objects[0][0].ray1.rayId, ray1.rayId);
    EXPECT_EQ(objects[0][0].ray2.rayId, ray2.rayId);
    EXPECT_NEAR(objects[0][0].metersFromCamera1, 40.0 / 2, 0.001);
    EXPECT_NEAR(objects[0][0].metersFromCamera2, 50.0 / 2, 0.001);
    EXPECT_NEAR(objects[0][0].odoMercatorPos.x(), cameraOdoMercPos.x(), 0.0001);
    EXPECT_NEAR(objects[0][0].odoMercatorPos.y(), cameraOdoMercPos.y() + 40, 0.0001);
    EXPECT_NEAR(objects[0][0].odoMercatorPos.z(), 0, 0.0001);

    // check that objects of different types don't intersect
    rays[1].objectTypeId = static_cast<size_t>(traffic_signs::TrafficSign::ProhibitoryNoEntry);
    objects = locateObjectsUsingIntersections(rays);
    ASSERT_EQ(objects.size(), 0u);

    pos_improvment::UnitVector3 directionToObject3(Vector3(-3, 4, 0.1)); // to West-North
    Ray ray3 = createRay(signType,
                         cameraMercPos + Vector2(30, 0),
                         cameraOdoMercPos + Vector2(30, 0),
                         directionToObject3,
                         metersToObject,
                         rayId + 2,
                         featureId + 2);
    rays = Rays{ray1, ray2, ray3};
    objects = locateObjectsUsingIntersections(rays);
    objects = sortByRayId(objects);

    ASSERT_EQ(objects.size(), 1u);
    ASSERT_EQ(objects[0].size(), 3u);

    EXPECT_EQ(objects[0][0].ray1.rayId, ray1.rayId);
    EXPECT_EQ(objects[0][0].ray2.rayId, ray2.rayId);
    EXPECT_NEAR(objects[0][0].metersFromCamera1, 40.0 / 2, 0.001);
    EXPECT_NEAR(objects[0][0].metersFromCamera2, 50.0 / 2, 0.001);
    EXPECT_NEAR(objects[0][0].odoMercatorPos.x(), cameraOdoMercPos.x(), 0.0001);
    EXPECT_NEAR(objects[0][0].odoMercatorPos.y(), cameraOdoMercPos.y() + 40, 0.0001);
    EXPECT_NEAR(objects[0][0].odoMercatorPos.z(), 0, 0.0001);

    EXPECT_EQ(objects[0][1].ray1.rayId, ray1.rayId);
    EXPECT_EQ(objects[0][1].ray2.rayId, ray3.rayId);
    EXPECT_NEAR(objects[0][1].metersFromCamera1, 40.0 / 2, 0.1);
    EXPECT_NEAR(objects[0][1].metersFromCamera2, 50.0 / 2, 0.1);

    EXPECT_EQ(objects[0][2].ray1.rayId, ray2.rayId);
    EXPECT_EQ(objects[0][2].ray2.rayId, ray3.rayId);

    Point2 resultObjectPos = aggregatedMercatorPosition(objects[0]);
    EXPECT_NEAR(resultObjectPos.x(), cameraMercPos.x(), 0.1);
    EXPECT_NEAR(resultObjectPos.y(), cameraMercPos.y() + 40, 0.1);

    pos_improvment::UnitVector3 directionToObject4(Vector3(0, 40, 10)); // to West-North
    geolib3::Point2 cameraMercPos4 = cameraMercPos + Vector2{3, 0};
    geolib3::Point2 cameraOdoMercPos4 = cameraOdoMercPos + Vector2{3, 0};
    Ray ray4 = createRay(signType,
                         cameraMercPos4,
                         cameraOdoMercPos4,
                         directionToObject4,
                         metersToObject,
                         rayId + 3,
                         featureId + 3);
    // ray4 doesn't intersect with provious rays

    rays = Rays{ray1, ray2, ray3, ray4};
    objects = locateObjectsUsingIntersections(rays);
    objects = sortByRayId(objects);
    ASSERT_EQ(objects.size(), 1u);
    ASSERT_EQ(objects[0].size(), 3u); // still 1 intersection of 3 rays

    pos_improvment::UnitVector3 directionToObject5(Vector3(-30, 40, 10)); // to West-North
    Ray ray5 = createRay(signType,
                         cameraMercPos4 + Vector2(30, 0),
                         cameraOdoMercPos4 + Vector2(30, 0),
                         directionToObject5,
                         metersToObject,
                         rayId + 4,
                         featureId + 5);
    // ray5 intersects with ray4

    rays = Rays{ray1, ray2, ray3, ray4, ray5};
    objects = locateObjectsUsingIntersections(rays);
    objects = sortByRayId(objects);

    ASSERT_EQ(objects.size(), 2u);

    ASSERT_EQ(objects[0].size(), 3u);
    EXPECT_EQ(objects[0][0].ray1.rayId, ray1.rayId);
    EXPECT_EQ(objects[0][0].ray2.rayId, ray2.rayId);
    EXPECT_EQ(objects[0][1].ray1.rayId, ray1.rayId);
    EXPECT_EQ(objects[0][1].ray2.rayId, ray3.rayId);
    EXPECT_EQ(objects[0][2].ray1.rayId, ray2.rayId);
    EXPECT_EQ(objects[0][2].ray2.rayId, ray3.rayId);

    ASSERT_EQ(objects[1].size(), 1u);
    EXPECT_EQ(objects[1][0].ray1.rayId, ray4.rayId);
    EXPECT_EQ(objects[1][0].ray2.rayId, ray5.rayId);
}

} // namespace maps::mrc::sensors_feature_positioner::tests

