#include <maps/wikimap/mapspro/services/mrc/libs/keypoints_matcher/include/mutual_nn_matcher.h>
#include <maps/wikimap/mapspro/services/mrc/libs/keypoints_matcher/include/superglue_matcher.h>

#include <library/cpp/testing/gtest/gtest.h>
#include <maps/libs/common/include/exception.h>

#include <string>

namespace maps::mrc::keypoints_matcher::tests {

namespace {
    cv::Mat fromDegreeAngles(const std::vector<float>& phi) {
        cv::Mat desc(phi.size(), 2, CV_32FC1);
        for (size_t i = 0; i < phi.size(); i++) {
            desc.at<float>(i, 0) = cos(phi[i] * M_PI / 180.);
            desc.at<float>(i, 1) = sin(phi[i] * M_PI / 180.);
        }
        return desc;
    }

    void l2NormalizeRows(cv::Mat& mat) {
        const int rows = mat.rows;
        for (int row = 0; row < rows; row++) {
            cv::normalize(mat.rowRange(row, row + 1), mat.rowRange(row, row + 1), 1., 0., cv::NORM_L2);
        }
    }

    common::Keypoints generateKeypoints(int imageWidth, int imageHeight, int featureDims) {
        constexpr int gridSize = 8;
        const int ptsCnt = (imageWidth / gridSize) * (imageHeight / gridSize);

        maps::mrc::common::Keypoints kpts;
        kpts.imageWidth = (size_t) imageWidth;
        kpts.imageHeight= (size_t) imageHeight;
        kpts.coords.create(ptsCnt, 1, CV_32FC2); // N x 2
        kpts.descriptors.create(ptsCnt, featureDims, CV_32FC1); // N x 256
        kpts.scores.resize(ptsCnt); // N
        for (int i = 0; i < ptsCnt; i++) {
            kpts.coords.at<float>(i, 1) = i / (imageWidth / gridSize) * gridSize + gridSize / 2;
            kpts.coords.at<float>(i, 0) = i % (imageWidth / gridSize) * gridSize + gridSize / 2;
        }
        cv::RNG rng(42);
        rng.fill(kpts.descriptors, cv::RNG::UNIFORM, 0, 1);
        l2NormalizeRows(kpts.descriptors);
        rng.fill(kpts.scores, cv::RNG::UNIFORM, 0, 1);
        return kpts;
    }
}

TEST(keypoints_matcher, basic)
{
    MutualNNMatcher matcher(1.0);
    common::Keypoints kpts0;
    kpts0.descriptors.create(10, 1, CV_32FC1);
    common::Keypoints kpts1;
    kpts1.descriptors.create(10, 2, CV_32FC1);
    // дескрипторы должны быть векторами одинаковой размерности
    // для точек которые матчим
    EXPECT_THROW(matcher.match(kpts0, kpts1), maps::Exception);
}

TEST(keypoints_matcher, mutual_nn_matcher)
{
    /*
        Дескрипторы размерности 2. Посколько мы работаем только с
        нормализованными дескрипторами, то это точки на единичной
        окружности и определяются углом
            1.  0; 20; 60; 90
            2. 45; 50; 120
        I. Порог 0.2. Должны сматчиться:
            <2, 1>
        II. Порог 0.6. Должны сматчиться:
            <2, 1>; <3, 2>
    */
    std::vector<float> phi0 = {0, 20, 60, 90};
    std::vector<float> phi1 = {45, 50, 120};

    common::Keypoints kpts0;
    kpts0.descriptors = fromDegreeAngles(phi0);
    common::Keypoints kpts1;
    kpts1.descriptors = fromDegreeAngles(phi1);

    MutualNNMatcher matcher1(0.1);
    MatchedPairs pairs = matcher1.match(kpts0, kpts1);
    EXPECT_EQ(pairs.size(), 0u);

    MutualNNMatcher matcher2(0.2);
    pairs = matcher2.match(kpts0, kpts1);
    EXPECT_EQ(pairs.size(), 1u);
    EXPECT_THAT(pairs, testing::UnorderedElementsAreArray({std::pair(2, 1)}));

    MutualNNMatcher matcher3(0.6);
    pairs = matcher3.match(kpts0, kpts1);
    EXPECT_EQ(pairs.size(), 2u);
    EXPECT_THAT(pairs, testing::UnorderedElementsAreArray({std::pair(2, 1), std::pair(3, 2)}));
}

TEST(keypoints_matcher, superglue_basic)
{
    SuperglueMatcher matcher(0.0);
    maps::mrc::common::Keypoints kpts0 = generateKeypoints(320, 240, 256);
    maps::mrc::common::Keypoints kpts1 = kpts0;
    EXPECT_GT(matcher.match(kpts0, kpts1).size(), 0u);
}

} // namespace maps::mrc::keypoints_matcher::tests
