#include <balancer/client/ut/util/env.h>

#include <balancer/kernel/client_request/backend.h>
#include <balancer/kernel/client_request/backend_config.h>

#include <balancer/modules/balancer/module.h>
#include <library/cpp/testing/unittest/registar.h>

using namespace NSrvKernel;
using namespace NBalancerClient;
using namespace NBalancerClient::NTesting;

using NThreading::TFuture;
using NThreading::TPromise;


Y_UNIT_TEST_SUITE(TestClientRequest) {
    Y_UNIT_TEST(Simple) {
        TEnv env;
        env.Start();

        class TFakeProtocolBackend
            : public TBackendProtocolImpl<TFakeProtocolBackend>
        {
        public:
            TFakeProtocolBackend(const TBackendConfig& config)
                : Config(config)
            {
            }

            static TStringBuf ProtocolName() {
                return "fake";
            }

            const TBackendConfig& Config;
        };

        env.MainTask().GetProtocolFactory().RegisterProtocolImpl<TFakeProtocolBackend>();

        NJson::TJsonValue config;
        config["maxlen"] = 65536;
        config["maxreq"] = 65536;
        auto& rr = config["balancer"]["rr"];

        NJson::TJsonValue proxy;
        proxy["cached_ip"] = "127.0.0.1";
        proxy["host"] = "localhost";
        proxy["port"] = 12345;
        proxy["backend_timeout"] = "1s";
        proxy["protocols"].AppendValue(TFakeProtocolBackend::ProtocolName());

        NJson::TJsonValue item;
        item["proxy"] = proxy;
        rr.AppendValue(item);
        rr.AppendValue(item);

        config["balancer"]["attempts"] = 2;

        auto source = env.CreateSource(config);

        struct TClientRequest
            : public IClientRequest
        {
            size_t Counter = 0;
            bool HasResult = false;

            TError DoRun(const TConnDescr& /*connDescr*/, const TBackendProtocols& backend, const NSrvKernel::THostInfo& /*hostInfo*/) noexcept override {
                ++Counter;

                auto* impl = backend.GetImpl<TFakeProtocolBackend>();
                UNIT_ASSERT(impl);

                UNIT_ASSERT_VALUES_EQUAL(impl->Config.host(), "localhost");
                UNIT_ASSERT_VALUES_EQUAL(impl->Config.port(), 12345);

                if (Counter > 1) {
                    HasResult = true;
                    return {};
                } else {
                    return Y_MAKE_ERROR(TBackendError(Y_MAKE_ERROR(yexception())));
                }
            }
        };

        auto request = MakeIntrusive<TClientRequest>();
        auto context = source->SendRequest(request, MakeHolder<NSrvKernel::TAttemptsHolderBase>(2, 0), nullptr);

        TBalancerClientResponse response;
        if (auto error = context->GetFuture().ExtractValueSync().GetResponse().AssignTo(response)) {
            UNIT_ASSERT(false);
        } else {
            UNIT_ASSERT_VALUES_EQUAL(request->Counter, 2);
            UNIT_ASSERT(request->HasResult);
        }
    }

    Y_UNIT_TEST(Srcrwr) {
        TEnv env;
        env.Start();

        class TFakeProtocolBackend: public TBackendProtocolImpl<TFakeProtocolBackend>
        {
          public:
            TFakeProtocolBackend(const TBackendConfig& /*config*/)
            {
            }

            static TStringBuf ProtocolName() {
                return "fake";
            }
        };

        env.MainTask().GetProtocolFactory().RegisterProtocolImpl<TFakeProtocolBackend>();

        NJson::TJsonValue config;
        config["maxlen"] = 65536;
        config["maxreq"] = 65536;
        auto& rr = config["balancer"]["rr"];

        NJson::TJsonValue proxy;
        proxy["cached_ip"] = "127.0.0.1";
        proxy["host"] = "localhost";
        proxy["port"] = 12345;
        proxy["backend_timeout"] = "1s";
        proxy["protocols"].AppendValue(TFakeProtocolBackend::ProtocolName());

        NJson::TJsonValue item;
        item["proxy"] = proxy;
        rr.AppendValue(item);
        rr.AppendValue(item);

        config["balancer"]["attempts"] = 2;

        auto source = env.CreateSource(config);

        class TAttemptsHolder: public NSrvKernel::TAttemptsHolderBase {
          public:
            TAttemptsHolder(): TAttemptsHolderBase(1, 0) {
                SrcrwrAddrs_.SetAddrs({{"ya.ru" ,"127.0.0.2", 54321}});
            }

            NSrcrwr::TAddrs *SrcrwrAddrs() override {
                return &SrcrwrAddrs_;
            }

          private:
            NSrcrwr::TAddrs SrcrwrAddrs_;
        };

        struct TClientRequest: public IClientRequest
        {
            TError DoRun(const TConnDescr& /*connDescr*/, const TBackendProtocols& /*backend*/, const NSrvKernel::THostInfo& hostInfo) noexcept override {
                UNIT_ASSERT_VALUES_EQUAL(hostInfo.Host, "ya.ru");
                UNIT_ASSERT_VALUES_EQUAL(hostInfo.CachedIp, "127.0.0.2");
                UNIT_ASSERT_VALUES_EQUAL(hostInfo.Port, 54321);
                UNIT_ASSERT(hostInfo.IsSrcRwr);
                return {};
            }
        };

        auto request = MakeIntrusive<TClientRequest>();
        auto context = source->SendRequest(request, MakeHolder<TAttemptsHolder>(), nullptr);

        context->GetFuture().GetValueSync();
    }

    Y_UNIT_TEST(Trailers) {
        TEnv env;
        env.Start();

        NJson::TJsonValue config;
        config["maxlen"] = 65536;
        config["maxreq"] = 65536;
        auto& rr = config["balancer"]["rr"];

        NJson::TJsonValue proxy;
        proxy["cached_ip"] = "127.0.0.1";
        proxy["host"] = "localhost";
        proxy["port"] = 12345;
        proxy["backend_timeout"] = "1s";

        NJson::TJsonValue item;
        item["proxy"] = proxy;
        rr.AppendValue(item);
        rr.AppendValue(item);

        config["balancer"]["attempts"] = 2;

        auto source = env.CreateSource(config);

        struct TClientRequest: public IClientRequest
        {
            TError DoRun(const TConnDescr& connDescr, const TBackendProtocols&, const NSrvKernel::THostInfo& /*hostInfo*/) noexcept override {
                THeaders trailers;
                trailers.Add("Some-Trailer", "some value");
                Y_PROPAGATE_ERROR(connDescr.Output->SendTrailers(std::move(trailers), TInstant::Now()));
                return {};
            }
        };

        auto request = MakeIntrusive<TClientRequest>();
        auto context = source->SendRequest(request, MakeHolder<NSrvKernel::TAttemptsHolderBase>(2, 0), nullptr);

        TBalancerClientResponse response;
        TryRethrowError(context->GetFuture().ExtractValueSync().GetResponse().AssignTo(response));
        UNIT_ASSERT_VALUES_EQUAL(response.Trailers.GetFirstValue("Some-Trailer"), "some value");
    }

    Y_UNIT_TEST(OutputStreaming) {
        TEnv env;
        env.Start();

        NJson::TJsonValue config;
        config["maxlen"] = 65536;
        config["maxreq"] = 65536;
        auto& rr = config["balancer"]["rr"];

        NJson::TJsonValue proxy;
        proxy["cached_ip"] = "127.0.0.1";
        proxy["host"] = "localhost";
        proxy["port"] = 12345;
        proxy["backend_timeout"] = "1s";

        NJson::TJsonValue item;
        item["proxy"] = proxy;
        rr.AppendValue(item);
        rr.AppendValue(item);

        config["balancer"]["attempts"] = 2;

        auto source = env.CreateSource(config);

        struct TClientRequest: public IClientRequest
        {
            TError DoRun(const TConnDescr& connDescr, const TBackendProtocols&, const NSrvKernel::THostInfo& /*hostInfo*/) noexcept override {
                TChunkList chunkList;
                chunkList.Push(NewChunk("test_data"));
                Y_PROPAGATE_ERROR(connDescr.Output->Send(std::move(chunkList), TInstant::Now()));
                return {};
            }
        };

        auto request = MakeIntrusive<TClientRequest>();
        auto outputChannel = MakeAtomicShared<NSrvKernel::TW2UChannel<TString>>(Max<size_t>());
        auto context = source->SendRequest({}, {}, request, MakeHolder<NSrvKernel::TAttemptsHolderBase>(2, 0), nullptr, outputChannel);
        TString data;
        UNIT_ASSERT_EQUAL(outputChannel->Receive(data, TInstant::Max()), NSrvKernel::EChannelStatus::Success);
        UNIT_ASSERT_EQUAL(data, "test_data");

        context->GetFuture().GetValueSync();
    }
}
