#include "sender_neh.h"
#include "waiter.h"

#include <library/cpp/balloc/optional/operators.h>
#include <library/cpp/logger/global/global.h>
#include <library/cpp/neh/multiclient.h>
#include <library/cpp/neh/location.h>

#include <util/datetime/base.h>
#include <util/thread/pool.h>
#include <util/system/mutex.h>
#include <util/system/guard.h>

#include <atomic>

namespace NRTYServer {
    namespace {
        struct TNehRequestsResource {
            NNeh::TMultiClientPtr MC;
            std::atomic<bool>& AllowedRecv;
            std::atomic<bool>& AllowedSend;
            TAtomic RequestsCount = 0;

        public:
            TNehRequestsResource(std::atomic<bool>& isActive, std::atomic<bool>& allowedAdd)
                : MC(NNeh::CreateMultiClient())
                , AllowedRecv(isActive)
                , AllowedSend(allowedAdd)
            {}

            inline bool GetAllowedRecv() const {
                return AllowedRecv || AtomicGet(RequestsCount);
            }

            inline void Stop() {
                MC->Interrupt();
                while (AtomicGet(RequestsCount)) {
                    Sleep(TDuration::MicroSeconds(1000));
                }
            }
        };

        class THandleFullInfo : public IObjectInQueue, public TAtomicRefCount<THandleFullInfo> {
        public:
            using TPtr = TIntrusivePtr<THandleFullInfo>;
        private:
            NNeh::TMessage Message;
            const IHandleListener* Listener;
            const ui32 AttemptsCount;

            NNeh::THandleRef Handler;
            TInstant StartTime;

            TAtomic& InFly;
            TQueryInfo Info;
            TWaiter& Queue;
            TPtr SelfPtr;
            TNehRequestsResource* RResource = nullptr;
            TMultiRequesterBase::TRequestData::TPtr RequestData;
            TMutex Mutex;
            const TMultiRequesterBase::TResendDurations& ResendDurations;

            inline void RenewInfo() {
                NNeh::TParsedLocation loc(Message.Addr);
                Info.Host = loc.Host;
                Info.Port = loc.GetPort();
                Info.Addr = Message.Addr;
                Info.SendComplete = Handler && Handler->MessageSendedCompletely();
            }

        public:
            template<class Message>
            inline THandleFullInfo(Message message, const IHandleListener* listener, ui32 attemptsCount, TAtomic& inFly, TWaiter& queue, TMultiRequesterBase::TRequestData::TPtr requestData, const TMultiRequesterBase::TResendDurations& resendDurations)
                : Message(message)
                , Listener(listener)
                , AttemptsCount(attemptsCount)
                , InFly(inFly)
                , Queue(queue)
                , SelfPtr(this)
                , RequestData(requestData)
                , ResendDurations(resendDurations)
            {
                AtomicIncrement(InFly);
                RenewInfo();
            }

            inline ~THandleFullInfo() {
                AtomicDecrement(InFly);
            }

            inline TPtr GetPtr() {
                return SelfPtr;
            }

            inline void FinishReply() noexcept {
                AtomicDecrement(RResource->RequestsCount);
                SelfPtr.Drop();
            }

            inline const IHandleListener* GetListener() const {
                return Listener;
            }

            inline const TQueryInfo& GetInfo() {
                Info.Duration = Now() - StartTime;
                return Info;
            }

            inline bool Resend() {
                if (Info.Attempts >= AttemptsCount)
                    return false;
                if (!Listener->OnResend(GetInfo(), Message.Addr, Message, Handler ? Handler->Response() : nullptr))
                    return false;
                RenewInfo();
                AtomicDecrement(RResource->RequestsCount);
                TDuration waitTime = TDuration::Zero();
                if (!ResendDurations.empty()) {
                    waitTime = ResendDurations[Min<size_t>(ResendDurations.size(), Info.Attempts) - 1];
                }
                Queue.Add(waitTime.ToDeadLine(), this);
                return true;
            }

            inline bool Send() {
                if (!!Handler) {
                    Handler->Cancel();
                    Handler.Reset();
                }
                StartTime = Now();
                TInstant recvDeadline = StartTime + Listener->GetRecvTimeout();
                ++Info.Attempts;
                NNeh::TServiceStatRef stat;
                if (AttemptsCount >= Info.Attempts) {
                    NNeh::IMultiClient::TRequest req(Message, recvDeadline, this);
                    Handler = RResource->MC->Request(req);
                    return true;
                } else {
                    return false;
                }
            }

