#include "client.h"

#include <saas/library/rtyt/lib/operation/factory.h>
#include <saas/library/rtyt/lib/io/ut_row.pb.h>
#include <library/cpp/testing/unittest/registar.h>
#include <mapreduce/yt/interface/operation.h>
#include <mapreduce/yt/interface/io.h>

namespace NTestMultiple {

    using ::NRTYT::NTesting::TRow;
    using ::NRTYT::NTesting::TRow2;
    class TProtoMultiplyingMapper
        : public NYT::IMapper<NYT::TTableReader<NRTYT::NTesting::TRow>, NYT::TTableWriter<NRTYT::NTesting::TRow>>
    {
    public:
        TProtoMultiplyingMapper() = default;

        void Do(TReader* input, TWriter* output) override {
            NRTYT::NTesting::TRow row;
            for (; input->IsValid(); input->Next()) {
                row = input->GetRow();
                i64 result = row.GetInt32() * row.GetFixed64();
                row.SetFixed64(result);
                output->AddRow(row);
            }
        }
    };

    REGISTER_RTYT_MAPPER(TProtoMultiplyingMapper);

    class TProtoMultipleTablesMapper
        : public NYT::IMapper<NYT::TTableReader<::google::protobuf::Message>, NYT::TTableWriter<NRTYT::NTesting::TRow>>
    {
    public:
        TProtoMultipleTablesMapper() = default;

        void Do(TReader* input, TWriter* output) override {
            NRTYT::NTesting::TRow row;
            NRTYT::NTesting::TRow2 row2;
            for (; input->IsValid(); input->Next()) {
                if (input->GetTableIndex() == 0) {
                    DEBUG_LOG << "first table" << Endl;
                    row = input->GetRow<NRTYT::NTesting::TRow>();
                    i64 result = row.GetInt32() * row.GetFixed64();
                    row.SetFixed64(result);
                    row.SetInt32(0);
                    output->AddRow(row);
                } else if (input->GetTableIndex() == 1) {
                    row2 = input->GetRow<NRTYT::NTesting::TRow2>();
                    i64 result = row2.GetStringField().size();
                    row.SetFixed64(result);
                    row.SetInt32(1);
                    output->AddRow(row);
                    DEBUG_LOG << "second table" << Endl;
                } else {
                    ERROR_LOG << "unknown table index" << Endl;
                }
            }
        }
    };

    REGISTER_RTYT_MAPPER(TProtoMultipleTablesMapper);

    class TMultipleTableReducer
        : public NYT::IReducer<NYT::TTableReader<::google::protobuf::Message>, NYT::TTableWriter<NRTYT::NTesting::TRow>>
    {
        int Modifier = 1;
    public:
        TMultipleTableReducer() = default;
        TMultipleTableReducer(int mod) : Modifier(mod) {
        }

        void Do(TReader* input, TWriter* output) override {
            int result = 0;
            int key = 0;
            for (; input->IsValid(); input->Next()) {
                if (input->GetTableIndex() == 0) {
                    auto row = input->GetRow<TRow>();
                    result += Modifier * row.GetInt32();
                    key = row.GetFixed64();
                } else {
                    auto row = input->GetRow<TRow2>();
                    result += (row.GetStringField().size()) * 100;
                    key = row.GetFixed64();
                }
            }
            NRTYT::NTesting::TRow row;
            row.SetInt32(result);
            row.SetFixed64(key);
            output->AddRow(row);
        }

        Y_SAVELOAD_JOB(Modifier);
    };
    REGISTER_RTYT_REDUCER(TMultipleTableReducer);

