#include <maps/wikimap/mapspro/services/mrc/libs/sideview_classifier/include/sideview.h>

#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

#include <opencv2/opencv.hpp>

#include <vector>

namespace tf = tensorflow;
namespace tfi = maps::wiki::tf_inferencer;

namespace maps {
namespace mrc {
namespace sideview {

namespace {

const std::string TF_MODEL_RESOURCE = "/maps/mrc/sideview/models/tf_model.gdef";

SideViewType stringToSideViewType(const std::string& str) {
    static const std::string FORWARD_VIEW_NAME = "forward_view";
    static const std::string SIDE_VIEW_NAME    = "side_view";

    if (FORWARD_VIEW_NAME == str)
        return SideViewType::ForwardView;
    else if(SIDE_VIEW_NAME == str)
        return SideViewType::SideView;
    else
        throw ::maps::RuntimeError() << "Unknown class from sideview classifier: " << str;
}


} // namespace

SideViewClassifier::SideViewClassifier()
    : inferencer(tfi::TensorFlowInferencer::fromResource(TF_MODEL_RESOURCE))
{
    evalClassesNames();
}

void SideViewClassifier::evalClassesNames() {
    static const std::string TF_OUTPUT_CLASS_NAMES_LAYER_NAME = "class_names:0";

    std::vector<TString> strs = tfi::tensorToVector<TString>(
        inferencer.inference(TF_OUTPUT_CLASS_NAMES_LAYER_NAME)
    );

    classesNames.reserve(strs.size());
    for(const auto& str : strs) {
        classesNames.push_back(str);
    }
}

std::pair<SideViewType, float>
SideViewClassifier::inference(const cv::Mat &image1, const cv::Mat &image2) const{
    static const std::string TF_LAYER_FIRST_IMAGE_NAME  = "inference_input1";
    static const std::string TF_LAYER_SECOND_IMAGE_NAME = "inference_input2";
    static const std::string TF_LAYER_OUTPUT_NAME       = "inference_softmax";

    std::vector< std::pair<std::string, cv::Mat> > inputLayerImages({
        {TF_LAYER_FIRST_IMAGE_NAME,  image1},
        {TF_LAYER_SECOND_IMAGE_NAME, image2},
    });

    std::vector<tf::Tensor> result = inferencer.inference(inputLayerImages, {TF_LAYER_OUTPUT_NAME});
    REQUIRE(1 == result.size(), "Invalid output tensors number");
    REQUIRE((2 == result[0].dims()) && (1 == result[0].dim_size(0)) && (2 == result[0].dim_size(1)),
            "Invalid scores tensor dimension");

    const float *pScores  = static_cast<const float*>(tf::DMAHelper::base(&result[0]));
    int classID = 0;
    float maxScore = pScores[0];
    for (int i = 1; i < result[0].dim_size(1); i++) {
        if (maxScore < pScores[i]) {
            maxScore = pScores[i];
            classID = i;
        }
    }
    return {stringToSideViewType(classesNames[classID]), maxScore};
}

} // sideview
} // mrc
} // maps
