#include <mail/template_master/lib/router/node.h>
#include <mail/template_master/lib/router/retry.h>
#include <mail/template_master/ut/mock/http_client_mock.h>
#include <mail/template_master/ut/mock/node_mock.h>
#include <mail/template_master/ut/environment.h>
#include <mail/template_master/ut/utils.h>

#include <library/cpp/testing/unittest/registar.h>

#include <chrono>

namespace NTemplateMaster::NTests {

using ::testing::Return;
using ::testing::ReturnRef;
using ::testing::ReturnRefOfCopy;
using ::testing::InSequence;
using ::testing::InvokeArgument;
using ::testing::_;
using namespace std::chrono_literals;
template<typename T>
using TRequestWithRetries = NTemplateMaster::NRouter::TRequestWithRetries<T>;

class TRetryRequestTest : public TTestBase, public TWithSpawn {
    UNIT_TEST_SUITE(TRetryRequestTest)
        UNIT_TEST(SuccessRequestWithoutRetries)
        UNIT_TEST(SuccessRequestWithRetries)
        UNIT_TEST(FailAfterAllRetries)
    UNIT_TEST_SUITE_END();
public:
    void SetUp() override {
        TTestsEnvironment::SetUp();
        Context = GetContext();
        Io = std::make_unique<boost::asio::io_context>();
    }

    void SuccessRequestWithoutRetries() {
        auto node = GetNodeMock();
        THttpOptions opts;
        size_t maxRetries = 3;
        auto nodeGenerator = [=]() {
            return node;
        };
        std::string schema = "http";
        size_t port = 99;
        std::string body(R"({"body": "qwerty"})");
        TDuration banDuration(1s);
        auto httpClient = GetHttpClientMock();
        auto retry = std::make_shared<TRequestWithRetries<decltype(nodeGenerator)>>
                (opts, maxRetries, nodeGenerator, schema, port, body, banDuration, httpClient);
        Spawn([=](TYield yield) {
            const InSequence s;
            const std::string host = "localhost";
            boost::system::error_code ec;
            EXPECT_CALL(*node, GetAddress()).WillOnce(ReturnRef(host));
            EXPECT_CALL(*httpClient, async_run(_, _, _, _))
                .WillOnce(InvokeArgument<3>(ec, THttpResponse{200, {}, "body", ""}));
            EXPECT_CALL(*node, ResetBan()).WillOnce(Return(1));
            auto result = retry->Run(Context, yield);
            ASSERT_TRUE(result);
            ASSERT_EQ(result.value().status, 200);
            ASSERT_EQ(result.value().body, "body");
        });
    }

    void SuccessRequestWithRetries() {
        auto node = GetNodeMock();
        THttpOptions opts;
        size_t maxRetries = 2;
        auto nodeGenerator = [=]() {
            return node;
        };
        std::string schema = "http";
        size_t port = 99;
        std::string body(R"({"body": "qwerty"})");
        TDuration banDuration(1s);
        auto httpClient = GetHttpClientMock();
        auto retry = std::make_shared<TRequestWithRetries<decltype(nodeGenerator)>>
                (opts, maxRetries, nodeGenerator, schema, port, body, banDuration, httpClient);
        Spawn([=](TYield yield) {
            const InSequence s;
            std::string host = "localhost";
            EXPECT_CALL(*node, GetAddress()).WillOnce(ReturnRef(host));
            EXPECT_CALL(*httpClient, async_run(_, _, _, _))
                    .WillOnce(InvokeArgument<3>(boost::system::error_code{}, THttpResponse{500, {}, "", ""}));
            EXPECT_CALL(*node, Ban(banDuration)).WillOnce(Return(1s));
            EXPECT_CALL(*node, GetAddress()).WillOnce(ReturnRef(host));
            EXPECT_CALL(*httpClient, async_run(_, _, _, _))
                    .WillOnce(InvokeArgument<3>(boost::system::error_code{}, THttpResponse{200, {}, "123", ""}));
            EXPECT_CALL(*node, ResetBan()).WillOnce(Return(1));
            auto result = retry->Run(Context, yield);
            ASSERT_TRUE(result);
            ASSERT_EQ(result.value().status, 200);
            ASSERT_EQ(result.value().body, "123");
        });
    }

    void FailAfterAllRetries() {
        auto node = GetNodeMock();
        THttpOptions opts;
        size_t maxRetries = 1;
        auto nodeGenerator = [=]() {
            return node;
        };
        std::string schema = "http";
        size_t port = 99;
        std::string body(R"({"body": "qwerty"})");
        TDuration banDuration(1s);
        auto httpClient = GetHttpClientMock();
        auto retry = std::make_shared<TRequestWithRetries<decltype(nodeGenerator)>>
                (opts, maxRetries, nodeGenerator, schema, port, body, banDuration, httpClient);
        Spawn([=](TYield yield) {
            const InSequence s;
            std::string host = "localhost";
            EXPECT_CALL(*node, GetAddress()).WillOnce(ReturnRef(host));
            EXPECT_CALL(*httpClient, async_run(_, _, _, _))
                    .WillOnce(InvokeArgument<3>(boost::system::error_code{}, THttpResponse{500, {}, "", ""}));
            EXPECT_CALL(*node, Ban(banDuration)).WillOnce(Return(1s));

            auto ec = boost::system::errc::make_error_code(boost::system::errc::not_supported);
            EXPECT_CALL(*node, GetAddress()).WillOnce(ReturnRef(host));
            EXPECT_CALL(*httpClient, async_run(_, _, _, _))
                    .WillOnce(InvokeArgument<3>(ec, THttpResponse{0, {}, "", ""}));
            EXPECT_CALL(*node, Ban(banDuration)).WillOnce(Return(1s));
            EXPECT_CALL(*node, GetAddress()).WillOnce(ReturnRef(host));

            auto result = retry->Run(Context, yield);
            ASSERT_FALSE(result);
        });
    }
private:
    NTemplateMaster::TContextPtr Context;
};

}

UNIT_TEST_SUITE_REGISTRATION(NTemplateMaster::NTests::TRetryRequestTest)
