#include "client.h"

#include <drive/backend/abstract/frontend.h>
#include <drive/backend/compiled_riding/manager.h>
#include <drive/backend/data/chargable.h>
#include <drive/backend/database/transaction/assert.h>
#include <drive/backend/history_iterator/history_iterator.h>
#include <drive/backend/saas/api.h>
#include <drive/backend/tracks/yagr.h>

#include <drive/library/cpp/common/status.h>
#include <drive/library/cpp/tracks/client.h>

NDrive::ECarStatus TSessionTrackClient::GetStatus(TStringBuf tag) {
    if (tag == TChargableTag::Acceptance) {
        return NDrive::ECarStatus::csAcceptance;
    }
    if (tag == TChargableTag::Parking) {
        return NDrive::ECarStatus::csParking;
    }
    if (tag == TChargableTag::Riding) {
        return NDrive::ECarStatus::csRide;
    }
    if (tag == TChargableTag::Reservation) {
        return NDrive::ECarStatus::csReservation;
    }
    return NDrive::ECarStatus::csParking;
}

TSessionTrackClient::TSessionTrackClient(NDrive::TTrackClientPtr client, const NDrive::IServer& server)
    : Client(client)
    , Server(server)
{
}

NThreading::TFuture<NDrive::TTracks> TSessionTrackClient::GetTracks(const NDrive::TTrackQuery& query, TDuration timeout) const {
    Y_ENSURE(Client);
    Y_ENSURE(!query.RideId);
    if (query.DeviceId || query.SessionId || query.UserId || query.Status) {
        THistoryRidesContext context(Server, query.Since);
        auto tx = Server.GetDriveDatabase().GetCompiledSessionManager().BuildTx<NSQL::ReadOnly | NSQL::Deferred>(timeout, timeout);
        auto until = query.Until != TInstant::Max() ? MakeMaybe(query.Until) : Nothing();
        auto ydbTx = Server.GetDriveAPI()->BuildYdbTx<NSQL::ReadOnly>("session_track_client", &Server);
        R_ENSURE(
            context.Initialize(query.SessionId, query.UserId, query.DeviceId, tx, ydbTx, until, query.NumDoc),
            {},
            "cannot Initialize HistoryRidesContext",
            tx
        );

        auto allTracks = NThreading::TFutures<NDrive::TTracks>();
        auto sessions = context.GetSessions(query.Until, query.NumDoc);
        for (auto&& session : sessions) {
            auto compiledRiding = session.GetFullCompiledRiding(tx);
            auto deviceId = session.GetObjectId();
            auto sessionId = session.GetSessionId();
            auto userId = session.GetUserId();

            NDrive::TTrackQuery sessionQuery;
            sessionQuery.DeviceId = deviceId;
            sessionQuery.Since = session.GetStartTS();
            sessionQuery.Until = session.GetLastTS();
            auto sessionRawTracks = Client->GetTracks(sessionQuery, timeout);
            auto sessionTracks = sessionRawTracks.Apply([
                compiledRiding,
                deviceId,
                sessionId,
                userId,
                status = query.Status
            ](const NThreading::TFuture<NDrive::TTracks>& t) {
                const auto& tracks = t.GetValue();
                auto result = NDrive::TTracks();

                Y_ENSURE(compiledRiding);
                auto current = compiledRiding->GetLastEventAt(TInstant::Zero());
                auto next = compiledRiding->GetNextEventAt(TInstant::Zero());
                Y_ENSURE(next);

                for (auto&& track : tracks) {
                    Y_ENSURE(track.DeviceId == deviceId);
                    for (auto&& coordinate : track.Coordinates) {
                        if (next && next->GetInstant() <= coordinate.Timestamp) {
                            current = compiledRiding->GetLastEventAt(coordinate.Timestamp);
                            next = compiledRiding->GetNextEventAt(coordinate.Timestamp);
                            result.emplace_back();
                            result.back().DeviceId = deviceId;
                            result.back().SessionId = sessionId;
                            result.back().UserId = userId;
                            if (current) {
                                result.back().Name = TStringBuilder() << current->GetInstant().Seconds() << '-' << sessionId;
                                result.back().Status = GetStatus(current->GetTagName());
                            } else {
                                result.back().Name = TStringBuilder() << "post" << '-' << sessionId;
                                result.back().Status = NDrive::ECarStatus::csPost;
                            }
                        }
                        Y_ENSURE(!result.empty());
                        result.back().Coordinates.push_back(coordinate);
                    }
                }
                if (status) {
                    erase_if(result, [status](const NDrive::TTrack& track) {
                        return status != track.Status;
                    });
                }
                return result;
            });
            allTracks.push_back(std::move(sessionTracks));
        }
        if (allTracks.empty()) {
            return NThreading::MakeFuture<NDrive::TTracks>();
        }
        if (allTracks.size() == 1) {
            return std::move(allTracks[0]);
        }
        auto waiter = NThreading::WaitAll(allTracks);
        auto merged = waiter.Apply([allTracks = std::move(allTracks)](const NThreading::TFuture<void>& /*w*/) mutable {
            NDrive::TTracks result;
            for (auto&& tracks : allTracks) {
                for (auto&& track : tracks.ExtractValue()) {
                    result.push_back(std::move(track));
                }
            }
            return result;
        });
        return merged;
    }
    return Client->GetTracks(query, timeout);
}

