#include "common.h"
#include "support_ai.h"

#include <drive/backend/data/notifications_tags.h>
#include <drive/backend/database/drive_api.h>
#include <drive/backend/device_snapshot/snapshots/tag.h>
#include <drive/backend/logging/events.h>
#include <drive/backend/notifications/support_ai/call.h>
#include <drive/backend/support_center/manager.h>
#include <drive/backend/tags/tag_description.h>
#include <drive/backend/tags/tags_manager.h>

#include <rtline/library/json/builder.h>

#include <util/generic/utility.h>


namespace {

    class TCallException: public yexception {
    private:
        const TString CallId;
        const TString ErrorType;

    public:
        TCallException(const TString& callId, const TString& errorType)
            : CallId(callId)
            , ErrorType(errorType)
        {}

        TString GetCallId() const {
            return CallId;
        }

        TString GetErrorType() const {
            return ErrorType;
        }
    };

    class TFileUploader {
    private:
        TString CallId;
        TString ErrorType;

    public:
        TFileUploader(const TString& callId, const TString& error)
            : CallId(callId)
            , ErrorType(error)
        {
        }

        void operator()(const NThreading::TFuture<NUtil::THttpReply>& reply) const {
            if (!reply.HasValue()) {
                ythrow TCallException(CallId, ErrorType) << NThreading::GetExceptionInfo(reply);
            }
            const auto& report = reply.GetValue();
            if (!report.IsSuccessReply()) {
                ythrow TCallException(CallId, ErrorType) << report.Serialize();
            }
        }
    };

    class TCallProcessor {
        const NDrive::TSupportAIClient& Client;
        const NS3::TBucket& Bucket;
        const TString CallId;

    public:
        TCallProcessor(const NDrive::TSupportAIClient& client, const NS3::TBucket& bucket, const TString& callId)
            : Client(client)
            , Bucket(bucket)
            , CallId(callId)
        {
        }

        NThreading::TFuture<TSupportAICall> operator()(const NThreading::TFuture<NJson::TJsonValue>& call) {
            if (!call.HasValue()) {
                ythrow TCallException(CallId, "get_info_error") << NThreading::GetExceptionInfo(call);
            }
            const auto& rawData = call.GetValue();
            TSupportAICall callData;
            if (!callData.DeserializeFromJson(rawData)) {
                ythrow TCallException(CallId, "parse_info_error") << "fail to parse: " << rawData;
            }
            callData.SetExternalId(CallId);
            TVector<NThreading::TFuture<void>> callResultFutures;
            callResultFutures.emplace_back(Bucket.PutKey(CallId + ".json", rawData.GetStringRobust()).Apply(TFileUploader(CallId, "upload_call_info_error")));
            if (callData.GetStatus() == TSupportAICall::EStatus::Servised) {
                callResultFutures.emplace_back(Client.LoadTrack(CallId).Apply([&bucket = Bucket, callId = CallId](const NThreading::TFuture<TString>& track) {
                    if (!track.HasValue()) {
                        ythrow TCallException(callId, "load_track_error") << NThreading::GetExceptionInfo(track);
                    }
                    return bucket.PutKey(callId + ".wav", track.GetValue()).Apply(TFileUploader(callId, "upload_track_error"));
                }));
            }
            return NThreading::WaitExceptionOrAll(callResultFutures).Apply([
                call = std::move(callData),
                callResultFutures = std::move(callResultFutures)
            ](const NThreading::TFuture<void>& /* r */) mutable -> TSupportAICall
            {
                if (callResultFutures.empty()) {
                    ythrow TCallException(call.GetExternalId(), "process_call_error") << "empty_result";
                }
                for (auto&& callFuture : callResultFutures) {
                    callFuture.TryRethrow();
                }
                return std::move(call);
            });
        }
    };

    using TFutureResults = NDrive::TFutureCallResults<TSupportAICall>;
    using TExpectedCall = TExpected<TSupportAICall, NJson::TJsonValue>;
    using TExpectedCalls = TVector<TExpectedCall>;

