#include <solomon/services/fetcher/lib/data_sink/limiter.h>

#include <library/cpp/testing/gtest/gtest.h>
#include <library/cpp/monlib/metrics/metric_registry.h>

#include <util/random/random.h>

#include <thread>

using namespace testing;
using namespace NMonitoring;
using namespace NSolomon::NFetcher;

TMetricRegistry DUMMY_REGISTRY;
TIntGauge* DUMMY_COUNTER = DUMMY_REGISTRY.IntGauge(TLabels{{}});

TEST(TLimiterTest, SingleShard) {
    {
        TIntrusivePtr<TGlobalLimiter> global{new TGlobalLimiter{20, DUMMY_COUNTER}};
        TShardLimiter limiter{10, DUMMY_COUNTER, global};

        ASSERT_TRUE(limiter.OnWrite(4));
        ASSERT_TRUE(limiter.OnWrite(4));
        ASSERT_FALSE(limiter.OnWrite(4));
        limiter.OnFree(4);
        ASSERT_TRUE(limiter.OnWrite(4));
    }
    {
        TIntrusivePtr<TGlobalLimiter> global{new TGlobalLimiter{20, DUMMY_COUNTER}};
        TShardLimiter limiter{10, DUMMY_COUNTER, global};
        ASSERT_FALSE(limiter.OnWrite(20));
    }
}

TEST(TLimiterTest, TwoShards) {
    TIntrusivePtr<TGlobalLimiter> global{new TGlobalLimiter{12, DUMMY_COUNTER}};
    TShardLimiter first{10, DUMMY_COUNTER, global};
    TShardLimiter second{10, DUMMY_COUNTER, global};

    ASSERT_TRUE(first.OnWrite(8));
    ASSERT_THAT(global->Available(), 4);
    ASSERT_FALSE(second.OnWrite(8));
    second.OnFree(8);
    ASSERT_THAT(global->Available(), 4);
    ASSERT_FALSE(second.OnWrite(8));
    first.OnFree(8);
    ASSERT_THAT(global->Available(), 12);
    ASSERT_TRUE(second.OnWrite(8));
}

TEST(TLimiterTest, MultipleShardsThreaded) {
    struct TCheckedLimiter: TGlobalLimiter {
        TCheckedLimiter()
            : TGlobalLimiter{12, DUMMY_COUNTER}
        {
        }

        bool OnWrite(ui64 bytes) override {
            const auto result = TGlobalLimiter::OnWrite(bytes);
            Y_VERIFY(Current_ <= Limit_);
            return result;
        }

        void OnFree(ui64 bytes) override {
            TGlobalLimiter::OnFree(bytes);
            Y_VERIFY(Current_ <= Limit_);
        }
    };

    TIntrusivePtr<TGlobalLimiter> limiter{new TCheckedLimiter};

    auto f = [limiter] {
        TShardLimiter s{20, DUMMY_COUNTER, limiter};

        for (auto i = 0; i < 10000; ++i) {
            auto rnd = RandomNumber<ui32>(2);
            if (rnd & 1) {
                s.OnWrite(1);
            } else {
                s.OnFree(1);
            }
        }
    };

    TVector<std::thread> threads;
    for (auto i = 0; i < 10; ++i) {
        threads.emplace_back(f);
    }

    for (auto&& thread: threads) {
        thread.join();
    }
}
