#include <maps/wikimap/mapspro/libs/tfrecord_writer/include/tfrecord_writer.h>

#include <maps/wikimap/mapspro/libs/tfrecord_writer/impl/example.pb.h>
#include <maps/wikimap/mapspro/libs/tfrecord_writer/impl/feature.pb.h>

#include <tensorflow/core/platform/posix/posix_file_system.h>
#include <tensorflow/core/platform/file_system.h>

#include "opencv2/opencv.hpp"

#include <maps/libs/common/include/exception.h>

#include <iostream>

namespace maps {
namespace wiki {
namespace tfrecord_writer {

namespace {

Feature Int64Feature(int64_t value) {
    std::unique_ptr<Int64List> valueList(new Int64List);
    valueList->add_value(value);
    Feature feature;
    feature.set_allocated_int64_list(valueList.release());
    return feature;
}

Feature Int64ListFeature(const std::list<int64_t> &values) {
    std::unique_ptr<Int64List> valueList(new Int64List);
    for (const auto &value : values) {
        valueList->add_value(value);
    }
    Feature feature;
    feature.set_allocated_int64_list(valueList.release());
    return feature;
}

Feature FloatFeature(float value) {
    std::unique_ptr<FloatList> valueList(new FloatList);
    valueList->add_value(value);
    Feature feature;
    feature.set_allocated_float_list(valueList.release());
    return feature;
}

Feature FloatListFeature(const std::list<float> &values) {
    std::unique_ptr<FloatList> valueList(new FloatList);
    for (const auto &value : values) {
        valueList->add_value(value);
    }
    Feature feature;
    feature.set_allocated_float_list(valueList.release());
    return feature;
}

Feature StringFeature(const std::string &value) {
    std::unique_ptr<BytesList> valueList(new BytesList);
    valueList->add_value(value.c_str());
    Feature feature;
    feature.set_allocated_bytes_list(valueList.release());
    return feature;
}

Feature StringListFeature(const std::list<std::string> &values) {
    std::unique_ptr<BytesList> valueList(new BytesList);
    for (const auto &value : values) {
        valueList->add_value(value.c_str());
    }
    Feature feature;
    feature.set_allocated_bytes_list(valueList.release());
    return feature;
}

Feature BytesFeature(const std::vector<uchar> &value) {
    std::unique_ptr<BytesList> valueList(new BytesList);
    valueList->add_value(value.data(), value.size());
    Feature feature;
    feature.set_allocated_bytes_list(valueList.release());
    return feature;
}

Feature BytesListFeature(const std::list<std::vector<uchar>> &values) {
    std::unique_ptr<BytesList> valueList(new BytesList);
    for (const auto& value : values) {
        valueList->add_value(value.data(), value.size());
    }
    Feature feature;
    feature.set_allocated_bytes_list(valueList.release());
    return feature;
}

void EncodeImage(Features *features, const cv::Mat &image, const std::string &filename) {
    static const std::string ENCODE_FORMAT = "jpeg";

    static const TString FEATURE_NAME_IMAGE_HEIGHT      = "image/height";
    static const TString FEATURE_NAME_IMAGE_WIDTH       = "image/width";
    static const TString FEATURE_NAME_IMAGE_FILENAME    = "image/filename";
    static const TString FEATURE_NAME_IMAGE_SOURCE_ID   = "image/source_id";
    static const TString FEATURE_NAME_IMAGE_FORMAT      = "image/format";
    static const TString FEATURE_NAME_IMAGE_ENCODED     = "image/encoded";

    std::vector<uchar> encodedImageData;
    cv::imencode(".jpeg", image, encodedImageData);

    (*features->mutable_feature())[FEATURE_NAME_IMAGE_HEIGHT]    = Int64Feature(image.rows);
    (*features->mutable_feature())[FEATURE_NAME_IMAGE_WIDTH]     = Int64Feature(image.cols);
    (*features->mutable_feature())[FEATURE_NAME_IMAGE_FILENAME]  = StringFeature(filename);
    (*features->mutable_feature())[FEATURE_NAME_IMAGE_SOURCE_ID] = StringFeature(filename);
    (*features->mutable_feature())[FEATURE_NAME_IMAGE_FORMAT]    = StringFeature(ENCODE_FORMAT);
    (*features->mutable_feature())[FEATURE_NAME_IMAGE_ENCODED]   = BytesFeature(encodedImageData);
}

template <typename Object>
void EncodeBBoxes(Features *features, const cv::Mat &image, const std::list<Object> &objects) {
    static const TString FEATURE_NAME_OBJECT_XMIN       = "image/object/bbox/xmin";
    static const TString FEATURE_NAME_OBJECT_XMAX       = "image/object/bbox/xmax";
    static const TString FEATURE_NAME_OBJECT_YMIN       = "image/object/bbox/ymin";
    static const TString FEATURE_NAME_OBJECT_YMAX       = "image/object/bbox/ymax";
    static const TString FEATURE_NAME_OBJECT_TEXT       = "image/object/class/text";
    static const TString FEATURE_NAME_OBJECT_LABEL      = "image/object/class/label";

    std::list<int64_t> clLbls;
    std::list<std::string> clTxts;
    std::list<float> xmins, ymins, xmaxs, ymaxs;

    for (const Object &obj : objects) {
        clLbls.emplace_back(obj.label);
        clTxts.emplace_back(obj.text);
        const float x1 = (float)obj.bbox.x / (float)image.cols;
        const float x2 = (float)(obj.bbox.x + obj.bbox.width - 1) / (float)image.cols;
        xmins.emplace_back(std::min(x1, x2));
        xmaxs.emplace_back(std::max(x1, x2));
        const float y1 = (float)obj.bbox.y / (float)image.rows;
        const float y2 = (float)(obj.bbox.y + obj.bbox.height - 1) / (float)image.rows;
        ymins.emplace_back(std::min(y1, y2));
        ymaxs.emplace_back(std::max(y1, y2));
    }

    (*features->mutable_feature())[FEATURE_NAME_OBJECT_XMIN]     = FloatListFeature(xmins);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_XMAX]     = FloatListFeature(xmaxs);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_YMIN]     = FloatListFeature(ymins);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_YMAX]     = FloatListFeature(ymaxs);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_TEXT]     = StringListFeature(clTxts);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_LABEL]    = Int64ListFeature(clLbls);
}

