#include "test_http_server.h"
#include <solomon/libs/cpp/http/client/curl/client.h>

#include <library/cpp/monlib/metrics/metric_registry.h>
#include <library/cpp/string_utils/url/url.h>
#include <library/cpp/testing/gtest/gtest.h>

#include <util/generic/algorithm.h>
#include <util/generic/size_literals.h>
#include <util/stream/input.h>
#include <util/string/ascii.h>
#include <util/system/event.h>

using namespace NSolomon;
using namespace NMonitoring;

void EmptyCallback(IHttpClient::TResult) {
}

// TODO test different clients with a very slow server
template <typename TFactory>
class THttpClientTest: public ::testing::Test {
    void SetUp() override {
        Registry_.Reset(new TMetricRegistry);
        Client_ = TFactory::Create(*Registry_);

        Server_.AddHandler("/foo", [] {
            THttpResponse r;
            r.SetHttpCode(HTTP_OK);
            return r;
        });

        ASSERT_TRUE(Server_.Start());
    }

    void TearDown() override {
        Client_.Reset();
    }

protected:
    // returns request as seen by the server
    template <typename TFunc, typename... TArgs>
    IRequestPtr Wait(TFunc requestFactory, TString path, IHttpClient::TOnComplete fn, TRequestOpts opts, TArgs&&... args) {
        auto promise = NThreading::NewPromise<void>();
        auto wrapper = [promise, fn = std::move(fn)] (auto result) mutable {
            fn(std::move(result));
            promise.SetValue();
        };

        TString p{GetPathAndQuery(path)};
        auto f = Server_.Wait(p);
        Client_->Request(requestFactory(path, std::forward<TArgs>(args)...), wrapper, opts);
        auto srvReq = f.ExtractValueSync();

        // return only after the response was handled by our client
        bool done = promise.GetFuture().Wait(TDuration::Seconds(10));
        Y_VERIFY(done);

        return srvReq;
    }

protected:
    THolder<TMetricRegistry> Registry_;
    IHttpClientPtr Client_;
    TTestHttpServer Server_;
};

struct TCurlFetcherFactory {
    static IHttpClientPtr Create(TMetricRegistry& r) {
        return CreateCurlClient({}, r);
    }
};

struct TCurlInflightLimitFactory {
    static IHttpClientPtr Create(TMetricRegistry& r) {
        return CreateCurlClient({ .MaxInflight = 2 }, r);
    }
};

using THttpClientTypes = ::testing::Types<TCurlFetcherFactory, TCurlInflightLimitFactory>;
TYPED_TEST_SUITE(THttpClientTest, THttpClientTypes);

TYPED_TEST(THttpClientTest, GetRequest) {
    auto srvReq = this->Wait(NSolomon::Get, this->Server_.Address(), EmptyCallback, {}, nullptr);

    ASSERT_EQ(srvReq->Method(), EHttpMethod::Get);
    ASSERT_EQ(srvReq->Url(), "/");
    ASSERT_TRUE(srvReq->Data().empty());
}

TYPED_TEST(THttpClientTest, PostRequest) {
    constexpr auto PAYLOAD = "payload";
    auto srvReq = this->Wait(Post, this->Server_.Address() + "/foo", EmptyCallback, {}, PAYLOAD, nullptr);

    ASSERT_EQ(srvReq->Method(), EHttpMethod::Post);
    ASSERT_EQ(srvReq->Url(), "/foo");
    ASSERT_EQ(srvReq->Data(), PAYLOAD);
}

TYPED_TEST(THttpClientTest, PutRequest) {
    constexpr auto PAYLOAD = "payload";
    auto srvReq = this->Wait(Put, this->Server_.Address() + "/foo", EmptyCallback, {}, PAYLOAD, nullptr);

    ASSERT_EQ(srvReq->Method(), EHttpMethod::Put);
    ASSERT_EQ(srvReq->Url(), "/foo");
    ASSERT_EQ(srvReq->Data(), PAYLOAD);
}

TYPED_TEST(THttpClientTest, Headers) {
    constexpr auto PAYLOAD = "payload";
    constexpr auto H_NAME = "My";
    constexpr auto H_VALUE = "Header";

    auto srvReq = this->Wait(
        Post,
        this->Server_.Address() + "/foo",
        EmptyCallback,
        {},
        PAYLOAD,
        NSolomon::Headers({{"My", "Header"}})
    );

    auto value = srvReq->Headers().Find(H_NAME);
    ASSERT_TRUE(value);
    ASSERT_EQ(*value, H_VALUE);
}

