#include "watcher.h"

#include <solomon/services/dataproxy/lib/cluster_map/cluster_map.h>
#include <solomon/services/dataproxy/lib/initialization/enable_initialization.h>

#include <solomon/libs/cpp/actors/config/log_component.pb.h>
#include <solomon/libs/cpp/logging/logging.h>
#include <solomon/libs/cpp/selfmon/selfmon.h>
#include <solomon/protos/stockpile/stockpile_requests.pb.h>

#include <library/cpp/actors/core/actor_bootstrapped.h>
#include <library/cpp/actors/core/hfunc.h>

#include <util/generic/hash.h>

using namespace NActors;
using namespace yandex::solomon::stockpile;

namespace NSolomon::NDataProxy {
namespace {

struct TLocalEvents: private TPrivateEvents {
    enum {
        RequestStatus = SpaceBegin,
        ReceiveStatus,
        ReceiveError,
        NodeUpdated,
        NodeFailed,
        End,
    };
    static_assert(End < SpaceEnd, "too many event types");

    struct TRequestStatus: public TEventLocal<TRequestStatus, RequestStatus> {
    };

    struct TReceiveStatus: public TEventLocal<TReceiveStatus, ReceiveStatus> {
        TServerStatusResponse Response;
    };

    struct TReceiveError: public TEventLocal<TReceiveError, ReceiveError> {
        TStockpileError Error;
    };

    struct TNodeUpdated: public TEventLocal<TNodeUpdated, NodeUpdated> {
        ui32 WatcherId;
        TInstant Time;
        size_t ShardCount{0};
        size_t TotalShardCount{0};
        std::vector<TStockpileShardInfo> Updated;
    };

    struct TNodeFailed: public TEventLocal<TNodeFailed, NodeFailed> {
        ui32 WatcherId;
        TInstant Time;
        TStockpileError Error;
    };
};

/**
 * Periodically downloads information about shards located on current node.
 */
class TNodeWatcher: public TActorBootstrapped<TNodeWatcher> {
public:
    TNodeWatcher(ui32 watcherId, TStringBuf address, IStockpileRpc* rpc, TDuration updateDelay)
        : WatcherId_(watcherId)
        , Address_(address)
        , Rpc_(rpc)
        , InitialUpdateDelay_(updateDelay)
        , UpdateDelay_(updateDelay)
    {
    }

    void Bootstrap(TActorId parentId) {
        ParentId_ = parentId;
        OnRequest();
    }

    STATEFN(Sleeping) {
        switch (ev->GetTypeRewrite()) {
            sFunc(TLocalEvents::TRequestStatus, OnRequest);
            sFunc(TEvents::TEvPoison, OnDie);
        }
    }

