package validation

import (
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	"errors"
	"math/rand"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/service"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/validation/mocks"
	"code.justin.tv/sse/jwt"
	"github.com/golang/mock/gomock"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

//go:generate mockgen -package mocks -destination ./mocks/certiface.go code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/cert/certiface CertificatesAPI

func TestValidationsValidate(t *testing.T) {
	const headerX5U = "myX5u"
	const claimedAudience = "claimedaudiencescheme://claimedaudiencehost:8081"
	const serviceDomain = "serviceDomain"
	const serviceName = "serviceName"
	const serviceStage = "serviceStage"
	claimedExpiration := time.Now().UTC().Truncate(time.Second).Add(2 * time.Hour)
	claimedNotBefore := time.Now().UTC().Truncate(time.Second).Add(-time.Hour)

	ctx := context.Background()

	expectedService := service.Service{
		Domain: serviceDomain,
		Name:   serviceName,
		Stage:  serviceStage,
	}

	privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(69)))
	require.NoError(t, err)

	signedJWT := func(t *testing.T, header, claims interface{}) []byte {
		bs, err := jwt.Encode(header, claims, jwt.ES384(privateKey))
		require.NoError(t, err)
		return bs
	}

	serverOrigins := func(t *testing.T) *Origins {
		o, err := NewOrigins(claimedAudience)
		require.NoError(t, err)
		return o
	}

	t.Run("success", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))

		test.MockCertificatesAPI.EXPECT().
			Get(gomock.Any(), headerX5U).
			Return(&expectedService, &privateKey.PublicKey, nil)

		res, err := test.Validate(ctx, signedJWT(t,
			struct {
				X5U string `json:"x5u"`
			}{X5U: headerX5U},
			struct {
				Audience   string `json:"aud"`
				Expiration int64  `json:"exp"`
				NotBefore  int64  `json:"nbf"`
			}{
				Audience:   claimedAudience,
				Expiration: claimedExpiration.Unix(),
				NotBefore:  claimedNotBefore.Unix(),
			},
		))
		require.NoError(t, err)
		assert.Equal(t, &Validation{Subject: expectedService, Expiration: claimedExpiration, NotBefore: claimedNotBefore}, res)
	})

	t.Run("invalid token structure", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))

		_, err := test.Validate(ctx, []byte("ok"))
		assert.Contains(t, err.Error(), "invalid section count")
	})

	t.Run("missing x5u header", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))
		_, err := test.Validate(ctx, signedJWT(t,
			struct {
				X5U string `json:"x5u"`
			}{},
			struct {
				Audience   string `json:"aud"`
				Expiration int64  `json:"exp"`
				NotBefore  int64  `json:"nbf"`
			}{
				Audience:   claimedAudience,
				Expiration: claimedExpiration.Unix(),
				NotBefore:  claimedNotBefore.Unix(),
			},
		))
		assert.Equal(t, err, ErrMissingX5UHeader)
	})

	t.Run("x5u wrong type", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))

		_, err := test.Validate(ctx, signedJWT(t,
			struct {
				X5U int64 `json:"x5u"`
			}{X5U: 240},
			struct {
				Audience   string `json:"aud"`
				Expiration int64  `json:"exp"`
				NotBefore  int64  `json:"nbf"`
			}{
				Audience:   claimedAudience,
				Expiration: claimedExpiration.Unix(),
				NotBefore:  claimedNotBefore.Unix(),
			},
		))

		assert.Contains(t, err.Error(), ".x5u")
	})

	t.Run("get certificate error", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))

		myErr := errors.New("myerror")
		test.MockCertificatesAPI.EXPECT().
			Get(gomock.Any(), headerX5U).
			Return(nil, nil, myErr)

		_, err := test.Validate(ctx, signedJWT(t,
			struct {
				X5U string `json:"x5u"`
			}{X5U: headerX5U},
			struct {
				Audience   string `json:"aud"`
				Expiration int64  `json:"exp"`
				NotBefore  int64  `json:"nbf"`
			}{
				Audience:   claimedAudience,
				Expiration: claimedExpiration.Unix(),
				NotBefore:  claimedNotBefore.Unix(),
			},
		))

		assert.Equal(t, myErr, err)
	})

	t.Run("exp wrong type", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))

		test.MockCertificatesAPI.EXPECT().
			Get(gomock.Any(), headerX5U).
			Return(&expectedService, &privateKey.PublicKey, nil)

		_, err := test.Validate(ctx, signedJWT(t,
			struct {
				X5U string `json:"x5u"`
			}{X5U: headerX5U},
			struct {
				Audience   string `json:"aud"`
				Expiration string `json:"exp"`
				NotBefore  int64  `json:"nbf"`
			}{
				Audience:   "derp",
				Expiration: "420",
				NotBefore:  claimedNotBefore.Unix(),
			},
		))

		assert.Contains(t, err.Error(), ".exp")
	})

	t.Run("audience unparsable", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))

		test.MockCertificatesAPI.EXPECT().
			Get(gomock.Any(), headerX5U).
			Return(&expectedService, &privateKey.PublicKey, nil)

		_, err := test.Validate(ctx, signedJWT(t,
			struct {
				X5U string `json:"x5u"`
			}{X5U: headerX5U},
			struct {
				Audience   string `json:"aud"`
				Expiration int64  `json:"exp"`
				NotBefore  int64  `json:"nbf"`
			}{
				Audience:   "derp",
				Expiration: claimedExpiration.Unix(),
				NotBefore:  claimedNotBefore.Unix(),
			},
		))

		assert.Equal(t, errors.New("unparseable web origin: derp"), err)
	})

	t.Run("token expired", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))

		test.MockCertificatesAPI.EXPECT().
			Get(gomock.Any(), headerX5U).
			Return(&expectedService, &privateKey.PublicKey, nil)

		_, err := test.Validate(ctx, signedJWT(t,
			struct {
				X5U string `json:"x5u"`
			}{X5U: headerX5U},
			struct {
				Audience   string `json:"aud"`
				Expiration int64  `json:"exp"`
				NotBefore  int64  `json:"nbf"`
			}{
				Audience:   claimedAudience,
				Expiration: time.Now().UTC().Add(-24 * time.Hour).Unix(), // changed
				NotBefore:  claimedNotBefore.Unix(),
			},
		))

		errClaimValidation := err.(*ClaimValidationError)
		assert.Contains(t, errClaimValidation.Message, "after")
		errClaimValidation.Message = ""

		assert.Equal(t, &ClaimValidationError{Field: "exp"}, err)
	})

	t.Run("token not usable yet", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))

		test.MockCertificatesAPI.EXPECT().
			Get(gomock.Any(), headerX5U).
			Return(&expectedService, &privateKey.PublicKey, nil)

		_, err := test.Validate(ctx, signedJWT(t,
			struct {
				X5U string `json:"x5u"`
			}{X5U: headerX5U},
			struct {
				Audience   string `json:"aud"`
				Expiration int64  `json:"exp"`
				NotBefore  int64  `json:"nbf"`
			}{
				Audience:   claimedAudience,
				Expiration: claimedExpiration.Unix(),
				NotBefore:  time.Now().UTC().Add(24 * time.Hour).Unix(), // changed
			},
		))

		errClaimValidation := err.(*ClaimValidationError)
		assert.Contains(t, errClaimValidation.Message, "before")
		errClaimValidation.Message = ""

		assert.Equal(t, &ClaimValidationError{Field: "nbf"}, err)
	})

	t.Run("non-authoritative audience", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newValidationsTest(ctrl, serverOrigins(t))

		test.MockCertificatesAPI.EXPECT().
			Get(gomock.Any(), headerX5U).
			Return(&expectedService, &privateKey.PublicKey, nil)

		_, err := test.Validate(ctx, signedJWT(t,
			struct {
				X5U string `json:"x5u"`
			}{X5U: headerX5U},
			struct {
				Audience   string `json:"aud"`
				Expiration int64  `json:"exp"`
				NotBefore  int64  `json:"nbf"`
			}{
				Audience:   "https://something.else", // changed
				Expiration: claimedExpiration.Unix(),
				NotBefore:  claimedNotBefore.Unix(),
			},
		))
		assert.Equal(t, &ClaimValidationError{
			Field:   "aud",
			Message: "claimedaudiencescheme://claimedaudiencehost:8081 does not contain https://something.else",
		}, err)
	})
}

type validationsTest struct {
	*Validations
	*mocks.MockCertificatesAPI
}

func newValidationsTest(ctrl *gomock.Controller, serverOrigins *Origins) *validationsTest {
	mockCertificatesAPI := mocks.NewMockCertificatesAPI(ctrl)
	return &validationsTest{
		Validations: &Validations{
			Certificates:  mockCertificatesAPI,
			ServerOrigins: serverOrigins,
		},
		MockCertificatesAPI: mockCertificatesAPI,
	}
}
