package authorization

import (
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"encoding/json"
	"errors"
	"math/big"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchS2S2/internal/authorization/mocks"
	"code.justin.tv/amzn/TwitchS2S2/internal/oidc"
	"code.justin.tv/amzn/TwitchS2S2/internal/token"
	"code.justin.tv/amzn/TwitchS2SJWTAlgorithms/es256"
	"code.justin.tv/sse/jwt"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
)

func TestAuthorizations(t *testing.T) {
	ctx := context.Background()
	now := time.Now().UTC().Truncate(time.Second)

	t.Run("Validate", func(t *testing.T) {
		const keyID = "KEYID"
		const issuerID = "ISSUERID"
		key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
		require.NoError(t, err)

		jwtAlgorithm := &es256.PrivateKey{PrivateKey: key, KeyID: keyID}

		jwtHeader := struct {
			KeyID     string `json:"kid"`
			Algorithm string `json:"alg"`
			Type      string `json:"typ"`
		}{
			KeyID:     keyID,
			Algorithm: "ES256",
			Type:      "JWT",
		}

		marshalledJWTWithClaims := func(t *testing.T, claims interface{}) string {
			res, err := jwt.Encode(jwtHeader, claims, jwtAlgorithm)
			require.NoError(t, err)
			return string(res)
		}

		marshalledJWTWithHeader := func(t *testing.T, header interface{}, a *Authorization) string {
			res, err := jwt.Encode(header, a, jwtAlgorithm)
			require.NoError(t, err)
			return string(res)
		}

		marshalledJWT := func(t *testing.T, a *Authorization) string {
			return marshalledJWTWithHeader(t, jwtHeader, a)
		}

		authz := &Authorization{
			Subject:    &subject{id: "my-id"},
			Audience:   NewAudience("aud1", "aud2", "aud3"),
			Scope:      token.NewScope("scope1", "scope2"),
			JWTID:      "JWTID",
			Issuer:     issuerID,
			Active:     true,
			NotBefore:  now.Add(-time.Hour),
			IssuedAt:   now,
			Expiration: now.Add(2 * time.Hour),
		}

		t.Run("success", func(t *testing.T) {
			at := newAuthorizationsTest()
			defer at.Teardown(t)

			at.OIDC.On("Configuration").
				Return(&oidc.Configuration{IssuerID: issuerID})

			at.OIDC.On("ValidationKey", mock.Anything, keyID).
				Return(&es256.PublicKey{PublicKey: &key.PublicKey, KeyID: keyID}, nil).
				Once()

			res, err := at.Authorizations.Validate(ctx, "", marshalledJWT(t, authz))
			require.NoError(t, err)
			assert.Equal(t, authz, res)
		})

		t.Run("unparsable jwt", func(t *testing.T) {
			at := newAuthorizationsTest()
			defer at.Teardown(t)

			_, err := at.Authorizations.Validate(ctx, "", "===")
			assert.Equal(t, jwt.ErrTooShort, err)
		})

		t.Run("unparsable header", func(t *testing.T) {
			at := newAuthorizationsTest()
			defer at.Teardown(t)

			_, err := at.Authorizations.Validate(ctx, "", marshalledJWTWithHeader(t, []string{"a", "b", "c"}, authz))
			assert.IsType(t, &json.UnmarshalTypeError{}, err)
		})

		t.Run("validation key doesn't exist", func(t *testing.T) {
			at := newAuthorizationsTest()
			defer at.Teardown(t)

			at.OIDC.On("ValidationKey", mock.Anything, keyID).
				Return(nil, errors.New("myerr")).
				Once()

			_, err := at.Authorizations.Validate(ctx, "", marshalledJWT(t, authz))
			assert.Equal(t, &ErrInvalidToken{Field: "kid", Reason: "myerr"}, err)
		})

		t.Run("differing public key for validation", func(t *testing.T) {
			at := newAuthorizationsTest()
			defer at.Teardown(t)

			at.OIDC.On("ValidationKey", mock.Anything, keyID).
				Return(&es256.PublicKey{
					PublicKey: &ecdsa.PublicKey{
						Curve: elliptic.P256(),
						X:     big.NewInt(11),
						Y:     big.NewInt(13),
					},
					KeyID: keyID,
				}, nil).
				Once()

			_, err := at.Authorizations.Validate(ctx, "", marshalledJWT(t, authz))
			assert.Equal(t, &ErrInvalidToken{Field: "kid", Reason: "Invalid Signature"}, err)
		})

		t.Run("claims wrong format", func(t *testing.T) {
			at := newAuthorizationsTest()
			defer at.Teardown(t)

			at.OIDC.On("ValidationKey", mock.Anything, keyID).
				Return(&es256.PublicKey{PublicKey: &key.PublicKey, KeyID: keyID}, nil).
				Once()

			_, err := at.Authorizations.Validate(ctx, "", marshalledJWTWithClaims(t, struct {
				JWTID      []string `json:"jti"` // wrong type
				Issuer     string   `json:"iss"`
				Subject    string   `json:"sub"`
				Audience   string   `json:"aud"`
				NotBefore  int      `json:"nbf"`
				IssuedAt   int      `json:"iat"`
				Expiration int      `json:"exp"`
				Scope      string   `json:"scope"`
				Active     bool     `json:"active"`
			}{JWTID: []string{"abc"}}))
			assert.Equal(t, &ErrInvalidToken{
				Field:  "claims",
				Reason: "json: cannot unmarshal array into Go struct field .jti of type string",
			}, err)
		})

		t.Run("claims error", func(t *testing.T) {
			fromBaseAuthorization := func(cb func(*Authorization) *Authorization) *Authorization {
				return cb(&Authorization{
					Subject:    &subject{id: "my-id"},
					Audience:   NewAudience("aud1", "aud2", "aud3"),
					Scope:      token.NewScope("scope1", "scope2"),
					JWTID:      "JWTID",
					Issuer:     issuerID,
					Active:     true,
					NotBefore:  now.Add(-time.Hour),
					IssuedAt:   now,
					Expiration: now.Add(time.Hour),
				})
			}

			tcs := []struct {
				Name         string
				Value        *Authorization
				ErrorMatcher func(*testing.T, error)
			}{
				{
					Name: "not active",
					Value: fromBaseAuthorization(func(a *Authorization) *Authorization {
						a.Active = false
						return a
					}),
					ErrorMatcher: func(t *testing.T, err error) {
						assert.Equal(t, &ErrInvalidToken{
							Field:  "active",
							Reason: "Access token is not active",
						}, err)
					},
				},
				{
					Name: "now before nbf",
					Value: fromBaseAuthorization(func(a *Authorization) *Authorization {
						a.NotBefore = now.Add(time.Hour)
						return a
					}),
					ErrorMatcher: func(t *testing.T, err error) {
						require.IsType(t, &ErrInvalidToken{}, err)
						assert.Equal(t, &ErrInvalidToken{
							Field:  "nbf",
							Reason: err.(*ErrInvalidToken).Reason,
						}, err)
					},
				},
				{
					Name: "now after exp",
					Value: fromBaseAuthorization(func(a *Authorization) *Authorization {
						a.Expiration = now.Add(-time.Hour)
						return a
					}),
					ErrorMatcher: func(t *testing.T, err error) {
						require.IsType(t, &ErrInvalidToken{}, err)
						assert.Equal(t, &ErrInvalidToken{
							Field:  "exp",
							Reason: err.(*ErrInvalidToken).Reason,
						}, err)
					},
				},
				{
					Name: "issuer mismatch",
					Value: fromBaseAuthorization(func(a *Authorization) *Authorization {
						a.Issuer = "badissuerid"
						return a
					}),
					ErrorMatcher: func(t *testing.T, err error) {
						require.IsType(t, &ErrInvalidToken{}, err)
						assert.Equal(t, &ErrInvalidToken{
							Field:  "iss",
							Reason: "Token has an invalid issuer<badissuerid> - we expect " + issuerID,
						}, err)
					},
				},
			}

			for _, tc := range tcs {
				t.Run(tc.Name, func(t *testing.T) {
					at := newAuthorizationsTest()
					defer at.Teardown(t)

					at.OIDC.On("Configuration").
						Return(&oidc.Configuration{IssuerID: issuerID}).
						Maybe()

					at.OIDC.On("ValidationKey", mock.Anything, keyID).
						Return(&es256.PublicKey{PublicKey: &key.PublicKey, KeyID: keyID}, nil).
						Once()

					_, err := at.Authorizations.Validate(ctx, "", marshalledJWT(t, tc.Value))
					tc.ErrorMatcher(t, err)
				})
			}
		})
	})
}

func newAuthorizationsTest() *authorizationsTest {
	oidc := new(mocks.OIDCAPI)
	return &authorizationsTest{
		Authorizations: &Authorizations{
			OIDC:                 oidc,
			OnTokenAuthenticated: func(context.Context, *Authorization) {},
			OnTokenRejected:      func(context.Context, *ErrInvalidToken) {},
		},
		OIDC: oidc,
	}
}

type authorizationsTest struct {
	Authorizations *Authorizations
	OIDC           *mocks.OIDCAPI
}

func (at *authorizationsTest) Teardown(t *testing.T) {
	at.OIDC.AssertExpectations(t)
}
