package s2s2

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"net/http"
	"net/http/httptest"
	"strconv"
	"testing"
	"time"

	"github.com/aws/aws-lambda-go/lambdacontext"
	"github.com/golang/mock/gomock"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"

	"code.justin.tv/amzn/TwitchS2S2/c7s"
	"code.justin.tv/amzn/TwitchS2S2/internal/authorization"
	"code.justin.tv/amzn/TwitchS2S2/internal/httpwrap"
	"code.justin.tv/amzn/TwitchS2S2/internal/logutil"
	"code.justin.tv/amzn/TwitchS2S2/internal/s2s2err"
	"code.justin.tv/amzn/TwitchS2S2/internal/token"
	"code.justin.tv/amzn/TwitchS2S2/s2s2/mocks"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/s2s2dicallee"
)

func TestS2S2(t *testing.T) {
	t.Run("CapabilityScope", func(t *testing.T) {
		s2s := &S2S2{clientServiceURI: "https://service.uri"}
		assert.Equal(t, "https://service.uri#createBooks", s2s.CapabilityScope("createBooks"))
	})

	t.Run("HTTPClient", func(t *testing.T) {
		t.Run("HTTP", func(t *testing.T) {
			inner := new(mocks.HTTPClient)
			accessTokens := new(mocks.CachedTokens)
			s2s := &S2S2{accessTokens: accessTokens}
			assert.Equal(t,
				&authenticatingHTTPClient{
					AccessTokens: accessTokens,
					Inner:        inner,
					logger:       s2s.logger,
				},
				s2s.HTTPClient(inner))
		})

		t.Run("lambda", func(t *testing.T) {
			const lambdaName = "LAMBDANAME"
			inner := new(mocks.LambdaTransport)
			defer inner.AssertExpectations(t)

			accessTokens := new(mocks.CachedTokens)
			inner.On("LambdaName").Return(lambdaName).Once()
			s2s := &S2S2{accessTokens: accessTokens, logger: logutil.NoopLogger}
			assert.Equal(t,
				&authenticatingHTTPClient{
					AccessTokens: accessTokens,
					Inner:        inner,
					LambdaName:   lambdaName,
					logger:       s2s.logger,
				},
				s2s.HTTPClient(inner))
		})
	})

	t.Run("LambdaHTTPClient", func(t *testing.T) {
		const lambdaName = "LAMBDANAME"
		inner := new(mocks.HTTPClient)
		accessTokens := new(mocks.CachedTokens)
		s2s := &S2S2{accessTokens: accessTokens}
		assert.Equal(t,
			&authenticatingHTTPClient{
				AccessTokens: accessTokens,
				Inner:        inner,
				LambdaName:   lambdaName,
				logger:       s2s.logger,
			},
			s2s.LambdaHTTPClient(inner, lambdaName))
	})

	t.Run("RoundTripper", func(t *testing.T) {
		inner := new(mocks.HTTPClient)
		accessTokens := new(mocks.CachedTokens)
		s2s := &S2S2{accessTokens: accessTokens}
		assert.Equal(t, httpwrap.RoundTripperFromHTTPClient(&authenticatingHTTPClient{
			AccessTokens: accessTokens,
			Inner:        inner,
			logger:       s2s.logger,
		}),
			s2s.RoundTripper(inner))
	})

	t.Run("RoundTripperWrapper", func(t *testing.T) {
		inner := new(mocks.RoundTripper)
		accessTokens := new(mocks.CachedTokens)
		s2s := &S2S2{accessTokens: accessTokens}
		assert.Equal(t,
			httpwrap.RoundTripperFromHTTPClient(&authenticatingHTTPClient{
				AccessTokens: accessTokens,
				Inner:        httpwrap.HTTPClientFromRoundTripper(inner),
				logger:       s2s.logger,
			}),
			s2s.RoundTripperWrapper(inner))
	})

	tcs := []struct {
		Name       string
		LambdaName string
	}{
		{
			Name: "http transport",
		},
		{
			Name:       "lambda transport",
			LambdaName: "myLambdaName",
		},
	}

	for _, tc := range tcs {
		t.Run(tc.Name, func(t *testing.T) {
			t.Run("RequireAuthentication", func(t *testing.T) {
				newEndpointTest := func() *s2sTest {
					return newS2S2Test(func(s2s *S2S2, inner *mocks.Handler) http.Handler {
						return s2s.RequireAuthentication(inner)
					}, tc.LambdaName)
				}

				t.Run("success", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					sub := new(mocks.Subject)
					aud := new(mocks.Audience)
					defer mock.AssertExpectationsForObjects(t, sub, aud)

					sub.On("ID").Return("anID").Once()
					aud.On("Contains", st.ServerHost(t)).Return(true).Once()

					st.diCallee.
						On("ValidateAuthentication", mock.Anything).
						Return(nil, s2s2dicallee.ErrMissingX5UHeader).
						Once()

					st.Authorizations.
						On("Validate", mock.Anything, "Bearer", "my-authorization").
						Return(&authorization.Authorization{
							Audience: aud,
							Subject:  sub,
							Scope:    token.Scope{},
						}, nil).
						Once()

					st.Handler.
						On("ServeHTTP", mock.Anything, mock.Anything).
						Run(func(args mock.Arguments) {
							req := args.Get(1).(*http.Request)
							assert.Equal(t, &authorizedSubject{
								Subject: sub,
								scope:   token.Scope{},
							}, RequestSubject(req.Context()))
						}).
						Return().
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer my-authorization")

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					defer res.Body.Close()
					assert.Equal(t, http.StatusOK, res.StatusCode)

					assert.Equal(t, 1, len(st.Logger.messages))
					assert.Equal(t, "", st.Logger.data[0]["S2s2TokenId"])
					assert.Equal(t, "Go-http-client/1.1", st.Logger.data[0]["UserAgent"])
				})

				t.Run("success with di header", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					ctrl := gomock.NewController(t)
					defer ctrl.Finish()

					authSubject := mocks.NewMockAuthenticatedSubject(ctrl)
					authSubject.EXPECT().Service().Return("test-service")
					authSubject.EXPECT().Stage().Return("testing")
					authSubject.EXPECT().Domain().Return("twitch")

					st.diCallee.
						On("ValidateAuthentication", mock.Anything).
						Return(authSubject, nil).
						Once()

					st.Handler.
						On("ServeHTTP", mock.Anything, mock.Anything).
						Run(func(args mock.Arguments) {
							req := args.Get(1).(*http.Request)
							assert.Equal(t,
								&distributedIdentitiesAuthorizedSubject{
									AuthenticatedSubject: authSubject,
									cfg:                  st.S2S2.config,
								},
								RequestSubject(req.Context()),
							)
						}).
						Return().
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer my-authorization")

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					defer res.Body.Close()
					assert.Equal(t, http.StatusOK, res.StatusCode)

					assert.Equal(t, 1, len(st.Logger.messages))
					assert.Equal(t, "", st.Logger.data[0]["S2s2TokenId"])
					assert.Equal(t, "Go-http-client/1.1", st.Logger.data[0]["UserAgent"])
				})

				t.Run("missing authorization header", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					st.diCallee.
						On("ValidateAuthentication", mock.Anything).
						Return(nil, s2s2dicallee.ErrMissingX5UHeader).
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					assert.Equal(t, http.StatusUnauthorized, res.StatusCode)

					st.AssertAuthChallenge(t, res, &wwwAuthenticateChallenge{
						challengeType: "Bearer",
						parameters: map[wwwAuthenticateField]string{
							wwwAuthenticateFieldRealm:            st.S2S2.config.CalleeRealm,
							wwwAuthenticateFieldIssuer:           st.S2S2.config.Issuer,
							wwwAuthenticateFieldError:            "authentication_required",
							wwwAuthenticateFieldErrorDescription: s2s2err.S2S2ErrorString("Authorization header is required for this request.", "Authorization"),
						},
					})
					expected := st.twirpError(errorCodeString(401), s2s2err.S2S2ErrorString("Authorization header is required for this request.", "Authorization"))
					st.AssertJSONBody(t, res, expected)
				})

				t.Run("invalid authorization header", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					st.diCallee.
						On("ValidateAuthentication", mock.Anything).
						Return(nil, s2s2dicallee.ErrMissingX5UHeader).
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer")

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
					s2s2Error := st.twirpError(errorCodeString(401), s2s2err.S2S2ErrorString("authorization header must have the format '<authorizationType> <token>'", "Authorization"))
					st.AssertJSONBody(t, res, s2s2Error)
				})

				t.Run("validation error", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					validateErr := errors.New("validateErr")

					st.diCallee.
						On("ValidateAuthentication", mock.Anything).
						Return(nil, s2s2dicallee.ErrMissingX5UHeader).
						Once()

					st.Authorizations.
						On("Validate", mock.Anything, "Bearer", "my-authorization").
						Return(nil, validateErr).
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer my-authorization")

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
					expected := st.twirpError(errorCodeString(500), "validateErr")
					st.AssertJSONBody(t, res, expected)
				})

				t.Run("parsable but expired token", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					validateErr := &authorization.ErrInvalidToken{
						Field:  "exp",
						Reason: "token is expired",
					}

					st.diCallee.
						On("ValidateAuthentication", mock.Anything).
						Return(nil, s2s2dicallee.ErrMissingX5UHeader).
						Once()

					st.Authorizations.
						On("Validate", mock.Anything, "Bearer", "my-authorization").
						Return(nil, validateErr).
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer my-authorization")

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
					s2s2Error := st.twirpError(errorCodeString(401), s2s2err.S2S2ErrorString("Issue with access token field exp: token is expired", "Authorization"))
					st.AssertJSONBody(t, res, s2s2Error)
				})

				t.Run("wrong audience", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					aud := new(mocks.Audience)
					defer mock.AssertExpectationsForObjects(t, aud)

					aud.On("Contains", st.ServerHost(t)).Return(false).Once()

					st.diCallee.
						On("ValidateAuthentication", mock.Anything).
						Return(nil, s2s2dicallee.ErrMissingX5UHeader).
						Once()

					st.Authorizations.
						On("Validate", mock.Anything, "Bearer", "my-authorization").
						Return(&authorization.Authorization{
							Audience: aud,
						}, nil).
						Once()

					aud.On("All").Return([]string{"aud.string"})

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer my-authorization")

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
					s2s2Error := st.twirpError(errorCodeString(401), s2s2err.S2S2ErrorString("Token with audience<aud.string> not intended for '"+st.ServerHost(t)+"'", "Authorization"))
					st.AssertJSONBody(t, res, s2s2Error)
				})

				if tc.LambdaName == "" {
					// test case is not applicable to lambda
					t.Run("unparsable audience", func(t *testing.T) {
						st := newEndpointTest()
						defer st.Teardown(t)

						aud := new(mocks.Audience)
						defer mock.AssertExpectationsForObjects(t, aud)

						st.diCallee.
							On("ValidateAuthentication", mock.Anything).
							Return(nil, s2s2dicallee.ErrMissingX5UHeader).
							Once()

						st.Authorizations.
							On("Validate", mock.Anything, "Bearer", "my-authorization").
							Return(&authorization.Authorization{
								Audience: aud,
							}, nil).
							Once()

						req, err := http.NewRequest("GET", st.Server.URL, nil)
						require.NoError(t, err)

						req.Header.Set("Authorization", "Bearer my-authorization")
						req.Host = ":443"

						res, err := st.Client.Do(req)
						require.NoError(t, err)
						assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
						expected := st.twirpError(errorCodeString(500), "could not determine web origin from Host<:443>")
						st.AssertJSONBody(t, res, expected)
					})
				}
			})

			t.Run("RequireScopes", func(t *testing.T) {
				newEndpointTest := func(scopes ...string) *s2sTest {
					return newS2S2Test(func(s2s *S2S2, inner *mocks.Handler) http.Handler {
						return s2s.RequireScopes(inner, scopes...)
					}, tc.LambdaName)
				}

				t.Run("success with scopes", func(t *testing.T) {
					st := newEndpointTest("testScope")
					defer st.Teardown(t)

					sub := new(mocks.Subject)
					aud := new(mocks.Audience)
					defer mock.AssertExpectationsForObjects(t, sub, aud)

					sub.On("ID").Return("aCallerID").Once()
					aud.On("Contains", st.ServerHost(t)).Return(true).Once()

					st.Authorizations.
						On("Validate", mock.Anything, "Bearer", "my-authorization").
						Return(&authorization.Authorization{
							Audience: aud,
							Subject:  sub,
							Scope:    token.NewScope("testScope"),
						}, nil).
						Once()

					st.Handler.
						On("ServeHTTP", mock.Anything, mock.Anything).
						Run(func(args mock.Arguments) {
							req := args.Get(1).(*http.Request)
							assert.Equal(t, &authorizedSubject{
								Subject: sub,
								scope:   token.NewScope("testScope"),
							}, RequestSubject(req.Context()))
						}).
						Return().
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer my-authorization")

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					defer res.Body.Close()
					assert.Equal(t, http.StatusOK, res.StatusCode)

					require.Equal(t, 1, len(st.Logger.messages))
					assert.Equal(t, "", st.Logger.data[0]["S2s2TokenId"])
					assert.Equal(t, "Go-http-client/1.1", st.Logger.data[0]["UserAgent"])
				})

				t.Run("success with no scopes", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					sub := new(mocks.Subject)
					aud := new(mocks.Audience)
					defer mock.AssertExpectationsForObjects(t, sub, aud)

					sub.On("ID").Return("aCallerID").Once()
					aud.On("Contains", st.ServerHost(t)).Return(true).Once()

					st.Authorizations.
						On("Validate", mock.Anything, "Bearer", "my-authorization").
						Return(&authorization.Authorization{
							Audience: aud,
							Subject:  sub,
							Scope:    token.NewScope("https://myauthority/myclientserviceuri"),
						}, nil).
						Once()

					st.Handler.
						On("ServeHTTP", mock.Anything, mock.Anything).
						Run(func(args mock.Arguments) {
							req := args.Get(1).(*http.Request)
							assert.Equal(t, &authorizedSubject{
								Subject: sub,
								scope:   token.NewScope("https://myauthority/myclientserviceuri"),
							}, RequestSubject(req.Context()))
						}).
						Return().
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer my-authorization")

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					defer res.Body.Close()
					assert.Equal(t, http.StatusOK, res.StatusCode)

					require.Equal(t, 1, len(st.Logger.messages))
					assert.Equal(t, "", st.Logger.data[0]["S2s2TokenId"])
					assert.Equal(t, "Go-http-client/1.1", st.Logger.data[0]["UserAgent"])
				})

				t.Run("error with no scopes", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					sub := new(mocks.Subject)
					aud := new(mocks.Audience)
					defer mock.AssertExpectationsForObjects(t, sub, aud)

					aud.On("Contains", st.ServerHost(t)).Return(true).Once()

					st.Authorizations.
						On("Validate", mock.Anything, "Bearer", "my-authorization").
						Return(&authorization.Authorization{
							Audience: aud,
						}, nil).
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer my-authorization")

					res, err := st.Client.Do(req)
					require.NoError(t, err)
					defer res.Body.Close()
					assert.Equal(t, http.StatusForbidden, res.StatusCode)

					st.AssertAuthChallenge(t, res, &wwwAuthenticateChallenge{
						challengeType: "Bearer",
						parameters: map[wwwAuthenticateField]string{
							wwwAuthenticateFieldScope:            st.S2S2.clientServiceURI,
							wwwAuthenticateFieldRealm:            st.S2S2.config.CalleeRealm,
							wwwAuthenticateFieldIssuer:           st.S2S2.config.Issuer,
							wwwAuthenticateFieldError:            "insufficient_scope",
							wwwAuthenticateFieldErrorDescription: s2s2err.S2S2ErrorString(fmt.Sprintf("More scopes are required [%s] for this request in addition to []", st.S2S2.clientServiceURI), "Authorization"),
						},
					})
					expected := st.twirpError(errorCodeString(403), s2s2err.S2S2ErrorString(fmt.Sprintf("More scopes are required [%s] for this request in addition to []", st.S2S2.clientServiceURI), "Authorization"))
					st.AssertJSONBody(t, res, expected)
				})

				t.Run("missing scope", func(t *testing.T) {
					st := newEndpointTest("testScopeA", "testScopeB")
					defer st.Teardown(t)

					sub := new(mocks.Subject)
					aud := new(mocks.Audience)
					defer mock.AssertExpectationsForObjects(t, sub, aud)

					aud.On("Contains", st.ServerHost(t)).Return(true).Once()

					st.Authorizations.
						On("Validate", mock.Anything, "Bearer", "my-authorization").
						Return(&authorization.Authorization{
							Scope:    token.NewScope("testScopeA"),
							Audience: aud,
						}, nil).
						Once()

					req, err := http.NewRequest("GET", st.Server.URL, nil)
					require.NoError(t, err)

					req.Header.Set("Authorization", "Bearer my-authorization")

					res, err := st.Client.Do(req)
					require.NoError(t, err)

					assert.Equal(t, http.StatusForbidden, res.StatusCode)

					st.AssertAuthChallenge(t, res, &wwwAuthenticateChallenge{
						challengeType: "Bearer",
						parameters: map[wwwAuthenticateField]string{
							wwwAuthenticateFieldScope:            "testScopeA testScopeB",
							wwwAuthenticateFieldRealm:            st.S2S2.config.CalleeRealm,
							wwwAuthenticateFieldIssuer:           st.S2S2.config.Issuer,
							wwwAuthenticateFieldError:            "insufficient_scope",
							wwwAuthenticateFieldErrorDescription: s2s2err.S2S2ErrorString("More scopes are required [testScopeA testScopeB] for this request in addition to [testScopeA]", "Authorization"),
						},
					})
					expected := st.twirpError(errorCodeString(403), s2s2err.S2S2ErrorString("More scopes are required [testScopeA testScopeB] for this request in addition to [testScopeA]", "Authorization"))
					st.AssertJSONBody(t, res, expected)
				})
			})
			t.Run("DistributedIdentitiesCaller", func(t *testing.T) {

				t.Run("RequireAuthentication", func(t *testing.T) {

					newEndpointTest := func() *s2sTest {
						return newS2S2Test(func(s2s *S2S2, inner *mocks.Handler) http.Handler {
							return s2s.RequireAuthentication(inner)
						}, tc.LambdaName)
					}

					t.Run("success", func(t *testing.T) {
						st := newEndpointTest()
						defer st.Teardown(t)

						ctrl := gomock.NewController(t)
						defer ctrl.Finish()
						authSubject := mocks.NewMockAuthenticatedSubject(ctrl)
						authSubject.EXPECT().Service().Return("test-service")
						authSubject.EXPECT().Stage().Return("testing")
						authSubject.EXPECT().Domain().Return("twitch")

						st.diCallee.On("ValidateAuthentication", mock.Anything).Return(authSubject, nil).Once()

						st.Handler.
							On("ServeHTTP", mock.Anything, mock.Anything).
							Run(func(args mock.Arguments) {
								req := args.Get(1).(*http.Request)
								assert.Equal(t,
									&distributedIdentitiesAuthorizedSubject{
										AuthenticatedSubject: authSubject,
										cfg:                  st.S2S2.config,
									},
									RequestSubject(req.Context()),
								)
							}).
							Return().
							Once()

						req, err := http.NewRequest("GET", st.Server.URL, nil)
						require.NoError(t, err)

						req.Header.Set("Authorization", "Bearer my-authorization")

						res, err := st.Client.Do(req)
						require.NoError(t, err)
						defer res.Body.Close()
						assert.Equal(t, http.StatusOK, res.StatusCode)

						require.Equal(t, 1, len(st.Logger.messages))
						assert.Equal(t, "", st.Logger.data[0]["S2s2TokenId"])
						assert.Equal(t, "Go-http-client/1.1", st.Logger.data[0]["UserAgent"])
					})
				})
			})

			t.Run("logRequest", func(t *testing.T) {
				newEndpointTest := func() *s2sTest {
					return newS2S2Test(func(s2s *S2S2, inner *mocks.Handler) http.Handler {
						return s2s.logRequest(inner)
					}, tc.LambdaName)
				}

				t.Run("success", func(t *testing.T) {
					st := newEndpointTest()
					defer st.Teardown(t)

					st.Handler.
						On("ServeHTTP", mock.Anything, mock.Anything).
						Run(func(args mock.Arguments) {
							time.Sleep(time.Millisecond * 5)
						}).
						Return().Once()

					req, err := http.NewRequest("GET", st.Server.URL+"/test?a=1&b=2", nil)
					require.NoError(t, err)

					sub := new(mocks.Subject)
					sub.On("ID").Return("aCaller").Once()
					defer sub.AssertExpectations(t)

					as := &authorizedSubject{
						Subject: sub,
						tokenID: "token",
					}

					req.Header.Set("User-Agent", "aUserAgent")

					st.S2S2.logRequest(st.Handler).ServeHTTP(nil, req.WithContext(SetRequestSubject(req.Context(), as)))

					assert.Equal(t, 1, len(st.Logger.messages))
					assert.Equal(t, "GET", st.Logger.data[0]["MethodName"])
					assert.Equal(t, st.Server.URL+"/test?a=1&b=2", st.Logger.data[0]["Url"])
					assert.Equal(t, "aUserAgent", st.Logger.data[0]["UserAgent"])

					_, err = time.Parse(time.RFC3339, st.Logger.data[0]["Timestamp"])
					assert.NoError(t, err)

					duration, err := strconv.Atoi(st.Logger.data[0]["RequestProcessingTime"])
					assert.NoError(t, err)
					assert.True(t, duration > 0)

					assert.Equal(t, "aCaller", st.Logger.data[0]["AuthenticatedS2S2CallerURI"])
					assert.Equal(t, "token", st.Logger.data[0]["S2s2TokenId"])
				})
			})

			t.Run("HardRefreshCache", func(t *testing.T) {
				ctx := context.Background()

				t.Run("success", func(t *testing.T) {
					ctx, cancel := context.WithCancel(ctx)
					defer cancel()

					accessTokens := new(mocks.CachedTokens)
					accessTokens.On("HardRefreshCache", ctx).Return(nil)
					defer accessTokens.AssertExpectations(t)

					s2s := &S2S2{
						accessTokens:           accessTokens,
						hardRefreshCacheTicker: time.NewTicker(time.Millisecond),
					}
					assert.NoError(t, s2s.HardRefreshCache(ctx))
				})

				t.Run("ctx done", func(t *testing.T) {
					ctx, cancel := context.WithCancel(ctx)
					cancel()

					s2s := &S2S2{
						hardRefreshCacheTicker: time.NewTicker(time.Millisecond),
					}

					assert.Equal(t, ctx.Err(), s2s.HardRefreshCache(ctx))
				})
			})
		})
	}

	t.Run("ServiceURIFromID", func(t *testing.T) {
		assert.Equal(t, "https://my.origin/my-id", (&S2S2{
			config: &c7s.Config{
				ServiceByIDAuthorityURI: "https://my.origin",
			},
		}).ServiceURIFromID("my-id"))
	})

	t.Run("DistributedIdentitiesServciceURI", func(t *testing.T) {
		assert.Equal(t, "https://my.origin/mydomain/myservice/mystage.json", (&S2S2{
			config: &c7s.Config{
				IdentityOrigin: "https://my.origin",
				ServiceDomain:  "mydomain",
			},
		}).DistributedIdentitiesServiceURI("myservice", "mystage"))
	})
}

