#include "watcher.h"

#include <solomon/services/dataproxy/lib/cluster_map/cluster_map.h>
#include <solomon/services/dataproxy/lib/initialization/enable_initialization.h>

#include <solomon/libs/cpp/actors/config/log_component.pb.h>
#include <solomon/libs/cpp/labels/known_keys.h>
#include <solomon/libs/cpp/logging/logging.h>
#include <solomon/libs/cpp/selfmon/selfmon.h>

#include <library/cpp/actors/core/actor_bootstrapped.h>
#include <library/cpp/actors/core/hfunc.h>

#include <util/generic/hash.h>

using namespace NActors;
using namespace yandex::solomon::metabase;

namespace NSolomon::NDataProxy {
namespace {

struct TShardKeyRef {
    const TString* Project{nullptr};
    const TString* Cluster{nullptr};
    const TString* Service{nullptr};

    explicit operator bool() const noexcept {
        return Project && Cluster && Service;
    }

    static TShardKeyRef FromShardStatus(const TShardStatus& status) {
        TShardKeyRef ref;
        for (const auto& label: status.GetLabels()) {
            if (label.key() == NLabels::LABEL_PROJECT) {
                ref.Project = &label.value();
            } else if (label.key() == NLabels::LABEL_CLUSTER) {
                ref.Cluster = &label.value();
            } else if (label.key() == NLabels::LABEL_SERVICE) {
                ref.Service = &label.value();
            }
        }
        return ref;
    }

    void CopyTo(TShardKey* key) const {
        key->Project = *Project;
        key->SubKey.Cluster = *Cluster;
        key->SubKey.Service = *Service;
    }

    bool operator !=(const TShardKey& key) const noexcept {
        return *Project != key.Project || *Cluster != key.SubKey.Cluster || *Service != key.SubKey.Service;
    }
};

struct TLocalEvents: private TPrivateEvents {
    enum {
        RequestStatus = SpaceBegin,
        ReceiveStatus,
        ReceiveError,
        NodeUpdated,
        NodeFailed,
        End,
    };
    static_assert(End < SpaceEnd, "too many event types");

    struct TRequestStatus: public TEventLocal<TRequestStatus, RequestStatus> {
    };

    struct TReceiveStatus: public TEventLocal<TReceiveStatus, ReceiveStatus> {
        std::unique_ptr<const TServerStatusResponse> Response;

        explicit TReceiveStatus(std::unique_ptr<const TServerStatusResponse> response) noexcept
            : Response{std::move(response)}
        {
        }
    };

    struct TReceiveError: public TEventLocal<TReceiveError, ReceiveError> {
        TMetabaseError Error;

        explicit TReceiveError(TMetabaseError&& error) noexcept
            : Error{std::move(error)}
        {
        }
    };

    struct TNodeUpdated: public TEventLocal<TNodeUpdated, NodeUpdated> {
        const ui32 NodeId;
        const TInstant Time;
        size_t ShardCount{0};
        size_t TotalShardCount{0};
        std::vector<TShardInfoPtr> Updated;
        std::vector<TShardId> Removed;

        TNodeUpdated(ui32 nodeId, TInstant time) noexcept
            : NodeId{nodeId}
            , Time{time}
        {
        }
    };

    struct TNodeFailed: public TEventLocal<TNodeFailed, NodeFailed> {
        const ui32 NodeId;
        const TInstant Time;
        TMetabaseError Error;

        TNodeFailed(ui32 nodeId, TInstant time, TMetabaseError&& error) noexcept
            : NodeId{nodeId}
            , Time{time}
            , Error{std::move(error)}
        {
        }
    };
};

/**
 * Periodically downloads information about shards located on particular node.
 */
class TNodeWatcher: public TActorBootstrapped<TNodeWatcher> {
public:
    TNodeWatcher(ui32 nodeId, TStringBuf address, IMetabaseRpc* rpc, TDuration updateDelay)
        : NodeId_(nodeId)
        , Address_(address)
        , Rpc_(rpc)
        , InitialUpdateDelay_(updateDelay)
        , UpdateDelay_(updateDelay)
    {
    }

    void Bootstrap(TActorId parentId) {
        ParentId_ = parentId;
        OnRequest();
    }

