#include "writer.h"

#include <saas/library/searchmap/parsers/parser.h>
#include <saas/library/searchmap/searchmap.h>
#include <saas/util/json/json.h>
#include <saas/util/network/http_request.h>

#include <library/cpp/balloc/optional/operators.h>
#include <library/cpp/logger/global/global.h>
#include <library/cpp/json/json_writer.h>

#include <util/string/join.h>
#include <util/generic/guid.h>
#include <util/random/random.h>
#include <util/system/thread.h>
#include <util/stream/file.h>
#include <util/system/thread.h>

namespace NSaas {

class TPersQueueWriterTelemetry : public ITelemetry {
public:
    TPersQueueWriterTelemetry(
        TPersQueueWriter* writer,
        TTelemetryConfig& config,
        TPQLibPtr pqLib,
        std::shared_ptr<NPersQueue::ICredentialsProvider> credentialsProvider
    )
        : ITelemetry("pq_writer_lib", config, pqLib, credentialsProvider)
        , Writer(writer)
    {}

    NJson::TJsonValue GetTelemetryData() const override {
        NJson::TJsonValue data;
        data.InsertValue("config", Writer->GetJsonConfig());
        data.InsertValue("status", Writer->GetStatus());
        return data;
    }

private:
    TPersQueueWriter* Writer;
};

TPersQueueWriter::TMessageSerializer::TMessageSerializer(const TCheckMessageSettings& settings, NPersQueueCommon::ECodec codec)
    : TMessageChecker(settings)
    , Codec(codec)
{
    // we will check serializated message size
    MaxSizeBytes = Settings.GetCheckMaxSizeBytes();
    Settings.SetCheckMaxSizeBytes(0);
}

NPersQueue::TData TPersQueueWriter::TMessageSerializer::CheckAndSerialize(const NRTYServer::TMessage& message) {
    Check(message);

    TString serialized;
    serialized.ReserveAndResize(message.ByteSizeLong());
    Y_PROTOBUF_SUPPRESS_NODISCARD message.SerializeToArray(serialized.begin(), serialized.size());

    if (MaxSizeBytes) {
        Y_ENSURE(serialized.size() < MaxSizeBytes, "Message size is more then limit: " << MaxSizeBytes << "bytes");
    }
    if (Codec == NPersQueueCommon::ECodec::RAW) {
        return NPersQueue::TData(std::move(serialized));
    }
    return NPersQueue::TData::Encode(std::move(serialized), Codec);
}

struct TPersQueueWriter::TShardProducer : public TAtomicRefCount<TPersQueueWriter::TShardProducer> {
    const TString Shard;
    const THolder<NPersQueue::IProducer> Producer;

    TShardProducer(const TString& shard, THolder<NPersQueue::IProducer>&& producer)
        : Shard(shard)
        , Producer(std::move(producer))
    {}
};

TPersQueueWriter::TWriteResult::TAttemtResult::TAttemtResult(ui16 attempt, bool written, TStringBuf shard, TStringBuf comment, ui32 size, ui32 compressedSize)
    : Attempt(attempt)
    , Written(written)
    , Shard(shard)
    , Comment(comment)
    , Size(size)
    , CompressedSize(compressedSize)
{}

NJson::TJsonValue TPersQueueWriter::TWriteResult::TAttemtResult::ToJson() const {
    NJson::TJsonValue json;
    json.InsertValue("attempt", Attempt);
    json.InsertValue("written", Written);
    json.InsertValue("shard", Shard);
    json.InsertValue("comment", Comment);
    json.InsertValue("size", Size);
    json.InsertValue("compressed_size", CompressedSize);
    return json;
}

TPersQueueWriter::TWriteResult::TWriteResult(
    bool written,
    bool userError,
    bool exceededInFlight,
    TStringBuf comment,
    const TVector<TAttemtResult>& attempts /*= {}*/
)
    : Written(written)
    , UserError(userError)
    , ExceededInFlight(exceededInFlight)
    , Comment(comment)
    , Attempts(attempts)
{}

NJson::TJsonValue TPersQueueWriter::TWriteResult::ToJson() const {
    NJson::TJsonValue json;
    json.InsertValue("written", Written);
    json.InsertValue("comment", Comment);
    if (UserError) {
        json.InsertValue("user_error", UserError);
    }
    if (ExceededInFlight) {
        json.InsertValue("exceeded_in_flight", ExceededInFlight);
    }
    NJson::TJsonValue jsonAttempts;
    for (auto&& attempt : Attempts) {
        jsonAttempts.AppendValue(attempt.ToJson());
    }
    json.InsertValue("attempts", jsonAttempts);
    return json;
}

TString TPersQueueWriter::TWriteResult::ToString() const {
    TString text;
    TStringOutput out(text);
    NJsonWriter::TBuf jsonWriter(NJsonWriter::HEM_DONT_ESCAPE_HTML, &out);
    auto json = ToJson();
    jsonWriter.WriteJsonValue(&json);
    return text;
}

struct TPersQueueWriter::TWriteState {
    TShardProducerPtr ShardProducer;
    ui32 Attempt;
    TWriteFuture Future;

