#include "persqueue_context.h"

#include <saas/indexerproxy/dispatching/dispatch_queue.h>

#include <library/cpp/logger/global/global.h>

#include <util/generic/guid.h>
#include <util/string/join.h>

class THandlerOnSend {
private:
    const TString Topic;
    mutable ISrvDispContext* Replier;

private:
    void RegisterStatus(const ESendStatus& status, const NJson::TJsonValue& message) const {
        Replier->RegisterStatus(TSendStatus(ESource::PersQueue, Topic, status, false, message.GetStringRobust()));
    }

public:
    THandlerOnSend(const TString& topic, ISrvDispContext* replier)
        : Topic(topic)
        , Replier(replier)
    {
    }
    void OnSuccess(const NJson::TJsonValue& message) const {
        RegisterStatus(ESendStatus::SS_OK, message);
    }
    void OnError(const NJson::TJsonValue& message) const {
        RegisterStatus(ESendStatus::SS_SEND_FAILED, message);
    }
};

void TPersQueueManager::TShardsInfo::Init(ui32 shardsCount, NSaas::TShardsDispatcher::TPtr shardsDispatcher) {
    NumeratedShardsNames = true;
    NSaas::TSharding sharding;
    for (ui32 i = 0; i < shardsCount; ++i) {
        Shards.push_back(sharding.GetInterval(i, shardsCount));
    }
    ShardsDispatcher = shardsDispatcher;
}

TVector<TString> TPersQueueManager::TShardsInfo::GetShardNames() const {
    TVector<TString> shardNames;
    if (NumeratedShardsNames) {
        for (ui32 shardIndex = 0; shardIndex < Shards.size(); ++shardIndex) {
            shardNames.push_back(ToString(shardIndex));
        }
    } else {
        for (auto&& shard : Shards) {
            shardNames.push_back(ToString(shard));
        }
    }
    return shardNames;
}

TVector<TString> TPersQueueManager::TShardsInfo::GetShardNames( const NRTYServer::TMessage& message) const {
    auto selectedShards = GetShards(message);
    TVector<TString> shardNames;
    if (NumeratedShardsNames) {
        for (auto&& shard : selectedShards) {
            for (ui32 shardIndex = 0; shardIndex < Shards.size(); ++shardIndex) {
                if (Shards[shardIndex] == shard) {
                    shardNames.push_back(ToString(shardIndex));
                    break;
                }
            }
        }
    } else {
        for (auto&& shard : selectedShards) {
            shardNames.push_back(ToString(shard));
        }
    }
    return shardNames;
}

TString TPersQueueManager::TServiceWriter::GetTopicName(TStringBuf shard) const {
    return Join("-", Ident, "-shard", shard);
}

TString TPersQueueManager::TServiceWriter::CreateSourceId() const {
    return SessionPrefix + "_" + CreateGuidAsString();
}

void TPersQueueManager::TServiceWriter::CreateWriter(TStringBuf shard) {
    THolder<NPersQueue::IProducer> producer;
    NPersQueue::TProducerSeqNo seqNo;
    NPersQueue::TProducerSettings settings;
    settings.Server = NPersQueue::TServerSetting{ Server };
    settings.Topic = GetTopicName(shard);
    settings.SourceId = CreateSourceId();
    settings.ReconnectOnFailure = true;
    settings.StartSessionTimeout = TDuration::Max();
    if (CredentialsProvider) {
        settings.CredentialsProvider = CredentialsProvider;
    }

    try {
        producer = PQLib->CreateProducer(settings, Logger, true);
        auto future = producer->Start();
        if (!future.Wait(ConnectionTimeout)) {
            ythrow yexception() << "connection timeout";
        }
        if (future.GetValue().Response.HasError()) {
            ythrow yexception() << future.GetValue().Response;
        }
        seqNo = future.GetValue().Response.GetInit().GetMaxSeqNo();
    } catch (...) {
        ERROR_LOG << "Cannot create producer for"
            << " topic=" << settings.Topic
            << " server=" << settings.Server.Address
            << " sourceId=" << settings.SourceId
            << " error: " << CurrentExceptionMessage() << Endl;
        return;
    }
    WritersPerShard[shard].Reset(new TWriter(producer, seqNo));
}

