/* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. */

package cloudauth

import (
	"errors"
	"fmt"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"
)

func testInvalidErr(t *testing.T, err error) {
	t.Helper()

	switch {
	case err == nil:
		t.Fatal("did not get expected error")
	case err != nil && err.Error() == "":
		t.Fatal("empty error message")
	case err != nil:
		if _, ok := err.(*InvalidError); !ok {
			t.Fatalf("un-expected error type: %T", err)
		}
	}
}

// TestNewAuthSession_EmptyAuthServerTokenProvider tests when AuthServerTokenProvider
// is a nil interface.
func TestNewAuthSession_EmptyAuthServerTokenProvider(t *testing.T) {
	_, err := NewAuthSession(nil)

	testInvalidErr(t, err)
}

// TestNewAuthSession_InvalidIssuerURL tests on invalid Issuer URLs.
func TestNewAuthSession_InvalidIssuerURL(t *testing.T) {
	const (
		badIssuerURL1 = "http://[::1]:namedport"
		badIssuerURL2 = "https://127.0.0.1:-1"
	)

	for _, tc := range []struct {
		name string
		url  string
	}{
		{
			name: "Empty URL",
		},
		{
			name: "URL with a non-numerical port number",
			url:  badIssuerURL1,
		},
		{
			name: "URL with a negative port number",
			url:  badIssuerURL2,
		},
	} {
		t.Run(tc.name, func(t *testing.T) {
			as, err := NewAuthSession(&mockAuthServerToken{}, withAuthServerURL(tc.url))

			if as != nil {
				t.Errorf("got non-nil response")
			}

			if err == nil {
				t.Fatal("did not get expected error")
			}

			if err.Error() == "" {
				t.Error("empty error message")
			}
		})
	}
}

// TestNewAuthSession is a test suite of:
//
// 1. Auth Server discovery.
// 2. Client Identity Assertion.
// 3. Acquiring access token for Auth Server.
//
func TestNewAuthSession(t *testing.T) {
	tests := []struct {
		description     string
		authServerToken AuthServerTokenProvider
		expectError     bool
		behaviour       serverBehaviour
	}{
		{
			description:     "Discovery - Bogus Response",
			authServerToken: &mockAuthServerToken{},
			expectError:     true,
			behaviour:       serverBehaviour{state: stateInit, response: responseBogus},
		},
		{
			description:     "Discovery - 404 Response",
			authServerToken: &mockAuthServerToken{},
			expectError:     true,
			behaviour:       serverBehaviour{state: stateInit, response: responseInvalidStatusCode},
		},
		{
			description:     "Discovery - I/O Error",
			authServerToken: &mockAuthServerToken{},
			expectError:     true,
			behaviour:       serverBehaviour{state: stateInit, response: responseIoError},
		},
		{
			description:     "Bogus Client Identity",
			authServerToken: &mockBogusAuthServerToken{},
			expectError:     true,
		},
		{
			description:     "Token Request (Auth Server) - Bogus Response",
			authServerToken: &mockAuthServerToken{},
			expectError:     true,
			behaviour:       serverBehaviour{state: stateDiscovered, response: responseBogus},
		},
		{
			description:     "Token Request (Auth Server) - 404 Response",
			authServerToken: &mockAuthServerToken{},
			expectError:     true,
			behaviour:       serverBehaviour{state: stateDiscovered, response: responseInvalidStatusCode},
		},
		{
			description:     "Token Request (Auth Server) - I/O Error",
			authServerToken: &mockAuthServerToken{},
			expectError:     true,
			behaviour:       serverBehaviour{state: stateDiscovered, response: responseIoError},
		},
		{
			description:     "Token Request (Auth Server) - Valid Response",
			authServerToken: &mockAuthServerToken{},
		},
	}

	for _, c := range tests {
		t.Run(c.description, func(t *testing.T) {
			s := mockCloudAuthServer{t: t, behaviour: c.behaviour, introspectionScope: mockServiceScope}
			ts := httptest.NewTLSServer(s.Handler())
			defer ts.Close()

			ts.Client().Timeout = 1 * time.Minute

			as, err := NewAuthSession(c.authServerToken, withAuthServerURL(ts.URL), withAuthSessionClient(ts.Client()))

			switch {
			case c.expectError == false && err != nil:
				t.Fatalf("got unexpected error %+v", err)
			case c.expectError == false && as == nil:
				t.Fatalf("got unexpected nil authSession")
			case as != nil && as.TokenEndpoint() == "":
				t.Fatal("got empty Token endpoint")
			case as != nil:
				// Trigger a token request
				_, err = NewClient(
					ResourceParameter{
						URL:   mockServiceURL,
						Scope: mockServiceScope,
						Realm: mockServiceName,
					}, as)
			}

			switch {
			case c.expectError && err == nil:
				t.Error("did not get expected error")
			case c.expectError && err != nil:
				if err.Error() == "" {
					t.Error("empty error message")
				}
			case err != nil:
				t.Errorf("got unexpected error %+v", err)
			}
		})
	}
}