    /**
     * In this state the actor waits until the next poling interval occurs.
     */
    STATEFN(Sleeping) {
        switch (ev->GetTypeRewrite()) {
            sFunc(TLocalEvents::TRequestStatus, OnRequest);
            sFunc(TEvents::TEvPoison, OnDie);
        }
    }

    /**
     * In this state the actor waits for a response (or an error) from node.
     */
    STATEFN(WaitingResponse) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TLocalEvents::TReceiveStatus, OnStatus);
            hFunc(TLocalEvents::TReceiveError, OnError);
            hFunc(TEvents::TEvPoison, OnPoison);
        }
    }

    /**
     * The actor enters this state after receiving {@code Poison} event while waiting for an
     * incomplete response from node.
     */
    STATEFN(Dying) {
        switch (ev->GetTypeRewrite()) {
            sFunc(TLocalEvents::TReceiveStatus, OnDie);
            sFunc(TLocalEvents::TReceiveError, OnDie);
        }
    }

    void OnRequest() {
        MON_TRACE(MetabaseWatcher, "request status of {" << Address_ << "}");
        Become(&TThis::WaitingResponse);

        TActorId selfId = SelfId();
        auto* actorSystem = TActorContext::ActorSystem();

        TServerStatusRequest req;
        req.SetShardIdsHash(ShardIdsHash_);
        req.SetDeadlineMillis((TActivationContext::Now() + 2 * UpdateDelay_).MilliSeconds());

        Rpc_->ServerStatus(req)
                .Subscribe([selfId, actorSystem](TAsyncMetabaseStatusResponse respFuture) {
                    std::unique_ptr<IEventBase> event;
                    try {
                        auto valueOrError = respFuture.ExtractValue();
                        if (valueOrError.Success()) {
                            event = std::make_unique<TLocalEvents::TReceiveStatus>(valueOrError.Extract());
                        } else {
                            event = std::make_unique<TLocalEvents::TReceiveError>(valueOrError.ExtractError());
                        }
                    } catch (...) {
                        event = std::make_unique<TLocalEvents::TReceiveError>(TMetabaseError{CurrentExceptionMessage()});
                    }
                    actorSystem->Send(selfId, event.release());
                });
    }

    void OnStatus(TLocalEvents::TReceiveStatus::TPtr& ev) {
        MON_TRACE(MetabaseWatcher, "got update from {" << Address_ << '}');
        UpdateState(std::move(ev->Get()->Response));

        // restore update interval and schedule next update
        UpdateDelay_ = InitialUpdateDelay_;
        Schedule(UpdateDelay_, new TLocalEvents::TRequestStatus{});
        Become(&TThis::Sleeping);
    }

    void OnError(TLocalEvents::TReceiveError::TPtr& ev) {
        MON_WARN(MetabaseWatcher, "cannot get update from {" << Address_ << "}, error: " << ev->Get()->Error);

        // notify parent about error
        auto event = std::make_unique<TLocalEvents::TNodeFailed>(
                NodeId_,
                TActivationContext::Now(),
                std::move(ev->Get()->Error));
        Send(ParentId_, event.release());

        // slowdown exponentially and add some jitter
        TDuration jitter = TDuration::MilliSeconds(RandomNumber(2'000ull));
        UpdateDelay_ = Min(1.5 * UpdateDelay_ + jitter, TDuration::Seconds(10));

        // schedule next update
        MON_TRACE(MetabaseWatcher, "will retry request status of {" << Address_ << "} after " << UpdateDelay_);
        Schedule(UpdateDelay_, new TLocalEvents::TRequestStatus{});
        Become(&TThis::Sleeping);
    }

    void OnPoison(TEvents::TEvPoison::TPtr&) {
        // there is not yet completed query, which we have to wait
        Become(&TThis::Dying);
    }

    void OnDie() {
        Send(ParentId_, new TEvents::TEvPoisonTaken{});
        PassAway();
    }

    void UpdateState(std::unique_ptr<const TServerStatusResponse>&& respPtr) {
        const auto& resp = *respPtr;

        if (ShardIdsHash_ == resp.GetShardIdsHash()) {
            // nothing changed
            return;
        }
        ShardIdsHash_ = resp.GetShardIdsHash();

        auto event = std::make_unique<TLocalEvents::TNodeUpdated>(NodeId_, TActivationContext::Now());

        // (1) process updated shards
        for (const TShardStatus& status: resp.GetPartitionStatus()) {
            ui32 shardId = status.GetNumId();
            auto shardKey = TShardKeyRef::FromShardStatus(status);
            if (!shardKey) {
                MON_WARN(MetabaseWatcher, "got incomplete shard key from {" << Address_ << "}, shardId=" << shardId);
                continue;
            }

            auto it = Shards_.find(shardId);
            if (it == Shards_.end() || status.GetReady() != it->second->Ready || shardKey != it->second->Key) {
                // if it's a new shard on this node or shard configuration was changed, then
                // create a new shard info and replace previously stored one
                auto shard = CreateShardInfo(status, shardKey);
                Shards_[shardId] = shard;
                event->Updated.push_back(std::move(shard));
            }
        }

        // (2) process deleted shards
        if (Shards_.size() > resp.PartitionStatusSize()) {
            std::unordered_set<TShardId> newIds;
            newIds.reserve(resp.PartitionStatusSize());
            for (const TShardStatus& status: resp.GetPartitionStatus()) {
                newIds.insert(status.GetNumId());
            }

            for (auto it = Shards_.begin(), end = Shards_.end(); it != end; ) {
                TShardId shardId = it->first;
                if (newIds.count(shardId) == 0) {
                    event->Removed.push_back(shardId);
                    it = Shards_.erase(it);
                } else {
                    ++it;
                }
            }
        }

        event->TotalShardCount = resp.total_partition_count();

        // always send initial reponse 
        // in order to let cluster wather know that this node is avaliable
        if (Initial_ || !event->Updated.empty() || !event->Removed.empty()) {
            Initial_ = false;
            event->ShardCount = Shards_.size();
            Send(ParentId_, event.release());
        }
    }

    TShardInfoPtr CreateShardInfo(const TShardStatus& status, const TShardKeyRef& keyRef) const {
        TShardKey key;
        keyRef.CopyTo(&key);

        auto shard = std::make_shared<TShardInfo>(
                status.GetNumId(),
                Address_,
                std::move(key),
                status.GetReady());

        return shard;
    }

private:
    const ui32 NodeId_;
    const TString Address_;
    IMetabaseRpc* Rpc_;
    const TDuration InitialUpdateDelay_;
    TDuration UpdateDelay_;
    TActorId ParentId_;
    ui64 ShardIdsHash_{0};
    std::unordered_map<TShardId, TShardInfoPtr> Shards_;
    bool Initial_ = true;
};

bool IsEnoughKnownShards(size_t totalShardCount, size_t knownShardsSum) {
    return totalShardCount != 0 && static_cast<double>(knownShardsSum) / static_cast<double>(totalShardCount) >= 0.9;
}

/**
 * Creates node watcher for each Metabase node in cluster, aggregates information about shards locations.
 * Also handles update subscriptions and allow force shard resolving.
 */
class TClusterWatcher: public TActorBootstrapped<TClusterWatcher>, public TEnableInitialization<TClusterWatcher> {
    struct TNodeState {
        TString Address;
        TActorId Watcher;
        TInstant UpdatedAt;
        size_t ShardCount{0};
        TInstant LastErrorAt;
        std::optional<TMetabaseError> LastError;
    };

    struct TShardState {
        TShardInfoPtr Info;
        ui32 NodeId;
    };

public:
    TClusterWatcher(IMetabaseClusterRpcPtr rpc, std::vector<TString> addresses, TDuration updateDelay)
        : Rpc_(std::move(rpc))
        , Addresses_(std::move(addresses))
        , Nodes_(Addresses_.size())
        , UpdateDelay_(updateDelay)
    {
    }

    void Bootstrap() {
        for (ui32 id = 0; id < Addresses_.size(); id++) {
            IMetabaseRpc* rpc = Rpc_->Get(Addresses_[id]);
            Nodes_[id].Address = Addresses_[id];
            Nodes_[id].Watcher = Register(new TNodeWatcher(id, Addresses_[id], rpc, UpdateDelay_));
        }

        Become(&TThis::Normal);
    }

    /**
     * In normal state actor process incoming updates or errors from node watchers
     * and handles subscription/resolving requests as usual.
     */
    STATEFN(Normal) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TMetabaseWatcherEvents::TSubscribe, OnSubscribe);
            hFunc(TMetabaseWatcherEvents::TResolve, OnResolve);
            hFunc(TLocalEvents::TNodeUpdated, OnNodeUpdated);
            hFunc(TLocalEvents::TNodeFailed, OnNodeFailed);
            hFunc(TInitializationEvents::TSubscribe, OnInitializationSubscribe);
            hFunc(NSelfMon::TEvPageDataReq, OnSelfMon);
            hFunc(TEvents::TEvPoison, OnPoison);
        }
    }

    /**
     * In dying state actor ignores all updates and requests, but waiting all its
     * child node watchers are take poison.
     */
    STATEFN(Dying) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TEvents::TEvPoisonTaken, OnPoisonTaken);
        }
    }

    void OnSubscribe(TMetabaseWatcherEvents::TSubscribe::TPtr& ev) {
        bool inserted = Subscribers_.insert(ev->Sender).second;
        if (inserted && !Shards_.empty()) {
            auto event = std::make_unique<TMetabaseWatcherEvents::TStateChanged>();
            for (auto& [id, s]: Shards_) {
                event->Updated.push_back(s.Info);
            }
            Send(ev->Sender, event.release(), 0, ev->Cookie);
        }
    }

    void OnResolve(TMetabaseWatcherEvents::TResolve::TPtr& ev) {
        auto event = std::make_unique<TMetabaseWatcherEvents::TResolveResult>();
        for (TShardId id: ev->Get()->Ids) {
            if (auto it = Shards_.find(id); it != Shards_.end()) {
                event->Locations.push_back(TShardLocation{id, it->second.Info->Address});
            }
        }
        Send(ev->Sender, event.release(), 0, ev->Cookie);
    }

    void OnNodeUpdated(TLocalEvents::TNodeUpdated::TPtr& ev) {
        ui32 nodeId = ev->Get()->NodeId;
        TNodeState& node = Nodes_[nodeId];

        if (!node.UpdatedAt && !node.LastErrorAt) {
            ++InitializedNodes_;
            ShardCountSum_ += ev->Get()->Updated.size();
            TotalShardCount_ = std::max(ev->Get()->TotalShardCount, TotalShardCount_);

            if (InitializedNodes_ == Nodes_.size() && IsEnoughKnownShards(TotalShardCount_, ShardCountSum_)) {
                FinishInitialization();
            }
        }

        node.UpdatedAt = ev->Get()->Time;
        node.ShardCount = ev->Get()->ShardCount;

        if (node.LastErrorAt && (TActivationContext::Now() - node.LastErrorAt) > TDuration::Hours(1)) {
            // clear the last error if more than an hour has passed since its appearance
            node.LastErrorAt = TInstant::Zero();
            node.LastError.reset();
        }

        auto& updated = ev->Get()->Updated;
        for (const auto& s: updated) {
            TShardState& shardState = Shards_[s->Id];
            shardState.Info = s;
            shardState.NodeId = nodeId;
        }

        std::vector<TShardId> actuallyRemoved;
        for (TShardId id: ev->Get()->Removed) {
            if (auto it = Shards_.find(id); it != Shards_.end() && it->second.NodeId == nodeId) {
                // erase shard from the map only if it is not moved yet
                Shards_.erase(it);
                actuallyRemoved.push_back(id);
            }
        }

        if (!updated.empty() || !actuallyRemoved.empty()) {
            MON_TRACE(MetabaseWatcher,
                    "updated " << updated.size() << " and removed " << actuallyRemoved.size()
                    << " shards on {" << node.Address << '}');

            for (auto it = Subscribers_.begin(), end = Subscribers_.end(); it != end; ) {
                auto event = std::make_unique<TMetabaseWatcherEvents::TStateChanged>();
                const TActorId& subscriber = *it++;
                if (it == end) {
                    // do not copy vectors for the last subscriber
                    event->Updated = std::move(updated); // NOLINT(bugprone-use-after-move)
                    event->Removed = std::move(actuallyRemoved); // NOLINT(bugprone-use-after-move)
                } else {
                    event->Updated = updated;
                    event->Removed = actuallyRemoved;
                }
                Send(subscriber, event.release());
            }
        }
    }

    void OnNodeFailed(TLocalEvents::TNodeFailed::TPtr& ev) {
        TNodeState& node = Nodes_[ev->Get()->NodeId];

        if (!node.UpdatedAt && !node.LastErrorAt) {
            if (++InitializedNodes_ == Nodes_.size() && IsEnoughKnownShards(TotalShardCount_, ShardCountSum_)) {
                FinishInitialization();
            }
        }

        node.LastErrorAt = ev->Get()->Time;
        node.LastError = std::move(ev->Get()->Error);
    }

    void OnSelfMon(const NSelfMon::TEvPageDataReq::TPtr& ev) {
        using namespace yandex::monitoring::selfmon;

        Page page;
        auto* table = page.mutable_component()->mutable_table();
        table->set_numbered(true);

        auto* addressColumn = table->add_columns();
        addressColumn->set_title("Address");
        auto* addressValues = addressColumn->mutable_string();

        auto* lastUpdateColumn = table->add_columns();
        lastUpdateColumn->set_title("Last Update (ago)");
        auto* lastUpdateValues = lastUpdateColumn->mutable_duration();

        auto* shardCountColumn = table->add_columns();
        shardCountColumn->set_title("Shard Count");
        auto* shardCountValues = shardCountColumn->mutable_uint64();

        auto* lastErrorColumn = table->add_columns();
        lastErrorColumn->set_title("Last Error Message");
        auto* lastErrorValues = lastErrorColumn->mutable_string();

        auto* lastErrorTimeColumn = table->add_columns();
        lastErrorTimeColumn->set_title("Last Error Time (ago)");
        auto* lastErrorTimeValues = lastErrorTimeColumn->mutable_duration();

        auto now = TInstant::Now();
        for (const auto& node: Nodes_) {
            addressValues->add_values(node.Address);
            shardCountValues->add_values(node.ShardCount);

            if (node.UpdatedAt) {
                lastUpdateValues->add_values((now - node.UpdatedAt).GetValue());
            } else {
                lastUpdateValues->add_values(TDuration::Max().GetValue());
            }

            if (node.LastError.has_value()) {
                lastErrorValues->add_values(node.LastError->Message);
                lastErrorTimeValues->add_values((now - node.LastErrorAt).GetValue());
            } else {
                lastErrorValues->add_values();
                lastErrorTimeValues->add_values(TDuration::Max().GetValue());
            }
        }

        Send(ev->Sender, new NSelfMon::TEvPageDataResp{std::move(page)});
    }

    void OnPoison(TEvents::TEvPoison::TPtr& ev) {
        Become(&TThis::Dying);
        Poisoner_ = ev->Sender;
        for (auto& node: Nodes_) {
            Send(node.Watcher, new TEvents::TEvPoison{});
        }
    }

    void OnPoisonTaken(TEvents::TEvPoisonTaken::TPtr& ev) {
        if (++PoisonedWatchers_ == Nodes_.size()) {
            Send(Poisoner_, ev->Release().Release());
            PassAway();
        }
    }

private:
    size_t InitializedNodes_ = 0;
    size_t TotalShardCount_ = 0;
    size_t ShardCountSum_ = 0;

    std::set<TActorId> Subscribers_;
    IMetabaseClusterRpcPtr Rpc_;
    std::vector<TString> Addresses_;
    std::vector<TNodeState> Nodes_;
    std::unordered_map<TShardId, TShardState> Shards_;
    TDuration UpdateDelay_;

    TActorId Poisoner_;
    ui32 PoisonedWatchers_{0};
};

} // namespace

std::unique_ptr<IActor> MetabaseClusterWatcher(
        IMetabaseClusterRpcPtr rpc,
        std::vector<TString> addresses,
        TDuration updateDelay)
{
    return std::make_unique<TClusterWatcher>(std::move(rpc), std::move(addresses), updateDelay);
}

} // namespace NSolomon::NDataProxy
