#include "key_value_storage.h"

#include <travel/hotels/lib/cpp/util/compress.h>

#include <yt/yt/client/api/public.h>
#include <yt/yt/client/api/client.h>
#include <yt/yt/client/api/rowset.h>
#include <yt/yt/client/table_client/row_buffer.h>
#include <yt/yt/client/table_client/name_table.h>
#include <yt/yt/client/table_client/unversioned_row.h>
#include <yt/yt/core/actions/bind.h>
#include <yt/yt/core/actions/callback.h>

#include <library/cpp/logger/global/global.h>
#include <util/string/builder.h>

#define LOG_PFX "[" << Name_ << "] "

namespace {
    void CheckType(const NYT::NTableClient::EValueType& expected, const NYT::NTableClient::EValueType& actual) {
        if (expected != actual) {
            throw yexception() << "Invalid field type: expected " << (int)expected << ", got " << (int)actual;
        }
    }

    bool IsNull(const NYT::NTableClient::TUnversionedValue& value) {
        return value.Type == NYT::NTableClient::EValueType::Null;
    };

    TString GetString(const NYT::NTableClient::TUnversionedValue& value) {
        CheckType(NYT::NTableClient::EValueType::String, value.Type);
        return value.AsString();
    }

    TInstant GetTimestamp(const NYT::NTableClient::TUnversionedValue& value, bool optional) {
        if (optional && IsNull(value)) {
            return TInstant();
        }
        CheckType(NYT::NTableClient::EValueType::Uint64, value.Type);
        return TInstant::MilliSeconds(value.Data.Uint64);
    }

    ui64 GetUint64(const NYT::NTableClient::TUnversionedValue& value) {
        CheckType(NYT::NTableClient::EValueType::Uint64, value.Type);
        return value.Data.Uint64;
    }

    TString Decompress(NTravelProto::EMessageCodec codec, const TString& bytes) {
        TString res;
        switch (codec) {
            case NTravelProto::MC_NONE:
                res = bytes;
                break;
            case NTravelProto::MC_ZLIB:
                res = NTravel::ZLibDecompress(bytes, ZLib::StreamType::Auto);
                break;
            case NTravelProto::MC_ZSTD:
                res = NTravel::ZStdDecompress(bytes);
                break;
            default:
                throw yexception() << "Unknown MessageCodec: " << (int)codec;
        }
        return res;
    }
}

namespace NTravel {
    TYtKeyValueStorage::TYtKeyValueStorage(const NTravelProto::NAppConfig::TConfigYtKeyValueStorage& config, const TString& name)
        : Config_(config)
        , Name_(name)
        , ClientCreator_(config.GetUser(), config.GetYtTokenPath())
        , TablePath_(config.GetTablePath())
        , MinReadsForSuccess_(config.GetMinReadsForSuccess())
        , MinReadsForNotFound_(config.HasMinReadsForNotFound() ? config.GetMinReadsForNotFound() : (Config_.GetClusterName().size() + 1) / 2)
        , ReadTimeout_(TDuration::MilliSeconds(Config_.GetReadTimeoutMSec()))
        , PerClusterCounters_({"YtCluster"})
    {
    }

    TYtKeyValueStorage::~TYtKeyValueStorage() {
        Stop();
    }

    void TYtKeyValueStorage::RegisterCounters(NMonitor::TCounterSource& source) {
        source.RegisterSource(&Counters_, Name_);
        source.RegisterSource(&PerClusterCounters_, Name_ + "PerCluster");
    }

    bool TYtKeyValueStorage::IsReady() const {
        return IsReady_;
    }

    void TYtKeyValueStorage::Start() {
        for (const TString& clusterName: Config_.GetClusterName()) {
            TCluster cluster;
            cluster.Name = clusterName;
            try {
                INFO_LOG << LOG_PFX << "Creating client for cluster " << cluster.Name << Endl;
                cluster.YtClient = ClientCreator_.CreateClient(cluster.Name);
            } catch (...) {
                ERROR_LOG << LOG_PFX << "Failed to create client for cluster " << cluster.Name << ", Error: " << CurrentExceptionMessage() << Endl;
            }
            Clusters_.emplace(clusterName, cluster);
        }

        Y_ENSURE(MinReadsForSuccess_ <= Clusters_.size(), "MinReadsForSuccess is more than clusters count");
        Y_ENSURE(MinReadsForNotFound_ <= Clusters_.size(), "MinReadsForNotFound is more than clusters count");

        LivenessCheckThread_ = SystemThreadFactory()->Run([this]() { RunLivenessChecking(); });
    }

