#include "buffer.h"
#include "http.h"

#include <balancer/serval/contrib/cno/core.h>
#include <balancer/serval/contrib/cone/cone.hh>
#include <balancer/serval/core/unistat.h>

#include <util/generic/deque.h>
#include <util/generic/hash.h>
#include <util/generic/list.h>
#include <util/generic/scope.h>
#include <util/generic/vector.h>

static inline cno_message_t ConvertHead(const NSv::THead& head) {
    return {head.Code, {head.Method.data(), head.Method.size()}, {head.PathWithQuery.data(), head.PathWithQuery.size()},
            (cno_header_t*)head.data(), head.size()};
}

namespace {
    class TStream: public NSv::IStream {
    public:
        TStream(
            cno_connection_t* conn,
            size_t id,
            cone::event* streamEnd,
            NSv::IP addr,
            NSv::TLogFrame log,
            NSv::TNumber<ui64>* emptyWritesSignal
        ) noexcept
            : Body(conn->settings[CNO_LOCAL].initial_window_size)
            , Addr(addr)
            , Log_(std::move(log))
            , StreamEnd(streamEnd)
            , Conn(conn)
            , ID(id)
            , EmptyWritesSignal(emptyWritesSignal)
        {}

        NSv::IP Peer() const noexcept override {
            return Addr;
        }

        NSv::TLogFrame& Log() noexcept override {
            return Log_;
        }

        NSv::THead* Head() noexcept override {
            while (NextHead == Heads.size()) {
                if (State == Cancelled ? mun_error(ECONNRESET, "no head received") : !StateChange.wait() MUN_RETHROW) {
                    return nullptr;
                }
            }
            auto ret = &Heads[NextHead];
            if (ret->IsInformational()) {
                ++NextHead;
            }
            return ret;
        }

        NSv::THeaderVector* Tail() noexcept override {
            return Tail_ ? &*Tail_ : nullptr;
        }

        TMaybe<TStringBuf> Read() noexcept override {
            return Body.Read();
        }

        bool Consume(size_t n) noexcept override {
            Skipped += n;
            return Body.Consume(n) && !(cno_open_flow(Conn, ID, n) MUN_RETHROW);
        }

        bool AtEnd() const noexcept override {
            return Body.AtEnd();
        }

        size_t Replay() noexcept override {
            if (Conn->client || State != Idle) {
                return -1;
            }
            if (Heads) {
                // Only need to reset the last head because requests never have more than one.
                Heads.back() = LastHead;
            }
            if (Tail_) {
                Tail_.ConstructInPlace(OriginalTail);
            }
            return Skipped;
        }

        bool WriteHead(NSv::THead& h) noexcept override {
            Log_.Push<NSv::NEv::TSentResponse>(h.Code);
            // FIXME remove this, it's h1 style while we're always pretending to use h2.
            if (auto ch = h.find("connection"); ch != h.end() && ch->second == "close") {
                if (cno_shutdown(Conn) MUN_RETHROW) {
                    return false;
                }
                h.erase(ch);
            }
            auto m = ConvertHead(h);
            while (cno_write_head(Conn, ID, &m, 0)) {
                if (mun_errno != CNO_ERRNO_WOULD_BLOCK || !StreamEnd->wait() MUN_RETHROW) {
                    return false;
                }
            }
            if (State == Idle) {
                State = Writing;
            }
            StateChange.wake();
            return true;
        }

        bool Write(TStringBuf chunk) noexcept override {
            for (ssize_t written; chunk; chunk.Skip(written)) {
                if ((written = cno_write_data(Conn, ID, chunk.data(), chunk.size(), 0)) < 0 MUN_RETHROW) {
                    return false;
                }
                if (static_cast<size_t>(written) < chunk.size() && !FlowOpen.wait() MUN_RETHROW) {
                    return false;
                }
            }
            return true;
        }

