#include "http2_frame.h"
#include "http2_settings.h"

#include <balancer/kernel/http2/server/utils/http2_log.h>

#include <util/string/builder.h>
#include <util/stream/format.h>
#include <util/stream/output.h>

namespace NSrvKernel::NHTTP2 {

    namespace {
        const ui32 RFC_GOAWAY_FRAME_SIZE_MIN = 8;
        const ui32 RFC_PING_FRAME_SIZE = 8;
        const ui32 RFC_PADDING_SIZE_MIN = 1;
        const ui32 RFC_RST_STREAM_FRAME_SIZE = 4;
        const ui32 RFC_WINDOW_UPDATE_FRAME_SIZE = 4;

        const std::initializer_list<EFrameType> PADDING_ALLOWED {
            EFrameType::DATA, EFrameType::HEADERS, EFrameType::PUSH_PROMISE
        };

        const std::initializer_list<EFrameType> PRIORITY_ALLOWED {
            EFrameType::HEADERS
        };

        const std::initializer_list<EFrameType> ACK_ALLOWED {
            EFrameType::SETTINGS, EFrameType::PING
        };

        const std::initializer_list<EFrameType> END_HEADERS_ALLOWED {
            EFrameType::HEADERS, EFrameType::PUSH_PROMISE, EFrameType::CONTINUATION
        };

        const std::initializer_list<EFrameType> END_STREAM_ALLOWED {
            EFrameType::DATA, EFrameType::HEADERS
        };

        [[nodiscard]]
        EFrameType ParseFrameType(ui8 rawFrameType) noexcept {
#define Y_HTTP2_CASE_FRAME_TYPE(frameType) case EFrameType::frameType: return EFrameType::frameType
            switch ((EFrameType)rawFrameType) {
            Y_HTTP2_CASE_FRAME_TYPE(DATA);
            Y_HTTP2_CASE_FRAME_TYPE(HEADERS);
            Y_HTTP2_CASE_FRAME_TYPE(PRIORITY);
            Y_HTTP2_CASE_FRAME_TYPE(RST_STREAM);
            Y_HTTP2_CASE_FRAME_TYPE(SETTINGS);
            Y_HTTP2_CASE_FRAME_TYPE(PUSH_PROMISE);
            Y_HTTP2_CASE_FRAME_TYPE(PING);
            Y_HTTP2_CASE_FRAME_TYPE(GOAWAY);
            Y_HTTP2_CASE_FRAME_TYPE(WINDOW_UPDATE);
            Y_HTTP2_CASE_FRAME_TYPE(CONTINUATION);
            Y_HTTP2_CASE_FRAME_TYPE(ALTSVC);
            default:
                return EFrameType::INVALID;
            }
        }
    }


    // TFrameHeading ===================================================================================================

    TFrameHeading::TFrameHeading(EFrameType type) noexcept
        : RawType((ui32)type)
        , Type(type)
    {}

    TFrameHeading& TFrameHeading::SetStreamId(ui32 streamId) noexcept {
        StreamId = streamId;
        return *this;
    }

    bool TFrameHeading::IsData() const noexcept {
        return EFrameType::DATA == Type;
    }

    bool TFrameHeading::IsHeaders() const noexcept {
        return EFrameType::HEADERS == Type;
    }

    bool TFrameHeading::IsPriority() const noexcept {
        return EFrameType::PRIORITY == Type;
    }

    bool TFrameHeading::IsRstStream() const noexcept {
        return EFrameType::RST_STREAM == Type;
    }

    bool TFrameHeading::IsSettings() const noexcept {
        return EFrameType::SETTINGS == Type;
    }

    bool TFrameHeading::IsPing() const noexcept {
        return EFrameType::PING == Type;
    }

    bool TFrameHeading::IsGoAway() const noexcept {
        return EFrameType::GOAWAY == Type;
    }

    bool TFrameHeading::IsWindowUpdate() const noexcept {
        return EFrameType::WINDOW_UPDATE == Type;
    }

    bool TFrameHeading::IsContinuation() const noexcept {
        return EFrameType::CONTINUATION == Type;
    }

    TFrameHeading TFrameHeading::NewData(ui32 streamId) noexcept {
        return TFrameHeading(EFrameType::DATA).SetStreamId(streamId);
    }

