#include "cluster.h"
#include "shard.h"
#include "watcher.h"

#include <solomon/services/dataproxy/lib/shard/shards_map.h>

#include <solomon/libs/cpp/logging/logging.h>
#include <solomon/libs/cpp/selfmon/selfmon.h>

#include <library/cpp/actors/core/actor_bootstrapped.h>
#include <library/cpp/actors/core/event.h>
#include <library/cpp/actors/core/hfunc.h>

using namespace NActors;

namespace NSolomon::NDataProxy {
namespace {

class TMemStoreCluster: public TActorBootstrapped<TMemStoreCluster> {
    struct TShardState {
        TActorId ActorId;
        TInstant LastAccessAt;
    };

public:
    explicit TMemStoreCluster(TMemStoreClusterOptions opts) noexcept
        : ClusterId_{opts.ClusterId}
        , Rpc_{std::move(opts.Rpc)}
        , Watcher_{MemStoreClusterWatcher(Rpc_, std::move(opts.Addresses), opts.WatchDelay)}
        , CleanupDelay_{opts.CleanupDelay}
        , ShardTtl_{opts.ShardTtl}
    {
    }

    void Bootstrap() {
        Watcher_ = Register(std::get<std::unique_ptr<IActor>>(Watcher_).release());
        Send(std::get<TActorId>(Watcher_), new TMemStoreWatcherEvents::TSubscribe);
        Become(&TThis::Normal);
        Schedule(CleanupDelay_, new TEvents::TEvWakeup);
    }

    STFUNC(Normal) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TMemStoreWatcherEvents::TStateChanged, OnStateChanged);
            hFunc(TMemStoreClusterEvents::TFindShardsReq, OnFindShards);
            hFunc(TEvents::TEvWakeup, OnWakeup);
            HFunc(NSelfMon::TEvPageDataReq, OnSelfMon)
            hFunc(TEvents::TEvPoison, OnPoison);
        }
    }

    STATEFN(Dying) {
        switch (ev->GetTypeRewrite()) {
            hFunc(TEvents::TEvPoisonTaken, OnPoisonTaken);
        }
    }

private:
    void OnStateChanged(TMemStoreWatcherEvents::TStateChanged::TPtr& ev) {
        // removals must be processed first
        for (TShardId id: ev->Get()->Removed) {
            ShardLocations_.Remove(id);

            if (auto it = Shards_.find(id); it != Shards_.end()) {
                Send(it->second.ActorId, new TEvents::TEvPoison);
                Shards_.erase(it);
            }
        }

        for (const auto& shard: ev->Get()->Updated) {
            ShardLocations_.Update(shard);

            if (auto it = Shards_.find(shard->Id); it != Shards_.end()) {
                Send(it->second.ActorId, new TMemStoreShardEvents::TUpdateLocation{shard->Address});
            }
        }
    }

    TActorId GetShardActorId(TShardId shardId, TString address) {
        auto& shardState = Shards_[shardId];
        shardState.LastAccessAt = TInstant::Now();
        if (!shardState.ActorId) {
            auto actor = MemStoreShard(ClusterId_, shardId, std::move(address), Rpc_);
            shardState.ActorId = Register(actor.release());
        }
        return shardState.ActorId;
    }

    void OnFindShards(TMemStoreClusterEvents::TFindShardsReq::TPtr& ev) {
        const TShardSelector& selector = ev->Get()->Selector;

        if (!selector.IsExact()) {
            auto shards = ShardLocations_.FindInfo(selector);
            std::vector<TMemStoreShard> result;
            result.reserve(shards.size());
            for (auto&& shard: shards) {
                auto actorId = GetShardActorId(shard->Id, shard->Address);
                result.emplace_back(TMemStoreShard{actorId, shard->Key.SubKey, shard->Id});
            }

            MON_DEBUG(MemStoreClient, "found " << shards.size() << " shards by " << selector << " in " << ClusterId_);
            Send(ev->Sender, new TMemStoreClusterEvents::TFindShardsResp{ClusterId_, std::move(result)},
                    0, 0, std::move(ev->TraceId));
            return;
        }

        if (auto shard = ShardLocations_.FindExactInfo(selector)) {
            auto actorId = GetShardActorId(shard->Id, shard->Address);

            MON_DEBUG(MemStoreClient, "found 1 shard by " << selector << " in " << ClusterId_);
            Send(ev->Sender, new TMemStoreClusterEvents::TFindShardsResp{ClusterId_, {
                TMemStoreShard{actorId, shard->Key.SubKey, shard->Id}}
            }, 0, 0, std::move(ev->TraceId));
            return;
        }

        // shard was not found
        MON_DEBUG(MemStoreClient, "shard was not found by " << selector << " in " << ClusterId_);
        Send(ev->Sender, new TMemStoreClusterEvents::TFindShardsResp{ClusterId_, {}},
                0, 0, std::move(ev->TraceId));
    }

    void OnWakeup(TEvents::TEvWakeup::TPtr& ev) {
        auto now = TActivationContext::Now();
        for (auto it = Shards_.begin(), end = Shards_.end(); it != end; ) {
            if ((now - it->second.LastAccessAt) >= ShardTtl_) {
                // poison and remove shard if there were no reads from it more than {ShardTtl} period
                Send(it->second.ActorId, new TEvents::TEvPoison);
                auto toRemove = it++;
                Shards_.erase(toRemove);
            } else {
                ++it;
            }
        }
        Schedule(CleanupDelay_, ev->Release().Release()); // send the same event
    }

    void OnSelfMon(const NSelfMon::TEvPageDataReq::TPtr& ev, const NActors::TActorContext& ctx) {
        ctx.Send(ev->Forward(std::get<TActorId>(Watcher_)));
    }

    void OnPoison(TEvents::TEvPoison::TPtr& ev) {
        PoisonerId_ = ev->Sender;
        PoisonCountdown_ = Shards_.size() + 1; // +1 for watcher

        Send(std::get<TActorId>(Watcher_), new TEvents::TEvPoison);
        for (auto& [id, shard]: Shards_) {
            Send(shard.ActorId, new TEvents::TEvPoison);
        }

        Shards_.clear();
        Become(&TThis::Dying);
    }

    void OnPoisonTaken(TEvents::TEvPoisonTaken::TPtr& ev) {
        if (--PoisonCountdown_ == 0) {
            Send(PoisonerId_, ev->Release().Release());
            PassAway();
        }
    }

private:
    TClusterId ClusterId_;
    std::shared_ptr<NMemStore::IMemStoreClusterRpc> Rpc_;
    std::variant<TActorId, std::unique_ptr<IActor>> Watcher_; // id of registered actor or actor itself
    TDuration CleanupDelay_;
    TDuration ShardTtl_;
    TShardsMap ShardLocations_;
    absl::flat_hash_map<TShardId, TShardState> Shards_;
    TActorId PoisonerId_;
    size_t PoisonCountdown_{0};
};

} // namespace

std::unique_ptr<IActor> MemStoreCluster(TMemStoreClusterOptions opts) {
    return std::make_unique<TMemStoreCluster>(std::move(opts));
}

} // namespace NSolomon::NDataProxy
