#include <maps/libs/common/include/exception.h>
#include <maps/libs/log8/include/log8.h>
#include <maps/libs/cmdline/include/cmdline.h>

#include <mapreduce/yt/interface/client.h>

#include <random>
#include <string>

using namespace NYT;

class TSplitMapper
    : public IMapper<TTableReader<TNode>, TTableWriter<TNode>>
{
public:
    Y_SAVELOAD_JOB(Seed_, TrainPart_);

    TSplitMapper() = default;

    TSplitMapper(int seed, double trainPart)
        : Seed_(seed)
        , TrainPart_(trainPart)
    { }

    void Do(TReader* reader, TWriter* writer) override
    {
        std::default_random_engine rndGen(Seed_);
        std::uniform_real_distribution<double> rndUniformDistr(0.0, 1.0);
        for (; reader->IsValid(); reader->Next()) {
            writer->AddRow( reader->GetRow(), (rndUniformDistr(rndGen) < TrainPart_) ? 0 : 1);
        }
    }
private:
    int Seed_;
    double TrainPart_;
};
REGISTER_MAPPER(TSplitMapper);

int main(int argc, const char** argv) try {
    Initialize(argc, argv);

    maps::cmdline::Parser parser("Split traffic signs or house numbers dataset YT table to train and test tables");

    maps::cmdline::Option<std::string> inputTable = parser.string("input")
        .required()
        .help("Input YT table name");

    maps::cmdline::Option<std::string> outputTrainTable = parser.string("outtrain")
        .required()
        .help("Output YT table for train data");

    maps::cmdline::Option<std::string> outputTestTable = parser.string("outtest")
        .required()
        .help("Output YT table for test data");

    maps::cmdline::Option<double> trainPart = parser.real("trainpart")
        .defaultValue(.9)
        .help("Part of dataset for training (from 0. to 1.0) other part will use for test  (default: 0.9)");

    maps::cmdline::Option<int> randomSeed = parser.num("seed")
        .defaultValue(42)
        .help("Seed of random generator to split train and test records (default: 42)");

    parser.parse(argc, const_cast<char**>(argv));

    INFO() << "Connecting to yt::hahn";
    IClientPtr client = CreateClient("hahn");

    static const TString TABLE_ATTR_NAME_SCHEMA = "schema";
    TString attr = inputTable.c_str();
    attr += "/@" + TABLE_ATTR_NAME_SCHEMA;
    TNode schema = client->Get(attr);

    client->Create(outputTrainTable.c_str(),
                   NYT::NT_TABLE,
                   NYT::TCreateOptions().Attributes(NYT::TNode()("schema", schema)));
    client->Create(outputTestTable.c_str(),
                   NYT::NT_TABLE,
                   NYT::TCreateOptions().Attributes(NYT::TNode()("schema", schema)));

    client->Map(
        TMapOperationSpec()
        .Ordered(true)
        .AddInput<TNode>(inputTable.c_str())
        .AddOutput<TNode>(outputTrainTable.c_str())
        .AddOutput<TNode>(outputTestTable.c_str()),
        new TSplitMapper(randomSeed, trainPart));

    return 0;
}
catch (const maps::Exception& e) {
    FATAL() << "Worker failed: " << e;
    return EXIT_FAILURE;
}
catch (const std::exception& e) {
    FATAL() << "Worker failed: " << e.what();
    return EXIT_FAILURE;
}
