#include "fake_node.h"

#include <solomon/services/dataproxy/lib/memstore/watcher.h>

#include <solomon/libs/cpp/actors/test_runtime/actor_runtime.h>
#include <solomon/libs/cpp/clients/memstore/stub.h>

#include <library/cpp/testing/gtest/gtest.h>

using namespace NSolomon;
using namespace NDataProxy;

using yandex::monitoring::memstore::ListShardsRequest;
using yandex::monitoring::memstore::ListShardsResponse;

namespace {

template <typename T>
auto FindShard(const std::vector<T>& shards, TShardId id) {
    return std::find_if(shards.begin(), shards.end(), [id](const auto& s) {
        if constexpr (requires(const T& t) { t.Id; }) {
            return s.Id == id;
        } else {
            return s->Id == id;
        }
    });
}

} // namespace

class TMemStoreWatcherTest: public ::testing::Test {
protected:
    void SetUp() override {
        ActorRuntime = TTestActorRuntime::CreateInited(1, false, true);
        ActorRuntime->WaitForBootstrap();
        EdgeId = ActorRuntime->AllocateEdgeActor();

        NodeA.AddShard(101, "solomon", "sts1", "ingestor");
        NodeA.AddShard(102, "solomon", "sts1", "memstore");
        NodeB.AddShard(201, "solomon", "sts2", "fetcher");

        auto rpc = MakeFakeRpc({&NodeA, &NodeB});
        auto watcher = MemStoreClusterWatcher(rpc, rpc->Addresses(), TDuration::Seconds(1));
        WatcherId = ActorRuntime->Register(watcher.release());

        ActorRuntime->WaitForBootstrap();
    }

    void TearDown() override {
        ActorRuntime.Reset();
    }

    template <typename TEvent>
    typename TEvent::TPtr ReceiveResponse() {
        return ActorRuntime->GrabEdgeEvent<TEvent>(EdgeId, TDuration::Seconds(3));
    }

public:
    THolder<TTestActorRuntime> ActorRuntime;
    NActors::TActorId EdgeId;
    NActors::TActorId WatcherId;

    TFakeMemStoreNode NodeA{"node-a"};
    TFakeMemStoreNode NodeB{"node-b"};
};

TEST_F(TMemStoreWatcherTest, GetChangesAfterSubscription) {
    ActorRuntime->Send(WatcherId, EdgeId, MakeHolder<TMemStoreWatcherEvents::TSubscribe>());

    auto change = ReceiveResponse<TMemStoreWatcherEvents::TStateChanged>();
    ASSERT_TRUE(change);

    auto& updated = change->Get()->Updated;
    ASSERT_EQ(updated.size(), 3u);

    if (auto it = FindShard(updated, 101); it != updated.end()) {
        EXPECT_EQ((*it)->Id, 101u);
        EXPECT_EQ((*it)->Key, TShardKey("solomon", "sts1", "ingestor"));
        EXPECT_EQ((*it)->Address, NodeA.Address());
    } else {
        FAIL() << "shard with id=101 not found";
    }

    if (auto it = FindShard(updated, 102); it != updated.end()) {
        EXPECT_EQ((*it)->Id, 102u);
        EXPECT_EQ((*it)->Key, TShardKey("solomon", "sts1", "memstore"));
        EXPECT_EQ((*it)->Address, NodeA.Address());
    } else {
        FAIL() << "shard with id=102 not found";
    }

    if (auto it = FindShard(updated, 201); it != updated.end()) {
        EXPECT_EQ((*it)->Id, 201u);
        EXPECT_EQ((*it)->Key, TShardKey("solomon", "sts2", "fetcher"));
        EXPECT_EQ((*it)->Address, NodeB.Address());
    } else {
        FAIL() << "shard with id=201 not found";
    }

    auto& removed = change->Get()->Removed;
    ASSERT_EQ(removed.size(), 0u);
}

