#include <solomon/agent/misc/countdown_event.h>

#include <solomon/libs/cpp/threading/pool/pool.h>

#include <library/cpp/testing/gtest/gtest.h>

#include <util/datetime/base.h>
#include <util/generic/scope.h>
#include <util/system/spinlock.h>

using namespace NSolomon::NAgent;

struct TConcurrentVector {
public:
    TAdaptiveLock VectorLock_;
    TVector<size_t> vec;

    void Add(size_t num) {
        with_lock (VectorLock_) {
            vec.push_back(num);
        }
    }
};

class TDelayPolicy {
public:
    static std::atomic<ui64> SleepTime_;

    static void Init() {
        SleepTime_ = 500;
    }

    // Taken from https://a.yandex-team.ru/arc/trunk/arcadia/util/system/spin_wait.cpp?rev=4351651#L11
    static inline TDuration RandomizeSleepTime() noexcept {
        static std::atomic<ui64> counter = 0;
        const ui64 rndNum = IntHash(counter.fetch_add(1) + 1);

        // max value: 24000 microseconds
        ui64 sleepTimeValue = SleepTime_;
        ui64 value = (sleepTimeValue * 4ull + (rndNum % sleepTimeValue) * 2ull) / 5ull;
        return TDuration::MicroSeconds(value);
    }

    static void RandomSleep() {
        Sleep(RandomizeSleepTime());
        ui64 sleepTimeValue = SleepTime_;
        SleepTime_ = Min<ui64>((sleepTimeValue * 3) / 2, 20000);
    }
};

std::atomic<ui64> TDelayPolicy::SleepTime_ = 0;

std::function<void()> CreateThreadFunc(
        TConcurrentVector& data,
        TCountdownEventImpl<TDelayPolicy>& cntEvent,
        size_t thisThreadNum)
{
    return [&, thisThreadNum]() {
        Y_SCOPE_EXIT(&cntEvent) {
            cntEvent.Dec();
        };

        Sleep(TDuration::Seconds(1));
        data.Add(thisThreadNum);
    };
}

    TEST(TCountdownEventTest, DoNotBlockOnZero) {
        TDelayPolicy::Init();

        TCountdownEventImpl<TDelayPolicy> cntEvent;

        cntEvent.Stop();
        ASSERT_TRUE(cntEvent.Await());
    }

    TEST(TCountdownEventTest, Timeout) {
        TDelayPolicy::Init();

        TCountdownEvent cnt;
        auto p = NSolomon::CreateThreadPool(/*threads*/2, 0);

        ASSERT_TRUE(cnt.TryInc());
        bool added = p->AddFunc([&cnt]() {
            Y_SCOPE_EXIT(&cnt) {
                cnt.Dec();
            };

            Sleep(TDuration::Seconds(2));
        });
        ASSERT_TRUE(added);

        cnt.Stop();

        ASSERT_FALSE(cnt.Await(TDuration::Seconds(1)));
        ASSERT_TRUE(cnt.Await(TDuration::Seconds(2)));
    }

    TEST(TCountdownEventTest, IncOnStoppedCntFails) {
        TDelayPolicy::Init();

        TCountdownEvent cnt;
        auto p = NSolomon::CreateThreadPool(/*threads*/2, 0);

        bool added = p->AddFunc([&cnt]() {
            Sleep(TDuration::Seconds(1));

            ASSERT_FALSE(cnt.TryInc());
            ASSERT_FALSE(cnt.TryAdd(0));
            ASSERT_FALSE(cnt.TryAdd(1));
            ASSERT_FALSE(cnt.TryAdd(2));
        });
        ASSERT_TRUE(added);

        cnt.Stop();
        ASSERT_TRUE(cnt.Await(TDuration::Seconds(2)));
    }

    TEST(TCountdownEventTest, BelowZeroException) {
        TDelayPolicy::Init();

        TCountdownEvent cnt;
        ASSERT_THROW(cnt.Dec(), yexception);
    }

    TEST(TCountdownEventTest, WaitAllThreads) {
        TDelayPolicy::Init();

        TCountdownEventImpl<TDelayPolicy> cntEvent;
        TConcurrentVector data;

        size_t thisThreadNum = 0;
        uint64_t numOfThreads = 16;
        auto p = NSolomon::CreateThreadPool(numOfThreads, 0);

        for (size_t i = 0; i != numOfThreads; ++i) {
            ASSERT_TRUE(cntEvent.TryInc());

            bool added = p->AddFunc(CreateThreadFunc(data, cntEvent, i + 1));
            ASSERT_TRUE(added);
        }

        cntEvent.Stop();
        ASSERT_TRUE(cntEvent.Await());
        data.Add(thisThreadNum);

        for (size_t i = 0; i != numOfThreads - 1; ++i) {
            ASSERT_TRUE(data.vec[i] != thisThreadNum);
        }
        ASSERT_EQ(data.vec[numOfThreads], thisThreadNum);
    }

    TEST(TCountdownEventTest, CntIsMovable) {
        TDelayPolicy::Init();

        auto p = NSolomon::CreateThreadPool(/*threads*/2, 0);
        TCountdownEvent cnt1;

        ASSERT_FALSE(cnt1.IsStopped());
        cnt1.TryAdd(2);

        TCountdownEvent cnt2 = std::move(cnt1);

        ASSERT_FALSE(cnt2.IsStopped());

        cnt1.TryInc();
        cnt2.TryInc();

        cnt1.Stop();
        ASSERT_TRUE(p->AddFunc([&cnt1]() {
            Y_SCOPE_EXIT(&cnt1) {
                cnt1.Dec();
            };

            Sleep(TDuration::Seconds(1));
        }));
        cnt1.Await();

        cnt2.Stop();

        ASSERT_TRUE(cnt1.IsStopped());
        ASSERT_FALSE(cnt2.IsStopped());

        ASSERT_TRUE(p->AddFunc([&cnt2]() {
            Y_SCOPE_EXIT(&cnt2) {
                cnt2.Sub(3);
            };

            Sleep(TDuration::Seconds(1));
        }));
        cnt2.Await();


        ASSERT_TRUE(cnt1.IsStopped());
        ASSERT_TRUE(cnt2.IsStopped());
    }
