package cacheauthorization

import (
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"errors"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchS2S2/internal/authorization"
	"code.justin.tv/amzn/TwitchS2S2/internal/token"
	"code.justin.tv/amzn/TwitchS2SJWTAlgorithms/es256"
	"code.justin.tv/sse/jwt"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestLRU(t *testing.T) {
	const authorizationType = "AUTHORIZATIONTYPE"
	const badAuthToken = "BADAUTHTOKEN"

	authorizationToken := authTokenConfig(t, true, false)

	baseAuthz := func() *authorization.Authorization {
		return &authorization.Authorization{JWTID: "JWTID"}
	}

	t.Run("Fetch", func(t *testing.T) {
		t.Run("hit", func(t *testing.T) {
			lru := newLru(1)
			err := lru.Put(authorizationType, authorizationToken, baseAuthz())
			require.NoError(t, err)
			res, ok, err := lru.Fetch(authorizationType, authorizationToken)
			require.NoError(t, err)
			require.True(t, ok)
			assert.Equal(t, baseAuthz(), res)
		})

		t.Run("miss", func(t *testing.T) {
			lru := newLru(1)
			_, ok, err := lru.Fetch(authorizationType, authorizationToken)
			require.NoError(t, err)
			require.False(t, ok)
		})

		t.Run("error", func(t *testing.T) {
			expected := errors.New("jwt: invalid section count: 0")
			lru := newLru(1)
			_, _, err := lru.Fetch(authorizationType, badAuthToken)
			assert.Equal(t, expected, err)
		})

		t.Run("hit but different token", func(t *testing.T) {
			lru := newLru(1)
			err := lru.Put(authorizationType, authorizationToken, baseAuthz())
			require.NoError(t, err)
			_, ok, err := lru.Fetch(authorizationType, authTokenConfig(t, false, false))
			require.NoError(t, err)
			require.False(t, ok)
		})
	})

	t.Run("Put", func(t *testing.T) {
		t.Run("below limit", func(t *testing.T) {
			lru := newLru(2)
			err := lru.Put(authorizationType, authorizationToken, baseAuthz())
			require.NoError(t, err)
			res, ok, err := lru.Fetch(authorizationType, authorizationToken)
			require.NoError(t, err)
			require.True(t, ok)
			assert.Equal(t, baseAuthz(), res)
			assert.Equal(t, 1, len(lru.cache))
			assert.Equal(t, 1, lru.evictList.Len())
		})

		t.Run("above limit", func(t *testing.T) {
			lru := newLru(2)
			err := lru.Put(authorizationType, authorizationToken, baseAuthz())
			require.NoError(t, err)
			err = lru.Put("type1", authorizationToken, nil)
			require.NoError(t, err)
			err = lru.Put("type2", authorizationToken, nil)
			require.NoError(t, err)
			_, ok, err := lru.Fetch(authorizationType, authorizationToken)
			require.NoError(t, err)
			require.False(t, ok)
			assert.Equal(t, 2, len(lru.cache))
			assert.Equal(t, 2, lru.evictList.Len())
		})

		t.Run("error", func(t *testing.T) {
			expected := errors.New("jwt: invalid section count: 0")
			lru := newLru(1)
			err := lru.Put(authorizationType, badAuthToken, baseAuthz())
			assert.Equal(t, expected, err)
		})
	})

	t.Run("getClaimsFragment", func(t *testing.T) {
		tcs := []struct {
			Name               string
			AuthorizationToken string
			ExpectedResult     string
			ExpectedError      error
		}{
			{
				Name:               "valid token",
				AuthorizationToken: "header.claims.signature",
				ExpectedResult:     "claims",
			},
			{
				Name:               "empty string",
				AuthorizationToken: "",
				ExpectedError:      errors.New("jwt: invalid section count: 0"),
			},
			{
				Name:               "no periods",
				AuthorizationToken: "header",
				ExpectedError:      errors.New("jwt: invalid section count: 0"),
			},
			{
				Name:               "no signature",
				AuthorizationToken: "header.claims",
				ExpectedError:      errors.New("jwt: invalid section count: 1"),
			},
			{
				Name:               "too many sections",
				AuthorizationToken: "header.claims.signature.somethingelse",
				ExpectedError:      errors.New("jwt: invalid section count: more than 2"),
			},
		}

		lru := newLru(1)
		for _, tc := range tcs {
			tc := tc
			t.Run(tc.Name, func(t *testing.T) {
				t.Parallel()
				result, err := lru.getClaimsFragment(tc.AuthorizationToken)
				require.Equal(t, tc.ExpectedError, err)
				assert.Equal(t, tc.ExpectedResult, result)
			})
		}
	})
}

// authTokenConfig returns a fake JWT to be parsed.
func authTokenConfig(t *testing.T, active, JWTIDint bool) string {
	const keyID = "KEYID"
	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	require.NoError(t, err)

	jwtAlgorithm := &es256.PrivateKey{PrivateKey: key, KeyID: keyID}

	jwtHeader := struct {
		KeyID     string `json:"kid"`
		Algorithm string `json:"alg"`
		Type      string `json:"typ"`
	}{
		KeyID:     keyID,
		Algorithm: "ES256",
		Type:      "JWT",
	}

	// Returns a fake JWT that is parsable but cannot be unmarshalled due to the
	// decoder wanting a string for JWTID
	if JWTIDint {
		type jtiInt struct {
			JWTID int `json:"jti"`
		}
		res, err := jwt.Encode(jwtHeader, &jtiInt{JWTID: 1}, jwtAlgorithm)
		require.NoError(t, err)
		return string(res)
	}

	marshalledJWTWithHeader := func(t *testing.T, header interface{}, a *authorization.Authorization) string {
		res, err := jwt.Encode(header, a, jwtAlgorithm)
		require.NoError(t, err)
		return string(res)
	}

	marshalledJWT := func(t *testing.T, a *authorization.Authorization) string {
		return marshalledJWTWithHeader(t, jwtHeader, a)
	}

	now := time.Now().UTC()

	authz := &authorization.Authorization{
		Subject:    authorization.NewSubject("my-id"),
		Audience:   authorization.NewAudience("aud1", "aud2", "aud3"),
		Scope:      token.NewScope("scope1", "scope2"),
		JWTID:      "JWTID",
		Issuer:     "ISSUER",
		Active:     active,
		NotBefore:  now.Add(-time.Hour),
		IssuedAt:   now,
		Expiration: now.Add(time.Hour),
	}

	return marshalledJWT(t, authz)
}
