#include "http.h"
#include "storage.h"

#include <balancer/serval/contrib/cno/hpack.h>
#include <balancer/serval/contrib/cone/cone.hh>
#include <balancer/serval/contrib/cone/cold.h>

#include <util/generic/list.h>
#include <util/generic/scope.h>

namespace {
    // 0    6   8                    8+N
    // | ID | N | HPACK-encoded head | payload ...
    //
    // HPACK does not use a dynamic table. HTTP2 pseudo-header coding applies, i.e.
    // requests have :method and :path, responses have :status.
    TString Encode(ui64 id, const NSv::THead& head, bool request) {
        cno_buffer_dyn_t buf = {};
        cno_hpack_t hp;
        cno_hpack_init(&hp, 0);
        Y_DEFER {
            cno_hpack_clear(&hp);
            cno_buffer_dyn_clear(&buf);
        };
        ui8 prefix[8] = {(ui8)(id >> 40), (ui8)(id >> 32), (ui8)(id >> 24),
                         (ui8)(id >> 16), (ui8)(id >> 8),  (ui8)id, 0, 0};
        TString code = request ? "" : ToString(head.Code); // TODO remove this allocation
        struct cno_header_t extra[] = {
            {CNO_BUFFER_STRING(":status"), {code.data(), code.size()}, 0},
            {CNO_BUFFER_STRING(":method"), {head.Method.data(), head.Method.size()}, 0},
            {CNO_BUFFER_STRING(":path"),   {head.PathWithQuery.data(), head.PathWithQuery.size()}, 0},
        };
        if (cno_buffer_dyn_concat(&buf, {(char*)prefix, 8})
         || cno_hpack_encode(&hp, &buf, request ? extra + 1 : extra, request ? 2 : 1)
         || cno_hpack_encode(&hp, &buf, (cno_header_t*)head.data(), head.size()) MUN_RETHROW)
            return {};
        buf.data[6] = (ui8)((buf.size - 8) >> 8);
        buf.data[7] = (ui8)(buf.size - 8);
        return TString(buf.data, buf.size); // TODO remove this allocation
    }

    ui64 DecodeID(TStringBuf buf) {
        return (ui64)(ui8)buf[0] << 40 | (ui64)(ui8)buf[1] << 32 | (ui64)(ui8)buf[2] << 24
             | (ui64)(ui8)buf[3] << 16 | (ui64)(ui8)buf[4] << 8  | (ui64)(ui8)buf[5];
    }

    ui32 DecodeHeadLength(TStringBuf buf) {
        return (ui32)(ui8)buf[6] << 8 | (ui32)(ui8)buf[7];
    }

    bool MayBeValid(TStringBuf buf) {
        return buf.size() >= 8 && DecodeHeadLength(buf) <= buf.size() - 8;
    }

    class THeadDecoder : TNonCopyable {
    public:
        THeadDecoder(TStringBuf buf) {
            cno_hpack_t hp;
            cno_hpack_init(&hp, 0);
            Y_DEFER { cno_hpack_clear(&hp); };
            if (cno_hpack_decode(&hp, {buf.data() + 8, DecodeHeadLength(buf)}, Data, &Size))
                Size = 0;
            for (auto h = Data; h != Data + Size; h++) {
                TStringBuf k{h->name.data, h->name.size};
                TStringBuf v{h->value.data, h->value.size};
                if (k == ":status") {
                    Head.Code = FromStringWithDefault(v, Head.Code);
                } else if (k == ":method") {
                    Head.Method = v;
                } else if (k == ":path") {
                    Head.PathWithQuery = v;
                } else {
                    Head.emplace(k, v);
                }
            }
        }

        ~THeadDecoder() {
            for (size_t i = 0; i < Size; i++) {
                cno_hpack_free_header(&Data[i]);
            }
        }

    public:
        NSv::THead Head;
        cno_header_t Data[CNO_MAX_HEADERS];
        size_t Size = CNO_MAX_HEADERS;
    };

