#include "session_matcher.h"

#include <drive/backend/abstract/frontend.h>
#include <drive/backend/compiled_riding/manager.h>
#include <drive/backend/data/chargable.h>
#include <drive/backend/database/drive_api.h>
#include <drive/backend/database/history/event.h>
#include <drive/backend/database/history/session.h>
#include <drive/backend/database/history/session_builder.h>
#include <drive/backend/device_snapshot/manager.h>
#include <drive/backend/device_snapshot/snapshot.h>
#include <drive/backend/tags/tags.h>

#include <rtline/library/json/cast.h>

namespace NJson {
    template <>
    TJsonValue ToJson(const NDrive::NSession::TSessionBindingInfo::TBindingAlgoMetaInfo& object) {
        return (NJson::TMapBuilder
            ("earliest_ride_start", NJson::ToJson(object.EarliestRideStart))
            ("latest_ride_start", NJson::ToJson(object.LatestRideStart))
            ("selected_rides_total", object.SelectedRidesTotal)
            ("skipped_total", object.SkippedTotal)
        );
    }

    template <>
    TJsonValue ToJson(const NDrive::NSession::TSessionBindingInfo& object) {
        NJson::TJsonValue result;
        if (object.GetBillingSessionsMatchingMetaInfo()) {
            result["billing_session_match_metainfo"] = NJson::ToJson(object.GetBillingSessionsMatchingMetaInfo());
        }
        if (object.GetCompiledRidesMatchingMetaInfo()) {
            result["compiled_session_match_metainfo"] = NJson::ToJson(object.GetCompiledRidesMatchingMetaInfo());
        }
        if (object.GetSessionId()) {
            result["session_id"] = object.GetSessionId();
        }
        if (object.GetUserId()) {
            result["user_id"] = object.GetUserId();
        }
        result["skipped"] = object.GetSkipped();
        return result;
    }
}

namespace NDrive::NSession {
    IMatchingConstraint::IMatchingConstraint(const NDrive::IServer* server, const TMatchingOptions& options)
        : Server(server)
        , Options(options)
    {
    }

    TMatchingConstraintsGroup::TPtr TMatchingConstraintsGroup::Construct(const NDrive::IServer* server, const TMatchingOptions& options, const TSet<TString>& constraintNames, const bool checkEmpty) {
        auto constraintsGroupPtr = MakeAtomicShared<TMatchingConstraintsGroup>(server, options);

        for (auto&& constraintName: constraintNames) {
            if (IMatchingConstraint::TFactory::Has(constraintName)) {
                if (!constraintsGroupPtr->Has(constraintName)) {
                    constraintsGroupPtr->Add(constraintName, IMatchingConstraint::TFactory::Construct(constraintName, server, options));
                }
            } else {
                ERROR_LOG << "Unknown constraint name " << constraintName << Endl;
                return nullptr;
            }
        }

        if (checkEmpty && constraintsGroupPtr->Empty()) {
            ERROR_LOG << "Empty constraint group" << Endl;
            return nullptr;
        }

        return constraintsGroupPtr;
    }

    TMatchingConstraintsGroup::TMatchingConstraintsGroup(const NDrive::IServer* server, const TMatchingOptions& options)
        : Server(server)
        , Options(options)
    {
    }

    TMatchingConstraintsGroup::TMatchingConstraintsGroup(const NDrive::IServer* server, const TMatchingOptions& options, std::initializer_list<std::pair<const TString, IMatchingConstraint::TPtr>> il)
        : Server(server)
        , Options(options)
        , Constraints(il)
    {
    }

    TOptionalSessionBindingInfo TMatchingConstraintsGroup::MatchSession(const TString& carId, TInstant timestamp) const {
        TSessionBindingInfo result;
        if (!MatchSession(carId, timestamp, result)) {
            return {};
        }
        return result;
    }

    constexpr bool TMatchingConstraintsGroup::Empty() const {
        return Constraints.empty();
    }

    bool TMatchingConstraintsGroup::Has(const TString& name) const {
        return Constraints.contains(name);
    }

    void TMatchingConstraintsGroup::Add(const TString& name, IMatchingConstraint::TPtr constraint) {
        Constraints.emplace(name, constraint);
    }