            inline void Cancel(const TString& reason) {
                if (!!Handler) {
                    Handler->Cancel();
                    Handler.Reset();
                }
                Listener->OnCancel(GetInfo(), reason);
                FinishReply();
            }

            inline void Notify() {
                const auto& info = GetInfo();
                NNeh::TResponseRef response = Handler->Get();
                Listener->OnNotify(info, response);
                FinishReply();
            }

            inline bool ReplyIsOK() {
                return Handler && Handler->Response() && Listener->ReplyIsOk(GetInfo(), *Handler->Response());
            }

            void Process(void* threadSpecificResource) override {
                TPtr ptrG = GetPtr();
                {
                    TGuard<TMutex> g(Mutex);
                    RResource = (TNehRequestsResource*)threadSpecificResource;
                    AtomicIncrement(RResource->RequestsCount);
                    try {
                        Listener->OnStart(GetInfo());
                        if (!RResource->AllowedSend || !Send()) {
                            Cancel("Stopping");
                        }
                    }
                    catch (...) {
                        Cancel("Error: " + CurrentExceptionMessage());
                    }
                }
            }

            inline void ProcessReply(NNeh::IMultiClient::TEvent::TType evType) {
                TPtr ptrG = GetPtr();
                {
                    TGuard<TMutex> g(Mutex);
                    try {
                        if (ReplyIsOK()) {
                            Notify();
                        } else if (!Resend()) {
                            switch (evType) {
                            case NNeh::IMultiClient::TEvent::Response:
                                Notify();
                                break;
                            case NNeh::IMultiClient::TEvent::Timeout:
                                Cancel("Timeout");
                                break;
                            default:
                                FAIL_LOG("Invalid event type");
                            }
                        }
                    } catch (...) {
                        TString msg = CurrentExceptionMessage();
                        ERROR_LOG << "exeption while process msg: " << msg << Endl;
                        Cancel(msg);
                    }
                }
            }

        };

        class TNehRequestRepliesWatcher : public IObjectInQueue {
        private:
            TNehRequestsResource& RResource;
        public:
            TNehRequestRepliesWatcher(TNehRequestsResource& res)
                : RResource(res)
            {}

            void CheckReplies() {
                try {
                    NNeh::IMultiClient::TEvent ev;
                    if (RResource.MC->Wait(ev)) {
                        THandleFullInfo* result = (THandleFullInfo*)ev.UserData;
                        VERIFY_WITH_LOG(result, "Incorrect event.UserData");
                        result->ProcessReply(ev.Type);
                    }
                } catch (...) {
                    FAIL_LOG("Exception: %s", CurrentExceptionMessage().data());
                }
            }

            void Process(void* /*ThreadSpecificResource*/) override {
                ThreadDisableBalloc();
                DEBUG_LOG << "Start replies watcher" << Endl;
                NNeh::TResponseRef resp;
                while (RResource.GetAllowedRecv()) {
                    CheckReplies();
                }
                DEBUG_LOG << "Replies watcher finished" << Endl;
            }
        };
    }

    bool IHandleListener::OnResend(const TQueryInfo& /*info*/, TString& /*newAdrr*/, const NNeh::TMessage& /*msg*/, const NNeh::TResponse* /*resp*/) const {
        return false;
    }

    bool IHandleListener::ReplyIsOk(const TQueryInfo& /*info*/, const NNeh::TResponse& rep) const {
        return !rep.IsError();
    }

    class TMultiRequesterBase::TImpl {
    private:
        std::atomic<bool> AllowedRecv;
        std::atomic<bool> AllowedSend;
        std::atomic<bool> AllowAdd;
        TThreadPoolBinder<TSimpleThreadPool, TMultiRequesterBase::TImpl> WatcherNewRequests;
        TSimpleThreadPool WatcherReplies;
        ui32 AttemptsCount;
        TAtomic InFly = 0;
        TMutex RequestResourceMutex;
        TVector<THolder<TNehRequestsResource>> RequestResource;
        TResendDurations ResendDurations;
        TWaiter ResendWaiter;