    void TYtKeyValueStorage::Stop() {
        if (!Stopping_.TrySet()) {
            return;
        }
        StopEvent_.Signal();
        if (LivenessCheckThread_) {
            LivenessCheckThread_->Join();
            LivenessCheckThread_ = nullptr;
        }
    }

    void TYtKeyValueStorage::RunLivenessChecking() {
        while (!Stopping_) {
            size_t okCnt = 0;
            for (auto&[name, cluster]: Clusters_) {
                if (Stopping_) {
                    break;
                }
                auto opts = NYT::NApi::TCheckClusterLivenessOptions{};
                opts.CheckCypressRoot = true;
                auto result = cluster.YtClient->CheckClusterLiveness(opts).WithTimeout(LivenessCheckTimeout_).Get();
                if (result.IsOK()) {
                    okCnt++;
                } else {
                    WARNING_LOG << LOG_PFX << "Cluster " << name << " is not alive: " << ToString(result) << Endl;
                }
                PerClusterCounters_.GetOrCreate({name})->IsReady = result.IsOK();
            }
            if (okCnt >= MinReadsForSuccess_ && okCnt >= MinReadsForNotFound_ && IsReady_.TrySet()) {
                INFO_LOG << LOG_PFX << "Enough clusters are alive, storage is ready" << Endl;
                Counters_.IsReady = 1; // It's sticky and doesn't become 0 after getting to 1
            }
            StopEvent_.WaitT(CheckLivenessEachPeriod_);
        }
    }

    NYT::TSharedRange<NYT::NTableClient::TLegacyKey> TYtKeyValueStorage::PrepareRange(const TVector<TString>& keys, i32 keyField) const {
        auto rowBuffer = NYT::New<NYT::NTableClient::TRowBuffer>();
        TVector<NYT::NTableClient::TUnversionedRow> rowKeys;
        for (const auto& key : keys) {
            NYT::NTableClient::TUnversionedRowBuilder builder;
            builder.AddValue(NYT::NTableClient::MakeUnversionedStringValue(key, keyField));
            rowKeys.push_back(rowBuffer->CaptureRow(builder.GetRow()));
        }
        return NYT::MakeSharedRange(std::move(rowKeys), std::move(rowBuffer));
    }

    TVector<NYT::TFuture<NYT::NApi::IUnversionedRowsetPtr>> TYtKeyValueStorage::SendYtLookupRequests(const NYT::NTableClient::TNameTablePtr& nameTable,
                                                                                                     const NYT::TSharedRange<NYT::NTableClient::TLegacyKey>& keys) const {
        NYT::NApi::TLookupRowsOptions options;
        options.KeepMissingRows = true;
        options.Timestamp = NYT::NTransactionClient::SyncLastCommittedTimestamp;
        options.Timeout = ReadTimeout_;

        TVector<NYT::TFuture<NYT::NApi::IUnversionedRowsetPtr>> readFutures;
        readFutures.reserve(Clusters_.size());
        with_lock (ClustersMutex_) {
            for (auto&[name, cluster]: Clusters_) {
                TProfileTimer started;
                auto& clusterName = name;
                NYT::TCallback callback = BIND([this, &clusterName, started](const NYT::TErrorOr<NYT::NApi::IUnversionedRowsetPtr>& result) {
                    PerClusterCounters_.GetOrCreate({clusterName})->ReadTimeMs.Update(started.Get().MilliSeconds());
                    if (!result.IsOK() && result.GetCode() != NYT::EErrorCode::Canceled) {
                        PerClusterCounters_.GetOrCreate({clusterName})->NReadError.Inc();
                        WARNING_LOG << LOG_PFX << "Failed to read from: " << clusterName << ": " << ToString(result) << Endl;
                    }
                    return result.ValueOrThrow();
                });
                readFutures.emplace_back(cluster.YtClient->LookupRows(TablePath_, nameTable, keys, options)
                                             .WithTimeout(ReadTimeout_)
                                             .Apply(callback));
            }
        }
        return readFutures;
    }