func newS2S2Test(h func(s2s *S2S2, inner *mocks.Handler) http.Handler, lambdaName string) *s2sTest {
	authorizations := new(mocks.AuthorizationsAPI)
	logger := &mockLogger{}
	callee := new(mocks.CalleeAPI)

	s2s := &S2S2{
		clientServiceURI: "https://myauthority/myclientserviceuri",
		authorizations:   authorizations,
		config: &c7s.Config{
			CalleeRealm:         "s2scalleerealm",
			Issuer:              "s2sissuer",
			EnableAccessLogging: true,
		},
		logger:   logger,
		diCallee: callee,
	}

	mockHandler := new(mocks.Handler)
	var handler http.Handler
	handler = h(s2s, mockHandler)
	if lambdaName != "" {
		innerHandler := handler
		handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			innerHandler.ServeHTTP(w, r.WithContext(lambdacontext.NewContext(
				r.Context(),
				&lambdacontext.LambdaContext{InvokedFunctionArn: lambdaName},
			)))
		})
	}

	return &s2sTest{
		S2S2:           s2s,
		Server:         httptest.NewServer(handler),
		Client:         &http.Client{},
		LambdaName:     lambdaName,
		Authorizations: authorizations,
		Handler:        mockHandler,
		Logger:         logger,
		diCallee:       callee,
	}
}

