package message

import (
	"testing"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestAccepted(t *testing.T) {
	sid := protocol.FirstSuggestion("source")
	host := "wss://localhost:8080"
	addr, errx := stream.NewAddress(stream.Namespace("test"), stream.Version(1), map[string]string{})
	require.NoError(t, errx)

	t.Run("should report values correctly", func(t *testing.T) {
		request, err := NewAccepted(addr, sid, host)
		require.NoError(t, err)
		assert.Equal(t, protocol.Accepted, request.OpCode())
		assert.Equal(t, sid, request.ID())
		assert.Equal(t, addr, request.Address())
		assert.Equal(t, host, request.Hostname())
		assert.Equal(t, "<accepted> suggestion=source#0, addr=test@1, hostname=wss://localhost:8080", request.String())
	})

	t.Run("should marshal correctly", func(t *testing.T) {
		request, err := NewAccepted(addr, sid, host)
		require.NoError(t, err)
		bytes, err := request.Marshal(protocol.Current)
		require.NoError(t, err)
		out, err := Unmarshal(bytes)
		require.NoError(t, err)
		assert.Equal(t, request, out)
	})

	t.Run("should report suggestion id errors", func(t *testing.T) {
		_, err := NewAccepted(addr, protocol.NoSuggestion, host)
		assert.Equal(t, protocol.ErrInvalidSuggestionID, err)
		_, err = (&acceptedMessage{addr, protocol.NoSuggestion, host}).Marshal(protocol.Current)
		assert.Equal(t, protocol.ErrInvalidSuggestionID, err)
	})

	t.Run("should report version errors", func(t *testing.T) {
		_, err := createBlank(protocol.Accepted).Marshal(protocol.Unknown)
		assert.Equal(t, protocol.ErrInvalidVersion, err)
		assert.Equal(t, protocol.ErrInvalidVersion, createBlank(protocol.Accepted).Unmarshal(protocol.Unknown, []byte{}))
	})

	t.Run("should report address errors", func(t *testing.T) {
		_, err := NewAccepted(nil, sid, host)
		assert.Equal(t, protocol.ErrInvalidAddress, err)
		bytes, err := (&acceptedMessage{&badAddr{}, sid, host}).Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, stream.ErrMissingRequiredVersion, createBlank(protocol.Accepted).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report empty hostname", func(t *testing.T) {
		_, err := NewAccepted(addr, sid, "")
		assert.Equal(t, protocol.ErrInvalidHost, err)
		bytes, err := (&acceptedMessage{addr, sid, ""}).Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, protocol.ErrInvalidHost, createBlank(protocol.Accepted).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report invalid hostname", func(t *testing.T) {
		_, err := NewAccepted(addr, sid, "bad!url:")
		assert.Equal(t, protocol.ErrInvalidHost, err)
		bytes, err := (&acceptedMessage{addr, sid, "bad!url:"}).Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, protocol.ErrInvalidHost, createBlank(protocol.Accepted).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report length errors", func(t *testing.T) {
		assert.Equal(t, protocol.ErrInvalidLength(acceptedAddressOffset, 0),
			createBlank(protocol.Accepted).Unmarshal(protocol.Current, []byte{}))
	})

	t.Run("should handle EOF gracefully", func(t *testing.T) {
		broken := make([]byte, acceptedAddressOffset)
		injectHeader(broken, protocol.One, protocol.Accepted)
		broken = append(broken, "key"...)
		assert.Equal(t, protocol.ErrEOF, createBlank(protocol.Accepted).Unmarshal(protocol.Current, broken))
		broken = append(broken, "\000source"...)
		assert.Equal(t, protocol.ErrEOF, createBlank(protocol.Accepted).Unmarshal(protocol.Current, broken))
	})
}
