#include <infra/netmon/library/helpers.h>
#include <infra/netmon/library/scheduler.h>

#include <library/cpp/testing/unittest/registar.h>

#include <util/generic/xrange.h>
#include <util/generic/list.h>

using namespace NNetmon;

namespace {

    class TCounterTask: public TScheduledTask {
    public:
        TCounterTask(const TDuration& interval)
            : TScheduledTask(interval, true)
            , Counter(0)
        {
        }

        NNetmon::TThreadPool::TFuture Run() override {
            return NNetmon::TThreadPool::Get()->Add([this]() {
                Counter.Inc();
            });
        }

        TAtomicCounter Counter;
    };

    class TRoundToIntervalTask : public TScheduledTask {
    public:
        TRoundToIntervalTask(const TDuration& interval)
            : TScheduledTask(TScheduledTask::TRoundToIntervalTag{}, interval, TDuration())
            , Promise(NThreading::NewPromise<TInstant>())
        {
        }

        NNetmon::TThreadPool::TFuture Run() override {
            Promise.SetValue(TInstant::Now());
            return NThreading::MakeFuture();
        }

        NThreading::TFuture<TInstant> GetFuture() {
            return Promise;
        }

    private:
        NThreading::TPromise<TInstant> Promise;
    };
}

class TSchedulerTest: public TTestBase {
    UNIT_TEST_SUITE(TSchedulerTest);
    UNIT_TEST(TestGuard)
    UNIT_TEST(TestStress)
    UNIT_TEST(TestRoundToInterval)
    UNIT_TEST_SUITE_END();

private:
    inline void TestGuard() {
        TCounterTask task(TDuration::MilliSeconds(1));
        auto guard(task.Schedule());
        task.SpinAndWait().Wait();
        UNIT_ASSERT(task.Counter.Val() >= 1);
    }

    inline void TestStress() {
        TList<TCounterTask> tasks;
        TList<TCounterTask::TTaskGuard> guards;
        for (auto index : xrange(100)) {
            tasks.emplace_back(TDuration::MilliSeconds(10));
            guards.emplace_back(tasks.back().Schedule());
            Y_UNUSED(index);
        }
        TScheduler::Get()->CreateWaiter().Wait();
        guards.clear();
        for (const auto& task : tasks) {
            UNIT_ASSERT(task.Counter.Val() >= 1);
        }
    }

    inline void TestRoundToInterval() {
        auto interval = TDuration::Seconds(1);

        for (int index : xrange(3)) {
            Y_UNUSED(index);
            TRoundToIntervalTask task(interval);

            {
                auto guard = task.Schedule();
                task.GetFuture().Wait();
            }

            auto executedAt = task.GetFuture().GetValue();
            TInstant intervalEndpoint = RoundInstant(executedAt, interval);
            auto diff = Min(executedAt - intervalEndpoint,
                            intervalEndpoint + interval - executedAt);
            UNIT_ASSERT(diff < TDuration::MilliSeconds(250));
            Sleep(TDuration::MilliSeconds(250));
        }
    }
};

UNIT_TEST_SUITE_REGISTRATION(TSchedulerTest);