type s2sTest struct {
	S2S2       *S2S2
	Server     *httptest.Server
	Client     *http.Client
	LambdaName string

	Logger         *mockLogger
	Authorizations *mocks.AuthorizationsAPI
	Handler        *mocks.Handler
	diCallee       *mocks.CalleeAPI
}

func (st *s2sTest) Teardown(t *testing.T) {
	st.Authorizations.AssertExpectations(t)
	st.Handler.AssertExpectations(t)
	st.Server.Close()
}

func (st *s2sTest) twirpError(r, msg string) map[string]string {
	return map[string]string{
		"code": r,
		"msg":  msg,
	}
}

func (st *s2sTest) AssertJSONBody(t *testing.T, res *http.Response, expected map[string]string) {
	var b map[string]string
	dec := json.NewDecoder(res.Body)
	err := dec.Decode(&b)
	require.NoError(t, err)
	assert.Equal(t, expected, b)
	cType := res.Header.Get("Content-Type")
	expCType := "application/json"
	assert.Equal(t, expCType, cType)
}

func (st *s2sTest) AssertAuthChallenge(t *testing.T, res *http.Response, expected *wwwAuthenticateChallenge) {
	authChallenge, err := parseWWWAuthenticateChallenge(res)
	require.NoError(t, err)
	assert.Equal(t, expected, authChallenge)
}

func (st *s2sTest) ServerHost(t *testing.T) string {
	if st.LambdaName != "" {
		return st.LambdaName
	}
	return st.Server.URL
}

type mockLogger struct {
	messages []string
	data     []map[string]string
}

func (logger *mockLogger) Log(msg string, keyvals ...interface{}) {
	logger.messages = append(logger.messages, msg)

	parameters := make(map[string]string)
	for i := 0; i < len(keyvals); i += 2 {
		if i+1 < len(keyvals) {
			parameters[keyvals[i].(string)] = keyvals[i+1].(string)
		} else {
			parameters[keyvals[i].(string)] = "There was an odd number of parameters."
		}
	}
	logger.data = append(logger.data, parameters)
}