    TFrameHeading TFrameHeading::NewDataEndStream(ui32 streamId) noexcept {
        return NewData(streamId).SetFlagEndStream(true);
    }

    TFrameHeading TFrameHeading::NewHeaders(ui32 streamId) noexcept {
        return TFrameHeading(EFrameType::HEADERS).SetStreamId(streamId);
    }

    TFrameHeading TFrameHeading::NewRstStream(ui32 streamId) noexcept {
        return TFrameHeading(EFrameType::RST_STREAM).SetStreamId(streamId);
    }

    TFrameHeading TFrameHeading::NewSettings() noexcept {
        return TFrameHeading(EFrameType::SETTINGS);
    }

    TFrameHeading TFrameHeading::NewSettingsAck() noexcept {
        return NewSettings().SetFlagAck(true);
    }

    TFrameHeading TFrameHeading::NewPing() noexcept {
        return TFrameHeading(EFrameType::PING);
    }

    TFrameHeading TFrameHeading::NewPingAck() noexcept {
        return NewPing().SetFlagAck(true);
    }

    TFrameHeading TFrameHeading::NewGoAway() noexcept {
        return TFrameHeading(EFrameType::GOAWAY);
    }

    TFrameHeading TFrameHeading::NewConnectionWindowUpdate() noexcept {
        return TFrameHeading::NewStreamWindowUpdate(0);
    }

    TFrameHeading TFrameHeading::NewStreamWindowUpdate(ui32 streamId) noexcept {
        return TFrameHeading(EFrameType::WINDOW_UPDATE).SetStreamId(streamId);
    }

    TFrameHeading TFrameHeading::NewContinuation(ui32 streamId) noexcept {
        return TFrameHeading(EFrameType::CONTINUATION).SetStreamId(streamId);
    }

    bool TFrameHeading::HasFlagPadded() const noexcept {
        return DoGetFlag(PADDING_ALLOWED, FLAG_PADDED);
    }

    TFrameHeading& TFrameHeading::SetFlagPadded(bool padded) noexcept {
        return DoSetFlag(PADDING_ALLOWED, FLAG_PADDED, padded);
    }

    bool TFrameHeading::HasFlagPriority() const noexcept {
        return DoGetFlag(PRIORITY_ALLOWED, FLAG_PRIORITY);
    }

    TFrameHeading& TFrameHeading::SetFlagPriority(bool priority) noexcept {
        return DoSetFlag(PRIORITY_ALLOWED, FLAG_PRIORITY, priority);
    }

    bool TFrameHeading::HasFlagAck() const noexcept {
        return DoGetFlag(ACK_ALLOWED, FLAG_ACK);
    }

    TFrameHeading& TFrameHeading::SetFlagAck(bool ack) noexcept {
        return DoSetFlag(ACK_ALLOWED, FLAG_ACK, ack);
    }

    bool TFrameHeading::HasFlagEndHeaders() const noexcept {
        return DoGetFlag(END_HEADERS_ALLOWED, FLAG_END_HEADERS);
    }

    TFrameHeading& TFrameHeading::SetFlagEndHeaders(bool endHeaders) noexcept {
        return DoSetFlag(END_HEADERS_ALLOWED, FLAG_END_HEADERS, endHeaders);
    }

    bool TFrameHeading::HasFlagEndStream() const noexcept {
        return DoGetFlag(END_STREAM_ALLOWED, FLAG_END_STREAM);
    }

    TFrameHeading& TFrameHeading::SetFlagEndStream(bool endStream) noexcept {
        return DoSetFlag(END_STREAM_ALLOWED, FLAG_END_STREAM, endStream);
    }

    bool TFrameHeading::DoGetFlag(std::initializer_list<EFrameType> allowed, ui8 flag) const noexcept {
        return IsIn(allowed, Type) && (Flags & flag);
    }

    TFrameHeading& TFrameHeading::DoSetFlag(std::initializer_list<EFrameType> allowed, ui8 flag, bool value) noexcept {
        Y_VERIFY(IsIn(allowed, Type));

        if (value) {
            Flags |= (ui8)flag;
        } else {
            Flags &= (ui8)~flag;
        }
        return *this;
    }