    std::pair<TYtKeyValueStorage::EReadStatus, TVector<NYT::TFuture<NYT::NApi::IUnversionedRowsetPtr>> /* futuresToWait */> TYtKeyValueStorage::ProcessReadFutures(
        const TVector<NYT::TFuture<NYT::NApi::IUnversionedRowsetPtr>>& readFutures,
        size_t rowsCnt) const {

        size_t cntFailedFutures = 0;

        struct TReadStatus {
            size_t CntFound = 0;
            size_t CntNotFound = 0;
        };

        TVector<TReadStatus> statuses(rowsCnt);

        TVector<NYT::TFuture<NYT::NApi::IUnversionedRowsetPtr>> futuresToWait;
        futuresToWait.reserve(readFutures.size());

        for (const auto& readFuture: readFutures) {
            if (readFuture.IsSet()) {
                if (readFuture.Get().IsOK()) {
                    const auto& rows = readFuture.Get().Value()->GetRows();
                    for (size_t i = 0; i < rowsCnt; i++) {
                        if (rows[i]) {
                            statuses[i].CntFound++;
                        } else {
                            statuses[i].CntNotFound++;
                        }
                    }
                } else {
                    cntFailedFutures++;
                }
            } else {
                futuresToWait.push_back(readFuture);
            }
        }

        bool allFound = true;
        bool allNotFoundOrFound = true;
        for (const auto& status: statuses) {
            allFound &= status.CntFound >= MinReadsForSuccess_;
            allNotFoundOrFound &= status.CntFound + status.CntNotFound >= MinReadsForNotFound_;
        }
        if (allFound) {
            return {EReadStatus::Found, {}};
        }
        if (allNotFoundOrFound) {
            return {EReadStatus::NotFound, {}};
        }

        if (cntFailedFutures + Min(MinReadsForSuccess_, MinReadsForNotFound_) > readFutures.size()) {
            return {EReadStatus::Failed, {}};
        }

        for (const auto& status: statuses) {
            if (MinReadsForSuccess_ + cntFailedFutures + status.CntNotFound > readFutures.size() + status.CntFound && // MinReadsForSuccess_ - status.CntFound > readFutures.size() - cntFailedFutures - status.CntNotFound
                MinReadsForNotFound_ + cntFailedFutures > readFutures.size() + status.CntFound + status.CntNotFound) { // MinReadsForNotFound_ - status.CntFound - status.CntNotFound > readFutures.size() - cntFailedFutures
                return {EReadStatus::Failed, {}};
            }
        }

        return {EReadStatus::Unknown, futuresToWait};
    }