    bool TMatchingConstraintsGroup::Match(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant timestamp) const {
        return AllOf(
            Constraints.cbegin(),
            Constraints.cend(),
            [&objEvent, timestamp](const auto& cPair) {
                return cPair.second->Match(objEvent, timestamp);
            }
        );
    }

    bool TMatchingConstraintsGroup::Match(TCarTagHistoryEventsSessionPtr objEventPtr, const TInstant timestamp) const {
        TBillingEventsCompilation evCompilation;
        if (!objEventPtr->FillCompilation(evCompilation)) {
            return false;
        }
        return AllOf(
            Constraints.cbegin(),
            Constraints.cend(),
            [objEventPtr, &evCompilation, timestamp](const auto& cPair) {
                return cPair.second->Match(objEventPtr, evCompilation, timestamp);
            }
        );
    }

    bool TMatchingConstraintsGroup::MatchSession(const TString& carId, const TInstant timestamp, TSessionBindingInfo& sessionInfo) const {
        return (MatchBillingSession(carId, timestamp, sessionInfo) ||
                MatchCompiledRidesSession(carId, timestamp, sessionInfo));
    }

    bool TMatchingConstraintsGroup::MatchCompiledRidesSession(const TString& carId, const TInstant timestamp, TSessionBindingInfo& sessionInfo) const {
        const auto& compiledRides = Server->GetDriveAPI()->GetMinimalCompiledRides();
        auto tx = compiledRides.BuildSession(true);
        auto ydbTx = Server->GetDriveAPI()->BuildYdbTx<NSQL::ReadOnly>("matching_constraints_group", Server);
        auto optionalCompiledSessions = compiledRides.GetObject<TFullCompiledRiding>({ carId }, tx, ydbTx);
        if (!optionalCompiledSessions) {
            ERROR_LOG << "MatchCompiledRidesSession: cannot get CompiledSession for " << carId << ": " << tx.GetStringReport() << Endl;
            return false;
        }

        const auto& fullCompiledRides = *optionalCompiledSessions;
        return MatchCompiledRidesSession(fullCompiledRides, carId, timestamp, sessionInfo);
    }

    bool TMatchingConstraintsGroup::MatchCompiledRidesSession(const TFullCompiledRideContainer& fullCompiledRides, const TString& /* carId */, const TInstant timestamp, TSessionBindingInfo& sessionInfo) const {
        {
            TSessionBindingInfo::TBindingAlgoMetaInfo meta;
            if (fullCompiledRides) {
                meta.EarliestRideStart = fullCompiledRides.front().GetStartInstant();
                meta.LatestRideStart = fullCompiledRides.back().GetStartInstant();
                meta.SelectedRidesTotal = fullCompiledRides.size();
            }
            sessionInfo.SetCompiledRidesMatchingMetaInfo(std::move(meta));
        }

        auto& meta = sessionInfo.MutableCompiledRidesMatchingMetaInfo();

        for (auto&& objEvent : Reversed(fullCompiledRides)) {
            if (objEvent.GetStartInstant() > timestamp) {  // skip rides started after a violation
                continue;
            }

            if (static_cast<ui32>(meta.SkippedTotal) > Options.RejectedSessionsSkipLimit) {
                break;
            }

            if (Match(objEvent, timestamp)) {
                sessionInfo.SetSessionId(objEvent.GetSessionId());
                sessionInfo.SetUserId(objEvent.GetHistoryUserId());
                sessionInfo.SetSkipped(meta.SkippedTotal);
                return true;
            }

            ++meta.SkippedTotal;
        }

        sessionInfo.SetSkipped(Max(meta.SkippedTotal, sessionInfo.GetSkipped()));
        return false;
    }

    bool TMatchingConstraintsGroup::MatchBillingSession(const TString& carId, const TInstant timestamp, TSessionBindingInfo& sessionInfo) const {
        TCarSessionsBuilderPtr bSessions = Server->GetDriveAPI()->GetTagsManager().GetDeviceTags().GetHistoryManager().GetSessionsBuilder("billing");
        TCarTagHistoryEventsSessionPtrContainer vSessions = bSessions->GetObjectSessions(carId);
        return MatchBillingSession(vSessions, carId, timestamp, sessionInfo);
    }