TPersQueueManager::TServiceWriter::TWriterPtr TPersQueueManager::TServiceWriter::GetWriter(TStringBuf shard) {
    TGuard<TMutex> g(Lock);
    if (!WritersPerShard[shard]) {
        CreateWriter(shard);
    }
    return WritersPerShard[shard];
}

void TPersQueueManager::TServiceWriter::IncrementInFlight(TStringBuf shard) {
    InFlightPerShard[shard].Inc();
    Signals->UpdateInFlight(shard, InFlightPerShard[shard]);
}

void TPersQueueManager::TServiceWriter::DecrementInFlight(TStringBuf shard) {
    InFlightPerShard[shard].Dec();
    Signals->UpdateInFlight(shard, InFlightPerShard[shard]);
}

NJson::TJsonValue TPersQueueManager::TServiceWriter::DoSend(TStringBuf shard, const NPersQueue::TData& data) {
    bool writeResult = false;
    NJson::TJsonValue writeStatus;

    auto writer = GetWriter(shard);

    writeStatus.InsertValue("shard", shard);
    if (!writer) {
        writeStatus.InsertValue("error", "Cannot get writer");
    } else {
        IncrementInFlight(shard);
        try {
            auto future = writer->Write(data);
            if (!future.Wait(SendTimeout)) {
                ythrow yexception() << "send timeout";
            }
            if (future.GetValue().Response.HasError()) {
                ythrow yexception() << future.GetValue().Response;
            }
            writeResult = true;
        } catch (...) {
            writeStatus.InsertValue("error", CurrentExceptionMessage());
        }
        DecrementInFlight(shard);
    }
    writeStatus.InsertValue("result", writeResult);
    return writeStatus;
}

TPersQueueManager::TServiceWriter::TServiceWriter(
    TPQLibPtr pqLib,
    std::shared_ptr<NPersQueue::ICredentialsProvider> credentialsProvider,
    const TString& service,
    const NSearchMapParser::TServiceSpecificOptions* serviceConfig,
    const TProxyServiceConfig& writeConfig,
    const NSearchMapParser::TSearchMap& searchMap,
    TIntrusivePtr<NPersQueue::ILogger> logger
)
    : PQLib(pqLib)
    , CredentialsProvider(credentialsProvider)
    , Service(service)
    , Server(serviceConfig->GetServers())
    , Ident(serviceConfig->Stream)
    , SessionPrefix(writeConfig.SessionPrefix)
    , ConnectionTimeout(writeConfig.ConnectionTimeoutDuration)
    , SendTimeout(writeConfig.InteractionTimeoutDuration)
    , SendAttemptsCount(writeConfig.SendAttemptsCount)
    , Logger(logger)
{
    if (serviceConfig->NumTopics) {
        ShardsInfo.Init(serviceConfig->NumTopics, serviceConfig->ShardsDispatcher);
    } else {
        ShardsInfo.Init(searchMap, service);
    }
    VERIFY_WITH_LOG(ShardsInfo.Initilized(), "Not initilized sards for service: %s", service.data());
    const TVector<TString>& shards = ShardsInfo.GetShardNames();

    Signals.Reset(new TPersqueueServiceSignals(service, shards));

    for (auto&& shard : shards) {
        WritersPerShard[shard] = nullptr;
        InFlightPerShard[shard] = 0;
        CreateWriter(shard);
    }
}

bool TPersQueueManager::TServiceWriter::DoSend(const NPersQueue::TData& data, TStringBuf shard, NJson::TJsonValue& attempts) {
    bool written = false;
    auto startProcessTime = Now();
    for (ui32 attempt = 0; attempt < SendAttemptsCount; ++attempt) {
        auto writeStatus = DoSend(shard, data);
        attempts.AppendValue(writeStatus);
        if (writeStatus["result"].GetBoolean()) {
            written = true;
            break;
        }
    }
    ui64 processTime = (Now() - startProcessTime).MilliSeconds();
    Signals->Processed(written, shard, processTime);
    return written;
}

