#include "process.h"

#include <drive/backend/compiled_riding/manager.h>
#include <drive/backend/logging/events.h>
#include <drive/backend/saas/api.h>
#include <drive/backend/tracks/report.h>

#include <drive/library/cpp/threading/future.h>
#include <drive/library/cpp/tracks/client.h>
#include <drive/library/cpp/yt/node/cast.h>

#include <mapreduce/yt/interface/client.h>

#include <rtline/library/json/parse.h>

class TAnalyzerCacheProcessState: public TJsonRTBackgroundProcessState {
public:
    TAnalyzerCacheProcessState() = default;
    TAnalyzerCacheProcessState(ui64 lastEventId)
        : LastEventId(lastEventId)
    {
    }

    ui64 GetLastEventId() const {
        return LastEventId;
    }

    static TString GetTypeName() {
        return "analyzer_cache";
    }
    TString GetType() const override {
        return GetTypeName();
    }

protected:
    NJson::TJsonValue SerializeToJson() const override {
        NJson::TJsonValue result = NJson::JSON_MAP;
        result["last_event_id"] = LastEventId;
        return result;
    }
    bool DeserializeFromJson(const NJson::TJsonValue& value) override {
        return NJson::TryFromJson(value["last_event_id"], LastEventId);
    }

private:
    ui64 LastEventId = 0;

private:
    static TFactory::TRegistrator<TAnalyzerCacheProcessState> Registrator;
};

TExpectedState TAnalyzerCacheProcess::DoExecute(TAtomicSharedPtr<IRTBackgroundProcessState> state_, const TExecutionContext& context) const {
    auto state = dynamic_cast<const TAnalyzerCacheProcessState*>(state_.Get());
    if (!state && state_) {
        WARNING_LOG << GetRobotId() << ": broken state " << state_->GetType() << ' ' << state_->GetReport().GetStringRobust() << Endl;
    }

    const auto& server = context.GetServerAs<NDrive::IServer>();
    const auto api = Yensured(server.GetDriveAPI());

    auto tracksApi = server.GetRTLineAPI(TracksApiName);
    if (!tracksApi) {
        return MakeUnexpected<TString>("tracks service " + TracksApiName + " is missing");
    }
    NDrive::TTracksClient tracksClient(tracksApi->GetSearchClient());

    auto linker = server.GetLinker(LinkerApiName);
    if (!linker) {
        return MakeUnexpected<TString>("linker service " + LinkerApiName + " is missing");
    }
    NDrive::TTracksLinker tracksLinker(linker);

    const auto& compiledRides = api->GetMinimalCompiledRides();
    const ui64 previousLastEventId = state ? state->GetLastEventId() : 0;
    auto events = TOptionalObjectEvents<TMinimalCompiledRiding>();
    {
        auto session = compiledRides.BuildSession(true);
        if (previousLastEventId) {
            events = compiledRides.GetEvents<TMinimalCompiledRiding>(previousLastEventId + 1, session);
        } else {
            TInstant start = Now() - TimeToLive;
            events = compiledRides.GetEvents<TMinimalCompiledRiding>({}, start, session);
        }
        if (!events) {
            return MakeUnexpected<TString>("cannot GetEvents: " + session.GetStringReport());
        }
    }
    auto& view = *events;
    std::sort(view.begin(), view.end(), [](const auto& left, const auto& right) {
        return left.GetHistoryEventId() < right.GetHistoryEventId();
    });
    if (BatchSize && BatchSize < view.size()) {
        view.resize(BatchSize);
    }

    TVector<std::pair<
        TObjectEvent<TMinimalCompiledRiding>,
        NThreading::TFuture<NDrive::TTracksLinker::TResults>
    >> linkedRides;
    size_t finished = 0;
    for (auto&& ride : view) {
        NDrive::TTrackQuery trackQuery;
        trackQuery.SessionId = ride.GetSessionId();
        trackQuery.DurationThreshold = DurationThreshold;
        if (RidingOnly) {
            trackQuery.Status = NDrive::ECarStatus::csRide;
        }

        auto timeout = TDuration::Seconds(10);
        auto deadline = Now() + timeout;
        auto asyncTracks = tracksClient.GetTracks(trackQuery, timeout);
        auto asyncLinked = tracksLinker.Link(std::move(asyncTracks));

        linkedRides.emplace_back(std::move(ride), std::move(asyncLinked));
        while ((linkedRides.size() > finished) && (linkedRides.size() - finished > MaxInFlight)) {
            linkedRides[finished].second.Wait(deadline);
            finished += 1;
        }
    }

    ui64 currentLastEventId = previousLastEventId;
    std::multimap<TInstant, NJson::TJsonValue> reports;
    for (auto&& [ride, asyncLinked] : linkedRides) {
        auto timeout = TDuration::Seconds(10);
        auto deadline = Now() + timeout;
        if (!asyncLinked.Wait(deadline)) {
            ERROR_LOG << GetRobotId() << ": wait timeout for " << ride.GetSessionId() << Endl;
            break;
        }
        if (!asyncLinked.HasValue()) {
            ERROR_LOG << GetRobotId() << ": exception for " << ride.GetSessionId() << ": " << NThreading::GetExceptionMessage(asyncLinked) << Endl;
            break;
        }

        for (auto&& linked : asyncLinked.GetValue()) {
            auto timestamp = ride.GetFinishInstant();
            auto value = NDrive::GetAnalyzerReport<NJson::TJsonValue>(linked, Format);
            if (!value.IsDefined()) {
                return MakeUnexpected<TString>(TStringBuilder() << "cannot get report " << Format << " for " << linked.Track.Name);
            }
            if (value.GetArray().size() == 1) {
                reports.emplace(timestamp, std::move(value.GetArraySafe()[0]));
            } else {
                reports.emplace(timestamp, std::move(value));
            }
        }

        currentLastEventId = std::max(currentLastEventId, ride.GetHistoryEventId());
    }

    NYT::IClientPtr client = NYT::CreateClient(YtCluster);
    TMap<NYT::TYPath, TInstant> tables;
    {
        NYT::ITransactionPtr tx = client->StartTransaction();
        NYT::TRichYPath path;
        NYT::TTableWriterPtr<NYT::TNode> writer;
        for (auto&&[timestamp, report] : reports) {
            auto table = YtDirectory + '/' + GetTableName(timestamp);
            auto p = NYT::TRichYPath(table).Append(true);
            if (path.Path_ != p.Path_) {
                path = p;
                tables.emplace(table, GetRoundedTimestamp(timestamp));
                if (writer) {
                    writer->Finish();
                    writer.Drop();
                }
            }
            if (!writer) {
                writer = tx->CreateTableWriter<NYT::TNode>(path);
            }
            writer->AddRow(NYT::ToNode(report));
        }
        if (writer) {
            writer->Finish();
        }
        tx->Commit();
    }
    if (TimeToLive) {
        for (auto&&[table, timestamp] : tables) {
            auto expirationTime = timestamp + TimeToLive;
            client->Set(table + '/' + "@expiration_time", expirationTime.ToString());
        }
    }

    return MakeAtomicShared<TAnalyzerCacheProcessState>(currentLastEventId);
}

