/* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. */

package cloudauth

import (
	"encoding/json"
	"errors"
	"fmt"
	"net/http"
	"strings"
	"sync"
	"testing"
	"time"

	"golang.org/x/oauth2"
)

// What kind of responses to return on each request.
const (
	responseValid = iota
	responseBogus
	responseInvalidStatusCode
	responseIoError
)

// State of the mock CloudAuth Auth Server.
const (
	stateInit = iota
	stateDiscovered
	stateAuthenticated
)

const (
	mockServiceName      = "MockService"
	mockServiceHost      = "mockservice.integ.amazon.com"
	mockServiceAuthority = "mockservice.integ.amazon.com:9443"
	mockServiceURL       = "https://mockservice.integ.amazon.com:9443"
	mockServiceScope     = "https://aaa.amazon.com/scopes/MockService#Action1"
)

type serverBehaviour struct {
	state    int
	response int
}

type mockCloudAuthServer struct {
	t                   *testing.T
	state               int
	behaviour           serverBehaviour
	introspectionExpiry int64
	introspectionCount  int
	introspectionScope  string
	mutex               sync.Mutex
	introspectError     bool
}

func mockDiscoveryMetadata(host string) *authDiscoveryMetadata {
	return &authDiscoveryMetadata{
		Issuer:                fmt.Sprintf("https://%s", host),
		AuthorizationEndpoint: fmt.Sprintf("https://%s/authorization_grant", host),
		TokenEndpoint:         fmt.Sprintf("https://%s/token", host),
		IntrospectionEndpoint: fmt.Sprintf("https://%s/introspect", host),
		BootstrapEndpoint:     fmt.Sprintf("https://%s/bootstrap", host),
		JwksURI:               fmt.Sprintf("https://%s/jwks", host),
	}
}

func (m *mockCloudAuthServer) Handler() http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Handle fast-fail responses first.
		if m.state == m.behaviour.state {
			switch m.behaviour.response {
			case responseBogus:
				fmt.Fprintln(w, "bogus response")
				return
			case responseInvalidStatusCode:
				http.Error(w, "404", http.StatusNotFound)
				return
			case responseIoError:
				w.Header().Set("Content-Length", "1")
				fmt.Fprintln(w, "io err")
				return
			}
		}

		switch r.RequestURI {
		case "/.well-known/openid-configuration":
			r, _ := json.Marshal(mockDiscoveryMetadata(r.Host))

			w.Header().Set("Content-Type", "application/json")
			fmt.Fprintln(w, string(r))

			m.mutex.Lock()
			m.state = stateDiscovered
			m.mutex.Unlock()
		case "/token":
			// If there is Bearer header, then it's a request to
			// get access token for Resource Server.
			if strings.HasPrefix(r.Header.Get("Authorization"), "Bearer") {
				if err := r.ParseForm(); err != nil {
					panic(fmt.Sprintf("Parsing error in token request: %+v", err))
				}

				// Check form parameters.
				for _, chk := range []struct {
					param string
					got   string
					want  string
				}{
					{
						param: "Authorization",
						got:   r.Header.Get("Authorization"),
						want:  "Bearer auth_server_token",
					},
					{
						param: "grant_type",
						got:   r.PostFormValue("grant_type"),
						want:  "client_credentials",
					},
					{
						param: "realm",
						got:   r.PostFormValue("realm"),
						want:  mockServiceName,
					},
					{
						param: "scope",
						got:   r.PostFormValue("scope"),
						want:  mockServiceScope,
					},
					{
						param: "host",
						got:   r.PostFormValue("host"),
						want:  mockServiceHost,
					},
				} {
					if chk.got != chk.want {
						panic(fmt.Sprintf("internal error: got %s %q, want %q", chk.param, chk.got, chk.want))
					}
				}

				// Return mock access token.
				token := &oauth2.Token{
					AccessToken: "resource_server_token",
					TokenType:   "Bearer",
					Expiry:      time.Now().Add(1 * time.Hour),
				}

				r, _ := json.Marshal(token)
				w.Header().Set("Content-Type", "application/json")
				fmt.Fprintln(w, string(r))
			} else {
				// Otherwise it's request to get access token for
				// Auth Server.
				if err := r.ParseForm(); err != nil {
					panic(fmt.Sprintf("Parsing error in token request: %+v", err))
				}

				// Check form parameters.
				for _, chk := range []struct {
					param string
					got   string
					want  string
				}{
					{
						param: "assertion",
						got:   r.PostFormValue("assertion"),
						want:  "client_identity_token",
					},
					{
						param: "grant_type",
						got:   r.PostFormValue("grant_type"),
						want:  "urn:ietf:params:oauth:grant-type:jwt-bearer",
					},
					{
						param: "realm",
						got:   r.PostFormValue("realm"),
						want:  defaultRealm,
					},
					{
						param: "scope",
						got:   r.PostFormValue("scope"),
						want:  defaultScopes,
					},
				} {
					if chk.got != chk.want {
						panic(fmt.Sprintf("internal error: got %s %q, want %q", chk.param, chk.got, chk.want))
					}
				}

				// Return mock access token.
				token := &oauth2.Token{
					AccessToken: "auth_server_token",
					TokenType:   "Bearer",
					Expiry:      time.Now().Add(1 * time.Hour),
				}

				r, _ := json.Marshal(token)

				w.Header().Set("Content-Type", "application/json")
				fmt.Fprintln(w, string(r))

				m.mutex.Lock()
				m.state = stateAuthenticated
				m.mutex.Unlock()
			}
		case "/introspect":
			if m.introspectError {
				http.Error(w, "500", http.StatusInternalServerError)
			}
			if err := r.ParseForm(); err != nil {
				panic(fmt.Sprintf("Parsing error in introspect: %+v", err))
			}

			// Check form parameters.
			if got, want := r.PostFormValue("token"), "resource_server_token"; got != want {
				panic(fmt.Sprintf("internal error: got token %q, want %q", got, want))
			}

			// Return mock introspection response.
			r, _ := json.Marshal(&IntrospectResponse{
				Active:         true,
				Audience:       []string{m.introspectionScope},
				Subject:        "urn:cdo:MockService",
				Scope:          m.introspectionScope,
				ExpiryUnixTime: m.introspectionExpiry,
			})

			w.Header().Set("Content-Type", "application/json")
			fmt.Fprintln(w, string(r))
			m.mutex.Lock()
			m.introspectionCount++
			m.mutex.Unlock()
		default:
			fmt.Fprintln(w, "bogus")
		}
	})
}

type mockAuthServerToken struct{}

func (m *mockAuthServerToken) Token(url string) (*oauth2.Token, error) {
	return &oauth2.Token{
		AccessToken: "client_identity_token",
		TokenType:   "urn:ietf:params:oauth:grant-type:jwt-bearer",
		Expiry:      time.Now().Add(1 * time.Hour),
	}, nil
}

type mockBogusAuthServerToken struct{}

func (m *mockBogusAuthServerToken) Token(url string) (*oauth2.Token, error) {
	return nil, errors.New("bogus identity")
}