    TExpectedCalls ProcessCalls(const NDrive::TSupportAIClient& client, const NS3::TBucket& bucket, const TSet<TString>& callIds) {
        TFutureResults futures;
        for (auto&& callId : callIds) {
            futures.EmplaceResult(callId, client.GetCallResult(callId).Apply(TCallProcessor(client, bucket, callId)));
        }
        TExpectedCalls result;
        auto waiter = NThreading::WaitAll(futures.GetResults());
        if (!waiter.Wait(client.GetConfig().GetRequestTimeout() + Max<TDuration>(client.GetConfig().GetRequestTimeout(), bucket.GetConfig().GetRequestTimeout()))) {
            for (auto&& callId : callIds) {
                result.emplace_back(MakeUnexpected<NJson::TJsonValue>(NJson::TMapBuilder
                    ("call_id", callId)
                    ("error_type", "process_call_error")
                    ("error_data", "timeouted")
                ));
            }
        }
        for (ui32 i = 0; i < futures.GetResults().size(); ++i) {
            auto futureResult = futures.GetResultPtr(i);
            if (!futureResult) {
                return result;
            }
            try {
                result.emplace_back(futureResult->ExtractValue());
            } catch (const TCallException& e) {
                result.emplace_back(MakeUnexpected<NJson::TJsonValue>(NJson::TMapBuilder
                    ("call_id", e.GetCallId())
                    ("error_type", e.GetErrorType())
                    ("error_data", e.what())
                ));
            } catch (...) {
                result.emplace_back(MakeUnexpected<NJson::TJsonValue>(NJson::TMapBuilder
                    ("call_id", futures.GetCallId(i))
                    ("error_type", "process_call_error")
                    ("error_data", NThreading::GetExceptionInfo(*futureResult))
                ));
            }
        }
        return result;
    }
}

TRTHistoryWatcherState::TFactory::TRegistrator<TRTSupportAIProcessState> TRTSupportAIProcessState::Registrator(TRTSupportAIProcess::GetTypeName());


TString TRTSupportAIProcessState::GetType() const {
    return TRTSupportAIProcess::GetTypeName();
}

TRTSupportAIProcess::TFactory::TRegistrator<TRTSupportAIProcess> TRTSupportAIProcess::Registrator;

TString TRTSupportAIProcess::GetTypeName() {
    return "support_ai_call_watcher";
}

TString TRTSupportAIProcess::GetType() const {
    return GetTypeName();
}

NDrive::TScheme TRTSupportAIProcess::DoGetScheme(const IServerBase& server) const {
    NDrive::TScheme scheme = TBase::DoGetScheme(server);
    scheme.Add<TFSNumeric>("pack_size", "Максимальное количество звонков в обработке за раз").SetDefault(PackSize);
    TSet<TString> tagNames;
    if (auto impl = server.GetAs<NDrive::IServer>()) {
        tagNames = impl->GetDriveAPI()->GetTagsManager().GetTagsMeta().GetRegisteredTagNames({ TSupportAITag::TypeName });
    }
    scheme.Add<TFSVariants>("tags", "Обрабатываемые теги").SetVariants(tagNames).SetMultiSelect(true);
    return scheme;
}

NJson::TJsonValue TRTSupportAIProcess::DoSerializeToJson() const {
    NJson::TJsonValue result = TBase::DoSerializeToJson();
    NJson::InsertField(result, "pack_size", PackSize);
    NJson::InsertNonNull(result, "tags", Tags);
    return result;
}

bool TRTSupportAIProcess::DoDeserializeFromJson(const NJson::TJsonValue& json) {
    return TBase::DoDeserializeFromJson(json)
        && NJson::ParseField(json["tags"], Tags)
        && NJson::ParseField(json["pack_size"], PackSize);
}