// TestIntrospect_EmptyToken tests when the access token is empty.
func TestIntrospect_EmptyToken(t *testing.T) {
	as := &authSession{}

	_, err := as.Introspect("")

	testInvalidErr(t, err)
}

// TestAuthSession_Introspect tests different scenarios for token introspection.
func TestAuthSession_Introspect(t *testing.T) {
	goRoutineCount := 10
	tests := []struct {
		name                      string
		cleaningRate              time.Duration
		behaviour                 serverBehaviour
		active                    bool
		expiry                    int64
		success                   bool
		expectedCallsToAuthServer int
	}{
		{name: "Bogus response",
			behaviour: serverBehaviour{state: stateAuthenticated, response: responseBogus},
			success:   false},
		{name: "404 response",
			behaviour: serverBehaviour{state: stateAuthenticated, response: responseInvalidStatusCode},
			success:   false},
		{name: "I/O Error",
			behaviour: serverBehaviour{state: stateAuthenticated, response: responseIoError},
			success:   false},
		{name: "Valid response expired token",
			expiry:                    time.Now().Unix() - 1,
			cleaningRate:              1 * time.Nanosecond,
			expectedCallsToAuthServer: goRoutineCount + 1,
			success:                   false},
		{name: "Valid response active token",
			expiry:                    time.Now().Unix() + 60,
			cleaningRate:              1 * time.Minute,
			expectedCallsToAuthServer: 1,
			active:                    true,
			success:                   true},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			authServer := mockCloudAuthServer{t: t, behaviour: test.behaviour, introspectionExpiry: test.expiry, introspectionScope: mockServiceScope}
			ts := httptest.NewTLSServer(authServer.Handler())
			defer ts.Close()
			ts.Client().Timeout = 1 * time.Minute

			as, err := NewAuthSession(&mockAuthServerToken{}, withAuthServerURL(ts.URL), withAuthSessionClient(ts.Client()), WithCleaningRate(test.cleaningRate))
			if err != nil || as == nil {
				t.Fatalf("failed to create a new auth session. Got auth seesion %v and err %+v", as, err)
			}

			// One of the things we are trying to test is if introspection results are properly cached. If we don't warm the cache
			// before kicking off go routines then we introduce a race condition in the tests where it's possible the cache never gets hit.
			// This first call to introspect warms the cache if applicable.
			resp, err := as.Introspect("resource_server_token")
			if test.success != (err == nil) {
				t.Fatalf("Expected success %t on the first call to introspect, but got error %+v", test.success, err)
			}
			if test.active && (resp == nil || !resp.Active) {
				t.Fatal("Expected the token to be active, but it was not.")
			}

			errsChan := make(chan error, goRoutineCount)
			for c := 0; c < goRoutineCount; c++ {
				go func(as AuthSessionProvider, success bool, active bool, errsChan chan error) {
					resp, err := as.Introspect("resource_server_token")
					switch {
					case success != (err == nil):
						errsChan <- errors.New(fmt.Sprintf("Expected success %t but got error %+v", success, err))
					case active && (resp == nil || !resp.Active):
						errsChan <- errors.New(fmt.Sprintf("Expected the token to stay active, but it did not."))
					default:
						errsChan <- nil
					}
				}(as, test.success, test.active, errsChan)
			}

			var errs []error
			for i := 0; i < goRoutineCount; i++ {
				if err := <-errsChan; err != nil {
					errs = append(errs, err)
				}
			}
			if len(errs) != 0 {
				t.Fatalf("got %d errors %v", len(errs), errs)
			}

			if test.expectedCallsToAuthServer != authServer.introspectionCount {
				t.Fatalf("Expected %d calls to the auth server for introspection, but got %d", test.expectedCallsToAuthServer, authServer.introspectionCount)
			}
		})
	}
}