    STATEFN(WaitingResponse) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TLocalEvents::TReceiveStatus, OnStatus);
            hFunc(TLocalEvents::TReceiveError, OnError);
            hFunc(TEvents::TEvPoison, OnPoison);
        }
    }

    STATEFN(Dying) {
        switch (ev->GetTypeRewrite()) {
            sFunc(TLocalEvents::TReceiveStatus, OnDie);
            sFunc(TLocalEvents::TReceiveError, OnDie);
        }
    }

    void OnRequest() {
        MON_TRACE(StockpileWatcher, "request stockpile node status {" << Address_ << "}");
        Become(&TThis::WaitingResponse);

        TActorId selfId = SelfId();
        auto* actorSystem = TActorContext::ActorSystem();

        TServerStatusRequest req;
        req.SetDeadline((TActivationContext::Now() + 2 * UpdateDelay_).MilliSeconds());

        auto future = Rpc_->ServerStatus(req);
        future.Subscribe([selfId, actorSystem](TAsyncStockpileStatusResponse respFuture) {
            try {
                auto valueOrError = respFuture.ExtractValue();
                if (valueOrError.Success()) {
                    auto event = std::make_unique<TLocalEvents::TReceiveStatus>();
                    event->Response = valueOrError.Extract();
                    actorSystem->Send(selfId, event.release());
                } else {
                    auto event = std::make_unique<TLocalEvents::TReceiveError>();
                    event->Error = valueOrError.Error();
                    actorSystem->Send(selfId, event.release());
                }
            } catch (...) {
                auto event = std::make_unique<TLocalEvents::TReceiveError>();
                event->Error.RpcCode = grpc::StatusCode::OK;
                event->Error.StockpileCode = EStockpileStatusCode::UNKNOWN;
                event->Error.Message = CurrentExceptionMessage();
                actorSystem->Send(selfId, event.release());
            }
        });
    }

    void OnStatus(const TLocalEvents::TReceiveStatus::TPtr& ev) {
        MON_TRACE(StockpileWatcher, "got stockpile node update {" << Address_ << '}');
        UpdateState(ev->Get()->Response);

        // restore update interval and schedule next update
        UpdateDelay_ = InitialUpdateDelay_;
        Schedule(UpdateDelay_, new TLocalEvents::TRequestStatus{});
        Become(&TThis::Sleeping);
    }

    void OnError(const TLocalEvents::TReceiveError::TPtr& ev) {
        MON_WARN(StockpileWatcher, "cannot get stockpile node update {" << Address_ << "}, error: " << ev->Get()->Error);

        // notify parent about error
        auto event = std::make_unique<TLocalEvents::TNodeFailed>();
        event->Time = TActivationContext::Now();
        event->WatcherId = WatcherId_;
        event->Error = std::move(ev->Get()->Error);
        Send(ParentId_, event.release());

        // slowdown exponentially and add some jitter
        TDuration jitter = TDuration::MilliSeconds(RandomNumber(2'000ull));
        UpdateDelay_ = Min(1.5 * UpdateDelay_ + jitter, TDuration::Seconds(10));

        // schedule next update
        MON_TRACE(StockpileWatcher, "will retry after " << UpdateDelay_);
        Schedule(UpdateDelay_, new TLocalEvents::TRequestStatus{});
        Become(&TThis::Sleeping);
    }

    void OnPoison(const TEvents::TEvPoison::TPtr&) {
        // there is not yet completed RPC, which we have to wait
        Become(&TThis::Dying);
    }

    void OnDie() {
        Send(ParentId_, new TEvents::TEvPoisonTaken{});
        PassAway();
    }

    void UpdateState(const TServerStatusResponse& resp) {
        // TODO: copy format versions
//        resp.GetOlderSupportBinaryVersion();
//        resp.GetLatestSupportBinaryVersion();

        std::vector<TStockpileShardInfo> updated;
        for (const TShardStatus& status: resp.GetShardStatus()) {
            ui32 shardId = status.GetShardId();
            auto& shard = Shards_[shardId];
            if (shard.Id == 0) {
                shard.Id = shardId;
                shard.Ready = status.GetReady();
                shard.Location = Address_;
                updated.push_back(shard);
            } else if (shard.Ready != status.GetReady()) {
                shard.Ready = status.GetReady();
                updated.push_back(shard);
            }
        }

        if (Shards_.size() > resp.ShardStatusSize()) {
            std::unordered_set<TStockpileShardId> newIds;
            newIds.reserve(resp.ShardStatusSize());
            for (const TShardStatus& status: resp.GetShardStatus()) {
                newIds.insert(status.GetShardId());
            }

            for (auto it = Shards_.begin(), end = Shards_.end(); it != end; ) {
                if (newIds.count(it->first) == 0) {
                    it = Shards_.erase(it);
                } else {
                    ++it;
                }
            }
        }

        // always send initial reponse 
        // in order to let cluster wather know that this node is avaliable
        if (Initial_ || !updated.empty()) {
            Initial_ = false;
            auto event = std::make_unique<TLocalEvents::TNodeUpdated>();
            event->WatcherId = WatcherId_;
            event->Time = TActivationContext::Now();
            event->Updated = std::move(updated);
            event->ShardCount = Shards_.size();
            event->TotalShardCount = resp.GetTotalShardCount();
            Send(ParentId_, event.release());
        }
    }

private:
    const ui32 WatcherId_;
    const TString Address_;
    IStockpileRpc* Rpc_;
    const TDuration InitialUpdateDelay_;
    TDuration UpdateDelay_;
    TActorId ParentId_;
    TStockpileShardsMap Shards_;
    bool Initial_ = true;
};

/**
 * Creates node watcher for each Stockpile node in cluster, aggregates information about shards locations.
 * Also handles update subscriptions and allow force shard resolving.
 */
class TClusterWatcher: public TActorBootstrapped<TClusterWatcher>, public TEnableInitialization<TClusterWatcher> {
    struct TNodeState {
        TString Address;
        TActorId Watcher;
        size_t ShardCount{0};
        TInstant UpdatedAt;
        TInstant LastErrorAt;
        std::optional<TStockpileError> LastError;
    };

public:
    TClusterWatcher(IStockpileClusterRpcPtr rpc, std::vector<TString> addresses, TDuration updateDelay)
        : Rpc_(std::move(rpc))
        , Addresses_(std::move(addresses))
        , Nodes_(Addresses_.size())
        , UpdateDelay_(updateDelay)
    {
    }

    void Bootstrap() {
        for (ui32 id = 0; id < Addresses_.size(); id++) {
            IStockpileRpc* rpc = Rpc_->Get(Addresses_[id]);
            Nodes_[id].Address = Addresses_[id];
            Nodes_[id].Watcher = Register(new TNodeWatcher(id, Addresses_[id], rpc, UpdateDelay_));
        }

        Become(&TThis::Normal);
    }

    /**
     * In normal state actor process incoming updates or errors from node watchers
     * and handles subscription/resolving requests as usual.
     */
    STATEFN(Normal) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TLocalEvents::TNodeUpdated, OnNodeUpdated);
            hFunc(TLocalEvents::TNodeFailed, OnNodeFailed);
            hFunc(TStockpileWatcherEvents::TSubscribe, OnSubscribe);
            hFunc(TStockpileWatcherEvents::TResolve, OnResolve);
            hFunc(TInitializationEvents::TSubscribe, OnInitializationSubscribe);
            hFunc(NSelfMon::TEvPageDataReq, OnSelfMon);
            hFunc(TEvents::TEvPoison, OnPoison);
        }
    }

    /**
     * In dying state actor ignores all updates and requests, but waiting all its
     * child node watchers are take poison.
     */
    STATEFN(Dying) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TEvents::TEvPoisonTaken, OnPoisonTaken);
        }
    }

