#include <saas/library/sender_neh/sender_neh.h>
#include <saas/util/network/neh_server.h>
#include <library/cpp/logger/global/global.h>
#include <library/cpp/testing/unittest/registar.h>
#include <util/thread/pool.h>

using namespace NRTYServer;

namespace {
    const TString NEH_PROTOCOL = "tcp2";

    class THandleListener : public IHandleListener {
    public:
        typedef TAtomicSharedPtr<IHandleListener> TPtr;
        THandleListener(TAtomic& counter, TManualEvent& started, ui32 id)
            : Counter(counter)
            , Started(started)
            , Id(id)
        {
            AtomicIncrement(Counter);
        }

        virtual ~THandleListener() {
            AtomicDecrement(Counter);
        }

        void OnStart(const TQueryInfo& /*info*/) const override {
            DEBUG_LOG << Id << ": OnStart" << Endl;
            AtomicIncrement(StartCount);
            Started.Signal();
        }

        void OnNotify(const TQueryInfo& /*info*/, NNeh::TResponseRef& /*ref*/) const override {
            DEBUG_LOG << Id << ": OnNotify" << Endl;
            AtomicIncrement(OkCount);
            delete this;
        }

        void OnCancel(const TQueryInfo& /*info*/, const TString& reason) const override {
            AtomicIncrement(FailCount);
            DEBUG_LOG << Id << ": OnCancel " << reason << Endl;
            delete this;
        }

        TDuration GetRecvTimeout() const override {
            return TDuration::Seconds(5);
        }

        mutable TAtomic StartCount = 0;
        mutable TAtomic OkCount = 0;
        mutable TAtomic FailCount = 0;

    protected:
        TAtomic& Counter;
        TManualEvent& Started;
        ui32 Id;
    };

    class TResendListener : public THandleListener {
    public:
        TResendListener(TAtomic& counter, TManualEvent& started, ui32 id)
            : THandleListener(counter, started, id)
        {}

        bool OnResend(const TQueryInfo& info, TString& newAdrr, const NNeh::TMessage& /*msg*/, const NNeh::TResponse* /*resp*/) const override {
            DEBUG_LOG << Id << ": OnResend, duration: " << info.Duration.ToString() << Endl;
            newAdrr = NEH_PROTOCOL + "://localhost:17000/abc";
            return true;
        }
    };

    class TServer : public NUtil::TAbstractNehServer {
        struct TClient : public IObjectInQueue {
            TClient(NNeh::IRequestRef request, TDuration responseTime)
                : Request(request)
                , ResponseTime(responseTime)
            {}

            void Process(void* /*ThreadSpecificResource*/) override {
                THolder<TClient> suicide(this);
                DEBUG_LOG << "Begin process request, duration=" << ResponseTime.ToString() << Endl;
                Sleep(ResponseTime);
                NNeh::TDataSaver resp;
                resp << "Hello!";
                Request->SendReply(resp);
                DEBUG_LOG << "SendReply" << Endl;
            }
            NNeh::IRequestRef Request;
            TDuration ResponseTime;
        };
    public:
        TServer(TDuration respTime = TDuration::Seconds(10))
            : TAbstractNehServer([]() {
                TOptions opt(THttpServerOptions(), NEH_PROTOCOL);
                opt.SetThreads(1);
                return opt;
             }())
            , ResponseTime(respTime)
        {
            Start();
        }
        TAutoPtr<IObjectInQueue> DoCreateClientRequest(ui64 /*id*/, NNeh::IRequestRef req) final {
            return new TClient(req, ResponseTime);
        }

    private:
        TDuration ResponseTime;
    };

}

Y_UNIT_TEST_SUITE(NehSenderSuite) {
    Y_UNIT_TEST(StopTestNoInFly) {
        InitGlobalLog2Console(8);
        TAtomic counter(0);
        TManualEvent started;
        TMultiRequester requester(1, NNeh::ProtocolFactory()->Protocol(NEH_PROTOCOL));
        requester.Start(0);
        NNeh::TMessage nehMessage("", "data");
        for (ui32 i = 0; i < 5; ++i)
            requester.Send("localhost", 1234, nehMessage, new THandleListener(counter, started, i));
        requester.Stop();
        UNIT_ASSERT(!counter);
    }

    Y_UNIT_TEST(StopTestOneInFly) {
        InitGlobalLog2Console(8);
        TServer server;
        TAtomic counter(0);
        TManualEvent started;
        TMultiRequester requester(1, NNeh::ProtocolFactory()->Protocol(NEH_PROTOCOL));
        requester.Start(1);
        NNeh::TMessage nehMessage("", "data");
        for (ui32 i = 0; i < 5; ++i)
            requester.Send("localhost", 17000, nehMessage, new THandleListener(counter, started, i));
        started.Wait();
        requester.Stop();
        server.Stop();
        UNIT_ASSERT(!counter);
    }

    Y_UNIT_TEST(Resend) {
        InitGlobalLog2Console(8);
        TServer server(TDuration::Seconds(1));
        TAtomic counter(0);
        TManualEvent started;
        TMultiRequester requester(2, NNeh::ProtocolFactory()->Protocol(NEH_PROTOCOL));
        requester.Start(1);
        NNeh::TMessage nehMessage("", "data");
        TResendListener listener(counter, started, 1);
        requester.Send("localhost", 1234, nehMessage, &listener);
        Sleep(TDuration::Seconds(10));
        requester.Stop();
        server.Stop();
        UNIT_ASSERT_EQUAL(2, listener.StartCount);
        UNIT_ASSERT_EQUAL(listener.OkCount, 1);
        UNIT_ASSERT_EQUAL(listener.FailCount, 0);
    }

    Y_UNIT_TEST(Send) {
        InitGlobalLog2Console(8);
        TServer server(TDuration::Zero());
        TAtomic counter(0);
        TManualEvent started;
        TMultiRequester requester(2, NNeh::ProtocolFactory()->Protocol(NEH_PROTOCOL));
        requester.Start(10);
        NNeh::TMessage nehMessage("", "data");
        THandleListener listener(counter, started, 1);
        for (int i = 0; i < 100; ++i) {
            requester.Send("localhost", 17000, nehMessage, &listener);
        }
        while (AtomicGet(listener.OkCount) + AtomicGet(listener.FailCount) < 100)
            Sleep(TDuration::MilliSeconds(100));
        requester.Stop();
        server.Stop();
        UNIT_ASSERT_EQUAL(100, listener.StartCount);
        UNIT_ASSERT_EQUAL(listener.OkCount + listener.FailCount, 100);
    }
}