TYPED_TEST(THttpClientTest, CustomUserAgent) {
    constexpr auto H_NAME = "User-agent";
    constexpr auto H_VALUE = "test";

    auto srvReq = this->Wait(
        NSolomon::Get, this->Server_.Address() + "/foo", EmptyCallback, {}, NSolomon::Headers({{H_NAME, H_VALUE}})
    );

    auto value = srvReq->Headers().Find(H_NAME);
    ASSERT_TRUE(value);
    ASSERT_EQ(*value, H_VALUE);
}

TYPED_TEST(THttpClientTest, LargeUpload) {
    constexpr auto SIZE = NSan::PlainOrUnderSanitizer(100_MB, 10_MB);
    TString payload;
    payload.resize(SIZE);

    char cnt {0};
    GenerateN(payload.begin(), SIZE, [&cnt] {
        return cnt++;
    });

    auto srvReq = this->Wait(Post, this->Server_.Address() + "/foo", EmptyCallback, {}, payload, nullptr);

    ASSERT_EQ(srvReq->Method(), EHttpMethod::Post);
    ASSERT_EQ(srvReq->Url(), "/foo");
    ASSERT_EQ(srvReq->Data(), payload);
}

TYPED_TEST(THttpClientTest, AllRequestsAreDelivered) {
    constexpr auto REQ_COUNT = NSan::PlainOrUnderSanitizer(500, 50);

    std::atomic<i32> actualCount{0};
    TVector<NThreading::TFuture<void>> fs;
    for (auto i = 0; i < REQ_COUNT; ++i) {
        auto promise = NThreading::NewPromise<void>();
        this->Client_->Request(Get(this->Server_.Address()), [&, promise] (auto result) mutable {
            if (result.Success()) {
                ++actualCount;
                promise.SetValue();
            } else {
                promise.SetException(result.Error().Message());
            }
        });

        fs.emplace_back(promise);
    }

    // just wait all requests are done
    ASSERT_TRUE(NThreading::WaitAll(fs).Wait(TDuration::Seconds(5)));
    ASSERT_EQ(actualCount.load(), REQ_COUNT);
}

TYPED_TEST(THttpClientTest, ResponseContent) {
    constexpr auto CONTENT = "hello, world!";

    auto handler = [&] {
        THttpResponse r;
        r.SetHttpCode(HTTP_CREATED);
        r.SetContent(CONTENT);
        r.AddHeader("my", "header");
        return r;
    };

    this->Server_.AddHandler("/foo", handler);

    auto check = [&] (IHttpClient::TResult result) {
        ASSERT_TRUE(result.Success());
        auto response = result.Extract();

        ASSERT_EQ(response->Code(), HTTP_CREATED);
        ASSERT_EQ(response->Data(), CONTENT);
        auto h = response->Headers().Find("my");
        ASSERT_TRUE(h.Defined());
        ASSERT_EQ(h, "header");
    };

    auto srvReq = this->Wait(NSolomon::Get, this->Server_.Address() + "/foo", check, {}, nullptr);
}

TYPED_TEST(THttpClientTest, ReadTimeout) {
    auto handler = [&] {
        THttpResponse r;
        Sleep(TDuration::Seconds(1));
        r.SetHttpCode(HTTP_NOT_FOUND);
        return r;
    };

    this->Server_.AddHandler("/foo", handler);

    TManualEvent called;
    auto check = [&] (IHttpClient::TResult result) {
        ASSERT_TRUE(!result.Success());
        ASSERT_EQ(result.Error().Type(), TRequestError::EType::ReadTimeout);
        called.Signal();
    };

    this->Client_->Request(Get(this->Server_.Address() + "/foo"), check, {
        .ReadTimeout = TDuration::MilliSeconds(200),
        .Retries = Max<ui8>(),
    });

    const auto ok = called.WaitT(TDuration::MilliSeconds(500));
    ASSERT_TRUE(ok);
}

TYPED_TEST(THttpClientTest, ResponseSizeLimit) {
    static const TString CONTENT = [] {
        TString s;
        s.resize(500);
        FillN(s.begin(), s.size(), '\0');
        return s;
    }();

    auto handler = [&] {
        THttpResponse r;
        r.SetHttpCode(HTTP_CREATED);
        r.SetContent(CONTENT);
        return r;
    };

    this->Server_.AddHandler("/foo", handler);
    auto check = [&] (IHttpClient::TResult result) {
        ASSERT_TRUE(!result.Success());
        ASSERT_EQ(result.Error().Type(), TRequestError::EType::ResponseTooLarge);
    };

    this->Wait(NSolomon::Get, this->Server_.Address() + "/foo", check, TRequestOpts{
        .ResponseLimitBytes = 100,
    }, nullptr);
};