        bool Close(NSv::THeaderVector& tail) noexcept override {
            const cno_tail_t converted = {(cno_header_t*)tail.data(), tail.size()};
            if (cno_write_tail(Conn, ID, &converted) MUN_RETHROW) {
                return false;
            }
            return Conn->client || !mun_error(EREQDONE, "abort request handling");
        }

        template <typename T>
        auto AddArenaFor(const T& m) {
            size_t size = 0;
            if constexpr (std::is_same<T, cno_message_t>::value) {
                size += m.method.size + m.path.size;
            }
            for (auto h = m.headers, end = h + m.headers_len; h != end; h++) {
                size += h->name.size + h->value.size;
            }
            Arenas.emplace_back(new char[size]);
            return [&, next = Arenas.back().get()](cno_buffer_t buf) mutable {
                next += buf.size;
                return TStringBuf{buf.data ? (char*)memcpy(next - buf.size, buf.data, buf.size) : nullptr, buf.size};
            };
        }

        void AddHead(const cno_message_t& m) {
            // All cno strings will be destroyed once on_message_head returns, but the coroutine
            // will live on; have to copy them, preferably into a memory pool.
            auto allocate = AddArenaFor(m);
            Heads.emplace_back(allocate(m.method), NSv::NormalizePathInPlace(allocate(m.path)));
            Heads.back().Code = m.code;
            for (auto h = m.headers, end = m.headers + m.headers_len; h != end; h++) {
                Heads.back().emplace(allocate(h->name), allocate(h->value));
            }
            LastHead = Heads.back();
            if (m.code) {
                Log_.Push<NSv::NEv::TRecvResponse>(m.code);
            } else {
                Log_.Push<NSv::NEv::TRecvRequest>(TString{m.method.data, m.method.size}, TString{m.path.data, m.path.size});
            }
            StateChange.wake();
        }

        void SetTail(const cno_tail_t& t) noexcept {
            auto allocate = AddArenaFor(t);
            Tail_.ConstructInPlace();
            Body.Close();
            for (auto h = t.headers, end = t.headers + t.headers_len; h != end; h++) {
                Tail_->insert({allocate(h->name), allocate(h->value)});
            }
            OriginalTail = *Tail_;
            Log_.Push<NSv::NEv::TRecvTail>();
        }

        void Cancel() noexcept {
            if (Handler) {
                // Only happens in the reader coroutine, which means the handler
                // is currently sleeping or scheduled.
                Handler->cancel();
            }
            Body.Abort();
            State = Cancelled;
            StateChange.wake();
            FlowOpen.wake(); // Unblock the writer so that it receives an invalid stream error.
        }

    public:
        TDeque<std::unique_ptr<char[]>> Arenas;
        TDeque<NSv::THead> Heads;
        TMaybe<NSv::THeaderVector> Tail_;
        NSv::THead LastHead;
        NSv::THeaderVector OriginalTail;
        NSv::TBuffer Body;
        NSv::IP Addr;
        NSv::TLogFrame Log_;
        cone::event StateChange;
        cone::event FlowOpen;
        cone::event* StreamEnd;
        struct cone* Handler = nullptr;
        cno_connection_t* Conn;
        ui32 ID;
        size_t Skipped = 0;
        enum { Idle, Writing, Cancelled } State = Conn->client ? Writing : Idle;
        unsigned NextHead = 0;
        NSv::TNumber<ui64>* EmptyWritesSignal;
    };

    struct TConnectionBase: cno_connection_t {
        ~TConnectionBase() {
            cno_fini(this);
        }
    };

