#include <maps/libs/log8/include/log8.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/cmdline/include/cmdline.h>

#include <maps/wikimap/mapspro/libs/tfrecord_writer/include/tfrecord_writer.h>

#include <mapreduce/yt/interface/client.h>

#include <util/system/fs.h>
#include <library/cpp/string_utils/base64/base64.h>

#include <random>

namespace tfw = maps::wiki::tfrecord_writer;

namespace {

static const TString FEATURE_ID = "feature_id";
static const TString IMAGE = "image";
static const TString OBJECTS = "objects";
static const TString BBOX = "bbox";
static const TString TYPE = "type";

static const TString TRAFFIC_LIGHT = "traffic_light";
static const size_t TRAFFIC_LIGHT_INDEX = 1;

std::vector<uint8_t> readJPEG(const TString& encBase64Image) {
    std::vector<std::uint8_t> encimage(Base64DecodeBufSize(encBase64Image.length()));
    size_t encimageSize = Base64Decode(
        encimage.data(),
        encBase64Image.begin(), encBase64Image.end()
    );
    encimage.resize(encimageSize);
    return encimage;
}

tfw::FasterRCNNObjects readObjects(const NYT::TNode& objectsNode) {
    tfw::FasterRCNNObjects objects;
    for (int i = 0; i < objectsNode.AsList().ysize(); i++) {
        const NYT::TNode& objectNode = objectsNode[i];
        const TString& type = objectNode[TYPE].AsString();
        REQUIRE(TRAFFIC_LIGHT == type, "Unknown type: " + type);
        tfw::FasterRCNNObject object;
        object.text = TRAFFIC_LIGHT;
        object.label = TRAFFIC_LIGHT_INDEX;
        const NYT::TNode& bboxNode = objectNode[BBOX];
        object.bbox = cv::Rect(
            cv::Point(bboxNode[0][0].AsInt64(), bboxNode[0][1].AsInt64()),
            cv::Point(bboxNode[1][0].AsInt64(), bboxNode[1][1].AsInt64())
        );
        objects.push_back(object);
    }
    return objects;
}

struct Item {
    explicit Item(const NYT::TNode& node)
        : featureId(node[FEATURE_ID].AsInt64())
        , encimage(readJPEG(node[IMAGE].AsString()))
        , objects(readObjects(node[OBJECTS]))
    {
    }


    int64_t featureId;
    std::vector<uint8_t> encimage;
    tfw::FasterRCNNObjects objects;
};

void saveLabelMap(const std::string& labelMapPath) {
    std::ofstream ofs(labelMapPath);
    ofs << "item {" << std::endl;
    ofs << "  id: " << TRAFFIC_LIGHT_INDEX << std::endl;
    ofs << "  name: '" << TRAFFIC_LIGHT << "'" << std::endl;
    ofs << "}" << std::endl << std::endl;
    ofs.close();
}

void eraseSmallObjects(tfw::FasterRCNNObjects& objects, int minSize) {
    objects.erase(
        std::remove_if(
            objects.begin(), objects.end(),
            [&](const tfw::FasterRCNNObject& object) {
                return object.bbox.width < minSize
                    && object.bbox.height < minSize;
            }
        ),
        objects.end()
    );
}

void addItem(tfw::TFRecordWriter<tfw::FasterRCNNObject>& writer, const Item& item) {
    writer.AddRecord(item.encimage, item.objects, std::to_string(item.featureId));
}

} // namespace

