#include <maps/libs/log8/include/log8.h>
#include <maps/libs/common/include/exception.h>
#include <maps/libs/cmdline/include/cmdline.h>

#include <maps/libs/common/include/file_utils.h>

#include <maps/wikimap/mapspro/libs/tfrecord_writer/include/tfrecord_writer.h>

#include <maps/wikimap/mapspro/services/autocart/libs/geometry/include/hex_wkb.h>

#include <maps/libs/geolib/include/polygon.h>
#include <maps/libs/geolib/include/linear_ring.h>
#include <maps/libs/geolib/include/bounding_box.h>

#include <mapreduce/yt/util/temp_table.h>
#include <mapreduce/yt/interface/client.h>

#include <library/cpp/string_utils/base64/base64.h>

#include <util/string/split.h>

#include <cmath>
#include <vector>
#include <random>
#include <fstream>

using namespace maps::common;
using namespace maps::geolib3;
using namespace maps::wiki::autocart;
using namespace maps::wiki::tfrecord_writer;

namespace {

static const TString IMAGE = "image";
static const TString SHAPE = "shape";
static const TString OBJECTS = "objects";

// Read YT table paths from text file.
// File example:
//   //home/table1
//   //home/table2
//   //home/table3
//
TVector<NYT::TYPath> readYTTablePathsFromFile(const std::string& filepath) {
    TVector<NYT::TYPath> paths;
    std::ifstream ifs(filepath);
    while (!ifs.eof()) {
        std::string line;
        std::getline(ifs, line);
        if (line.empty()) {
            continue;
        }
        paths.emplace_back(line);
    }
    return paths;
}

// Checks that all tables in the list exist
// Throw exception if at least on of tables does not exist
void checkAllYTTablesExist(
    NYT::IClientBasePtr client,
    const TVector<NYT::TYPath>& ytTablePaths)
{
    for (const NYT::TYPath& ytTablePath : ytTablePaths) {
        REQUIRE(
            client->Exists(ytTablePath),
            "Table " << ytTablePath << " does not exist"
        );
    }
}

// Check that only one cmd option is defined
// Function values (XOR):
// | lhs   | rhs   | value |
// -------------------------
// | true  | true  | false |
// | true  | false | true  |
// | false | true  | true  |
// | false | false | false |
//
template <typename T, typename U>
bool isOnlyOneOptionDefined(
    const maps::cmdline::Option<T>& lhs,
    const maps::cmdline::Option<U>& rhs)
{
    return lhs.defined() ? !rhs.defined() : rhs.defined();
}

// Concatenate few tables into one
// Return path to result of concatenation
NYT::TYPath mergeYTTables(
    NYT::IClientBasePtr client,
    const TVector<NYT::TYPath>& ytTablePaths)
{
    NYT::TTempTable mergedYTTable(client);
    mergedYTTable.Release();
    client->Concatenate(
        ytTablePaths,
        mergedYTTable.Name(),
        NYT::TConcatenateOptions().Append(false)
    );
    return mergedYTTable.Name();
}

class DatasetSplitMapper
    : public NYT::IMapper<NYT::TTableReader<NYT::TNode>,
                          NYT::TTableWriter<NYT::TNode>>
{
public:
    DatasetSplitMapper() = default;
    DatasetSplitMapper(
        const maps::cmdline::Option<double>& validRatio,
        const maps::cmdline::Option<double>& testRatio,
        size_t randomSeed)
        : randomSeed_(randomSeed)
    {
        if (validRatio.defined()) {
            splitValidDataset_ = true;
            validRatio_ = validRatio;
        } else {
            splitValidDataset_ = false;
        }
        if (testRatio.defined()) {
            splitTestDataset_ = true;
            testRatio_ = testRatio;
        } else {
            splitTestDataset_ = false;
        }
    }

    Y_SAVELOAD_JOB(
        splitValidDataset_, validRatio_,
        splitTestDataset_, testRatio_,
        randomSeed_
    );

    void Do(NYT::TTableReader<NYT::TNode>* reader,
            NYT::TTableWriter<NYT::TNode>* writer) override
    {
        std::default_random_engine rndGen(randomSeed_);
        std::uniform_real_distribution<double> rndUniformDistr(0.0, 1.0);
        for (; reader->IsValid(); reader->Next()) {
            writer->AddRow(
                reader->GetRow(),
                getDatasetIndex(rndUniformDistr(rndGen))
            );
        }
    }

private:
    // Return index of dataset based on value 'p' of random variable
    // 1) If split into train and valid datasets:
    //      [******###], * - valid dataset, # - train dataset
    //      p in [0, valid ratio) - valid dataset
    //      p in [valid ratio, 1] - train dataset
    // 2) If split into train and test datasets:
    //      [******###], * - test dataset, # - train dataset
    //      p in [0, test ratio) - test dataset
    //      p in [test ratio, 1] - train dataset
    // 3) If split into train, valid, test datasets:
    //      [***###ooo], * - valid dataset, # - test dataset, o - train dataset
    //      p in [0, valid ratio) - valid dataset
    //      p in [valid ratio, valid ratio + test ratio) - test dataset
    //      p in [valid ratio + test ratio, 1] - train dataset
    size_t getDatasetIndex(double p) const {
        constexpr size_t trainIndex = 0;
        constexpr size_t validIndex = 1;
        constexpr size_t testIndex = 2;
        if (splitValidDataset_ && !splitTestDataset_) {
            if (p >= validRatio_) {
                return trainIndex;
            } else {
                return validIndex;
            }
        } else if (!splitValidDataset_ && splitTestDataset_) {
            if (p >= testRatio_) {
                return trainIndex;
            } else {
                return testIndex;
            }
        } else {
            if (p >= validRatio_ + testRatio_) {
                return trainIndex;
            } else if (p >= validRatio_) {
                return testIndex;
            } else {
                return validIndex;
            }
        }
        return trainIndex;
    }

    bool splitValidDataset_;
    double validRatio_;
    bool splitTestDataset_;
    double testRatio_;
    size_t randomSeed_;
};

REGISTER_MAPPER(DatasetSplitMapper);

void makeTrainValidTestSplit(
    NYT::IClientBasePtr client,
    const TVector<NYT::TYPath>& trainYTTablePaths,
    const NYT::TYPath& trainYTTablePath,
    const maps::cmdline::Option<double>& validRatio,
    const NYT::TYPath& validYTTablePath,
    const maps::cmdline::Option<double>& testRatio,
    const NYT::TYPath& testYTTablePath,
    size_t randomSeed)
{
    NYT::TMapOperationSpec spec;
    for (const NYT::TYPath& ytTablePath : trainYTTablePaths) {
        spec.AddInput<NYT::TNode>(ytTablePath);
    }
    spec.AddOutput<NYT::TNode>(trainYTTablePath);
    spec.AddOutput<NYT::TNode>(validYTTablePath);
    spec.AddOutput<NYT::TNode>(testYTTablePath);
    client->Map(
        spec,
        new DatasetSplitMapper(validRatio, testRatio, randomSeed)
    );
}

cv::Mat readImage(const NYT::TNode& node) {
    TString encimageStr = node.AsString();
    std::vector<std::uint8_t> encimage(Base64DecodeBufSize(encimageStr.length()));
    size_t encimageSize = Base64Decode(encimage.data(), encimageStr.begin(), encimageStr.end());
    encimage.resize(encimageSize);
    return cv::imdecode(encimage, cv::IMREAD_COLOR);
}

cv::Rect toCVRect(const BoundingBox& bbox) {
    return cv::Rect(
        cv::Point(std::floor(bbox.minX()), std::floor(bbox.minY())),
        cv::Point(std::ceil(bbox.maxX()), std::ceil(bbox.maxY()))
    );
}

cv::Point toCVPoint(const Point2& point) {
    return cv::Point(point.x(), point.y());
}

std::vector<cv::Point> toCVPoints(const LinearRing2& ring) {
    std::vector<cv::Point> cvPoints;
    for (size_t i = 0; i < ring.pointsNumber(); i++) {
        cvPoints.push_back(toCVPoint(ring.pointAt(i)));
    }
    return cvPoints;
}

// Draw polygon with holes in image
// Color of background and holes - 0, color of polygon - 1
void drawPolygon(cv::Mat& mask, const Polygon2& shape) {
    cv::Scalar BACKGROUND_COLOR(0);
    cv::Scalar OBJECT_COLOR(1);

    std::vector<cv::Point> exteriorRing = toCVPoints(shape.exteriorRing());
    cv::fillPoly(mask, {exteriorRing}, OBJECT_COLOR);
    std::vector<std::vector<cv::Point>> interiorRings;
    for (size_t i = 0; i < shape.interiorRingsNumber(); i++) {
        interiorRings.push_back(
            toCVPoints(shape.interiorRingAt(i))
        );
    }
    cv::fillPoly(mask, interiorRings, BACKGROUND_COLOR);
}

MaskRCNNObjects readObjects(const NYT::TNode& node, const cv::Size& imageSize) {
    constexpr int BLD_LABEL = 1;

    MaskRCNNObjects objects;
    for (size_t i = 0; i < node.AsList().size(); i++) {
        MaskRCNNObject object;
        cv::Mat mask(imageSize, CV_8UC1, cv::Scalar::all(0));
        Polygon2 shape = hexWKBToPolygon(node[i][SHAPE].AsString());
        drawPolygon(mask, shape);
        cv::Rect bbox = toCVRect(shape.boundingBox());
        objects.push_back({BLD_LABEL, "bld", bbox, mask});
    }
    return objects;
}

void createTFRecord(
    NYT::IClientBasePtr client,
    const NYT::TYPath& ytTablePath,
    const std::string& tfRecordPath)
{
    constexpr size_t LOG_STEP = 1000;

    TFileOutput file(tfRecordPath.c_str());
    TFRecordWriter<MaskRCNNObject> tfrecordWriter(&file);
    NYT::TTableReaderPtr<NYT::TNode> reader
        = client->CreateTableReader<NYT::TNode>(ytTablePath);
    for (size_t i  = 0; reader->IsValid(); reader->Next(), i++) {
        const NYT::TNode& row = reader->GetRow();
        cv::Mat image = readImage(row[IMAGE]);
        MaskRCNNObjects objects = readObjects(row[OBJECTS], image.size());
        tfrecordWriter.AddRecord(image, objects);
        if (i % LOG_STEP == 0) {
            INFO() << "Processed " << i << " items";
        }
    }
}

} // namespace

