#include "manager.h"
#include "helpers.h"
#include "resolver.h"

#include <library/cpp/testing/unittest/registar.h>

#include <util/system/fs.h>
#include <util/generic/scope.h>

using namespace NYP::NServiceDiscovery;

Y_UNIT_TEST_SUITE(TErrorStat) {
    Y_UNIT_TEST(TestProviderError) {
        TAtomic requestCount = 0;
        auto requester = MakeRequester([&requestCount](const TResolveRequestBatch&, TResolveResultBatch& result, TStatEnv&) {
            const auto rc = AtomicIncrement(requestCount);
            for (auto& r : result) {
                auto& endpoint = AddEndpoint(r);
                endpoint.set_fqdn(TString("fqdn") + ToString(rc));
            }
        });

        TEndpointSetManager manager(requester, TestClientName);

        TEndpointSetKey key("test-cluster", "test-service");
        manager.RegisterEndpointSet<TProvider>(key, [counter = 0](const NApi::TEndpointSet&) mutable {
            if (counter > 0) {
                throw yexception();
            }
            ++counter;
            return true;
        });

        manager.Start(TDuration::MilliSeconds(10));

        Wait([&requestCount]() {
            return AtomicGet(requestCount) > 1;
        });

        manager.Stop();

        const auto& stat = manager.GetStat(key);

        UNIT_ASSERT_VALUES_EQUAL(stat.ProviderUpdateErrors + 1, requestCount);
        UNIT_ASSERT_VALUES_EQUAL(manager.GetStat().ProviderUpdateErrors + 1, requestCount);
    }

    Y_UNIT_TEST(TestRequesterError) {
        TAtomic requestCount = 0;
        auto requester = MakeRequester([&requestCount](const TResolveRequestBatch& request, TResolveResultBatch& result, TStatEnv& stat) {
            if (AtomicIncrement(requestCount) > 1) {
                throw yexception();
            }

            DummyRequester->Resolve(request, result, stat);
        });

        TEndpointSetManager manager(requester, TestClientName);

        TEndpointSetKey key("test-cluster", "test-service");
        manager.RegisterEndpointSet<TProvider>(key);
        manager.Start(TDuration::MilliSeconds(10));
        const auto& stat = manager.GetStat(key);

        Wait([&requestCount]() {
            return AtomicGet(requestCount) > 1;
        });

        manager.Stop();

        UNIT_ASSERT_VALUES_EQUAL(stat.RequesterErrors + 1, requestCount);
        UNIT_ASSERT_VALUES_EQUAL(manager.GetStat().RequesterErrors + 1, requestCount);
    }

    Y_UNIT_TEST(TestRequesterInitError) {
        TAtomic initCount = 0;
        auto requester = MakeRequester(
            [](const TResolveRequestBatch&, TResolveResultBatch&, TStatEnv&) {
                UNIT_ASSERT(false);
            },
            [&initCount]() {
                AtomicIncrement(initCount);
                ythrow yexception() << "error";
            },
            TDuration::Max()
        );

        TEndpointSetManager manager(requester, TestClientName);

        TEndpointSetKey key("test-cluster", "test-service");
        manager.RegisterEndpointSet<TProvider>(key);
        const auto& stat = manager.GetStat(key);

        try {
            manager.Start(TDuration::MilliSeconds(10));
        } catch (...) {
            UNIT_ASSERT(CurrentExceptionMessage().Contains("endpointset [test-cluster#test-service] unavailable"));
        }
        manager.Stop();

        UNIT_ASSERT_VALUES_EQUAL(AtomicGet(initCount), 1);
        UNIT_ASSERT_VALUES_EQUAL(stat.RequesterErrors.Get(), AtomicGet(initCount));
        UNIT_ASSERT_VALUES_EQUAL(manager.GetStat().RequesterErrors.Get(), AtomicGet(initCount));
    }

    Y_UNIT_TEST(TestRemoteError) {
        TEndpointSetKey key1("test-cluster", "test-service1");
        TEndpointSetKey key2("test-cluster", "test-service2");

        TAtomic requestCount = 0;
        auto requester = MakeRequester([&requestCount, key2](const TResolveRequestBatch& request, TResolveResultBatch& result, TStatEnv& stat) {
            DummyRequester->Resolve(request, result, stat);

            if (AtomicIncrement(requestCount) > 1) {
                for (size_t i = 0; i < request.size(); ++i) {
                    if (request[i].endpoint_set_id() == key2.endpoint_set_id()) {
                        result[i].HasError = true;
                    }
                }
            }
        });

        TEndpointSetManager manager(requester, TestClientName);

        manager.RegisterEndpointSet<TProvider>(key1);
        manager.RegisterEndpointSet<TProvider>(key2);
        manager.Start(TDuration::MilliSeconds(10));

        const auto& stat1 = manager.GetStat(key1);
        const auto& stat2 = manager.GetStat(key2);

        Wait([&requestCount]() {
            return AtomicGet(requestCount) > 1;
        });

        manager.Stop();

        UNIT_ASSERT_VALUES_EQUAL(stat1.RequesterErrors.Get(), 0);
        UNIT_ASSERT_VALUES_EQUAL(stat2.RequesterErrors.Get(), 0);

        UNIT_ASSERT_VALUES_EQUAL(stat1.RemoteErrors.Get(), 0);
        UNIT_ASSERT_VALUES_EQUAL(stat2.RemoteErrors + 1, requestCount);
        UNIT_ASSERT_VALUES_EQUAL(manager.GetStat().RemoteErrors + 1, requestCount);
    }

    Y_UNIT_TEST(CacheStoreError) {
        TEndpointSetKey key("test-cluster", "test-service");

        TAtomic requestCount = 0;
        auto requester = MakeRequester([&requestCount](const TResolveRequestBatch&, TResolveResultBatch& result, TStatEnv&) {
            const auto rc = AtomicIncrement(requestCount);
            for (auto& r : result) {
                auto& endpoint = AddEndpoint(r);
                endpoint.set_fqdn(TString("fqdn") + ToString(rc));
            }
        });

        TFsPath cacheDir("./cache");
        NFs::RemoveRecursive(cacheDir.GetPath());
        cacheDir.MkDir();

        TEndpointSetManager manager(cacheDir.GetPath(), requester, TestClientName);
        manager.RegisterEndpointSet<TProvider>(key);
        manager.Start(TDuration::MilliSeconds(10));

        const auto& stat = manager.GetStat(key);

        Wait([&requestCount]() {
            return AtomicGet(requestCount) > 0;
        });

        UNIT_ASSERT(stat.CacheStoreErrors == 0);

        TFsPath endpointSetDir = cacheDir / EndpointSetDirName(key);
        endpointSetDir.MkDir();

        Y_DEFER {
            Chmod(endpointSetDir.GetPath().c_str(), MODE0777);
        };

        Chmod(endpointSetDir.GetPath().c_str(), S_IRUSR | S_IXUSR);

        const auto c = AtomicGet(requestCount);
        Wait([&requestCount, c]() {
            return AtomicGet(requestCount) > c;
        });

        manager.Stop();

        UNIT_ASSERT(stat.CacheStoreErrors > 0);
        UNIT_ASSERT_VALUES_EQUAL(manager.GetStat().CacheStoreErrors.Get(), stat.CacheStoreErrors.Get());
        UNIT_ASSERT_VALUES_EQUAL(stat.CacheLoadErrors.Get(), 0);
        UNIT_ASSERT_VALUES_EQUAL(manager.GetStat().CacheLoadErrors.Get(), stat.CacheLoadErrors.Get());
    }

    Y_UNIT_TEST(CacheLoadError) {
        TEndpointSetKey key("test-cluster", "test-service");

        TAtomic requestCount = 0;
        auto requester = MakeRequester([&requestCount](const TResolveRequestBatch&, TResolveResultBatch& result, TStatEnv&) {
            const auto rc = AtomicIncrement(requestCount);
            auto& endpoint = AddEndpoint(result[0]);
            endpoint.set_fqdn(TString("fqdn") + ToString(rc));
        });

        TFsPath cacheDir("./cache");
        NFs::RemoveRecursive(cacheDir.GetPath());
        cacheDir.MkDir();

        TEndpointSetManager manager(cacheDir.GetPath(), requester, TestClientName);
        manager.RegisterEndpointSet<TProvider>(key);

        TFsPath endpointSetDir = cacheDir / EndpointSetDirName(key);
        endpointSetDir.MkDir();

        Y_DEFER {
            Chmod(endpointSetDir.GetPath().c_str(), MODE0777);
        };

        Chmod(endpointSetDir.GetPath().c_str(), S_IRUSR | S_IXUSR);

        manager.Start(TDuration::MilliSeconds(10));
        const auto& stat = manager.GetStat(key);

        Wait([&requestCount]() {
            return AtomicGet(requestCount) > 0;
        });

        manager.Stop();

        UNIT_ASSERT_VALUES_EQUAL(stat.CacheStoreErrors.Get(), 2*requestCount);
        UNIT_ASSERT_VALUES_EQUAL(stat.CacheLoadErrors.Get(), requestCount);

        UNIT_ASSERT_VALUES_EQUAL(manager.GetStat().CacheStoreErrors.Get(), stat.CacheStoreErrors.Get());
        UNIT_ASSERT_VALUES_EQUAL(manager.GetStat().CacheLoadErrors.Get(), stat.CacheLoadErrors.Get());
    }

    Y_UNIT_TEST(LogWriteError) {
        class TFailingLog: public TLogBackend {
            void WriteData(const TLogRecord&) override {
                Y_ENSURE(false);
            }
            void ReopenLog() override {}
        };

        TLog log;
        TEndpointSetManager manager(DummyRequester, TestClientName);

        log.ResetBackend(MakeHolder<TFailingLog>());

        manager.AssignLog(&log);

        manager.Start(TDuration::MilliSeconds(10));

        const auto& stat = manager.GetStat();
        WaitUpdate(stat);
        UNIT_ASSERT(stat.LogWriteErrors > 0);
    }

    Y_UNIT_TEST(EmptyClientName) {
        const auto f = []() {
            TEndpointSetManager manager(DummyRequester, "");
        };

        UNIT_ASSERT_EXCEPTION(f(), yexception);
    }

    Y_UNIT_TEST(InvalidEndpointSet) {
        TEndpointSetKey key("test-cluster", "test-service");

        TAtomic requestCount = 0;
        TAtomic empty = 0;
        TAtomic timestamp = 12345;
        auto requester = MakeRequester([&empty, &timestamp, &requestCount](const TResolveRequestBatch&, TResolveResultBatch& result, TStatEnv&) {
            if (AtomicGet(empty)) {
                return;
            }

            const auto rc = AtomicIncrement(requestCount);
            auto& endpoint = AddEndpoint(result[0]);
            endpoint.set_fqdn(TString("fqdn") + ToString(rc));
            endpoint.set_ip4_address("127.0.0.1");
            endpoint.set_port(80);

            result[0].Result.set_timestamp(AtomicGet(timestamp));
        });

        TEndpointSetManager manager(requester, TestClientName);
        manager.RegisterEndpointSet<TProvider>(key);
        manager.Start(TDuration::MilliSeconds(10));
        const auto& stat = manager.GetStat();
        WaitUpdate(stat);
        UNIT_ASSERT(stat.InvalidEndpointSetErrors == 0);

        AtomicSet(timestamp, 0);
        WaitUpdate(stat);
        UNIT_ASSERT(stat.ObsoleteTimestamp > 0);

        AtomicSet(timestamp, 12345);
        WaitUpdate(stat);

        const TAtomicBase prev = stat.ObsoleteTimestamp;
        AtomicSet(empty, 1);
        WaitUpdate(stat);
        UNIT_ASSERT(stat.ObsoleteTimestamp == prev);
        UNIT_ASSERT(stat.InvalidEndpointSetErrors > 0);
    }
}