TYPED_TEST(THttpClientTest, Retries) {
    std::atomic<ui32> reqNum{0};
    auto handler = [&] {
        THttpResponse r;
        if (reqNum.fetch_add(1) < 3) {
            r.SetHttpCode(HTTP_INTERNAL_SERVER_ERROR);
            return r;
        }

        r.SetHttpCode(HTTP_OK);
        r.SetContent("OK");
        return r;
    };

    this->Server_.AddHandler("/foo", handler);

    TAutoEvent done;

    auto checkFail = [&] (IHttpClient::TResult result) {
        done.Signal();
        ASSERT_TRUE(result.Success());
        ASSERT_EQ(result.Value()->Code(), HTTP_INTERNAL_SERVER_ERROR);
    };

    auto checkOk = [&] (IHttpClient::TResult result) {
        ASSERT_TRUE(result.Success());
        ASSERT_EQ(result.Value()->Code(), HTTP_OK);
        done.Signal();
    };

    {
        this->Client_->Request(Get(this->Server_.Address() + "/foo"), checkOk, {
            .Retries = 3,
            .BackoffMin = TDuration::MilliSeconds(10),
        });

        auto signaled = done.WaitT(TDuration::Seconds(1));
        ASSERT_TRUE(signaled);
    }

    reqNum = 0;
    done.TSystemEvent::Reset();

    {
        this->Client_->Request(Get(this->Server_.Address() + "/foo"), checkFail, {
            .Retries = 1,
            .BackoffMin = TDuration::MilliSeconds(10),
        });

        auto signaled = done.WaitT(TDuration::Seconds(1));
        ASSERT_TRUE(signaled);
    }
}

TYPED_TEST(THttpClientTest, PostRetries) {
    std::atomic<ui32> reqNum{0};
    auto handler = [&] {
        THttpResponse r;
        if (reqNum.fetch_add(1) < 3) {
            r.SetHttpCode(HTTP_INTERNAL_SERVER_ERROR);
            return r;
        }

        r.SetHttpCode(HTTP_OK);
        r.SetContent("OK");
        return r;
    };

    this->Server_.AddHandler("/foo", handler);

    TAutoEvent done;

    auto checkFail = [&] (IHttpClient::TResult result) {
        done.Signal();
        ASSERT_TRUE(result.Success());
        ASSERT_EQ(result.Value()->Code(), HTTP_INTERNAL_SERVER_ERROR);
    };

    auto checkOk = [&] (IHttpClient::TResult result) {
        ASSERT_TRUE(result.Success());
        ASSERT_EQ(result.Value()->Code(), HTTP_OK);
        done.Signal();
    };

    {
        this->Client_->Request(Post(this->Server_.Address() + "/foo"), checkOk, {
            .Retries = 3,
            .BackoffMin = TDuration::MilliSeconds(10),
        });

        auto signaled = done.WaitT(TDuration::Seconds(1));
        ASSERT_TRUE(signaled);
    }

    reqNum = 0;
    done.TSystemEvent::Reset();

    {
        this->Client_->Request(Post(this->Server_.Address() + "/foo"), checkFail, {
            .Retries = 1,
            .BackoffMin = TDuration::MilliSeconds(10),
        });

        auto signaled = done.WaitT(TDuration::Seconds(1));
        ASSERT_TRUE(signaled);
    }
}

TYPED_TEST(THttpClientTest, StopDuringRequest) {
    auto handler = [&] {
        THttpResponse r;
        Sleep(TDuration::Seconds(1));
        r.SetHttpCode(HTTP_NOT_FOUND);
        return r;
    };

    this->Server_.AddHandler("/foo", handler);

    this->Client_->Request(Get(this->Server_.Address() + "/foo"), EmptyCallback, {
        .Retries = Max<ui8>(),
        .BackoffMin = TDuration::MilliSeconds(10),
    });

    // must not leak under asan
}

TYPED_TEST(THttpClientTest, CallbackGetsCalledOnStop) {
    auto handler = [&] {
        THttpResponse r;
        Sleep(TDuration::Seconds(1));
        r.SetHttpCode(HTTP_NOT_FOUND);
        return r;
    };

    this->Server_.AddHandler("/foo", handler);

    std::atomic<int> callbacks{0};
    auto cb = [&] (IHttpClient::TResult result) {
        ASSERT_EQ(result.Error().Type(), TRequestError::EType::RequestInitializationFailed);
        callbacks++;
    };

    static constexpr auto REQ_COUNT = 5;
    for (auto i = 0; i < REQ_COUNT; ++i) {
        this->Client_->Request(Get(this->Server_.Address() + "/foo"), cb, {
            .Retries = Max<ui8>(),
            .BackoffMin = TDuration::MilliSeconds(10),
        });
    }

    this->Client_.Reset();
    ASSERT_EQ(callbacks.load(), REQ_COUNT);
}