    class TMultipleWritersReducer
        : public NYT::IReducer<NYT::TTableReader<::google::protobuf::Message>, NYT::TTableWriter<::google::protobuf::Message>>
    {
        int Modifier = 1;
    public:
        TMultipleWritersReducer() = default;
        TMultipleWritersReducer(int mod) : Modifier(mod) {
        }

        void Do(TReader* input, TWriter* output) override {
            int result = 0;
            int key = 0;
            for (; input->IsValid(); input->Next()) {
                if (input->GetTableIndex() == 0) {
                    auto row = input->GetRow<TRow>();
                    result += Modifier * row.GetInt32();
                    key = row.GetFixed64();
                } else {
                    auto row = input->GetRow<TRow2>();
                    result += (row.GetStringField().size()) * 100;
                    key = row.GetFixed64();
                }
            }
            if (key % 2 == 0) {
                NRTYT::NTesting::TRow row;
                row.SetInt32(result);
                row.SetFixed64(key);
                output->AddRow<TRow>(row, 0);
            } else {
                NRTYT::NTesting::TRow2 row2;
                row2.SetFixed64(result);
                output->AddRow<TRow2>(row2, 1);
            }
        }

        Y_SAVELOAD_JOB(Modifier);
    };
    REGISTER_RTYT_REDUCER(TMultipleWritersReducer);

    class TMultiConcatReducer
        : public NYT::IReducer<NYT::TTableReader<::google::protobuf::Message>, NYT::TTableWriter<NRTYT::NTesting::TRow>>
    {
    public:
        TMultiConcatReducer() = default;

        void Do(TReader* input, TWriter* output) override {
            TString result;
            int key = 0;
            for (; input->IsValid(); input->Next()) {
                if (input->GetTableIndex() == 0) {
                    auto row = input->GetRow<TRow>();
                    result += row.GetString();
                    key = row.GetFixed64();
                } else {
                    auto row = input->GetRow<TRow2>();
                    result += row.GetStringField();
                    key = row.GetFixed64();
                }
            }
            NRTYT::NTesting::TRow row;
            row.SetString(result);
            row.SetFixed64(key);
            output->AddRow(row);
        }

    };
    REGISTER_RTYT_REDUCER(TMultiConcatReducer);
        
}

