#include "smartqueue.h"
#include "smartqueue_ng.h"

#include <library/cpp/deprecated/atomic/atomic.h>
#include <library/cpp/logger/global/global.h>

#include <library/cpp/testing/unittest/registar.h>

#include <util/system/mutex.h>
#include <util/system/thread.h>
#include <util/system/event.h>
#include <util/generic/utility.h>
#include <util/random/fast.h>

#include <atomic>

using namespace NUtil;

class TSmartMtpQueueTest : public TTestBase {
    UNIT_TEST_SUITE(TSmartMtpQueueTest);
        UNIT_TEST(RejectAtStop)
        UNIT_TEST(StartAndStop)
        UNIT_TEST(Overflow)
        UNIT_TEST(QueueLimits)
        UNIT_TEST(TsrTest)
        UNIT_TEST(IdleInterval)
        UNIT_TEST(CompareQueues)
        UNIT_TEST(PerformanceTest)
        UNIT_TEST(ThreadsCreatedUnlimited)
        UNIT_TEST(ThreadsCreatedLimited)
    UNIT_TEST_SUITE_END();

protected:
    TMutex Lock;
    ui32 Running;
    i32 Value;
    TMutex StopLock;

    TAtomic ScheduledCounter = 0;

protected:
    struct TTask : public IObjectInQueue {
        TSmartMtpQueueTest *Test;
        i32 Value;

        TTask(TSmartMtpQueueTest *parent, i32 val)
            : Test(parent)
            , Value(val) {

        }
        void Process(void *) override {
            THolder<TTask> Deleter(this);
            {
                TGuard<TMutex> g(Test->Lock);

                Test->Value += Value;
                Test->Running++;
            }
            {
                TGuard<TMutex> stopg(Test->StopLock);
            }
            {
                TGuard<TMutex> g(Test->Lock);
                Test->Running--;
            }
        }
    };

    struct TRandomWaitTask : public IObjectInQueue {
        TReallyFastRng32& Generator;

        TRandomWaitTask(TReallyFastRng32& generator)
            : Generator(generator)
        {
        }

        void Process(void*) override {
            THolder<TRandomWaitTask> deleter(this);
            Sleep(TDuration::MicroSeconds(Generator.Uniform(333)));
        }
    };

    class TCountedThreadPool : public IThreadFactory {
    public:
        ui64 GetMaxInFly() const {
            return MaxInFly;
        }

        ui64 GetCreated() const {
            return AtomicGet(Created);
        }

    private:
        class TThread : public IThreadFactory::IThread {
        public:
            TThread(TCountedThreadPool& owner)
                : Owner(owner)
            {}

            ~TThread() {
                TGuard<TMutex> g(Mutex);
                Slave.Reset(nullptr);
                Owner.EndThread();
            }

        private:
            void DoRun(IThreadFactory::IThreadAble* func) override {
                TGuard<TMutex> g(Mutex);
                Slave.Reset(SystemThreadFactory()->Run(func).Release());
            }

            void DoJoin() noexcept override {
                TGuard<TMutex> g(Mutex);
                if (Slave)
                    Slave->Join();
            }

            TCountedThreadPool& Owner;
            THolder<IThreadFactory::IThread> Slave;
            TMutex Mutex;
        };

        IThread* DoCreate() override {
            MaxInFly = Max<ui64>(MaxInFly, AtomicIncrement(InFly));
            AtomicIncrement(Created);
            return new TThread(*this);
        }

        void EndThread() {
            AtomicDecrement(InFly);
        }

        TAtomic InFly = 0;
        TAtomic Created = 0;
        ui64 MaxInFly = 0;

    };

public:
    TSmartMtpQueueTest()
        : TTestBase()
        , Lock()
        , Running(0)
        , Value(1234)
        , StopLock()
    {
        InitGlobalLog2Console();
    }

    void RejectAtStop() {
        TSmartMtpQueue q;
        TTask* t = new TTask(this, 0);
        UNIT_ASSERT(!q.Add(t));
        q.Start(0);
        UNIT_ASSERT(q.Add(t));
    }

