package main

import (
	"bufio"
	"bytes"
	"compress/gzip"
	"context"
	"crypto/rand"
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"net/url"
	"strings"

	"github.com/aws/aws-lambda-go/events"
	"github.com/aws/aws-lambda-go/lambda"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/s3"
	"github.com/aws/aws-sdk-go/service/sts"
	jwt "github.com/dgrijalva/jwt-go"
	"github.com/lestrrat-go/jwx/jwk"
	"github.com/lestrrat-go/jwx/jws"
)

const (
	nonceCookie = "amzn_sso_rfp"
	tokenCookie = "amzn_sso_token"
	tokenParam  = "id_token"
)

func main() {
	lambda.Start(handle)
}

var (
	unauthorized = fmt.Errorf("Unauthorized")
)

func handle(ctx context.Context, input json.RawMessage) (*events.APIGatewayCustomAuthorizerResponse, error) {
	var authReq events.APIGatewayCustomAuthorizerRequest
	err := json.Unmarshal(input, &authReq)
	if err != nil {
		return nil, err
	}
	var fullReq events.APIGatewayProxyRequest
	err = json.Unmarshal(input, &fullReq)
	if err != nil {
		return nil, err
	}

	anReq := newAuthNRequest(&fullReq)
	anResp, err := (&authN{}).check(ctx, anReq)
	if err != nil {
		log.Printf("authN.check: %v", err)
		return nil, unauthorized
	}

	resp := &events.APIGatewayCustomAuthorizerResponse{Context: make(map[string]interface{})}

	if anResp.RedirectLocation != "" {
		resp.PolicyDocument.Version = "2012-10-17"
		resp.PolicyDocument.Statement = append(resp.PolicyDocument.Statement, events.IAMPolicyStatement{
			Action:   []string{"execute-api:Invoke"},
			Effect:   "Deny",
			Resource: []string{authReq.MethodArn},
		})
		resp.Context["location"] = anResp.RedirectLocation
		if anResp.SetCookie != nil {
			resp.Context["set-cookie"] = anResp.SetCookie.String()
		}
		return resp, nil
	}

	if anResp.VerifiedClaims == nil {
		log.Printf("no verified claims")
		return nil, unauthorized
	}

	members, err := (&authZ{}).getSinglePassGroups(context.Background())
	if err != nil {
		log.Printf("authZ.getSinglePassGroups: %v", err)
		return nil, unauthorized
	}

	group := fullReq.StageVariables["authorized_posix_group"]

	statement := events.IAMPolicyStatement{
		Action:   []string{"execute-api:Invoke"},
		Effect:   "Deny",
		Resource: []string{authReq.MethodArn},
	}

	user := anResp.VerifiedClaims.Subject
	if user != "" && group != "" && members.isMember(group, user) {
		statement.Effect = "Allow"
	}

	resp.PolicyDocument.Version = "2012-10-17"
	resp.PolicyDocument.Statement = append(resp.PolicyDocument.Statement, statement)
	resp.PrincipalID = anResp.VerifiedClaims.Subject

	return resp, nil
}

type authNRequest struct {
	NonceCookie          *http.Cookie
	TokenCookie          string
	Hostname             string
	MidwayRedirectTarget string
	TokenQuery           string
}

type authNResponse struct {
	RedirectLocation string
	SetCookie        *http.Cookie

	VerifiedClaims *midwayClaims
}

type midwayClaims struct {
	Nonce string `json:"nonce"`
	jwt.StandardClaims
}

func newAuthNRequest(req *events.APIGatewayProxyRequest) *authNRequest {
	// normalize header names: API Gateway doesn't
	headers := make(http.Header)
	for k, v := range req.Headers {
		headers.Set(k, v)
	}

	var (
		host       = headers.Get("Host")
		cookies    = headers.Get("Cookie")
		path       = req.Path
		tokenQuery = req.QueryStringParameters[tokenParam]
	)

	q := make(url.Values)
	for k, v := range req.QueryStringParameters {
		q.Set(k, v)
	}
	// avoid redirect loops
	q.Del(tokenParam)

	redirectTarget := &url.URL{
		Scheme:   "https",
		Host:     host,
		Path:     path,
		RawQuery: q.Encode(),
	}

	anr := &authNRequest{
		Hostname:             host,
		MidwayRedirectTarget: redirectTarget.String(),
		TokenQuery:           tokenQuery,
	}

	cookieReq := (&http.Request{Header: http.Header{"Cookie": []string{cookies}}})
	// TODO: check for and delete duplicate cookies
	if cookie, err := cookieReq.Cookie(nonceCookie); err == nil {
		anr.NonceCookie = cookie
	}
	if cookie, err := cookieReq.Cookie(tokenCookie); err == nil {
		anr.TokenCookie = cookie.Value
	}

	return anr
}

type authN struct{}