Y_UNIT_TEST_SUITE(RTYT_MULTIPLE_TABLES) {
    using namespace NTestMultiple;
    Y_UNIT_TEST(MAP_MULTIPLE_TABLES) {
        using NRTYT::NTesting::TRow;
        NRTYT::TClientBase client(".");
        client.Create("//test1", NYT::NT_MAP, NYT::TCreateOptions());
        TRow sample;
        {
            auto writer = client.CreateTableWriter("//test1/test2", *sample.GetDescriptor());
            TRow first;
            first.SetString("abcabc");
            first.SetInt32(1);
            first.SetFixed64(1);
            writer->AddRow(first, 0);

            TRow second;
            second.SetString("");
            second.SetInt32(0);
            second.SetFixed64(2);
            writer->AddRow(second, 0);
        }
        {
            auto writer = client.CreateTableWriter("//test1/test3", *sample.GetDescriptor());
            TRow first;
            first.SetString("abcabc");
            first.SetInt32(1);
            first.SetFixed64(2);
            writer->AddRow(first, 0);

            TRow second;
            second.SetString("");
            second.SetInt32(2);
            second.SetFixed64(2);
            writer->AddRow(second, 0);
        }


        NYT::TMapOperationSpec spec;
        spec.AddInput<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/test2"))
            .AddInput<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/test3"))
            .AddOutput<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/output"));
        client.Map(spec, new TProtoMultiplyingMapper(), NYT::TOperationOptions());

        auto reader = client.CreateTableReader<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/output"));
        TRow row;
        TVector<bool> rows(5, false);
        for (int i = 0; i < 4; i++) {
            row = reader->GetRow();
            UNIT_ASSERT_C(!rows[row.GetFixed64()], "We already read this row");
            rows[row.GetFixed64()] = true;
            reader->Next();
        }
        UNIT_CHECK_GENERATED_EXCEPTION_C(row = reader->GetRow(), yexception, "end of table");
        UNIT_ASSERT_C(rows[0], "we should've read this row");
        UNIT_ASSERT_C(rows[1], "we should've read this row");
        UNIT_ASSERT_C(rows[2], "we should've read this row");
        UNIT_ASSERT_C(!rows[3], "Only this row we shouldn't read");
        UNIT_ASSERT_C(rows[4], "we should've read this row");
    }

    Y_UNIT_TEST(MAP_MULTIPLE_TABLES_WITH_NONEXISTENT) {
        using NRTYT::NTesting::TRow;
        NRTYT::TClientBase client(".");
        client.Create("//test1", NYT::NT_MAP, NYT::TCreateOptions());
        TRow sample;
        {
            auto writer = client.CreateTableWriter("//test1/test2", *sample.GetDescriptor());
            TRow first;
            first.SetString("abcabc");
            first.SetInt32(1);
            first.SetFixed64(1);
            writer->AddRow(first, 0);

            TRow second;
            second.SetString("");
            second.SetInt32(0);
            second.SetFixed64(2);
            writer->AddRow(second, 0);
        }

        NYT::TMapOperationSpec spec;
        spec.AddInput<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/test2"))
            .AddInput<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/test3"))
            .AddOutput<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/output"));
        client.Map(spec, new TProtoMultiplyingMapper(), NYT::TOperationOptions());

        auto reader = client.CreateTableReader<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/output"));
        TRow row;
        TVector<bool> rows(5, false);
        for (int i = 0; i < 2; i++) {
            row = reader->GetRow();
            UNIT_ASSERT_C(!rows[row.GetFixed64()], "We already read this row");
            rows[row.GetFixed64()] = true;
            reader->Next();
        }
        UNIT_CHECK_GENERATED_EXCEPTION_C(row = reader->GetRow(), yexception, "end of table");
        UNIT_ASSERT_C(rows[0], "we should've read this row");
        UNIT_ASSERT_C(rows[1], "we should've read this row");
        UNIT_ASSERT_C(!rows[2], "we shouldn't read this row");
        UNIT_ASSERT_C(!rows[3], "we shouldn't read this row");
        UNIT_ASSERT_C(!rows[4], "we shouldn't read this row");
    }

    Y_UNIT_TEST(MAP_TABLES_WITH_DIFFERENT_PROTO) {
        using NRTYT::NTesting::TRow;
        using NRTYT::NTesting::TRow2;
        NRTYT::TClientBase client(".");
        client.Create("//test1", NYT::NT_MAP, NYT::TCreateOptions());
        TRow sample;
        {
            auto writer = client.CreateTableWriter("//test1/test2", *sample.GetDescriptor());
            TRow first;
            first.SetString("abcabc");
            first.SetInt32(1);
            first.SetFixed64(1);
            writer->AddRow(first, 0);

            TRow second;
            second.SetString("");
            second.SetInt32(0);
            second.SetFixed64(2);
            writer->AddRow(second, 0);
        }
        {
            TRow2 sample2;
            auto writer = client.CreateTableWriter("//test1/test3", *sample2.GetDescriptor());
            TRow2 first;
            first.SetStringField("abcabc");
            first.SetFixed64(2);
            writer->AddRow(first, 0);

            TRow2 second;
            second.SetStringField("length = 11");
            second.SetFixed64(2);
            writer->AddRow(second, 0);
        }


        NYT::TMapOperationSpec spec;
        spec.AddInput<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/test2"))
            .AddInput<NRTYT::NTesting::TRow2>(NYT::TRichYPath("//test1/test3"))
            .AddOutput<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/output"));
        client.Map(spec, new TProtoMultipleTablesMapper(), NYT::TOperationOptions());

        auto reader = client.CreateTableReader<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/output"));
        TRow row;
        TMap<std::pair<int, int>, bool> readRows;
        for (int i = 0; i < 4; i++) {
            row = reader->GetRow();
            UNIT_ASSERT_C((!readRows[std::make_pair<int, int>(row.GetFixed64(), row.GetInt32())]), "We already read this row");
            readRows[std::make_pair<int, int>(row.GetFixed64(), row.GetInt32())] = true;
            reader->Next();
        }
        UNIT_CHECK_GENERATED_EXCEPTION_C(row = reader->GetRow(), yexception, "end of table");
        UNIT_ASSERT_C((readRows[std::make_pair<int, int>(1, 0)]), "we should've read this row (first table, 1 * 1)");
        UNIT_ASSERT_C((readRows[std::make_pair<int, int>(0, 0)]), "we should've read this row (first table, 0 * 2)");
        UNIT_ASSERT_C((readRows[std::make_pair<int, int>(6, 1)]), "we should've read this row (second table, length(abcabc))");
        UNIT_ASSERT_C((readRows[std::make_pair<int, int>(11, 1)]), "we should've read this row (second table, length(length = 11))");
    }
    Y_UNIT_TEST(REDUCE_MULTIPLE_TABLES) {
        using NRTYT::NTesting::TRow;
        TRow sample;
        NRTYT::TClientBase client(".");
        client.Create("//test1", NYT::NT_MAP, NYT::TCreateOptions());
        client.Create("//test1/test2", NYT::NT_TABLE);
        {
            auto writer = client.CreateTableWriter("//test1/test2", *sample.GetDescriptor());
            TRow first;
            first.SetFixed64(0);
            first.SetInt32(1);
            writer->AddRow(first, 0);

            TRow second;
            second.SetFixed64(0);
            second.SetInt32(2);
            writer->AddRow(second, 0);

            second.SetFixed64(1);
            second.SetInt32(2);
            writer->AddRow(second, 0);
        }
        {
            TRow2 sample2;
            auto writer = client.CreateTableWriter("//test1/test3", *sample2.GetDescriptor());
            TRow2 first;
            first.SetFixed64(0);
            first.SetStringField("1234567");
            writer->AddRow(first, 0);

            first.SetFixed64(1);
            first.SetStringField("1234");
            writer->AddRow(first, 0);

            first.SetFixed64(1);
            first.SetStringField("12345");
            writer->AddRow(first, 0);
        }

        {
            NYT::TSortOperationSpec spec;
            spec.AddInput(NYT::TRichYPath("//test1/test2"))
                .Output(NYT::TRichYPath("//test1/output1"));
            spec.SortBy({"Fixed64", "Int32"});
            client.Sort(spec);
        }
        Cerr << "Sort#1 finished" << Endl;

        {
            NYT::TSortOperationSpec spec;
            spec.AddInput(NYT::TRichYPath("//test1/test3"))
                .Output(NYT::TRichYPath("//test1/output2"));
            spec.SortBy({"Fixed64"});
            client.Sort(spec);
        }
        Cerr << "Sort#2 finished" << Endl;

        NYT::TReduceOperationSpec spec;
        spec.AddInput<TRow>(NYT::TRichYPath("//test1/output1"))
            .AddInput<TRow2>(NYT::TRichYPath("//test1/output2"))
            .AddOutput<TRow>(NYT::TRichYPath("//test1/output")).ReduceBy({"Fixed64"});
        client.Reduce(spec, new TMultipleTableReducer(5), NYT::TOperationOptions());

        auto reader = client.CreateTableReader<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/output"));
        TRow row;
        row = reader->GetRow();
        UNIT_ASSERT_VALUES_EQUAL_C(row.GetInt32(), 715, "first range, three rows, length('1234567') * 100 + 1 * 5 + 2 * 5");
        UNIT_ASSERT_VALUES_EQUAL(row.GetFixed64(), 0);
        reader->Next();
        row = reader->GetRow();
        UNIT_ASSERT_VALUES_EQUAL_C(row.GetInt32(), 910, "second range, three rows, length('1234') * 100 + length('12345') * 100 + 2 * 5");
        UNIT_ASSERT_VALUES_EQUAL(row.GetFixed64(), 1);
    }
    Y_UNIT_TEST(REDUCE_MTABLES_MWRITERS) {
        using NRTYT::NTesting::TRow;
        TRow sample;
        NRTYT::TClientBase client(".");
        client.Create("//test1", NYT::NT_MAP, NYT::TCreateOptions());
        client.Create("//test1/test2", NYT::NT_TABLE);
        {
            auto writer = client.CreateTableWriter("//test1/test2", *sample.GetDescriptor());
            TRow first;
            first.SetFixed64(0);
            first.SetInt32(1);
            writer->AddRow(first, 0);

            TRow second;
            second.SetFixed64(0);
            second.SetInt32(2);
            writer->AddRow(second, 0);

            second.SetFixed64(1);
            second.SetInt32(2);
            writer->AddRow(second, 0);
        }
        {
            TRow2 sample2;
            auto writer = client.CreateTableWriter("//test1/test3", *sample2.GetDescriptor());
            TRow2 first;
            first.SetFixed64(0);
            first.SetStringField("1234567");
            writer->AddRow(first, 0);

            first.SetFixed64(1);
            first.SetStringField("1234");
            writer->AddRow(first, 0);

            first.SetFixed64(1);
            first.SetStringField("12345");
            writer->AddRow(first, 0);
        }

        {
            NYT::TSortOperationSpec spec;
            spec.AddInput(NYT::TRichYPath("//test1/test2"))
                .Output(NYT::TRichYPath("//test1/sort_out1"));
            spec.SortBy({"Fixed64", "Int32"});
            client.Sort(spec);
        }
        Cerr << "Sort#1 finished" << Endl;

        {
            NYT::TSortOperationSpec spec;
            spec.AddInput(NYT::TRichYPath("//test1/test3"))
                .Output(NYT::TRichYPath("//test1/sort_out2"));
            spec.SortBy({"Fixed64"});
            client.Sort(spec);
        }
        Cerr << "Sort#2 finished" << Endl;

        NYT::TReduceOperationSpec spec;
        spec.AddInput<TRow>(NYT::TRichYPath("//test1/sort_out1"))
            .AddInput<TRow2>(NYT::TRichYPath("//test1/sort_out2"))
            .AddOutput<TRow>(NYT::TRichYPath("//test1/mwriters-output1"))
            .AddOutput<TRow2>(NYT::TRichYPath("//test1/mwriters-output2")).ReduceBy({"Fixed64"});
        client.Reduce(spec, new TMultipleWritersReducer(5), NYT::TOperationOptions());

        {
            auto reader = client.CreateTableReader<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/mwriters-output1"));
            TRow row;
            row = reader->GetRow();
            UNIT_ASSERT_VALUES_EQUAL_C(row.GetInt32(), 715, "first range, three rows, length('1234567') * 100 + 1 * 5 + 2 * 5");
            UNIT_ASSERT_VALUES_EQUAL(row.GetFixed64(), 0);
            Cerr << "first table is ok" << Endl;
        }
        {
            auto reader = client.CreateTableReader<NRTYT::NTesting::TRow2>(NYT::TRichYPath("//test1/mwriters-output2"));
            TRow2 row = reader->GetRow();
            UNIT_ASSERT_VALUES_EQUAL_C(row.GetFixed64(), 910, "second range, three rows, length('1234') * 100 + length('12345') * 100 + 2 * 5");
            Cerr << "second table is ok" << Endl;
        }
    }
    Y_UNIT_TEST(SORTED_REDUCE_MULTIPLE_TABLES) {
        using NRTYT::NTesting::TRow;
        TRow sample;
        NRTYT::TClientBase client(".");
        client.Create("//test1", NYT::NT_MAP, NYT::TCreateOptions());
        client.Create("//test1/test2", NYT::NT_TABLE);
        {
            auto writer = client.CreateTableWriter("//test1/test2", *sample.GetDescriptor());
            TRow first;
            first.SetFixed64(0);
            first.SetInt32(3);
            first.SetString("012");
            writer->AddRow(first, 0);

            TRow second;
            second.SetFixed64(0);
            second.SetInt32(5);
            second.SetString("678");
            writer->AddRow(second, 0);

            second.SetFixed64(1);
            second.SetInt32(1);
            second.SetString("456");
            writer->AddRow(second, 0);
        }
        {
            TRow2 sample2;
            auto writer = client.CreateTableWriter("//test1/test3", *sample2.GetDescriptor());
            TRow2 first;
            first.SetFixed64(0);
            first.SetStringField("345");
            first.SetInt32(4);
            writer->AddRow(first, 0);

            first.SetFixed64(1);
            first.SetStringField("123");
            first.SetInt32(0);
            writer->AddRow(first, 0);

            first.SetFixed64(1);
            first.SetStringField("789");
            first.SetInt32(2);
            writer->AddRow(first, 0);
        }
        /*
        * Tables looks like this:
        * test2 (sorted by Fixed64, Int32)
        * Fixed64, Int32, String
        * 0        3      012
        * 0        5      678
        * 1        1      456
        *
        * test3 (sorted by Fixed64, Int32)
        * Fixed64, Int32, StringField
        * 0        4      345
        * 1        0      123
        * 1        2      789
        *
        * Reducer grouped input should be like this:
        * Group1 (Fixed64 = 0)
        * Row(Fixed64 = 0, Int32 = 3, String = 012)
        * Row2(Fixed64 = 0, Int32 = 4, String = 345)
        * Row(Fixed64 = 0, Int32 = 5, String = 678)
        *
        * Group2 (Fixed64 = 1)
        * Row2(Fixed64 = 1, Int32 = 0, String = 123)
        * Row(Fixed64 = 1, Int32 = 1, String = 456)
        * Row2(Fixed64 = 1, Int32 = 2, String = 789)
        *
        * So reducer results will be 
        * Fixed64, Int32, String
        * 0        0      012345678
        * 1        0      123456789
        */

        {
            NYT::TSortOperationSpec spec;
            spec.AddInput(NYT::TRichYPath("//test1/test2"))
                .Output(NYT::TRichYPath("//test1/output1"));
            spec.SortBy({"Fixed64", "Int32"});
            client.Sort(spec);
        }
        Cerr << "Sort#1 finished" << Endl;

        {
            NYT::TSortOperationSpec spec;
            spec.AddInput(NYT::TRichYPath("//test1/test3"))
                .Output(NYT::TRichYPath("//test1/output2"));
            spec.SortBy({"Fixed64", "Int32"});
            client.Sort(spec);
        }
        Cerr << "Sort#2 finished" << Endl;

        NYT::TReduceOperationSpec spec;
        spec.AddInput<TRow>(NYT::TRichYPath("//test1/output1"))
            .AddInput<TRow2>(NYT::TRichYPath("//test1/output2"))
            .AddOutput<TRow>(NYT::TRichYPath("//test1/output"))
            .ReduceBy({"Fixed64"})
            .SortBy({"Fixed64", "Int32"});
        client.Reduce(spec, new TMultiConcatReducer(), NYT::TOperationOptions());

        auto reader = client.CreateTableReader<NRTYT::NTesting::TRow>(NYT::TRichYPath("//test1/output"));
        TRow row;
        row = reader->GetRow();
        UNIT_ASSERT_VALUES_EQUAL_C(row.GetString(), "012345678", "first range incorrect");
        UNIT_ASSERT_VALUES_EQUAL(row.GetFixed64(), 0);
        reader->Next();
        row = reader->GetRow();
        UNIT_ASSERT_VALUES_EQUAL_C(row.GetString(), "123456789", "second range incorrect");
        UNIT_ASSERT_VALUES_EQUAL(row.GetFixed64(), 1);
    }
}