void TPersQueueManager::TServiceWriter::DoSend(const NRTYServer::TMessage& message, const THandlerOnSend* handler) {
    const TVector<TString> shards = ShardsInfo.GetShardNames(message);
    TString serializedMessgae;;
    Y_PROTOBUF_SUPPRESS_NODISCARD message.SerializeToString(&serializedMessgae);
    NPersQueue::TData data(serializedMessgae);

    NJson::TJsonValue attemptStatuses;

    auto startProcessTime = Now();

    bool written = true;
    for (auto&& shard : shards) {
        written = written && DoSend(data, shard, attemptStatuses);
    }
    ui64 processTime = (Now() - startProcessTime).MilliSeconds();
    NJson::TJsonValue status;
    status.InsertValue("result", written);
    status.InsertValue("shards_count", shards.size());
    status.InsertValue("attempts", attemptStatuses);
    status.InsertValue("process_time", processTime);

    if (written) {
        handler->OnSuccess(status);
    } else {
        handler->OnError(status);
    }
}

void TPersQueueManager::CreateServiceWriter(
    const TString& service,
    std::shared_ptr<NPersQueue::ICredentialsProvider> credentialsProvider,
    const TProxyConfig& globalConfig
) {
    TServiceWriterPtr writer(new TServiceWriter(
        PQLib,
        credentialsProvider,
        service,
        globalConfig.GetServiceInfo(service),
        globalConfig.GetServicesConfig().GetConfig(service),
        globalConfig.GetSearchMap(),
        Logger)
    );
    Services[service] = writer;
}

void TPersQueueManager::CreateWriters(const TList<TString> services, const TProxyConfig& config) {
    TList<std::pair<TString, NSaas::TTvmSettings>> servicesWithTvm;
    THashMap<NTvmAuth::TTvmId, NSaas::TTvmSettings> tvmSettingsByClient;
    for (auto&& service : services) {
        const auto tvmConfig = config.GetServicesConfig().GetConfig(service).TvmConfig;
        if (tvmConfig) {
            auto tvmSettings = tvmConfig->GetSettings();
            servicesWithTvm.emplace_back(service, tvmSettings);
            auto it = tvmSettingsByClient.find(tvmSettings.ClientId);
            if (it != tvmSettingsByClient.end()) {
                it->second.Merge(tvmSettings);
            } else {
                tvmSettingsByClient[tvmSettings.ClientId] = tvmSettings;
            }
        } else {
            WARNING_LOG << "Service without authorization in PQ: " << service << Endl;
            CreateServiceWriter(service, nullptr, config);
        }
    }

    THashMap<NTvmAuth::TTvmId, THashMap<TString, std::shared_ptr<NPersQueue::ICredentialsProvider>>> credentialsProviders;
    for (auto&& tvmInfo : tvmSettingsByClient) {
        auto& tvmSettings = tvmInfo.second;
        auto tvmClient = CreateTvmClient(tvmSettings);

        THashMap<TString, std::shared_ptr<NPersQueue::ICredentialsProvider>> credentialsProviderByAlias;
        for (auto dstClient : tvmSettings.DestinationClients) {
            credentialsProviderByAlias[dstClient.first] = CreateTVMCredentialsProvider(tvmClient, Logger, dstClient.first);
        }
        credentialsProviders[tvmSettings.ClientId] = credentialsProviderByAlias;
    }

    for (const auto& serviceInfo : servicesWithTvm) {
        auto& service = serviceInfo.first;
        auto& tvmSettings = serviceInfo.second;

        auto credentialsProvider = credentialsProviders[tvmSettings.ClientId][tvmSettings.DestinationClients.begin()->first];
        CreateServiceWriter(service, credentialsProvider, config);
    }
}

