#include "area_tags.h"

#include <drive/backend/base/config.h>
#include <drive/backend/base/server.h>
#include <drive/backend/processors/sessions/base.h>
#include <drive/backend/saas/api.h>
#include <drive/backend/ut/library/helper.h>
#include <drive/backend/ut/library/helper2.h>

#include <library/cpp/testing/unittest/registar.h>

TVector<struct TSpeedLimitRange>  realRanges = {
    //{}, {}, {}, TVector<TGeoCoord>{}, {}, Length, Time, SpeedLimitDuration, 0, 0, 0, 0, 0, false, "" }
    { {}, {}, {}, TVector<TGeoCoord>{
        { 37.895059549999999, 55.71842796},
        { 37.89534785, 55.718395569999998},
        { 37.89534785, 55.718395569999998},
        { 37.900029109999998, 55.717219989999997},
        { 37.900029109999998, 55.717219989999997},
        { 37.900102519999997, 55.71756929}
    }, {}, 639.332275, 54.3692474, 27.0776024, 0, 0, 0, 0, 0, false, "" },

    { {}, {}, {}, TVector<TGeoCoord>{
        { 37.900102519999997, 55.71756929},
        { 37.900130230000002, 55.717698429999999},
        { 37.900130230000002, 55.717698429999999},
        { 37.900764760000001, 55.720388100000001},
        { 37.900764760000001, 55.720388100000001},
        { 37.900764760000001, 55.720388100000001}
    }, {}, 316.384613, 17.3035507, 13.3998184, 0, 0, 0, 0, 0, false, "" },

    //violation SpeedLimitDuration = 6.70520258, Time = 6.55843115
    { {}, {}, {}, TVector<TGeoCoord>{
        { 37.900764760000001, 55.720388100000001},
        { 37.900764760000001, 55.720388100000001},
        { 37.900767899999998, 55.720395060000001},
        { 37.902021859999998, 55.72149744},
        { 37.902021859999998, 55.72149744},
        { 37.902139570000003, 55.72155746}
    }, {}, 158.317291, 6.55843115, 6.70520258, 0, 0, 0, 0, 0, false, "" },

    { {}, {}, {}, TVector<TGeoCoord>{
        { 37.902139570000003, 55.72155746},
        { 37.902308779999998, 55.721636359999998},
        { 37.902480670000003, 55.72170466},
        { 37.905488650000002, 55.722704620000002},
        { 37.905488650000002, 55.722704620000002},
        { 37.905679300000003, 55.722766290000003}
    }, {}, 259.543823, 21.7926979, 10.9924431, 0, 0, 0, 0, 0, false, "" }
};

TVector<struct TSpeedLimitRange>  testRanges = {
    //{}, {}, {}, TVector<TGeoCoord>{}, {}, Length, Time, SpeedLimitDuration, 0, 0, 0, 0, 0, false }

    // has violation, not in polygon
    { {}, {}, {}, TVector<TGeoCoord>{
        { 10, 10},
        { 20, 20}
    }, {}, 20, 7, 10, 0, 0, 0, 0, 0, false, "" },

    // no violation, not in polygon, must be removed
    { {}, {}, {}, TVector<TGeoCoord>{
        { 20, 20},
        { 25, 25}
    }, {}, 316, 17, 13, 0, 0, 0, 0, 0, false, "" },

    // has violation, partly in polygon, must be removed
    { {}, {}, {}, TVector<TGeoCoord>{
        { 25, 25},
        { 30, 35}
    }, {}, 316, 7, 13, 0, 0, 0, 0, 0, false, "" },

    // has violation, is in polygon, must be removed
    { {}, {}, {}, TVector<TGeoCoord>{
        { 30, 35},
        { 40, 45}
    }, {}, 316, 7, 13, 0, 0, 0, 0, 0, false, "" },

    // has violation, is in polygon, must not be removed
    { {}, {}, {}, TVector<TGeoCoord>{
        { 40, 45},
        { 50, 50}
    }, {}, 700, 7, 130, 0, 0, 0, 0, 0, false, "" },

    // has violation, partly in polygon, must not be removed
    { {}, {}, {}, TVector<TGeoCoord>{
        { 50, 50},
        { 70, 60}
    }, {}, 700, 7, 130, 0, 0, 0, 0, 0, false, "" }
};

