#include "server.h"

#include <balancer/server/ut/util/common.h>
#include <balancer/server/ut/util/env.h>
#include <balancer/server/ut/util/raw_socket_io.h>

#include <library/cpp/testing/unittest/tests_data.h>

#include <library/cpp/cgiparam/cgiparam.h>

#include <util/random/random.h>

using namespace NBalancerServer;
using namespace NSrvKernel;

Y_UNIT_TEST_SUITE(TestServerStreaming) {
    Y_UNIT_TEST(TestStreaming) {
        struct TChannels {
            TAtomic Counter = 0;
            THashMap<size_t, TAtomicSharedPtr<TC2TChannel<TString>>> Channels;
            TMutex Lock;
        };

        for (size_t threads : {1, 5}) {
            TChannels c;

            auto streamReplyCallback = [&c](THttpRequestEnv& env) {
                TQuickCgiParam cgi(env.Cgi());
                const size_t ruid = FromString(cgi.Get("ruid"));

                TAtomicSharedPtr<TC2TChannel<TString>> channel;
                with_lock (c.Lock) {
                    channel = c.Channels[ruid];
                    c.Channels.erase(ruid);
                }

                NSrvKernel::TResponse response(200, "Ok");
                response.Props().ChunkedTransfer = true;

                auto reply = env.GetReplyTransport();

                reply->SendHead(std::move(response));

                TString v = 0;
                while (channel->Receive(v, TInstant::Max()) == EChannelStatus::Success) {
                    if (!v) {
                        break;
                    }
                    reply->SendData(v);
                }

                reply->SendEof();

                return TError();
            };

            auto streamRequestCallback = [&c](NTesting::TEnv& env) {
                TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                                NCoro::NStack::TPoolAllocatorSettings{});

                TVector<TString> codecs = {"deflate", "gzip", "br"};

                auto f = [&] (TCont* cont) {
                    NTesting::TRawSocketIO io("localhost", env.Port, cont, Now() + TDuration::Seconds(10));
                    const bool keepAlive = RandomNumber<double>() < 0.5;

                    for (size_t i = 0; i < 5; ++i) {
                        const size_t ruid = AtomicIncrement(c.Counter);
                        try {
                            TAtomicSharedPtr<TC2TChannel<TString>> channel = MakeAtomicShared<TC2TChannel<TString>>(128u);
                            with_lock (c.Lock) {
                                c.Channels[ruid] = channel;
                            }

                            TString codec;
                            if (RandomNumber<double>() < 0.2) {
                                codec = codecs[RandomNumber<size_t>() % codecs.size()];
                            }

                            THttpOutput out(&io);
                            out.EnableKeepAlive(keepAlive);
                            out << "GET /?ruid=" << ruid << " HTTP/1.1\r\n";
                            if (codec) {
                                out << "Accept-Encoding: " << codec << "\r\n";
                            }
                            out << "\r\n";
                            out.Flush();
                            io.Flush();

                            THttpInput httpI(&io);

                            UNIT_ASSERT_VALUES_EQUAL(ParseHttpRetCode(httpI.FirstLine()), 200);



                            for (size_t i = 0; i < RandomNumber<size_t>(128); ++i) {
                                const auto size = RandomNumber<size_t>(10 * 1024) + 1;
                                TString v = NUnitTest::RandomString(size, size);
                                UNIT_ASSERT_EQUAL(channel->Send(v, TInstant::Max(), cont), EChannelStatus::Success);
                                TString r;
                                r.resize(v.size());
                                httpI.Load((char*)r.data(), r.size());
                                UNIT_ASSERT_VALUES_EQUAL(r, v);
                            }

                            if (codec) {
                                auto* encoding = httpI.Headers().FindHeader("Content-Encoding");
                                UNIT_ASSERT(encoding);
                                UNIT_ASSERT_VALUES_EQUAL(encoding->Value(), codec);
                            } else {
                                UNIT_ASSERT(!httpI.Headers().FindHeader("Content-Encoding"));
                            }

                            UNIT_ASSERT_EQUAL(channel->Send("", TInstant::Max(), cont), EChannelStatus::Success);

                            UNIT_ASSERT_VALUES_EQUAL(httpI.ReadAll(), "");

                            if (!keepAlive) {
                                break;
                            }
                        } catch (...) {
                            UNIT_ASSERT_C(false, CurrentExceptionMessage());
                        }
                    }
                };

                e.Create(f, "request");
                e.Execute();
            };

            TOptions options;
            options.Threads = threads;
            options.EnableResponseCompression = true;

            NTesting::RunConcurrent<NTesting::TEnv>(
                streamReplyCallback,
                streamRequestCallback,
                options
            );
        }
    }

    Y_UNIT_TEST(TestChunkedRequest) {
        TOptions options;

        for (size_t threads : {0, 1, 5}) {
            options.Threads = threads;

            NTesting::RunConcurrent<NTesting::TEnv>(
                NTesting::GetEchoCallback(),
                [&](NTesting::TEnv& env) {
                    auto f = [&] (TCont* cont) {
                        try {
                            TString data = NTesting::Repeat("abcdef", RandomNumber<size_t>(100) + 1);

                            NTesting::TRawSocketIO io("localhost", env.Port, cont, Now() + TDuration::Seconds(10));
                            THttpOutput out(&io);
                            out << "POST /yandsearch?xxx=1 HTTP/1.1\r\n";
                            out << "Transfer-Encoding: chunked\r\n";
                            out << "\r\n";

                            for (size_t i = 0; i < data.size();) {
                                size_t next = Min(data.size(), i + 1 + RandomNumber<size_t>(100));
                                out << data.substr(i, next - i);
                                out.Flush();
                                io.Flush();
                                i = next;
                            }
                            out.Finish();
                            io.Flush();

                            THttpInput httpI(&io);

                            UNIT_ASSERT_VALUES_EQUAL(ParseHttpRetCode(httpI.FirstLine()), 200);
                            UNIT_ASSERT_VALUES_EQUAL(httpI.ReadAll(), data);
                        } catch (...) {
                            UNIT_ASSERT_C(false, CurrentExceptionMessage());
                        }
                    };

                    TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                                    NCoro::NStack::TPoolAllocatorSettings{});
                    e.Create(f, "request");
                    e.Execute();
                },
                options,
                100
            );
        }
    }
}

