#include "tf_inferencer.h"
#include "maskrcnn_inferencer.h"

#include <maps/libs/common/include/exception.h>

#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

#include <library/cpp/resource/resource.h>
#include <opencv2/opencv.hpp>

#include <library/cpp/resource/resource.h>

namespace tf = tensorflow;

namespace maps {
namespace wiki {
namespace tf_inferencer {

MaskRCNNInferencer::MaskRCNNInferencer(const tensorflow::GraphDef& graphDef)
    : inferencer_(graphDef)
{}
MaskRCNNInferencer::MaskRCNNInferencer(const std::string& path)
    : inferencer_(path)
{}
MaskRCNNInferencer::MaskRCNNInferencer(TensorFlowInferencer&& inferencer)
    : inferencer_(std::move(inferencer))
{}

MaskRCNNInferencer MaskRCNNInferencer::fromResource(const std::string& resourceName)
{
    return MaskRCNNInferencer(TensorFlowInferencer::fromResource(resourceName));
}

MaskRCNNResults MaskRCNNInferencer::inference(const cv::Mat &image, float scoreThreshold) const
{
    ImagesBatch batch(1, image);
    return inference(batch, scoreThreshold)[0];
}

std::vector<MaskRCNNResults> MaskRCNNInferencer::inference(const ImagesBatch& batch, float scoreThreshold) const {
    const std::string TF_LAYER_IMAGE_NAME      = "image_tensor";
    const std::string TF_LAYER_BOXES_NAME      = "detection_boxes:0";
    const std::string TF_LAYER_SCORES_NAME     = "detection_scores:0";
    const std::string TF_LAYER_MASKS_NAME      = "detection_masks:0";
    const std::string TF_LAYER_CLASSID_NAME    = "detection_classes:0";

    std::vector<tf::Tensor> result =
         inferencer_.inference(TF_LAYER_IMAGE_NAME, batch,
            {
                TF_LAYER_BOXES_NAME,
                TF_LAYER_SCORES_NAME,
                TF_LAYER_MASKS_NAME,
                TF_LAYER_CLASSID_NAME
            });

    // return 3 tensors
    //    1. boxes with shape [?, N, 4], which contains bounding boxes of detected objects
    //       in format [[[ymin1, xmin1, ymax1, xmax1], [ymin2, xmin2, ymax2, xmax2], ..., [yminN, xminN, ymaxN, xmaxN]]]
    //    2. scores with shape [?, N], which contains confidences of detected objects
    //    3. masks with shape [?, N, hm, wm] set of masks of detected objects
    //    4. classesID with shape [?, N], which contains ID of classes (1...T). Zero - background.

    const int batchSize = batch.size();

    REQUIRE(4 == result.size(), "Invalid output tensors number");
    REQUIRE((3 == result[0].dims()) &&
            (batchSize == result[0].dim_size(0)) &&
            (4 == result[0].dim_size(2)),
            "Invalid boxes tensor dimension");
    REQUIRE((2 == result[1].dims()) && (batchSize == result[1].dim_size(0)),
            "Invalid scores tensor dimension");
    REQUIRE((4 == result[2].dims()) &&
            (batchSize == result[2].dim_size(0)),
            "Invalid masks tensor dimension");
    REQUIRE((2 == result[3].dims()) &&
            (batchSize == result[3].dim_size(0)),
            "Invalid classes ids tensor dimension");

    const size_t objCnt = result[0].dim_size(1);

    REQUIRE((objCnt == (size_t)result[1].dim_size(1)) &&
            (objCnt == (size_t)result[2].dim_size(1)),
            "Different amount of boxes, scores and masks");

    const int maskRows = result[2].dim_size(2);
    const int maskCols = result[2].dim_size(3);
    const int maskShift = maskRows * maskCols;

    const float *pBoxes   = static_cast<const float*>(tf::DMAHelper::base(&result[0]));
    const float *pScores  = static_cast<const float*>(tf::DMAHelper::base(&result[1]));
    const float *pMasks   = static_cast<const float*>(tf::DMAHelper::base(&result[2]));
    const float *pClasses = static_cast<const float*>(tf::DMAHelper::base(&result[3]));

    REQUIRE(0 != pBoxes,  "Invalid boxes data");
    REQUIRE(0 != pScores, "Invalid scores data");
    REQUIRE(0 != pMasks,  "Invalid masks data");

    const int imgCols = batch[0].cols;
    const int imgRows = batch[0].rows;
    std::vector<MaskRCNNResults> batchObjects(batchSize);
    for (int imgIdx = 0; imgIdx < batchSize; imgIdx++) {
        MaskRCNNResults& objects = batchObjects[imgIdx];
        for (size_t i = 0; i < objCnt; i++) {
            if (scoreThreshold > pScores[i])
                continue;

            const int yMin = std::min(std::max((int)(pBoxes[i * 4 + 0] * imgRows), 0), imgRows - 1);
            const int xMin = std::min(std::max((int)(pBoxes[i * 4 + 1] * imgCols), 0), imgCols - 1);
            const int yMax = std::min(std::max((int)(pBoxes[i * 4 + 2] * imgRows), 0), imgRows - 1);
            const int xMax = std::min(std::max((int)(pBoxes[i * 4 + 3] * imgCols), 0), imgCols - 1);

            const cv::Rect bbox(cv::Point(xMin, yMin), cv::Point(xMax, yMax));

            if (std::min(bbox.height, bbox.width) <= 0) {
                continue;
            }

            cv::Mat curMask =
                    cv::Mat(maskRows, maskCols, CV_32FC1,
                            const_cast<void *>(static_cast<const void*>(pMasks + i * maskShift))).clone();

            cv::Mat rszMask;
            cv::resize(curMask, rszMask, bbox.size());

            objects.emplace_back(rszMask, bbox, (int)pClasses[i]);
        }

        pBoxes += 4 * objCnt;
        pScores += objCnt;
        pClasses += objCnt;
        pMasks += maskShift * objCnt;
    }
    return batchObjects;
}

cv::Mat MaskRCNNInferencer::segment(const cv::Mat &inputImage, float scoreThreshold, float maskThreshold) const {
    ImagesBatch batch(1, inputImage);
    return segment(batch, scoreThreshold, maskThreshold)[0];
}

std::vector<cv::Mat> MaskRCNNInferencer::segment(
    const ImagesBatch& inputImagesBatch,
    float scoreThreshold,
    float maskThreshold) const
{
    std::vector<MaskRCNNResults> batchResults = inference(inputImagesBatch, scoreThreshold);

    std::vector<cv::Mat> batchMasks(batchResults.size());
    for (size_t i = 0; i < batchResults.size(); i++) {
        cv::Mat& mask = batchMasks[i];

        mask = cv::Mat::zeros(inputImagesBatch[i].size(), CV_8UC1);
        for (const auto& object : batchResults[i]) {
            mask(object.bbox) = cv::max(mask(object.bbox), object.mask > maskThreshold);
        }
    }

    return batchMasks;
}

} //namespace tf_inferencer
} //namespace wiki
} //namespace maps