    template <typename T>
    class TUDPConnection : TNonCopyable {
    private:
        struct TMessage {
            TMessage* Next = nullptr;
            NSv::IP Address;
            ui16 Size : 16;
            char Data[];

            static void* operator new(size_t size, size_t overhead, size_t) {
                return ::operator new(size + overhead);
            }

            static void operator delete(void* ptr, size_t ) {
                return ::operator delete(ptr);
            }
        };

        NSv::TFile& F_;
        cone::event More_;
        TMessage* Head_ = nullptr;
        TMessage* Tail_ = nullptr;

    public:
        TUDPConnection(NSv::TFile& f) noexcept
            : F_(f)
        {
        }

        ~TUDPConnection() {
            while (TMessage* it = Head_) {
                Head_ = it->Next;
                delete it;
            }
        }

        void Send(NSv::IP address, TStringBuf contents) {
            THolder<TMessage> m(new (contents.size(), 0) TMessage);
            m->Address = std::move(address);
            m->Size = contents.size();
            memcpy(m->Data, contents.data(), contents.size());
            if (Head_ == nullptr) {
                Head_ = m.Get();
                More_.wake();
            } else {
                Tail_->Next = m.Get();
            }
            Tail_ = m.Release();
        }

    private:
#ifdef _linux_
        static constexpr size_t TransferPerCall = 64;
#else
        static constexpr size_t TransferPerCall = 1;

        struct mmsghdr {
            struct msghdr msg_hdr;
            unsigned msg_len;
        };
#endif

        cone::guard Reader = [this]() {
            THolder<TMessage> ptr[TransferPerCall];
            struct iovec iov[TransferPerCall];
            struct mmsghdr msg[TransferPerCall];
            for (size_t i = 0; i < TransferPerCall; i++) {
                ptr[i] = THolder<TMessage>(new (2048, 0) TMessage);
                iov[i] = {ptr[i]->Data, 2048};
                msg[i] = mmsghdr{
                    .msg_hdr = {
                        .msg_name = &ptr[i]->Address.Data,
                        .msg_namelen = sizeof(ptr[i]->Address.Data),
                        .msg_iov = &iov[i],
                        .msg_iovlen = 1,
                    },
                };
            }
            while (true) {
#ifdef _linux_
                int n = cold_recvmmsg(F_, msg, TransferPerCall, MSG_DONTWAIT, nullptr);
#else
                ssize_t m = cold_recvmsg(F_, &msg[0].msg_hdr, MSG_DONTWAIT);
                int n = m < 0 ? -1 : 1;
                msg[0].msg_len = m;
#endif
                if (n < 0 MUN_RETHROW_OS) {
                    // TODO ???
                    return false;
                }
                for (int i = 0; i < n; i++) {
                    static_cast<T*>(this)->OnMessage(ptr[i]->Address, TString(ptr[i]->Data, msg[i].msg_len));
                    msg[i].msg_hdr.msg_namelen = sizeof(ptr[i]->Address.Data);
                }
            }
        };

        cone::guard Writer = [this]() {
            while (Head_ || More_.wait()) {
                unsigned n = 0;
                struct iovec iov[TransferPerCall];
                struct mmsghdr msg[TransferPerCall];
                for (TMessage* it = Head_; it && n < TransferPerCall; it = it->Next, n++) {
                    iov[n] = {it->Data, it->Size};
                    msg[n] = mmsghdr{
                        .msg_hdr = {
                            .msg_name = (struct sockaddr*)&it->Address.Data,
                            .msg_namelen = sizeof(it->Address.Data),
                            .msg_iov = &iov[n],
                            .msg_iovlen = 1,
                        },
                        .msg_len = static_cast<unsigned>(it->Size),
                    };
                }
#ifdef _linux_
                int m = cold_sendmmsg(F_, msg, n, MSG_DONTWAIT);
#else
                int m = cold_sendmsg(F_, &msg[0].msg_hdr, MSG_DONTWAIT) < 0 ? -1 : 1;
#endif
                if (m < 0 MUN_RETHROW_OS) {
                    // TODO ???
                    return false;
                }
                for (TMessage* it = Head_; m--; it = Head_) {
                    Head_ = it->Next;
                    delete it;
                }
            }
            return false;
        };
    };

