#include "refunds.h"

#include <drive/backend/rt_background/manager/state.h>

#include <drive/backend/billing/manager.h>
#include <drive/backend/data/billing_tags.h>
#include <drive/backend/database/drive_api.h>
#include <drive/backend/tags/tags_manager.h>

IRTRegularBackgroundProcess::TFactory::TRegistrator<TRTRefundsWatcher> TRTRefundsWatcher::Registrator(TRTRefundsWatcher::GetTypeName());
IRTRegularBackgroundProcess::TFactory::TRegistrator<TDeferredRefundsWatcher> TDeferredRefundsWatcher::Registrator(TDeferredRefundsWatcher::GetTypeName());

TExpectedState TRTRefundsWatcher::DoExecute(TAtomicSharedPtr<IRTBackgroundProcessState> /*stateExt*/, const TExecutionContext& context) const {
    const NDrive::IServer& server = context.GetServerAs<NDrive::IServer>();
    const TBillingManager& billingManager = server.GetDriveAPI()->GetBillingManager();
    const TUserTagsManager& userTagsManager = server.GetDriveAPI()->GetTagsManager().GetUserTags();

    auto registeredTags = server.GetDriveAPI()->GetTagsManager().GetTagsMeta().GetRegisteredTags();
    TVector<TString> refundTagNames;
    for (auto&& tag : registeredTags) {
        if (tag.second->GetType() == TRefundTag::TypeName) {
            refundTagNames.push_back(tag.second->GetName());
        }
    }

    TDBTags refundTags;
    {
        auto session = userTagsManager.BuildSession(true);
        auto optionalTags = userTagsManager.RestoreTags(TVector<TString>{}, refundTagNames, session);
        if (!optionalTags) {
            return MakeUnexpected("cannot RestoreTags: " + session.GetStringReport());
        }
        refundTags = std::move(*optionalTags);
    }

    for (auto&& commonTag : refundTags) {
        auto session = userTagsManager.BuildSession(false);
        TVector<TDBTag> toRemove;
        TVector<TDBTag> toUpdate;
        {
            auto descIt = registeredTags.find(commonTag->GetName());
            if (descIt == registeredTags.end()) {
                continue;
            }

            TRefundTag* tagData = dynamic_cast<TRefundTag*>(commonTag.GetData().Get());
            if (!tagData) {
                continue;
            }

            if (tagData->GetOperations().empty()) {
                toRemove.push_back(commonTag);
            } else {
                TVector<TRefundTask> sessionRefunds;
                if (!billingManager.GetPaymentsManager().GetRefundsDB().GetSessionRefunds(tagData->GetSessionId(), sessionRefunds, session)) {
                    ERROR_LOG << "Cannot read refunds for " << tagData->GetSessionId() << Endl;
                    break;
                }

                const auto compare = [](const TRefundTask& left, const TRefundTask& right) -> bool {
                    return left.GetId() < right.GetId();
                };
                Sort(sessionRefunds.begin(), sessionRefunds.end(), compare);

                TSet<TString> notFinishedRefunds;
                for (auto&& task : sessionRefunds) {
                    if (!task.GetFinished()) {
                        notFinishedRefunds.insert(task.GetPaymentId());
                    } else {
                        notFinishedRefunds.erase(task.GetPaymentId());
                    }
                }
                TVector<TRefundTag::TOperation> updates;
                for (auto&& refund : tagData->GetOperations()) {
                    if (notFinishedRefunds.contains(refund.GetPaymentId())) {
                        updates.push_back(refund);
                    }
                }
                if (updates.size() != tagData->GetOperations().size()) {
                    tagData->SetOperations(updates);
                    toUpdate.push_back(commonTag);
                }
            }
        }

        if (!userTagsManager.UpdateTagsData(toUpdate, GetRobotUserId(), session)
            || !userTagsManager.RemoveTagsSimple(toRemove, GetRobotUserId(), session, false)
            || !session.Commit())
        {
            INFO_LOG << "Commit fails TRTBillingTagsWatcher::DoExecute" << Endl;
            return MakeUnexpected<TString>({});
        }
    }

    return MakeAtomicShared<IRTBackgroundProcessState>();
}

NDrive::TScheme TRTRefundsWatcher::DoGetScheme(const IServerBase& server) const {
    return TBase::DoGetScheme(server);
}

bool TRTRefundsWatcher::DoDeserializeFromJson(const NJson::TJsonValue& jsonInfo) {
    return TBase::DoDeserializeFromJson(jsonInfo);
}

NJson::TJsonValue TRTRefundsWatcher::DoSerializeToJson() const {
    return TBase::DoSerializeToJson();
}

TExpectedState TDeferredRefundsWatcher::DoExecute(TAtomicSharedPtr<IRTBackgroundProcessState> /*stateExt*/, const TExecutionContext& context) const {
    const NDrive::IServer& server = context.GetServerAs<NDrive::IServer>();
    const TTraceTagsManager& traceTagsManager = server.GetDriveAPI()->GetTagsManager().GetTraceTags();

    TVector<TDBTag> tags;
    {
        auto session = traceTagsManager.BuildSession(/*readOnly=*/true);
        if (!traceTagsManager.RestoreTags({}, { TDeferredSessionRefundTag::Type() }, tags, session)) {
            ERROR_LOG << GetRobotId() << ": cannot get trace tags: " << session.GetStringReport() << Endl;
            return MakeUnexpected<TString>({});
        }
    }
    ui64 processed = 0;
    for (auto&& tag : tags) {
        const auto& sessionId = tag.GetObjectId();
        auto session = traceTagsManager.BuildSession();
        if (!TDeferredSessionRefundTag::Exercise(tag, GetRobotUserId(), server, session)) {
            ERROR_LOG << GetRobotId() << ": cannot execute DeferredSessionRefundTag for " << sessionId << ": " << session.GetStringReport() << Endl;
            continue;
        }
        if (!session.Commit()) {
            ERROR_LOG << GetRobotId() << ": cannot commit transaction from " << sessionId << ": " << session.GetStringReport() << Endl;
            continue;
        }
        ++processed;
        if (processed >= Limit) {
            break;
        }
    }

    return MakeAtomicShared<IRTBackgroundProcessState>();
}
