//nolint:bodyclose
package s2s2

import (
	"bytes"
	"errors"
	"io"
	"io/ioutil"
	"net/http"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchS2S2/internal/token"
	"code.justin.tv/amzn/TwitchS2S2/s2s2/mocks"
	"github.com/golang/mock/gomock"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
)

func TestAuthenticatingHTTPClient(t *testing.T) {
	t.Run("Do", func(t *testing.T) {
		const host = "http://targethost"
		const testBody = "testbody"

		testBodyReader := func() io.ReadCloser {
			return ioutil.NopCloser(bytes.NewReader([]byte(testBody)))
		}

		assertRequestBody := func(t *testing.T, res *http.Request, expected string) {
			bs, err := ioutil.ReadAll(res.Body)
			require.NoError(t, err)
			assert.Equal(t, expected, string(bs))
		}

		tcs := []struct {
			LambdaName   string
			ExpectedHost string
		}{
			{
				LambdaName:   "my-lambda-arn",
				ExpectedHost: "my-lambda-arn",
			},
			{
				ExpectedHost: host,
			},
		}

		for _, tc := range tcs {
			onToken := func(rtt *authenticatingHTTPClientTest) *mock.Call {
				return rtt.AccessTokens.On("Token", mock.Anything, token.NewOptions().WithHost(tc.ExpectedHost)).Once()
			}

			t.Run("LambdaName: "+tc.LambdaName, func(t *testing.T) {
				t.Run("success", func(t *testing.T) {
					ctrl := gomock.NewController(t)
					defer ctrl.Finish()
					rtt := newAuthenticatingHTTPClientTest(tc.LambdaName, ctrl)
					defer rtt.Teardown(t)

					authorizationToken := []byte("authorizationtoken")
					accessToken := &token.Token{
						AccessToken: string(authorizationToken),
					}

					onToken(rtt).Return(accessToken, nil)

					req, err := http.NewRequest("POST", host, testBodyReader())
					require.NoError(t, err)

					innerRes := &http.Response{StatusCode: http.StatusOK}

					rtt.Inner.
						On("Do", mock.Anything).
						Run(func(args mock.Arguments) {
							req := args.Get(0).(*http.Request)
							assert.Equal(t, "Bearer "+string(authorizationToken), req.Header.Get("Authorization"))
							assertRequestBody(t, req, testBody)
						}).
						Return(innerRes, nil)

					rtt.Logger.EXPECT().Log(gomock.Any()).Return().Times(1)

					res, err := rtt.AuthenticatingHTTPClient.Do(req)
					require.NoError(t, err)
					assert.Equal(t, innerRes, res)
				})

				t.Run("authorization failure", func(t *testing.T) {
					ctrl := gomock.NewController(t)
					defer ctrl.Finish()
					rtt := newAuthenticatingHTTPClientTest(tc.LambdaName, ctrl)
					defer rtt.Teardown(t)

					authorizationErr := errors.New("authorization error")
					onToken(rtt).Return(nil, authorizationErr)

					req, err := http.NewRequest("POST", host, testBodyReader())
					require.NoError(t, err)

					rtt.Logger.EXPECT().Log(gomock.Any()).Return().Times(1)

					_, err = rtt.AuthenticatingHTTPClient.Do(req)
					assert.Equal(t, authorizationErr, err)
				})

				t.Run("inner round trip failure", func(t *testing.T) {
					ctrl := gomock.NewController(t)
					defer ctrl.Finish()
					rtt := newAuthenticatingHTTPClientTest(tc.LambdaName, ctrl)
					defer rtt.Teardown(t)

					authorizationToken := []byte("authorizationtoken")

					accessToken := &token.Token{AccessToken: string(authorizationToken)}

					onToken(rtt).Return(accessToken, nil)

					req, err := http.NewRequest("POST", host, testBodyReader())
					require.NoError(t, err)

					innerErr := errors.New("innererr")

					rtt.Inner.
						On("Do", mock.Anything).
						Run(func(args mock.Arguments) {
							req := args.Get(0).(*http.Request)
							assert.Equal(t, "Bearer "+string(authorizationToken), req.Header.Get("Authorization"))
							assertRequestBody(t, req, testBody)
						}).
						Return(nil, innerErr)

					rtt.Logger.EXPECT().Log(gomock.Any()).Return().Times(1)

					_, err = rtt.AuthenticatingHTTPClient.Do(req)
					assert.Equal(t, innerErr, err)
				})

				t.Run("401 - error on requesting additional scopes", func(t *testing.T) {
					ctrl := gomock.NewController(t)
					defer ctrl.Finish()
					rtt := newAuthenticatingHTTPClientTest(tc.LambdaName, ctrl)
					defer rtt.Teardown(t)

					authorizationToken := []byte("authorizationtoken")

					accessToken := &token.Token{
						AccessToken: string(authorizationToken),
						Scope:       make(token.Scope),
					}
					onToken(rtt).Return(accessToken, nil)

					req, err := http.NewRequest("POST", host, testBodyReader())
					require.NoError(t, err)

					initialResponse := &http.Response{StatusCode: http.StatusUnauthorized, Header: make(http.Header)}
					initialResponse.Header.Set("WWW-Authenticate", "Bearer error=\"invalid_token\", scope=\"scope1 scope2\"")

					rtt.Inner.
						On("Do", mock.Anything).
						Run(func(args mock.Arguments) {
							req := args.Get(0).(*http.Request)
							assert.Equal(t, "Bearer "+string(authorizationToken), req.Header.Get("Authorization"))
							assertRequestBody(t, req, testBody)
						}).
						Return(initialResponse, nil).
						Once()

					addScopesErr := errors.New("add scopes error")
					rtt.AccessTokens.On("Token", mock.Anything, mock.MatchedBy(func(o *token.Options) bool {
						return o.Scope().Contains(token.Scope{
							"scope1": nil,
							"scope2": nil,
						})
					})).Return(nil, addScopesErr).Once()

					rtt.Logger.EXPECT().Log(gomock.Any()).Return().Times(1)

					_, err = rtt.AuthenticatingHTTPClient.Do(req)
					assert.Equal(t, addScopesErr, err)
				})

				t.Run("401 - then success", func(t *testing.T) {
					ctrl := gomock.NewController(t)
					defer ctrl.Finish()
					rtt := newAuthenticatingHTTPClientTest(tc.LambdaName, ctrl)
					defer rtt.Teardown(t)

					authorizationToken := []byte("authorizationtoken")

					accessToken := &token.Token{
						AccessToken: string(authorizationToken),
						Scope:       make(token.Scope),
					}

					onToken(rtt).Return(accessToken, nil).Once()

					req, err := http.NewRequest("GET", host, testBodyReader())
					require.NoError(t, err)

					initialResponse := &http.Response{StatusCode: http.StatusUnauthorized, Header: make(http.Header)}
					initialResponse.Header.Set("WWW-Authenticate", "Bearer scope=\"scope1\", error=\"invalid_token\"")

					rtt.Inner.
						On("Do", mock.Anything).
						Run(func(args mock.Arguments) {
							req := args.Get(0).(*http.Request)
							assert.Equal(t, "Bearer "+string(authorizationToken), req.Header.Get("Authorization"))
							assertRequestBody(t, req, testBody)
						}).
						Return(initialResponse, nil).
						Once()

					authorizationToken2 := []byte("authorizationtoken2")
					accessToken2 := &token.Token{
						AccessToken: string(authorizationToken2),
						Scope:       make(token.Scope),
					}

					options := &token.Options{
						NoCache: true,
					}
					rtt.AccessTokens.On("Token", mock.Anything, options.
						WithHost(tc.ExpectedHost).WithScope(token.ParseScope("scope1"))).
						Return(accessToken2, nil).Once()

					finalResponse := &http.Response{StatusCode: http.StatusOK}
					rtt.Inner.
						On("Do", mock.Anything).
						Run(func(args mock.Arguments) {
							req := args.Get(0).(*http.Request)
							assert.Equal(t, "Bearer "+string(authorizationToken2), req.Header.Get("Authorization"))
							assertRequestBody(t, req, testBody)
						}).
						Return(finalResponse, nil).
						Once()

					rtt.Logger.EXPECT().Log(gomock.Any()).Return().Times(1)

					res, err := rtt.AuthenticatingHTTPClient.Do(req)
					require.NoError(t, err)
					assert.Equal(t, finalResponse, res)
				})
			})
		}
	})
}

func newAuthenticatingHTTPClientTest(lambdaName string, ctrl *gomock.Controller) *authenticatingHTTPClientTest {
	inner := new(mocks.HTTPClient)
	accessTokens := new(mocks.Tokens)
	logger := mocks.NewMockLogger(ctrl)
	rateLimiter := time.NewTicker(time.Hour)

	return &authenticatingHTTPClientTest{
		AuthenticatingHTTPClient: &authenticatingHTTPClient{
			Inner:        inner,
			AccessTokens: accessTokens,
			LambdaName:   lambdaName,
			logger:       logger,
			rateLimiter:  rateLimiter,
		},
		AccessTokens: accessTokens,
		Inner:        inner,
		Logger:       logger,
	}
}

type authenticatingHTTPClientTest struct {
	AuthenticatingHTTPClient *authenticatingHTTPClient
	AccessTokens             *mocks.Tokens
	Inner                    *mocks.HTTPClient
	Logger                   *mocks.MockLogger
}

func (rtt *authenticatingHTTPClientTest) Teardown(t *testing.T) {
	rtt.AccessTokens.AssertExpectations(t)
	rtt.Inner.AssertExpectations(t)
}
