#include "queue_writer.h"

#include <travel/hotels/lib/cpp/util/compress.h>

#include <yt/yt/client/api/client.h>
#include <yt/yt/client/api/rowset.h>
#include <yt/yt/client/api/transaction.h>
#include <yt/yt/client/table_client/unversioned_row.h>
#include <yt/yt/client/table_client/name_table.h>
#include <yt/yt/client/table_client/row_buffer.h>


#include <library/cpp/logger/global/global.h>
#include <util/generic/guid.h>
#include <util/string/cast.h>

#define LOG_PFX "[" << Name_ << "] "

using namespace NTravelProto;

namespace NTravel {

void TYtQueueWriter::TCounters::QueryCounters(NMonitor::TCounterTable* ct) const {
    ct->insert(MAKE_COUNTER_PAIR(NAliveClusters));
    ct->insert(MAKE_COUNTER_PAIR(NRecords));
    ct->insert(MAKE_COUNTER_PAIR(NBytes));
    ct->insert(MAKE_COUNTER_PAIR(NWriteError));
}

TYtQueueWriter::TYtQueueWriter(const NTravelProto::NAppConfig::TConfigYtQueueWriter& config, const TString& name)
    : Config_(config)
    , Name_(name)
    , ClientCreator_(config.GetUser(), config.GetYtTokenPath())
{
}

TYtQueueWriter::~TYtQueueWriter() {
    Stop();
}

void TYtQueueWriter::RegisterCounters(NMonitor::TCounterSource& source) {
    source.RegisterSource(&Counters_, Name_);
}

void TYtQueueWriter::Start() {
    for (const TString& clusterName: Config_.GetClusterName()) {
        TClusterRef c = new TCluster;
        c->Name = clusterName;
        c->PrintName = "[" + clusterName + "]";
        Clusters_[clusterName] = c;
    }
    for (auto it = Clusters_.begin(); it != Clusters_.end(); ++it) {
        it->second->Thread = SystemThreadFactory()->Run([this, it]() {
            ThreadFunc(it->second);
        });
    }
}

void TYtQueueWriter::Stop() {
    if (StopFlag_) {
        return;
    }
    StopFlag_.Set();
    for (auto it = Clusters_.begin(); it != Clusters_.end(); ++it) {
        it->second->WakeUp.Signal();
    }
    for (auto it = Clusters_.begin(); it != Clusters_.end(); ++it) {
        if (it->second->Thread) {
            it->second->Thread->Join();
        }
    }
}

void TYtQueueWriter::Write(const NProtoBuf::Message& proto, TDuration lifetime) {
    Write(proto, ::Now(), CreateGuidAsString(), lifetime);
}

void TYtQueueWriter::Write(const NProtoBuf::Message& proto, TInstant timestamp, TString messageId, TDuration lifetime) {
    TYtQueueMessagePacked msg;
    msg.Timestamp = timestamp;
    if (lifetime) {
        msg.ExpireTimestamp = msg.Timestamp + lifetime;
    }
    msg.MessageType = proto.GetDescriptor()->full_name();
    msg.Codec = Config_.GetMessageCodec();
    msg.BytesPacked = Compress(Config_.GetMessageCodec(), Config_.GetMessageCompressionLevel(), proto.SerializeAsString());
    msg.MessageId = messageId;
    for (auto it = Clusters_.begin(); it != Clusters_.end(); ++it) {
        const TClusterRef& cluster = it->second;
        with_lock (cluster->Lock) {
            if (cluster->MessagePacks.empty() || cluster->MessagePacks.back().size() >= Config_.GetBatchMaxRecords()) {
                cluster->MessagePacks.emplace_back();
            }
            cluster->MessagePacks.back().push_back(msg);
            cluster->MessageCountToWrite += 1;
        }
    }
}

bool TYtQueueWriter::IsFlushed() const {
    for (auto it = Clusters_.begin(); it != Clusters_.end(); ++it) {
        const TClusterRef& cluster = it->second;
        with_lock (cluster->Lock) {
            if (cluster->MessageCountToWrite > 0) {
                return false;
            }
        }
    }
    return true;
}

void TYtQueueWriter::ThreadFunc(TClusterRef cluster) {
    while (!StopFlag_) {
        try {
            EnsureAlive(cluster);
            break;
        } catch (...) {
            ERROR_LOG << LOG_PFX << CurrentExceptionMessage() << Endl;
            cluster->WakeUp.WaitT(TDuration::Seconds(1));
        }
    }
    TDuration batchPeriod = TDuration::MilliSeconds(Config_.GetBatchPeriodMSec());
    while (!StopFlag_) {
        if (cluster->WakeUp.WaitT(batchPeriod)) {
            continue;
        }
        TVector<TVector<TYtQueueMessagePacked>> messagePacks;
        with_lock (cluster->Lock) {
            messagePacks.swap(cluster->MessagePacks);
        }
        for (const auto& messages: messagePacks) {
            DEBUG_LOG << LOG_PFX << "Going to write " << messages.size() << " messages to cluster " << cluster->PrintName << Endl;
            try {
                WriteMessages(cluster, messages);
                DEBUG_LOG << LOG_PFX << "Written " << messages.size() << " messages to cluster " << cluster->PrintName << Endl;
                // TODO add retries
            } catch (...) {
                ERROR_LOG << LOG_PFX << "Failed to write " << messages.size() << " messages to cluster " << cluster->PrintName
                          << ", cause: " << CurrentExceptionMessage() << Endl;
            }
            with_lock (cluster->Lock) {
                cluster->MessageCountToWrite -= messages.size();
            }
        }
    }
}

void TYtQueueWriter::EnsureAlive(TClusterRef cluster) {
    if (!cluster->IsAlive) {
        PingCluster(cluster);
        if (!cluster->IsAlive) {
            throw yexception() << "Cluster is not alive";
        }
    }
}

void TYtQueueWriter::PingCluster(TClusterRef cluster) {
    if (!cluster->YtClient) {
        try {
            INFO_LOG << LOG_PFX << "Creating client for cluster" << cluster->PrintName << Endl;
            cluster->YtClient = ClientCreator_.CreateClient(cluster->Name);
        } catch (...) {
            ERROR_LOG << LOG_PFX << "Failed to Create client for cluster " << cluster->PrintName << ", Error: " << CurrentExceptionMessage() << Endl;
            ChangeClusterAlive(cluster, false);
            return;
        }
    }
    try {
        TDuration timeout = TDuration::MilliSeconds(Config_.GetGetRowsOpTimeoutMSec());
        auto res1 = cluster->YtClient->GetTabletInfos(Config_.GetTablePath(), {0}).WithTimeout(timeout).Get();
        if (!res1.IsOK()) {
            throw yexception() << "Cannot do GetTotalRowCount: " << ToString(res1);
        }
        if (res1.Value().size() != 1) {
            throw yexception() << "Invalid result from GetTabletInfos: count " << res1.Value().size() << " != 1";
        }
        ChangeClusterAlive(cluster, true);
    } catch (...) {
        ERROR_LOG << LOG_PFX << "Failed to ping cluster " << cluster->PrintName << ", Error: " << CurrentExceptionMessage() << Endl;
        ChangeClusterAlive(cluster, false);
    }
}

void TYtQueueWriter::ChangeClusterAlive(TClusterRef cluster, bool isAlive) {
    if (cluster->IsAlive == isAlive) {
        if (isAlive) {
            DEBUG_LOG << LOG_PFX << "Cluster " << cluster->PrintName << " is still alive" << Endl;
        } else {
            DEBUG_LOG << LOG_PFX << "Cluster " << cluster->PrintName << " is still DEAD" << Endl;
        }
    } else {
        cluster->IsAlive = isAlive;
        if (isAlive) {
            Counters_.NAliveClusters.Inc();
            INFO_LOG << LOG_PFX << "Cluster " << cluster->PrintName << " is now alive!" << Endl;
        } else {
            Counters_.NAliveClusters.Dec();
            ERROR_LOG << LOG_PFX << "Cluster " << cluster->PrintName << " is now DEAD" << Endl;
        }
    }
}

void TYtQueueWriter::WriteMessages(TClusterRef cluster, const TVector<TYtQueueMessagePacked>& messages) {
    EnsureAlive(cluster);
    auto nameTable = NYT::New<NYT::NTableClient::TNameTable>();
    auto cIdTimestamp = nameTable->RegisterName("Timestamp");
    auto cIdMessageType = nameTable->RegisterName("MessageType");
    auto cIdCodec = nameTable->RegisterName("Codec");
    auto cIdBytes = nameTable->RegisterName("Bytes");
    auto cIdMessageId = nameTable->RegisterName("MessageId");
    auto cIdExpireTimestamp = nameTable->RegisterName("ExpireTimestamp");

    std::vector<NYT::NTableClient::TUnversionedRow> writeRows;
    TVector<NYT::NTableClient::TUnversionedOwningRow> writeRowsStorage;
    for (const auto& msg: messages) {
        NYT::NTableClient::TUnversionedOwningRowBuilder builder;
        builder.AddValue(NYT::NTableClient::MakeUnversionedUint64Value(msg.Timestamp.MilliSeconds(), cIdTimestamp));
        builder.AddValue(NYT::NTableClient::MakeUnversionedStringValue(msg.MessageType, cIdMessageType));
        builder.AddValue(NYT::NTableClient::MakeUnversionedUint64Value(msg.Codec, cIdCodec));
        builder.AddValue(NYT::NTableClient::MakeUnversionedStringValue(msg.BytesPacked, cIdBytes));
        builder.AddValue(NYT::NTableClient::MakeUnversionedStringValue(msg.MessageId, cIdMessageId));
        if (msg.ExpireTimestamp) {
            builder.AddValue(NYT::NTableClient::MakeUnversionedUint64Value(msg.ExpireTimestamp.MilliSeconds(), cIdExpireTimestamp));
        } else {
            builder.AddValue(NYT::NTableClient::MakeUnversionedNullValue(cIdExpireTimestamp));
        }
        writeRowsStorage.push_back(builder.FinishRow());
        writeRows.push_back(writeRowsStorage.back().Get());
    }

    NYT::NApi::ITransactionPtr transaction = cluster->YtClient->StartTransaction(NYT::NTransactionClient::ETransactionType::Tablet).Get().ValueOrThrow();
    NYT::NApi::TModifyRowsOptions opts;
    transaction->WriteRows(Config_.GetTablePath(), nameTable, NYT::MakeSharedRange(writeRows), opts);
    TDuration timeout = TDuration::MilliSeconds(Config_.GetWriteTimeoutMSec());
    auto result = transaction->Commit().WithTimeout(timeout).Get();
    if (!result.IsOK()) {
        Counters_.NWriteError.Inc();
        throw yexception() << "Failed to write to " << cluster->PrintName << ": " << ToString(result);
    }
    Counters_.NRecords += messages.size();
    Counters_.NBytes += std::accumulate(messages.begin(), messages.end(), NMonitor::TDerivCounter::TRawValue(),
        [] (NMonitor::TDerivCounter::TRawValue accum, const TYtQueueMessagePacked& message) {
            return accum + message.BytesPacked.size();
        });
}

TString TYtQueueWriter::Compress(EMessageCodec codec, size_t compressionLevel, const TString& bytes) {
    switch (codec) {
        case MC_NONE:
            return bytes;
        case MC_ZLIB:
            return ZLibCompress(bytes, ZLib::StreamType::Auto, compressionLevel);
        case MC_ZSTD:
            return ZStdCompress(bytes, compressionLevel);
        default:
            throw yexception() << "Unknown MessageCodec: " << (int)codec;
    }
}

}// namespace NTravel