void TPersQueueManager::Init(const TProxyConfig& config) {
    TList<TString> pqServices;
    for (auto&& service : config.GetServiceMap()) {
        const auto serviceConfig = config.GetServiceInfo(service.first);
        if (serviceConfig && serviceConfig->IndexingTarget == NSearchMapParser::PersQueue) {
            pqServices.push_back(service.first);
        }
    }

    if (!pqServices.empty()) {
        PQLib.Reset(new NPersQueue::TPQLib());
        Logger.Reset(new NSaas::TPersQueueLogger<NPersQueue::ILogger>(config.GetDispConfig().GetLog()));
    }

    CreateWriters(pqServices, config);
}

void TPersQueueManager::DoSend(const TString& serviceName, const NRTYServer::TMessage& message, const THandlerOnSend* handler) {
    auto service = Services.find(serviceName);
    if (service == Services.end()) {
        ERROR_LOG << "No service '" << serviceName << "' in PersQueueManager" << Endl;
        handler->OnError("Incorrect stream config");
        return;
    }
    service->second->DoSend(message, handler);
}

TPersQueueReplier::TPersQueueReplier(const TString& serviceName, ISenderTask& task, const NRTYServer::TMessage& message, const TDispatcherConfig* config, TDeferredMessagesQueue& storage)
    : TReplier(serviceName, task, message, config, storage)
{
}

bool TPersQueueReplier::DoStoreStatus(const TString& backend, const TString& status) {
    VERIFY_WITH_LOG(GetServiceName() == backend, "%s != %s", GetServiceName().data(), backend.data());
    if (!!status)
        ReplyOverride = NRTYServer::TReply::dsCANT_STORE_IN_QUEUE;
    return !status;
}

void TPersQueueReplier::DoFinish() {
    CHECK_WITH_LOG(!IsAsync);
    if (GetDispStatus() == NRTYServer::TReply::dsNO_REPLY) {
        DEBUG_LOG << "PersQueue's message storing in queue... " << GetServiceName() << "/" << Message.GetDocument().GetUrl() << Endl;
        StoreInQueue(GetServiceName(), Message);
        DEBUG_LOG << "PersQueue's message stored in queue... OK " << GetServiceName() << "/" << Message.GetDocument().GetUrl() << Endl;
    }
}

NRTYServer::TReply::TDispStatus TPersQueueReplier::DoGetDispStatus() const {
    if (!Status.empty() && Status.begin()->second.Defined()) {
        const TSendStatus& status = Status.begin()->second.GetRef();
        switch (status.GetDispStatus()) {
        case ESendStatus::SS_OK:
            return NRTYServer::TReply::dsOK;
        case ESendStatus::SS_SEND_FAILED:
            return NRTYServer::TReply::dsINTERNAL_ERROR;
        default:
            return NRTYServer::TReply::dsUNDEFINED_REPLY_CODE;
        }
    }
    return NRTYServer::TReply::dsNO_REPLY;
}

bool TPersQueueReplier::DoVerify(const TProxyConfig& /*config*/, const TRTYMessageBehaviour& bh) const {
    if (bh.IsBroadcastMessage || IsAsyncMessage(GetMessage())) {
        ForceMessageStatus(NRTYServer::TReply::dsUSER_ERROR, "broadcast and async messages are not supported in PersQueue");
        return false;
    }
    return true;
}

bool TPersQueueReplier::DoSend(const TProxyConfig& config, TDispatchServiceSelector& selector) {
    const auto* it = config.GetServiceInfo(GetServiceName());
    VERIFY_WITH_LOG(it, "incorrect service %s", GetServiceName().data());
    CHECK_WITH_LOG(it->IndexingTarget == NSearchMapParser::PersQueue);
    VERIFY_WITH_LOG(!!it->Stream, "empty ident");
    VERIFY_WITH_LOG(!!it->GetServers(), "incorrect server");

    SetBackendCounter(1);

    const THandlerOnSend Handler(it->Stream, this);

    selector.GetPersQueueManager().DoSend(GetServiceName(), GetMessage(), &Handler);
    return true;
}
