#include <library/cpp/testing/common/env.h>
#include <library/cpp/testing/gtest/gtest.h>
#include <maps/wikimap/mapspro/services/mrc/libs/graph_matcher_adapter/include/feature_positioner.h>
#include <maps/libs/geolib/include/direction.h>
#include <maps/libs/geolib/include/distance.h>

#include <random>

namespace maps {
namespace mrc {
namespace adapters {
namespace tests {
namespace {

const std::string SOURCE_ID = "iPhone100500";
const std::string TEST_GRAPH_PATH = BinaryPath("maps/data/test/graph3");

/**
 * The accuracy of the input GPS signal is one second (time_t)
 * @see
 * https://a.yandex-team.ru/arc/trunk/arcadia/maps/doc/proto/analyzer/gpssignal.proto
 */
db::TrackPoints makeTrackPoints()
{
    db::TrackPoints result;
    result.emplace_back()
        .setSourceId(SOURCE_ID)
        .setTimestamp(chrono::parseSqlDateTime("2017-05-17 11:03:16+03"))
        .setGeodeticPos(geolib3::Point2(37.6698835, 55.7281365));
    result.emplace_back()
        .setSourceId(SOURCE_ID)
        .setTimestamp(chrono::parseSqlDateTime("2017-05-17 11:03:17+03"))
        .setGeodeticPos(geolib3::Point2(37.670128, 55.728180));
    result.emplace_back()
        .setSourceId(SOURCE_ID)
        .setTimestamp(chrono::parseSqlDateTime("2017-05-17 11:03:18+03"))
        .setGeodeticPos(geolib3::Point2(37.670373, 55.7282245));
    return result;
}

std::vector<const Matcher*> makeMatchers()
{
    static const CompactGraphMatcherAdapter fbMatcher(TEST_GRAPH_PATH);
    return {&fbMatcher};
}

void addRide(chrono::TimePoint timestamp,
             double metersPerSecond,
             db::TrackPoints& trackPoints,
             db::Features& features)
{
    using namespace std::chrono_literals;

    static const auto SOURCE_ID = "iPhone100500";
    static const auto POINT_A = geolib3::Point2{37.606339, 55.690217};
    static const auto POINT_B = geolib3::Point2{37.606824, 55.691011};
    static const auto VECTOR = (POINT_B - POINT_A).unit();
    static const auto HEADING = geolib3::Direction2(VECTOR).heading();
    static const auto DISTANCE = geolib3::fastGeoDistance(POINT_A, POINT_B);

    auto seconds = int(DISTANCE / metersPerSecond);
    for (int second = 0; second < seconds; ++second) {
        trackPoints.emplace_back()
            .setSourceId(SOURCE_ID)
            .setTimestamp(timestamp += 500ms)
            .setGeodeticPos(
                fastGeoShift(POINT_A, VECTOR * (metersPerSecond * second)))
            .setHeading(HEADING);
        features.push_back(
            sql_chemistry::GatewayAccess<db::Feature>::construct()
                .setSourceId(SOURCE_ID)
                .setSize(1080, 920)
                .setTimestamp(timestamp += 500ms));
    }
    features.pop_back();  // inner only
}

auto byTime = [](const auto& lhs, const auto& rhs) {
    return lhs.timestamp() < rhs.timestamp();
};

auto speedClassifier = [](db::TrackPoints trackPoints, const GraphTypeToMatcherMap& /*graphTypeToMatcherMap*/){
    track_classifier::TrackInterval result;
    auto min = std::min_element(trackPoints.begin(), trackPoints.end(), byTime);
    auto max = std::max_element(trackPoints.begin(), trackPoints.end(), byTime);
    auto meters =
        geolib3::fastGeoDistance(max->geodeticPos(), min->geodeticPos());
    auto seconds = std::chrono::duration_cast<std::chrono::seconds>(
                       max->timestamp() - min->timestamp())
                       .count();
    result.begin = min->timestamp();
    result.end = max->timestamp();
    result.type = seconds == 0
                      ? track_classifier::TrackType::Undefined
                      : meters / seconds < 3
                            ? track_classifier::TrackType::Pedestrian
                            : meters / seconds < 30
                                  ? track_classifier::TrackType::Vehicle
                                  : track_classifier::TrackType::Undefined;
    return std::vector<track_classifier::TrackInterval>{{result}};
};

struct GraphTypeCounter {
    int pedestrianNumbers;
    int roadNumbers;
    int undefinedNumbers;
};

GraphTypeCounter countGraphType(const db::Features& features)
{
    GraphTypeCounter result{};
    for (const auto& feature : features) {
        if (feature.hasGraph()) {
            switch (feature.graph()) {
                case db::GraphType::Pedestrian:
                    ++result.pedestrianNumbers;
                    break;
                case db::GraphType::Road:
                    ++result.roadNumbers;
                    break;
            }
        }
        else {
            ++result.undefinedNumbers;
        }
    }
    return result;
}

} // anonymous namespace

TEST(tests, test_matcher)
{
    auto trackPoints = makeTrackPoints();
    for (const Matcher* matcher : makeMatchers()) {
        auto path = matcher->match(trackPoints);
        ASSERT_FALSE(path.empty());
        EXPECT_EQ(path.front().startTime, trackPoints.front().timestamp());
        EXPECT_EQ(path.back().endTime, trackPoints.back().timestamp());
    }
}

TEST(tests, test_positioner)
{
    auto trackPoints = makeTrackPoints();
    for (const Matcher* matcher : makeMatchers()) {
        FeaturePositioner positioner(
            {{db::GraphType::Road, matcher}, {db::GraphType::Pedestrian, matcher}},
            [&](auto&&...) { return trackPoints; });

        // out of the interval (before and after)
        for (const auto& timestamp : {trackPoints.front().timestamp()
                                          - std::chrono::milliseconds(500),
                                      trackPoints.back().timestamp()
                                          + std::chrono::milliseconds(500)}) {
            db::Features features {
                sql_chemistry::GatewayAccess<db::Feature>::construct()
                    .setSourceId(SOURCE_ID)
                    .setTimestamp(timestamp)
            };

            positioner(features);
            auto& feature = features.front();
            EXPECT_FALSE(feature.hasPos());
            EXPECT_FALSE(feature.hasHeading());
        }

        // nodes (track points)
        for (const auto& trackPoint : trackPoints) {
            db::Features features {
                sql_chemistry::GatewayAccess<db::Feature>::construct()
                    .setSourceId(SOURCE_ID)
                    .setTimestamp(trackPoint.timestamp())
            };

            positioner(features);
            auto& feature = features.front();
            EXPECT_TRUE(feature.hasPos());
            EXPECT_TRUE(feature.hasHeading());
            EXPECT_NEAR(feature.geodeticPos().x(), trackPoint.geodeticPos().x(), 0.001);
            EXPECT_NEAR(feature.geodeticPos().y(), trackPoint.geodeticPos().y(), 0.001);
            EXPECT_NEAR(feature.heading().value(), 72, 1);
        }

        // edges (inside segments)
        std::adjacent_find(
            trackPoints.begin(), trackPoints.end(),
            [&positioner](const auto& lhs, const auto& rhs) {
                db::Features features {
                    sql_chemistry::GatewayAccess<db::Feature>::construct()
                        .setSourceId(SOURCE_ID)
                        .setTimestamp(lhs.timestamp() + (rhs.timestamp() - lhs.timestamp()))
                };

                positioner(features);
                auto& feature = features.front();
                EXPECT_TRUE(feature.hasPos());
                EXPECT_TRUE(feature.hasHeading());
                auto midpoint = geolib3::Segment2{lhs.geodeticPos(), rhs.geodeticPos()}.midpoint();
                EXPECT_NEAR(feature.geodeticPos().x(), midpoint.x(), 0.001);
                EXPECT_NEAR(feature.geodeticPos().y(), midpoint.y(), 0.001);
                EXPECT_NEAR(feature.heading().value(), 72, 1);
                return false; // continue
            });
    }
}

TEST(tests, test_augmented_track_points)
{
    auto trackPoints = makeTrackPoints();
    auto& trackPoint = trackPoints[trackPoints.size() / 2];
    trackPoint.setHeading(geolib3::Heading{42}).setIsAugmented(true);

    for (const Matcher* matcher : makeMatchers()) {
        FeaturePositioner positioner(
            {{db::GraphType::Road, matcher}, {db::GraphType::Pedestrian, matcher}},
            [&](auto&&...) { return trackPoints; });
        db::Features features{
            sql_chemistry::GatewayAccess<db::Feature>::construct()
                .setSourceId(trackPoint.sourceId())
                .setTimestamp(trackPoint.timestamp())};

        positioner(features);
        auto& feature = features.front();
        EXPECT_TRUE(feature.hasPos());
        EXPECT_TRUE(feature.hasHeading());
        EXPECT_EQ(feature.geodeticPos(), trackPoint.geodeticPos());
        EXPECT_EQ(feature.heading(), trackPoint.heading().value());
    }
}

TEST(tests, test_track_classifier)
{
    CompactGraphMatcherAdapter matcher(TEST_GRAPH_PATH);

    db::TrackPoints trackPoints;
    db::Features features;
    GraphTypeCounter expect{};
    addRide(chrono::parseSqlDateTime("2017-05-17 11:00:00+03"),
            2 /*metersPerSecond*/,
            trackPoints,
            features);
    expect.pedestrianNumbers = int(features.size());
    addRide(chrono::parseSqlDateTime("2017-05-17 11:10:00+03"),
            4 /*metersPerSecond*/,
            trackPoints,
            features);
    expect.roadNumbers = features.size() - expect.pedestrianNumbers;
    addRide(chrono::parseSqlDateTime("2017-05-17 11:20:00+03"),
            40 /*metersPerSecond*/,
            trackPoints,
            features);
    expect.undefinedNumbers =
        features.size() - expect.pedestrianNumbers - expect.roadNumbers;

    std::random_device randomDevice;
    std::mt19937 randomNumberEngine(randomDevice());
    std::shuffle(features.begin(), features.end(), randomNumberEngine);

    EXPECT_GT(expect.pedestrianNumbers, 0);
    EXPECT_GT(expect.roadNumbers, 0);
    EXPECT_GT(expect.undefinedNumbers, 0);

    auto trackPointProvider = [&](const std::string& sourceId,
                                  chrono::TimePoint startTime,
                                  chrono::TimePoint endTime) {
        db::TrackPoints result;
        for (const auto& trackPoint : trackPoints) {
            if (trackPoint.sourceId() == sourceId &&
                trackPoint.timestamp() >= startTime &&
                trackPoint.timestamp() <= endTime) {
                result.push_back(trackPoint);
            }
        }
        return result;
    };

    FeaturePositioner({{db::GraphType::Pedestrian, &matcher},
                       {db::GraphType::Road, &matcher}},
                      trackPointProvider,
                      speedClassifier)(features);
    auto counter = countGraphType(features);

    EXPECT_EQ(counter.pedestrianNumbers, expect.pedestrianNumbers);
    EXPECT_EQ(counter.roadNumbers, expect.roadNumbers);
    EXPECT_EQ(counter.undefinedNumbers, expect.undefinedNumbers);

    FeaturePositioner(
        {{db::GraphType::Pedestrian, &matcher},
         {db::GraphType::Road, &matcher}},
        trackPointProvider,
        classifyAs(track_classifier::TrackType::Undefined))(features);
    counter = countGraphType(features);

    EXPECT_EQ(counter.pedestrianNumbers, 0);
    EXPECT_EQ(counter.roadNumbers, 0);
    EXPECT_EQ(counter.undefinedNumbers, (int)features.size());
}

} // namespace tests
} // namespace adapters
} // namespace mrc
} // namespace maps