void EncodeMultiLabelsObject(Features *features, const cv::Mat &image, const MultiLabelsObject& object) {
    static const TString FEATURE_NAME_OBJECT_XMIN       = "image/object/bbox/xmin";
    static const TString FEATURE_NAME_OBJECT_XMAX       = "image/object/bbox/xmax";
    static const TString FEATURE_NAME_OBJECT_YMIN       = "image/object/bbox/ymin";
    static const TString FEATURE_NAME_OBJECT_YMAX       = "image/object/bbox/ymax";
    static const TString FEATURE_NAME_OBJECT_TEXT       = "image/object/class/text";
    static const TString FEATURE_NAME_OBJECT_LABEL_SEQ  = "image/object/class/label_sequence";

    const float xmin = (float)object.bbox.x / (float)image.cols;
    const float xmax = (float)(object.bbox.x + object.bbox.width - 1) / (float)image.cols;
    const float ymin = (float)object.bbox.y / (float)image.rows;
    const float ymax = (float)(object.bbox.y + object.bbox.height - 1) / (float)image.rows;

    (*features->mutable_feature())[FEATURE_NAME_OBJECT_XMIN]      = FloatFeature(xmin);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_XMAX]      = FloatFeature(xmax);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_YMIN]      = FloatFeature(ymin);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_YMAX]      = FloatFeature(ymax);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_TEXT]      = StringFeature(object.text);
    (*features->mutable_feature())[FEATURE_NAME_OBJECT_LABEL_SEQ] = Int64ListFeature(object.labels);
}

void ValidateMasks(const cv::Mat &image, const MaskRCNNObjects &objects) {
    for (const MaskRCNNObject &object : objects) {
        REQUIRE(object.mask.type() == CV_8UC1, "Mask must have CV_8UC1 type");
        REQUIRE(object.mask.size() == image.size(), "Mask and image must have equal sizes");
        double minVal, maxVal;
        cv::minMaxLoc(object.mask, &minVal, &maxVal);
        REQUIRE(static_cast<int>(minVal) <= 1 && static_cast<int>(maxVal) == 1,
                "Object mask must be binary: 0 - background, 1 - object");
    }
}