    public:
        TImpl(ui32 attemptsCount, const TResendDurations& resendDurations, const TString& threadName)
            : AllowedRecv(true)
            , AllowedSend(true)
            , AllowAdd(true)
            , WatcherNewRequests(this, threadName + "Req")
            , WatcherReplies(threadName + "Rep")
            , AttemptsCount(attemptsCount)
            , ResendDurations(resendDurations)
            , ResendWaiter(WatcherNewRequests)
        {}

        void* CreateThreadSpecificResource() {
            auto result = new TNehRequestsResource(AllowedRecv, AllowedSend);
            WatcherReplies.SafeAddAndOwn(THolder(new TNehRequestRepliesWatcher(*result)));
            {
                TGuard<TMutex> guard(RequestResourceMutex);
                RequestResource.emplace_back(result);
            }
            return result;
        }

        void DestroyThreadSpecificResource(void*) {
        }

        ui32 QueueSize() const {
            return AtomicGet(InFly);
        }

        ui32 SendQueueSize() const {
            return WatcherNewRequests.Size();
        }

        void Stop(bool wait) {
            AllowAdd = false;
            if (wait) {
                while (AtomicGet(InFly)) {
                    Sleep(TDuration::MilliSeconds(10));
                }
            }
            AllowedSend = false;
            ResendWaiter.Stop(wait);
            WatcherNewRequests.Stop();
            AllowedRecv = false;

            TGuard<TMutex> guard(RequestResourceMutex);
            for (ui32 i = 0; i < RequestResource.size(); ++i)
                RequestResource[i]->Stop();
            WatcherReplies.Stop();
            RequestResource.clear();
        }

        void Start(ui32 threadsNum) {
            AllowedRecv = true;
            AllowedSend = true;
            AllowAdd = true;
            WatcherReplies.Start(threadsNum);
            WatcherNewRequests.Start(threadsNum);
            ResendWaiter.Start();
        }

        void Send(THandleFullInfo* fi) {
            if (AllowAdd && !WatcherNewRequests.Add(fi))
                ERROR_LOG << "Cannot add request to queue" << Endl;
        }

        template<class Message>
        void Send(Message message, const IHandleListener* handleCallback, TRequestData::TPtr requestData) {
            Send(new THandleFullInfo(message, handleCallback, AttemptsCount, InFly, ResendWaiter, requestData, ResendDurations));
        }
    };

    ui32 TMultiRequesterBase::QueueSize() const {
        return Impl_->QueueSize();
    }

    ui32 TMultiRequesterBase::SendQueueSize() const {
        return Impl_->SendQueueSize();
    }

    TMultiRequesterBase::TMultiRequesterBase(ui32 attemptsCount, const TResendDurations& resendDurations, const TString& threadName)
        : Impl_(MakeHolder<TImpl>(attemptsCount, resendDurations, threadName))
    {}

    TMultiRequesterBase::~TMultiRequesterBase() {
    }

    void TMultiRequesterBase::Stop(bool wait) {
        Impl_->Stop(wait);
    }

    void TMultiRequesterBase::Start(ui32 threadsNum) {
        Impl_->Start(threadsNum);
    }

    void TMultiRequesterBase::Send(const NNeh::TMessage& message, const IHandleListener* handleCallback, TRequestData::TPtr requestData) {
        Impl_->Send(message, handleCallback, requestData);
    }

    void TMultiRequesterBase::Send(NNeh::TMessage&& message, const IHandleListener* handleCallback, TRequestData::TPtr requestData) {
        Impl_->Send(message, handleCallback, requestData);
    }

    TMultiRequester::TMultiRequester(ui32 attemptsCount, NNeh::IProtocol* protocol, const TResendDurations& resendDurations, const TString& threadName)
        : TMultiRequesterBase(attemptsCount, resendDurations, threadName)
        , ProtocolScheme(protocol->Scheme())
    {}

    void TMultiRequester::Send(const TString& host, ui16 port, NNeh::TMessage& message, const IHandleListener* handleCallback) {
        message.Addr = TString::Join(ProtocolScheme, "://", host, ':',  ToString(port), '/');
        TMultiRequesterBase::Send(message, handleCallback, nullptr);
    }

};