    void AddWaitTaskToQueue(TSmartMtpQueue& q, ui32 num, TReallyFastRng32& generator, size_t minThreadCount, size_t maxThreadCount) {
        TRandomWaitTask* t = nullptr;
        for (ui32 i = 0; i < num; ++i) {
            t = new TRandomWaitTask(generator);
            UNIT_ASSERT(q.Add(t));
        }
        size_t threadCount = q.ThreadCount();
        UNIT_ASSERT(threadCount >= minThreadCount);
        UNIT_ASSERT(threadCount <= maxThreadCount);
    }

    void QueueLimits() {
        TReallyFastRng32 rand(Now().NanoSeconds());
        const size_t minThreadCount = 10;
        const size_t maxThreadCount = 50;
        TSmartMtpQueue q(NUtil::TSmartMtpQueue::TOptions(minThreadCount, TDuration::Zero()));
        q.SetMaxIdleTime(TDuration::MilliSeconds(1));
        q.Start(maxThreadCount);

        for (ui32 i = 0; i < 100; i++) {
            size_t toadd = rand.Uniform(maxThreadCount);
            AddWaitTaskToQueue(q, toadd, rand, minThreadCount, maxThreadCount);
        }
        q.Stop();
    }

    void AddTasksToQueue(TSmartMtpQueue& q, ui32 num, i32 value = 1) {
        TTask* t=nullptr;
        for (ui32 i = 0; i < num; ++i) {
            t = new TTask(this, value);
            UNIT_ASSERT(q.Add(t));
        }
    }

    void Overflow() {
        Value = 0;
        TSmartMtpQueue q(NUtil::TSmartMtpQueue::TOptions(10, TDuration::Zero()));

        q.SetMaxIdleTime(TDuration::MicroSeconds(1));
        q.Start(20, 50);
        UNIT_ASSERT_EQUAL(q.ThreadCount(), 10);
        {
            TGuard<TMutex> g(StopLock);
            AddTasksToQueue(q, 10, 1);
            UNIT_ASSERT(q.ThreadCount() >= 10);
            Sleep(TDuration::MilliSeconds(100));
            UNIT_ASSERT_EQUAL(q.ThreadCount(), 10);
            AddTasksToQueue(q, 60, 1);
            UNIT_ASSERT_EQUAL(q.ThreadCount(), 20);
            UNIT_ASSERT_EQUAL(Value, i32(Running));
            Sleep(TDuration::MilliSeconds(20));
            UNIT_ASSERT(!q.Add(new TTask(this, 0)));
        }
        Sleep(TDuration::MilliSeconds(100));
        UNIT_ASSERT_EQUAL(q.ThreadCount(), 10);
        q.Stop();
        UNIT_ASSERT_EQUAL(Running, 0);
        UNIT_ASSERT_EQUAL(Value, 70);
    }

    void StartAndStop() {
        Value = 0;
        TSmartMtpQueue q(NUtil::TSmartMtpQueue::TOptions(10, TDuration::Zero()));

        TReallyFastRng32 randstart(Now().NanoSeconds());
        TReallyFastRng32 rand = randstart;

        for (ui32 i = 0; i < 10; ++i) {
            Value += rand.Uniform(1000);
        }

        q.Start(100);
        Sleep(TDuration::MilliSeconds(20));
        {
            TGuard<TMutex> g(StopLock);

            rand = randstart;
            for (ui32 i = 0; i < 10; ++i) {
                UNIT_ASSERT(q.Add(new TTask(this, -(i32)rand.Uniform(1000))));
            }

            Sleep(TDuration::MilliSeconds(20));
            UNIT_ASSERT_EQUAL(10, Running);

        }
        q.Stop();

        UNIT_ASSERT_EQUAL(0, Running);
        UNIT_ASSERT_EQUAL(0, Value);
    }