    struct TConnection: NSv::IConnection, TConnectionBase {
    public:
        TConnection(NSv::IO& io, enum CNO_CONNECTION_KIND kind, NSv::THandler h, NSv::TConnectionOptions opts, NSv::TLogFrame log, NSv::TNumber<ui64>* emptyWritesSignal)
            : WriteBuffer(opts.WriteBuffer)
            , IO(&io)
            , Handler(std::move(h))
            , Log(std::move(log))
            , EmptyWritesSignal(emptyWritesSignal)
        {
#define METHOD(f) [](void* p, auto... args) { return ((TConnection*)p)->f(args...); }
            static const cno_vtable_t vtable = {
                .on_writev        = METHOD(OnWriteV),
                .on_close         = METHOD(OnClose),
                .on_stream_start  = METHOD(OnStreamStart),
                .on_stream_end    = METHOD(OnStreamEnd),
                .on_flow_increase = METHOD(OnFlowIncrease),
                .on_message_head  = METHOD(OnMessageHead),
                .on_message_data  = METHOD(OnMessageData),
                .on_message_tail  = METHOD(OnMessageTail),
                .on_upgrade       = METHOD(OnUpgrade),
            };
#undef METHOD
            cno_init(this, kind);
            cb_code = &vtable;
            cb_data = this;
            manual_flow_control = 1;
            settings[CNO_LOCAL].initial_window_size = opts.StreamWindow;
            settings[CNO_LOCAL].max_frame_size = opts.StreamWindow;
        }

        bool IsIdle() const noexcept override {
            return !Streams && WriteBuffer.IsOpen();
        }

        bool Wait() noexcept override {
            return Writer->wait(cone::norethrow);
        }

        bool Shutdown() noexcept override {
            return !(cno_shutdown(this) MUN_RETHROW);
        }

        NSv::IStreamPtr Request(const NSv::THead& head, bool payload, NSv::TLogFrame log) noexcept override {
            if (!client) {
                return mun_error(EINVAL, "this is a server, it cannot send requests"), nullptr;
            }
            auto m = ConvertHead(head);
            do {
                // May fail in h2 mode if reached the server-imposed concurrent request limit.
                if (ui32 id = cno_next_stream(this); !cno_write_head(this, id, &m, !payload)) {
                    auto& stream = Streams.find(id)->second;
                    stream->Log_ = std::move(log);
                    if (mode != CNO_HTTP2) {
                        return stream;
                    }
                    // Cancel the stream when the returned reference is destroyed. The closure
                    // over the original shared_ptr is used to also destroy the object.
                    return std::shared_ptr<TStream>(stream.get(), [this, stream](TStream* s) {
                        cno_write_reset(this, s->ID, CNO_RST_CANCEL);
                    });
                }
            } while (mun_errno == CNO_ERRNO_WOULD_BLOCK && StreamEnd.wait());
            return nullptr;
        }

        bool IsH2() const noexcept override {
            return mode == CNO_HTTP2;
        }

        bool IsOpen() const noexcept override {
            return WriteBuffer.IsOpen();
        }

        bool Migrate() noexcept override {
            cone::uninterruptible([this]() {
                Writer->cancel();
                Reader->cancel();
                Writer->wait(cone::norethrow);
                Reader->wait(cone::norethrow);
            });
            if (!IsOpen()) {
                return false;
            }
            Reader = [this]() {
                return ReadLoop();
            };
            Writer = [this]() {
                return WriteLoop();
            };
            return true;
        }

    private:
        bool ReadLoop() noexcept {
            Y_DEFER {
                if (mun_errno != ECANCELED) {
                    // Close, not Abort; should still flush the buffer (in h2 mode, it may
                    // contain GOAWAY frames on protocol errors.)
                    WriteBuffer.Close();
                }
                for (auto& s : Streams) {
                    s.second->Cancel();
                }
                Streams.clear();
            };
            char buf[8192];
            while (auto rd = IO->ReadInto(TStringBuf(buf, sizeof(buf)))) {
                if (!*rd) {
                    // This connection is either closed, or half-closed. Unfortunately, the cases are
                    // indistinguishable until `Writer` attempts to write, so assume clients don't close
                    // the connection until reading the response (Other servers, like Go's built-in one,
                    // assume that as well, so maybe supporting half-closed is pointless anyway).
                    return !cno_eof(this) || EqualToOneOf(mun_errno, EPIPE, CNO_ERRNO_PROTOCOL);
                }
                for (; cno_consume(this, buf, *rd); *rd = 0) {
                    if (mun_errno == CNO_ERRNO_PROTOCOL) {
                        Log.Push<NSv::NEv::THxProtocolError>(mun_last_error()->text);
                    }
                    // `cno_consume` returns an error when there are too many pipelined requests;
                    // we can retry the read later if that's the case. (This never happens in h2.)
                    if (mun_errno != CNO_ERRNO_WOULD_BLOCK || !StreamEnd.wait()) {
                        return EqualToOneOf(mun_errno, EPIPE, CNO_ERRNO_PROTOCOL);
                    }
                }
            }
            return mun_errno == ECONNRESET;
        }

