package twitchs2smigration

import (
	"encoding/json"
	"errors"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchS2S2/s2s2"
	"code.justin.tv/amzn/TwitchS2SLegacyMigrationMiddleware/twitchs2smigration/mocks"
	"code.justin.tv/sse/malachai/pkg/jwtvalidation"
	"code.justin.tv/sse/malachai/pkg/s2s/callee"
	"github.com/golang/mock/gomock"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestOptionsMerge(t *testing.T) {
	assert.Equal(t,
		&Options{},
		(&Options{}).merge(&Options{}),
		"no value",
	)
	assert.Equal(t,
		&Options{
			LogS2Sv0RequestReceivedRateLimit: time.Minute,
		},
		(&Options{
			LogS2Sv0RequestReceivedRateLimit: time.Minute,
		}).merge(&Options{}),
		"no override",
	)
	assert.Equal(t,
		&Options{
			LogS2Sv0RequestReceivedRateLimit: time.Hour,
		},
		(&Options{
			LogS2Sv0RequestReceivedRateLimit: time.Minute,
		}).merge(&Options{
			LogS2Sv0RequestReceivedRateLimit: time.Hour,
		}),
		"overridden",
	)
}

func TestOptionsGetLogS2Sv0RequestReceivedRateLimit(t *testing.T) {
	assert.Equal(t,
		time.Second,
		(&Options{}).getLogS2Sv0RequestReceivedRateLimit(),
		"default",
	)
	assert.Equal(t,
		time.Minute,
		(&Options{
			LogS2Sv0RequestReceivedRateLimit: time.Minute,
		}).getLogS2Sv0RequestReceivedRateLimit(),
		"overridden",
	)
}

func TestMiddlewareService(t *testing.T) {
	const serviceName = "serviceName"
	const serviceID = "serviceID"
	const serviceURI = "serviceURI"

	t.Run("success", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newMiddlewareTest(ctrl)
		test.S2SV0.EXPECT().ServiceID(serviceName).Return(serviceID, nil)
		test.S2S2.EXPECT().ServiceURIFromID(serviceID).Return(serviceURI)

		res, err := test.Middleware.Service(serviceName)
		require.NoError(t, err)
		assert.Equal(t, service{ID: serviceURI, Name: serviceName}, res)
	})

	t.Run("error", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		myErr := errors.New("myerr")
		test := newMiddlewareTest(ctrl)
		test.S2SV0.EXPECT().ServiceID(serviceName).Return("", myErr)

		_, err := test.Middleware.Service(serviceName)
		assert.Equal(t, myErr, err)
	})
}

func TestMiddlewareDistributedIdentityService(t *testing.T) {
	const serviceName = "serviceName"
	const serviceStage = "serviceStage"
	const serviceURI = "serviceURI"

	t.Run("success", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newMiddlewareTest(ctrl)
		test.S2S2.EXPECT().DistributedIdentitiesServiceURI(serviceName, serviceStage).Return(serviceURI)

		assert.Equal(t,
			distributedIdentityService{
				Service: serviceName,
				Stage:   serviceStage,
				ID:      serviceURI,
			},
			test.Middleware.DistributedIdentityService(serviceName, serviceStage))
	})
}

