package message

import (
	"testing"
	"time"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type badCreds struct{ stream.Credentials }

func (*badCreds) ClientID() string               { return "_" }
func (*badCreds) MarshalBinary() ([]byte, error) { return []byte{'X'}, nil }

type badAddr struct{ stream.Address }

func (*badAddr) Key() stream.AddressKey { return stream.AddressKey("!!bad") }

func TestReserve(t *testing.T) {
	addr, errx := stream.NewAddress(stream.Namespace("test"), stream.Version(1), map[string]string{})
	require.NoError(t, errx)
	creds := stream.NewTimedCredentials(
		"client",
		stream.AddressScopes{stream.AnyAddress},
		stream.AddressScopes{},
		time.Now().AddDate(0, 0, 1),
	)

	t.Run("should report values correctly", func(t *testing.T) {
		request, err := NewReserve(protocol.FirstRequestID, addr, creds)
		require.NoError(t, err)
		assert.Equal(t, protocol.Reserve, request.OpCode())
		assert.Equal(t, protocol.FirstRequestID, request.RequestID())
		assert.Equal(t, addr, request.Address())
		assert.Equal(t, creds, request.Credentials())
		assert.Equal(t, "<reserve> reqID=1, addr=test@1, creds=client", request.String())
	})

	t.Run("should marshal correctly", func(t *testing.T) {
		request, err := NewReserve(protocol.FirstRequestID, addr, creds)
		require.NoError(t, err)
		bytes, err := request.Marshal(protocol.Current)
		require.NoError(t, err)
		out, err := Unmarshal(bytes)
		require.NoError(t, err)
		assert.Equal(t, request, out)
	})

	t.Run("should report request id errors", func(t *testing.T) {
		_, err := NewReserve(protocol.NoRequestID, addr, creds)
		assert.Equal(t, protocol.ErrInvalidRequestID, err)
		bytes, err := (&reserveMessage{protocol.NoRequestID, addr, creds}).Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, protocol.ErrInvalidRequestID, createBlank(protocol.Reserve).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report version errors", func(t *testing.T) {
		_, err := createBlank(protocol.Reserve).Marshal(protocol.Unknown)
		assert.Equal(t, protocol.ErrInvalidVersion, err)
		assert.Equal(t, protocol.ErrInvalidVersion, createBlank(protocol.Reserve).Unmarshal(protocol.Unknown, []byte{}))
	})

	t.Run("should report address errors", func(t *testing.T) {
		_, err := NewReserve(protocol.FirstRequestID, nil, creds)
		assert.Equal(t, protocol.ErrInvalidAddress, err)
		bytes, err := (&reserveMessage{protocol.FirstRequestID, &badAddr{}, creds}).Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, stream.ErrMissingRequiredVersion, createBlank(protocol.Reserve).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report credential errors", func(t *testing.T) {
		_, err := NewReserve(protocol.FirstRequestID, addr, nil)
		assert.Equal(t, protocol.ErrMissingCredentials, err)
		bytes, err := (&reserveMessage{protocol.FirstRequestID, addr, &badCreds{}}).Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, stream.ErrInvalidCredentialFormat, createBlank(protocol.Reserve).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report length errors", func(t *testing.T) {
		assert.Equal(t, protocol.ErrInvalidLength(reserveAddressOffset, 0),
			createBlank(protocol.Reserve).Unmarshal(protocol.Current, []byte{}))
	})

	t.Run("should handle EOF gracefully", func(t *testing.T) {
		broken := make([]byte, reserveAddressOffset)
		injectHeader(broken, protocol.One, protocol.Reserve)
		broken = append(broken, "key"...)
		assert.Equal(t, protocol.ErrEOF, createBlank(protocol.Reserve).Unmarshal(protocol.Current, broken))
	})
}
