package message

import (
	"testing"

	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestScopes(t *testing.T) {
	delta := stream.AddressSourceMap{stream.AnyAddress.Key(): 0}
	t.Run("should report values correctly", func(t *testing.T) {
		scopes, err := NewScopes(protocol.FirstAckID, delta, true)
		require.NoError(t, err)
		assert.Equal(t, protocol.Scopes, scopes.OpCode())
		assert.Equal(t, protocol.FirstAckID, scopes.AckID())
		assert.Equal(t, delta, scopes.Scopes())
		assert.True(t, scopes.Remove())
		assert.Equal(t, "<scopes> reqID=1, remove=true, scopes=map[*:0]", scopes.String())
	})

	t.Run("should marshal correctly", func(t *testing.T) {
		scopes, err := NewScopes(protocol.FirstAckID, delta, false)
		require.NoError(t, err)
		bytes, err := scopes.Marshal(protocol.Current)
		require.NoError(t, err)
		out, err := Unmarshal(bytes)
		require.NoError(t, err)
		assert.Equal(t, scopes, out)
	})

	t.Run("should compress scopes correctly", func(t *testing.T) {
		addr1, err := stream.ParseAddress("n@1")
		require.NoError(t, err)
		addr2, err := stream.ParseAddress("x@1?a=b")
		require.NoError(t, err)
		elements := stream.AddressSourceMap{stream.AnyAddress.Key(): 0, addr1.Key(): 1, addr2.Key(): 2}
		scopes, err := NewScopes(protocol.FirstAckID, elements, true)
		require.NoError(t, err)
		bytes, err := scopes.Marshal(protocol.Current)
		require.NoError(t, err)
		out, err := Unmarshal(bytes)
		require.NoError(t, err)
		assert.Equal(t, scopes, out)
	})

	t.Run("should report ack id errors", func(t *testing.T) {
		_, err := NewScopes(protocol.NoAckID, delta, true)
		assert.Equal(t, protocol.ErrInvalidAckID, err)
		bytes, err := createBlank(protocol.Scopes).Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, protocol.ErrInvalidAckID, createBlank(protocol.Scopes).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report illegal address scopes (last address)", func(t *testing.T) {
		scopes, err := NewScopes(protocol.FirstAckID, nil, true)
		require.NoError(t, err)
		bytes, err := scopes.Marshal(protocol.Current)
		require.NoError(t, err)
		bytes = append(bytes, []byte("x\000n@1")...)
		assert.Equal(t, stream.ErrMissingRequiredVersion, createBlank(protocol.Scopes).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report illegal address scopes (first address)", func(t *testing.T) {
		scopes, err := NewScopes(protocol.FirstAckID, nil, false)
		require.NoError(t, err)
		bytes, err := scopes.Marshal(protocol.Current)
		require.NoError(t, err)
		bytes = append(bytes, []byte("n@1\000x")...)
		assert.Equal(t, stream.ErrMissingRequiredVersion, createBlank(protocol.Scopes).Unmarshal(protocol.Current, bytes))
	})

	t.Run("should report version errors", func(t *testing.T) {
		_, err := createBlank(protocol.Scopes).Marshal(protocol.Unknown)
		assert.Equal(t, protocol.ErrInvalidVersion, err)
		assert.Equal(t, protocol.ErrInvalidVersion, createBlank(protocol.Scopes).Unmarshal(protocol.Unknown, []byte{}))
	})

	t.Run("should report length errors", func(t *testing.T) {
		assert.Equal(t, protocol.ErrInvalidLength(scopesContentOffset, 0),
			createBlank(protocol.Scopes).Unmarshal(protocol.Current, []byte{}))
	})
}

func TestScopesV1(t *testing.T) {
	addr1, err := stream.ParseAddress("n@1")
	require.NoError(t, err)
	addr2, err := stream.ParseAddress("x@1?a=b")
	require.NoError(t, err)
	elements := stream.AddressSourceMap{stream.AnyAddress.Key(): stream.None, addr1.Key(): 1, addr2.Key(): 2}
	revised := stream.AddressSourceMap{stream.AnyAddress.Key(): stream.None, addr1.Key(): stream.None, addr2.Key(): stream.None}

	t.Run("should marshal correctly", func(t *testing.T) {
		scopes, err := NewScopes(protocol.FirstAckID, elements, false)
		require.NoError(t, err)
		bytes, err := scopes.Marshal(protocol.One)
		require.NoError(t, err)
		out, err := Unmarshal(bytes)
		require.NoError(t, err)
		// this process has stripped out sources
		assert.NotEqual(t, scopes, out)
		noSources, err := NewScopes(protocol.FirstAckID, revised, false)
		require.NoError(t, err)
		assert.Equal(t, noSources, out)
	})

	t.Run("should marshal remove correctly", func(t *testing.T) {
		scopes, err := NewScopes(protocol.FirstAckID, elements, true)
		require.NoError(t, err)
		bytes, err := scopes.Marshal(protocol.One)
		require.NoError(t, err)
		out, err := Unmarshal(bytes)
		require.NoError(t, err)
		// this process has stripped out sources
		assert.NotEqual(t, scopes, out)
		noSources, err := NewScopes(protocol.FirstAckID, revised, true)
		require.NoError(t, err)
		assert.Equal(t, noSources, out)
	})

	t.Run("should report ack id errors", func(t *testing.T) {
		_, err := NewScopes(protocol.NoAckID, elements, true)
		assert.Equal(t, protocol.ErrInvalidAckID, err)
		bytes, err := createBlank(protocol.Scopes).Marshal(protocol.Current)
		require.NoError(t, err)
		assert.Equal(t, protocol.ErrInvalidAckID, createBlank(protocol.Scopes).Unmarshal(protocol.One, bytes))
	})

	t.Run("should report illegal address scopes (last address)", func(t *testing.T) {
		scopes, err := NewScopes(protocol.FirstAckID, nil, true)
		require.NoError(t, err)
		bytes, err := scopes.Marshal(protocol.Current)
		require.NoError(t, err)
		bytes = append(bytes, []byte("x\000n@1")...)
		assert.Equal(t, stream.ErrMissingRequiredVersion, createBlank(protocol.Scopes).Unmarshal(protocol.One, bytes))
	})

	t.Run("should report illegal address scopes (first address)", func(t *testing.T) {
		scopes, err := NewScopes(protocol.FirstAckID, nil, false)
		require.NoError(t, err)
		bytes, err := scopes.Marshal(protocol.One)
		require.NoError(t, err)
		bytes = append(bytes, []byte("n@1\000x")...)
		assert.Equal(t, stream.ErrMissingRequiredVersion, createBlank(protocol.Scopes).Unmarshal(protocol.One, bytes))
	})

	t.Run("should reject bad addresses", func(t *testing.T) {
		scopes, err := NewScopes(protocol.FirstAckID, stream.AddressSourceMap{stream.AddressKey("n@q"): 0}, false)
		require.NoError(t, err)
		bytes, err := scopes.Marshal(protocol.One)
		require.Equal(t, stream.ErrVersionSyntax, err)
		assert.Nil(t, bytes)
	})

	t.Run("should report length errors", func(t *testing.T) {
		assert.Equal(t, protocol.ErrInvalidLength(scopesContentOffset, 0),
			createBlank(protocol.Scopes).Unmarshal(protocol.One, []byte{}))
	})
}
