#include <maps/wikimap/mapspro/services/mrc/libs/keypoints_matcher/include/superglue_matcher.h>

#include <maps/wikimap/mapspro/services/mrc/libs/common/include/keypoints.h>

#include <maps/libs/common/include/exception.h>

#include <tensorflow/core/common_runtime/dma_helper.h>

#include <utility>
#include <vector>

namespace tf = tensorflow;
namespace tfi = maps::wiki::tf_inferencer;

namespace maps::mrc::keypoints_matcher {
namespace {

const std::string TF_MODEL_RESOURCE = "/maps/mrc/superglue_matcher/models/superglue.gdef";

} // namespace


SuperglueMatcher::SuperglueMatcher(double confThreshold)
    : confThreshold_(confThreshold)
    , inferencer_(tfi::TensorFlowInferencer::fromResource(TF_MODEL_RESOURCE))
{}

MatchedPairs SuperglueMatcher::match(const common::Keypoints &kpts0, const common::Keypoints &kpts1) const {
    REQUIRE(kpts0.descriptors.cols == kpts1.descriptors.cols, "Unable match keypoints with different descriptors dims");

    const std::string TF_LAYER_POINTS_0_NAME  = "eval_points0";
    const std::string TF_LAYER_POINTS_1_NAME  = "eval_points1";

    const std::string TF_LAYER_SCORES_0_NAME  = "eval_scores0";
    const std::string TF_LAYER_SCORES_1_NAME  = "eval_scores1";

    const std::string TF_LAYER_DESCS_0_NAME   = "eval_descs0";
    const std::string TF_LAYER_DESCS_1_NAME   = "eval_descs1";

    const std::string TF_LAYER_IMG_SZ_0_NAME  = "eval_img_sz0";
    const std::string TF_LAYER_IMG_SZ_1_NAME  = "eval_img_sz1";

    const std::string TF_LAYER_THR_NAME       = "eval_thr";

    const std::string TF_LAYER_INDICES_0_NAME = "eval_indices0:0";

    tf::Tensor thresholdT(tf::DT_FLOAT, tf::TensorShape());
    *thresholdT.flat<float>().data() = confThreshold_;

    tf::Tensor imageSize0(tf::DT_FLOAT, tf::TensorShape({1, 2}));
    float* ptr = imageSize0.flat<float>().data();
    ptr[0] = (float)kpts0.imageWidth;
    ptr[1] = (float)kpts0.imageHeight;

    tf::Tensor imageSize1(tf::DT_FLOAT, tf::TensorShape({1, 2}));
    ptr = imageSize1.flat<float>().data();
    ptr[0] = (float)kpts1.imageWidth;
    ptr[1] = (float)kpts1.imageHeight;

    std::vector<tf::Tensor> result = inferencer_.inference(
        {
          {TF_LAYER_THR_NAME, thresholdT},
          {TF_LAYER_POINTS_0_NAME, maps::wiki::tf_inferencer::cvMatToTensor(kpts0.coords, true, false, false)},
          {TF_LAYER_SCORES_0_NAME, maps::wiki::tf_inferencer::cvMatToTensor(cv::Mat(kpts0.scores), true, false, false)},
          {TF_LAYER_DESCS_0_NAME, maps::wiki::tf_inferencer::cvMatToTensor(kpts0.descriptors, true, false, false)},
          {TF_LAYER_IMG_SZ_0_NAME, imageSize0},

          {TF_LAYER_POINTS_1_NAME, maps::wiki::tf_inferencer::cvMatToTensor(kpts1.coords, true, false, false)},
          {TF_LAYER_SCORES_1_NAME, maps::wiki::tf_inferencer::cvMatToTensor(cv::Mat(kpts1.scores), true, false, false)},
          {TF_LAYER_DESCS_1_NAME, maps::wiki::tf_inferencer::cvMatToTensor(kpts1.descriptors, true, false, false)},
          {TF_LAYER_IMG_SZ_1_NAME, imageSize1}
        },
        {TF_LAYER_INDICES_0_NAME}
    );
    REQUIRE(1 == result.size(), "Invalid output tensors number");
    REQUIRE((2 == result[0].dims()) && (1 == result[0].dim_size(0)),
            "Invalid indices0 tensor dimension");

    size_t pairsCount = result[0].dim_size(1);
    const int32_t *pIndices0 = static_cast<const int32_t*>(tf::DMAHelper::base(&result[0]));
    MatchedPairs pairs;
    for (size_t i = 0; i < pairsCount; i++) {
        if (pIndices0[i] == -1) {
            continue;
        }
        pairs.emplace_back(i, pIndices0[i]);
    }
    return pairs;
}

} // maps::mrc::keypoints_matcher
