#pragma once

#include <balancer/kernel/coro/channel.h>
#include <balancer/kernel/coro/waked.h>
#include <balancer/kernel/helpers/errors.h>
#include <balancer/kernel/http/parser/http.h>

class THttpHeaders;

namespace NBalancerServer {

    class IHttpReplyTransport
        : public TAtomicRefCount<IHttpReplyTransport, IHttpReplyTransport>
    {
    public:
        struct TSendingRef
            : TPointerBase<TSendingRef, IHttpReplyTransport>
        {
            TSendingRef() = default;

            explicit TSendingRef(TAtomicSharedPtr<IHttpReplyTransport> t)
                : Holder_(std::move(t))
                , SenderRefCount_(&*Holder_)
            {
            }

            IHttpReplyTransport* Get() const {
                return Holder_.Get();
            }

            void Reset() {
                SenderRefCount_.Reset();
                Holder_.Reset();
            }
        private:
            friend class THttpRequestEnvImpl;
            auto GetHolder() {
                return Holder_;
            }
        private:
            TAtomicSharedPtr<IHttpReplyTransport> Holder_;
            TIntrusivePtr<IHttpReplyTransport> SenderRefCount_;
        };
    public:
        virtual ~IHttpReplyTransport();

        bool ChooseEncoding(const NSrvKernel::THeaders& requestHeaders);

        void SendHead(NSrvKernel::TResponse&& response);

        void SendData(TStringBuf data) {
            if (Encoder_) {
                Encoder_->Write(data.data(), data.size());
                Encoder_->Flush();
            } else {
                SendData(NSrvKernel::TChunkList(TString(data)));
            }
        }

        void SendData(TString data) {
            if (Encoder_) {
                Encoder_->Write(data.data(), data.size());
                Encoder_->Flush();
            } else {
                SendData(NSrvKernel::TChunkList(std::move(data)));
            }
        }

        void SendTrailers(NSrvKernel::THeaders&& headers);

        void Finish() {
            if (SendState_.Finished) {
                return;
            }

            SendState_.Finished = true;
            DoFinish();
        }

        void SendEof() {
            CheckClientError();

            if (SendState_.Finished) {
                return;
            }

            if (Encoder_) {
                Encoder_->Finish();
            }

            DoSendEof();
        }

        NSrvKernel::TError TakeClientError() {
            if (AtomicGet(HasClientError_) == 1) {
                Y_VERIFY(ClientError_);
                return std::move(ClientError_);
            } else {
                return {};
            }
        }

        class TClientError
            : public yexception
        {
        };

        bool IsSendFinished() const {
            return SendState_.Finished;
        }

        bool IsReplyFinished() const {
            return ReplyState_.Finished;
        }

        size_t SentSize() const {
            return SendState_.SentSize;
        }

        virtual NSrvKernel::TError Transfer(TCont* cont, NSrvKernel::TEventWaker* waker, NSrvKernel::IHttpOutput& httpOutput) = 0;
    public:
        static void Destroy(IHttpReplyTransport* t) {
            try {
                t->Finish();
            } catch (...) {
                Cdbg << "uncaught exception: " << CurrentExceptionMessage();
            }
        }
    protected:
        void SendHeadImpl(NSrvKernel::IHttpOutput& httpOutput, NSrvKernel::TResponse&& response) {
            ReplyState_.HeadSent = true;
            if (auto error = httpOutput.SendHead(std::move(response), false, TInstant::Max())) {
                PutClientError(std::move(error));
            }
        }
        void SendDataImpl(NSrvKernel::IHttpOutput& httpOutput, NSrvKernel::TChunkList&& data) {
            const size_t chunkSize = data.size();
            if (auto error = httpOutput.Send(std::move(data), TInstant::Max())) {
                PutClientError(std::move(error));
            }
            ReplyState_.SentSize += chunkSize;
        }
        void SendTrailersImpl(NSrvKernel::IHttpOutput& httpOutput, NSrvKernel::THeaders&& headers) {
            if (auto error = httpOutput.SendTrailers(std::move(headers), TInstant::Max())) {
                PutClientError(std::move(error));
            }
        }
        void SendEofImpl(NSrvKernel::IHttpOutput& httpOutput) {
            if (ReplyState_.Finished) {
                return;
            }
            ReplyState_.Finished = true;
            if (!ReplyState_.HeadSent) { // You can't send eof without head bc it breaks HttpOutput state
                return;
            }
            if (auto error = httpOutput.SendEof(TInstant::Max())) {
                PutClientError(std::move(error));
            }
        }

        void CheckNotFinished() {
            Y_ENSURE(!SendState_.Finished, "try to send something after finish");
        }

        struct TReplyState {
            bool HeadSent = false;
            bool Finished = false;
            size_t SentSize = 0;
        } SendState_, ReplyState_;
    private:
        virtual void DoSendHead(NSrvKernel::TResponse&& response) = 0;
        virtual void DoSendData(NSrvKernel::TChunkList&& data) = 0;
        virtual void DoSendTrailers(NSrvKernel::THeaders&& headers) = 0;
        virtual void DoSendEof() = 0;
        virtual void DoFinish() = 0;

        void SendData(NSrvKernel::TChunkList&& data) {
            CheckClientError();
            CheckNotFinished();
            Y_ENSURE(SendState_.HeadSent, "try to send data before head");

            if (data.Empty()) {
                return;
            }

            SendState_.SentSize += data.size();

            DoSendData(std::move(data));
        }

        void PutClientError(NSrvKernel::TError&& error) {
            if (AtomicGet(HasClientError_)) {
                return;
            }

            Y_VERIFY(error);

            ClientErrorMessage_ = error->what();
            if (!ClientErrorMessage_) {
                ClientErrorMessage_ = "unknown client error";
            }
            ClientError_ = std::move(error);

            AtomicSet(HasClientError_, 1);
        }

        void CheckClientError() const {
            if (AtomicGet(HasClientError_) != 1) {
                return;
            }

            ythrow TClientError() << ClientErrorMessage_;
        }
    private:
        TString ClientErrorMessage_;
        NSrvKernel::TError ClientError_;
        TAtomic HasClientError_ = 0;

        TString Encoding_;
        class TSendDataStream;
        THolder<IOutputStream> SendDataStream_;
        THolder<IOutputStream> Encoder_;
    };



