#include <maps/wikimap/mapspro/services/mrc/libs/common/include/exif.h>
#include <maps/libs/common/include/file_utils.h>

#include <maps/libs/cmdline/include/cmdline.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/log8/include/log8.h>

#include <opencv2/opencv.hpp>

#include <tensorflow/core/public/session.h>
#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/framework/tensor.h>

#include <util/generic/string.h>
#include <util/system/env.h>

namespace tf = tensorflow;

namespace {

const size_t CHANNELS_NUM = 3;
const TString INPUT_LAYER_NAME = "inference_input";
const TString OUTPUT_SOFTMAX_LAYER_NAME = "inference_softmax";
const TString OUTPUT_CLASS_NAMES_LAYER_NAME = "class_names";

std::vector<uint8_t> toBGR(const cv::Mat& image)
{
    std::vector<uint8_t> result;
    result.reserve(image.size().area() * 3);
    if (image.isContinuous()) {
        result.assign(image.datastart, image.dataend);
    }
    else {
        for (int i = 0; i < image.rows; ++i) {
            result.insert(result.end(), image.ptr<uint8_t>(i),
                          image.ptr<uint8_t>(i)
                              + image.cols * image.channels());
        }
    }
    return result;
}

class RotationClassifier {
public:
    /**
    * @param path to tensorflow model in protobuf format
    * Tensorflow model should contain input layer 'inference_input'
    * for image pixel data and output layers:
    *  'inference_softmax' -- with objects probabilities
    *  'class_names' -- with objects identifiers
    */
    RotationClassifier(const std::string& path)
    {
        tf::GraphDef graph_def;
        auto load_graph_status
            = tf::ReadBinaryProto(tf::Env::Default(), TString(path), &graph_def);
        REQUIRE(load_graph_status.ok(), "Failed to load graph at '" << path
                                                                    << "'");
        tf::Session* session = nullptr;
        auto operationStatus = tf::NewSession(tf::SessionOptions(), &session);
        REQUIRE(operationStatus.ok(),
                "Failed to create session: " << operationStatus);
        REQUIRE(session != nullptr, "Failed to create session");
        session_.reset(session);
        operationStatus = session_->Create(graph_def);
        REQUIRE(operationStatus.ok(),
                "Failed to initialize session: " << operationStatus);
    }

    /**
        * @brief Classifies input image
        *
        * @param data 8-bit image data in BGR format
        * @param width of image
        * @param height of image
        *
        * @return pair of classified object identifier and correspondent
        *  confidence (in range [0, 1])
        */
    std::pair<std::string, float> classify(const std::vector<uint8_t>& data,
                                           size_t width,
                                           size_t height) const
    {
        tf::Tensor input(tf::DataType::DT_UINT8,
                         tf::TensorShape({1,
                                          static_cast<int>(height),
                                          static_cast<int>(width),
                                          CHANNELS_NUM}));

        auto dst = input.flat<uint8_t>().data();
        std::copy_n(data.begin(), CHANNELS_NUM * height * width, dst);

        std::vector<tf::Tensor> outputs;
        auto runStatus = session_->Run(
            {{INPUT_LAYER_NAME, input}},
            {OUTPUT_SOFTMAX_LAYER_NAME, OUTPUT_CLASS_NAMES_LAYER_NAME}, {},
            &outputs);
        REQUIRE(runStatus.ok(), "Running model failed: " << runStatus);

        auto confidences = outputs[0].matrix<float>();
        int64_t argmax = 0;
        float valmax = confidences(0, argmax);
        for (int64_t i = 0; i < confidences.dimension(1); ++i) {
            auto confidence = confidences(0, i);
            INFO() << i << " " << confidence;
            if (confidence > valmax) {
                valmax = confidence;
                argmax = i;
            }
        }

        std::string sign = outputs[1].vec<TString>()(argmax);
        return std::make_pair(sign, valmax);
    }

    std::vector<std::string> classNames() const
    {
        std::vector<tf::Tensor> outputs;
        auto runStatus = session_->Run(
            {}, {OUTPUT_CLASS_NAMES_LAYER_NAME}, {}, &outputs);
        REQUIRE(runStatus.ok(), "Running model failed: " << runStatus);

        auto classes = outputs[0].vec<TString>();
        std::vector<std::string> result;
        size_t classesNum = classes.dimension(0);
        result.reserve(classesNum);

        for (size_t i = 0; i < classesNum; ++i) {
            result.push_back(classes(i));
        }

        return result;
    }

private:
    std::unique_ptr<tf::Session> session_;
};

} // namespace

int main(int argc, char** argv) try {

    if (GetEnv("TF_CPP_MIN_VLOG_LEVEL").empty() &&
        GetEnv("TF_CPP_MIN_LOG_LEVEL").empty())
        SetEnv("TF_CPP_MIN_LOG_LEVEL", "99"); // Silence!


    if (GetEnv("TF_DISABLE_MKLDNN").empty())
        SetEnv("TF_DISABLE_MKLDNN", "1");

    maps::cmdline::Parser parser;
    auto modelPath
        = parser.string("model").required().help("path to trained model");

    auto imagePath
        = parser.string("image").required().help("path to trained model");

    parser.parse(argc, argv);

    INFO() << "Parsing image orientation from exif";
    auto imageData = maps::common::readFileToVector(imagePath);
    auto orient = maps::mrc::common::parseImageOrientationFromExif(imageData);
    if (orient) {
        INFO() << "horizontalFlip=" << orient->horizontalFlip()
        << " rotationDegrees=" << orient->rotation();
    } else {
        INFO() << "no orientation";
    }


    INFO() << "loading file " << imagePath;
    auto image = cv::imread(imagePath);

    INFO() << "loading model " << modelPath;
    RotationClassifier classifier(modelPath);

    INFO() << "supported classes:";
    for (const auto& name : classifier.classNames()) {
        INFO() << "\t" << name;
    }

    INFO() << "classifying image";
    auto result = classifier.classify(
        toBGR(image),
        image.cols,
        image.rows
    );

    INFO() << result.first << "\t" << result.second;

    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    INFO() << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    INFO() << e.what();
    return EXIT_FAILURE;
}
catch (...) {
    INFO() << "Caught unknown exception";
    return EXIT_FAILURE;
}