NDrive::TScheme TAnalyzerCacheProcess::DoGetScheme(const IServerBase& server) const {
    NDrive::TScheme scheme = TBase::DoGetScheme(server);
    scheme.Add<TFSNumeric>("batch_size", "Max tracks to process in a single run").SetDefault(BatchSize);
    scheme.Add<TFSDuration>("duration_threshold", "Duration threshold").SetDefault(DurationThreshold);
    scheme.Add<TFSNumeric>("max_in_flight", "Max in flight").SetDefault(MaxInFlight);
    scheme.Add<TFSVariants>("format", "Format").SetVariants({
        "legacy",
        "legacy_debug",
        "legacy_with_raw",
    });
    scheme.Add<TFSDuration>("fraction", "Fraction").SetDefault(Fraction);
    scheme.Add<TFSBoolean>("riding_only", "Dump only riding");
    scheme.Add<TFSDuration>("ttl", "Time to live");
    scheme.Add<TFSVariants>("linker_api", "Linker API to use").SetVariants({
        LinkerApiName,
        TString{"maps_linker"},
    });
    scheme.Add<TFSVariants>("tracks_api", "Tracks API to use").SetVariants(server.ListRTLineAPIs());
    scheme.Add<TFSString>("yt_cluster", "YT cluster").SetDefault("hahn").SetRequired(true);
    scheme.Add<TFSString>("yt_directory", "YT directory").SetRequired(true);
    return scheme;
}

bool TAnalyzerCacheProcess::DoDeserializeFromJson(const NJson::TJsonValue& value) {
    if (!TBase::DoDeserializeFromJson(value)) {
        return false;
    }
    return
        NJson::ParseField(value["batch_size"], BatchSize) &&
        NJson::ParseField(value["duration_threshold"], DurationThreshold) &&
        NJson::ParseField(value["max_in_flight"], MaxInFlight) &&
        NJson::ParseField(value["format"], Format) &&
        NJson::ParseField(value["fraction"], Fraction) &&
        NJson::ParseField(value["riding_only"], RidingOnly) &&
        NJson::ParseField(value["linker_api"], LinkerApiName) &&
        NJson::ParseField(value["tracks_api"], TracksApiName) &&
        NJson::ParseField(value["ttl"], TimeToLive) &&
        NJson::ParseField(value["yt_cluster"], YtCluster, true) &&
        NJson::ParseField(value["yt_directory"], YtDirectory, true);
}

NJson::TJsonValue TAnalyzerCacheProcess::DoSerializeToJson() const {
    NJson::TJsonValue result = TBase::DoSerializeToJson();
    result["batch_size"] = BatchSize;
    result["duration_threshold"] = NJson::ToJson(NJson::Hr(DurationThreshold));
    result["max_in_flight"] = MaxInFlight;
    result["format"] = Format;
    result["fraction"] = NJson::ToJson(NJson::Hr(Fraction));
    result["riding_only"] = RidingOnly;
    result["linker_api"] = LinkerApiName;
    result["tracks_api"] = TracksApiName;
    result["ttl"] = NJson::ToJson(NJson::Hr(TimeToLive));
    result["yt_cluster"] = YtCluster;
    result["yt_directory"] = YtDirectory;
    return result;
}

TInstant TAnalyzerCacheProcess::GetRoundedTimestamp(TInstant timestamp) const {
    auto minutes = Fraction.Minutes();
    Y_ENSURE(minutes > 0);
    return TInstant::Minutes((timestamp.Minutes() / minutes) * minutes);
}

TString TAnalyzerCacheProcess::GetTableName(TInstant timestamp) const {
    TInstant rounded = GetRoundedTimestamp(timestamp);
    if (Fraction >= TDuration::Days(1)) {
        return rounded.FormatLocalTime("%Y-%m-%d");
    }
    return rounded.FormatLocalTime("%Y-%m-%dT%H:%M:%S");
}

TAnalyzerCacheProcess::TFactory::TRegistrator<TAnalyzerCacheProcess> TAnalyzerCacheProcess::Registrator(TAnalyzerCacheProcess::GetTypeName());
TAnalyzerCacheProcessState::TFactory::TRegistrator<TAnalyzerCacheProcessState> TAnalyzerCacheProcessState::Registrator(TAnalyzerCacheProcessState::GetTypeName());
