#include "sfm_positioning.h"

#include <maps/libs/geolib/include/conversion.h>
#include <maps/libs/geolib/include/distance.h>
#include <maps/libs/log8/include/log8.h>

#include <opencv2/features2d.hpp>

#include <algorithm>
#include <map>
#include <vector>

namespace maps::mrc::experiment_sfm_positioning {

namespace {

constexpr float NEAREST_NEIGHBOUR_MATCH_RATIO = 0.75f;
constexpr float INLIER_DISTANCE_PIXELS = 3.0;
constexpr float OBJECTS_MATCHING_RELATIVE_ERROR_THRESHOLD = 0.25;

bool cameraTurnSanityCheck(const cv::Mat& /* cam2LocalPose */)
{
    // TODO:
    // 1. Check the move is non-negative
    // 2. Check the camera didn't turn around Z axis too much
    // 3. Check the camera didn't turn around X axis too much
    // 4. Check the camera didn't turn around Y axis too much
    return true;
}

using MatchedPoints
    = std::pair<std::vector<cv::Point2f>, std::vector<cv::Point2f>>;

MatchedPoints matchBRISKPoints(const cv::Mat& img1,
                               const cv::Mat& img1Mask,
                               const cv::Mat& img2,
                               const cv::Mat& img2Mask)
{
    const auto detector = cv::BRISK::create();
    std::vector<cv::KeyPoint> kPoints1, kPoints2;
    cv::Mat descriptors1, descriptors2;
    detector->detectAndCompute(img1, img1Mask, kPoints1, descriptors1);
    detector->detectAndCompute(img2, img2Mask, kPoints2, descriptors2);

    cv::BFMatcher bfMatcher{cv::NORM_HAMMING};
    std::vector<std::vector<cv::DMatch>> nearMatches1, nearMatches2;
    bfMatcher.knnMatch(descriptors1, descriptors2, nearMatches1, 2);
    bfMatcher.knnMatch(descriptors2, descriptors1, nearMatches2, 2);

    // Apply ratio test as explained by D.Lowe in his SIFT paper.
    // https://docs.opencv.org/3.2.0/db/d70/tutorial_akaze_matching.html
    const auto ratioTestedPoints = [](
        const std::vector<std::vector<cv::DMatch>>& nearMatches) {
        std::map<int, cv::DMatch> matches;
        for (const auto& couple : nearMatches) {
            const auto& match = couple[0];
            if (couple[0].distance
                < NEAREST_NEIGHBOUR_MATCH_RATIO * couple[1].distance) {
                if (!matches.count(match.queryIdx)
                    || match.distance < matches[match.queryIdx].distance) {
                    matches[match.queryIdx] = match;
                }
            }
        }
        return matches;
    };

    const auto matches1 = ratioTestedPoints(nearMatches1);
    const auto matches2 = ratioTestedPoints(nearMatches2);

    // Do the cross check. I.e. choose points that are matched in both
    // directions.
    MatchedPoints result;
    for (const auto& [ idx1, match1 ] : matches1) {
        const auto idx2 = match1.trainIdx;
        if (matches2.count(idx2) && matches2.at(idx2).trainIdx == idx1) {
            result.first.push_back(kPoints1[idx1].pt);
            result.second.push_back(kPoints2[idx2].pt);
        }
    }

    return result;
}

MatchedPoints extractInliers(const MatchedPoints& points,
                             const cv::Mat& mask)
{
    REQUIRE(mask.type() == CV_8UC1, "The mask element type must be a byte");
    REQUIRE(points.first.size() == static_cast<std::size_t>(mask.rows),
        "The mask must have same number of rows as points count");

    MatchedPoints result;
    for (auto idx = 0; idx < mask.rows; ++idx) {
        if (mask.at<std::uint8_t>(idx)) {
            result.first.push_back(points.first.at(idx));
            result.second.push_back(points.second.at(idx));
        }
    }
    return result;
}

cv::Point3f makeCentroid3f(const mrc::common::ImageBox& bbox)
{
    const cv::Rect2f rect = static_cast<cv::Rect>(bbox);
    return cv::Point3f{
        rect.x + rect.width / 2.0f, rect.y + rect.height / 2.0f, 1.0f};
};

cv::Point2f makeCentroid2f(const mrc::common::ImageBox& bbox)
{
    const cv::Rect2f rect = static_cast<cv::Rect>(bbox);
    return cv::Point2f{
        rect.x + rect.width / 2.0f, rect.y + rect.height / 2.0f};
};

float bboxSize(const mrc::common::ImageBox& bbox)
{
    const float width = bbox.maxX() - bbox.minX();
    const float height = bbox.maxY() - bbox.minY();
    return std::max(width, height);
}

struct MatchedObject {
    const ObjectInfo& object1;
    const ObjectInfo& object2;
    std::pair<cv::Point2f, cv::Point2f> points;
};
using MatchedObjects = std::vector<MatchedObject>;

// Compute a distance from one object's bbox center to another object's
// epipolar line that goes through its center. Do this for the both objects
// and average the results.
float meanRelativeError(
    const cv::Mat& fundMat, const ObjectInfo& obj1, const ObjectInfo& obj2)
{
    const auto point1 = makeCentroid3f(obj1.bbox);
    const auto point2 = makeCentroid3f(obj2.bbox);

    // Note: the line coefficients are normilized so that a^2 + b^2 = 1
    //       and the distance from a 2D point to a 2D line formula is
    //       |a*x + b*y + c| / (a^2 + b^2)
    cv::Mat line1, line2;
    cv::computeCorrespondEpilines(
        cv::Mat{point1}.reshape(3), 1 /* image num */, fundMat, line1);
    cv::computeCorrespondEpilines(
        cv::Mat{point2}.reshape(3), 2 /* image num */, fundMat, line2);

    const float dist1to2 = cv::abs(line2.at<cv::Point3f>(0).dot(point1));
    const float error1 = dist1to2 / bboxSize(obj1.bbox);
    const float dist2to1 = cv::abs(line1.at<cv::Point3f>(0).dot(point2));
    const float error2 = dist2to1 / bboxSize(obj2.bbox);
    const float meanError = (error1 + error2) / 2.0f;
    INFO() << "Matching relative error: mean(" << error1 << ", " << error2
           << ") = " << meanError;
    return meanError;
}

MatchedObjects matchObjects(
    const cv::Mat& fundMat,
    const ObjectsInfo& objects1,
    const ObjectsInfo& objects2)
{
    // There is an alternative way to match objects. Put a Cartesian product
    // of their centroids to the set of points used to find a fundamental
    // matrix. But it is expected that their quantity is small enought
    // otherwise this may break the whole process.
    MatchedObjects result;

    struct ObjectMatch {
        const ObjectInfo* object1;
        const ObjectInfo* object2;
        float meanError;
    };
    std::vector<ObjectMatch> objectMatches;

    for (const auto& object1 : objects1) {
        for (const auto& object2 : objects2) {
            const auto error = meanRelativeError(fundMat, object1, object2);
            if (error <= OBJECTS_MATCHING_RELATIVE_ERROR_THRESHOLD) {
                objectMatches.push_back(
                    ObjectMatch{&object1, &object2, error});
            }
        }
    }

    std::sort(objectMatches.begin(), objectMatches.end(),
        [](const ObjectMatch& lhs, const ObjectMatch& rhs) {
            return lhs.meanError < rhs.meanError;
        });

    // Choose greedily a set of matches. The choice might be suboptimal but
    // it's OK for now (it is expected to be quite small).
    std::set<const ObjectInfo*> chosen;
    for (const auto& match : objectMatches) {
        if (!chosen.count(match.object1) && !chosen.count(match.object2)) {
            chosen.insert(match.object1);
            chosen.insert(match.object2);
            result.push_back(MatchedObject{*match.object1, *match.object2,
                std::make_pair(makeCentroid2f(match.object1->bbox),
                               makeCentroid2f(match.object2->bbox))});
        }
    }

    return result;
}

// Returns 3x4 rotation|translation matrix of the second camera
cv::Mat recover2ndCamPose(cv::Mat camMat,
                          cv::Mat fundMat,
                          const MatchedPoints& points,
                          cv::Mat inliersMask)
{
    // Obtain the essential matrix
    cv::Mat essentialMat = camMat.t() * fundMat * camMat;

    // Compute rotation matrix and translation vector of the second camera
    // in the coordinate system of the first camera.
    cv::Mat rotationMat2, translationMat2;
    cv::recoverPose(essentialMat, points.first, points.second, camMat,
        rotationMat2, translationMat2, inliersMask);

    cv::Mat result;
    cv::hconcat(rotationMat2, translationMat2, result);
    return result;
}

struct TriangulatedObject {
    const ObjectInfo& object1;
    const ObjectInfo& object2;
    cv::Point3f position;
};
using TriangulatedObjects = std::vector<TriangulatedObject>;

// Returns a list of matched objects with 3d points on them in the 1st
// camera's coordinate system
TriangulatedObjects triangulateObjects(cv::Mat camMat,
                                       cv::Mat cam2Pose,
                                       const MatchedObjects& matchedObjects)
{
    // The first camera Rt matrix defines the coordinate system origin. So its
    // rotations part is an ID matrix and the translation part is just zero
    // vector.
    double data1[12] = {1, 0, 0, 0,
                        0, 1, 0, 0,
                        0, 0, 1, 0};
    cv::Mat cam1Pose{3, 4, CV_64F, data1};

    cv::Mat projMat1 = camMat * cam1Pose;
    cv::Mat projMat2 = camMat * cam2Pose;

    std::vector<cv::Point2f> points1, points2;
    for (const auto& match : matchedObjects) {
        points1.push_back(match.points.first);
        points2.push_back(match.points.second);
    }

    cv::Mat points4d{4, static_cast<int>(points1.size()), CV_64F};
    cv::triangulatePoints(projMat1, projMat2, points1, points2, points4d);
    points4d = points4d.t();
    points4d = points4d.reshape(4);
    cv::Mat points3d;
    cv::convertPointsFromHomogeneous(points4d.reshape(4), points3d);

    TriangulatedObjects result;
    for (std::size_t idx = 0; idx < matchedObjects.size(); ++idx) {
        const auto& pos = points3d.at<cv::Point3f>(idx);
        result.push_back(TriangulatedObject{
            matchedObjects[idx].object1, matchedObjects[idx].object2, pos});
    }

    return result;
}

PositionedObjects computeObjectPositions(
    const TriangulatedObjects& triangulatedObjects,
    const CameraPose& cam1Pose,
    MercatorUnits scale)
{
    PositionedObjects result;
    // https://st.yandex-team.ru/MAPSMRC-511#5a7c678c141a11001a0f1092
    constexpr Meters MAX_DISTANCE_TO_OBJECT{40};

    // Note: the scale factor is taken from the distance between two GPS
    //       coordinates which are quite inaccurate.

    for (const auto& triangulatedObject : triangulatedObjects) {
        cv::Mat position3d
            = cv::Mat{triangulatedObject.position} * scale.value();
        // Align (rotate) position to the global coordinate system
        position3d = cam1Pose.rotation * position3d;
        // See https://docs.opencv.org/3.2.0/d9/d0c/group__calib3d.html
        position3d.at<float>(1) = 0; // project the point to the earth plane

        const Meters distance = toMeters(cam1Pose.mercatorPos,
            MercatorUnits{static_cast<float>(cv::norm(position3d))});
        if (distance > MAX_DISTANCE_TO_OBJECT) {
            continue;
        }

        const geolib3::Point2 mercatorPos{
            cam1Pose.mercatorPos.x() + position3d.at<float>(0),
            cam1Pose.mercatorPos.y() + position3d.at<float>(2)};

        result.push_back(PositionedObject{triangulatedObject.object1,
            triangulatedObject.object2, mercatorPos});

        // TODO: compute an object plane normal
    }

    return result;
}

} // anonymous namespace

PositionedObjects computeObjectPositionsWith2Images(
    cv::Mat camMat,              // a camera matrix
    cv::Mat img1,                // the 1st image
    cv::Mat img1Mask,            // the 1st image features mask
    const CameraPose& cam1Pose,  // the 1st camera geo position
    const ObjectsInfo& objects1, // the 1st image objects
    cv::Mat img2,                // the 2nd image
    cv::Mat img2Mask,            // the 2nd image features mask
    const CameraPose& cam2Pose,  // the 2nd camera geo position
    const ObjectsInfo& objects2, // the 2nd image objects
    std::size_t featuresLimit)
{
    REQUIRE(camMat.data && img1.data && img2.data, "Invalid input data");
    REQUIRE(img1.type() == img2.type() && img1.type() == CV_8UC1,
        "Images must be in grayscale color space");
    REQUIRE(8 <= featuresLimit,
        "May not position anything with less then eight matched points");

    // https://st.yandex-team.ru/MAPSMRC-511#5a7c678c141a11001a0f1092
    const MercatorUnits distance{
        geolib3::distance(cam1Pose.mercatorPos, cam2Pose.mercatorPos)};

    const auto matchedPoints
        = matchBRISKPoints(img1, img1Mask, img2, img2Mask);

    // Compute fundamental matrix. The RANSAC approach requires at least 8
    // points to compute a fundamental matrix.
    if (matchedPoints.first.size() < featuresLimit) {
        return {};
    }

    INFO() << "Positioning objects with " << matchedPoints.first.size()
           << " feature points";
    cv::Mat inliersMask;
    auto fundMat = cv::findFundamentalMat(
        matchedPoints.first,
        matchedPoints.second,
        CV_FM_RANSAC,
        INLIER_DISTANCE_PIXELS, /* distance from an inlier match to its
                                   epipolar line */
        0.999f,                 /* RANSAC probability */
        inliersMask);

    cv::Mat cam2LocalPose
        = recover2ndCamPose(camMat, fundMat, matchedPoints, inliersMask);
    if (!cameraTurnSanityCheck(cam2LocalPose)) {
        return {};
    }

    const auto inliers = extractInliers(matchedPoints, inliersMask);
    // TODO: Check cosine distance between inliers' responses. But first find
    //       out which value to use as a threshold to decide if images are
    //       similar enough.
    const auto matchedObjects = matchObjects(fundMat, objects1, objects2);

    if (matchedObjects.empty()) {
        INFO() << "Failed to match any of objects on this pair";
        return {};
    }

    const auto triangulatedObjects
        = triangulateObjects(camMat, cam2LocalPose, matchedObjects);

    return computeObjectPositions(triangulatedObjects, cam1Pose, distance);
}

bool isCameraMoveInAllowedRange(
    const CameraPose& pose1, const CameraPose& pose2)
{
    // TODO: check if it needs to do something smarter then just checking for
    //       thresholds of camera move and rotation angle.
    constexpr Meters CAMERA_MOVE_THRESHOLD_MAX{30};
    constexpr Meters CAMERA_MOVE_THRESHOLD_MIN{3};
    constexpr double CAMERA_THRESHOLD_15_DEG_COSINE = 0.965925;

    const MercatorUnits mercatorDistance{
        geolib3::distance(pose1.mercatorPos, pose2.mercatorPos)};

    const Meters distance = toMeters(pose1.mercatorPos, mercatorDistance);
    if (distance <= CAMERA_MOVE_THRESHOLD_MIN
        || CAMERA_MOVE_THRESHOLD_MAX <= distance) {
        return false;
    }

    const cv::Mat z{cv::Point3f{0, 0, 1}};
    cv::Mat z1 = pose1.rotation * z;
    cv::Mat z2 = pose2.rotation * z;
    const auto cos = z1.dot(z2);
    return CAMERA_THRESHOLD_15_DEG_COSINE < cos;
}

} // namespace maps::mrc::experiment_sfm_positioning
