#include "watcher.h"

#include <solomon/libs/cpp/actors/config/log_component.pb.h>
#include <solomon/libs/cpp/logging/logging.h>
#include <solomon/libs/cpp/selfmon/selfmon.h>

#include <library/cpp/actors/core/actor_bootstrapped.h>
#include <library/cpp/actors/core/hfunc.h>
#include <library/cpp/containers/absl_flat_hash/flat_hash_map.h>
#include <library/cpp/containers/absl_flat_hash/flat_hash_set.h>

using namespace NActors;
using namespace NSolomon::NMemStore;
using namespace yandex::monitoring::memstore;

namespace NSolomon::NDataProxy {
namespace {

struct TLocalEvents: private TPrivateEvents {
    enum {
        RequestStatus = SpaceBegin,
        ReceiveResponse,
        ReceiveError,
        NodeUpdated,
        NodeFailed,
        End,
    };
    static_assert(End < SpaceEnd, "too many event types");

    struct TRequestStatus: public TEventLocal<TRequestStatus, RequestStatus> {
    };

    struct TReceiveResponse: public TEventLocal<TReceiveResponse, ReceiveResponse> {
        const ListShardsResponse Response;

        explicit TReceiveResponse(ListShardsResponse&& response) noexcept
            : Response{std::move(response)}
        {
        }
    };

    struct TReceiveError: public TEventLocal<TReceiveError, ReceiveError> {
        NGrpc::TGrpcStatus Status;

        explicit TReceiveError(NGrpc::TGrpcStatus&& status) noexcept
            : Status{std::move(status)}
        {
        }
    };

    struct TNodeUpdated: public TEventLocal<TNodeUpdated, NodeUpdated> {
        const ui32 NodeId;
        const TInstant Time;
        size_t ShardCount{0};
        std::vector<TShardInfoPtr> Updated;
        std::vector<TShardId> Removed;

        TNodeUpdated(ui32 nodeId, TInstant time) noexcept
            : NodeId{nodeId}
            , Time{time}
        {
        }
    };

    struct TNodeFailed: public TEventLocal<TNodeFailed, NodeFailed> {
        const ui32 NodeId;
        const TInstant Time;
        NGrpc::TGrpcStatus Error;

        TNodeFailed(ui32 nodeId, TInstant time, NGrpc::TGrpcStatus&& error) noexcept
            : NodeId{nodeId}
            , Time{time}
            , Error{std::move(error)}
        {
        }
    };
};

/**
 * Periodically downloads information about shards located on particular node.
 */
class TNodeWatcher: public TActorBootstrapped<TNodeWatcher> {
public:
    TNodeWatcher(ui32 nodeId, TStringBuf address, IMemStoreRpc* rpc, TDuration updateDelay)
        : NodeId_(nodeId)
        , Address_(address)
        , Rpc_(rpc)
        , InitialUpdateDelay_(updateDelay)
        , UpdateDelay_(updateDelay)
    {
    }

    void Bootstrap(TActorId parentId) {
        ParentId_ = parentId;
        OnRequest();
    }

    /**
     * In this state the actor waits until the next poling interval occurs.
     */
    STATEFN(Sleeping) {
        switch (ev->GetTypeRewrite()) {
            sFunc(TLocalEvents::TRequestStatus, OnRequest);
            sFunc(TEvents::TEvPoison, OnDie);
        }
    }

    /**
     * In this state the actor waits for a response (or an error) from node.
     */
    STATEFN(WaitingResponse) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TLocalEvents::TReceiveResponse, OnResponse);
            hFunc(TLocalEvents::TReceiveError, OnError);
            hFunc(TEvents::TEvPoison, OnPoison);
        }
    }

    /**
     * The actor enters this state after receiving {@code Poison} event while waiting for an
     * incomplete response from node.
     */
    STATEFN(Dying) {
        switch (ev->GetTypeRewrite()) {
            sFunc(TLocalEvents::TReceiveResponse, OnDie);
            sFunc(TLocalEvents::TReceiveError, OnDie);
        }
    }