    TWriteState(TShardProducerPtr shardProducer)
        : ShardProducer(shardProducer)
        , Attempt(0)
    {}
};

TPersQueueWriter::TWriteResult::TAttemtResult TPersQueueWriter::GetAttemptInfo(const TWriteState& status) const {
    bool written = status.Future.HasValue() && !status.Future.GetValue().Response.HasError();
    TString comment;
    ui32 size = 0;
    ui32 compressedSize = 0;
    if (status.Future.HasValue()) {
        const auto& value = status.Future.GetValue();
        comment = ToString(value.Response);
        compressedSize = value.Data.IsEncoded() ? value.Data.GetEncodedData().size() : value.Data.GetSourceData().size();
        size = value.Data.GetSourceData().size();
    } else {
        comment = "send timeout";
    }
    return TWriteResult::TAttemtResult(status.Attempt - 1, written, status.ShardProducer->Shard, comment, size, compressedSize);
}

TString TPersQueueWriter::GetTopicName(TStringBuf directoryWithTopics, TStringBuf shardName) {
    return Join("", TFsPath(directoryWithTopics).Fix(), "/", "shard-", shardName);
}

TString TPersQueueWriter::GetTopicName(
    TStringBuf directoryWithTopics,
    NSearchMapParser::TShardsInterval shard
) {
    return GetTopicName(directoryWithTopics, shard.ToString());
}

void* TPersQueueWriter::TThreadNamer::CreateThreadSpecificResource() const {
    TThread::SetCurrentThreadName(Name.c_str());
    return nullptr;
}

void TPersQueueWriter::TThreadNamer::DestroyThreadSpecificResource(void*) const {
}

void TPersQueueWriter::TThreadNamer::SetName(const TString& name) {
    // thread name is limited by 15 symbols on some platforms
    if (name.length() > 15) {
        Name = name.substr(0, 7) + '.' + name.substr(name.length() - 7);
    } else {
        Name = name;
    }
}

const TString& TPersQueueWriter::TThreadNamer::GetName() const {
    return Name;
}

TString TPersQueueWriter::CreateSourceId() const {
    return Settings.SourceIdPrefix + "_" + CreateGuidAsString();
}

ui64 TPersQueueWriter::GenerateMessageId() const {
    static TAtomic counter = 0;
    return AtomicIncrement(counter);
}

void TPersQueueWriter::PrepareMessage(NRTYServer::TMessage& message) const {
    auto now = TInstant::Now().Seconds();
    if (!message.HasReceiveTimestamp()) {
        message.SetReceiveTimestamp(now);
    }
    if (!message.HasMessageId()) {
        message.SetMessageId(GenerateMessageId());
    }
    if (message.HasDocument()) {
        auto doc = message.MutableDocument();
        if (!doc->HasModificationTimestamp()) {
            doc->SetModificationTimestamp(now);
        }
        if (!doc->HasVersionTimestamp()) {
            doc->SetVersionTimestamp(now);
        }
    }
}

bool TPersQueueWriter::MessageNeedsPrepare(const NRTYServer::TMessage& message) const {
    if (!Settings.PrepareMessages) {
        return false;
    }
    if (!message.HasReceiveTimestamp()) {
        return true;
    }
    if (!message.HasMessageId()) {
        return true;
    }
    if (message.HasDocument()) {
        if (!message.GetDocument().HasModificationTimestamp()) {
            return true;
        }
        if (!message.GetDocument().HasVersionTimestamp()) {
            return true;
        }
    }
    return false;
}

NPersQueue::TData TPersQueueWriter::GetPreparedData(const NRTYServer::TMessage& originalMessage) const {
    if (MessageNeedsPrepare(originalMessage)) {
        NRTYServer::TMessage message = originalMessage;
        PrepareMessage(message);
        return MessageSerializer->CheckAndSerialize(message);
    }
    return MessageSerializer->CheckAndSerialize(originalMessage);
}

THolder<IThreadPool> TPersQueueWriter::CreateThreadPool() {
    THolder<IThreadPool> queue;
    TThreadPoolBinder<TAdaptiveThreadPool, TThreadNamer>* adaptive = nullptr;
    if (Settings.NoWriteQueue) {
        queue = MakeHolder<TFakeThreadPool>();
    } else if (Settings.ThreadsCount) {
        queue = MakeHolder<TThreadPoolBinder<TThreadPool, TThreadNamer>>(WriteThreadNamer);
    } else {
        adaptive = new TThreadPoolBinder<TAdaptiveThreadPool, TThreadNamer>(WriteThreadNamer);
        queue.Reset(adaptive);
    }
    queue->Start(Settings.ThreadsCount, 0);

    if (adaptive) {
        adaptive->SetMaxIdleTime(TDuration::Seconds(100));
    }
    return queue;
}

void TPersQueueWriter::BeforeProcessWrite(const NRTYServer::TMessage& /*msg*/, const NPersQueue::TData& /*data*/, const TVector<NSearchMapParser::TShardsInterval>& /*shards*/) const {
}

void TPersQueueWriter::InitShards(const TVector<NSearchMapParser::TShardsInterval>& /*shards*/) const {
}

TPersQueueWriter::TPersQueueWriter() = default;

TPersQueueWriter::~TPersQueueWriter() {
    Stop();
}

void TPersQueueWriter::Init(const TPersQueueWriterSettings& settings) {
    Y_ENSURE(!Initilized, "writer is already initilized");
    {
        TString settingsError;
        if (!settings.IsCorrectlyFilled(&settingsError)) {
            ythrow yexception() << settingsError;
        }
    }

    INFO_LOG << "PersQueueWriter initializing..." << Endl;
    Settings = settings;

    if (!Settings.DryRun) {
        if (!Settings.PQLib) {
            PQLib = MakeHolder<NPersQueue::TPQLib>(Settings.PQLibSettings);
        } else {
            PQLib = Settings.PQLib;
        }
        InitThreads();
    }

    ProducerSettingsTemplate.Server = NPersQueue::TServerSetting{ Settings.Server };
    ProducerSettingsTemplate.ReconnectOnFailure = true;
    ProducerSettingsTemplate.CredentialsProvider = Settings.CredentialsProvider;
    ProducerSettingsTemplate.Server.UseLogbrokerCDS = Settings.UseLogbrokerCDS;
    if (Settings.Codec) {
        ProducerSettingsTemplate.Codec = Settings.Codec.value();
    }

    MessageSerializer = MakeHolder<TMessageSerializer>(
        Settings.CheckMessageSettings,
        ProducerSettingsTemplate.Codec
    );

    TServiceShards serviceShards;
    if (Settings.ServiceShards) {
        serviceShards = Settings.ServiceShards.value();
    } else {
        auto& searchmapSettings = Settings.SearchMapSettings.value();
        serviceShards.Init(searchmapSettings, Settings.ServiceName);

        if (searchmapSettings.UpdatePeriod != TDuration::Zero()) {
            TThread::TParams searchmapUpdaterThreadParams(&SearchMapUpdaterProc, this);
            searchmapUpdaterThreadParams.SetName(UpdateSearchmapThreadNamer.GetName());
            SearchMapUpdater = MakeHolder<TThread>(searchmapUpdaterThreadParams);
            SearchMapUpdater->Start();
        }
    }
    UpdateShards(serviceShards);

    TTelemetryConfig telemetryConfig;
    if (Settings.TelemetryConfig) {
        telemetryConfig = Settings.TelemetryConfig.value();
    } else if (Settings.TelemetryInterval) {
        telemetryConfig.SetInterval(Settings.TelemetryInterval.value());
        if (Settings.Server.StartsWith("lbkxt.") || Settings.Server.StartsWith("logbroker-prestable.")) {
            telemetryConfig.SetServer("lbkxt.logbroker.yandex.net");
        } else {
            telemetryConfig.SetServer("lbkx.logbroker.yandex.net");
        }
    } else {
        telemetryConfig.SetInterval(TDuration::Zero());
    }
    Telemetry = MakeHolder<TPersQueueWriterTelemetry>(this, telemetryConfig, PQLib, Settings.CredentialsProvider);

    Initilized = true;
    INFO_LOG << "PersQueueWriter initialized" << Endl;
}

void TPersQueueWriter::Stop() {
    if (Initilized) {
        INFO_LOG << "Writer stopping... In flight messages count: " << AtomicGet(InFlight) << Endl;
        Initilized = false;
        while (AtomicGet(InFlight) > 0) {
            Sleep(TDuration::MilliSeconds(100));
        }
        Stopped.Signal();
        WriteQueue.Reset();
        WaitQueue.Reset();
        SearchMapUpdater.Reset();

        INFO_LOG << "Writer stopped." << Endl;
    }
}

ui32 TPersQueueWriter::GetInFlight() const {
    return AtomicGet(InFlight);
}

NJson::TJsonValue TPersQueueWriter::GetStatus() const {
    NJson::TJsonValue status;
    NJson::TJsonValue shards;
    {
        TReadGuard rg(ShardsMutex);
        for (auto& s : ServiceShards.GetShards()) {
            shards.AppendValue(s.ToString());
        }
        status.InsertValue("ShardsUpdateTime", ShardsUpdateTime.Seconds());
    }
    status.InsertValue("Shards", shards);
    status.InsertValue("InFlight", GetInFlight());
    return status;
}

NJson::TJsonValue TPersQueueWriter::GetJsonConfig() const {
    return Settings.GetJsonConfig();
}

void TPersQueueWriter::InitThreads() {
    WriteThreadNamer.SetName("SPQWw_" + Settings.ThreadsName);
    WaitThreadNamer.SetName("SPQWt_" + Settings.ThreadsName);
    UpdateSearchmapThreadNamer.SetName("SPQWs_" + Settings.ThreadsName);

    WriteQueue = CreateThreadPool();
    WaitQueue = MakeHolder<TThreadPoolBinder<TThreadPool, TThreadNamer>>(WaitThreadNamer);
    WaitQueue->Start(1);
}

TPersQueueWriter::TWriteResult TPersQueueWriter::GetWriteErrorInfo(TStringBuf comment, bool userError, bool exceededInFlight) {
    return TWriteResult(false, userError, exceededInFlight, comment);
};

struct TPersQueueWriter::TWriteJob : public TAtomicRefCount<TWriteJob> {
    TWriteJob()
        : Promise(NThreading::NewPromise<TWriteResult>())
    {}

