#include "object_intersections_positioner.h"

#include "camera_azimuth_calibration.h"
#include "object_rays_intersection.h"
#include "utils.h"

#include <maps/libs/log8/include/log8.h>
#include <maps/libs/geolib/include/distance.h>
#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>

#include <vector>

namespace maps::mrc::sensors_feature_positioner {

const double SIMILAR_CAMERAS_DISTANCE_THRESHOLD = 0.5;

using GroupId = size_t;

double metersBetweenIntersections(const ObjectIntersection& intersection1,
                                  const ObjectIntersection& intersection2)
{
    double mercatorDistance = geolib3::length(
        geolib3::Vector3(
            intersection1.odoMercatorPos.x() - intersection2.odoMercatorPos.x(),
            intersection1.odoMercatorPos.y() - intersection2.odoMercatorPos.y(),
            intersection1.odoMercatorPos.z() - intersection2.odoMercatorPos.z()));
    return geolib3::toMeters(
        mercatorDistance,
        intersection1.ray1.cameraPos.mercatorPosition());
}

const double ALLOWED_SIGN_POSITION_ERROR = 10; //meters

// If one rays intersects with many rays, and the distance beetween
// 2 intersections > ALLOWED_SIGN_POSITION_ERROR, then it removes the
// intersection with less score.
std::vector<RayIntersections> filterIntersections(
    std::vector<RayIntersections> allIntersections,
    std::function<double(const ObjectIntersection&)> intersectionScore,
    double minScoreDifference)
{
    std::unordered_set<IntersectionId> intersectionsToDelete;

    for (int ray1Index = allIntersections.size() - 1; ray1Index >= 0; ray1Index--) {
        const auto& curRayIntersections = allIntersections[ray1Index].intersections;
        for (auto intersection1 = curRayIntersections.begin();
             intersection1 != curRayIntersections.end();
             ++intersection1)
        {
            if (intersectionsToDelete.count(intersection1->id)) {
                continue;
            }
            for (auto intersection2 = intersection1 + 1;
                 intersection2 != curRayIntersections.end();
                 ++intersection2)
            {
                if (intersectionsToDelete.count(intersection2->id)) {
                    continue;
                }
                double score1 = intersectionScore(*intersection1);
                double score2 = intersectionScore(*intersection2);

                if (intersection1->ray2.featureId == intersection2->ray2.featureId
                    || (std::abs(intersection2->metersFromCamera1
                                - intersection1->metersFromCamera1)
                        > ALLOWED_SIGN_POSITION_ERROR))
                {
                    if (score1 - score2 > minScoreDifference){
                        intersectionsToDelete.insert(intersection2->id);
                    } else if (score2 - score1 > minScoreDifference) {
                        intersectionsToDelete.insert(intersection1->id);
                        break;
                    }
                }
            }
        }
    }

    for (int rayIndex = allIntersections.size() - 1; rayIndex >= 0; rayIndex--) {
        auto& curRayIntersections = allIntersections[rayIndex].intersections;
        curRayIntersections.erase(
            std::remove_if(
                curRayIntersections.begin(), curRayIntersections.end(),
                [=] (const ObjectIntersection& intersection) {
                    return intersectionsToDelete.count(intersection.id);
                }),
            curRayIntersections.end());
    };

    return allIntersections;
}

// Each ray can have many intersections with other rays. If all the
// intersection points are located in a small area, we consider that
// all the intersection correspond to one real object. If the distance
// between intersections is big, we consider that some
// intersections are accidental, and there are no real objects at these
// points.
//
// Filter 1:
// Lets take one ray and handle all its intersections. If there are
// two close intersections and a separate third intersection, we
// delete the third intersection.
// Filter2:
// Lets take one ray and handle all its intersections. If one
// intersection has good accuracy and a second remote intersection has
// bad accuracy, we delete the second intersection. (The accuracy is
// the distance between rays + other properties)
//
// Returns filtered intersections.
std::vector<RayIntersections> filterGoodIntersections(
    std::vector<RayIntersections> allIntersections)
{
    // distance to the nearest intersection for each intersection
    std::unordered_map<IntersectionId, double> distanceToNearestIntersection;

    for (int ray1Index = allIntersections.size() - 1; ray1Index >= 0; ray1Index--) {
        const auto& curRayIntersections = allIntersections[ray1Index].intersections;
        for (const auto& intersection1 : curRayIntersections) {
            if (!distanceToNearestIntersection.count(intersection1.id)) {
                distanceToNearestIntersection[intersection1.id]
                    = std::numeric_limits<double>::max();
            }
            for (const auto& intersection2 : curRayIntersections) {
                if ((metersBetweenCameras(intersection1.ray2,
                                          intersection2.ray2)
                     < SIMILAR_CAMERAS_DISTANCE_THRESHOLD)) {
                    continue;
                }
                double distanceBetweenIntersections
                    = metersBetweenIntersections(intersection1, intersection2);

                distanceToNearestIntersection[intersection1.id] = std::min(
                    distanceToNearestIntersection[intersection1.id],
                    distanceBetweenIntersections);
            }
        }
    }

    allIntersections = filterIntersections(
        allIntersections,
        [&] (const ObjectIntersection& intersection) {
            return -distanceToNearestIntersection[intersection.id];
        },
        1);

    allIntersections = filterIntersections(
        allIntersections,
        [&] (const ObjectIntersection& intersection) {
            return -(intersection.errorRadians.value() * 40
                     + intersection.errorMeters * 4
                     + std::abs(intersection.metersFromCamera1
                                - intersection.ray1.metersToObject));
        },
        0);

    return allIntersections;
}

// Groups close intersections and returns groups.
// Each group describes a single object.
std::unordered_map<IntersectionId, GroupId> getGroupedObjects(
    const std::vector<RayIntersections>& allIntersections) {

    std::unordered_map<IntersectionId, std::vector<IntersectionId>> connectedIntersections;

    // if one ray has several intersections, we consider that all
    // these intersections describe one object.
    for (int ray1Index = allIntersections.size() - 1; ray1Index >= 0; ray1Index--) {
        const auto& curRayIntersections = allIntersections[ray1Index].intersections;
        for (size_t i = 0; i < curRayIntersections.size(); i++) {
            for (size_t j = i + 1; j < curRayIntersections.size(); j++) {
                connectedIntersections[curRayIntersections[i].id].push_back(
                    curRayIntersections[j].id);
                connectedIntersections[curRayIntersections[j].id].push_back(
                    curRayIntersections[i].id);
            }
        }
    }

    std::unordered_map<IntersectionId, GroupId> intersectionIdToGroupId;

    // add the intersection to the group and also recursively adds all the connected
    // intersections to the group
    std::function<void(IntersectionId, size_t)> addIntersectionToGroup
        = [&] (IntersectionId intersectionId, size_t groupId) {
            if (intersectionIdToGroupId.count(intersectionId)) {
                return;
            }
            intersectionIdToGroupId[intersectionId] = groupId;
            for (size_t connectedIntersectionId : connectedIntersections[intersectionId]) {
                addIntersectionToGroup(connectedIntersectionId, groupId);
            }
        };

    size_t nextGroupId = 0;

    // group objects into connected components
    for (int ray1Index = allIntersections.size() - 1; ray1Index >= 0; ray1Index--) {
        const auto& curRayIntersections = allIntersections[ray1Index].intersections;
        for (const auto& intersection : curRayIntersections) {
            if (!intersectionIdToGroupId.count(intersection.id)) {
                addIntersectionToGroup(intersection.id, nextGroupId);
                nextGroupId++;
            }
        }
    }

    return intersectionIdToGroupId;
}

std::vector<ObjectIntersections> getGroupsAsVectors(
    const std::unordered_map<IntersectionId, GroupId>& intersectionIdToGroupId,
    const std::vector<RayIntersections>& allIntersections)
{
    std::unordered_set<IntersectionId> addedIntersections;
    std::vector<ObjectIntersections> objectGroups;
    for (int ray1Index = allIntersections.size() - 1; ray1Index >= 0; ray1Index--) {
        const auto& curRayIntersections = allIntersections[ray1Index].intersections;
        for (const auto& intersection : curRayIntersections) {
            if (addedIntersections.insert(intersection.id).second) {
                size_t groupId = intersectionIdToGroupId.at(intersection.id);
                if (objectGroups.size() < groupId + 1) {
                    objectGroups.resize(groupId + 1);
                }
                objectGroups[groupId].push_back(intersection);
            }
        }
    }

    return objectGroups;

}

std::vector<ObjectIntersections> locateObjectsUsingIntersections(Rays rays)
{
    std::sort(rays.begin(), rays.end(),
              [](const Ray& lhs, const Ray& rhs) {
                  return lhs.cameraPos.timestamp() < rhs.cameraPos.timestamp();
              });
    rays = calibrateRaysAzimuth(std::move(rays));

    std::vector<RayIntersections> intersections(rays.size());
    IdGenerator idGenerator;

    for (size_t i = 0; i < rays.size(); i++) {
        intersections[i].rayId = rays[i].rayId;
        for (size_t j = i + 1; j < rays.size(); j++) {
            if (metersBetweenCameras(rays[i], rays[j])
                > MAX_METERS_BETWEEN_CAMERAS)
            {
                break;
            }
            auto intersection = createObjectFrom2Photos(rays[i], rays[j],
                                                      idGenerator);
            if (!intersection) {
                continue;
            }

            // intersection[i].intersections[X].ray1 == rays[i];
            intersections[i].intersections.push_back(*intersection);
            std::swap(intersection->metersFromCamera1,
                      intersection->metersFromCamera2);
            std::swap(intersection->ray1,
                      intersection->ray2);
            // intersection[j].intersections[X].ray1 == rays[j];
            intersections[j].intersections.push_back(*intersection);
        }
    }

    intersections = filterGoodIntersections(intersections);
    auto groups = getGroupedObjects(intersections);
    return getGroupsAsVectors(groups, intersections);
}

geolib3::Point2 aggregatedMercatorPosition(const ObjectIntersections& objectIntersections) {
    double x = 0;
    double y = 0;
    double weight = 0;
    for (const auto& intersection : objectIntersections) {
        geolib3::Radians angleBetweenCameras(
            std::acos(geolib3::innerProduct(intersection.ray1.directionToObject,
                                            intersection.ray2.directionToObject)));
        // when rays are almost parallel, small position or
        // direction error produces big intersection position error
        double curWeight = std::pow(angleBetweenCameras.value(), 2);
        x += intersection.mercatorPos().x() * curWeight;
        y += intersection.mercatorPos().y() * curWeight;
        weight += curWeight;
    }
    REQUIRE(weight > 0, "object was found using bad intersections");
    return {x / weight, y / weight};
}

} // namespace maps::mrc::sensors_feature_positioner