func (an *authN) check(ctx context.Context, req *authNRequest) (*authNResponse, error) {
	_, err := an.verifyToken(ctx, req.TokenQuery, req.NonceCookie, req.Hostname)
	if err == nil {
		resp := &authNResponse{
			RedirectLocation: req.MidwayRedirectTarget,
			SetCookie: &http.Cookie{
				Name:     tokenCookie,
				Value:    req.TokenQuery,
				Secure:   true,
				HttpOnly: true,
			},
		}
		return resp, nil
	}

	claims, err := an.verifyToken(ctx, req.TokenCookie, req.NonceCookie, req.Hostname)
	if err != nil {
		log.Printf("bad token: %v", err)
		return an.requestAuthentication(ctx, req)
	}

	return &authNResponse{VerifiedClaims: claims}, nil
}

func (an *authN) verifyToken(ctx context.Context, tokenString string, nonceCookie *http.Cookie, hostname string) (*midwayClaims, error) {
	keyFunc := func(token *jwt.Token) (interface{}, error) { return jwtKeyFunc(ctx, token) }

	var claims midwayClaims
	token, err := (&jwt.Parser{}).ParseWithClaims(tokenString, &claims, keyFunc)
	if err != nil {
		return nil, err
	}

	if alg, _ := token.Header["alg"]; alg != "RS256" {
		return nil, fmt.Errorf("unknown algorithm %q", alg)
	}

	if nonceCookie == nil {
		return nil, fmt.Errorf("nonce cookie missing")
	}
	if claims.Nonce != an.getHexNonce(nonceCookie) {
		return nil, fmt.Errorf("nonce mismatch")
	}
	if claims.Issuer != "https://midway-auth.amazon.com" {
		return nil, fmt.Errorf("issuer mismatch")
	}
	if claims.Audience != hostname {
		return nil, fmt.Errorf("audience mismatch")
	}

	return &claims, nil
}

func (an *authN) requestAuthentication(ctx context.Context, req *authNRequest) (*authNResponse, error) {
	var resp authNResponse

	hexNonce := an.getHexNonce(req.NonceCookie)
	if hexNonce == "" {
		var err error
		req, err = an.replaceNonce(req)
		if err != nil {
			return nil, err
		}
		hexNonce = an.getHexNonce(req.NonceCookie)
		resp.SetCookie = req.NonceCookie
	}

	resp.RedirectLocation = an.midwayRedirect(req).String()
	return &resp, nil
}

func (an *authN) getHexNonce(cookie *http.Cookie) string {
	if cookie == nil {
		return ""
	}
	rfp, err := hex.DecodeString(cookie.Value)
	if err != nil {
		return ""
	}
	nonce := sha256.Sum256(rfp)
	return hex.EncodeToString(nonce[:])
}

func (an *authN) replaceNonce(req *authNRequest) (*authNRequest, error) {
	rfp := make([]byte, 32)
	_, err := io.ReadFull(rand.Reader, rfp)
	if err != nil {
		return nil, err
	}
	cookie := &http.Cookie{
		Name:     nonceCookie,
		Value:    hex.EncodeToString(rfp),
		Secure:   true,
		HttpOnly: true,
	}
	withCookie := *req
	withCookie.NonceCookie = cookie
	return &withCookie, nil
}

func (an *authN) midwayRedirect(req *authNRequest) *url.URL {
	q := make(url.Values)
	q.Set("redirect_uri", req.MidwayRedirectTarget)
	q.Set("client_id", req.Hostname)
	q.Set("scope", "openid")
	q.Set("response_type", "id_token")
	q.Set("nonce", an.getHexNonce(req.NonceCookie))

	return &url.URL{
		Scheme:   "https",
		Host:     "midway-auth.amazon.com",
		Path:     "/SSO/redirect",
		RawQuery: q.Encode(),
	}
}

type OpenIDConfig struct {
	JWKS string `json:"jwks_uri"`
}

func getOpenIDConfig(ctx context.Context, host string) (*OpenIDConfig, error) {
	req, err := http.NewRequest("GET", (&url.URL{
		Scheme: "https",
		Host:   host,
		Path:   "/.well-known/openid-configuration",
	}).String(), nil)
	if err != nil {
		return nil, err
	}
	req = req.WithContext(ctx)

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()

	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return nil, err
	}

	var config OpenIDConfig
	err = json.Unmarshal(body, &config)
	if err != nil {
		return nil, err
	}

	return &config, nil
}

