#include "channel.h"
#include "waked.h"

#include <balancer/kernel/testing/cont_executor.h>

#include <library/cpp/testing/unittest/registar.h>
#include <library/cpp/coroutine/engine/impl.h>
#include <util/random/random.h>
#include <util/thread/pool.h>

using namespace NSrvKernel;
using namespace NSrvKernel::NTesting;

Y_UNIT_TEST_SUITE(Channel) {
    const TDuration timeout = TDuration::MilliSeconds(100);

    void Sleep(TDuration duration, TCont* cont, TEventWaker* waker = nullptr) {
        Y_UNUSED(waker);
        cont->SleepT(duration);
    }

    template <typename... Args>
    void RandomDelay(size_t iteration, Args... waitArgs) {
        if ((0 == iteration % 1000) && (RandomNumber<double>() < 0.1)) {
            Sleep(TDuration::MilliSeconds(5) * RandomNumber<double>(), std::forward<Args>(waitArgs)...);
        }
    }

    template<template <typename> typename TChannel>
    struct TEnv {
        explicit TEnv(size_t channelSize, EChannelOverflowMode overflowMode)
            : Size(channelSize)
            , Count(Max<size_t>(10000, channelSize * 1000))
            , Channel(channelSize, overflowMode)
        {
        }

        const size_t Size;
        const size_t Count;

        TChannel<int> Channel;

        TAutoEvent StartReceive;

        TThreadPool ThreadPool;

        TAtomic Received = 0;

        void SendFromThread() {
            ThreadPool.Start(1, 0);

            UNIT_ASSERT(ThreadPool.AddFunc([&] {
                Send([&]() {
                    StartReceive.Signal();
                });
            }));
        }

        void ReceiveFromThread() {
            StartReceive.WaitI();
            Receive();
        }

        void SendFromCoroutine() {
            ThreadPool.Start(1, 0);

            UNIT_ASSERT(ThreadPool.AddFunc([&] {
                TTestContExecutor contExecutor;
                auto send = [&](TCont* cont) {
                    SendFromCoroutine<typename TChannel<int>::TReceiveEvent>([&]() {
                        StartReceive.Signal();
                    }, cont, &contExecutor);
                };
                contExecutor.Create(send, "sender");
                contExecutor.Execute();
            }));
        }

        void ReceiveFromCoroutine() {
            TTestContExecutor contExecutor;
            auto receive = [&](TCont* cont) {
                ReceiveFromCoroutine<typename TChannel<int>::TSendEvent>(cont, &contExecutor);
            };

            contExecutor.Create(receive, "receiver");

            StartReceive.WaitI();
            contExecutor.Execute();
        }

        template <typename TEvent>
        void SendFromCoroutine(std::function<void()> startReceive, TCont* cont, TContExecutor* executor) {
            Y_UNUSED(executor);
            return Send(startReceive, cont);
        }

        template <>
        void SendFromCoroutine<TWakedEvent>(std::function<void()> startReceive, TCont* cont, TContExecutor* executor) {
            TEventWaker waker(executor);
            return Send(startReceive, cont, &waker);
        }

        template <>
        void SendFromCoroutine<TUnitedEvent>(std::function<void()> startReceive, TCont* cont, TContExecutor* executor) {
            Y_UNUSED(cont);
            Y_UNUSED(executor);
            return Send(startReceive);
        }

        template <typename TEvent>
        void ReceiveFromCoroutine(TCont* cont, TContExecutor* executor) {
            Y_UNUSED(executor);
            return Receive(cont);
        }

        template <>
        void ReceiveFromCoroutine<TWakedEvent>(TCont* cont, TContExecutor* executor) {
            TEventWaker waker(executor);
            return Receive(cont, &waker);
        }

        template <>
        void ReceiveFromCoroutine<TUnitedEvent>(TCont* cont, TContExecutor* executor) {
            Y_UNUSED(cont);
            Y_UNUSED(executor);
            return Receive();
        }

        template <typename F, typename... Args>
        void Send(F startReceive, Args... waitArgs) {
            for (size_t i = 0; i < Size; ++i) {
                UNIT_ASSERT_EQUAL(Channel.Send(i + 1, TInstant::Max(), std::forward<Args>(waitArgs)...), EChannelStatus::Success);
            }

            UNIT_ASSERT(!Channel.TrySend(1));
            UNIT_ASSERT_VALUES_EQUAL(Channel.InFly(), Size);

            auto now = Now();
            UNIT_ASSERT(Channel.Send(1, timeout.ToDeadLine(), std::forward<Args>(waitArgs)...) == EChannelStatus::TimedOut);
            UNIT_ASSERT(Now() - now > timeout);

            startReceive();

            for (size_t i = Size; i < Count; ++i) {
                UNIT_ASSERT_EQUAL(Channel.Send(i + 1, TInstant::Max(), std::forward<Args>(waitArgs)...), EChannelStatus::Success);
                UNIT_ASSERT(Channel.InFly() <= Size);
                RandomDelay(i, std::forward<Args>(waitArgs)...);
            }

            UNIT_ASSERT_EQUAL(Channel.Send(0, TInstant::Max(), std::forward<Args>(waitArgs)...), EChannelStatus::Success);
        }

        template <typename... Args>
        void Receive(Args... waitArgs) {
            size_t s = 0;

            int prev = 0;
            int x = 0;
            size_t i = 0;
            while (Channel.Receive(x, TInstant::Max(), std::forward<Args>(waitArgs)...) == EChannelStatus::Success && x > 0) {
                AtomicIncrement(Received);
                UNIT_ASSERT_VALUES_EQUAL(prev + 1, x);
                prev = x;
                s += x;

                ++i;
                RandomDelay(i, std::forward<Args>(waitArgs)...);
            }

            UNIT_ASSERT_VALUES_EQUAL(s, ((1 + Count) * Count) / 2);
            UNIT_ASSERT(!Channel.TryReceive(x));

            auto now = Now();
            UNIT_ASSERT(Channel.Receive(x, timeout.ToDeadLine(), std::forward<Args>(waitArgs)...) == EChannelStatus::TimedOut);
            UNIT_ASSERT(Now() - now >= timeout);
        }
    };

    auto TestCases() {
        TVector<std::pair<size_t, EChannelOverflowMode>> res;
        for (size_t size : {1, 100, 1000}) {
            for (EChannelOverflowMode mode : {EChannelOverflowMode::Poll, EChannelOverflowMode::Signal}) {
                res.emplace_back(size, mode);
            }
        }

        Y_VERIFY(res.size() == 6);

        return res;
    }

    Y_UNIT_TEST(Thread2Thread) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TT2TChannel> env(size, overflowMode);
            env.SendFromThread();
            env.ReceiveFromThread();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Thread2Coroutine) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TT2CChannel> env(size, overflowMode);
            env.SendFromThread();
            env.ReceiveFromCoroutine();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Coroutine2Thread) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TC2TChannel> env(size, overflowMode);
            env.SendFromCoroutine();
            env.ReceiveFromThread();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Coroutine2Coroutine) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TC2CChannel> env(size, overflowMode);
            env.SendFromCoroutine();
            env.ReceiveFromCoroutine();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Thread2Waked) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TT2WChannel> env(size, overflowMode);
            env.SendFromThread();
            env.ReceiveFromCoroutine();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Waked2Thread) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TW2TChannel> env(size, overflowMode);
            env.SendFromCoroutine();
            env.ReceiveFromThread();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Waked2Waked) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TW2WChannel> env(size, overflowMode);
            env.SendFromCoroutine();
            env.ReceiveFromCoroutine();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Thread2ThreadUnited) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TU2UChannel> env(size, overflowMode);
            env.SendFromThread();
            env.ReceiveFromThread();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Thread2CoroutineUnited) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TU2UChannel> env(size, overflowMode);
            env.SendFromThread();
            env.ReceiveFromCoroutine();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Coroutine2ThreadUnited) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TU2UChannel> env(size, overflowMode);
            env.SendFromCoroutine();
            env.ReceiveFromThread();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Coroutine2CoroutineUnited) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TU2UChannel> env(size, overflowMode);
            env.SendFromCoroutine();
            env.ReceiveFromCoroutine();
            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Coroutine2CoroutineSingleThread) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TC2CChannel> env(size, overflowMode);

            TTestContExecutor contExecutor;

            auto receive = [&](TCont* cont) {
                env.Receive(cont);
            };

            auto send = [&](TCont* cont) {
                env.Send([&]() {
                    contExecutor.Create(receive, "receiver");
                },
                cont);
            };
            contExecutor.Create(send, "sender");
            contExecutor.Execute();

            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    Y_UNIT_TEST(Waked2WakedSingleThread) {
        for (auto [size, overflowMode] : TestCases()) {
            TEnv<TW2WChannel> env(size, overflowMode);

            TTestContExecutor contExecutor;

            auto receive = [&](TCont* cont) {
                env.ReceiveFromCoroutine<TWakedEvent>(cont, &contExecutor);
            };

            auto send = [&](TCont* cont) {
                env.SendFromCoroutine<TWakedEvent>([&]() {
                    contExecutor.Create(receive, "receiver");
                },
                cont, &contExecutor);
            };
            contExecutor.Create(send, "sender");
            contExecutor.Execute();

            UNIT_ASSERT_VALUES_EQUAL(AtomicGet(env.Received), env.Count);
        }
    }

    size_t copyCounter = 0;
    size_t moveCounter = 0;
    Y_UNIT_TEST(CountCopyAndMove) {
        struct T {
            T() = default;

            T(T&&) {
                ++moveCounter;
            }

            T(const T&) {
                ++copyCounter;
            }

            T& operator=(T&&) {
                ++moveCounter;
                return *this;
            }

            T& operator=(const T&) {
                ++copyCounter;
                return *this;
            }
        };

        TT2TChannel<T> c(1);
        T t;

        copyCounter = 0;
        moveCounter = 0;
        UNIT_ASSERT(c.TrySend(t));
        UNIT_ASSERT_VALUES_EQUAL(copyCounter, 1);
        UNIT_ASSERT_VALUES_EQUAL(moveCounter, 0);
        UNIT_ASSERT(c.TryReceive(t));
        UNIT_ASSERT_VALUES_EQUAL(copyCounter, 2);
        UNIT_ASSERT_VALUES_EQUAL(moveCounter, 1);
        copyCounter = 0;
        moveCounter = 0;
        UNIT_ASSERT(c.TrySend(std::move(t)));
        UNIT_ASSERT_VALUES_EQUAL(copyCounter, 0);
        UNIT_ASSERT_VALUES_EQUAL(moveCounter, 1);
        UNIT_ASSERT(c.TryReceive(t));
        UNIT_ASSERT_VALUES_EQUAL(copyCounter, 1);
        UNIT_ASSERT_VALUES_EQUAL(moveCounter, 2);
    }

    void StressTest(size_t size, EChannelOverflowMode overflowMode) {
        TT2WChannel<int> channel(size, overflowMode);
        TAtomic finished = 0;

        TThreadPool threadPool;
        threadPool.Start(10, 0);

        UNIT_ASSERT(threadPool.AddFunc([&] {
            while (!AtomicGet(finished)) {
                Y_UNUSED(channel.Send(1, TDuration::MilliSeconds(1).ToDeadLine()));
                Sleep(TDuration::MicroSeconds(100));
            }
        }));

        TTestContExecutor executor;

        auto receive = [&](TCont* cont) {
            size_t received = 0;
            size_t errors = 0;
            TEventWaker waker(&executor);
            while (received < 50000 || errors < 5000) {
                int x = 0;
                if (channel.Receive(x, TDuration::MicroSeconds(1).ToDeadLine(), cont, &waker) == EChannelStatus::Success) {
                    ++received;
                    UNIT_ASSERT_VALUES_EQUAL(x, 1);
                } else {
                    ++errors;
                }
            }

            errors = received = 0;

            while (received < 50000 || errors < 5000) {
                TEventWaker waker(&executor);
                int x = 0;
                if (channel.Receive(x, TDuration::MicroSeconds(1).ToDeadLine(), cont, &waker) == EChannelStatus::Success) {
                    ++received;
                    UNIT_ASSERT(x == 1);
                } else {
                    ++errors;
                }
            }
        };

        executor.Create(receive, "receive");
        executor.Execute();

        AtomicSet(finished, 1);
    }

    Y_UNIT_TEST(ThreadToWakedStressTestSignalMode) {
        StressTest(10, EChannelOverflowMode::Signal);
    }

    Y_UNIT_TEST(ThreadToWakedStressTestPollMode) {
        StressTest(10, EChannelOverflowMode::Poll);
    }
}
