package message

import (
	"encoding/json"
	"testing"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/errors"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type illegalError struct{}

var _ errors.DetailsError = (*illegalError)(nil)
var _ json.Marshaler = (*illegalError)(nil)

func (*illegalError) Error() string                  { return "boom" }
func (i *illegalError) Details() errors.Details      { return errors.Details{"k": i} }
func (i *illegalError) MarshalJSON() ([]byte, error) { return nil, i }

func TestError(t *testing.T) {
	t.Run("should report values correctly", func(t *testing.T) {
		msg, err := NewError(protocol.FirstAckID, protocol.ErrInvalidHeader)
		require.NoError(t, err)
		assert.Equal(t, protocol.Error, msg.OpCode())
		assert.Equal(t, protocol.FirstAckID, msg.ForAckID())
		assert.EqualError(t, msg.Unwrap(), protocol.ErrInvalidHeader.Error())
		assert.Equal(t, "<error> reqID=1, code=invalid_header, err=Invalid header", msg.String())
	})

	t.Run("should marshal correctly", func(t *testing.T) {
		join, err := NewError(protocol.FirstAckID, nil)
		require.NoError(t, err)
		bytes, err := join.Marshal(protocol.Current)
		require.NoError(t, err)
		out, err := Unmarshal(bytes)
		require.NoError(t, err)
		assert.Equal(t, join, out)
	})

	t.Run("should report marshaling errors", func(t *testing.T) {
		join, err := NewError(protocol.FirstAckID, &illegalError{})
		require.NoError(t, err)
		out, err := join.Marshal(protocol.Current)
		assert.Nil(t, out)
		assert.IsType(t, &json.MarshalerError{}, err)
	})

	t.Run("should report unmarshaling errors", func(t *testing.T) {
		join, err := NewError(protocol.FirstAckID, nil)
		require.NoError(t, err)
		bytes, err := join.Marshal(protocol.Current)
		require.NoError(t, err)
		bytes = append(bytes, '{')
		out, err := Unmarshal(bytes)
		assert.Nil(t, out)
		assert.IsType(t, &json.SyntaxError{}, err)
	})

	t.Run("should report version errors", func(t *testing.T) {
		_, err := createBlank(protocol.Error).Marshal(protocol.Unknown)
		assert.Equal(t, protocol.ErrInvalidVersion, err)
		assert.Equal(t, protocol.ErrInvalidVersion, createBlank(protocol.Error).Unmarshal(protocol.Unknown, []byte{}))
	})

	t.Run("should report length errors", func(t *testing.T) {
		assert.Equal(t, protocol.ErrInvalidLength(errorMessageOffset, 0),
			createBlank(protocol.Error).Unmarshal(protocol.Current, []byte{}))
	})
}
