#include <balancer/kernel/http/parser/http.h>

#include <balancer/kernel/http/parser/tests/util/custom_chunks_emitter_io.h>

#include <library/cpp/testing/unittest/registar.h>

#include <util/generic/xrange.h>

using namespace NSrvKernel;

Y_UNIT_TEST_SUITE(THttpRequestTest) {

    void DoValidateRequestLine(const TRequestLine& requestLine, EMethod method, TString path, TString cgi, int maj, int min) {
        UNIT_ASSERT_VALUES_EQUAL(requestLine.Method, method);
        UNIT_ASSERT_VALUES_EQUAL(requestLine.Path.AsStringBuf(), path);
        UNIT_ASSERT_VALUES_EQUAL(requestLine.CGI.AsStringBuf(), cgi);
        UNIT_ASSERT_VALUES_EQUAL(requestLine.MajorVersion, maj);
        UNIT_ASSERT_VALUES_EQUAL(requestLine.MinorVersion, min);
    }

    void DoTestAcceptRequestLine(TString line, EMethod method, TString path, TString cgi, int maj, int min) {
        TRequest request;
        line += "\r\n";
        TError error = request.Parse(line);
        UNIT_ASSERT_C(!error, line.Quote());
        DoValidateRequestLine(request.RequestLine(), method, path, cgi, maj, min);
    }

    void DoTestRejectRequestLine(TString line, int status) {
        TRequest request;
        line += "\r\n";
        TError error = request.Parse(line);
        UNIT_ASSERT_C(error, line.Quote());
        UNIT_ASSERT_VALUES_EQUAL_C(static_cast<int>(error.GetAs<THttpError>()->Code()), status, line.Quote());
    }

    Y_UNIT_TEST(TestMethods) {
        const auto httpBasicMethods = {
            EMethod::OPTIONS,
            EMethod::GET,
            EMethod::POST,
            EMethod::PUT,
            EMethod::PATCH,
            EMethod::DELETE,
            EMethod::TRACE,
            EMethod::CONNECT
        };

#define Y_BALANCER_HTTP_UT_METH(m) { #m, EMethod::m }

        for (auto [s, m] : std::initializer_list<std::pair<TString, EMethod>> {
            Y_BALANCER_HTTP_UT_METH(OPTIONS),
            Y_BALANCER_HTTP_UT_METH(GET),
            Y_BALANCER_HTTP_UT_METH(POST),
            Y_BALANCER_HTTP_UT_METH(PUT),
            Y_BALANCER_HTTP_UT_METH(PATCH),
            Y_BALANCER_HTTP_UT_METH(DELETE),
            Y_BALANCER_HTTP_UT_METH(TRACE),
            Y_BALANCER_HTTP_UT_METH(CONNECT),
            Y_BALANCER_HTTP_UT_METH(ACL),
            {"BASELINE-CONTROL", EMethod::BASELINE_CONTROL},
            Y_BALANCER_HTTP_UT_METH(BIND),
            Y_BALANCER_HTTP_UT_METH(CHECKIN),
            Y_BALANCER_HTTP_UT_METH(CHECKOUT),
            Y_BALANCER_HTTP_UT_METH(COPY),
            Y_BALANCER_HTTP_UT_METH(LABEL),
            Y_BALANCER_HTTP_UT_METH(LINK),
            Y_BALANCER_HTTP_UT_METH(LOCK),
            Y_BALANCER_HTTP_UT_METH(MERGE),
            Y_BALANCER_HTTP_UT_METH(MKACTIVITY),
            Y_BALANCER_HTTP_UT_METH(MKCALENDAR),
            Y_BALANCER_HTTP_UT_METH(MKCOL),
            Y_BALANCER_HTTP_UT_METH(MKREDIRECTREF),
            Y_BALANCER_HTTP_UT_METH(MKWORKSPACE),
            Y_BALANCER_HTTP_UT_METH(MOVE),
            Y_BALANCER_HTTP_UT_METH(ORDERPATCH),
            Y_BALANCER_HTTP_UT_METH(PROPFIND),
            Y_BALANCER_HTTP_UT_METH(PROPPATCH),
            Y_BALANCER_HTTP_UT_METH(REBIND),
            Y_BALANCER_HTTP_UT_METH(REPORT),
            Y_BALANCER_HTTP_UT_METH(SEARCH),
            Y_BALANCER_HTTP_UT_METH(UNBIND),
            Y_BALANCER_HTTP_UT_METH(UNCHECKOUT),
            Y_BALANCER_HTTP_UT_METH(UNLINK),
            Y_BALANCER_HTTP_UT_METH(UNLOCK),
            Y_BALANCER_HTTP_UT_METH(UPDATE),
            Y_BALANCER_HTTP_UT_METH(UPDATEREDIRECTREF),
            {"VERSION-CONTROL", EMethod::VERSION_CONTROL},
        }) {
            DoTestAcceptRequestLine(s + " / HTTP/1.1\r\n", m, "/", "", 1, 1);
            if (IsIn(httpBasicMethods, m)) {
                UNIT_ASSERT_C((HttpBasicMethods & m), m);
                UNIT_ASSERT_C(!(HttpWebDavMethods & m), m);
            } else {
                UNIT_ASSERT_C(!(HttpBasicMethods & m), m);
                UNIT_ASSERT_C((HttpWebDavMethods & m), m);
            }
        }

        DoTestRejectRequestLine("get / HTTP/1.1\r\n", 400);
        DoTestRejectRequestLine("hit / HTTP/1.1\r\n", 400);
        DoTestRejectRequestLine("GETOLOLO / HTTP/1.1\r\n", 400);
        DoTestRejectRequestLine("PRI * HTTP/2.0\r\n", 400);
    }

    Y_UNIT_TEST(TestVersions) {
        DoTestAcceptRequestLine("GET / HTTP/1.0\r\n", EMethod::GET, "/", "", 1, 0);
        DoTestAcceptRequestLine("GET / HTTP/1.1\r\n", EMethod::GET, "/", "", 1, 1);

        DoTestRejectRequestLine("GET / HTTP/2.0\r\n", 400);
        DoTestRejectRequestLine("GET / HTTP/1.2\r\n", 400);
        DoTestRejectRequestLine("GET / HTTP/0.9\r\n", 400);
        DoTestRejectRequestLine("GET / HTTP/1000\r\n", 400);
    }

    Y_UNIT_TEST(TestPath) {
        DoTestAcceptRequestLine("GET /abc HTTP/1.1\r\n", EMethod::GET, "/abc", "", 1, 1);
        DoTestAcceptRequestLine("GET /OLOLO HTTP/1.1\r\n", EMethod::GET, "/OLOLO", "", 1, 1);

        DoTestAcceptRequestLine("GET OLOLO HTTP/1.1\r\n", EMethod::GET, "OLOLO", "", 1, 1);

        DoTestRejectRequestLine("GET /\t\t HTTP/1.1\r\n", 400);
        DoTestRejectRequestLine("GET /\r HTTP/1.1\r\n", 400);
        DoTestRejectRequestLine("GET \n HTTP/1.1\r\n", 400);
        DoTestRejectRequestLine("GET ?sadfasdf HTTP/1.1\r\n", 400);
    }

    Y_UNIT_TEST(TestCgi) {
        DoTestAcceptRequestLine("GET /?test HTTP/1.1\r\n", EMethod::GET, "/", "?test", 1, 1);
        DoTestAcceptRequestLine("GET /?abc?cde?fde HTTP/1.1\r\n", EMethod::GET, "/", "?abc?cde?fde", 1, 1);
        DoTestAcceptRequestLine("GET /?k=v&u=t&c=d HTTP/1.1\r\n", EMethod::GET, "/", "?k=v&u=t&c=d", 1, 1);

        DoTestRejectRequestLine("GET /?k=\t HTTP/1.1\r\n", 400);
        DoTestRejectRequestLine("GET /?k=\r HTTP/1.1\r\n", 400);
        DoTestRejectRequestLine("GET /?k=\n HTTP/1.1\r\n", 400);
    }

    Y_UNIT_TEST(TestSetURL) {
        TRequest request;
        {
            TString line = "GET http://yandex.ru HTTP/1.1\r\n\r\n";
            auto error = request.Parse(line);
            UNIT_ASSERT_C(!error, line.Quote());
        }

        TString url = "https://yandex.ru/search?abc?cde?fde";
        auto error = request.RequestLine().SetURL(url);
        UNIT_ASSERT_C(!error, url.Quote());
        DoValidateRequestLine(request.RequestLine(), EMethod::GET, "https://yandex.ru/search", "?abc?cde?fde", 1, 1);
    }

    Y_UNIT_TEST(TestSetInvalidURL) {
        TRequest request;
        {
            TString line = "GET http://yandex.ru HTTP/1.1\r\n\r\n";
            auto error = request.Parse(line);
            UNIT_ASSERT_C(!error, line.Quote());
        }

        TString url = "??";
        auto error = request.RequestLine().SetURL(url);
        UNIT_ASSERT_C(error, url.Quote());
    }

    Y_UNIT_TEST(TestRequest) {
        TRequest request;
        TChunkList unparsed;
        TString line = "GET /?test HTTP/1.1\r\nContent-Length: 3\nX-Forwarded-For: host\nConnection: upGrade   ,   Keep-Alive\r\n\r\n";

        TCustomChunksEmitterIo stringEmitter(std::move(line));

        TError error = request.Read(&stringEmitter, unparsed, TInstant::Max());
        UNIT_ASSERT_C(!error, error->what());

        DoValidateRequestLine(request.RequestLine(), EMethod::GET, "/", "?test", 1, 1);

        const auto& headers = request.Headers();
        UNIT_ASSERT_VALUES_EQUAL(headers.Size(), 1);
        UNIT_ASSERT_VALUES_EQUAL(headers.GetFirstValue("X-Forwarded-For"), "host");

        const auto& props = request.Props();
        UNIT_ASSERT_VALUES_EQUAL(*props.ContentLength, 3);
        UNIT_ASSERT(props.KeepAlive);
        UNIT_ASSERT(props.ExplicitKeepAliveHeader);
        UNIT_ASSERT(props.ExplicitConnectionHeader);
        UNIT_ASSERT(props.UpgradeRequested);
    }

    Y_UNIT_TEST(TestRequestExceedsHeaderCount) {
        const size_t HeaderCount = MaxHeaderCount + 1;
        TRequest request;
        TChunkList unparsed;
        TString line = "GET /?test HTTP/1.1\r\n";
        for (size_t i = 0; i < HeaderCount; i++) {
            line += "a: b\r\n";
        }
        line += "\r\n";
        TCustomChunksEmitterIo stringEmitter(std::move(line));
        TError error = request.Read(&stringEmitter, unparsed, TInstant::Max());
        UNIT_ASSERT(error);
        THttpParseError* specificError = error.GetAs<THttpParseError>();
        UNIT_ASSERT(specificError);
        auto chunks = specificError->Chunks().Copy();
        UNIT_ASSERT(TString{StrInplace(chunks)}.StartsWith("too many headers to parse"));
    }
};
