#include "consumer.h"

#include <crypta/lib/native/pqlib/credentials_provider.h>

#include <utility>

using namespace NCrypta::NPQ;

namespace {
    NPersQueue::TConsumerSettings GetConsumerSettings(const TConsumerConfig& config, const std::shared_ptr<NPersQueue::ICredentialsProvider>& credentialsProvider) {
        Y_ENSURE(!config.GetTopics().empty(), "No topics were provided");

        NPersQueue::TConsumerSettings consumerSettings;
        consumerSettings.Server = NPersQueue::TServerSetting(config.GetServer(), static_cast<ui16>(config.GetPort()));
        consumerSettings.Topics.insert(consumerSettings.Topics.end(), config.GetTopics().begin(), config.GetTopics().end());
        consumerSettings.ClientId = config.GetClientId();
        consumerSettings.MaxInflyRequests = config.GetMaxInflyRequests();
        consumerSettings.MaxMemoryUsage = config.GetMaxMemoryUsage();
        consumerSettings.UseLockSession = config.GetUseLockSession();
        consumerSettings.ReadMirroredPartitions = config.GetReadMirroredPartitions();
        consumerSettings.MaxCount = config.GetMaxCount();
        consumerSettings.MaxSize = config.GetMaxSize();
        consumerSettings.Unpack = config.GetUnpack();
        consumerSettings.MaxUncommittedCount = config.GetMaxUncommittedCount();
        consumerSettings.CredentialsProvider = credentialsProvider;
        consumerSettings.MaxUncommittedSize = config.GetMaxUncommittedSize();

        if (config.GetSkipOlderThanSec() > 0) {
            consumerSettings.ReadTimestampMs = (TInstant::Now() - TDuration::Seconds(config.GetSkipOlderThanSec())).MilliSeconds();
        }

        return consumerSettings;
    }
}

void TConsumer::TReadResult::ClearData() {
    TConsumer::TData tmp;
    tmp.Swap(&Data);
}

TConsumer::TConsumer(
    TAtomicSharedPtr<NPersQueue::TPQLib> pqLib,
    TIntrusivePtr<NPersQueue::ILogger> logger,
    const TConsumerConfig& consumerConfig,
    const std::shared_ptr<NPersQueue::ICredentialsProvider>& credentialsProvider,
    TStats& stats)
    : PqLib(std::move(pqLib))
    , Logger(std::move(logger))
    , ConsumerSettings(GetConsumerSettings(consumerConfig, credentialsProvider))
    , Stats(stats)
    , CreateConsumerTimeout(TDuration::Seconds(consumerConfig.GetCreateConsumerTimeoutSec()))
    , Consumer(CreateConsumer())
    , Future(Consumer->GetNextMessage())
{
}

THolder<NPersQueue::IConsumer> TConsumer::CreateConsumer() {
    ++Epoch;

    const auto& deadline = CreateConsumerTimeout.ToDeadLine();

    while (deadline > TInstant::Now()) {
        Stats.Count->Add("create_consumer");

        auto newConsumer = PqLib->CreateConsumer(ConsumerSettings, Logger);
        auto startFuture = newConsumer->Start();

        if (!startFuture.Wait(deadline)) {
            break;
        }

        if (!startFuture.GetValue().Response.HasError()) {
            return newConsumer;
        }

        Logger->Log("Failed to start consumer: " + startFuture.GetValue().Response.GetError().GetDescription(), "", "", TLOG_ERR);
    }

    ythrow TCreateConsumerTimeoutException() << "CreateConsumer timed out";
}

TMaybe<TConsumer::TReadResult> TConsumer::GetNextData(TInstant deadline) {
    do {
        if (!Future.Wait(deadline)) {
            return Nothing();
        }
        auto value = Future.ExtractValue();
        Future = Consumer->GetNextMessage();

        if (value.Type == NPersQueue::EMT_ERROR) {
            Stats.Count->Add("msg.error");
            Consumer = CreateConsumer();
            Future = Consumer->GetNextMessage();
        } else if (value.Type == NPersQueue::EMT_LOCK) {
            Stats.Count->Add("msg.lock_partition");
            value.ReadyToRead.SetValue({});
        } else if (value.Type == NPersQueue::EMT_RELEASE) {
            Stats.Count->Add("msg.release_partition");
        } else if (value.Type == NPersQueue::EMT_DATA) {
            Stats.Count->Add("msg.data");

            auto& data = *value.Response.MutableData();

            TReadResult result;
            result.EpochCookie.Cookie = data.GetCookie();
            result.EpochCookie.Epoch = Epoch;

            Stats.Count->Add("msg.data.total_batches", data.MessageBatchSize());
            Stats.Percentile->Add("msg.data.batches_per_message", data.MessageBatchSize());

            ui64 chunksPerMessage = 0;

            for (auto& batch : *data.MutableMessageBatch()) {
                auto chunks = batch.GetMessage().size();
                Stats.Count->Add("msg.data.total_chunks", chunks);
                Stats.Percentile->Add("msg.data.chunks_per_batch", chunks);

                chunksPerMessage += chunks;

                for (auto& chunk : *batch.MutableMessage()) {
                    auto chunkSize = chunk.GetData().size();
                    Stats.Percentile->Add("msg.data.chunk_size", chunkSize);
                    Stats.Count->Add("msg.data.total_uncompressed_bytes", chunkSize);

                    result.Data.AddAllocated(chunk.release_data());
                }
            }

            Stats.Percentile->Add("msg.data.chunks_per_message", chunksPerMessage);
            return result;
        } else if (value.Type == NPersQueue::EMT_COMMIT) {
            Stats.Count->Add("msg.commit");
        } else {
            ythrow yexception() << "Unknown msg type in consumer: " << static_cast<int>(value.Type);
        }
    } while (TInstant::Now() < deadline);

    return Nothing();
}

TMaybe<TConsumer::TReadResult> TConsumer::GetNextData(TDuration timeout) {
    return GetNextData(timeout.ToDeadLine());
}

void TConsumer::Commit(const TVector<TEpochCookie>& epochCookies) {
    if (epochCookies.empty()) {
        return;
    }

    TVector<ui64> cookies;
    cookies.reserve(epochCookies.size());

    for (const auto& epochCookie: epochCookies) {
        if (epochCookie.Epoch == Epoch) {
            cookies.emplace_back(epochCookie.Cookie);
        }
    }

    Consumer->Commit(cookies);
}