    class THttpReplyTransport
        : public IHttpReplyTransport {
    public:
        explicit THttpReplyTransport(TCont* serverCont, NSrvKernel::IHttpOutput& output, size_t channelSize = 32, TDuration sendTimeout = TDuration::Seconds(10))
            : ServerCont_(serverCont)
            , Output_(output)
            , ReplyChannel_(channelSize)
            , Timeout_(sendTimeout)
        {
        }

        NSrvKernel::TError Transfer(TCont* cont, NSrvKernel::TEventWaker* waker, NSrvKernel::IHttpOutput& httpOutput) override {
            if (ReplyState_.Finished || !DoTransfer_) {
                return {};
            }

            TAtomicSharedPtr<TFuncMessageBase> f;

            while (ReplyChannel_.Receive(f, TInstant::Max(), cont, waker) == NSrvKernel::EChannelStatus::Success) {
                Y_VERIFY(f.Get());
                NSrvKernel::TErrorOr<bool> ret = f->Run(httpOutput);

                if (NSrvKernel::TError clientError = TakeClientError()) {
                    while (ReplyChannel_.TryReceive(f)) {}
                    return clientError;
                }

                bool finished = false;
                if (NSrvKernel::TError backendError = ret.AssignTo(finished)) {
                    return backendError;
                }
                if (finished) {
                    return {};
                }
            }

            return Y_MAKE_ERROR(yexception() << "http reply transport receive timeout");
        }

        void SendError(NSrvKernel::TError e) {
            CheckNotFinished();

            SendState_.Finished = true;

            auto f = MakeFuncMessage([error = std::move(e)](NSrvKernel::IHttpOutput&) mutable -> NSrvKernel::TErrorOr<bool> {
                return NSrvKernel::TError(std::move(error));
            });

            Send(f);
        }

    private:
        void DoSendHead(NSrvKernel::TResponse&& response) override {
            if (ServerCont_ == RunningCont()) {
                SendHeadImpl(Output_, std::move(response));
            } else {
                auto f = MakeFuncMessage([this, r = std::move(response)](NSrvKernel::IHttpOutput& httpOutput) mutable -> NSrvKernel::TErrorOr<bool> {
                    SendHeadImpl(httpOutput, std::move(r));
                    return false;
                });

                Send(f);
            }
        }

        void DoSendData(NSrvKernel::TChunkList&& data) override {
            if (ServerCont_ == RunningCont()) {
                SendDataImpl(Output_, std::move(data));
            } else {
                auto f = MakeFuncMessage([this, d = std::move(data)](NSrvKernel::IHttpOutput& httpOutput) mutable -> NSrvKernel::TErrorOr<bool> {
                    SendDataImpl(httpOutput, std::move(d));
                    return false;
                });

                Send(f);
            }
        }

        void DoSendTrailers(NSrvKernel::THeaders&& headers) override {
            if (ServerCont_ == RunningCont()) {
                SendTrailersImpl(Output_, std::move(headers));
            } else {
                auto f = MakeFuncMessage([this, h = std::move(headers)](NSrvKernel::IHttpOutput& httpOutput) mutable -> NSrvKernel::TErrorOr<bool> {
                    SendTrailersImpl(httpOutput, std::move(h));
                    return false;
                });

                Send(f);
            }
        }

        void DoSendEof() override {
            if (ServerCont_ == RunningCont()) {
                SendEofImpl(Output_);
            } else {
                SendState_.Finished = true;

                auto f = MakeFuncMessage([this](NSrvKernel::IHttpOutput& httpOutput) mutable -> NSrvKernel::TErrorOr<bool> {
                    SendEofImpl(httpOutput);
                    return true;
                });

                Send(f);
            }
        }

        void DoFinish() override {
            if (ServerCont_ == RunningCont()) {
                DoTransfer_ = false;
            } else {
                auto f = MakeFuncMessage([](NSrvKernel::IHttpOutput&) mutable -> NSrvKernel::TErrorOr<bool> {
                    return true;
                });

                Send(f);
            }
        }

    private:
        class TFuncMessageBase {
        public:
            virtual NSrvKernel::TErrorOr<bool> Run(NSrvKernel::IHttpOutput&) = 0;
            virtual ~TFuncMessageBase() {}
        };

        template <typename F>
        class TFuncMessage
            : public TFuncMessageBase
        {
        public:
            TFuncMessage(F f)
                : F_(std::move(f))
            {
            }
        private:
            F F_;
            NSrvKernel::TErrorOr<bool> Run(NSrvKernel::IHttpOutput& output) override {
                return F_(output);
            }
        };

        template <typename F>
        TAtomicSharedPtr<TFuncMessageBase> MakeFuncMessage(F f) {
            return MakeAtomicShared<TFuncMessage<F>>(std::move(f));
        }

        void Send(TAtomicSharedPtr<TFuncMessageBase>& f) {
            if (ReplyChannel_.Send(f, Timeout_.ToDeadLine()) != NSrvKernel::EChannelStatus::Success) {
                ythrow yexception() << "http reply transport send failed";
            }
        }

    private:
        TCont* const ServerCont_;

        NSrvKernel::IHttpOutput& Output_;

        NSrvKernel::TU2WChannel<TAtomicSharedPtr<TFuncMessageBase>> ReplyChannel_;
        const TDuration Timeout_;
        bool DoTransfer_ = true;
    };
}