bool HasViolation(NDrive::TTracksLinker::TResults& trackValue) {
    for (const auto& result : trackValue) {
        for (const auto& segment : result.Segments) {
            for (const auto& range : segment.Processed) {
                if (range.IsSpeedLimitExceeded()) {
                    return true;
                }
            }
        }
    }
    return false;
}

i32 CountViolations(NDrive::TTracksLinker::TResults& trackValue) {
    i32 count = 0;
    for (const auto& result : trackValue) {
        for (const auto& segment : result.Segments) {
            for (const auto& range : segment.Processed) {
                if (range.IsSpeedLimitExceeded()) {
                    ++count;
                }
            }
        }
    }
    return count;
}

bool CoordinatesEqual(const TVector<TGeoCoord>& coords1, const TVector<TGeoCoord>& coords2) {
    if (coords1.size() != coords2.size()) {
        return false;
    }
    for (size_t i = 0; i < coords1.size(); ++i) {
        if (coords1[i].X != coords2[i].X || coords1[i].Y != coords2[i].Y) {
            return false;
        }
    }
    return true;
}

NThreading::TFuture<class TVector<struct NDrive::TTracksLinker::TResult>>  GetTracks(const TString& sessionId, TTestEnvironment& env) {
        auto tracksApiName = "drive_graph";
        auto api = env.GetServer()->GetRTLineAPI(tracksApiName);
        R_ENSURE(api, HTTP_INTERNAL_SERVER_ERROR, "cannot get SaaS api ");
        auto tracksClient = MakeHolder<NDrive::TTracksClient>(api->GetSearchClient());
        R_ENSURE(tracksClient, HTTP_INTERNAL_SERVER_ERROR, "cannot CreateTrackClient");

        auto linkerApiName = "default-linker";
        auto linker = env.GetServer()->GetLinker(linkerApiName);

        NDrive::TTracksLinker::TOptions tracksLinkerOptions;
        auto tracksLinker = MakeHolder<NDrive::TTracksLinker>(linker, tracksLinkerOptions);

        NDrive::TTrackQuery query;
        query.SessionId = sessionId;
        auto tracks = tracksClient->GetTracks(query, TDuration::Seconds(10));
        tracks.Wait(TDuration::Seconds(10));
        auto linked = tracksLinker->Link(std::move(tracks));
        linked.Wait(TDuration::Seconds(10));
        return linked;
}

