package message

import (
	"testing"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type badMessage struct{}

var _ protocol.Message = &badMessage{}

func (*badMessage) OpCode() protocol.OpCode { return 0 }
func (*badMessage) Marshal(ver protocol.Version) ([]byte, error) {
	return nil, protocol.ErrInvalidHeader
}
func (*badMessage) Unmarshal(ver protocol.Version, bytes []byte) error {
	return protocol.ErrInvalidHeader
}
func (*badMessage) String() string { return "<unknown: 0>" }

func TestForward(t *testing.T) {
	src := NewForwardedAddr("network", "address")
	msg, merr := NewAuthHost(protocol.FirstAckID.Next(), protocol.Current, "hostname", OpaqueBytes("token"))
	require.NoError(t, merr)

	t.Run("should report values correctly", func(t *testing.T) {
		fwd, err := NewForward(true, src, msg)
		require.NoError(t, err)
		assert.Equal(t, protocol.Forward, fwd.OpCode())
		assert.True(t, fwd.IsInit())
		assert.Equal(t, src, fwd.Source())
		assert.Equal(t, msg, fwd.Message())
		assert.Equal(t, "<forward> source=address, msg=(<auth_host> reqID=3, hostname=hostname, version=2, ####)", fwd.String())
	})

	t.Run("should marshal correctly", func(t *testing.T) {
		fwd, err := NewForward(true, src, msg)
		require.NoError(t, err)
		bytes, err := fwd.Marshal(protocol.Current)
		require.NoError(t, err)
		out, err := Unmarshal(bytes)
		require.NoError(t, err)
		assert.Equal(t, fwd, out)
	})

	t.Run("should report illegal addresses", func(t *testing.T) {
		_, err := NewForward(false, nil, msg)
		assert.Equal(t, protocol.ErrInvalidAddress, err)
	})

	t.Run("should report illegal messages", func(t *testing.T) {
		_, err := NewForward(false, src, nil)
		assert.Equal(t, protocol.ErrInvalidHeader, err)
	})

	t.Run("should report version errors", func(t *testing.T) {
		_, err := createBlank(protocol.Forward).Marshal(protocol.Unknown)
		assert.Equal(t, protocol.ErrInvalidVersion, err)
		assert.Equal(t, protocol.ErrInvalidVersion, createBlank(protocol.Forward).Unmarshal(protocol.Unknown, []byte{}))
	})

	t.Run("should report inner mesage errors", func(t *testing.T) {
		msg, err := NewForward(false, src, &badMessage{})
		require.NoError(t, err)
		bytes, err := msg.Marshal(protocol.Current)
		assert.Empty(t, bytes)
		assert.Equal(t, protocol.ErrInvalidHeader, err)
		bind, err := NewBind(protocol.FirstAckID, OpaqueBytes("name"), &badCreds{})
		require.NoError(t, err)
		msg, err = NewForward(false, src, bind)
		require.NoError(t, err)
		bytes, err = msg.Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, stream.ErrInvalidCredentialFormat, createBlank(protocol.Forward).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report length errors", func(t *testing.T) {
		assert.Equal(t, protocol.ErrInvalidLength(forwardNetworkOffset, 0),
			createBlank(protocol.Forward).Unmarshal(protocol.Current, []byte{}))
	})

	t.Run("should handle EOF gracefully", func(t *testing.T) {
		broken := make([]byte, forwardNetworkOffset)
		injectHeader(broken, protocol.One, protocol.Bind)
		broken = append(broken, "network"...) // truncated message
		assert.Equal(t, protocol.ErrEOF, createBlank(protocol.Forward).Unmarshal(protocol.Current, broken))
		broken = append(broken, "\000addr"...) // append address
		assert.Equal(t, protocol.ErrEOF, createBlank(protocol.Forward).Unmarshal(protocol.Current, broken))
		broken = append(broken, "\000"...) // append msg start but no message
		assert.Equal(t, protocol.ErrInvalidHeader, createBlank(protocol.Forward).Unmarshal(protocol.Current, broken))
	})
}