        static constexpr int OP_LIMIT = 3;

        bool WriteLoop() noexcept {
            Y_DEFER { Reader->cancel(); };
            int zero_writes = 0;
            while (auto chunk = WriteBuffer.Read()) {
                if (!*chunk) {
                    return true; // buffer closed by graceful shutdown
                }
                auto written = IO->Write(*chunk);
                if (written && !*written) {
                    ++(*EmptyWritesSignal);
                    ++zero_writes;
                    if (zero_writes > OP_LIMIT) {
                        break;
                    }
                }
                if (!written || !WriteBuffer.Consume(*written) MUN_RETHROW) {
                    WriteBuffer.Abort();
                    // Even though the reader will do the same, it will be too late
                    // to avoid EPIPEs from `Write` and `WriteHead`.
                    for (auto& s : Streams) {
                        s.second->Cancel();
                    }
                    Streams.clear();
                    break;
                }
            }
            return EqualToOneOf(mun_errno, EPIPE, ECONNRESET);
        }

        int OnWriteV(const cno_buffer_t* iov, size_t iovcnt) noexcept {
            WriteLock.lock();
            Y_DEFER {
                if (!WriteLock.unlock(cone::mutex::fair)) {
                    WriteBuffer.Uncork();
                }
            };
            // Can't cancel a write in the middle of a frame; also, this might be a compressed
            // header block, so can't cancel a write at all.
            return cone::uninterruptible([&]{
                for (; iovcnt--; iov++) {
                    if (!WriteBuffer.Write(TStringBuf(iov->data, iov->size), /*corked=*/true) MUN_RETHROW) {
                        return -1;
                    }
                }
                return 0;
            }) ? cone::yield(), -1 : 0; // consume cancellation on error
        }

        int OnClose() noexcept {
            WriteBuffer.Close();
            return 0;
        }

        int OnStreamStart(ui32 id) noexcept {
            Streams.emplace(id, std::make_shared<TStream>(this, id, &StreamEnd, IO->Peer(), Log.Fork<NSv::NEv::TStreamStart>(id), EmptyWritesSignal));
            return 0;
        }

        int OnStreamEnd(ui32 id, ui32 code, enum CNO_PEER_KIND side) noexcept {
            auto it = Streams.find(id);
            if (it != Streams.end()) {
                if (code != CNO_RST_NO_ERROR) {
                    it->second->Log_.Push<NSv::NEv::THxStreamError>(code, side == CNO_REMOTE);
                }
                it->second->Cancel();
                Streams.erase(it);
            }
            // Can consume one more pipelined h1 request.
            StreamEnd.wake();
            return 0;
        }