int main(int argc, char** argv)
try {
    NYT::Initialize(argc, argv);

    maps::cmdline::Parser parser("Prepare datasets for train and test Mask R-CNN");

    maps::cmdline::Option<std::string> trainYTTablesFilePath = parser.string("train_tables")
        .required()
        .help("List of tables with data for training (each in new line)");

    maps::cmdline::Option<std::string> trainTFRecordPath = parser.string("train_tfrecord")
        .required()
        .help("Path to tfrecord with training dataset");

    maps::cmdline::Option<double> validRatio = parser.real("valid_ratio")
        .help("Use N * valid_ratio random samples from training dataset for validation");

    maps::cmdline::Option<std::string> validYTTablesFilePath = parser.string("valid_tables")
        .help("List of tables with data for validation (each in new line)");

    maps::cmdline::Option<std::string> validTFRecordPath = parser.string("valid_tfrecord")
        .required()
        .help("Path to tfrecord with validation dataset");

    maps::cmdline::Option<double> testRatio = parser.real("test_ratio")
        .help("Use N * test_ratio random samples from training dataset for testing");

    maps::cmdline::Option<std::string> testYTTablesFilePath = parser.string("test_tables")
        .help("List of tables with data for testing (each in new line)");

    maps::cmdline::Option<std::string> testTFRecordPath = parser.string("test_tfrecord")
        .required()
        .help("Path to tfrecord with testing dataset");

    maps::cmdline::Option<size_t> randomSeed = parser.size_t("seed")
        .defaultValue(42)
        .help("Seed for random numbers generator");

    parser.parse(argc, argv);

    INFO() << "Creating YT client: yt::hahn";
    NYT::IClientPtr client = NYT::CreateClient("hahn");

    INFO() << "Creating YT transaction";
    NYT::ITransactionPtr txn = client->StartTransaction();

    INFO() << "Reading train YT tables list: " << trainYTTablesFilePath;
    TVector<NYT::TYPath> trainYTTablePaths = readYTTablePathsFromFile(trainYTTablesFilePath);
    checkAllYTTablesExist(txn, trainYTTablePaths);

    INFO() << "Checking validation dataset parameters";
    TVector<NYT::TYPath> validYTTablePaths;
    REQUIRE(
        isOnlyOneOptionDefined(validRatio, validYTTablesFilePath),
        "Only one option can be defined: valid_ratio or valid_tables"
    );
    if (validRatio.defined()) {
        REQUIRE(
            0 < validRatio && validRatio < 1,
            "Validation dataset ratio should be in range (0, 1)"
        );
    } else {
        validYTTablePaths = readYTTablePathsFromFile(validYTTablesFilePath);
        checkAllYTTablesExist(txn, validYTTablePaths);
    }
    INFO() << "Validation dataset parameters are valid";

    INFO() << "Checking test dataset parameters";
    TVector<NYT::TYPath> testYTTablePaths;
    REQUIRE(
        isOnlyOneOptionDefined(testRatio, testYTTablesFilePath),
        "Only one option can be defined: test_ratio or test_tables"
    );
    if (testRatio.defined()) {
        REQUIRE(
            0 < testRatio && testRatio < 1,
            "Test dataset ratio should be in range (0, 1)"
        );
    } else {
        testYTTablePaths = readYTTablePathsFromFile(testYTTablesFilePath);
        checkAllYTTablesExist(txn, testYTTablePaths);
    }
    INFO() << "Test dataset parameters are valid";

    TString trainYTTablePath;
    TString validYTTablePath;
    TString testYTTablePath;

    INFO() << "Preparing YT tables with datasets";
    if (validRatio.defined() || testRatio.defined()) {
        INFO() << "Splitting train dataset";
        NYT::TTempTable tmpTrainYTTable(txn);
        tmpTrainYTTable.Release();
        NYT::TTempTable tmpValidYTTable(txn);
        tmpValidYTTable.Release();
        NYT::TTempTable tmpTestYTTable(txn);
        tmpTestYTTable.Release();

        makeTrainValidTestSplit(
            txn,
            trainYTTablePaths,
            tmpTrainYTTable.Name(),
            validRatio,
            tmpValidYTTable.Name(),
            testRatio,
            tmpTestYTTable.Name(),
            randomSeed
        );

        trainYTTablePath = tmpTrainYTTable.Name();
        if (validRatio.defined()) {
            validYTTablePath = tmpValidYTTable.Name();
        }
        if (testRatio.defined()) {
            testYTTablePath = tmpTestYTTable.Name();
        }
    } else {
        INFO() << "Concatenating train YT tables";
        trainYTTablePath = mergeYTTables(txn, trainYTTablePaths);
    }

    if (!validRatio.defined()) {
        INFO() << "Concatenating valid YT tables";
        validYTTablePath = mergeYTTables(txn, validYTTablePaths);
    }
    if (!testRatio.defined()) {
        INFO() << "Concatenating test YT tables";
        testYTTablePath = mergeYTTables(txn, testYTTablePaths);
    }

    INFO() << "Creating train tfrecord: " << trainTFRecordPath;
    createTFRecord(txn, trainYTTablePath, trainTFRecordPath);
    INFO() << "Creating valid tfrecord: " << validTFRecordPath;
    createTFRecord(txn, validYTTablePath, validTFRecordPath);
    INFO()<< "Creating test tfrecord: " << testTFRecordPath;
    createTFRecord(txn, testYTTablePath, testTFRecordPath);

    INFO() << "Finishing";
    txn->Abort();

    return EXIT_SUCCESS;
}
catch (const maps::Exception& e) {
    INFO() << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    INFO() << e.what();
    return EXIT_FAILURE;
}
catch (...) {
    INFO() << "Caught unknown exception";
    return EXIT_FAILURE;
}