    TChunkPtr TFrameHeading::Write() const noexcept {
        Y_VERIFY(EFrameType::INVALID != Type);

        TStackRegion<9, TOutputRegion> reg;
        auto& output = reg.GetRegion();
        output.WriteUIntUnsafe<3>(Length);
        output.WriteUIntUnsafe<1>((ui8)Type);
        output.WriteUIntUnsafe<1>(Flags);
        output.WriteUIntUnsafe<4>(StreamId & RFC_STREAM_ID_MASK);
        return reg.CopyToChunk();
    }

    TFrameHeading TFrameHeading::Parse(TStringBuf data) noexcept {
        // Should never call parse on unfinished or truncated input
        Y_VERIFY(data.size() >= RFC_FRAME_HEADING_SIZE);

        TInputRegion input{data};

        TFrameHeading result;
        result.Length = input.ReadUIntUnsafe<3, ui32>();
        result.RawType = input.ReadUIntUnsafe<1, ui8>();
        result.Type = ParseFrameType(result.RawType);
        result.Flags = input.ReadUIntUnsafe<1, ui8>();
        result.StreamId = input.ReadUIntUnsafe<4, ui32>() & RFC_STREAM_ID_MASK;
        return result;
    }

    void TFrameHeading::PrintTo(IOutputStream& out) const {
        Y_HTTP2_PRINT_OBJ(out, Type, RawType, Bin(Flags), StreamId, Length);
    }


    // DATA ============================================================================================================

    void TData::PrintTo(IOutputStream& out) const {
        Y_HTTP2_PRINT_OBJ(out, Data.size(), EndOfStream);
    }

    TError ValidateDataHeading(const TFrameHeading& heading, const TSettings& serverSettings) noexcept {
        Y_REQUIRE(heading.Length <= serverSettings.MaxFrameSize,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));

        Y_REQUIRE(heading.StreamId,
            TConnectionError(EErrorCode::PROTOCOL_ERROR, EConnProtocolError::InvalidFrame));

