#include "tf_inferencer.h"
#include "sem_segm_inferencer.h"

#include <maps/libs/common/include/exception.h>

#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

#include <library/cpp/resource/resource.h>
#include <opencv2/opencv.hpp>

namespace tf = tensorflow;

namespace maps::wiki::tf_inferencer {

SemSegmInferencer::SemSegmInferencer(const tensorflow::GraphDef& graphDef)
    : inferencer_(graphDef)
{}
SemSegmInferencer::SemSegmInferencer(const std::string& path)
    : inferencer_(path)
{}
SemSegmInferencer::SemSegmInferencer(TensorFlowInferencer&& inferencer)
    : inferencer_(std::move(inferencer))
{}

SemSegmInferencer SemSegmInferencer::fromResource(const std::string& resourceName)
{
    return SemSegmInferencer(TensorFlowInferencer::fromResource(resourceName));
}

cv::Mat SemSegmInferencer::inference(const cv::Mat &image) const
{
    ImagesBatch batch(1, image);
    return inference(batch)[0];
}

ImagesBatch SemSegmInferencer::inference(const ImagesBatch &batch) const
{
    const std::string TF_LAYER_IMAGE_NAME      = "inference_image:0";
    const std::string TF_LAYER_LOGITS_NAME    = "inference_logits:0";

    std::vector<tf::Tensor> result =
         inferencer_.inference(TF_LAYER_IMAGE_NAME, batch, {TF_LAYER_LOGITS_NAME});

    REQUIRE(1 == result.size(), "Invalid output tensors number");
    REQUIRE((4 == result[0].dims()) &&
            ((int)batch.size() == result[0].dim_size(0)),
            "Invalid output tensor dimension");

    REQUIRE(255 >= result[0].dim_size(3), "Too many classes in semantic segmentation");

    const float *pScores  = static_cast<const float*>(tf::DMAHelper::base(&result[0]));

    REQUIRE(0 != pScores, "Invalid scores data");

    // мы можем отдавать сегментацию не такого же размера как исходная картинка
    const int segmRows = result[0].dim_size(1);
    const int segmCols = result[0].dim_size(2);
    const int segmClasses = result[0].dim_size(3);
    ImagesBatch results(batch.size());
    for (size_t imgIdx = 0; imgIdx < batch.size(); imgIdx++) {
        results[imgIdx].create(segmRows, segmCols, CV_8UC1);
        cv::Mat &result = results[imgIdx];
        for (int segmRow = 0; segmRow < segmRows; segmRow++) {
            uchar *ptr = result.ptr<uchar>(segmRow);
            for (int segmCol = 0; segmCol < segmCols; segmCol++) {
                float maxScore = pScores[0];
                int bestClassIdx = 0;
                for (int classIdx = 1; classIdx < segmClasses; classIdx++) {
                    if (pScores[classIdx] > maxScore) {
                        maxScore = pScores[classIdx];
                        bestClassIdx = classIdx;
                    }
                }
                ptr[segmCol] = (uchar)bestClassIdx;
                pScores += segmClasses;
            }
        }
    }

    return results;
}

} //namespace maps::wiki::tf_inferencer