    TVector<TWriteState> States;
    NThreading::TPromise<TWriteResult> Promise;
    NPersQueue::TData Data;
    ui32 StatesInProcess = 0;
    TMutex Mutex;
};

bool TPersQueueWriter::Write(TWriteJobPtr job, ui32 stateIndex) {
    auto& writeState = job->States[stateIndex];
    writeState.Attempt++;
    auto info = GetAttemptInfo(writeState);
    if (!info.Written && writeState.Attempt <= Settings.MaxAttempts) {
        TInstant deadline = Now() + Settings.SendTimeout;
        ui32 attempt = writeState.Attempt;
        writeState.Future = writeState.ShardProducer->Producer->Write(job->Data);
        writeState.Future.Subscribe([this, job, stateIndex, attempt](TWriteFuture) {
            ProcessProducerResponse(job, stateIndex, attempt);
        });
        WaitQueue->SafeAddFunc([this, job, stateIndex, attempt, deadline]() {
            ProcessDeadline(job, stateIndex, attempt, deadline);
        });
        return true;
    }
    return false;
}

void TPersQueueWriter::Write(TPersQueueWriter::TWriteJobPtr job) {
    TGuard<TMutex> g (job->Mutex);
    AtomicIncrement(InFlight);
    for (ui32 i = 0; i < job->States.size(); ++i) {
        Write(job, i);
        ++job->StatesInProcess;
    }
}

void TPersQueueWriter::ProcessProducerResponse(TPersQueueWriter::TWriteJobPtr job, ui32 stateIndex, ui32 attempt) {
    TGuard<TMutex> g(job->Mutex);
    if (job->States[stateIndex].Attempt != attempt) {
        return;
    }
    if (Write(job, stateIndex)) {
        return;
    }
    --job->StatesInProcess;
    if (job->StatesInProcess == 0) {
        bool result = true;
        TVector<TWriteResult::TAttemtResult> attempts;
        attempts.reserve(job->States.size());
        for (const auto& state : job->States) {
            attempts.emplace_back(GetAttemptInfo(state));
            result = result && attempts.back().Written;
        }
        job->Promise.SetValue(TWriteResult(result, false, false, "", attempts));
        AtomicDecrement(InFlight);
    }
}

void TPersQueueWriter::ProcessDeadline(TWriteJobPtr job, ui32 stateIndex, ui32 attempt, TInstant deadline) {
    Stopped.WaitD(deadline);
    ProcessProducerResponse(job, stateIndex, attempt);
}

TPersQueueWriter::TWriteResultFuture TPersQueueWriter::CreateWriteJob(const NRTYServer::TMessage& message) {
    TWriteJobPtr job = MakeIntrusive<TWriteJob>();
    if (!Initilized) {
        job->Promise.SetValue(GetWriteErrorInfo("not initilized"));
        return job->Promise.GetFuture();
    }
    try {
        job->Data = GetPreparedData(message);
    } catch (...) {
        job->Promise.SetValue(GetWriteErrorInfo("incorrect message: " + CurrentExceptionMessage(), true));
        return job->Promise.GetFuture();
    }
    {
        TReadGuard rg(ShardsMutex);
        auto shards = ServiceShards.GetShards(message);
        job->States.reserve(shards.size());
        for (const auto& shard : shards) {
            job->States.emplace_back(ProducerPerShard[shard.ToString()]);
        }
        BeforeProcessWrite(message, job->Data, shards);
    }
    if (Settings.DryRun) {
        job->Promise.SetValue(TWriteResult(true, false, false, "Data was not really send becourse of DryRun"));
        return job->Promise.GetFuture();
    }
    try {
        if (Settings.ThreadsBlockingMode) {
            while (Settings.MaxInFlightCount && AtomicGet(InFlight) >= (i64)Settings.MaxInFlightCount) {
                Sleep(TDuration::MilliSeconds(10));
            }
            WriteQueue->SafeAddFunc([this, job]() {Write(job);});
        } else {
            if (Settings.MaxInFlightCount && AtomicGet(InFlight) >= (i64)Settings.MaxInFlightCount) {
                job->Promise.SetValue(GetWriteErrorInfo("exceeded the number of messages in flight", false, true));
            } else {
                WriteQueue->SafeAddFunc([this, job]() {Write(job); });
            }
        }
    } catch(...) {
        job->Promise.SetValue(GetWriteErrorInfo("cannot add message to queue: " + CurrentExceptionMessage()));
        return job->Promise.GetFuture();
    }
    return job->Promise.GetFuture();
}

THashMap<TString, TPersQueueWriter::TShardProducerPtr> TPersQueueWriter::CreateProducers(const TVector<NSearchMapParser::TShardsInterval>& shards) const {
    THashMap<TString, TShardProducerPtr> producerPerShard;

    if (!Settings.DryRun) {
        NPersQueue::TProducerSettings producerSettings = ProducerSettingsTemplate;
        for (const auto& shard : shards) {
            auto shardName = shard.ToString();
            producerSettings.Topic = GetTopicName(Settings.DirectoryWithTopics, shardName);
            producerSettings.SourceId = CreateSourceId();

            auto shardProducer = MakeIntrusive<TShardProducer>(
                shardName,
                PQLib->CreateProducer(producerSettings, Settings.Logger, true)
            );
            auto future = shardProducer->Producer->Start();
            if (!future.Wait(Settings.ConnectionTimeout)) {
                ythrow yexception() << "producer: connection timeout expired";
            }
            if (future.GetValue().Response.HasError()) {
                ythrow yexception() << "producer: " << future.GetValue().Response;
            }
            producerPerShard[shardName] =  shardProducer;
        }
    }
    return producerPerShard;
}

TPersQueueWriter::TWriteResultFuture TPersQueueWriter::Write(const NRTYServer::TMessage& message) {
    return CreateWriteJob(message);
}

TPersQueueWriter::TWriteResultFuture TPersQueueWriter::Write(const NSaas::TAction& action) {
    return Write(action.ToProtobuf());
}

TPersQueueWriter::TWriteResultFuture TPersQueueWriter::Write(const NJson::TJsonValue& document) {
    NSaas::TAction action;
    try {
        action.ParseFromJson(document);
    } catch (...) {
        auto promise = NThreading::NewPromise<TWriteResult>();
        promise.SetValue(GetWriteErrorInfo("incorrect message: " + CurrentExceptionMessage(), true));
        return promise.GetFuture();
    }
    return Write(action);
}

void TPersQueueWriter::UpdateShards(const TServiceShards& serviceShards) {
    Y_ENSURE(serviceShards.Initilized(), "received TServiceShards not initilized");
    auto shards = serviceShards.GetShards();
    THashMap<TString, TShardProducerPtr> producerPerShard = CreateProducers(shards);

    INFO_LOG << "Replace shards..." << Endl;
    TWriteGuard wg(ShardsMutex);
    ServiceShards = serviceShards;
    ProducerPerShard = std::move(producerPerShard);
    InitShards(shards);
    ShardsUpdateTime = TInstant::Now();
    INFO_LOG << "Shards update finished." << Endl;
}

void TPersQueueWriter::UpdateSearchMap(const NSearchMapParser::TSearchMap& searchMap) {
    Y_ENSURE(Settings.ServiceName, "service name not set");

    TServiceShards serviceShards;
    serviceShards.Init(searchMap, Settings.ServiceName);

    UpdateShards(serviceShards);
}

void TPersQueueWriter::UpdateSearchMap() noexcept {
    try {
        Y_ENSURE(Settings.SearchMapSettings, "SearchMap settings not set - incorrect usage");
        Y_ENSURE(Settings.ServiceName, "service name not set");
        TServiceShards serviceShards;
        serviceShards.Init(Settings.SearchMapSettings.value(), Settings.ServiceName);
        UpdateShards(serviceShards);
    } catch (...) {
        ERROR_LOG << "Cannot update SearchMap: " << CurrentExceptionMessage() << Endl;
    }
}

void* TPersQueueWriter::SearchMapUpdaterProc(void* object) {
    ThreadDisableBalloc();
    auto this_ = reinterpret_cast<TPersQueueWriter*>(object);
    TDuration updatePeriod = this_->Settings.SearchMapSettings.value().UpdatePeriod;
    while (true) {
        if (this_->Stopped.WaitT(updatePeriod)) {
            break;
        }
        INFO_LOG << "Shards update started." << Endl;
        this_->UpdateSearchMap();
    }
    return nullptr;
}

}