TEST_F(TMemStoreWatcherTest, GetChangesAfterShardMovement) {
    ActorRuntime->Send(WatcherId, EdgeId, MakeHolder<TMemStoreWatcherEvents::TSubscribe>());

    auto change = ReceiveResponse<TMemStoreWatcherEvents::TStateChanged>();
    ASSERT_TRUE(change);

    // move shard 101 from node A to node B
    auto [id, key] = NodeA.RemoveShard(101);
    NodeB.AddShard(id, key.Project, key.SubKey.Cluster, key.SubKey.Service);
    ActorRuntime->AdvanceCurrentTime(TDuration::Seconds(10));

    change = ReceiveResponse<TMemStoreWatcherEvents::TStateChanged>();
    ASSERT_TRUE(change);

    const auto& removed = change->Get()->Removed;
    if (!removed.empty()) {
        // if watcher receives response from a node where shard was removed before
        // response from a node where shard was added then it send two TStateChanged events
        EXPECT_EQ(removed.size(), 1u);
        EXPECT_EQ(removed[0], 101u);

        // get second change event from watcher
        change = ReceiveResponse<TMemStoreWatcherEvents::TStateChanged>();
        ASSERT_TRUE(change);
        ASSERT_TRUE(change->Get()->Removed.empty());
    }

    // otherwise watcher merges shard add and remove events from nodes and send only one
    // TStateChanged event
    const auto& updated = change->Get()->Updated;
    ASSERT_EQ(updated.size(), 1u);

    EXPECT_EQ(updated[0]->Id, 101u);
    EXPECT_EQ(updated[0]->Key, key);
    EXPECT_EQ(updated[0]->Address, NodeB.Address()); // moved to node B
}

TEST_F(TMemStoreWatcherTest, GetChangesAfterShardRemove) {
    ActorRuntime->Send(WatcherId, EdgeId, MakeHolder<TMemStoreWatcherEvents::TSubscribe>());

    auto change = ReceiveResponse<TMemStoreWatcherEvents::TStateChanged>();
    ASSERT_TRUE(change);

    // remove shard 201 from node B
    NodeB.RemoveShard(201);
    ActorRuntime->AdvanceCurrentTime(TDuration::Seconds(100));

    change = ReceiveResponse<TMemStoreWatcherEvents::TStateChanged>();
    ASSERT_TRUE(change);

    EXPECT_EQ(change->Get()->Updated.size(), 0u);

    auto shardIds = change->Get()->Removed;
    EXPECT_EQ(shardIds.size(), 1u);
    EXPECT_EQ(shardIds[0], 201u);
}

TEST_F(TMemStoreWatcherTest, Resolve) {
    {
        ActorRuntime->Send(WatcherId, EdgeId, THolder(new TMemStoreWatcherEvents::TResolve({101, 201})));
        auto result = ReceiveResponse<TMemStoreWatcherEvents::TResolveResult>();
        ASSERT_TRUE(result);

        const auto& locations = result->Get()->Locations;
        EXPECT_EQ(locations.size(), 2u);

        {
            auto it = FindShard(locations, 101u);
            ASSERT_TRUE(it != locations.end());
            EXPECT_EQ(it->Id, 101u);
            EXPECT_EQ(it->Address, "node-a");
        }

        {
            auto it = FindShard(locations, 201u);
            ASSERT_TRUE(it != locations.end());
            EXPECT_EQ(it->Id, 201u);
            EXPECT_EQ(it->Address, "node-b");
        }
    }

    {
        // one missing shard id
        ActorRuntime->Send(WatcherId, EdgeId, THolder(new TMemStoreWatcherEvents::TResolve({102, 301})));
        auto result = ReceiveResponse<TMemStoreWatcherEvents::TResolveResult>();
        ASSERT_TRUE(result);

        const auto& locations = result->Get()->Locations;
        EXPECT_EQ(locations.size(), 1u);

        {
            auto it = FindShard(locations, 102u);
            ASSERT_TRUE(it != locations.end());
            EXPECT_EQ(it->Id, 102u);
            EXPECT_EQ(it->Address, "node-a");
        }
    }
}
