#pragma once
#include "tf_inferencer.h"

#include <maps/libs/common/include/exception.h>

#include <opencv2/opencv.hpp>
#include <tensorflow/core/public/session.h>

#include <string>
#include <vector>

namespace maps {
namespace wiki {
namespace tf_inferencer {

struct MaskRCNNResult {
    MaskRCNNResult(const cv::Mat& _mask,
                   const cv::Rect& _bbox,
                   int _classID)
        : mask(_mask)
        , bbox(_bbox)
        , classID(_classID)
    {}
    cv::Mat  mask;    //< CV_32FC1 mask, pixel value will be in range [0.0, 1.0]
    cv::Rect bbox;
    int      classID;
};

typedef std::list<MaskRCNNResult> MaskRCNNResults;

class MaskRCNNInferencer {
public:
    /**
    * @param path to tensorflow model in protobuf format
    */
    MaskRCNNInferencer(const std::string& path);
    static MaskRCNNInferencer fromResource(const std::string& resourceName);
    MaskRCNNInferencer(const tensorflow::GraphDef& graphDef);
    MaskRCNNInferencer(TensorFlowInferencer&& inferencer);


    MaskRCNNResults inference(const cv::Mat &inputImage, float scoreThreshold) const;

    std::vector<MaskRCNNResults>
    inference(const ImagesBatch &inputImagesBatch, float scoreThreshold) const;

    /**
    * @brief Launch inference of network loaded in constructor
    *
    * @param inputImage      image for segmentation
    * @param scoreThreshold  minimal score of detected object to add to results
    * @param maskThreshold   minimal score of the mask value in pixel, to add this pixel
    *                        to result mask
    *
    * @return               mask of the input image. Pixel is zero if it is not in object
    *                       pixel is not zero otherwise
    */
    cv::Mat segment(const cv::Mat &inputImage, float scoreThreshold, float maskThreshold) const;

    std::vector<cv::Mat>
    segment(const ImagesBatch& inputImagesBatch, float scoreThreshold, float maskThreshold) const;

private:
    TensorFlowInferencer inferencer_;
};


} //namespace tf_inferencer
} //namespace wiki
} //namespace maps
