#include "tf_inferencer.h"
#include "faster_rcnn_inferencer.h"

#include <maps/libs/common/include/exception.h>

#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

#include <library/cpp/resource/resource.h>
#include <opencv2/opencv.hpp>

namespace tf = tensorflow;

namespace maps {
namespace wiki {
namespace tf_inferencer {

FasterRCNNInferencer::FasterRCNNInferencer(const tensorflow::GraphDef& graphDef)
    : inferencer_(graphDef)
{}
FasterRCNNInferencer::FasterRCNNInferencer(const std::string& path)
    : inferencer_(path)
{}
FasterRCNNInferencer::FasterRCNNInferencer(TensorFlowInferencer&& inferencer)
    : inferencer_(std::move(inferencer))
{}

FasterRCNNInferencer FasterRCNNInferencer::fromResource(const std::string& resourceName)
{
    return FasterRCNNInferencer(TensorFlowInferencer::fromResource(resourceName));
}

FasterRCNNResults FasterRCNNInferencer::inference(const cv::Mat &image, float scoreThreshold) const
{
    ImagesBatch batch(1, image);
    return inference(batch, scoreThreshold)[0];
}

std::vector<FasterRCNNResults> FasterRCNNInferencer::inference(const ImagesBatch &batch, float scoreThreshold) const
{
    const std::string TF_LAYER_IMAGE_NAME      = "image_tensor";
    const std::string TF_LAYER_BOXES_NAME      = "detection_boxes:0";
    const std::string TF_LAYER_SCORES_NAME     = "detection_scores:0";
    const std::string TF_LAYER_CLASSID_NAME    = "detection_classes:0";

    std::vector<tf::Tensor> result =
         inferencer_.inference(TF_LAYER_IMAGE_NAME, batch,
            {
                TF_LAYER_BOXES_NAME,
                TF_LAYER_SCORES_NAME,
                TF_LAYER_CLASSID_NAME
            });

    // return 3 tensors
    //    1. boxes with shape [?, N, 4], which contains bounding boxes of detected objects
    //       in format [[[ymin1, xmin1, ymax1, xmax1], [ymin2, xmin2, ymax2, xmax2], ..., [yminN, xminN, ymaxN, xmaxN]]]
    //    2. scores with shape [?, N], which contains confidences of detected objects
    //    3. classesID with shape [?, N], which contains ID of classes (1...T). Zero - background.

    REQUIRE(3 == result.size(), "Invalid output tensors number");
    REQUIRE((3 == result[0].dims()) &&
            ((int)batch.size() == result[0].dim_size(0)) &&
            (4 == result[0].dim_size(2)),
            "Invalid boxes tensor dimension");
    REQUIRE((2 == result[1].dims()) && ((int)batch.size() == result[1].dim_size(0)),
            "Invalid scores tensor dimension");
    REQUIRE((2 == result[2].dims()) &&
            ((int)batch.size() == result[2].dim_size(0)),
            "Invalid classes ids tensor dimension");

    const size_t objCnt = result[0].dim_size(1);

    REQUIRE((objCnt == (size_t)result[1].dim_size(1)) &&
            (objCnt == (size_t)result[2].dim_size(1)),
            "Different amount of boxes, scores and classes");

    const float *pBoxes   = static_cast<const float*>(tf::DMAHelper::base(&result[0]));
    const float *pScores  = static_cast<const float*>(tf::DMAHelper::base(&result[1]));
    const float *pClasses = static_cast<const float*>(tf::DMAHelper::base(&result[2]));

    REQUIRE(0 != pBoxes,  "Invalid boxes data");
    REQUIRE(0 != pScores, "Invalid scores data");
    REQUIRE(0 != pClasses, "Invalid classes data");

    const int imgCols = batch[0].cols;
    const int imgRows = batch[0].rows;
    std::vector<FasterRCNNResults> batchObjects(batch.size());
    for (size_t imgIdx = 0; imgIdx < batch.size(); imgIdx++) {
        FasterRCNNResults &objects = batchObjects[imgIdx];
        for (size_t i = 0; i < objCnt; i++){
            if (scoreThreshold > pScores[i]) {
                continue;
            }

            const int yMin = std::min(std::max((int)(pBoxes[i * 4 + 0] * imgRows), 0), imgRows - 1);
            const int xMin = std::min(std::max((int)(pBoxes[i * 4 + 1] * imgCols), 0), imgCols - 1);
            const int yMax = std::min(std::max((int)(pBoxes[i * 4 + 2] * imgRows), 0), imgRows - 1);
            const int xMax = std::min(std::max((int)(pBoxes[i * 4 + 3] * imgCols), 0), imgCols - 1);

            const cv::Rect bbox(cv::Point(xMin, yMin), cv::Point(xMax, yMax));

            if (std::min(bbox.height, bbox.width) > 0) {
                objects.emplace_back(bbox, (int)pClasses[i], pScores[i]);
            }
        }
        pBoxes += 4 * objCnt;
        pScores += objCnt;
        pClasses += objCnt;
    }

    return batchObjects;
}

} //namespace tf_inferencer
} //namespace wiki
} //namespace maps
