package extjwt

import (
	"testing"
	"time"

	"code.justin.tv/common/jwt"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func testRequest(clientID string, claims *claims) stream.AuthRequest {
	return NewRequest(clientID, testToken(claims))
}

func TestResolver(t *testing.T) {
	res := NewResolver(NewFakeValidator(100, 0, nil))

	t.Run("should report its method correctly", func(t *testing.T) {
		assert.Equal(t, method, res.Method())
	})

	t.Run("should allow '*' to connect to any address with the same clientID", func(t *testing.T) {
		creds, err := res.Resolve(testRequest(clientID, &claims{
			ChannelID: channelID,
			Verbs:     map[string][]string{listenVerb: []string{"*"}},
			Expires:   time.Now().AddDate(0, 0, 1).Unix(),
		}))
		require.NoError(t, err)
		assert.False(t, creds.CanListen(ext2Addr))
		assert.True(t, creds.CanListen(anyChannelAddr))
		assert.True(t, creds.CanListen(globalAddr))
		assert.True(t, creds.CanListen(channelAddr))
		assert.True(t, creds.CanListen(channel2Addr))
	})

	t.Run("should allow 'broadcast' to connect to any address with the same clientID and channel", func(t *testing.T) {
		creds, err := res.Resolve(testRequest(clientID, &claims{
			ChannelID: channelID,
			Verbs:     map[string][]string{listenVerb: []string{"broadcast"}},
			Expires:   time.Now().AddDate(0, 0, 1).Unix(),
		}))
		require.NoError(t, err)
		assert.False(t, creds.CanListen(ext2Addr))
		assert.False(t, creds.CanListen(anyChannelAddr))
		assert.False(t, creds.CanListen(globalAddr))
		assert.True(t, creds.CanListen(channelAddr))
		assert.False(t, creds.CanListen(channel2Addr))
	})

	t.Run("should allow 'global' to connect to any address with the same clientID and '*' channel", func(t *testing.T) {
		creds, err := res.Resolve(testRequest(clientID, &claims{
			ChannelID: channelID,
			Verbs:     map[string][]string{listenVerb: []string{"global"}},
			Expires:   time.Now().AddDate(0, 0, 1).Unix(),
		}))
		require.NoError(t, err)
		assert.False(t, creds.CanListen(ext2Addr))
		assert.False(t, creds.CanListen(anyChannelAddr))
		assert.True(t, creds.CanListen(globalAddr))
		assert.False(t, creds.CanListen(channelAddr))
		assert.False(t, creds.CanListen(channel2Addr))
	})

	t.Run("should allow 'global' + 'broadcast' to connect to the super set of the two", func(t *testing.T) {
		creds, err := res.Resolve(testRequest(clientID, &claims{
			ChannelID: channelID,
			Verbs:     map[string][]string{listenVerb: []string{"global", "broadcast"}},
			Expires:   time.Now().AddDate(0, 0, 1).Unix(),
		}))
		require.NoError(t, err)
		assert.False(t, creds.CanListen(ext2Addr))
		assert.False(t, creds.CanListen(anyChannelAddr))
		assert.True(t, creds.CanListen(globalAddr))
		assert.True(t, creds.CanListen(channelAddr))
		assert.False(t, creds.CanListen(channel2Addr))
	})

	t.Run("should ignore unknown verb scopes", func(t *testing.T) {
		creds := payloadToCredentials(clientID, &claims{
			ChannelID: channelID,
			Verbs:     map[string][]string{listenVerb: []string{"unknown"}},
		})
		assert.False(t, creds.CanListen(ext2Addr))
		assert.False(t, creds.CanListen(anyChannelAddr))
		assert.False(t, creds.CanListen(globalAddr))
		assert.False(t, creds.CanListen(channelAddr))
		assert.False(t, creds.CanListen(channel2Addr))
	})

	t.Run("should forward the error when incompatible claims are provided", func(t *testing.T) {
		token, err := jwt.Encode(
			struct{}{},
			map[string]interface{}{
				"channel_id":   1,
				"pubsub_perms": 0,
				"exp":          123,
			},
			jwt.None,
		)
		require.NoError(t, err)

		_, err = res.Resolve(NewRequest(clientID, token))
		assert.IsType(t, jwt.Err{}, err)
	})

	t.Run("should forward validator errors", func(t *testing.T) {
		badRes := NewResolver(NewFakeValidator(0, 100, protocol.ErrInvalidAddress))
		_, err := badRes.Resolve(testRequest(clientID, &claims{
			ChannelID: channelID,
			Verbs:     map[string][]string{listenVerb: []string{"*"}},
		}))
		assert.Equal(t, protocol.ErrInvalidAddress, err)
	})

	t.Run("should reject invalid requests", func(t *testing.T) {
		_, err := res.Resolve(nil)
		assert.Equal(t, stream.ErrInvalidAuthMethod, err)
	})
}