THolder<NDrive::ITrackClient> CreateTrackClient(const TString& name, const NDrive::IServer& server) {
    if (name == "internal") {
        auto client = server.GetTrackClient();
        R_ENSURE(client, HTTP_INTERNAL_SERVER_ERROR, "cannot get " << name);
        return MakeHolder<NDrive::TTrackClientRef>(client);
    }

    if (name == "yagr") {
        const auto& settings = server.GetSettings();
        const auto enabled = settings.GetValueDef<bool>("track_client.enable_yagr", false);
        R_ENSURE(enabled, HTTP_INTERNAL_SERVER_ERROR, "TYAGRTrackClient is disabled");
        return MakeHolder<TYAGRTrackClient>(server);
    }

    auto api = server.GetRTLineAPI(name);
    R_ENSURE(api, HTTP_INTERNAL_SERVER_ERROR, "cannot get SaaS api " << name);
    return MakeHolder<NDrive::TTracksClient>(api->GetSearchClient());
}

THolder<NDrive::ITrackClient> CreateCameraTrackClient(const NDrive::IServer& server) {
    if (server.GetSettings().GetValue<bool>("signalq.signalq_tracks.enable").GetOrElse(false)) {
        return MakeHolder<TYAGRTrackClientV2>(server);
    } else {
        return nullptr;
    }
}

TYDBTracksClient::TYDBTracksClient(const NDrive::IServer& server, const NDrive::TTracksLinker::TOptions& options)
    : Server(server)
    , Options(options)
{
}

