#include "tf_inferencer.h"

#include <maps/libs/common/include/exception.h>

#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/common_runtime/dma_helper.h>

#include <library/cpp/resource/resource.h>
#include <opencv2/opencv.hpp>

#include <library/cpp/resource/resource.h>

namespace tf = tensorflow;

namespace maps {
namespace wiki {
namespace tf_inferencer {

namespace {

tf::Tensor batchToTensor(const ImagesBatch& batch) {
    REQUIRE(0 < batch.size(),
            "Unable to convert empty batch to tensor");

    const int cols = batch[0].cols;
    const int rows = batch[0].rows;
    const int channels = batch[0].channels();

    for (size_t i = 0; i < batch.size(); i++) {
        const cv::Mat &image = batch[i];
        REQUIRE(image.type() == CV_8UC1 || image.type() == CV_8UC3,
                "Unable to convert image with types except CV_8UC1 and CV_8UC3");
        REQUIRE(image.cols == cols && image.rows == rows &&
                image.channels() == channels,
                "Unable to convert batch with different images sizes or amount of channels");
    }

    tf::Tensor output(tf::DataType::DT_UINT8,
                      tf::TensorShape({(int)batch.size(), rows, cols, channels}));
    auto output_data = output.flat<uint8_t>().data();
    for (size_t i = 0; i < batch.size(); i++) {
        const cv::Mat &image = batch[i];
        if (image.isContinuous()) {
            std::copy_n(image.data, cols * rows * channels, output_data);
        }
        else {
            for (int row = 0; row < rows; ++row) {
                std::copy_n(image.ptr<uint8_t>(row), cols * channels, output_data + row * cols * channels);
            }
        }
        output_data += cols * rows * channels;
    }
    return output;
}

tf::GraphDef loadGdefFromResource(const std::string& resourceName) {
    tf::GraphDef graphDef;
    REQUIRE(graphDef.ParseFromString(NResource::Find(resourceName)), "failed to load model");
    return graphDef;
}

tf::GraphDef loadGdefFromFile(const std::string& path) {
    tf::GraphDef graphDef;
    auto loadGraphStatus
        = tf::ReadBinaryProto(tf::Env::Default(), TString(path), &graphDef);
    REQUIRE(loadGraphStatus.ok(),
            "Failed to load graph at '" << path << "'");
    return graphDef;
}

template<class T>
void copyMatDataToTensor(const cv::Mat &mat, tf::Tensor &tensor) {
    const int channels = mat.channels();
    REQUIRE(sizeof(T) * channels == mat.elemSize(), "Size of matrix elements " << mat.elemSize() << " is different from tensor data " << sizeof(T) * channels);

    const int cols = mat.cols;
    const int rows = mat.rows;

    T* tensor_data = tensor.flat<T>().data();
    if (mat.isContinuous()) {
        std::copy_n((T*)mat.data, cols * rows * channels, tensor_data);
    }
    else {
        for (int row = 0; row < rows; ++row) {
            std::copy_n(mat.ptr<T>(row), cols * channels, tensor_data + row * cols * channels);
        }
    }
}

tf::TensorShape makeTensorShape(const cv::Mat& mat, bool addBatchDimension = true, bool singleColumnAsDim = false, bool singleChannelAsDim = false) {
    tf::TensorShape shape;
    if (addBatchDimension) {
        shape.AddDim(1);
    }
    shape.AddDim(mat.rows);
    if (1 != mat.cols || singleColumnAsDim) {
        shape.AddDim(mat.cols);
    }
    if (1 != mat.channels() || singleChannelAsDim) {
        shape.AddDim(mat.channels());
    }
    return shape;
}

} // anonymous namespace

cv::Mat tensorToImage(const tf::Tensor& tensor) {
    REQUIRE(1 <= tensor.dims() && tensor.dims() <= 4,
            "Tensor has unacceptable dimensions count" << tensor.dims());
    const int rows = tensor.dims() > 1 ? tensor.dim_size(1) : 1;
    const int cols = tensor.dims() > 2 ? tensor.dim_size(2) : 1;
    const int cn   = tensor.dims() > 3 ? tensor.dim_size(3) : 1;

    int type = 0;
    switch (tensor.dtype()) {
    case tf::DT_FLOAT:
        type = CV_MAKETYPE(CV_32F, cn);
        break;
    case tf::DT_DOUBLE:
        type = CV_MAKETYPE(CV_64F, cn);
        break;
    case tf::DT_INT32:
    //case tf::DT_UINT32:
        type = CV_MAKETYPE(CV_32S, cn);
        break;
    case tf::DT_UINT16:
        type = CV_MAKETYPE(CV_16U, cn);
        break;
    case tf::DT_UINT8:
        type = CV_MAKETYPE(CV_8U, cn);
        break;
    case tf::DT_INT16:
        type = CV_MAKETYPE(CV_16S, cn);
        break;
    case tf::DT_INT8:
        type = CV_MAKETYPE(CV_8S, cn);
        break;
    case tf::DT_INT64:
        type = CV_MAKETYPE(CV_64F, cn);
        break;
    default:
        REQUIRE(false, "Unable to convert tensor with data type: " << tensor.dtype());
    }

    void *tensorData = const_cast<void*>(tf::DMAHelper::base(&tensor));
    return cv::Mat(rows, cols, type, tensorData).clone();
}

tf::Tensor cvMatToTensor(const cv::Mat& mat, bool addBatchDimension, bool singleColumnAsDim, bool singleChannelAsDim) {
    REQUIRE(2 == mat.dims,
            "Supported matrix with 2 dimensions only not with: " << mat.dims);

    tf::TensorShape shape = makeTensorShape(mat, addBatchDimension, singleColumnAsDim, singleChannelAsDim);
    tf::Tensor output;
    switch (mat.depth()) {
    case CV_8U:
        output = tf::Tensor(tf::DT_UINT8, shape);
        copyMatDataToTensor<uint8_t>(mat, output);
        break;
    case CV_8S:
        output = tf::Tensor(tf::DT_INT8, shape);
        copyMatDataToTensor<int8_t>(mat, output);
        break;
    case CV_16U:
        output = tf::Tensor(tf::DT_UINT16, shape);
        copyMatDataToTensor<uint16_t>(mat, output);
        break;
    case CV_16S:
        output = tf::Tensor(tf::DT_INT16, shape);
        copyMatDataToTensor<int16_t>(mat, output);
        break;
    case CV_32S:
        output = tf::Tensor(tf::DT_INT32, shape);
        copyMatDataToTensor<int32_t>(mat, output);
        break;
    case CV_32F:
        output = tf::Tensor(tf::DT_FLOAT, shape);
        copyMatDataToTensor<float>(mat, output);
        break;
    case CV_64F:
        output = tf::Tensor(tf::DT_DOUBLE, shape);
        copyMatDataToTensor<double>(mat, output);
        break;
    default:
        REQUIRE(false, "Unable to convert cvMat with data type: " << mat.type());
    }
    return output;
}

TensorFlowInferencer::TensorFlowInferencer(const tensorflow::GraphDef& graphDef) {
    tf::Session* session = nullptr;
    auto operationStatus = tf::NewSession(tf::SessionOptions(), &session);
    REQUIRE(operationStatus.ok(),
            "Failed to create session: " << operationStatus);
    REQUIRE(session != nullptr, "Failed to create session");
    session_.reset(session);
    operationStatus = session_->Create(graphDef);
    REQUIRE(operationStatus.ok(),
            "Failed to initialize session: " << operationStatus);
}

TensorFlowInferencer::TensorFlowInferencer(const std::string& path)
: TensorFlowInferencer(loadGdefFromFile(path))
{}

TensorFlowInferencer TensorFlowInferencer::fromResource(const std::string& resourceName)
{
    return TensorFlowInferencer(loadGdefFromResource(resourceName));
}

tf::Tensor TensorFlowInferencer::inference(const std::string &inputLayerName,
                                        const cv::Mat &inputImage,
                                        const std::string &outputLayerName) const {
    REQUIRE(!inputLayerName.empty(), "Name of the input layer undefined");
    REQUIRE(!inputImage.empty(), "Input image is undefined");
    REQUIRE(!outputLayerName.empty(), "Name of the output layer undefined");

    tf::Tensor input = batchToTensor({inputImage});

    std::vector<tf::Tensor> outputs;
    auto runStatus = session_->Run({{TString(inputLayerName), input}},
                                   {TString(outputLayerName)}, {},
                                   &outputs);
    REQUIRE(runStatus.ok(), "Running model failed: " << runStatus);

    REQUIRE(1 == outputs.size(), "Wrong amount of outputs: " << outputs.size());
    return outputs[0];
}

std::vector<tensorflow::Tensor>
TensorFlowInferencer::inference(
        const std::string &inputLayerName,
        const tf::Tensor &input,
        const std::vector<std::string> &outputLayerNames) const
{
    REQUIRE(!inputLayerName.empty(), "Name of the input layer undefined");
    REQUIRE(!outputLayerNames.empty(), "Output layers undefined");

    std::vector<TString> outputLayerNamesT;
    outputLayerNamesT.reserve(outputLayerNames.size());

    for(const auto& outputLayerName : outputLayerNames) {
        REQUIRE(!outputLayerName.empty(), "Output layer name is empty");
        outputLayerNamesT.emplace_back(outputLayerName);
    }

    std::vector<tf::Tensor> outputs;
    auto runStatus = session_->Run({{TString(inputLayerName), input}},
                                   outputLayerNamesT, {},
                                   &outputs);
    REQUIRE(runStatus.ok(), "Running model failed: " << runStatus);

    REQUIRE(outputLayerNames.size() == outputs.size(),
         "Wrong amount of outputs: " << outputs.size()
         << ", expected " << outputLayerNames.size());
    return outputs;
}

std::vector<tf::Tensor>
TensorFlowInferencer::inference(
        const std::string &inputLayerName,
        const cv::Mat &inputImage,
        const std::vector<std::string> &outputLayerNames) const
{
    REQUIRE(!inputImage.empty(), "Input image is undefined");
    tf::Tensor input = batchToTensor({inputImage});
    return inference(inputLayerName, input, outputLayerNames);
}

std::vector<tensorflow::Tensor>
TensorFlowInferencer::inference(
        const std::string &inputLayerName,
        const ImagesBatch &inputBatch,
        const std::vector<std::string> &outputLayerNames) const
{
    REQUIRE(0 < inputBatch.size(), "Input batch is empty");
    tf::Tensor input = batchToTensor(inputBatch);
    return inference(inputLayerName, input, outputLayerNames);
}


std::vector<tensorflow::Tensor>
TensorFlowInferencer::inference(const std::vector<std::pair<std::string, cv::Mat>> &inputLayerImages,
                                const std::vector<std::string> &outputLayerNames) const
{
    REQUIRE(!inputLayerImages.empty(), "Input layers and images undefined");
    REQUIRE(!outputLayerNames.empty(), "Output layers undefined");

    std::vector<std::pair<TString, tf::Tensor>> inputLayersT;
    inputLayersT.reserve(inputLayerImages.size());
    for(const auto& inputLayer : inputLayerImages) {
        REQUIRE(!inputLayer.first.empty(), "Input layer name is empty");
        REQUIRE(!inputLayer.second.empty(), "Input image is empty");
        inputLayersT.emplace_back(TString(inputLayer.first), batchToTensor({inputLayer.second}));
    }

    std::vector<TString> outputLayerNamesT;
    outputLayerNamesT.reserve(outputLayerNames.size());
    for(const auto& outputLayerName : outputLayerNames) {
        REQUIRE(!outputLayerName.empty(), "Output layer name is empty");
        outputLayerNamesT.emplace_back(outputLayerName);
    }

    std::vector<tf::Tensor> outputs;
    auto runStatus = session_->Run(inputLayersT,
                                   outputLayerNamesT, {},
                                   &outputs);
    REQUIRE(runStatus.ok(), "Running model failed: " << runStatus);

    REQUIRE(outputLayerNames.size() == outputs.size(),
         "Wrong amount of outputs: " << outputs.size()
         << ", expected " << outputLayerNames.size());
    return outputs;
}

std::vector<tensorflow::Tensor>
TensorFlowInferencer::inference(const std::vector<std::pair<std::string, tensorflow::Tensor>> &inputLayerTensors,
                                const std::vector<std::string> &outputLayerNames) const
{
    REQUIRE(!inputLayerTensors.empty(), "Input layers and images undefined");
    REQUIRE(!outputLayerNames.empty(), "Output layers undefined");

    std::vector<std::pair<TString, tf::Tensor>> inputLayersT;
    inputLayersT.reserve(inputLayerTensors.size());
    for(const auto& inputLayer : inputLayerTensors) {
        REQUIRE(!inputLayer.first.empty(), "Input layer name is empty");
        inputLayersT.emplace_back(TString(inputLayer.first), inputLayer.second);
    }

    std::vector<TString> outputLayerNamesT;
    outputLayerNamesT.reserve(outputLayerNames.size());
    for(const auto& outputLayerName : outputLayerNames) {
        REQUIRE(!outputLayerName.empty(), "Output layer name is empty");
        outputLayerNamesT.emplace_back(outputLayerName);
    }

    std::vector<tf::Tensor> outputs;
    auto runStatus = session_->Run(inputLayersT,
                                   outputLayerNamesT, {},
                                   &outputs);
    REQUIRE(runStatus.ok(), "Running model failed: " << runStatus);

    REQUIRE(outputLayerNames.size() == outputs.size(),
         "Wrong amount of outputs: " << outputs.size()
         << ", expected " << outputLayerNames.size());
    return outputs;
}

tf::Tensor
TensorFlowInferencer::inference(const std::string &outputLayerName) const
{
    REQUIRE(!outputLayerName.empty(), "Output layer name is empty");

    std::vector<tf::Tensor> outputs;
    auto runStatus = session_->Run({}, {TString(outputLayerName)}, {}, &outputs);
    REQUIRE(runStatus.ok(), "Running model failed: " << runStatus);

    REQUIRE(1 == outputs.size(), "Wrong amount of outputs: " << outputs.size());
    return outputs[0];
}

} //namespace tf_inferencer
} //namespace wiki
} //namespace maps