    bool TMatchingConstraintsGroup::MatchBillingSession(const TCarTagHistoryEventsSessionPtrContainer& vSessions, const TString& /* carId */, const TInstant timestamp, TSessionBindingInfo& sessionInfo) const {
        {
            TSessionBindingInfo::TBindingAlgoMetaInfo meta;
            if (vSessions) {
                meta.EarliestRideStart = vSessions.front()->GetStartTS();
                meta.LatestRideStart = vSessions.back()->GetStartTS();
                meta.SelectedRidesTotal = vSessions.size();
            }
            sessionInfo.SetBillingSessionsMatchingMetaInfo(std::move(meta));
        }

        auto& meta = sessionInfo.MutableBillingSessionsMatchingMetaInfo();

        for (auto&& objEventPtr : Reversed(vSessions)) {
            if (objEventPtr->GetStartTS() > timestamp) {  // skip rides started after a violation
                continue;
            }

            if (static_cast<ui32>(meta.SkippedTotal) > Options.RejectedSessionsSkipLimit) {
                break;
            }

            if (Match(objEventPtr, timestamp)) {
                sessionInfo.SetSessionId(objEventPtr->GetSessionId());
                sessionInfo.SetUserId(objEventPtr->GetUserId());
                sessionInfo.SetSkipped(meta.SkippedTotal);
                return true;
            }

            ++meta.SkippedTotal;
        }

        sessionInfo.SetSkipped(Max(meta.SkippedTotal, sessionInfo.GetSkipped()));
        return false;
    }

    TString TViolationDuringSessionMatchingConstraint::GetTypeName() {
        return "violation_during_session";
    }

    TViolationDuringSessionMatchingConstraint::TFactory::TRegistrator<TViolationDuringSessionMatchingConstraint> TViolationDuringSessionMatchingConstraint::Registrator;

    bool TViolationDuringSessionMatchingConstraint::Match(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant timestamp) const {
        return objEvent.GetStartInstant().Get() <= timestamp && timestamp <= objEvent.GetFinishInstant().Get();
    }

    bool TViolationDuringSessionMatchingConstraint::Match(TCarTagHistoryEventsSessionPtr /* objEventPtr */, const TBillingEventsCompilation& eventsCompilation, const TInstant timestamp) const {
        return !!eventsCompilation.GetEvents() && eventsCompilation.GetEvents().front().GetInstant() <= timestamp && timestamp <= eventsCompilation.GetEvents().back().GetInstant();
    }

    TString TViolationDuringOrAfterSessionMatchingConstraint::GetTypeName() {
        return "violation_during_or_after_session";
    }

    TViolationDuringOrAfterSessionMatchingConstraint::TFactory::TRegistrator<TViolationDuringOrAfterSessionMatchingConstraint> TViolationDuringOrAfterSessionMatchingConstraint::Registrator;

    bool TViolationDuringOrAfterSessionMatchingConstraint::Match(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant timestamp) const {
        return objEvent.GetStartInstant() <= timestamp;
    }

    bool TViolationDuringOrAfterSessionMatchingConstraint::Match(TCarTagHistoryEventsSessionPtr /* objEventPtr */, const TBillingEventsCompilation& eventsCompilation, const TInstant timestamp) const {
        return !!eventsCompilation.GetEvents() && eventsCompilation.GetEvents().front().GetInstant() <= timestamp;
    }

    TString TViolationNotDuringServicingMatchingConstraint::GetTypeName() {
        return "violation_not_during_servicing";
    }

    TViolationNotDuringServicingMatchingConstraint::TFactory::TRegistrator<TViolationNotDuringServicingMatchingConstraint> TViolationNotDuringServicingMatchingConstraint::Registrator;

    bool TViolationNotDuringServicingMatchingConstraint::Match(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant timestamp) const {
        auto optionalLastEvent = objEvent.GetLastEventAt(timestamp);
        if (!optionalLastEvent) {
            return false;
        }
        return optionalLastEvent->GetTagName() != TChargableTag::Servicing;
    }

