package message

import (
	"testing"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestAuth(t *testing.T) {
	hostname := "wss://fakename.com"
	token := OpaqueBytes("Admit One")

	t.Run("should report values correctly", func(t *testing.T) {
		auth, err := NewAuthHost(protocol.FirstAckID, protocol.Current, hostname, token)
		require.NoError(t, err)
		assert.Equal(t, protocol.AuthHost, auth.OpCode())
		assert.Equal(t, protocol.FirstAckID, auth.AckID())
		assert.Equal(t, protocol.Current, auth.Version())
		assert.Equal(t, hostname, auth.Hostname())
		assert.Equal(t, token, auth.Token())
		assert.Equal(t, "<auth_host> reqID=1, hostname=wss://fakename.com, version=2, ####", auth.String())
		assert.True(t, auth.Equals(auth))
	})

	t.Run("should marshal correctly", func(t *testing.T) {
		auth, err := NewAuthHost(protocol.FirstAckID, protocol.Current, hostname, token)
		require.NoError(t, err)
		bytes, err := auth.Marshal(protocol.Current)
		require.NoError(t, err)
		out, err := Unmarshal(bytes)
		require.NoError(t, err)
		assert.Equal(t, auth, out)

		cast, ok := out.(AuthHost)
		require.True(t, ok)
		assert.True(t, auth.Equals(cast))
		assert.True(t, cast.Equals(auth))
	})

	t.Run("should report ack id errors", func(t *testing.T) {
		_, err := NewAuthHost(protocol.NoAckID, protocol.Current, hostname, token)
		assert.Equal(t, protocol.ErrInvalidAckID, err)
		bytes, err := createBlank(protocol.AuthHost).Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, protocol.ErrInvalidAckID, createBlank(protocol.AuthHost).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report version errors", func(t *testing.T) {
		_, err := NewAuthHost(protocol.FirstAckID, protocol.Unknown, hostname, token)
		assert.Equal(t, protocol.ErrInvalidVersion, err)
		_, err = createBlank(protocol.AuthHost).Marshal(protocol.Unknown)
		assert.Equal(t, protocol.ErrInvalidVersion, err)
		assert.Equal(t, protocol.ErrInvalidVersion, createBlank(protocol.AuthHost).Unmarshal(protocol.Unknown, []byte{}))
	})

	t.Run("should report length errors", func(t *testing.T) {
		assert.Equal(t, protocol.ErrInvalidLength(authHostnameOffset, 0),
			createBlank(protocol.AuthHost).Unmarshal(protocol.Current, []byte{}))
	})

	t.Run("should report length errors", func(t *testing.T) {
		assert.Equal(t, protocol.ErrInvalidLength(authHostnameOffset, 0),
			createBlank(protocol.AuthHost).Unmarshal(protocol.Current, []byte{}))
	})

	t.Run("should handle EOF gracefully", func(t *testing.T) {
		broken := make([]byte, authHostnameOffset)
		injectHeader(broken, protocol.Current, protocol.AuthHost)
		broken = append(broken, "key"...)
		assert.Equal(t, protocol.ErrEOF, createBlank(protocol.AuthHost).Unmarshal(protocol.Current, broken))
	})
}