    void TestQueue(IThreadPool& queue) {
        class TFakeTask : public IObjectInQueue {
            TAtomic& Counter;
        public:
            TFakeTask(TAtomic& cntr)
                : Counter(cntr)
            {
            }

            void Process(void* /*ThreadSpecificResource*/) override {
                AtomicIncrement(Counter);
            }
        };

        class TSchedulingTask {
        protected:
            TAtomic& Counter;
            TAtomic& Errors;
            IThreadPool& Queue;
            IObjectInQueue* Task;
            const ui32 Limit;
        public:
            TSchedulingTask(TAtomic& counter, TAtomic& errors, IThreadPool& queue, IObjectInQueue* task, ui64 limit)
                : Counter(counter)
                , Errors(errors)
                , Queue(queue)
                , Task(task)
                , Limit(limit)
            {}

            void Do() {
                while (Counter < Limit) {
                    if (!Queue.Add(Task))
                        AtomicIncrement(Errors);
                    AtomicIncrement(Counter);
                }
            }

            static void* DoStatic(void* pThis) {
                reinterpret_cast<TSchedulingTask*>(pThis)->Do();
                return nullptr;
            }
        };

        const ui32 testThreads = 10;
        const ui32 limit = 10000000;
        TAtomic ExecutionCounter = TAtomic();
        TAtomic ErrorsCounter = TAtomic();
        TFakeTask Execution(ExecutionCounter);
        TSchedulingTask Scheduling(ScheduledCounter, ErrorsCounter, queue, &Execution, limit);

        TVector<TSimpleSharedPtr<TThread>> threads;

        ScheduledCounter = 0;

        TInstant start = Now();
        for (size_t i = 0; i < testThreads; ++i) {
            threads.push_back(new TThread(TSchedulingTask::DoStatic, &Scheduling));
            threads.back()->Start();
        }

        for (auto&& t : threads) {
            t->Join();
        }
        TInstant finish = Now();

        const TDuration duration = finish - start;
        Cout << "Duration: " << duration << Endl;
        Cout << "Scheduling rate: " << ScheduledCounter / duration.SecondsFloat() << " per second" << Endl;
        Cout << "Total ops: scheduled " << ScheduledCounter << " (" << ErrorsCounter << " fails), executed " << ExecutionCounter << Endl;

        queue.Stop();
    }

    void PerformanceTest() {
        const ui32 queueThreadCount = 64;

        {
            TSmartMtpQueue smartQueue;
            smartQueue.Start(queueThreadCount);

            Cerr << "Smart Mtp queue performance:" << Endl;
            TestQueue(smartQueue);
        }
        {
            TThreadPool regularQueue;
            regularQueue.Start(queueThreadCount);

            Cerr << "Regular Mtp queue performance:" << Endl;
            TestQueue(regularQueue);
        }
    }

    struct TSampleTsr {
        void Check() {
            UNIT_ASSERT(SampleNumber == 42);
            UNIT_ASSERT(SampleString == "Sample string");
        }
    private:
        ui64 SampleNumber = 42;
        TString SampleString = "Sample string";
    };

    class TTsrCreator : public TSmartMtpQueue {
        TAtomic CreatedResources = 0;

    public:
        ~TTsrCreator() {
            UNIT_ASSERT_EQUAL(CreatedResources, 0);
        }

    private:
        void* CreateThreadSpecificResource () override {
            AtomicIncrement(CreatedResources);
            return new TSampleTsr;
        }

        void DestroyThreadSpecificResource(void* arg) override {
            delete (TSampleTsr*)arg;
            Sleep(TDuration::MilliSeconds(100));
            AtomicDecrement(CreatedResources);
        }
    };

    class TTsrCaller : public IObjectInQueue {
        void Process(void* tsr) override {
            ((TSampleTsr*)tsr)->Check();
        }
    };

    void TsrTest() {
        TTsrCreator q;
        q.Start(0);
        UNIT_ASSERT(q.AddAndOwn(MakeHolder<TTsrCaller>()));
        q.Stop();
    }

    class TSimpleTask : public IObjectInQueue {
        void Process(void*) override {
            {
                TGuard<TMutex> g(Mutex);
                ThreadCount++;
                MaxThreadCount = Max(ThreadCount, MaxThreadCount);
            }

            Sleep(TDuration::MilliSeconds(10));
            AtomicIncrement(Counter);

            {
                TGuard<TMutex> g(Mutex);
                ThreadCount--;
            }
        }

    public:
        TAtomic Counter = 0;

        TMutex Mutex;
        int ThreadCount = 0;
        int MaxThreadCount = 0;
    };