void EncodeMasks(Features *features, const MaskRCNNObjects &objects) {
    static const TString FEATURE_NAME_MASK_ENCODED     = "image/object/mask";

    std::list<std::vector<uchar>> encodedMasksData;
    for (const MaskRCNNObject &object : objects) {
        std::vector<uchar> encodedMaskData;
        cv::imencode(".png", object.mask, encodedMaskData);
        encodedMasksData.push_back(encodedMaskData);
    }

    (*features->mutable_feature())[FEATURE_NAME_MASK_ENCODED] = BytesListFeature(encodedMasksData);
}

void CreateTFExample(const cv::Mat &image, const std::string &filename, const FasterRCNNObjects &objects, TString *output) {
    Features *features = new Features();
    EncodeImage(features, image, filename);
    EncodeBBoxes(features, image, objects);

    Example example;
    example.set_allocated_features(features);
    Y_PROTOBUF_SUPPRESS_NODISCARD example.SerializeToString(output);
}

void CreateTFExample(const cv::Mat &image, const std::string &filename, const MaskRCNNObjects &objects, TString *output) {
    ValidateMasks(image, objects);

    Features *features = new Features();
    EncodeImage(features, image, filename);
    EncodeBBoxes(features, image, objects);
    EncodeMasks(features, objects);

    Example example;
    example.set_allocated_features(features);
    Y_PROTOBUF_SUPPRESS_NODISCARD example.SerializeToString(output);
}

void CreateTFExample(const cv::Mat &image, const std::string &filename, const MultiLabelsObjects &objects, TString *output) {
    REQUIRE(objects.size() == 1, "Must be single multilabels object for one image");

    Features *features = new Features();
    EncodeImage(features, image, filename);
    EncodeMultiLabelsObject(features, image, objects.front());

    Example example;
    example.set_allocated_features(features);
    Y_PROTOBUF_SUPPRESS_NODISCARD example.SerializeToString(output);
}

class WritableFileImpl
    : public tensorflow::WritableFile {
public:
    WritableFileImpl(IOutputStream *stream)
        : stream_(stream) {}

    ~WritableFileImpl() override {
        delete stream_;
    }

    tensorflow::Status Append(const tensorflow::StringPiece& data) override {
        stream_->Write(data.data(), data.size());
        return tensorflow::Status::OK();
    }

    tensorflow::Status Close() override {
        stream_->Finish();
        return tensorflow::Status::OK();
    }

    tensorflow::Status Flush() override {
        stream_->Flush();
        return tensorflow::Status::OK();
    }

    tensorflow::Status Sync() override {
        stream_->Flush();
        return tensorflow::Status::OK();
    }
private:
    IOutputStream *stream_;
};

} // namespace

template <typename Object>
TFRecordWriter<Object>::TFRecordWriter(IOutputStream *stream)
    : recordsCount_(0)
    , objectsCount_(0)
{
    writer_ = std::make_shared<tensorflow::io::RecordWriter>(new WritableFileImpl(stream));
}

template <typename Object>
void TFRecordWriter<Object>::AddRecord(const cv::Mat &image, const std::list<Object> &objects, const std::string &filename) {
    TString example;
    CreateTFExample(image, filename, objects, &example);
    tensorflow::StringPiece strp(example.c_str(), example.length());
    REQUIRE(writer_->WriteRecord(strp).ok(), "Unable to write record");
    recordsCount_++;
    objectsCount_ += objects.size();
}

template <typename Object>
void TFRecordWriter<Object>::AddRecord(const std::string &imagepath, const std::list<Object> &objects) {
    cv::Mat image = cv::imread(imagepath);
    std::string filename = imagepath;
    size_t pos = imagepath.find_last_of("/\\");
    if (std::string::npos != pos)
        filename = imagepath.substr(pos + 1);
    AddRecord(image, objects, filename);
}

template <typename Object>
void TFRecordWriter<Object>::AddRecord(const std::vector<std::uint8_t> &encimage, const std::list<Object> &objects, const std::string &filename) {
    cv::Mat image = cv::imdecode(encimage, 1);
    AddRecord(image, objects, filename);
}

template <typename Object>
void TFRecordWriter<Object>::Flush() {
    REQUIRE(writer_->Flush().ok(), "TFRecordWriter unable to flush file");
}

template class TFRecordWriter<FasterRCNNObject>;

template class TFRecordWriter<MaskRCNNObject>;

template class TFRecordWriter<MultiLabelsObject>;

} // namespace tfrecord_writer
} // namespace wiki
} // namespace maps