    bool TViolationNotDuringServicingMatchingConstraint::Match(TCarTagHistoryEventsSessionPtr objEventPtr, const TBillingEventsCompilation& eventsCompilation, const TInstant timestamp) const {
        Y_UNUSED(eventsCompilation);
        auto optionalLastEvent = objEventPtr ? objEventPtr->GetLastEventAt(timestamp) : Nothing();
        if (!optionalLastEvent) {
            return false;
        }
        const auto& lastEvent = *optionalLastEvent;
        return lastEvent->GetName() != TChargableTag::Servicing;
    }

    bool IHasEventDuringSessionMatchingConstraint::MatchImpl(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant /* timestamp */, const TSet<TString>& needleEvents) const {
        for (const auto& localEvent : objEvent.GetLocalEvents()) {
            if (needleEvents.contains(localEvent.GetTagName())) {
                return true;
            }
        }
        return false;
    }

    bool IHasEventDuringSessionMatchingConstraint::MatchImpl(TCarTagHistoryEventsSessionPtr /* objEventPtr */, const TBillingEventsCompilation& eventsCompilation, const TInstant /* timestamp */, const TSet<TString>& needleEvents) const {
        for (const auto& event : eventsCompilation.GetEvents()) {
            if (needleEvents.contains(event.GetName())) {
                return true;
            }
        }
        return false;
    }

    TString THasRidingEventDuringSessionMatchingConstraint::GetTypeName() {
        return "has_riding_during_session";
    }

    THasRidingEventDuringSessionMatchingConstraint::TFactory::TRegistrator<THasRidingEventDuringSessionMatchingConstraint> THasRidingEventDuringSessionMatchingConstraint::Registrator;

    bool THasRidingEventDuringSessionMatchingConstraint::Match(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant timestamp) const {
        return MatchImpl(objEvent, timestamp, { ::ToString(ESessionState::Riding) });
    }

    bool THasRidingEventDuringSessionMatchingConstraint::Match(TCarTagHistoryEventsSessionPtr objEventPtr, const TBillingEventsCompilation& eventsCompilation, const TInstant timestamp) const {
        return MatchImpl(objEventPtr, eventsCompilation, timestamp, { ::ToString(ESessionState::Riding) });
    }

    bool IHasEventBeforeViolationMatchingConstraint::MatchImpl(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant timestamp, const TSet<TString>& needleEvents) const {
        for (const auto& localEvent : objEvent.GetLocalEvents()) {
            if (needleEvents.contains(localEvent.GetTagName()) && localEvent.GetInstant() <= timestamp) {
                return true;
            }
        }
        return false;
    }

    bool IHasEventBeforeViolationMatchingConstraint::MatchImpl(TCarTagHistoryEventsSessionPtr /* objEventPtr */, const TBillingEventsCompilation& eventsCompilation, const TInstant timestamp, const TSet<TString>& needleEvents) const {
        for (const auto& event : eventsCompilation.GetEvents()) {
            if (needleEvents.contains(event.GetName()) && event.GetInstant() <= timestamp) {
                return true;
            }
        }
        return false;
    }

    TString THasRidingEventBeforeViolationMatchingConstraint::GetTypeName() {
        return "has_riding_before_violation";
    }

    THasRidingEventBeforeViolationMatchingConstraint::TFactory::TRegistrator<THasRidingEventBeforeViolationMatchingConstraint> THasRidingEventBeforeViolationMatchingConstraint::Registrator;

    bool THasRidingEventBeforeViolationMatchingConstraint::Match(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant timestamp) const {
        return MatchImpl(objEvent, timestamp, { ::ToString(ESessionState::Riding) });
    }

    bool THasRidingEventBeforeViolationMatchingConstraint::Match(TCarTagHistoryEventsSessionPtr objEventPtr, const TBillingEventsCompilation& eventsCompilation, const TInstant timestamp) const {
        return MatchImpl(objEventPtr, eventsCompilation, timestamp, { ::ToString(ESessionState::Riding) });
    }

    TString THasLocationChangeMatchingConstraint::GetTypeName() {
        return "has_location_change";
    }