    template<class TQueueType>
    void TestAddBehaviour(size_t maxThreads, size_t queueSize) {
        TQueueType q;
        q.Start(maxThreads, queueSize);
        TSimpleTask task;
        for (size_t i = 0; i < 10; i++) {
            bool ok = q.Add(&task);
            UNIT_ASSERT(maxThreads > 0 || ok);
        }
        q.Stop();

        if (maxThreads == 0) {
            UNIT_ASSERT_EQUAL(task.Counter, 10);
            UNIT_ASSERT(static_cast<size_t>(task.MaxThreadCount) > maxThreads);
        }
        else if (queueSize == 0) {
            UNIT_ASSERT_EQUAL(task.Counter, 10);
            UNIT_ASSERT(static_cast<size_t>(task.MaxThreadCount) <= maxThreads);
        }
        else {
            UNIT_ASSERT(task.Counter < 10);
            UNIT_ASSERT(static_cast<size_t>(task.MaxThreadCount) <= maxThreads);
        }
    }

    void TestAddBehaviour(size_t maxThreads, size_t queueSize) {
        TestAddBehaviour<TSimpleThreadPool>(maxThreads, queueSize);
        TestAddBehaviour<TSmartMtpQueue>(maxThreads, queueSize);
    }

    void CompareQueues() {
        for (size_t maxThreads = 0; maxThreads <= 9; maxThreads++)
            for (size_t queueSize = 0; queueSize <= 9 - maxThreads; queueSize++)
                TestAddBehaviour(maxThreads, queueSize);
    }

    class EmptyTask : public IObjectInQueue {
        void Process(void*) override {
            Sleep(Time);
            AtomicIncrement(Counter);
        }
    public:
        EmptyTask(TDuration time = TDuration())
            : Counter(0)
            , Time(time)
        {
        }

    public:
        TAtomic Counter;
        TDuration Time;
    };

    template<class TQueueType>
    void TestIdleInterval() {
        EmptyTask longTask(TDuration::Seconds(1));
        EmptyTask emptyTask;
        TQueueType q;
        q.Start(2);
        q.SetMaxIdleTime(TDuration::Seconds(1));
        UNIT_ASSERT(q.Add(&longTask));
        UNIT_ASSERT(q.Add(&emptyTask));
        Sleep(TDuration::MilliSeconds(100));
        UNIT_ASSERT(q.Add(&emptyTask));
        Sleep(TDuration::MilliSeconds(100));
        UNIT_ASSERT_EQUAL(emptyTask.Counter, 2);
        q.Stop();
    }

    void IdleInterval() {
        TestIdleInterval<TAdaptiveThreadPool>();
        TestIdleInterval<TSmartMtpQueue>();
    }

    void TestThreadsCreated(size_t maxThreads) {
        EmptyTask emptyTask;
        TCountedThreadPool pool;
        TSmartMtpQueue q(&pool);
        q.Start(maxThreads);
        q.SetMaxIdleTime(TDuration::Seconds(1));
        for (ui32 i = 0; i < 5; ++i) {
            UNIT_ASSERT(q.Add(&emptyTask));
            Sleep(TDuration::MilliSeconds(500));
        }
        q.Stop();
        UNIT_ASSERT_EQUAL(pool.GetMaxInFly(), 1);
        UNIT_ASSERT_EQUAL(pool.GetCreated(), 1);
    }

    void TestThreadsCreatedAdvanced(size_t maxThreads) {
        TCountedThreadPool pool;
        TSmartMtpQueue q(&pool);
        q.Start(maxThreads);
        q.SetMaxIdleTime(TDuration::Seconds(1));
        UNIT_ASSERT(q.AddAndOwn(MakeHolder<EmptyTask>(TDuration::MilliSeconds(100))));
        UNIT_ASSERT(q.AddAndOwn(MakeHolder<EmptyTask>(TDuration::MilliSeconds(500))));
        Sleep(TDuration::MilliSeconds(300));
        UNIT_ASSERT(q.AddAndOwn(MakeHolder<EmptyTask>()));
        q.Stop();
        UNIT_ASSERT_EQUAL(pool.GetMaxInFly(), 2);
        UNIT_ASSERT_EQUAL(pool.GetCreated(), 2);
    }