int main(int argc, const char** argv) try {
    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser("Download traffic light dataset from YT table to tfrecord");

    maps::cmdline::Option<std::string> ytPath = parser.string("yt_path")
        .required()
        .help("Path to YT table");

    maps::cmdline::Option<std::string> trainTFRecordPath = parser.string("train_tfrecord")
        .required()
        .help("Path to output file for train data");

    maps::cmdline::Option<std::string> testTFRecordPath = parser.string("test_tfrecord")
        .help("Path to output file for test data");

    maps::cmdline::Option<std::string> labelMapPath = parser.string("labelmap")
        .required()
        .help("Path to output label_map file with classes");

    maps::cmdline::Option<int> minSize = parser.num("min_size")
        .defaultValue(0)
        .help("Minimal value of size (width or height) of traffic light (default: 0)");

     maps::cmdline::Option<double> trainPart = parser.real("train_part")
        .help("Part of dataset for building train.tfrecord (from 0. to 1.0) other part will use for test.tfrecord");

    auto randomSeed = parser.num("seed")
        .defaultValue(42)
        .help("Seed of random generator to split train and test records (default: 42)");

    parser.parse(argc, const_cast<char**>(argv));

    REQUIRE(
        (testTFRecordPath.defined() && trainPart.defined())
        ||
        (!testTFRecordPath.defined() && !trainPart.defined()),
        "Path to test TFRecord and train part must be defined or not defined at same time"
    );

    INFO() << "Connecting to yt::hahn";
    NYT::IClientPtr client = NYT::CreateClient("hahn");

    if (testTFRecordPath.defined() && trainPart.defined()) {
        INFO() << "Creating train TFRecord writer: " << trainTFRecordPath;
        TFileOutput trainFile{TString(trainTFRecordPath)};
        tfw::TFRecordWriter<tfw::FasterRCNNObject> trainWriter(&trainFile);

        INFO() << "Creating test TFRecord writer: " << testTFRecordPath;
        TFileOutput testFile{TString(testTFRecordPath)};
        tfw::TFRecordWriter<tfw::FasterRCNNObject> testWriter(&testFile);

        INFO() << "Saving train and test TFRecords";
        std::mt19937 rndGen(randomSeed);
        std::uniform_real_distribution<double> testProb(0.0, 1.0);

        NYT::TTableReaderPtr<NYT::TNode> reader
            = client->CreateTableReader<NYT::TNode>(NYT::TYPath(ytPath));
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode& row = reader->GetRow();
            Item item(row);
            eraseSmallObjects(item.objects, minSize);
            if (item.objects.empty()) {
                INFO() << "Feature with id = " << item.featureId
                       << " has no suitable objects";
                continue;
            }
            if (testProb(rndGen) <= trainPart) {
                INFO() << "Add feature with id = " << item.featureId
                       << " into train dataset";
                addItem(trainWriter, item);
            } else {
                INFO() << "Add feature with id = " << item.featureId
                       << " into test dataset";
                addItem(testWriter, item);
            }
        }

        INFO() << "Train images count: " << trainWriter.GetRecordsCount();
        INFO() << "Train objects count: " << trainWriter.GetObjectsCount();

        INFO() << "Test images count: " << testWriter.GetRecordsCount();
        INFO() << "Test objects count: " << testWriter.GetObjectsCount();
    } else {
        INFO() << "Creating TFRecord writer: " << trainTFRecordPath;
        TFileOutput trainFile{TString(trainTFRecordPath)};
        tfw::TFRecordWriter<tfw::FasterRCNNObject> trainWriter(&trainFile);

        INFO() << "Saving TFRecord";
        NYT::TTableReaderPtr<NYT::TNode> reader
            = client->CreateTableReader<NYT::TNode>(NYT::TYPath(ytPath));
        for (; reader->IsValid(); reader->Next()) {
            const NYT::TNode& row = reader->GetRow();
            Item item(row);
            eraseSmallObjects(item.objects, minSize);
            if (item.objects.empty()) {
                INFO() << "Feature with id = " << item.featureId
                       << " has no suitable objects";
                continue;
            }
            INFO() << "Add feature with id = " << item.featureId << " into dataset";
            addItem(trainWriter, item);
        }

        INFO() << "Images count: " << trainWriter.GetRecordsCount();
        INFO() << "Objects count: " << trainWriter.GetObjectsCount();
    }

    INFO() << "Saving label map: " << labelMapPath;
    saveLabelMap(labelMapPath);

    return 0;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