    TVector<TMaybe<TYtKeyValueStorage::TRecord>> TYtKeyValueStorage::DoRead(const TVector<TString>& keys) const {
        TProfileTimer started;
        Counters_.NRequests.Inc();
        Counters_.NTotalKeysRequested += keys.size();

        auto nameTable = NYT::New<NYT::NTableClient::TNameTable>();
        // Order is important
        int messageIdField = nameTable->GetIdOrRegisterName("MessageId");
        int timestampField = nameTable->GetIdOrRegisterName("Timestamp");
        int expireTimestampField = nameTable->GetIdOrRegisterName("ExpireTimestamp");
        int messageTypeField = nameTable->GetIdOrRegisterName("MessageType");
        int codecField = nameTable->GetIdOrRegisterName("Codec");
        int bytesField = nameTable->GetIdOrRegisterName("Bytes");

        const auto range = PrepareRange(keys, messageIdField);
        auto readFutures = SendYtLookupRequests(nameTable, range);

        auto readStatus = EReadStatus::Unknown;
        TVector<NYT::TFuture<NYT::NApi::IUnversionedRowsetPtr>> futuresToWait = readFutures;
        auto maxTries = readFutures.size() * 2;
        for (size_t i = 0; readStatus == EReadStatus::Unknown && i < maxTries; i++) {
            NYT::TFutureCombinerOptions opts;
            opts.CancelInputOnShortcut = false;
            NYT::AnySet(futuresToWait, opts).Get();
            std::tie(readStatus, futuresToWait) = ProcessReadFutures(readFutures, keys.size());
        }
        Y_ENSURE(readStatus != EReadStatus::Unknown, "Read status is unknown after " + ToString(maxTries) + " tries. key[0] = " + keys[0]);

        if (readStatus == EReadStatus::Failed) {
            TStringBuilder builder;
            builder << "Failed to read. Errors: \n";
            int ind = 0;
            for (const auto& future: readFutures) {
                if (future.IsSet() && !future.Get().IsOK()) {
                    builder << "Error #" << ind << "\n" << ToString(future.Get());
                    ind++;
                }
            }
            Counters_.NReadError.Inc();
            throw yexception() << TString(builder);
        }

        TVector<NYT::NApi::IUnversionedRowsetPtr> perClusterRows;
        perClusterRows.reserve(readFutures.size());
        for (const auto& readFuture: readFutures) {
            if (readFuture.IsSet() && readFuture.Get().IsOK()) {
                perClusterRows.push_back(readFuture.Get().ValueOrThrow());
            }
        }

        TVector<TMaybe<TRecord>> resultRecords;
        resultRecords.reserve(keys.size());

        for (size_t i = 0; i < keys.size(); i++) {
            auto record = TMaybe<TRecord>();
            for (const auto& rowSet: perClusterRows) {
                auto& row = rowSet->GetRows()[i];
                if (readStatus == EReadStatus::Found && row) {
                    auto messageId = GetString(row[messageIdField]);
                    auto timestamp = GetTimestamp(row[timestampField], false);
                    auto expireTimestamp = GetTimestamp(row[expireTimestampField], true);
                    auto messageType = GetString(row[messageTypeField]);
                    auto codec = static_cast<NTravelProto::EMessageCodec>(GetUint64(row[codecField]));
                    auto bytes = GetString(row[bytesField]);

                    Y_ENSURE(messageId == keys[i], "Unexpected message id (expected: " + keys[i] + ", got: " + messageId + ")");
                    record = TRecord{timestamp, expireTimestamp, Decompress(codec, bytes), messageId, messageType};
                    break;
                }
            }
            if (record.Empty()) {
                Counters_.NNotFound.Inc();
            } else {
                Counters_.NFound.Inc();
            }
            resultRecords.push_back(record);
        }

        Counters_.ReadTimeMs.Update(started.Get().MilliSeconds());

        return resultRecords;
    }

    TYtKeyValueStorage::TCounters::TCounters()
        : ReadTimeMs({10, 30, 50, 100, 200, 300, 500, 1000, 2000, 3000})
    {
    }

    void TYtKeyValueStorage::TCounters::QueryCounters(NMonitor::TCounterTable* ct) const {
        ct->insert(MAKE_COUNTER_PAIR(IsReady));
        ct->insert(MAKE_COUNTER_PAIR(NRequests));
        ct->insert(MAKE_COUNTER_PAIR(NTotalKeysRequested));
        ct->insert(MAKE_COUNTER_PAIR(NNotFound));
        ct->insert(MAKE_COUNTER_PAIR(NFound));
        ct->insert(MAKE_COUNTER_PAIR(NReadError));

        ReadTimeMs.QueryCounters("ReadTime", "Ms", ct);
    }

    TYtKeyValueStorage::TPerClusterCounters::TPerClusterCounters()
        : ReadTimeMs({10, 30, 50, 100, 200, 300, 500, 1000, 2000, 3000})
    {
    }

    void TYtKeyValueStorage::TPerClusterCounters::QueryCounters(NMonitor::TCounterTable* ct) const {
        ct->insert(MAKE_COUNTER_PAIR(IsReady));
        ct->insert(MAKE_COUNTER_PAIR(NReadError));

        ReadTimeMs.QueryCounters("ReadTime", "Ms", ct);
    }
}