NDrive::TTracksLinker::TResults TYDBTracksClient::GetYDBTracks(const NDrive::TTrackQuery& trackQuery) {
    auto shift = Server.GetSettings().GetValueDef<TDuration>("ydb_tracks.time_shift", TDuration::Days(1));
    auto time = TInstant::Now() - shift;
    auto ydbRange = GetYDBRange(trackQuery, time);

    if (ydbRange.From > ydbRange.To) {
        return {};
    };

    auto ydbTx = Server.GetDriveAPI()->BuildYdbTx<NSQL::ReadOnly | NSQL::Deferred>("ydb_tracks_client", &Server);
    if (!ydbTx) {
        return {};
    }

    auto queryOptions = NSQL::TQueryOptions().SetGenericCondition("groupstamp", ydbRange);
    if (trackQuery.UserId) {
        queryOptions.AddGenericCondition("user_id",  trackQuery.UserId);
        queryOptions.SetOrderBy({ "user_id", "groupstamp" });
        queryOptions.SetSecondaryIndex("tracks_user_id_groupstamp_index");
    }
    if (trackQuery.DeviceId) {
        queryOptions.AddGenericCondition("object_id",  trackQuery.DeviceId);
        queryOptions.SetOrderBy({ "object_id", "groupstamp" });
        queryOptions.SetSecondaryIndex("tracks_object_id_groupstamp_index");
    }
    if (trackQuery.SessionId) {
        queryOptions.AddGenericCondition("session_id", trackQuery.SessionId);
        if (!trackQuery.DeviceId) {
            queryOptions.SetOrderBy({ "session_id", "groupstamp" });
            queryOptions.SetSecondaryIndex("tracks_session_id_groupstamp_index");
        } else {
            queryOptions.SetSecondaryIndex("");
            queryOptions.SetOrderBy({"object_id", "groupstamp", "session_id" });
        }
    }
    auto transaction = ydbTx.GetTransaction();
    if (!transaction) {
        NDrive::TEventLog::Log("TYDBTracksClient::GetYDBTracks. Error", NJson::TMapBuilder
            ("error", "cannot create ydb tx")
        );
        return {};
    }

    auto query = queryOptions.PrintQuery(*transaction, "tracks");

    TRecordsSet records;
    auto queryResult = transaction->Exec(query, &records);
    if (!ParseQueryResult(queryResult, ydbTx)) {
        NDrive::TEventLog::Log("TYDBTracksClient::GetYDBTracks. Error", NJson::TMapBuilder
            ("query", query)
            ("error", "cannot parse query result")
        );
        return {};
    }
    return ParseYdbTracks(records, trackQuery.Since, trackQuery.Until);
}

TRange<ui64> TYDBTracksClient::GetYDBRange(const NDrive::TTrackQuery& trackQuery, TInstant maxTime) const {
    TRange<ui64> ydbRange = MakeRange<ui64>(0, Max<ui64>());
    ydbRange.From = trackQuery.Since.Hours() * 3600;
    ydbRange.To = (TInstant(std::min(maxTime, trackQuery.Until)).Hours() + 2) * 3600;
    return ydbRange;
}

NDrive::TTracksLinker::TResults TYDBTracksClient::ParseYdbTracks(const TRecordsSet& records, TInstant since, TInstant until) const {
    NDrive::TTracksLinker::TResults results;
    auto tableRecords = records.GetRecords();
    for (const auto& record : tableRecords) {
        NDrive::NProto::TTrackPartData partData;
        if (!partData.ParseFromString(record.Get("data"))) {
            continue;
        }
        auto segments = partData.GetSegments();
        for (const auto& seg : segments) {
            NDrive::TTracksLinker::TResult result;
            result.Track.SessionId = record.Get("session_id");
            result.Track.UserId = record.Get("user_id");
            result.Track.Status = NDrive::GetStatus(seg.GetStatus());
            auto coords = seg.GetMatchedCoords();
            // TODO: finish violation calculation eyurkovsk
            //TVector<NDrive::NProto::TMatchedCoordinate> matcherCoordinates;
            for (const auto& coord : coords) {
                TGeoCoord c;
                if (!c.Deserialize(coord.GetCoord())) {
                    continue;
                }
                auto coordinate = NGraph::TTimedGeoCoordinate(c, TInstant::Seconds(coord.GetTimestamp()), coord.GetSpeed() / 3.6);
                if (coordinate.Timestamp <= until && coordinate.Timestamp >= since) {
                    result.Track.Coordinates.push_back(coordinate);
                    // TODO: finish violation calculation eyurkovsk
                    // matcherCoordinates.push_back(coordinate);
                }
            }
            if (result.Track.Coordinates.empty()) {
                continue;
            }
            // TODO: finish violation calculation eyurkovsk
            // auto ranges = GetSpeedLimitRanges(matcherCoordinates);
            // segment.Processed = AnalyzeSpeedLimitRange(ranges);
            auto segments = SegmentCoordinates(result.Track.Coordinates, Options.SegmentSize, Options.SplitThreshold);
            for (auto&& coordinates : segments) {
                result.Segments.emplace_back().Coordinates = std::move(coordinates);
            }
            result.FilteredCoordinates = result.Track.Coordinates;
            results.push_back(result);
        }
    }
    return results;
}

