#include <maps/wikimap/mapspro/services/mrc/libs/superpoint/include/superpoint.h>

#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

#include <opencv2/opencv.hpp>


namespace tf = tensorflow;
namespace tfi = maps::wiki::tf_inferencer;

namespace maps::mrc::superpoint {

namespace {

tfi::ImagesBatch makeGray(const tfi::ImagesBatch &imagesBatch)
{
    tfi::ImagesBatch grayImagesBatch;
    for (size_t i = 0; i < imagesBatch.size(); i++) {
        const cv::Mat& image = imagesBatch[i];
        REQUIRE(1 == image.channels() || 3 == image.channels(),
            "Invalid input image channels count");
        if (image.channels() == 1) {
            grayImagesBatch.push_back(image);
        } else {
            cv::Mat gray;
            cv::cvtColor(image, gray, cv::COLOR_BGR2GRAY);
            grayImagesBatch.push_back(gray);
        }
    }
    return grayImagesBatch;
}

const std::string TF_MODEL_RESOURCE = "/maps/mrc/keypoint/models/superpoints.gdef";

} // namespace

SuperpointDetector::SuperpointDetector()
    : inferencer_(tfi::TensorFlowInferencer::fromResource(TF_MODEL_RESOURCE))
{}

std::vector<common::Keypoints> SuperpointDetector::detect(const tfi::ImagesBatch &imagesBatch) const
{
    const int batchSize = (int)imagesBatch.size();
    REQUIRE(0 < batchSize, "Batch is empty");

    tfi::ImagesBatch grayImagesBatch = makeGray(imagesBatch);

    const int imageWidth = grayImagesBatch[0].cols;
    const int imageHeight = grayImagesBatch[0].rows;

    constexpr int featureSize = 256;
    const std::string TF_LAYER_IMAGE_NAME       = "inference_images";
    const std::string TF_LAYER_KEY_POINTS_NAME  = "inference_points:0";
    const std::string TF_LAYER_SCORES_NAME      = "inference_scores:0";
    const std::string TF_LAYER_DESCRIPTORS_NAME = "inference_descriptors:0";

    std::vector<tensorflow::Tensor> result = inferencer_.inference(
        TF_LAYER_IMAGE_NAME,
        grayImagesBatch,
        {TF_LAYER_KEY_POINTS_NAME , TF_LAYER_DESCRIPTORS_NAME, TF_LAYER_SCORES_NAME}
    );

    REQUIRE(3 == result.size(), "Invalid output tensors number");
    REQUIRE((2 == result[0].dims()) && (3 == result[0].dim_size(1)),
            "Invalid points tensor dimension");
    REQUIRE((2 == result[1].dims()) &&
            (featureSize == result[1].dim_size(1)),
            "Invalid descriptors tensor dimension");
    REQUIRE((1 == result[2].dims()),
            "Invalid scores tensor dimension");

    const int pointsCount = result[0].dim_size(0);
    REQUIRE((pointsCount == result[1].dim_size(0)) &&
            (pointsCount == result[2].dim_size(0)),
            "Different amount of points, scores and descriptors");

    const int32_t *pPoints    = static_cast<const int32_t*>(tensorflow::DMAHelper::base(&result[0]));
    const float *pDescriptors = static_cast<const float*>(tensorflow::DMAHelper::base(&result[1]));
    const float *pScores      = static_cast<const float*>(tensorflow::DMAHelper::base(&result[2]));

    REQUIRE(0 != pPoints,      "Invalid points data");
    REQUIRE(0 != pDescriptors, "Invalid descriptors data");
    REQUIRE(0 != pScores,      "Invalid scores data");

    /*
        Координаты точек придут в виде трехмерного вектора:
            (batch_idx, y, x)
        при этом почти наверняка, эти точки будут упорядочены по batch_idx,
        но гарантий я нигде таких не нашел в используемой TF операции, поэтому
        такое предположение мы использовать не будем и честно переберем все точки.
    */

    std::vector<common::Keypoints> results(batchSize);

    // заполняем вектор scores и в тоже время считаем кол-во точек для каждой
    // картинки в батче
    for (int i = 0; i < pointsCount; i++) {
        const int batchIdx = pPoints[3 * i];
        REQUIRE(0 <= batchIdx && batchIdx < batchSize, "Invalid batch index: " << batchIdx);
        results[batchIdx].scores.push_back(pScores[i]);
    }

    // создаём матрицы нужных размеров для координат точек и дескрипторов
    // и заполняем размеры изображения
    for (int i = 0; i < batchSize; i++) {
        common::Keypoints& kpts = results[i];

        const int imagePointsCount = (int)kpts.scores.size();
        kpts.imageWidth = imageWidth;
        kpts.imageHeight = imageHeight;
        kpts.coords.create(imagePointsCount, 1, CV_32FC2);
        kpts.descriptors.create(imagePointsCount, featureSize, CV_32FC1);
    }

    std::vector<size_t> lastIdx(batchSize, 0);
    for (int i = 0; i < pointsCount; i++) {
        const int batchIdx = pPoints[3 * i];

        REQUIRE(0 <= batchIdx && batchIdx < batchSize, "Invalid batch index: " << batchIdx);

        common::Keypoints& kpts = results[batchIdx];
        size_t& idx = lastIdx[batchIdx];

        kpts.coords.at<cv::Vec2f>(idx, 0) = cv::Vec2f(pPoints[3 * i + 2], pPoints[3 * i + 1]);
        std::copy_n(&pDescriptors[featureSize * i], featureSize, kpts.descriptors.ptr<float>(idx));

        idx++;
    }
    return results;
}

common::Keypoints SuperpointDetector::detect(const cv::Mat &image) const
{
    tfi::ImagesBatch batch(1, image);
    return detect(batch)[0];
}

} // namespace maps::mrc::superpoint