Y_UNIT_TEST_SUITE(SpeedlLimitCorrectorSuite) {
    Y_UNIT_TEST(SimpleCheck) {
        TTestEnvironment env;
        env.Execute(NDrive::NTest::TBuildEnv());

        bool getTrack = false;
        TVector<NDrive::TTracksLinker::TResult> sessionTracks;
        {
            if (getTrack) {
                const TString sessionId = "35d48e30-90142c50-1b8e579f-6c98342a";
                auto linkedTracks = GetTracks(sessionId, env);
                UNIT_ASSERT(linkedTracks.HasValue());
                sessionTracks = linkedTracks.GetValue();
            } else {
                NDrive::TTracksLinker::TSegment segment;
                segment.Processed = realRanges;

                NDrive::TTracksLinker::TResult result;
                result.Segments = TVector<NDrive::TTracksLinker::TSegment>{segment};
                sessionTracks = TVector<NDrive::TTracksLinker::TResult>{result};
            }

            UNIT_ASSERT(sessionTracks.size() != 0);
            UNIT_ASSERT(HasViolation(sessionTracks));
        }
        {
            TVector<TGeoCoord> coords;
            UNIT_ASSERT(TGeoCoord::DeserializeVector("0 0 90 0 90 90 0 90 0 0", coords));
            UNIT_ASSERT(env.GetConfigGenerator().UpsertArea("custom_speed_limit_zone", USER_ROOT_DEFAULT, coords, {}));

            TTagDescription description;
            description.SetName("speed_limit_correction_area_tag").SetType("speed_limit_correction_area_tag");
            auto session = env.GetServer()->GetDriveAPI()->template BuildTx<NSQL::Writable>();
            R_ENSURE(env.GetServer()->GetDriveAPI()->GetTagsManager().GetTagsMeta().RegisterTag(new TTagDescription(description), USER_ID_DEFAULT, session), HTTP_INTERNAL_SERVER_ERROR, "can't register tag: " << session.GetStringReport());
            UNIT_ASSERT(session.Commit());

            THolder<TSpeedLimitCorrectionAreaTag> tag(new TSpeedLimitCorrectionAreaTag(TSpeedLimitCorrectionAreaTag::TypeName));
            tag->SetCustomSpeedLimit(120);
            tag->SetName(TSpeedLimitCorrectionAreaTag::TypeName);

            UNIT_ASSERT(env.GetConfigGenerator().AddTag(tag.Release(), "custom_speed_limit_zone", USER_ROOT_DEFAULT, NEntityTagsManager::EEntityType::Area));
            SendGlobalMessage<NDrive::TCacheRefreshMessage>();
        }

        auto correctedTracks = sessionTracks;
        TSpeedLimitCorrectionAreaTag::CorrectSpeedLimitRanges(correctedTracks, *env.GetServer());
        UNIT_ASSERT(!HasViolation(correctedTracks));
    }

    Y_UNIT_TEST(DifferentSpeedCases) {
        TTestEnvironment env;
        env.Execute(NDrive::NTest::TBuildEnv());

        NDrive::TTracksLinker::TSegment segment;
        segment.Processed = testRanges;
        NDrive::TTracksLinker::TResult result;
        result.Segments = TVector<NDrive::TTracksLinker::TSegment>{segment};
        TVector<NDrive::TTracksLinker::TResult> sessionTracks = TVector<NDrive::TTracksLinker::TResult>{result};

        UNIT_ASSERT(sessionTracks.size() != 0);
        UNIT_ASSERT(HasViolation(sessionTracks));

        TVector<TGeoCoord> coords;
        UNIT_ASSERT(TGeoCoord::DeserializeVector("10 30 60 30 60 60 10 60 10 30", coords));
        UNIT_ASSERT(env.GetConfigGenerator().UpsertArea("custom_speed_limit_zone", USER_ROOT_DEFAULT, coords, {}));

        TTagDescription description;
        description.SetName("speed_limit_correction_area_tag").SetType("speed_limit_correction_area_tag");
        auto session = env.GetServer()->GetDriveAPI()->template BuildTx<NSQL::Writable>();
        R_ENSURE(env.GetServer()->GetDriveAPI()->GetTagsManager().GetTagsMeta().RegisterTag(new TTagDescription(description), USER_ID_DEFAULT, session), HTTP_INTERNAL_SERVER_ERROR, "can't register tag: " << session.GetStringReport());
        UNIT_ASSERT(session.Commit());

        THolder<TSpeedLimitCorrectionAreaTag> tag(new TSpeedLimitCorrectionAreaTag(TSpeedLimitCorrectionAreaTag::TypeName));
        tag->SetCustomSpeedLimit(60);
        tag->SetName(TSpeedLimitCorrectionAreaTag::TypeName);

        UNIT_ASSERT(env.GetConfigGenerator().AddTag(tag.Release(), "custom_speed_limit_zone", USER_ROOT_DEFAULT, NEntityTagsManager::EEntityType::Area));
        SendGlobalMessage<NDrive::TCacheRefreshMessage>();

        auto correctedTracks = sessionTracks;
        TSpeedLimitCorrectionAreaTag::CorrectSpeedLimitRanges(correctedTracks, *env.GetServer());
        UNIT_ASSERT(HasViolation(correctedTracks));

        segment = correctedTracks[0].Segments[0];
        UNIT_ASSERT_VALUES_EQUAL(segment.Processed.size(), 6);

        UNIT_ASSERT(CoordinatesEqual(segment.Processed[0].Points, TVector<TGeoCoord>{ TGeoCoord(10, 10), TGeoCoord(20, 20) }));
        UNIT_ASSERT(HasViolation(correctedTracks));
        UNIT_ASSERT_VALUES_EQUAL(CountViolations(correctedTracks), 3);
    }
}