private:
    void OnRequest() {
        MON_TRACE(MemStoreWatcher, "request status of {" << Address_ << "}");
        Become(&TThis::WaitingResponse);

        TActorId selfId = SelfId();
        auto* actorSystem = TActorContext::ActorSystem();

        ListShardsRequest req;
        req.set_generation(Generation_);

        Rpc_->ListShards(req)
            .Subscribe([selfId, actorSystem](TListShardsAsyncResponse respFuture) {
                std::unique_ptr<IEventBase> event;
                try {
                    auto valueOrError = respFuture.ExtractValue();
                    if (valueOrError.Success()) {
                        event = std::make_unique<TLocalEvents::TReceiveResponse>(valueOrError.Extract());
                    } else {
                        event = std::make_unique<TLocalEvents::TReceiveError>(valueOrError.ExtractError());
                    }
                } catch (...) {
                    event = std::make_unique<TLocalEvents::TReceiveError>(
                        NGrpc::TGrpcStatus::Internal(CurrentExceptionMessage()));
                }
                actorSystem->Send(selfId, event.release());
            });
    }

    void OnResponse(TLocalEvents::TReceiveResponse::TPtr& ev) {
        MON_TRACE(MemStoreWatcher, "got update from {" << Address_ << '}');

        const auto& resp = ev->Get()->Response;
        if (Generation_ != resp.generation()) {
            if (resp.num_ids_size() != resp.shard_keys_size()) {
                MON_ERROR(MemStoreWatcher, "received invalid response from " << Address_
                        << " num_ids_size(" << resp.num_ids_size() << ") !="
                        << " shard_key_size(" << resp.shard_keys_size() << ')');
            } else {
                UpdateState(resp);
            }
        }

        // restore update interval and schedule next update
        UpdateDelay_ = InitialUpdateDelay_;
        Schedule(UpdateDelay_, new TLocalEvents::TRequestStatus{});
        Become(&TThis::Sleeping);
    }

    void OnError(TLocalEvents::TReceiveError::TPtr& ev) {
        MON_WARN(MemStoreWatcher, "cannot get update from {" << Address_ << "}, error: " << ev->Get()->Status.Msg);

        // notify parent about an error
        auto event = std::make_unique<TLocalEvents::TNodeFailed>(
            NodeId_,
            TActivationContext::Now(),
            std::move(ev->Get()->Status));
        Send(ParentId_, event.release());

        // slowdown exponentially and add some jitter
        TDuration jitter = TDuration::MilliSeconds(RandomNumber(2'000ull));
        UpdateDelay_ = Min(1.5 * UpdateDelay_ + jitter, TDuration::Seconds(10));

        // schedule next update
        MON_DEBUG(MemStoreWatcher, "will retry request status of {" << Address_ << "} after " << UpdateDelay_);
        Schedule(UpdateDelay_, new TLocalEvents::TRequestStatus{});
        Become(&TThis::Sleeping);
    }

    void OnPoison(TEvents::TEvPoison::TPtr&) {
        // there is not yet completed query, which we have to wait
        Become(&TThis::Dying);
    }

    void OnDie() {
        Send(ParentId_, new TEvents::TEvPoisonTaken{});
        PassAway();
    }

    void UpdateState(const ListShardsResponse& resp) {
        Generation_ = resp.generation();
        auto event = std::make_unique<TLocalEvents::TNodeUpdated>(NodeId_, TActivationContext::Now());

        // (1) process updated shards
        for (int i = 0, size = resp.shard_keys_size(); i < size; i++) {
            TShardId numId = resp.num_ids(i);
            if (auto it = Shards_.find(numId); it == Shards_.end()) {
                const auto& sk = resp.shard_keys(i);
                TShardKey key{sk.project(), {sk.cluster(), sk.service()}};

                auto shard = std::make_shared<TShardInfo>(numId, Address_, key, true);
                Shards_.emplace(numId, shard);
                event->Updated.push_back(std::move(shard));
            }
        }

        // (2) process deleted shards
        if (Shards_.size() != static_cast<size_t>(resp.num_ids_size())) {
            absl::flat_hash_set<TShardId> newIds(resp.num_ids().begin(), resp.num_ids().end());

            for (auto it = Shards_.begin(), end = Shards_.end(); it != end; ) {
                TShardId shardId = it->first;
                if (!newIds.contains(shardId)) {
                    event->Removed.push_back(shardId);
                    auto toErase = it++;
                    Shards_.erase(toErase);
                } else {
                    ++it;
                }
            }
        }

        if (!event->Updated.empty() || !event->Removed.empty()) {
            event->ShardCount = Shards_.size();
            Send(ParentId_, event.release());
        }
    }

private:
    const ui32 NodeId_;
    const TString Address_;
    IMemStoreRpc* Rpc_;
    const TDuration InitialUpdateDelay_;
    TDuration UpdateDelay_;
    TActorId ParentId_;
    ui64 Generation_{0};
    absl::flat_hash_map<TShardId, TShardInfoPtr> Shards_;
};

/**
 * Creates node watcher for each MemStore node in cluster, aggregates information about shards locations.
 * Also handles update subscriptions and allow force shard resolving.
 */
class TClusterWatcher: public TActorBootstrapped<TClusterWatcher> {
    struct TNodeState {
        TString Address;
        TActorId Watcher;
        size_t ShardCount{0};
        TInstant UpdatedAt;
        TInstant LastErrorAt;
        std::optional<NGrpc::TGrpcStatus> LastError;
    };

    struct TShardState {
        TShardInfoPtr Info;
        ui32 NodeId;
    };

public:
    TClusterWatcher(std::shared_ptr<IMemStoreClusterRpc> rpc, std::vector<TString> addresses, TDuration updateDelay)
        : Rpc_(std::move(rpc))
        , Nodes_(addresses.size())
        , UpdateDelay_(updateDelay)
    {
        for (size_t i = 0, size = addresses.size(); i < size; i++) {
            Nodes_[i].Address = addresses[i];
        }
    }

    void Bootstrap() {
        for (ui32 id = 0; id < Nodes_.size(); id++) {
            const TString& address = Nodes_[id].Address;
            IMemStoreRpc* rpc = Rpc_->Get(address);
            Y_VERIFY(rpc, "unknown address %s in MemStore RPC", address.c_str());
            Nodes_[id].Watcher = Register(new TNodeWatcher(id, address, rpc, UpdateDelay_), TMailboxType::Simple);
        }
        Become(&TThis::Normal);
    }

    /**
     * In normal state actor process incoming updates or errors from node watchers
     * and handles subscription/resolving requests as usual.
     */
    STATEFN(Normal) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TMemStoreWatcherEvents::TSubscribe, OnSubscribe);
            hFunc(TMemStoreWatcherEvents::TResolve, OnResolve);
            hFunc(TLocalEvents::TNodeUpdated, OnNodeUpdated);
            hFunc(TLocalEvents::TNodeFailed, OnNodeFailed);
            hFunc(NSelfMon::TEvPageDataReq, OnSelfMon);
            hFunc(TEvents::TEvPoison, OnPoison);
        }
    }

    /**
     * In dying state actor ignores all updates and requests, but waiting all its
     * child node watchers are take poison.
     */
    STATEFN(Dying) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TEvents::TEvPoisonTaken, OnPoisonTaken);
        }
    }

    void OnSubscribe(TMemStoreWatcherEvents::TSubscribe::TPtr& ev) {
        bool inserted = Subscribers_.insert(ev->Sender).second;
        if (inserted && !Shards_.empty()) {
            std::vector<TShardInfoPtr> shards;
            shards.reserve(Shards_.size());
            for (auto& [id, s]: Shards_) {
                shards.emplace_back(s.Info);
            }
            Send(ev->Sender, new TMemStoreWatcherEvents::TStateChanged{std::move(shards), {}}, 0, ev->Cookie);
        }
    }

    void OnResolve(TMemStoreWatcherEvents::TResolve::TPtr& ev) {
        std::vector<TShardLocation> locations;
        locations.reserve(ev->Get()->Ids.size());
        for (TShardId id: ev->Get()->Ids) {
            if (auto it = Shards_.find(id); it != Shards_.end()) {
                locations.emplace_back(id, it->second.Info->Address);
            }
        }
        Send(ev->Sender, new TMemStoreWatcherEvents::TResolveResult{locations}, 0, ev->Cookie);
    }

    void OnNodeUpdated(TLocalEvents::TNodeUpdated::TPtr& ev) {
        ui32 nodeId = ev->Get()->NodeId;
        TNodeState& node = Nodes_[nodeId];
        node.UpdatedAt = ev->Get()->Time;
        node.ShardCount = ev->Get()->ShardCount;

        if (node.LastErrorAt && (TActivationContext::Now() - node.LastErrorAt) > TDuration::Hours(1)) {
            // clear the last error if more than an hour has passed since its appearance
            node.LastErrorAt = TInstant::Zero();
            node.LastError.reset();
        }

        auto& updated = ev->Get()->Updated;
        for (const auto& s: updated) {
            TShardState& shardState = Shards_[s->Id];
            shardState.Info = s;
            shardState.NodeId = nodeId;
        }

        std::vector<TShardId> actuallyRemoved;
        for (TShardId id: ev->Get()->Removed) {
            if (auto it = Shards_.find(id); it != Shards_.end() && it->second.NodeId == nodeId) {
                // erase shard from the map only if it is not moved yet
                Shards_.erase(it);
                actuallyRemoved.push_back(id);
            }
        }

        if (!updated.empty() || !actuallyRemoved.empty()) {
            MON_TRACE(MemStoreWatcher,
                     "updated " << updated.size() << " and removed " << actuallyRemoved.size()
                                << " shards on {" << node.Address << '}');

            // send copies for all subscribers except the first one
            if (Subscribers_.size() > 1) {
                for (auto it = std::next(Subscribers_.begin()), end = Subscribers_.end(); it != end; ++it) {
                    Send(*it, new TMemStoreWatcherEvents::TStateChanged{updated, actuallyRemoved});
                }
            }

            // do not copy vectors for the first subscriber for small optimization
            if (auto it = Subscribers_.begin(); it != Subscribers_.end()) {
                Send(*it, new TMemStoreWatcherEvents::TStateChanged{std::move(updated), std::move(actuallyRemoved)});
            }
        }
    }

    void OnNodeFailed(TLocalEvents::TNodeFailed::TPtr& ev) {
        TNodeState& node = Nodes_[ev->Get()->NodeId];
        node.LastErrorAt = ev->Get()->Time;
        node.LastError = std::move(ev->Get()->Error);
    }

    void OnSelfMon(const NSelfMon::TEvPageDataReq::TPtr& ev) {
        using namespace yandex::monitoring::selfmon;

        Page page;
        auto* table = page.mutable_component()->mutable_table();
        table->set_numbered(true);

        auto* addressColumn = table->add_columns();
        addressColumn->set_title("Address");
        auto* addressValues = addressColumn->mutable_string();

        auto* lastUpdateColumn = table->add_columns();
        lastUpdateColumn->set_title("Last Update (ago)");
        auto* lastUpdateValues = lastUpdateColumn->mutable_duration();

        auto* shardCountColumn = table->add_columns();
        shardCountColumn->set_title("Shard Count");
        auto* shardCountValues = shardCountColumn->mutable_uint64();

        auto* lastErrorColumn = table->add_columns();
        lastErrorColumn->set_title("Last Error Message");
        auto* lastErrorValues = lastErrorColumn->mutable_string();

        auto* lastErrorTimeColumn = table->add_columns();
        lastErrorTimeColumn->set_title("Last Error Time (ago)");
        auto* lastErrorTimeValues = lastErrorTimeColumn->mutable_duration();

        auto now = TInstant::Now();
        for (const auto& node: Nodes_) {
            addressValues->add_values(node.Address);
            shardCountValues->add_values(node.ShardCount);

            if (node.UpdatedAt) {
                lastUpdateValues->add_values((now - node.UpdatedAt).GetValue());
            } else {
                lastUpdateValues->add_values(TDuration::Max().GetValue());
            }

            if (node.LastError.has_value()) {
                lastErrorValues->add_values(node.LastError->Msg);
                lastErrorTimeValues->add_values((now - node.LastErrorAt).GetValue());
            } else {
                lastErrorValues->add_values();
                lastErrorTimeValues->add_values(TDuration::Max().GetValue());
            }
        }

        Send(ev->Sender, new NSelfMon::TEvPageDataResp{std::move(page)});
    }

    void OnPoison(TEvents::TEvPoison::TPtr& ev) {
        Become(&TThis::Dying);
        Poisoner_ = ev->Sender;
        for (auto& node: Nodes_) {
            Send(node.Watcher, new TEvents::TEvPoison{});
        }
    }

    void OnPoisonTaken(TEvents::TEvPoisonTaken::TPtr& ev) {
        if (++PoisonedWatchers_ == Nodes_.size()) {
            Send(Poisoner_, ev->Release().Release());
            PassAway();
        }
    }

private:
    std::set<TActorId> Subscribers_;
    std::shared_ptr<IMemStoreClusterRpc> Rpc_;
    std::vector<TNodeState> Nodes_;
    absl::flat_hash_map<TShardId, TShardState> Shards_;
    TDuration UpdateDelay_;
    TActorId Poisoner_;
    ui32 PoisonedWatchers_{0};
};

} // namespace

std::unique_ptr<IActor> MemStoreClusterWatcher(
    std::shared_ptr<IMemStoreClusterRpc> rpc,
    std::vector<TString> addresses,
    TDuration updateDelay)
{
    return std::make_unique<TClusterWatcher>(std::move(rpc), std::move(addresses), updateDelay);
}

} // namespace NSolomon::NDataProxy
