#include "../lib/column_names.h"
#include "../lib/common.h"
#include "../lib/table_names.h"
#include "../lib/yt_operations.h"

#include <library/cpp/testing/unittest/registar.h>
#include <mapreduce/yt/tests/yt_unittest_lib/yt_unittest_lib.h>
#include <maps/libs/geolib/include/polygon.h>

namespace maps::wiki::route_lost_feedback::tests {

namespace {

YtTmpTableNames createTestYtTmpTableNames()
{
    const std::string TMP = "//tmp/";

    YtTmpTableNames res;

    res.routeLostClusters = TMP + "lost";
    res.routeLostClustersSorted = TMP + "lost_sorted";
    res.trackStat = TMP + "track";
    res.trackStatSorted = TMP + "track_sorted";
    res.routeLostJoinTrackStat = TMP + "lost_join_track";
    res.routeLostScored = TMP + "lost_scored";
    res.routeLostScoredSorted = TMP + "lost_scored_sorted";

    return res;
}

void createTable(NYT::IClient& client, const std::string& tableName)
{
    client.Create(
        TString(tableName),
        NYT::NT_TABLE,
        NYT::TCreateOptions()
            .Recursive(true)
            .IgnoreExisting(true)
    );
}

void writeToTable(
    NYT::IClient& client,
    const std::string& tableName,
    const TVector<NYT::TNode>& rows)
{
    auto writer = client.CreateTableWriter<NYT::TNode>(
        NYT::TRichYPath(TString(tableName)).Append(true)
    );

    for (const auto& row : rows) {
        writer->AddRow(row);
    }

    writer->Finish();
}

} // unnamed namespace

Y_UNIT_TEST_SUITE(suite) {

Y_UNIT_TEST(join_lost_and_track)
{
    auto clientPtr = NYT::NTesting::CreateTestClient();
    auto& client = *clientPtr;

    auto tmpTables = createTestYtTmpTableNames();

    // Prepare route-lost table
    //
    createTable(client, tmpTables.routeLostClusters);

    writeToTable(
        client,
        tmpTables.routeLostClusters,
        {
            NYT::TNode::CreateMap()
                (column_names::KEY, "key1")
                (column_names::PERSISTENT_ID, 1u)
                (column_names::SEGMENT_INDEX, 1u),
            NYT::TNode::CreateMap()
                (column_names::KEY, "key2")
                (column_names::PERSISTENT_ID, 1u)
                (column_names::SEGMENT_INDEX, 1u)
        }
    );

    sortByGivenСolumns(
        client,
        tmpTables.routeLostClusters,
        tmpTables.routeLostClustersSorted,
        {column_names::PERSISTENT_ID, column_names::SEGMENT_INDEX}
    );

    // Prepare tracks table
    //
    createTable(client, tmpTables.trackStat);

    writeToTable(
        client,
        tmpTables.trackStat,
        {
            NYT::TNode::CreateMap()
                (column_names::PERSISTENT_ID, 1u)
                (column_names::SEGMENT_INDEX, 1u)
                (column_names::TRACKS_NUMBER, 10u)
        }
    );

    sortByGivenСolumns(
        client,
        tmpTables.trackStat,
        tmpTables.trackStatSorted,
        {column_names::PERSISTENT_ID, column_names::SEGMENT_INDEX}
    );

    // Run operation
    //
    YtOperations ytOps(
        clientPtr,
        geolib3::Polygon2(),
        YtInputTableNames(),
        tmpTables
    );

    ytOps.joinRouteLostAndTracks();

    // Check result
    //
    auto reader = client.CreateTableReader<NYT::TNode>(
        TString(tmpTables.routeLostJoinTrackStat)
    );

    auto row = reader->GetRow();
    UNIT_ASSERT_EQUAL(row[column_names::TRACKS_NUMBER].AsUint64(), 10u);

    reader->Next();
    UNIT_ASSERT_EQUAL(reader->IsValid(), true);

    row = reader->GetRow();
    UNIT_ASSERT_EQUAL(row[column_names::TRACKS_NUMBER].AsUint64(), 10u);

    reader->Next();
    UNIT_ASSERT_EQUAL(reader->IsValid(), false);
}

Y_UNIT_TEST(join_lost_and_track__no_tracks)
{
    auto clientPtr = NYT::NTesting::CreateTestClient();
    auto& client = *clientPtr;

    auto tmpTables = createTestYtTmpTableNames();

    // Prepare route-lost table
    //
    createTable(client, tmpTables.routeLostClusters);

    writeToTable(
        client,
        tmpTables.routeLostClusters,
        {
            NYT::TNode::CreateMap()
                (column_names::KEY, "key")
                (column_names::PERSISTENT_ID, 1u)
                (column_names::SEGMENT_INDEX, 1u)
        }
    );

    sortByGivenСolumns(
        client,
        tmpTables.routeLostClusters,
        tmpTables.routeLostClustersSorted,
        {column_names::PERSISTENT_ID, column_names::SEGMENT_INDEX}
    );

    // Prepare tracks table
    //
    createTable(client, tmpTables.trackStat);

    sortByGivenСolumns(
        client,
        tmpTables.trackStat,
        tmpTables.trackStatSorted,
        {column_names::PERSISTENT_ID, column_names::SEGMENT_INDEX}
    );

    // Run operation
    //
    YtOperations ytOps(
        clientPtr,
        geolib3::Polygon2(),
        YtInputTableNames(),
        tmpTables
    );

    ytOps.joinRouteLostAndTracks();

    // Check result
    //
    auto reader = client.CreateTableReader<NYT::TNode>(
        TString(tmpTables.routeLostJoinTrackStat)
    );
    UNIT_ASSERT_EQUAL(reader->IsValid(), false);
}

Y_UNIT_TEST(tracks_stat)
{
    auto clientPtr = NYT::NTesting::CreateTestClient();
    auto& client = *clientPtr;

    YtInputTableNames inputTables;
    inputTables.tracks = {"//daily_travel_times"};

    auto tmpTables = createTestYtTmpTableNames();

    // Prepare tracks travel times table
    //
    createTable(client, inputTables.tracks.front());

    writeToTable(
        client,
        inputTables.tracks.front(),
        {
            NYT::TNode::CreateMap()
                (column_names::PERSISTENT_ID, 10u)
                (column_names::SEGMENT_INDEX, 11u)
                (column_names::START_LON, 50.)
                (column_names::START_LAT, 60.),
            NYT::TNode::CreateMap()
                (column_names::PERSISTENT_ID, 10u)
                (column_names::SEGMENT_INDEX, 11u)
                (column_names::START_LON, 30.)
                (column_names::START_LAT, 30.)
        }
    );

    // Run operation
    //
    YtOperations ytOps(
        clientPtr,
        geolib3::Polygon2(
            geolib3::PointsVector({
                {49, 59}, {49, 61}, {51, 61}, {51, 59}
            })
        ),
        inputTables,
        tmpTables
    );

    ytOps.calcTrackStat();
    ytOps.sortTrackStat();

    // Check result
    //
    auto reader = client.CreateTableReader<NYT::TNode>(
        TString(tmpTables.trackStatSorted)
    );

    auto row = reader->GetRow();
    UNIT_ASSERT_EQUAL(row[column_names::PERSISTENT_ID].AsUint64(), 10u);
    UNIT_ASSERT_EQUAL(row[column_names::SEGMENT_INDEX].AsUint64(), 11u);
    UNIT_ASSERT_EQUAL(row[column_names::TRACKS_NUMBER].AsUint64(), 1u);

    reader->Next();
    UNIT_ASSERT_EQUAL(reader->IsValid(), false);
}

}

} // namespace maps::wiki::route_lost_feedback::tests