        if (heading.HasFlagPadded()) {
            Y_REQUIRE(heading.Length >= RFC_PADDING_SIZE_MIN,
                TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));
        }
        return {};
    }

    TError StripPadding(TFrame& frame) noexcept {
        if (frame.Heading.HasFlagPadded()) {
            Y_REQUIRE(frame.Payload->Length() >= RFC_PADDING_SIZE_MIN,
                TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));

            TInputRegion input(*frame.Payload);
            const auto paddingLen = input.ReadUIntUnsafe<1, ui8>();
            const bool paddingLenValid = paddingLen + 1 <= frame.Payload->Length();

            Y_REQUIRE(paddingLenValid,
                TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));

            frame.Payload->Chop(paddingLen);
            frame.Payload->Skip(1);
            frame.Heading.SetFlagPadded(false);
            frame.Heading.Length -= paddingLen + 1;
        }
        return {};
    }

    // HEADERS =========================================================================================================

    TError ValidateHeadersHeading(const TFrameHeading& heading, const TSettings& serverSettings) noexcept {
        // Client-initiated stream ids are odd. Server-initiated are even.
        Y_REQUIRE((heading.StreamId & 1),
            TConnectionError(EErrorCode::PROTOCOL_ERROR, EConnProtocolError::InvalidFrame));

        Y_REQUIRE(heading.Length <= serverSettings.MaxFrameSize,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));

        ui32 minLength = 0;
        if (heading.HasFlagPadded()) {
            minLength += 1;
        }
        if (heading.HasFlagPriority()) {
            minLength += RFC_PRIORITY_SIZE;
        }
        Y_REQUIRE(heading.Length >= minLength,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));
        return {};
    }


    // PRIORITY ========================================================================================================

    void TPriority::PrintTo(IOutputStream& out) const {
        Y_HTTP2_PRINT_OBJ(out, Exclusive, StreamDependency, (ui32)RawWeight, Origin);
    }

    TPriority TPriority::Parse(TStringBuf data) noexcept {
        TInputRegion reg{data};
        TPriority priority;
        const auto streamDep = reg.ReadUIntUnsafe<4, ui32>();
        priority.Exclusive = streamDep & (~RFC_STREAM_ID_MASK);
        priority.StreamDependency = streamDep & RFC_STREAM_ID_MASK;
        priority.RawWeight = reg.ReadUIntUnsafe<1, ui8>();
        priority.Origin = EOrigin::PriorityFrame;
        return priority;
    }

    TErrorOr<TMaybe<TPriority>> TPriority::Strip(TFrame& frame) noexcept {
        Y_PROPAGATE_ERROR(StripPadding(frame));

        if (frame.Heading.HasFlagPriority()) {
            Y_REQUIRE(frame.Payload->Length() >= RFC_PRIORITY_SIZE,
                TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));

            TPriority priority = TPriority::Parse(frame.Payload->AsStringBuf());
            priority.Origin = EOrigin::HeadersFrame;
            frame.Payload->Skip(RFC_PRIORITY_SIZE);
            frame.Heading.Length -= RFC_PRIORITY_SIZE;
            frame.Heading.SetFlagPriority(false);
            return TMaybe<TPriority>(priority);
        } else {
            return Nothing();
        }
    }

    TError TPriority::ValidateHeading(const TFrameHeading& heading) noexcept {
        Y_REQUIRE(heading.Length == RFC_PRIORITY_SIZE,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));
        // The standard requires us to reject the PRIORITY frames with zero StreamId.
        // Yet occasionally we receive them from Firefox. So despite the RFC 7540 requirements we'll just ignore them.
        // See BALANCER-1278 for details.
        // Y_ENSURE_EX(heading.StreamId,
        //     TConnectionError(EErrorCode::PROTOCOL_ERROR, EProtocolError::InvalidFrame));
        return {};
    }


    // RST_STREAM ======================================================================================================

    TChunkPtr WriteRstStream(EErrorCode errorCode) noexcept {
        TStackRegion<4, TOutputRegion> reg;
        reg.GetRegion().WriteUIntUnsafe<4>((ui32) errorCode);
        return reg.CopyToChunk();
    }

    EErrorCode ParseRstStream(TStringBuf data) noexcept {
        TInputRegion reg{data};
        return GetErrorCode<EErrorCode>(reg.ReadUInt<4, ui32>());
    }

    TError ValidateRstStreamHeading(const TFrameHeading& heading) noexcept {
        Y_REQUIRE(heading.Length == RFC_RST_STREAM_FRAME_SIZE,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));

        Y_REQUIRE(heading.StreamId,
            TConnectionError(EErrorCode::PROTOCOL_ERROR, EConnProtocolError::InvalidFrame));
        return {};
    }


    // PING ============================================================================================================

    TPing::TPing(EPingType pingType, TInstant send) noexcept
        : Send(send)
        , Recv(Send)
        , PingType(pingType)
    {}

    TDuration TPing::GetRTT(TDuration maxRTT) const noexcept {
        return TDuration::MilliSeconds(
            Min(
                Max(
                    Recv.MilliSeconds(), Send.MilliSeconds()
                ) - Send.MilliSeconds(),
                maxRTT.MilliSeconds()
            )
        );
    }

    void TPing::PrintTo(IOutputStream& out) const {
        Y_HTTP2_PRINT_OBJ(out, PingType, Send.MilliSeconds(),
                          Recv.MilliSeconds(), GetRTT().MilliSeconds());
    }

    TChunkPtr TPing::Write() const noexcept {
        TStackRegion<8, TOutputRegion> reg;
        reg.GetRegion().WriteUIntUnsafe<1, ui8>((ui8) PingType);
        // WriteUInt<7, ui64> truncates the most significant byte (it copies the last 7 bytes of htons(val))
        // Anyway, milliseconds won't take up more than 6 bytes any time soon.
        reg.GetRegion().WriteUIntUnsafe<7, ui64>(Send.MilliSeconds());
        return reg.CopyToChunk();
    }

    TPing TPing::Parse(TStringBuf data) noexcept {
        TInputRegion reg{data};
        TPing ping;
        ping.PingType = (EPingType)reg.ReadUIntUnsafe<1, ui8>();
        // Milliseconds won't take up more than 6 bytes any time soon.
        ping.Send = TInstant::MilliSeconds(reg.ReadUInt<7, ui64>());
        ping.Recv = TInstant::Now();
        return ping;
    }

    TError TPing::Validate(const TFrameHeading& heading) noexcept {
        Y_REQUIRE(heading.Length == RFC_PING_FRAME_SIZE,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));

        Y_REQUIRE(!heading.StreamId,
            TConnectionError(EErrorCode::PROTOCOL_ERROR, EConnProtocolError::InvalidFrame));
        return {};
    }


    // GOAWAY ==========================================================================================================

    TGoAway::TGoAway(ui32 lastStreamId, TGoAway other) noexcept
        : DebugData(std::move(other.DebugData))
        , LastStreamId(lastStreamId)
        , ErrorCode(other.ErrorCode)
    {}


    TGoAway::TGoAway(ui32 lastStreamId, EErrorCode errorCode, TStringBuf debugData) noexcept
        : DebugData(NewChunkForceCopy(debugData))
        , LastStreamId(lastStreamId)
        , ErrorCode(errorCode)
    {}

    void TGoAway::PrintTo(IOutputStream& out) const {
        Y_HTTP2_PRINT_OBJ(out, LastStreamId, ErrorCode, EscC(DebugData));
    }

    TChunkPtr TGoAway::Write() const noexcept {
        TStackRegion<8, TOutputRegion> frameOut;
        frameOut.GetRegion().WriteUIntUnsafe<4>((ui32) LastStreamId);
        frameOut.GetRegion().WriteUIntUnsafe<4>((ui32) ErrorCode);
        return frameOut.CopyToChunk();
    }

    TGoAway TGoAway::Parse(TStringBuf data) noexcept {
        Y_VERIFY(data.size() >= RFC_GOAWAY_FRAME_SIZE_MIN);

        TInputRegion reg{data};
        TGoAway goAwayFrame;
        goAwayFrame.LastStreamId = reg.ReadUInt<4, ui32>() & RFC_STREAM_ID_MASK;
        goAwayFrame.ErrorCode = GetErrorCode<EErrorCode>(reg.ReadUInt<4, ui32>());
        goAwayFrame.DebugData = TChunkList(NewChunk(
            reg.SizeAvailable(),
            NewChunkData(reg.AsStringBuf())
        ));
        return goAwayFrame;
    }

    TError TGoAway::Validate(const TFrameHeading& heading, const TSettings& serverSettings) noexcept {
        Y_REQUIRE(heading.Length <= serverSettings.MaxFrameSize,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));

        Y_REQUIRE(heading.Length >= RFC_GOAWAY_FRAME_SIZE_MIN,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));

        Y_REQUIRE(!heading.StreamId,
            TConnectionError(EErrorCode::PROTOCOL_ERROR, EConnProtocolError::InvalidFrame));
        return {};
    }


    // WINDOW_UPDATE ===================================================================================================

    TChunkPtr WriteWindowUpdate(ui32 windowSizeIncrement) noexcept {
        Y_VERIFY(windowSizeIncrement > 0 && windowSizeIncrement <= RFC_WINDOW_SIZE_MAX);
        TStackRegion<4, TOutputRegion> reg;
        reg.GetRegion().WriteUIntUnsafe<4>((ui32) windowSizeIncrement);
        return reg.CopyToChunk();
    }

    ui32 ParseWindowUpdate(TStringBuf data) noexcept {
        TInputRegion reg{data};
        return reg.ReadUInt<4, ui32>() & RFC_WINDOW_SIZE_MASK;
    }

    TError ValidateWindowUpdateHeading(const TFrameHeading& heading) noexcept {
        Y_REQUIRE(heading.Length == RFC_WINDOW_UPDATE_FRAME_SIZE,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));
        return {};
    }


    // CONTINUATION ====================================================================================================

    TError ValidateContinuationHeading(const TFrameHeading& heading, const TSettings& serverSettings) noexcept {
        Y_REQUIRE(heading.Length <= serverSettings.MaxFrameSize,
            TConnectionError(EErrorCode::FRAME_SIZE_ERROR, TUnspecifiedReason()));
        return {};
    }
}

Y_HTTP2_GEN_PRINT(TFrameHeading);
Y_HTTP2_GEN_PRINT(TData);
Y_HTTP2_GEN_PRINT(TPriority);
Y_HTTP2_GEN_PRINT(TPing);
Y_HTTP2_GEN_PRINT(TGoAway);