func TestRequireAuthentication(t *testing.T) {
	const serviceID = "serviceID"
	const serviceName = "serviceName"
	const serviceURI = "serviceURI"
	t.Run("authenticated via s2s2", func(t *testing.T) {
		t.Run("authorized", func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()

			requestSubject := mocks.NewMockAuthorizedSubject(ctrl)
			requestSubject.EXPECT().ID().Return(serviceID)
			test := newMiddlewareTest(ctrl)
			test.S2S2.EXPECT().RequireAuthentication(gomock.Any()).
				DoAndReturn(func(h http.Handler) s2s2.Handler {
					wrapper := mocks.NewMockHandler(ctrl)
					wrapped := mocks.NewMockHandler(ctrl)
					wrapper.EXPECT().PassthroughIfAuthorizationNotPresented().Return(wrapped)
					wrapped.EXPECT().ServeHTTP(gomock.Any(), gomock.Any()).
						Do(func(w http.ResponseWriter, r *http.Request) {
							h.ServeHTTP(w, r.WithContext(s2s2.SetRequestSubject(r.Context(), requestSubject)))
						})
					return wrapper
				})
			test.S2SV0.EXPECT().RequestValidatorPassthroughMiddleware(gomock.Any()).
				DoAndReturn(func(h http.Handler) http.Handler {
					return h
				})

			h := test.Middleware.RequireAuthentication(
				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, subject(serviceID), RequestSubject(r.Context()))
				}),
				&RequireAuthenticationOptions{
					AuthorizedCallers: []Service{
						service{ID: serviceID},
					},
				})

			srv := httptest.NewServer(h)
			defer srv.Close()

			res, err := http.Get(srv.URL)
			require.NoError(t, err)
			defer res.Body.Close()
			assert.Equal(t, http.StatusOK, res.StatusCode)
		})

		t.Run("not authorized", func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()

			requestSubject := mocks.NewMockAuthorizedSubject(ctrl)
			requestSubject.EXPECT().ID().Return(serviceID)
			test := newMiddlewareTest(ctrl)
			test.S2S2.EXPECT().RequireAuthentication(gomock.Any()).
				DoAndReturn(func(h http.Handler) s2s2.Handler {
					wrapper := mocks.NewMockHandler(ctrl)
					wrapped := mocks.NewMockHandler(ctrl)
					wrapper.EXPECT().PassthroughIfAuthorizationNotPresented().Return(wrapped)
					wrapped.EXPECT().ServeHTTP(gomock.Any(), gomock.Any()).
						Do(func(w http.ResponseWriter, r *http.Request) {
							h.ServeHTTP(w, r.WithContext(s2s2.SetRequestSubject(r.Context(), requestSubject)))
						})
					return wrapper
				})
			test.S2SV0.EXPECT().RequestValidatorPassthroughMiddleware(gomock.Any()).
				DoAndReturn(func(h http.Handler) http.Handler {
					return h
				})

			h := test.Middleware.RequireAuthentication(
				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, subject(serviceID), RequestSubject(r.Context()))
				}),
				&RequireAuthenticationOptions{
					AuthorizedCallers: []Service{
						service{ID: "somethingElse"},
					},
				})

			srv := httptest.NewServer(h)
			defer srv.Close()

			res, err := http.Get(srv.URL)
			require.NoError(t, err)
			defer res.Body.Close()
			test.AssertErrorCode(t, res, http.StatusForbidden, "permission_denied")
		})
	})

	t.Run("authenticated via s2sv0", func(t *testing.T) {
		t.Run("authorized and logged", func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()

			test := newMiddlewareTest(ctrl)
			test.S2S2.EXPECT().RequireAuthentication(gomock.Any()).
				DoAndReturn(func(h http.Handler) s2s2.Handler {
					wrapper := mocks.NewMockHandler(ctrl)
					wrapped := mocks.NewMockHandler(ctrl)
					wrapper.EXPECT().PassthroughIfAuthorizationNotPresented().Return(wrapped)
					wrapped.EXPECT().ServeHTTP(gomock.Any(), gomock.Any()).
						Do(func(w http.ResponseWriter, r *http.Request) {
							h.ServeHTTP(w, r)
						})
					return wrapper
				})
			test.S2S2.EXPECT().ServiceURIFromID(serviceID).Return(serviceURI)
			test.S2SV0.EXPECT().RequestValidatorPassthroughMiddleware(gomock.Any()).
				DoAndReturn(func(h http.Handler) http.Handler {
					return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
						h.ServeHTTP(w, r.WithContext(callee.SetCallerID(r.Context(),
							&jwtvalidation.SigningEntity{Caller: serviceID},
						)))
					})
				})
			logS2Sv0RequestReceivedRateLimiter := make(chan time.Time, 1)
			logS2Sv0RequestReceivedRateLimiter <- time.Now()
			test.Middleware.logger.LogS2Sv0RequestReceivedRateLimiter = logS2Sv0RequestReceivedRateLimiter
			test.Logger.EXPECT().Log("LegacyS2Sv0CallerDetected", "ServiceName", serviceName)

			h := test.Middleware.RequireAuthentication(
				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, s2sv0Subject{subject: subject(serviceURI)}, RequestSubject(r.Context()))
				}),
				&RequireAuthenticationOptions{
					AuthorizedCallers: []Service{
						service{ID: serviceURI, Name: serviceName},
					},
				})

			srv := httptest.NewServer(h)
			defer srv.Close()

			res, err := http.Get(srv.URL)
			require.NoError(t, err)
			defer res.Body.Close()
			assert.Equal(t, http.StatusOK, res.StatusCode)
		})

		t.Run("authorized and not logged", func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()

			test := newMiddlewareTest(ctrl)
			test.S2S2.EXPECT().RequireAuthentication(gomock.Any()).
				DoAndReturn(func(h http.Handler) s2s2.Handler {
					wrapper := mocks.NewMockHandler(ctrl)
					wrapped := mocks.NewMockHandler(ctrl)
					wrapper.EXPECT().PassthroughIfAuthorizationNotPresented().Return(wrapped)
					wrapped.EXPECT().ServeHTTP(gomock.Any(), gomock.Any()).
						Do(func(w http.ResponseWriter, r *http.Request) {
							h.ServeHTTP(w, r)
						})
					return wrapper
				})
			test.S2S2.EXPECT().ServiceURIFromID(serviceID).Return(serviceURI)
			test.S2SV0.EXPECT().RequestValidatorPassthroughMiddleware(gomock.Any()).
				DoAndReturn(func(h http.Handler) http.Handler {
					return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
						h.ServeHTTP(w, r.WithContext(callee.SetCallerID(r.Context(),
							&jwtvalidation.SigningEntity{Caller: serviceID},
						)))
					})
				})

			h := test.Middleware.RequireAuthentication(
				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, s2sv0Subject{subject: subject(serviceURI)}, RequestSubject(r.Context()))
				}),
				&RequireAuthenticationOptions{
					AuthorizedCallers: []Service{
						service{ID: serviceURI, Name: serviceName},
					},
				})

			srv := httptest.NewServer(h)
			defer srv.Close()

			res, err := http.Get(srv.URL)
			require.NoError(t, err)
			defer res.Body.Close()
			assert.Equal(t, http.StatusOK, res.StatusCode)
		})

		t.Run("not authorized", func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()

			test := newMiddlewareTest(ctrl)
			test.S2S2.EXPECT().RequireAuthentication(gomock.Any()).
				DoAndReturn(func(h http.Handler) s2s2.Handler {
					wrapper := mocks.NewMockHandler(ctrl)
					wrapped := mocks.NewMockHandler(ctrl)
					wrapper.EXPECT().PassthroughIfAuthorizationNotPresented().Return(wrapped)
					wrapped.EXPECT().ServeHTTP(gomock.Any(), gomock.Any()).
						Do(func(w http.ResponseWriter, r *http.Request) {
							h.ServeHTTP(w, r)
						})
					return wrapper
				})
			test.S2S2.EXPECT().ServiceURIFromID(serviceID).Return(serviceURI)
			test.S2SV0.EXPECT().RequestValidatorPassthroughMiddleware(gomock.Any()).
				DoAndReturn(func(h http.Handler) http.Handler {
					return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
						h.ServeHTTP(w, r.WithContext(callee.SetCallerID(r.Context(),
							&jwtvalidation.SigningEntity{Caller: serviceID},
						)))
					})
				})

			h := test.Middleware.RequireAuthentication(
				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					t.Fatal("we expect permission_denied")
				}),
				&RequireAuthenticationOptions{
					AuthorizedCallers: []Service{
						service{ID: "somethingElse"},
					},
				})

			srv := httptest.NewServer(h)
			defer srv.Close()

			res, err := http.Get(srv.URL)
			require.NoError(t, err)
			defer res.Body.Close()
			test.AssertErrorCode(t, res, http.StatusForbidden, "permission_denied")
		})
	})

	t.Run("no authentication", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newMiddlewareTest(ctrl)
		test.S2S2.EXPECT().RequireAuthentication(gomock.Any()).
			DoAndReturn(func(h http.Handler) s2s2.Handler {
				wrapper := mocks.NewMockHandler(ctrl)
				wrapped := mocks.NewMockHandler(ctrl)
				wrapper.EXPECT().PassthroughIfAuthorizationNotPresented().Return(wrapped)
				wrapped.EXPECT().ServeHTTP(gomock.Any(), gomock.Any()).
					Do(func(w http.ResponseWriter, r *http.Request) {
						h.ServeHTTP(w, r)
					})
				return wrapper
			})
		test.S2SV0.EXPECT().RequestValidatorPassthroughMiddleware(gomock.Any()).
			DoAndReturn(func(h http.Handler) http.Handler {
				return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					h.ServeHTTP(w, r)
				})
			})

		h := test.Middleware.RequireAuthentication(
			http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				t.Fatal("we expect permission_denied")
			}),
			&RequireAuthenticationOptions{
				AuthorizedCallers: []Service{
					service{ID: serviceID},
				},
			})

		srv := httptest.NewServer(h)
		defer srv.Close()

		res, err := http.Get(srv.URL)
		require.NoError(t, err)
		defer res.Body.Close()
		test.AssertErrorCode(t, res, http.StatusUnauthorized, "unauthenticated")
	})

	t.Run("s2sv0 is unidentified", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newMiddlewareTest(ctrl)
		test.S2S2.EXPECT().RequireAuthentication(gomock.Any()).
			DoAndReturn(func(h http.Handler) s2s2.Handler {
				wrapper := mocks.NewMockHandler(ctrl)
				wrapped := mocks.NewMockHandler(ctrl)
				wrapper.EXPECT().PassthroughIfAuthorizationNotPresented().Return(wrapped)
				wrapped.EXPECT().ServeHTTP(gomock.Any(), gomock.Any()).
					Do(func(w http.ResponseWriter, r *http.Request) {
						h.ServeHTTP(w, r)
					})
				return wrapper
			})
		test.S2SV0.EXPECT().RequestValidatorPassthroughMiddleware(gomock.Any()).
			DoAndReturn(func(h http.Handler) http.Handler {
				return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					h.ServeHTTP(w, r.WithContext(callee.SetCallerID(r.Context(),
						&jwtvalidation.SigningEntity{Caller: callee.UnidentifiedCallerID},
					)))
				})
			})

		h := test.Middleware.RequireAuthentication(
			http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				t.Fatal("we expect permission_denied")
			}),
			&RequireAuthenticationOptions{
				AuthorizedCallers: []Service{
					service{ID: serviceID},
				},
			})

		srv := httptest.NewServer(h)
		defer srv.Close()

		res, err := http.Get(srv.URL)
		require.NoError(t, err)
		defer res.Body.Close()
		test.AssertErrorCode(t, res, http.StatusUnauthorized, "unauthenticated")
	})
}

type middlewareTest struct {
	Middleware *Middleware
	S2SV0      *mocks.MockClientAPI
	S2S2       *mocks.MockS2S2API
	Logger     *mocks.MockLogger
}

func newMiddlewareTest(ctrl *gomock.Controller) *middlewareTest {
	s2sv0 := mocks.NewMockClientAPI(ctrl)
	s2s2 := mocks.NewMockS2S2API(ctrl)
	l := mocks.NewMockLogger(ctrl)
	return &middlewareTest{
		Middleware: &Middleware{
			s2sv0:  s2sv0,
			s2s2:   s2s2,
			logger: &logger{Logger: l},
		},
		S2SV0:  s2sv0,
		S2S2:   s2s2,
		Logger: l,
	}
}

func (middlewareTest) AssertErrorCode(t *testing.T, r *http.Response, httpCode int, code string) {
	t.Helper()

	require.Equal(t, httpCode, r.StatusCode)
	var jsonError jsonError
	require.NoError(t, json.NewDecoder(r.Body).Decode(&jsonError))
	assert.Equal(t, code, jsonError.Code)
}