TExpectedState TRTSupportAIProcess::DoExecute(TAtomicSharedPtr<IRTBackgroundProcessState> state, const TExecutionContext& context) const {
    const auto& server = context.GetServerAs<NDrive::IServer>();
    if (!server.GetSupportCenterManager()) {
        ERROR_LOG << "Incorrect manager for support_ai" << Endl;
        return nullptr;
    }

    ui64 lastEventId = 0;
    if (const auto currentState = dynamic_cast<const TRTHistoryWatcherState*>(state.Get())) {
        lastEventId = currentState->GetLastEventId();
    }

    const auto& tagsManager = Yensured(server.GetDriveAPI())->GetTagsManager();
    const auto& userTagManager = tagsManager.GetUserTags();
    TOptionalTagHistoryEvents optionalEvents;
    {
        auto readTx = userTagManager.BuildSession(/* readOnly = */ true);
        IEntityTagsManager::TQueryOptions options(PackSize);
        options.SetActions({ EObjectHistoryAction::Remove });
        if (Tags.empty()) {
            options.SetTags(tagsManager.GetTagsMeta().GetRegisteredTagNames({ TSupportAITag::TypeName }));
        } else {
            options.SetTags(Tags);
        }
        optionalEvents = userTagManager.GetEvents(lastEventId + 1, Since, readTx, options);
        if (!optionalEvents) {
            return MakeUnexpected<TString>("cannot restore user tags: " + readTx.GetStringReport());
        }
    }

    TSet<TString> skipCalls;
    const auto& callsManager = server.GetSupportCenterManager()->GetSupportAICallManager();
    {
        TSet<TString> selectCallIds;
        for (auto&& event : *optionalEvents) {
            TSet<TString> callIds;
            if (auto tagSnapshot = dynamic_cast<TTagSnapshot*>(event->GetObjectSnapshot().Get()); tagSnapshot
                && NJson::ParseField(tagSnapshot->GetMeta(), "calls", callIds))
            {
                selectCallIds.insert(callIds.begin(), callIds.end());
            }
        }
        if (!selectCallIds.empty()) {
            auto callTx = callsManager.GetHistoryManager().BuildTx<NSQL::ETransactionTraits::ReadOnly>();
            if (auto calls = callsManager.GetObjects(callTx, NSQL::TQueryOptions().SetGenericCondition("external_id", selectCallIds))) {
                Transform(calls->begin(), calls->end(), std::inserter(skipCalls, skipCalls.begin()), [](const auto& call) { return call.GetExternalId(); });
            } else {
                return MakeUnexpected("Fail to add call: " + callTx.GetStringReport());
            }
        }
    }

    auto descriptions = tagsManager.GetTagsMeta().GetRegisteredTags(NEntityTagsManager::EEntityType::User, { TSupportAITag::TypeName });
    TVector<TSupportAICall> calls;
    for (auto&& event : *optionalEvents) {
        lastEventId = event.GetHistoryEventId();
        auto descriptionsPtr = descriptions.FindPtr(event->GetName());
        if (!descriptionsPtr || !descriptionsPtr->Get()) {
            continue;
        }
        if (auto description = descriptionsPtr->Get()->GetAs<TSupportAITag::TDescription>(); description && description->GetSaveResult()) {
            auto snapshot = event->GetObjectSnapshot();
            if (!snapshot) {
                NDrive::TEventLog::Log("SupportAISkip", NJson::TMapBuilder
                    ("error", "fail to tag snapshot")
                    ("tag_id", event.GetTagId())
                );
                continue;
            }
            if (auto tagSnapshot = dynamic_cast<TTagSnapshot*>(snapshot.Get())) {
                TSet<TString> callIds;
                if (!NJson::ParseField(tagSnapshot->GetMeta(), "calls", callIds)) {
                    NDrive::TEventLog::Log("SupportAISkip", NJson::TMapBuilder
                        ("error", "fail to parse tag snapshot")
                        ("tag_id", event.GetTagId())
                    );
                    continue;
                }
                for (auto&& callId : callIds) {
                    if (skipCalls.contains(callId)) {
                        continue;
                    }
                    TSupportAICall result;
                    result
                        .SetUserId(event.GetObjectId())
                        .SetStatus(TSupportAICall::EStatus::Initialized)
                        .SetStartTS(event.GetHistoryTimestamp())
                        .SetExternalId(callId);
                    calls.emplace_back(std::move(result));
                }
            }
        }
    }

    auto callTx = callsManager.GetHistoryManager().BuildTx<NSQL::Writable | NSQL::Deferred>();
    if (!callsManager.AddObjects(calls, GetRobotUserId(), callTx)) {
        return MakeUnexpected("Fail to add call: " + callTx.GetStringReport());
    }
    if (!callTx.Commit()) {
        return MakeUnexpected("Fail to commit calls: " + callTx.GetStringReport());
    }

    auto result = MakeAtomicShared<TRTSupportAIProcessState>();
    result->SetLastEventId(lastEventId);
    return result;
}


TRTSupportAISyncProcess::TFactory::TRegistrator<TRTSupportAISyncProcess> TRTSupportAISyncProcess::Registrator;

TString TRTSupportAISyncProcess::GetTypeName() {
    return "support_ai_call_loader";
}

TString TRTSupportAISyncProcess::GetType() const {
    return GetTypeName();
}

NDrive::TScheme TRTSupportAISyncProcess::DoGetScheme(const IServerBase& server) const {
    NDrive::TScheme scheme = TBase::DoGetScheme(server);
    scheme.Add<TFSNumeric>("max_calls_count", "Максимальное количество звонков в обработке за раз").SetDefault(MaxCallsCount);
    scheme.Add<TFSStructure>("event_time_filter", "Диапазон времени для обработки").SetStructure(TDateTimeFilterConfig::GetScheme());
    scheme.Add<TFSString>("bucket_name", "В какой бакет лить").SetRequired(true);
    scheme.Add<TFSNumeric>("max_calls_in_request", "Загружать за раз").SetDefault(MaxCallsInRequest);
    return scheme;
}

NJson::TJsonValue TRTSupportAISyncProcess::DoSerializeToJson() const {
    NJson::TJsonValue result = TBase::DoSerializeToJson();
    NJson::InsertField(result, "max_calls_count", MaxCallsCount);
    NJson::InsertNonNull(result, "event_time_filter", TimeFilterConfig);
    NJson::InsertField(result, "max_calls_in_request", MaxCallsInRequest);
    NJson::InsertField(result, "bucket_name", BucketName);
    return result;
}