private:
    void OnNodeUpdated(TLocalEvents::TNodeUpdated::TPtr& ev) {
        TNodeState& node = Nodes_[ev->Get()->WatcherId];

        if (!node.UpdatedAt && !node.LastErrorAt) {
            ++InitializedNodes_;
            ShardCountSum_ += ev->Get()->Updated.size();
            TotalShardCount_ = std::max(ev->Get()->TotalShardCount, TotalShardCount_);

            if (InitializedNodes_ == Nodes_.size() && TotalShardCount_ == ShardCountSum_) {
                FinishInitialization();
            }
        }

        node.UpdatedAt = ev->Get()->Time;
        node.ShardCount = ev->Get()->ShardCount;

        // clear the last error info if more than an hour has passed since its appearance
        if (node.LastErrorAt && (TActivationContext::Now() - node.LastErrorAt) > TDuration::Hours(1)) {
            node.LastErrorAt = TInstant::Zero();
            node.LastError.reset();
        }

        auto& updated = ev->Get()->Updated;
        MON_TRACE(StockpileWatcher, "updated " << updated.size() << " shards on {" << node.Address << '}');

        for (TStockpileShardInfo& s: updated) {
            Shards_[s.Id] = s;
        }

        if (auto it = Subscribers_.begin(); it != Subscribers_.end()) {
            // send copy to all subscribers except the last one
            for (auto end = std::prev(Subscribers_.end()); it != end; ++it) {
                SendUpdate(*it, updated);
            }
            SendUpdate(*it, std::move(updated));
        }
    }

    template <typename T>
    void SendUpdate(TActorId subscriber, T&& updated) {
        auto event = std::make_unique<TStockpileWatcherEvents::TUpdate>();
        event->Shards = std::forward<T>(updated);
        Send(subscriber, event.release());
    };

    void OnNodeFailed(TLocalEvents::TNodeFailed::TPtr& ev) {
        TNodeState& node = Nodes_[ev->Get()->WatcherId];

        if (!node.UpdatedAt && !node.LastErrorAt) {
            if (++InitializedNodes_ == Nodes_.size() && TotalShardCount_ == ShardCountSum_) {
                FinishInitialization();
            }
        }

        node.LastErrorAt = ev->Get()->Time;
        node.LastError = std::move(ev->Get()->Error);
    }

    void OnSubscribe(TStockpileWatcherEvents::TSubscribe::TPtr& ev) {
        bool inserted = Subscribers_.insert(ev->Sender).second;
        if (inserted && !Shards_.empty()) {
            auto event = std::make_unique<TStockpileWatcherEvents::TUpdate>();
            for (auto& [id, s]: Shards_) {
                event->Shards.push_back(s);
            }
            Send(ev->Sender, event.release(), 0, ev->Cookie);
        }
    }

    void OnResolve(TStockpileWatcherEvents::TResolve::TPtr& ev) {
        auto event = std::make_unique<TStockpileWatcherEvents::TResolveResult>();
        for (TStockpileShardId id: ev->Get()->Ids) {
            if (auto it = Shards_.find(id); it != Shards_.end()) {
                event->Shards.push_back(it->second);
            } else {
                // TODO: force update shards
                event->NotFound.push_back(id);
            }
        }
        Send(ev->Sender, event.release(), 0, ev->Cookie);
    }

    void OnSelfMon(const NSelfMon::TEvPageDataReq::TPtr& ev) {
        using namespace yandex::monitoring::selfmon;

        Page page;
        auto* table = page.mutable_component()->mutable_table();
        table->set_numbered(true);

        auto* addressColumn = table->add_columns();
        addressColumn->set_title("Address");
        auto* addressValues = addressColumn->mutable_string();

        auto* lastUpdateColumn = table->add_columns();
        lastUpdateColumn->set_title("Last Update (ago)");
        auto* lastUpdateValues = lastUpdateColumn->mutable_duration();

        auto* shardCountColumn = table->add_columns();
        shardCountColumn->set_title("Shard Count");
        auto* shardCountValues = shardCountColumn->mutable_uint64();

        auto* lastErrorColumn = table->add_columns();
        lastErrorColumn->set_title("Last Error Message");
        auto* lastErrorValues = lastErrorColumn->mutable_string();

        auto* lastErrorTimeColumn = table->add_columns();
        lastErrorTimeColumn->set_title("Last Error Time (ago)");
        auto* lastErrorTimeValues = lastErrorTimeColumn->mutable_duration();

        auto now = TInstant::Now();
        for (const auto& node: Nodes_) {
            addressValues->add_values(node.Address);
            shardCountValues->add_values(node.ShardCount);

            if (node.UpdatedAt) {
                lastUpdateValues->add_values((now - node.UpdatedAt).GetValue());
            } else {
                lastUpdateValues->add_values(TDuration::Max().GetValue());
            }

            if (node.LastError.has_value()) {
                lastErrorValues->add_values(node.LastError->Message);
                lastErrorTimeValues->add_values((now - node.LastErrorAt).GetValue());
            } else {
                lastErrorValues->add_values();
                lastErrorTimeValues->add_values(TDuration::Max().GetValue());
            }
        }

        Send(ev->Sender, new NSelfMon::TEvPageDataResp{std::move(page)});
    }

    void OnPoison(TEvents::TEvPoison::TPtr& ev) {
        Become(&TThis::Dying);
        Poisoner_ = ev->Sender;
        for (auto& node: Nodes_) {
            Send(node.Watcher, new TEvents::TEvPoison{});
        }
    }

    void OnPoisonTaken(TEvents::TEvPoisonTaken::TPtr& ev) {
        if (++PoisonedWatchers_ == Nodes_.size()) {
            Send(Poisoner_, ev->Release().Release());
            PassAway();
        }
    }

private:
    size_t InitializedNodes_ = 0;
    size_t TotalShardCount_ = 0;
    size_t ShardCountSum_ = 0;

    std::set<TActorId> Subscribers_;
    IStockpileClusterRpcPtr Rpc_;
    std::vector<TString> Addresses_;
    std::vector<TNodeState> Nodes_;
    TStockpileShardsMap Shards_;
    TDuration UpdateDelay_;

    TActorId Poisoner_;
    ui32 PoisonedWatchers_{0};
};

} // namespace

std::unique_ptr<IActor> StockpileClusterWatcher(
        IStockpileClusterRpcPtr rpc,
        std::vector<TString> addresses,
        TDuration updateDelay)
{
    return std::make_unique<TClusterWatcher>(std::move(rpc), std::move(addresses), updateDelay);
}

} // namespace NSolomon::NDataProxy
