#include <maps/wikimap/mapspro/services/mrc/libs/keypoints_matcher/include/mutual_nn_matcher.h>

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/keypoints.h>

#include <maps/libs/common/include/exception.h>

#include <utility>
#include <vector>

namespace maps::mrc::keypoints_matcher {

MutualNNMatcher::MutualNNMatcher(double threshold)
    : threshold_(threshold)
{}

MatchedPairs MutualNNMatcher::match(const common::Keypoints &kpts0, const common::Keypoints &kpts1) const {
    REQUIRE(kpts0.descriptors.cols == kpts1.descriptors.cols, "Unable match keypoints with different descriptors dims");

    const cv::Mat& desc0 = kpts0.descriptors;
    const cv::Mat& desc1 = kpts1.descriptors;

    const int N = desc0.rows;
    const int M = desc1.rows;

    cv::Mat dist;
    // поскольку принимаем на вход вектора нормализованные к 1
    // то по идеи их скалярное произведение должно по модулю
    // быть меньше 1, дабы не получить проблем с флотовой арифметикой
    // на всякий случай дополнительно делам cv::min
    cv::sqrt(2.f - 2.f * cv::min(desc0 * desc1.t(), 1.0), dist);

    const int rows = dist.rows;
    const int cols = dist.cols;
    REQUIRE(N == rows && M == cols, "Invalid size of distance matrix");

    std::vector<int> rowMinIdx(rows, -1); // индекс (колонка) минимума для строки
    std::vector<int> colMinIdx(cols, -1); // индекс (строка) минимума для колонки
    std::vector<float> colMin(cols, FLT_MAX);
    for (int row = 0; row < rows; row++) {
        float* ptr = dist.ptr<float>(row);
        float rowMin = FLT_MAX;
        for (int col = 0; col < cols; col++) {
            if (ptr[col] > threshold_) {
                continue;
            }
            if (ptr[col] < rowMin) {
                rowMinIdx[row] = col;
                rowMin = ptr[col];
            }
            if (ptr[col] < colMin[col]) {
                colMinIdx[col] = row;
                colMin[col] = ptr[col];
            }
        }
    }

    MatchedPairs result;
    for (size_t idx0 = 0; idx0 < rowMinIdx.size(); idx0++) {
        if (-1 == rowMinIdx[idx0]) {
            continue;
        }
        if (colMinIdx[rowMinIdx[idx0]] == (int)idx0) {
            result.push_back({idx0, rowMinIdx[idx0]});
        }
    }
    return result;
}

} // maps::mrc::keypoints_matcher