bool TRTSupportAISyncProcess::DoDeserializeFromJson(const NJson::TJsonValue& jsonInfo) {
    return TBase::DoDeserializeFromJson(jsonInfo)
        && NJson::ParseField(jsonInfo["max_calls_count"], MaxCallsCount)
        && NJson::ParseField(jsonInfo["max_calls_in_request"], MaxCallsInRequest)
        && NJson::ParseField(jsonInfo["event_time_filter"], TimeFilterConfig)
        && NJson::ParseField(jsonInfo["bucket_name"], BucketName, /* required = */ true);
}

TExpectedState TRTSupportAISyncProcess::DoExecute(TAtomicSharedPtr<IRTBackgroundProcessState> /* state */, const TExecutionContext& context) const {
    const auto& server = context.GetServerAs<NDrive::IServer>();
    if (!server.GetSupportCenterManager()) {
        ERROR_LOG << "Incorrect manager for support_ai" << Endl;
        return nullptr;
    }
    const auto& callsManager = server.GetSupportCenterManager()->GetSupportAICallManager();
    const auto client = server.GetSupportAIClient();
    if (!client) {
        ERROR_LOG << "Incorrect client for support_ai" << Endl;
        return nullptr;
    }
    if (!server.GetSupportCenterManager()->GetMDSClient()) {
        ERROR_LOG << GetRobotId() << ": Support mds client is unknown" << Endl;
        return nullptr;
    }
    auto bucket = server.GetSupportCenterManager()->GetMDSClient()->GetBucket(BucketName);
    if (!bucket) {
        ERROR_LOG << GetRobotId() << ": Support mds bucket is unknown " << BucketName << Endl;
        return nullptr;
    }

    TMap<TString, TSupportAICall> callIds;
    {
        NSQL::TQueryOptions options(MaxCallsCount);
        options.AddGenericCondition("status", ToString(TSupportAICall::EStatus::Initialized));
        if (TimeFilterConfig) {
            TRange<ui64> timeParam;
            auto filter = TimeFilterConfig.GetFilter();
            if (filter.From) {
                timeParam.From = filter.From->Seconds();
            }
            if (filter.To) {
                timeParam.To = filter.To->Seconds();
            }
            options.SetGenericCondition("start_ts", timeParam);
        }
        auto callTx = callsManager.GetHistoryManager().BuildTx<NSQL::ETransactionTraits::ReadOnly>();
        auto calls = callsManager.GetObjects(callTx, options);
        if (!calls) {
            return MakeUnexpected("Fail to get calls: " + callTx.GetStringReport());
        }
        Transform(calls->begin(), calls->end(), std::inserter(callIds, callIds.begin()), [](auto& call) {
            TString id = call.GetExternalId();
            return std::make_pair(id, std::move(call));
        });
    }

    TVector<TSupportAICall> succeedCalls;
    {
        TSet<TString> externalIds;
        auto processCalls = [&]() {
            auto results = ProcessCalls(*client, *bucket, externalIds);
            for (auto&& result : results) {
                if (result) {
                    NDrive::TEventLog::Log("SupportAISuccess", NJson::TMapBuilder
                        ("call_id", result.GetValue().GetExternalId())
                    );
                    succeedCalls.emplace_back(std::move(result.GetValue()));
                } else {
                    NDrive::TEventLog::Log("SupportAIError", result.GetError());
                }
            }

        };
        for (auto&& [externalId, _] : callIds) {
            if (externalIds.size() >= MaxCallsInRequest) {
                processCalls();
                externalIds.clear();
            }
            externalIds.insert(externalId);
        }
        if (!externalIds.empty()) {
            processCalls();
        }
    }

    auto callTx = callsManager.GetHistoryManager().BuildTx<NSQL::Writable | NSQL::Deferred>();
    for (auto&& succeedCall : succeedCalls) {
        auto call = callIds.FindPtr(succeedCall.GetExternalId());
        if (!call) {
            NDrive::TEventLog::Log("SupportAIError", NJson::TMapBuilder
                ("call_id", succeedCall.GetExternalId())
                ("error_type", "process_call_error")
                ("error_data", "unknown_call")
            );
        }
        succeedCall.SetId(call->GetId());
        succeedCall.SetUserId(call->GetUserId());
        if (!callsManager.UpsertObject(succeedCall, GetRobotUserId(), callTx)) {
            return MakeUnexpected("Fail to add call info: " + callTx.GetStringReport());
        }
    }
    if (!callTx.Commit()) {
        return MakeUnexpected("Fail to add calls info: " + callTx.GetStringReport());
    }

    return new IRTBackgroundProcessState();
}