// TestAuthSession_AuthorizeRequest tests different scenarios for authorizing requests securely
func TestAuthSession_AuthorizeRequest(t *testing.T) {
	plainTextRequestWithForwardedHeader := httptest.NewRequest("GET", "http://example.com/foo", nil)
	plainTextRequestWithForwardedHeader.Header.Add("Forwarded", "for=192.168.1.1")
	plainTextRequestWithForwardedHeader.Header.Add("Forwarded", "proto=https")
	plainTextRequestWithForwardedHeader.RemoteAddr = "127.0.0.1:9999"
	plainTextRequestWithForwardedHeader.Header.Set("Authorization", "Bearer resource_server_token")

	plainTextRequestWithForwardedHeaderThatContainsHttp := httptest.NewRequest("GET", "http://example.com/foo", nil)
	plainTextRequestWithForwardedHeaderThatContainsHttp.Header.Add("Forwarded", "for=192.168.1.1")
	plainTextRequestWithForwardedHeaderThatContainsHttp.Header.Add("Forwarded", "proto=http")
	plainTextRequestWithForwardedHeaderThatContainsHttp.Header.Add("Forwarded", "proto=https")
	plainTextRequestWithForwardedHeaderThatContainsHttp.RemoteAddr = "127.0.0.1:9999"
	plainTextRequestWithForwardedHeaderThatContainsHttp.Header.Set("Authorization", "Bearer resource_server_token")

	plainTextRequestWithXForwardedProtoHeaderThatContainsHttp := httptest.NewRequest("GET", "http://example.com/foo", nil)
	plainTextRequestWithXForwardedProtoHeaderThatContainsHttp.Header.Add("Forwarded", "for=192.168.1.1")
	plainTextRequestWithXForwardedProtoHeaderThatContainsHttp.Header.Add("X-Forwarded-Proto", "http")
	plainTextRequestWithXForwardedProtoHeaderThatContainsHttp.Header.Add("Forwarded", "proto=https")
	plainTextRequestWithXForwardedProtoHeaderThatContainsHttp.RemoteAddr = "127.0.0.1:9999"
	plainTextRequestWithXForwardedProtoHeaderThatContainsHttp.Header.Set("Authorization", "Bearer resource_server_token")

	plainTextRequestWithLastForwardedHeaderDoesNotContainHttps := httptest.NewRequest("GET", "http://example.com/foo", nil)
	plainTextRequestWithLastForwardedHeaderDoesNotContainHttps.Header.Add("Forwarded", "proto=https")
	plainTextRequestWithLastForwardedHeaderDoesNotContainHttps.Header.Add("Forwarded", "for=192.168.1.1")
	plainTextRequestWithLastForwardedHeaderDoesNotContainHttps.RemoteAddr = "127.0.0.1:9999"
	plainTextRequestWithLastForwardedHeaderDoesNotContainHttps.Header.Set("Authorization", "Bearer resource_server_token")

	plainTextRequestWithForwardedHeaderWithoutProto := httptest.NewRequest("GET", "http://example.com/foo", nil)
	plainTextRequestWithForwardedHeaderWithoutProto.Header.Set("Forwarded", "for=192.168.1.1")
	plainTextRequestWithForwardedHeaderWithoutProto.RemoteAddr = "127.0.0.1:9999"

	plainTextRequestWithoutForwardedHeader := httptest.NewRequest("GET", "http://example.com/foo", nil)
	plainTextRequestWithoutForwardedHeader.RemoteAddr = "127.0.0.1:9999"

	plainTextRequestWithMalformedForwardedHeader := httptest.NewRequest("GET", "http://example.com/foo", nil)
	plainTextRequestWithMalformedForwardedHeader.Header.Set("Forwarded", "njdsagjadsgdasgn")
	plainTextRequestWithMalformedForwardedHeader.RemoteAddr = "127.0.0.1:9999"

	plainTextRequestWithForwardedHeaderWithoutHttps := httptest.NewRequest("GET", "http://example.com/foo", nil)
	plainTextRequestWithForwardedHeaderWithoutHttps.Header.Set("Forwarded", "proto=http")
	plainTextRequestWithForwardedHeaderWithoutHttps.Header.Set("Authorization", "Bearer resource_server_token")
	plainTextRequestWithForwardedHeaderWithoutHttps.RemoteAddr = "127.0.0.1:9999"

	secureRequest := httptest.NewRequest("GET", "https://example.com/foo", nil)
	secureRequest.Header.Set("Authorization", "Bearer resource_server_token")

	secureRequestMalformedAuthorizationHeader := httptest.NewRequest("GET", "https://example.com/foo", nil)
	secureRequestMalformedAuthorizationHeader.Header.Set("Authorization", "malformed")

	tests := []struct {
		name                string
		request             *http.Request
		service             string
		operation           string
		introspectionScope  string
		introspectionExpiry int64
		success             bool
		result              Result
		introspectError     bool
	}{
		{name: "Rejects request over plain text",
			request: httptest.NewRequest("GET", "http://example.com/foo", nil),
			success: false,
		},
		{name: "Rejects request over plain text with forwarded header over http",
			request: plainTextRequestWithForwardedHeaderWithoutHttps,
			success: false,
		},
		{name: "Rejects request over plain text with x-forwarded-proto header over http",
			request: plainTextRequestWithXForwardedProtoHeaderThatContainsHttp,
			success: false,
		},
		{name: "Rejects request over plain text with forwarded header that contains http",
			request: plainTextRequestWithForwardedHeaderThatContainsHttp,
			success: false,
		},
		{name: "Rejects request over plain text when last forwarded header does not contain https proto property",
			request: plainTextRequestWithLastForwardedHeaderDoesNotContainHttps,
			success: false,
		},
		{name: "Rejects request without forwarded header",
			request: plainTextRequestWithoutForwardedHeader,
			success: false,
		},
		{name: "Rejects request with malformed forwarded header",
			request: plainTextRequestWithMalformedForwardedHeader,
			success: false,
		},
		{name: "Rejects request with forwarded header without proto element",
			request: plainTextRequestWithForwardedHeaderWithoutProto,
			success: false,
		},
		{name: "Accepts request over plain text with last forwarded Header secure from loopback address",
			request:             plainTextRequestWithForwardedHeader,
			service:             "MockService",
			operation:           "Action1",
			introspectionScope:  "https://aaa.amazon.com/scopes/MockService#Action1",
			introspectionExpiry: time.Now().Unix() + 60,
			success:             true,
			result:              ResultAllow,
		},
		{name: "Authorizes secure request with correct auth token",
			request:             secureRequest,
			service:             "MockService",
			operation:           "Action1",
			introspectionScope:  "https://aaa.amazon.com/scopes/MockService#Action1",
			introspectionExpiry: time.Now().Unix() + 60,
			success:             true,
			result:              ResultAllow,
		},
		{name: "Challenges secure request without auth token",
			request:             httptest.NewRequest("GET", "https://example.com/foo", nil),
			service:             "MockService",
			operation:           "Action1",
			introspectionScope:  "https://aaa.amazon.com/scopes/MockService#Action1",
			introspectionExpiry: time.Now().Unix() + 60,
			success:             true,
			result:              ResultChallenge,
		},
		{name: "Challenges secure request with unknown authorization header",
			request:             secureRequestMalformedAuthorizationHeader,
			service:             "MockService",
			operation:           "Action1",
			introspectionScope:  "https://aaa.amazon.com/scopes/MockService#Action1",
			introspectionExpiry: time.Now().Unix() + 60,
			success:             true,
			result:              ResultChallenge,
		},
		{name: "Challenges secure request when InvalidError from introspection",
			request:             secureRequest,
			result:              ResultChallenge,
			service:             "MockService",
			operation:           "Action1",
			introspectionScope:  "https://aaa.amazon.com/scopes/MockService#Action1",
			introspectionExpiry: time.Now().Unix() - 1,
			success:             true,
		},
		{name: "Denies secure request which doesn't have required scope",
			request:             secureRequest,
			service:             "MockService",
			operation:           "Action1",
			introspectionScope:  "https://aaa.amazon.com/scopes/MockService#Action2",
			introspectionExpiry: time.Now().Unix() + 60,
			success:             true,
			result:              ResultDeny,
		},
		{name: "Returns error when introspect errors due to internal server error",
			request:             plainTextRequestWithForwardedHeader,
			service:             "MockService",
			operation:           "Action1",
			introspectionScope:  "https://aaa.amazon.com/scopes/MockService#Action1",
			introspectionExpiry: time.Now().Unix() + 60,
			success:             false,
			introspectError:     true,
		},
	}

	for _, test := range tests {
		test := test
		t.Run(test.name, func(t *testing.T) {
			authServer := mockCloudAuthServer{t: t, introspectionExpiry: test.introspectionExpiry, introspectionScope: test.introspectionScope, introspectError: test.introspectError}
			ts := httptest.NewTLSServer(authServer.Handler())
			defer ts.Close()
			ts.Client().Timeout = 1 * time.Minute

			as, err := NewAuthSession(&mockAuthServerToken{}, withAuthServerURL(ts.URL), withAuthSessionClient(ts.Client()))
			if err != nil || as == nil {
				t.Fatalf("failed to create a new auth session. Got auth seesion %v and err %+v", as, err)
			}

			auth, err := as.AuthorizeRequest(test.request, test.service, test.operation)
			if (err == nil) != test.success {
				t.Fatalf("Expecting success %v but got error %+v", test.success, err)
			}
			if test.success {
				if test.result != auth.Result {
					t.Fatalf("Expecting result %v but got result %v", test.result, auth.Result)
				}
				if test.result != ResultAllow && auth.BearerChallenge == "" {
					t.Fatalf("Expecting Bearer Challenge to not be empty")
				}
			}
		})
	}
}