    class TClient : private NSv::TFile, public TUDPConnection<TClient> {
    public:
        class TStream : public NSv::IStream, ::TNonCopyable {
        public:
            TStream(TClient& conn, ui64 id, NSv::IP addr, NSv::TLogFrame log, TString data, bool payload)
                : Conn(conn)
                , Addr(addr)
                , Log_(std::move(log))
                , Data(std::move(data))
                , ID(id)
            {
                if (!payload) {
                    NSv::IStream::Close();
                }
            }

            ~TStream() {
                if (Sent && !Recv) {
                    Conn.Waiting.erase(ID);
                }
            }

            void SetHead(TString buf) noexcept {
                Recv.ConstructInPlace(buf);
                Skip = 8 + DecodeHeadLength(buf);
                Data = std::move(buf);
                Conn.Waiting.erase(ID);
                Log_.Push<NSv::NEv::TRecvResponse>(Recv->Head.Code);
                Log_.Push<NSv::NEv::TRecvTail>();
                HaveResponse.wake();
            }

            NSv::IP Peer() const noexcept override {
                return Addr;
            }

            NSv::TLogFrame& Log() noexcept override {
                return Log_;
            }

            NSv::THead* Head() noexcept override {
                if (!Recv && !HaveResponse.wait() MUN_RETHROW) {
                    return nullptr;
                }
                return &Recv->Head;
            }

            NSv::THeaderVector* Tail() noexcept override {
                return Recv ? &Tail_ : nullptr;
            }

            TMaybe<TStringBuf> Read() noexcept override {
                if (!Recv && !HaveResponse.wait() MUN_RETHROW) {
                    return {};
                }
                return TStringBuf(Data).Tail(Skip);
            }

            bool Consume(size_t n) noexcept override {
                Skip += n;
                return true;
            }

            bool AtEnd() const noexcept override {
                return Recv && Skip == Data.size();
            }

            bool WriteHead(NSv::THead&) noexcept override {
                return !mun_error(EPIPE, "already sent the request");
            }

            bool Write(TStringBuf part) noexcept override {
                if (Sent) {
                    return !mun_error(EPIPE, "payload already finished");
                }
                Data += part;
                return true;
            }

            bool Close(NSv::THeaderVector&) noexcept override {
                if (Sent) {
                    return !mun_error(EPIPE, "payload already finished");
                }
                Sent = true;
                Conn.Waiting[ID] = this;
                Conn.Send(Addr, Data);
                return true;
            }

        private:
            cone::event HaveResponse;
            TClient& Conn;
            NSv::IP Addr;
            NSv::TLogFrame Log_;
            TString Data;
            TMaybe<THeadDecoder> Recv;
            NSv::THeaderVector Tail_;
            ui64 ID;
            ui64 Skip = 0;
            bool Sent = false;
        };

    public:
        TClient(NSv::TFile f)
            : TFile(std::move(f))
            , TUDPConnection<TClient>(static_cast<NSv::TFile&>(*this))
        {
        }

        void OnMessage(const NSv::IP&, TString buf) {
            if (!MayBeValid(buf)) {
                return;
            }
            auto ev = Waiting.find(DecodeID(buf));
            if (ev == Waiting.end()) {
                return;
            }
            ev->second->SetHead(std::move(buf));
        }

        NSv::IStreamPtr Request(NSv::IP addr, NSv::TLogFrame log, const NSv::THead& head, bool payload) noexcept {
            ui64 id = NextId++ & 0xFFFFFFFFFFull;
            TString data = Encode(id, head, true);
            return data ? std::make_shared<TStream>(*this, id, addr, std::move(log), std::move(data), payload) : nullptr;
        }

    private:
        THashMap<ui32, TStream*> Waiting;
        ui64 NextId = 0;
    };