    void TestThreadsDestroed(size_t maxThreads) {
        EmptyTask emptyTask;
        TCountedThreadPool pool;
        TSmartMtpQueue q(&pool, TSmartMtpQueue::TOptions(0));
        q.Start(maxThreads);
        q.SetMaxIdleTime(TDuration::Seconds(1));
        UNIT_ASSERT(q.Add(&emptyTask));
        Sleep(TDuration::MilliSeconds(500));
        UNIT_ASSERT(q.Add(&emptyTask));
        Sleep(TDuration::MilliSeconds(1500));
        UNIT_ASSERT(q.Add(&emptyTask));
        q.Stop();
        UNIT_ASSERT_EQUAL(pool.GetMaxInFly(), 1);
        UNIT_ASSERT_EQUAL(pool.GetCreated(), 2);
    }

    void ThreadsCreatedUnlimited() {
        TestThreadsCreated(0);
        TestThreadsCreatedAdvanced(0);
        TestThreadsDestroed(0);
    }

    void ThreadsCreatedLimited() {
        TestThreadsCreated(1);
        TestThreadsCreatedAdvanced(5);
        TestThreadsDestroed(1);
    }

};
UNIT_TEST_SUITE_REGISTRATION(TSmartMtpQueueTest)


class TSmartMtpQueueNgTest : public TSmartMtpQueueTest {
    UNIT_TEST_SUITE(TSmartMtpQueueNgTest);
        UNIT_TEST(PerformanceTest)
        UNIT_TEST(LoadTest)
    UNIT_TEST_SUITE_END();
public:
    class TStatPrinter : public TSmartMtpQueueNg::IStatHandler {
    public:
        TMutex Mutex;
        int Counter = 0;
        int PrintFreq = 1;
        void OnStats(const TSmartMtpQueueNg::IStatHandler::TStats& s) override {
            with_lock(Mutex) {
                if (Counter++ % PrintFreq)
                    return;

                Cerr << "threads=" << s.NumThreads <<
                    ", busy=" << s.NumBusyThreads <<
                    ", adviced=" << s.NumAdviced <<
                    ", inFly=" << s.InFly <<
                    ", trgFly=" << s.TargetInFly <<
                    ", period=" << s.AdvicePeriod <<
                    ", queueSz=" << s.InQueue <<
                    ", nOp=" << s.nOperations << Endl;
            }
        }
    };
public:
    void PerformanceTest() {
        const ui32 queueThreadCount = 16;
        {
            TSmartMtpQueueNg::TOptions opts;
            opts.MinThreads = 1;
            opts.MaxThreadsAdd = queueThreadCount - opts.MinThreads;
            opts.JobsPerWorker = 10;
            opts.BurstJobsPerWorker = 100;
            Y_ASSERT(opts.MaxThreads() == queueThreadCount);
            TSmartMtpQueueNg smartQueue(opts);

            TSmartMtpQueueNg::TStatHandlerRef printer = MakeIntrusive<TStatPrinter>();
            smartQueue.SetStatHandler(printer);
            smartQueue.Start(queueThreadCount, queueThreadCount * 100000);

            Cerr << "TSmartMtpQueueNg performance:" << Endl;
            TestQueue(smartQueue);
            Cerr << Endl;
        }
        {
            TSmartMtpQueue smartQueue;
            smartQueue.Start(queueThreadCount);

            Cerr << "Smart Mtp queue performance:" << Endl;
            TestQueue(smartQueue);
            Cerr << Endl;
        }
        {
            TThreadPool regularQueue;
            regularQueue.Start(queueThreadCount);

            Cerr << "Regular Mtp queue performance:" << Endl;
            TestQueue(regularQueue);
            Cerr << Endl;
        }
    }

#define FU_COEFF 2  // 8 for 100% cpu load in LoadTest
    void LoadTest() {
        const ui32 queueThreadCount = (FU_COEFF <= 4) ? 16 : 64;
        {
            TSmartMtpQueueNg::TOptions opts;
            opts.MinThreads = 2;
            opts.MaxThreadsAdd = queueThreadCount - opts.MinThreads;
            opts.JobsPerWorker = 1.5f;
            opts.BurstJobsPerWorker = 4.0f;
            Y_ASSERT(opts.MaxThreads() == queueThreadCount);
            TSmartMtpQueueNg smartQueue(opts);
            TSmartMtpQueueNg::TStatHandlerRef printer = MakeIntrusive<TStatPrinter>();
            smartQueue.SetStatHandler(printer);
            smartQueue.Start(queueThreadCount, queueThreadCount * 10);

            Cerr << "TSmartMtpQueueNg performance:" << Endl;
            TestQueueLoad(smartQueue);
            Cerr << Endl;
        }
        {
            TSmartMtpQueue::TOptions opts(2);
            TSmartMtpQueue smartQueue(opts);
            smartQueue.Start(queueThreadCount);

            Cerr << "Smart Mtp queue performance:" << Endl;
            TestQueueLoad(smartQueue);
            Cerr << Endl;
        }
        {
            TThreadPool regularQueue;
            regularQueue.Start(queueThreadCount);

            Cerr << "Regular Mtp queue performance:" << Endl;
            TestQueueLoad(regularQueue);
            Cerr << Endl;
        }
    }