        int OnMessageHead(ui32 id, const cno_message_t* msg) noexcept {
            auto it = Streams.find(id);
            if (it != Streams.end()) {
                it->second->AddHead(*msg);
                if (Handler) {
                    it->second->Handler = Tasks.add([this, s = it->second]() mutable {
                        Y_DEFER {
                            s->Body.Discard();
                            s->Handler = nullptr;
                        };
                        // Stream might have been cancelled before this coroutine even started.
                        // Every other cancellation can only happen at a preemption point.
                        if (s->State == TStream::Cancelled) {
                            return true;
                        }
                        // Handler should never succeed; on success, it should "fail" with EREQDONE.
                        if (cone::try_mun([&] {
                            return Handler(s);
                        })) {
                            mun_error(EINVAL, "handler finished without closing the stream");
                        }
                        if (!EqualToOneOf(mun_errno, ECANCELED, EREQDONE) MUN_RETHROW) {
                            WriteBuffer.Close();
                            s->Log().Push<NSv::NEv::THxHandlerError>(mun_errno, mun_last_error()->text);
                            return false;
                        }
                        return mun_errno != ECANCELED && (!IsH2() || !cno_write_reset(this, s->ID, CNO_RST_CANCEL));
                    });
                }
            }
            return 0;
        }

        int OnMessageData(ui32 id, const char* data, size_t size) noexcept {
            auto it = Streams.find(id);
            if (it != Streams.end()) {
                return !std::shared_ptr<TStream>{it->second}->Body.Write(TStringBuf(data, size)) MUN_RETHROW;
            }
            return 0;
        }

        int OnMessageTail(ui32 id, const cno_tail_t* tail) noexcept {
            auto it = Streams.find(id);
            if (it != Streams.end()) {
                it->second->SetTail(tail ? *tail : Default<cno_tail_t>());
            }
            return 0;
        }

        int OnFlowIncrease(ui32 id) noexcept {
            if (!id) {
                for (auto& s : Streams) {
                    s.second->FlowOpen.wake();
                }
                return 0;
            }
            auto it = Streams.find(id);
            if (it != Streams.end()) {
                it->second->FlowOpen.wake();
            }
            return 0;
        }

        int OnUpgrade(ui32 id) noexcept {
            auto it = Streams.find(id);
            if (it != Streams.end() && it->second->State == TStream::Idle) {
                // Will always change to `Writing`, since the reader is blocked.
                return !std::shared_ptr<TStream>{it->second}->StateChange.wait() MUN_RETHROW;
            }
            return 0;
        }

    private:
        cone::event StreamEnd;
        cone::mutex WriteLock;
        NSv::TBuffer WriteBuffer;
        NSv::IO* IO = nullptr;
        NSv::THandler Handler;
        NSv::TLogFrame Log;
        THashMap<ui32, std::shared_ptr<TStream>> Streams;
        cone::mguard Tasks;
        cone::guard Reader = [this]() {
            return ReadLoop();
        };
        cone::guard Writer = [this]() {
            return WriteLoop();
        };
        NSv::TNumber<ui64>* EmptyWritesSignal;
    };
}

THolder<NSv::IConnection> NSv::H2Server(NSv::IO& io, NSv::THandler h, NSv::TNumber<ui64>* emptyWritesSignal, TConnectionOptions opts, NSv::TLogFrame log) {
    auto c = MakeHolder<TConnection>(io, CNO_SERVER, std::move(h), std::move(opts), std::move(log), emptyWritesSignal);
    if (cno_begin(c.Get(), opts.ForceH2 || io.SelectedProtocol() == "h2" ? CNO_HTTP2 : CNO_HTTP1) MUN_RETHROW) {
        return {};
    }
    if (cno_open_flow(c.Get(), 0, opts.StreamWindow * 5) MUN_RETHROW) {
        return {};
    }
    return c;
}

THolder<NSv::IConnection> NSv::H2Client(NSv::IO& io, NSv::TNumber<ui64>* emptyWritesSignal, TConnectionOptions opts) {
    auto c = MakeHolder<TConnection>(io, CNO_CLIENT, NSv::THandler{}, std::move(opts), NSv::TLogFrame{}, emptyWritesSignal);
    if (cno_begin(c.Get(), opts.ForceH2 || io.SelectedProtocol() == "h2" ? CNO_HTTP2 : CNO_HTTP1) MUN_RETHROW) {
        return {};
    }
    if (cno_open_flow(c.Get(), 0, opts.StreamWindow * 5) MUN_RETHROW) {
        return {};
    }
    return c;
}