    struct TBoundClient : public NSv::IConnection {
    public:
        TBoundClient(const NSv::IP& addr) : Addr(addr) {}
        bool IsIdle() const noexcept override { return true; }
        bool Wait() noexcept override { return true; }
        bool Shutdown() noexcept override { return true; }
        bool Migrate() noexcept override { return true; }

        NSv::IStreamPtr Request(const NSv::THead& head, bool payload, NSv::TLogFrame log) noexcept override {
            if (auto client = Client->Get()) {
                return client->Request(Addr, std::move(log), head, payload);
            }
            NSv::TFile f = socket(Addr.Data.Base.sa_family, SOCK_DGRAM, 0);
            if (!f MUN_RETHROW_OS) {
                return {};
            }
            return Client->GetOrCreate(std::move(f)).Request(Addr, std::move(log), head, payload);
        }

    private:
        NSv::IP Addr;
        std::shared_ptr<NSv::TThreadLocal<TClient>> Client = NSv::StaticData(Addr.Data.Base.sa_family, []{
            return NSv::TThreadLocal<TClient>{}; });
    };

    class TServer : public TUDPConnection<TServer>, public NSv::IConnection {
    public:
        TServer(NSv::TFile& f, NSv::THandler h, NSv::TLogFrame log)
            : TUDPConnection<TServer>(f)
            , H_(std::move(h))
            , L_(std::move(log))
        {
            Y_ASSERT(H_);
        }

        NSv::IStreamPtr Request(const NSv::THead&, bool, NSv::TLogFrame) noexcept override {
            return mun_error(EINVAL, "this is a server"), nullptr;
        }

        bool IsIdle() const noexcept override {
            return Tasks_.empty();
        }

        bool Wait() noexcept override {
            return cone::event{}.wait();
        }

        bool Shutdown() noexcept override {
            return !mun_error(ENOSYS, "not implemented");
        }

        bool Migrate() noexcept override {
            return !mun_error(ENOSYS, "not implemented");
        }

        void OnMessage(const NSv::IP& source, TString buf) {
            if (!MayBeValid(buf)) {
                return;
            }
            Tasks_.add([=, buf = std::move(buf)] {
                TString response;
                THeadDecoder dec(buf);
                auto connLog = L_.Fork<NSv::NEv::TAcceptConnection>(source.FormatFull());
                auto log = connLog.Fork<NSv::NEv::TStreamStart>(-1);
                log.Push<NSv::NEv::TRecvRequest>(TString{dec.Head.Method}, TString{dec.Head.PathWithQuery});
                log.Push<NSv::NEv::TRecvTail>();
                auto stream = NSv::ConstRequestStream(std::move(dec.Head), TStringBuf(buf).Skip(8 + DecodeHeadLength(buf)),
                    [&](NSv::THead& head) {
                        response = Encode(DecodeID(buf), head, false);
                        return true;
                    },
                    [&](TStringBuf part) {
                        response += part;
                        return true;
                    },
                    [&](NSv::THeaderVector&) {
                        Send(source, response);
                        return true;
                    },
                    std::move(log));
                if (H_(stream)) {
                    mun_error(EINVAL, "handler finished without closing the stream");
                }
                if (!EqualToOneOf(mun_errno, ECANCELED, EREQDONE) MUN_RETHROW) {
                    stream->Log().Push<NSv::NEv::THxHandlerError>(mun_errno, mun_last_error()->text);
                    return false;
                }
                return true;
            });
        }

    private:
        NSv::THandler H_;
        NSv::TLogFrame L_;
        cone::mguard Tasks_;
    };
}

THolder<NSv::IConnection> NSv::TestUDPServer(NSv::TFile& f, NSv::THandler h, NSv::TLogFrame log) {
    return MakeHolder<TServer>(f, std::move(h), std::move(log));
}

THolder<NSv::IConnection> NSv::TestUDPClient(const NSv::IP& addr) {
    return MakeHolder<TBoundClient>(addr);
}
