#include "manager.h"
#include "helpers.h"
#include "resolver.h"

#include <library/cpp/testing/unittest/registar.h>
#include <library/cpp/testing/unittest/env.h>

#include <util/random/random.h>
#include <util/thread/pool.h>

#include <util/system/fs.h>

using namespace NYP::NServiceDiscovery;

Y_UNIT_TEST_SUITE(Subscribe) {
    Y_UNIT_TEST(TestSimple) {
        TEndpointSetManager manager(DummyRequester, TestClientName);
        manager.Start(TDuration::MilliSeconds(10));

        TEndpointSetKey key1("test-cluster", "test-service");
        TEndpointSetKey key2("test-cluster", "test-service");
        TEndpointSetKey key3("test-cluster", "test-service3");

        for (size_t i = 0; i < 5; ++i) {
            UNIT_ASSERT_VALUES_EQUAL(manager.SubscriberCount(), 0);

            TProvider p1;
            auto ref1 = manager.Subscribe(key1, p1);
            UNIT_ASSERT(p1.GetEndpointSet().endpoints_size() > 0);
            UNIT_ASSERT_VALUES_EQUAL(manager.SubscriberCount(), 1);

            TProvider p2;
            auto ref2 = manager.Subscribe(key2, p2);
            UNIT_ASSERT_VALUES_EQUAL(manager.SubscriberCount(), 2);
            UNIT_ASSERT(&manager.GetStat(key1) == &manager.GetStat(key2));
            UNIT_ASSERT(p2.GetEndpointSet().endpoints_size() > 0);

            TProvider p3;
            auto ref3 = manager.Subscribe(key3, p3);
            UNIT_ASSERT_VALUES_EQUAL(manager.SubscriberCount(), 3);
            UNIT_ASSERT(p3.GetEndpointSet().endpoints_size() > 0);
            UNIT_ASSERT(&manager.GetStat(key1) != &manager.GetStat(key3));
        }
    }

    struct TRequester
        : public IRemoteRequester
    {
        TRequester() {
            AtomicSet(Readiness_, 1);
            AtomicSet(EndpointCount_, 1);
        }

        void DoResolve(const TResolveRequestBatch& request, TResolveResultBatch& result, TStatEnv&) override {
            if (Disabled) {
                return;
            }

            bool ready = AtomicGet(Readiness_) == 1;
            size_t currentGeneration = AtomicGet(CurrentGeneration);
            size_t endpointCount = AtomicGet(EndpointCount_);

            for (size_t i = 0; i < request.size(); ++i) {
                for (size_t j = 0; j < endpointCount; ++j) {
                    auto& endpoint = AddEndpoint(result[i]);
                    endpoint.set_fqdn(TStringBuilder{} << "fqdn" << endpointCount - i);
                    endpoint.set_port(currentGeneration);
                    endpoint.set_ready(ready);
                }
            }
        }

        void CheckProvider(TProvider& p) {
            size_t minGeneration = AtomicGet(MinGeneration);
            auto endpointSet = p.GetEndpointSet();
            size_t maxGeneration  = AtomicGet(CurrentGeneration);

            UNIT_ASSERT(endpointSet.endpoints_size() > 0);

            size_t endpointSetGeneration = endpointSet.endpoints(0).port();
            UNIT_ASSERT(minGeneration <= endpointSetGeneration && endpointSetGeneration <= maxGeneration);
        }

        void ChangeReadiness(bool ready) {
            AtomicSet(Readiness_, ready ? 1 : 0);
        }

        void SetEndpointCount(size_t count) {
            AtomicSet(EndpointCount_, count);
        }

        void DoInit() override {}

    public:
        TAtomic MinGeneration = 1;
        TAtomic CurrentGeneration = 1;
        bool Disabled = false;

    private:
        TAtomic Readiness_;
        TAtomic EndpointCount_;
    };

    Y_UNIT_TEST(TestUpdateReadiness) {
        auto r = MakeAtomicShared<TRequester>();
        r->ChangeReadiness(true);
        r->SetEndpointCount(10);

        TEndpointSetManager manager{r, TestClientName};
        manager.Start(TDuration::MilliSeconds(10));

        TAtomic updateCount = 0;
        TAtomic updateReadinessCount = 0;

        TVector<NApi::TEndpointSet> updateSeq;
        TVector<NApi::TEndpointSet> readinessSeq;

        TProvider p{[&](const NApi::TEndpointSet& eps) {
            updateSeq.emplace_back(eps);
            AtomicIncrement(updateCount);

        }, [&](const NApi::TEndpointSet& eps) {
            readinessSeq.emplace_back(eps);
            AtomicIncrement(updateReadinessCount);
        }};

        auto ref = manager.Subscribe({ TEndpointSetKey{ "test-cluster", "test-service1" } }, p);

        // Wait update.
        Wait([&]() {
            return AtomicGet(updateCount) > 0;
        });

        // Ready -> false.
        r->ChangeReadiness(false);
        Wait([&]() {
            return AtomicGet(updateReadinessCount) > 0;
        });

        // Ready -> true.
        r->ChangeReadiness(true);
        Wait([&]() {
            return AtomicGet(updateReadinessCount) > 1;
        });

        {
            // Test that new subsribers will receive full update, even only readiness changed in last update.
            TAtomic updateCount2 = 0;
            TAtomic updateReadinessCount2 = 0;

            TProvider p2{[&](const NApi::TEndpointSet&) {
                AtomicIncrement(updateCount2);

            }, [&](const NApi::TEndpointSet&) {
                AtomicIncrement(updateReadinessCount2);
            }};

            auto ref2 = manager.Subscribe({ TEndpointSetKey{ "test-cluster", "test-service1" } }, p2);

            Wait([&]() {
                return AtomicGet(updateCount2) > 0;
            });

            manager.Stop();

            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateCount2), 1);
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateReadinessCount2), 0);
        }

        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateCount), 1);
        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateReadinessCount), 2);
        UNIT_ASSERT_VALUES_EQUAL(updateSeq[0].endpoints().size(), readinessSeq[0].endpoints().size());
        UNIT_ASSERT_VALUES_EQUAL(readinessSeq[0].endpoints().size(), 10);

        for (int i = 0; i < updateSeq[0].endpoints().size(); i++) {
            UNIT_ASSERT(updateSeq[0].endpoints(i).ready());
            UNIT_ASSERT(!readinessSeq[0].endpoints(i).ready());
            UNIT_ASSERT(readinessSeq[1].endpoints(i).ready());
        }
    }

    void TestConcurrent(TEndpointSetManager& manager, TRequester& r, size_t keyCount) {
        TVector<TEndpointSetKey> keys;
        for (size_t i = 0; i < keyCount; ++i) {
            keys.emplace_back("test-cluster", "test-service" + ToString(i));
        }


        const size_t minGeneration = r.CurrentGeneration;
        const size_t maxGeneration = minGeneration + 25;

        TAtomic lastSeenGeneration = 0;

        TThreadPool pool;
        const size_t nThreads = 8;

        pool.Start(nThreads);

        for (size_t i = 0; i < nThreads; ++i) {
            Y_VERIFY(
                pool.AddFunc([&manager, &keys, &lastSeenGeneration, &r, maxGeneration]() {
                    TVector<std::pair<TProvider, IEndpointSetSubscriberRef>> providers(25);

                    size_t i = 0;
                    while ((size_t)AtomicGet(lastSeenGeneration) < maxGeneration) {
                        const TEndpointSetKey& key1 = keys[RandomNumber<size_t>() % keys.size()];
                        const TEndpointSetKey& key2 = keys[RandomNumber<size_t>() % keys.size()];

                        auto& p1 = providers[(i++) % providers.size()];
                        p1.second = manager.Subscribe(key1, p1.first);
                        r.CheckProvider(p1.first);

                        auto& p2 = providers[(i++) % providers.size()];
                        p2.second = manager.Subscribe(key1, p2.first);
                        r.CheckProvider(p2.first);

                        auto& p3 = providers[(i++) % providers.size()];
                        p3.second = manager.Subscribe(key2, p3.first);
                        r.CheckProvider(p3.first);

                        auto eps = p3.first.GetEndpointSet();

                        AtomicSet(lastSeenGeneration, eps.endpoints(0).port());

                        for (auto& p : providers) {
                            if (p.second) {
                                r.CheckProvider(p.first);
                            }
                        }
                    }
                })
            );
        }

        for (size_t i = 2; i <= maxGeneration; ++i) {
            AtomicSet(r.CurrentGeneration, i);

            const TAtomic prev = manager.GetStat().UpdateLoopCounter;

            Wait([&manager, prev]() {
                return manager.GetStat().UpdateLoopCounter > prev + 1;
            });

            AtomicSet(r.MinGeneration, i);
        }

        pool.Stop();

        UNIT_ASSERT_VALUES_EQUAL(lastSeenGeneration, maxGeneration);
    }

    Y_UNIT_TEST(TestConcurrent) {
        auto r = MakeAtomicShared<TRequester>();
        TEndpointSetManager manager(r, TestClientName);
        manager.Start(TDuration::MilliSeconds(10));

        TestConcurrent(manager, *r, 100);
    }

    Y_UNIT_TEST(TestConcurrentWithPersistentCache) {
        TFsPath cacheDir("./cache");
        NFs::RemoveRecursive(cacheDir.GetPath());

        auto r = MakeAtomicShared<TRequester>();
        TEndpointSetManager manager(cacheDir.GetPath(), r, TestClientName);
        manager.Start(TDuration::MilliSeconds(10));

        TestConcurrent(manager, *r, 25);
    }


    Y_UNIT_TEST(TestAllowUpdate) {
        class TProviderEx
            : public TProvider
        {
        public:
            using TProvider::TProvider;

            bool AllowUpdate() override {
                return AtomicGet(AllowUpdate_);
            }

            void SetAllowUpdate(bool v) {
                AtomicSet(AllowUpdate_, v);
            }
        private:
            TAtomic AllowUpdate_ = 0;
        };

        auto r = MakeAtomicShared<TRequester>();

        TEndpointSetManager manager(r, TestClientName);

        TAtomic updateCounter = 0;
        TProviderEx p([&updateCounter](const NApi::TEndpointSet&) {
            AtomicIncrement(updateCounter);
        });
        p.UpdateActiveEndpointSetInfo = true;

        TEndpointSetKey key("test-cluster", "test-service");
        auto ref = manager.Subscribe(key, p);
        p.SetEndpointSet(ref->GetEndpointSet());

        auto check = [&](const TEndpointSetEx& eps) {
            UNIT_ASSERT(eps.endpoints_size() > 0);
            UNIT_ASSERT_VALUES_EQUAL(eps.endpoints(0).port(), r->CurrentGeneration);
        };

        check(ref->GetEndpointSet());

        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateCounter), 0);

        manager.Start(TDuration::MilliSeconds(10));

        auto wait = [&]() {
            const TAtomic prev = manager.GetStat().UpdateLoopCounter;

            Wait([&manager, prev]() {
                return manager.GetStat().UpdateLoopCounter > prev + 1;
            });
        };

        wait();
        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateCounter), 0);
        check(ref->GetEndpointSet());

        p.SetAllowUpdate(true);
        wait();
        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateCounter), 0);
        check(ref->GetEndpointSet());
        check(p.GetEndpointSet());

        AtomicIncrement(r->CurrentGeneration);
        wait();
        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateCounter), 1);
        wait();
        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateCounter), 1);

        p.SetAllowUpdate(false);
        AtomicIncrement(r->CurrentGeneration);
        wait();
        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateCounter), 1);
        check(ref->GetEndpointSet());

        p.SetAllowUpdate(true);
        wait();
        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(updateCounter), 2);

        check(ref->GetEndpointSet());
        check(p.GetEndpointSet());
    }

    Y_UNIT_TEST(TestUnavailable) {
        auto r = MakeAtomicShared<TRequester>();
        r->Disabled = true;

        TEndpointSetManager manager(r, TestClientName);

        bool succeeded = false;
        try {
            TEndpointSetKey key("test-cluster", "test-service");
            TProvider p;
            auto ref = manager.Subscribe(key, p);
            succeeded = true;
        } catch (...) {
            UNIT_ASSERT(CurrentExceptionMessage().Contains("endpointset [test-cluster#test-service] unavailable"));
        }

        UNIT_ASSERT(!succeeded);
    }

    Y_UNIT_TEST(TestDestroyManagerFirst) {
        auto r = MakeAtomicShared<TRequester>();

        IEndpointSetSubscriberRef ref1, ref2;
        TProvider p1, p2;

        {
            TEndpointSetManager manager(r, TestClientName);

            TEndpointSetKey key("test-cluster", "test-service");
            ref1 = manager.Subscribe(key, p1);
            UNIT_ASSERT(ref1->GetEndpointSet().endpoints_size() > 0);
            manager.Start(TDuration::MilliSeconds(10));

            ref2 = manager.Subscribe(key, p2);
            UNIT_ASSERT(ref2->GetEndpointSet().endpoints_size() > 0);
        }
    }
}
