#include <balancer/kernel/coro/channel.h>
#include <balancer/kernel/coro/waked.h>

#include <library/cpp/testing/benchmark/bench.h>

#include <library/cpp/coroutine/engine/impl.h>
#include <util/thread/pool.h>

using namespace NSrvKernel;

namespace {
    template <template <typename T> class TChannel>
    struct TEnv {
        TEnv(size_t size, size_t count)
            : Count(count)
            , Channel(size)
        {
            ThreadPool.Start(1, 0);
        }

        const size_t Count;
        TChannel<TString> Channel;

        TThreadPool ThreadPool;

        void SendFromThread() {
            Y_VERIFY(ThreadPool.AddFunc([&] {
                Send();
            }));
        }

        size_t ReceiveFromThread() {
            return Receive();
        }

        void SendFromCoroutine() {
            Y_VERIFY(ThreadPool.AddFunc([&] {
                TContExecutor contExecutor(32000);
                const auto send = [&](TCont* cont) {
                    SendFromCoroutine<typename TChannel<TString>::TReceiveEvent>(cont, &contExecutor);
                };
                contExecutor.Create(send, "sender");
                contExecutor.Execute();
            }));
        }

        template<typename TEvent>
        void SendFromCoroutine(TCont* cont, TContExecutor* executor) {
            Y_UNUSED(executor);
            return Send(cont);
        }

        template <>
        void SendFromCoroutine<TWakedEvent>(TCont* cont, TContExecutor* executor) {
            TEventWaker waker(executor);
            return Send(cont, &waker);
        }

        template <>
        void SendFromCoroutine<TUnitedEvent>(TCont* cont, TContExecutor* executor) {
            Y_UNUSED(executor);
            Y_UNUSED(cont);
            return Send();
        }

        size_t ReceiveFromCoroutine() {
            size_t res = 0;
            TContExecutor contExecutor(32000);
            const auto receive = [&](TCont* cont) {
                res = ReceiveFromCoroutine<typename TChannel<TString>::TSendEvent>(cont, &contExecutor);
            };

            contExecutor.Create(receive, "receiver");
            contExecutor.Execute();
            return res;
        }

        template <typename TEvent>
        size_t ReceiveFromCoroutine(TCont* cont, TContExecutor* executor) {
            Y_UNUSED(executor);
            return Receive(cont);
        }

        template <>
        size_t ReceiveFromCoroutine<TWakedEvent>(TCont* cont, TContExecutor* executor) {
            TEventWaker waker(executor);
            return Receive(cont, &waker);
        }

        template <>
        size_t ReceiveFromCoroutine<TUnitedEvent>(TCont* cont, TContExecutor* executor) {
            Y_UNUSED(cont);
            Y_UNUSED(executor);
            return Receive();
        }

        template <typename... Args>
        void Send(Args... waitArgs) {
            TString x = "1";
            for (size_t i = 0; i < Count; ++i) {
                Y_VERIFY(Channel.Send(x, TInstant::Max(), std::forward<Args>(waitArgs)...) == EChannelStatus::Success);
            }
            Y_VERIFY(Channel.Send("", TInstant::Max(), std::forward<Args>(waitArgs)...) == EChannelStatus::Success);
        }

        template <typename... Args>
        size_t Receive(Args... waitArgs) {
            size_t i = 0;
            TString x;
            while (Channel.Receive(x, TInstant::Max(), std::forward<Args>(waitArgs)...) == EChannelStatus::Success && x) {
                ++i;
            }
            return i;
        }
    };
}

const auto BATCH_SIZE = 1000;

template <template <typename T> class TChannel, typename IFace>
void Coroutine2CoroutineSingleThreaded(size_t channelSize, IFace& iface) {
    TEnv<TChannel> env(channelSize, BATCH_SIZE);
    TContExecutor contExecutor(32000);

    for (size_t i = 0; i < iface.Iterations(); ++i) {
        for (size_t i = 0; i < iface.Iterations(); ++i) {
            size_t received = 0;
            const auto receive = [&](TCont* cont) {
                received = env.template ReceiveFromCoroutine<typename TChannel<TString>::TSendEvent>(cont, &contExecutor);
            };
            const auto send = [&](TCont* cont) {
                env.template SendFromCoroutine<typename TChannel<TString>::TReceiveEvent>(cont, &contExecutor);
            };
            contExecutor.Create(receive, "receiver");
            contExecutor.Create(send, "sender");
            contExecutor.Execute();
            Y_VERIFY(received == BATCH_SIZE);
        }
    }
}

#define BENCHMARK(channelSize)                                       \
Y_CPU_BENCHMARK(Thread2Thread##channelSize, iface) {                 \
    TEnv<TT2TChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromThread();                                        \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromThread());             \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Thread2Coroutine##channelSize, iface) {              \
    TEnv<TT2CChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromThread();                                        \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromCoroutine());          \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Coroutine2Thread##channelSize, iface) {              \
    TEnv<TC2TChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromCoroutine();                                     \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromThread());             \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Coroutine2Coroutine##channelSize, iface) {           \
    TEnv<TC2CChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromCoroutine();                                     \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromCoroutine());          \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Thread2Waked##channelSize, iface) {                  \
    TEnv<TT2WChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromThread();                                        \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromCoroutine());          \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Waked2Thread##channelSize, iface) {                  \
    TEnv<TW2TChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromCoroutine();                                     \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromThread());             \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Waked2Waked##channelSize, iface) {                   \
    TEnv<TC2CChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromCoroutine();                                     \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromCoroutine());          \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Thread2ThreadUnited##channelSize, iface) {           \
    TEnv<TU2UChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromThread();                                        \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromThread());             \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Thread2CoroutineUnited##channelSize, iface) {        \
    TEnv<TU2UChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromThread();                                        \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromCoroutine());          \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Coroutine2ThreadUnited##channelSize, iface) {        \
    TEnv<TU2UChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromCoroutine();                                     \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromThread());             \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Coroutine2CoroutineUnited##channelSize, iface) {     \
    TEnv<TU2UChannel> env(channelSize, BATCH_SIZE);                  \
    for (size_t i = 0; i < iface.Iterations(); ++i) {                \
        env.SendFromCoroutine();                                     \
        Y_VERIFY(BATCH_SIZE == env.ReceiveFromCoroutine());          \
    }                                                                \
}                                                                    \
Y_CPU_BENCHMARK(Coroutine2CoroutineSingleThreaded##channelSize, iface) {        \
    Coroutine2CoroutineSingleThreaded<TC2CChannel>(channelSize, iface);         \
}                                                                               \
Y_CPU_BENCHMARK(Coroutine2CoroutineSingleThreadedWaked##channelSize, iface) {  \
    Coroutine2CoroutineSingleThreaded<TW2WChannel>(channelSize, iface);         \
}                                                                               \
Y_CPU_BENCHMARK(Coroutine2CoroutineSingleThreadedUnited##channelSize, iface) {  \
    Coroutine2CoroutineSingleThreaded<TU2UChannel>(channelSize, iface);         \
}                                                                               \

BENCHMARK(1)
BENCHMARK(10)
BENCHMARK(100)
BENCHMARK(1000)
