#include "server.h"

#include <balancer/server/ut/util/common.h>
#include <balancer/server/ut/util/env.h>
#include <balancer/server/ut/util/raw_socket_io.h>

#include <util/random/random.h>
#include <util/stream/file.h>
#include <util/stream/zlib.h>

using namespace NBalancerServer;
using namespace NSrvKernel;
using namespace NTesting;

Y_UNIT_TEST_SUITE(TestEncoding) {
    void MakeSslConfig(IOutputStream& out) {
        TString certPath = ArcadiaSourceRoot() + "/balancer/test/plugin/certs/data/default.crt";
        TString privPath = ArcadiaSourceRoot() + "/balancer/test/plugin/certs/data/default.key";

        {
            out << "contexts = {\n";
            out << "  default = {\n";
            out << "     cert = " << certPath.Quote() << ";\n";
            out << "     priv = " << privPath.Quote() << ";\n";
            out << "     ciphers = \"kEECDH:kRSA+AES128:kRSA:+3DES:RC4:!aNULL:!eNULL:!MD5:!EXPORT:!LOW:!SEED:!CAMELLIA:!IDEA:!PSK:!SRP:!SSLv2\";\n";
            out << "  };\n";
            out << "};\n";
        }
    }

    Y_UNIT_TEST(TestSSL) {
        TOptions options;

        {
            TFileOutput fo("./ssl.cfg");
            MakeSslConfig(fo);
        }
        options.SslConfigPath = "./ssl.cfg";

        for (size_t threads : {0, 1, 5}) {
            options.Threads = threads;
            NTesting::RunConcurrent<NTesting::TEnv>(
                NTesting::GetEchoCallback(),
                [&](NTesting::TEnv& env) {
                    auto f = [&] (TCont* cont) {
                        try {
                            const auto size = RandomNumber<size_t>(10 * 1024) + 1;
                            TString data = NUnitTest::RandomString(size, size);

                            TString result = SendHttpsRequest(cont, env.SslPort,
                                "POST /yandsearch?xxx=1 HTTP/1.1",
                                data
                            ).Data;

                            UNIT_ASSERT_VALUES_EQUAL(result, data);
                        } catch (...) {
                            UNIT_ASSERT_C(false, CurrentExceptionMessage());
                        }
                    };

                    TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                                    NCoro::NStack::TPoolAllocatorSettings{});
                    e.Create(f, "request");
                    e.Execute();
                },
                options,
                100
            );
        }
    }

    Y_UNIT_TEST(TestSSLConfigs) {
        auto check = [](const TOptions& options) {
            NTesting::TEnv env;
            env.Start(NTesting::GetEchoCallback(), options);

            auto f = [&] (TCont* cont) {
                SendHttpsRequest(cont, env.SslPort, "POST /yandsearch?xxx=1 HTTP/1.1", "");
            };

            TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                            NCoro::NStack::TPoolAllocatorSettings{});
            e.Create(f, "request");
            e.Execute();
        };

        {
            TOptions options;

            {
                TFileOutput fo("./ssl.cfg");
                MakeSslConfig(fo);
            }

            options.SslConfigPath = "./ssl.cfg";

            check(options);
        }

        {
            TOptions options;
            TStringStream ss;
            MakeSslConfig(ss);
            options.SslConfig = ss.Str();
            check(options);
        }

        {//check NConfig::TConfig::DumpLua
            TOptions options;

            TStringStream ss;
            ss << "instance = {\n";
            MakeSslConfig(ss);
            ss << "};\n";
            TStringInput si(ss.Str());

            TStringStream ss2;
            NConfig::TConfig::FromLua(si).DumpLua(ss2);
            options.SslConfig = ss2.Str();
            check(options);
        }
    }

    Y_UNIT_TEST(TestSSLTwoPorts) {
        TOptions options;

        TStringStream ss;
        MakeSslConfig(ss);
        options.SslConfig = ss.Str();
        options.DecompressRequestBody = false;

        TAtomic reqCounter = 0;
        for (size_t threads : {0, 1, 5}) {
            options.Threads = threads;
            NTesting::RunConcurrent<NTesting::TEnv>(
                NTesting::GetEchoCallback(),
                [&](NTesting::TEnv& env) {
                    auto f = [&] (TCont* cont) {
                        try {
                            const auto size = RandomNumber<size_t>(10 * 1024) + 1;
                            TString data = NUnitTest::RandomString(size, size);

                            TString result;
                            if (AtomicIncrement(reqCounter) % 2) {
                                result = SendHttpRequest(cont, env.Port,
                                    "POST /yandsearch?xxx=1 HTTP/1.1",
                                    data,
                                    TString(),
                                    TString()
                                ).Data;
                            } else {
                                result = SendHttpsRequest(cont, env.SslPort,
                                    "POST /yandsearch?xxx=1 HTTP/1.1",
                                    data,
                                    TString(),
                                    TString()
                                ).Data;
                            }

                            UNIT_ASSERT_VALUES_EQUAL(result, data);
                        } catch (...) {
                            UNIT_ASSERT_C(false, CurrentExceptionMessage());
                        }
                    };

                    TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                                    NCoro::NStack::TPoolAllocatorSettings{});
                    e.Create(f, "request");
                    e.Execute();
                },
                options,
                100,
                true //useTwoPorts
            );
        }
    }

    Y_UNIT_TEST(TestRequestDecompression) {
        TOptions options;

        TStringStream ss;
        MakeSslConfig(ss);
        options.SslConfig = ss.Str();

        TVector<TString> codecs = {"deflate", "gzip", "br"};

        TAtomic reqCounter = 0;
        for (size_t threads : {0, 1, 5}) {
            options.Threads = threads;

            NTesting::RunConcurrent<NTesting::TEnv>(
                NTesting::GetEchoCallback(),
                [&](NTesting::TEnv& env) {
                    auto f = [&] (TCont* cont) {
                        try {
                            const auto size = RandomNumber<size_t>(10 * 1024) + 1;
                            TString data = NUnitTest::RandomString(size, size);

                            const TString& compression = codecs[AtomicGet(reqCounter) % codecs.size()];
                            TRequestResult result;
                            if (AtomicIncrement(reqCounter) % 2) {
                                result = SendHttpRequest(cont, env.Port,
                                    "POST /yandsearch?xxx=1 HTTP/1.1",
                                    data,
                                    compression,
                                    "gzip,br,deflate"
                                );
                            } else {
                                result = SendHttpsRequest(cont, env.SslPort,
                                    "POST /yandsearch?xxx=1 HTTP/1.1",
                                    data,
                                    compression,
                                    "gzip,br,deflate"
                                );
                            }

                            UNIT_ASSERT(!result.WasEncoded);
                            UNIT_ASSERT_VALUES_EQUAL(result.Data, data);
                        } catch (...) {
                            UNIT_ASSERT_C(false, CurrentExceptionMessage());
                        }
                    };

                    TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                                    NCoro::NStack::TPoolAllocatorSettings{});
                    e.Create(f, "request");
                    e.Execute();
                },
                options,
                100,
                true //useTwoPorts
            );
        }
    }

    Y_UNIT_TEST(TestResponseCompression) {
        TOptions options;
        options.EnableResponseCompression = true;

        TStringStream ss;
        MakeSslConfig(ss);
        options.SslConfig = ss.Str();

        TVector<TString> codecs = {"deflate", "gzip", "br"};

        TAtomic reqCounter = 0;
        for (size_t threads : {0, 1, 5}) {
            options.Threads = threads;

            NTesting::RunConcurrent<NTesting::TEnv>(
                NTesting::GetEchoCallback(),
                [&](NTesting::TEnv& env) {
                    auto f = [&] (TCont* cont) {
                        try {
                            const auto size = RandomNumber<size_t>(10 * 1024) + 1;
                            TString data = NUnitTest::RandomString(size, size);

                            const TString& accepted = codecs[AtomicGet(reqCounter) % codecs.size()];
                            TRequestResult result;
                            if (AtomicIncrement(reqCounter) % 2) {
                                result = SendHttpRequest(cont, env.Port,
                                    "POST /yandsearch?xxx=1 HTTP/1.1",
                                    data,
                                    TString(),
                                    accepted
                                );
                            } else {
                                result = SendHttpsRequest(cont, env.SslPort,
                                    "POST /yandsearch?xxx=1 HTTP/1.1",
                                    data,
                                    TString(),
                                    accepted
                                );
                            }

                            UNIT_ASSERT_VALUES_EQUAL(result.Data, data);
                            UNIT_ASSERT(result.WasEncoded || !accepted);
                        } catch (...) {
                            UNIT_ASSERT_C(false, CurrentExceptionMessage());
                        }
                    };

                    TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                                    NCoro::NStack::TPoolAllocatorSettings{});
                    e.Create(f, "request");
                    e.Execute();
                },
                options,
                100,
                true //useTwoPorts
            );
        }
    }

    TString Gzip(const TString& v) {
        TStringStream res;
        TZLibCompress c(&res, ZLib::GZip);
        c << v;
        c.Finish();
        return res.Str();
    };

    Y_UNIT_TEST(TestMaxRequestBodySize) {
        {
            NTesting::TEnv env;
            TOptions options;
            options.MaxRequestBodySize = 1;
            env.Start(NTesting::GetEchoCallback(), options);

            auto f = [&] (TCont* cont) {
                try {
                    {
                        NTesting::TRawSocketIO io("localhost", env.Port, cont, Now() + TDuration::Seconds(10));
                        THttpOutput out(&io);

                        out << "POST /yandsearch?xxx=1 HTTP/1.1\r\n\r\n";
                        out << "x";
                        out.Finish();

                        THttpInput httpI(&io);

                        UNIT_ASSERT_VALUES_EQUAL(ParseHttpRetCode(httpI.FirstLine()), 200);
                    }
                    {
                        NTesting::TRawSocketIO io("localhost", env.Port, cont, Now() + TDuration::Seconds(10));
                        THttpOutput out(&io);

                        out << "POST /yandsearch?xxx=1 HTTP/1.1\r\n\r\n";
                        out << "xx";
                        out.Finish();

                        THttpInput httpI(&io);

                        UNIT_ASSERT_VALUES_EQUAL(ParseHttpRetCode(httpI.FirstLine()), 413);
                    }
                } catch (...) {
                    UNIT_ASSERT_C(false, CurrentExceptionMessage());
                }
            };

            TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                            NCoro::NStack::TPoolAllocatorSettings{});
            e.Create(f, "request");
            e.Execute();
        }
        {
            NTesting::TEnv env;
            TOptions options;
            options.MaxRequestBodySize = 0;
            env.Start(NTesting::GetEchoCallback(), options);

            auto f = [&] (TCont* cont) {
                try {
                    NTesting::TRawSocketIO io("localhost", env.Port, cont, Now() + TDuration::Seconds(10));
                    THttpOutput out(&io);

                    out << "POST /yandsearch?xxx=1 HTTP/1.1\r\n\r\n";
                    out << "x";
                    out.Finish();

                    THttpInput httpI(&io);

                    UNIT_ASSERT_VALUES_EQUAL(ParseHttpRetCode(httpI.FirstLine()), 413);
                } catch (...) {
                    UNIT_ASSERT_C(false, CurrentExceptionMessage());
                }
            };

            TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                            NCoro::NStack::TPoolAllocatorSettings{});
            e.Create(f, "request");
            e.Execute();
        }

        {
            TString uncompressed1 = NTesting::Repeat("abcde", 50);
            TString uncompressed2 = NTesting::Repeat("xxxxx", 51);

            TString compressed1 = Gzip(uncompressed1);
            TString compressed2 = Gzip(uncompressed2);

            NTesting::TEnv env;
            TOptions options;
            options.MaxRequestBodySize = uncompressed1.Size() + 1;
            UNIT_ASSERT(uncompressed2.Size() > options.MaxRequestBodySize);

            env.Start(NTesting::GetEchoCallback(), options);

            {
                auto f = [&] (TCont* cont) {
                    try {
                        {
                            NTesting::TRawSocketIO io("localhost", env.Port, cont, Now() + TDuration::Seconds(10));
                            io << "POST /yandsearch?xxx=1 HTTP/1.1\r\n";
                            io << "Content-Encoding: gzip\r\n";
                            io << "Content-Length: " << compressed1.Size() << "\r\n\r\n";
                            io << compressed1 << "\r\n\r\n";
                            io.Finish();

                            THttpInput httpI(&io);

                            UNIT_ASSERT_VALUES_EQUAL(ParseHttpRetCode(httpI.FirstLine()), 200);
                            TString res = httpI.ReadAll();
                            UNIT_ASSERT_VALUES_EQUAL(res, uncompressed1);
                        }
                        {
                            NTesting::TRawSocketIO io("localhost", env.Port, cont, Now() + TDuration::Seconds(10));
                            io << "POST /yandsearch?xxx=1 HTTP/1.1\r\n";
                            io << "Content-Encoding: gzip\r\n";
                            io << "Content-Length: " << compressed2.Size() << "\r\n\r\n";
                            io << compressed2 << "\r\n\r\n";
                            io.Finish();

                            THttpInput httpI(&io);

                            UNIT_ASSERT_VALUES_EQUAL(ParseHttpRetCode(httpI.FirstLine()), 413);
                        }
                    } catch (...) {
                        UNIT_ASSERT_C(false, CurrentExceptionMessage());
                    }
                };

                TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                                NCoro::NStack::TPoolAllocatorSettings{});
                e.Create(f, "request");
                e.Execute();
            }
        }
    }

    Y_UNIT_TEST(TestOverrideEncoding) {
        auto gzipCallback = [](THttpRequestEnv& env) {
            NSrvKernel::TResponse response(200, "Ok");
            response.Props().ChunkedTransfer = true;

            UNIT_ASSERT_VALUES_EQUAL(env.Method(), NSrvKernel::EMethod::POST);

            response.Headers().Add("content-encoding", "gzip");

            auto reply = env.GetReplyTransport();
            reply->SendHead(std::move(response));
            reply->SendData(Gzip(env.Body()));
            reply->SendEof();

            return TError();
        };

        auto check = [](NTesting::TEnv& env, TCont* cont, const TString& accepted) {
            TString data = "abcdef";
            NTesting::TRawSocketIO io("localhost", env.Port, cont, Now() + TDuration::Seconds(10));
            THttpOutput out(&io);

            out << "POST /yandsearch?xxx=1 HTTP/1.1\r\n";
            out << "Accept-Encoding: " << accepted << "\r\n\r\n";
            out << data;
            out.Finish();

            THttpInput httpI(&io);

            UNIT_ASSERT_VALUES_EQUAL(ParseHttpRetCode(httpI.FirstLine()), 200);
            UNIT_ASSERT_VALUES_EQUAL(httpI.ReadAll(), data);
            UNIT_ASSERT(httpI.ContentEncoded());
            auto* encoding = httpI.Headers().FindHeader("Content-Encoding");
            UNIT_ASSERT(encoding);
            UNIT_ASSERT_VALUES_EQUAL(encoding->Value(), "gzip");
        };

        {
            NTesting::TEnv env;
            TOptions options;
            options.EnableResponseCompression = true;
            env.Start(gzipCallback, options);

            auto f = [&] (TCont* cont) {
                try {
                    check(env, cont, "");
                    check(env, cont, "gzip");
                    check(env, cont, "zzz");
                    check(env, cont, "deflate");
                } catch (...) {
                    UNIT_ASSERT_C(false, CurrentExceptionMessage());
                }
            };

            TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                            NCoro::NStack::TPoolAllocatorSettings{});
            e.Create(f, "request");
            e.Execute();
        }
        {
            NTesting::TEnv env;
            TOptions options;
            options.EnableResponseCompression = false;
            env.Start(gzipCallback, options);

            auto f = [&](TCont* cont) {
                try {
                    check(env, cont, "");
                    check(env, cont, "gzip");
                    check(env, cont, "zzz");
                    check(env, cont, "deflate");
                } catch (...) {
                    UNIT_ASSERT_C(false, CurrentExceptionMessage());
                }
            };

            TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                            NCoro::NStack::TPoolAllocatorSettings{});
            e.Create(f, "request");
            e.Execute();
        }
    }

    Y_UNIT_TEST(TestNotChunked) {
        NTesting::TEnv env;
        TOptions options;
        options.EnableResponseCompression = true;
        env.Start(NTesting::GetOkCallback(), options);

        auto f = [&](TCont* cont) {
            try {
                NTesting::TRawSocketIO io("localhost", env.Port, cont, Now() + TDuration::Seconds(10));
                THttpOutput out(&io);

                out << "GET /yandsearch?xxx=1 HTTP/1.1\r\n";
                out << "Accept-Encoding: gzip\r\n\r\n";
                out.Finish();

                THttpInput httpI(&io);

                UNIT_ASSERT_VALUES_EQUAL(ParseHttpRetCode(httpI.FirstLine()), 200);
                UNIT_ASSERT_VALUES_EQUAL(httpI.ReadAll(), "OK");

                auto* transfer = httpI.Headers().FindHeader("Transfer-Encoding");
                UNIT_ASSERT(!transfer);

                UNIT_ASSERT(!httpI.ContentEncoded());
                auto* encoding = httpI.Headers().FindHeader("Content-Encoding");
                UNIT_ASSERT(!encoding);
            } catch (...) {
                UNIT_ASSERT_C(false, CurrentExceptionMessage());
            }
        };

        TContExecutor e(128 * 1024, IPollerFace::Default(), nullptr, nullptr, NCoro::NStack::EGuard::Canary,
                        NCoro::NStack::TPoolAllocatorSettings{});
        e.Create(f, "request");
        e.Execute();
    }
}