// TODO: finish violation calculation eyurkovsk
TSpeedLimitRanges TYDBTracksClient::GetSpeedLimitRanges(const TVector<NDrive::NProto::TMatchedCoordinate>& coordinates) const {
    TSpeedLimitRanges ranges;
    ranges.emplace_back();

    ranges.front().Length = coordinates.front().GetLength();
    ranges.front().SpeedLimit = coordinates.front().GetSpeedLimit() / 3.6;
    ranges.front().SpeedMax = coordinates.front().GetSpeed() / 3.6;
    ranges.front().SpeedAverage = coordinates.front().GetSpeed() / 3.6;
    ranges.front().FC = coordinates.front().GetFC();
    ranges.front().Time = ranges.front().Length / ranges.front().SpeedAverage;
    ranges.front().SpeedLimitDuration = ranges.front().Length / ranges.front().SpeedLimit;
    bool isSpeedLimitExceeded = (coordinates[0].GetSpeed() - ranges.front().SpeedLimit) / 3.6 >= 0.1;

    TGeoCoord c;
    Y_ENSURE(c.Deserialize(coordinates.front().GetCoord()));
    ranges.front().Points.emplace_back(NGraph::TTimedGeoCoordinate(c, TInstant::Seconds(coordinates.front().GetTimestamp()), coordinates.front().GetSpeed() / 3.6));

    double summarySpeed = coordinates.front().GetSpeed() / 3.6;

    for (size_t i = 1; i < coordinates.size(); ++i) {
        TGeoCoord c;
        Y_ENSURE(c.Deserialize(coordinates.back().GetCoord()));

        float limit = std::max(static_cast<float>(coordinates[i].GetSpeedLimit() / 3.6), Options.ViolationOptions.SpeedLimitLow);
        float speedThreshold = GetSpeedThreshold(coordinates[i].GetFC(), Options.ViolationOptions);
        float threshold = limit + speedThreshold;


        bool currentSpeedLimitExceeded = coordinates[i].GetSpeed() / 3.6 - threshold >= 0.1;
        bool differentFC = !Options.ViolationOptions.SpeedThresholds.empty() && ranges.back().FC != coordinates[i].GetFC();
        bool differentStatus = currentSpeedLimitExceeded != isSpeedLimitExceeded;

        if (differentFC || differentStatus) {
            ranges.emplace_back();
            ranges.back().SpeedLimit = coordinates[i].GetSpeedLimit() / 3.6;
            ranges.back().FC = coordinates[i].GetFC();
            summarySpeed = 0;
            isSpeedLimitExceeded = currentSpeedLimitExceeded;
        }

        summarySpeed += coordinates[i].GetSpeed() / 3.6;
        ranges.back().Points.emplace_back(NGraph::TTimedGeoCoordinate(c, TInstant::Seconds(coordinates[i].GetTimestamp()), coordinates[i].GetSpeed()) / 3.6);
        ranges.back().Length += coordinates[i].GetLength();
        ranges.back().SpeedAverage = summarySpeed / ranges.back().Points.size();
        ranges.back().Time = ranges.back().Length / ranges.back().SpeedAverage;
        ranges.back().SpeedLimitDuration = ranges.back().Length / ranges.back().SpeedLimit;
        ranges.back().SpeedMax = std::max(static_cast<float>(coordinates[i].GetSpeed() / 3.6), ranges.back().SpeedMax);
        ranges.back().IsHard = ranges.back().SpeedMax >= ranges.back().SpeedLimit + Options.ViolationOptions.HardSpeedThreshold;
    }

    return ranges;
}

TSpeedLimitRanges TYDBTracksClient::AnalyzeSpeedLimitRange(TSpeedLimitRanges& ranges) const {
    TSpeedLimitRanges violations;
    for (auto& range : ranges) {
        range.PointIndices = TVector(range.Points.size(), 1.9);
        range.Indices = TVector(range.Points.size(), static_cast<ui64>(1));
        range.Projection = range.Points;
        range.Trace = range.Points;
        if (range.IsSpeedLimitExceeded() && range.Length > Options.ViolationOptions.CriticalLength) {
            violations.push_back(range);
        }
    }
    return violations;
}