func jwtKeyFunc(ctx context.Context, token *jwt.Token) (interface{}, error) {
	if token.Method.Alg() != "RS256" {
		return nil, fmt.Errorf("unknown alg: %q", token.Method.Alg())
	}

	config, err := getOpenIDConfig(ctx, "midway-auth.amazon.com")
	if err != nil {
		return nil, err
	}

	req, err := http.NewRequest("GET", config.JWKS, nil)
	if err != nil {
		return nil, err
	}
	req = req.WithContext(ctx)

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()

	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return nil, err
	}

	{
		// Midway's keys don't conform to RFC 7515's requirement of
		// stripping the base64 padding.
		var jwks struct {
			Keys []struct {
				Alg string   `json:"alg"`
				E   string   `json:"e"`
				Kid string   `json:"kid"`
				Kty string   `json:"kty"`
				N   string   `json:"n"`
				X5c []string `json:"x5c"`
			} `json:"keys"`
		}
		err = json.Unmarshal(body, &jwks)
		if err != nil {
			return nil, err
		}
		for i := range jwks.Keys {
			jwks.Keys[i].N = strings.TrimRight(jwks.Keys[i].N, "=")
			jwks.Keys[i].E = strings.TrimRight(jwks.Keys[i].E, "=")
			for j := range jwks.Keys[i].X5c {
				jwks.Keys[i].X5c[j] = strings.TrimRight(jwks.Keys[i].X5c[j], "=")
			}
		}
		body, err = json.Marshal(&jwks)
		if err != nil {
			return nil, err
		}
	}

	keys, err := jwk.Parse(body)
	if err != nil {
		return nil, err
	}

	kid, ok := token.Header["kid"].(string)
	if !ok {
		return nil, fmt.Errorf("missing kid")
	}

	for _, key := range keys.LookupKeyID(kid) {
		k, err := key.Materialize()
		if err == nil {
			return k, nil
		}
	}

	return nil, fmt.Errorf("key not available")
}

type authZ struct{}

func (az *authZ) getAccountId(ctx context.Context, sess client.ConfigProvider) (string, error) {
	ident, err := sts.New(sess).GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{})
	if err != nil {
		return "", err
	}

	return aws.StringValue(ident.Account), nil
}

func (az *authZ) listHostclasses(ctx context.Context, sess client.ConfigProvider, spieBucket string) ([]string, error) {
	const prefix = "hostclass/"

	var list []string
	err := s3.New(sess).ListObjectsPagesWithContext(ctx, &s3.ListObjectsInput{
		Bucket:    aws.String(spieBucket),
		Delimiter: aws.String("/"),
		Prefix:    aws.String(prefix),
	}, func(resp *s3.ListObjectsOutput, lastPage bool) bool {
		for _, p := range resp.CommonPrefixes {
			list = append(list, strings.TrimSuffix(strings.TrimPrefix(aws.StringValue(p.Prefix), prefix), "/"))
		}
		return true
	})
	if err != nil {
		return nil, err
	}

	return list, nil
}

func (az *authZ) getOneGroup(ctx context.Context, sess client.ConfigProvider,
	bucket, hostclass string) (*membership, error) {

	key := fmt.Sprintf("hostclass/%s/group", hostclass)
	resp, err := s3.New(sess).GetObjectWithContext(ctx, &s3.GetObjectInput{
		Bucket: aws.String(bucket),
		Key:    aws.String(key),
	})
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()

	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return nil, err
	}

	m, err := jws.Parse(bytes.NewReader(body))
	if err != nil {
		return nil, err
	}
	gr, err := gzip.NewReader(bytes.NewReader(m.Payload()))
	if err != nil {
		return nil, err
	}

	contents := struct {
		Data string `json:"data"`
	}{}
	err = json.NewDecoder(gr).Decode(&contents)
	if err != nil {
		return nil, err
	}

	members := new(membership)

	sc := bufio.NewScanner(strings.NewReader(contents.Data))
	for sc.Scan() {
		parts := strings.SplitN(sc.Text(), ":", 5)
		if len(parts) != 4 {
			continue
		}

		group := parts[0]
		for _, user := range strings.Split(parts[3], ",") {
			members.add(group, user)
		}
	}
	err = sc.Err()
	if err != nil {
		return nil, err
	}

	return members, nil
}

func (az *authZ) getSinglePassGroups(ctx context.Context) (*membership, error) {
	sess, err := session.NewSession()
	if err != nil {
		return nil, err
	}

	acct, err := az.getAccountId(ctx, sess)
	if err != nil {
		return nil, err
	}

	roleArn := fmt.Sprintf("arn:aws:iam::%s:role/singlepass-sync-reader", acct)
	sess = sess.Copy(&aws.Config{Credentials: stscreds.NewCredentials(sess, roleArn)})

	bucket := fmt.Sprintf("singlepass-%s-client-data", acct)
	list, err := az.listHostclasses(ctx, sess, bucket)
	if err != nil {
		return nil, err
	}

	allGroups := new(membership)

	for _, hostclass := range list {
		members, err := az.getOneGroup(ctx, sess, bucket, hostclass)
		if err != nil {
			return nil, err
		}
		allGroups.addAll(members)
	}

	return allGroups, nil
}

type membership struct {
	groups map[string]map[string]struct{}
}

func (m *membership) add(group, user string) {
	if m.groups == nil {
		m.groups = make(map[string]map[string]struct{})
	}
	if _, ok := m.groups[group]; !ok {
		m.groups[group] = make(map[string]struct{})
	}
	members := m.groups[group]
	members[user] = struct{}{}
}

func (m *membership) addAll(other *membership) {
	for group, members := range other.groups {
		for user := range members {
			m.add(group, user)
		}
	}
}

func (m *membership) isMember(group, user string) bool {
	_, ok := m.groups[group][user]
	return ok
}