    static int TestLoadIter(ui32 n) {
        int vl = Now().MilliSeconds() % 256;
        for (ui32 i = 0; i < n * 1000; ++i) {
            vl = (vl * 119) % 113;
        }
        return vl;
    }

    void TestQueueLoad(IThreadPool& queue) {
        class THashLoadTask : public IObjectInQueue {
            TAtomic& Counter;
        public:
            THashLoadTask(TAtomic& cntr)
                : Counter(cntr)
            {
            }

            void Process(void* /*ThreadSpecificResource*/) override {
                static TAtomic fakeOutput;
                AtomicSet(fakeOutput, TestLoadIter(4800));
                AtomicIncrement(Counter);
            }
        };

        class TSchedulingTask {
        protected:
            std::atomic<bool>* StopFlag;
            TManualEvent& Completed;
            TAtomic& Counter;
            TAtomic& Errors;
            IThreadPool& Queue;
            IObjectInQueue* Task;
            const ui32 Limit;
        public:
            TSchedulingTask(std::atomic<bool>* stopFlag, TManualEvent& ev, TAtomic& counter, TAtomic& errors, IThreadPool& queue, IObjectInQueue* task, ui64 limit)
                : StopFlag(stopFlag)
                , Completed(ev)
                , Counter(counter)
                , Errors(errors)
                , Queue(queue)
                , Task(task)
                , Limit(limit)
            {}

            void Do() {
                TReallyFastRng32 gen(Now().NanoSeconds());
                while (Counter < Limit && !*StopFlag) {
                    if (!Queue.Add(Task))
                        AtomicIncrement(Errors);
                    AtomicIncrement(Counter);
                    Sleep(TDuration::MilliSeconds(gen.Uniform(300)));
                    TestLoadIter(gen.Uniform(800));
                }
                Completed.Signal();
            }

            static void* DoStatic(void* pThis) {
                reinterpret_cast<TSchedulingTask*>(pThis)->Do();
                return nullptr;
            }
        };

        const ui32 testThreads = 15 * FU_COEFF;
        const ui32 limit = 2000 * FU_COEFF;
        TAtomic ExecutionCounter = TAtomic();
        TAtomic ErrorsCounter = TAtomic();
        TManualEvent Completed;
        std::atomic<bool> stopFlag = false;
        THashLoadTask Execution(ExecutionCounter);
        TSchedulingTask Scheduling(&stopFlag, Completed, ScheduledCounter, ErrorsCounter, queue, &Execution, limit);

        TVector<TSimpleSharedPtr<TThread>> threads;

        ScheduledCounter = 0;

        TInstant start = Now();
        for (size_t i = 0; i < testThreads; ++i) {
            threads.push_back(new TThread(TSchedulingTask::DoStatic, &Scheduling));
            threads.back()->Start();
        }

        Completed.WaitT(TDuration::Seconds(40));

        stopFlag = true;

        for (auto&& t : threads) {
            t->Join();
        }
        TInstant finish = Now();

        const TDuration duration = finish - start;
        Cout << "Duration: " << duration << Endl;
        Cout << "Scheduling rate: " << ScheduledCounter / duration.SecondsFloat() << " per second" << Endl;
        Cout << "Total ops: scheduled " << ScheduledCounter << " (" << ErrorsCounter << " fails), executed " << ExecutionCounter << Endl;

        queue.Stop();
    }

};

UNIT_TEST_SUITE_REGISTRATION(TSmartMtpQueueNgTest)