    THasLocationChangeMatchingConstraint::TFactory::TRegistrator<THasLocationChangeMatchingConstraint> THasLocationChangeMatchingConstraint::Registrator;

    bool THasLocationChangeMatchingConstraint::Match(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant /* timestamp */) const {
        if (!objEvent.HasSnapshotsDiff() || !objEvent.GetSnapshotsDiffUnsafe().HasMileage()) {  // skip without mileage info
            return false;
        }
        const double minRideDistanceKm = Options.MinRideDistanceMeters * 0.001;
        return objEvent.GetSnapshotsDiffUnsafe().GetMileageUnsafe() >= minRideDistanceKm;
    }

    bool THasLocationChangeMatchingConstraint::Match(TCarTagHistoryEventsSessionPtr objEventPtr, const TBillingEventsCompilation& eventsCompilation, const TInstant /* timestamp */) const {
        const double minRideDistanceKm = Options.MinRideDistanceMeters * 0.001;

        double startMileage = eventsCompilation.GetMileageOnStart();
        double finishMileage = 0.0;

        if (objEventPtr->GetClosed()) {
            finishMileage = eventsCompilation.GetMileageMax();
        } else {
            auto carId = objEventPtr->GetObjectId();
            auto snapshotPtr = Server->GetSnapshotsManager().GetSnapshotPtr(carId);
            auto deviceSnapshotPtr = std::dynamic_pointer_cast<TRTDeviceSnapshot>(snapshotPtr);
            if (!deviceSnapshotPtr || !deviceSnapshotPtr->GetMileage(finishMileage)) {
                return false;
            }
        }

        return (finishMileage - startMileage) >= minRideDistanceKm;
    }

    bool IViolationInStateMatchingConstraint::MatchImpl(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant timestamp, const TSet<TString>& needleSessionStates) const {
        TMaybe<TInstant> targetStateStartTimestamp;

        for (const auto& localEvent : objEvent.GetLocalEvents()) {
            if (!!targetStateStartTimestamp) {
                if (targetStateStartTimestamp <= timestamp && timestamp <= localEvent.GetInstant()) {
                    return true;
                }
                targetStateStartTimestamp.Clear();
            }

            if (needleSessionStates.contains(localEvent.GetTagName())) {
                targetStateStartTimestamp = localEvent.GetInstant();
            }
        }

        return (!!targetStateStartTimestamp) && targetStateStartTimestamp <= timestamp;  // timestamp is supposed to be not greater than now
    }

    bool IViolationInStateMatchingConstraint::MatchImpl(TCarTagHistoryEventsSessionPtr /* objEventPtr */, const TBillingEventsCompilation& eventsCompilation, const TInstant timestamp, const TSet<TString>& needleSessionStates) const {
        TMaybe<TInstant> targetStateStartTimestamp;

        for (const auto& event : eventsCompilation.GetEvents()) {
            if (!!targetStateStartTimestamp) {
                if (targetStateStartTimestamp <= timestamp && timestamp <= event.GetInstant()) {
                    return true;
                }
                targetStateStartTimestamp.Clear();
            }

            if (needleSessionStates.contains(event.GetName())) {
                targetStateStartTimestamp = event.GetInstant();
            }
        }

        return (!!targetStateStartTimestamp) && targetStateStartTimestamp <= timestamp;  // timestamp is supposed to be not greater than now
    }

    TString TViolationInRidingMatchingConstraint::GetTypeName() {
        return "violation_in_riding";
    }

    TViolationInRidingMatchingConstraint::TFactory::TRegistrator<TViolationInRidingMatchingConstraint> TViolationInRidingMatchingConstraint::Registrator;

    bool TViolationInRidingMatchingConstraint::Match(const TObjectEvent<TFullCompiledRiding>& objEvent, const TInstant timestamp) const {
        return MatchImpl(objEvent, timestamp, { ::ToString(ESessionState::Riding) });
    }

    bool TViolationInRidingMatchingConstraint::Match(TCarTagHistoryEventsSessionPtr objEventPtr, const TBillingEventsCompilation& eventsCompilation, const TInstant timestamp) const {
        return MatchImpl(objEventPtr, eventsCompilation, timestamp, { ::ToString(ESessionState::Riding) });
    }
}
