// Script intended to illustrate the required requests to generate an access
// token in TwitchS2SAuth. Not intended for use in production - instead use the
// library at ./s2s/.
package main

import (
	"bytes"
	"encoding/json"
	"flag"
	"fmt"
	"io/ioutil"
	"math/big"
	"net/http"
	"net/url"
	"regexp"
	"strconv"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
)

var numbersRegExp = regexp.MustCompile(`\d+$`)

type accessTokenGetter struct {
	Client      *http.Client
	SigV4Client *http.Client
}

func (atg accessTokenGetter) requireStatus(res *http.Response, status int) error {
	if status != res.StatusCode {
		bs, err := ioutil.ReadAll(res.Body)
		if err != nil {
			return err
		}
		return fmt.Errorf("status<%d>: %s", res.StatusCode, string(bs))
	}
	return nil
}

type openIDConfigurationOutput struct {
	AuthorizationEndpoint string `json:"authorization_endpoint"`
	IntrospectionEndpoint string `json:"introspection_endpoint"`
	Issuer                string `json:"issuer"`
	JSONWebKeySetURI      string `json:"jwks_uri"`
	TokenEndpoint         string `json:"token_endpoint"`
}

func (atg accessTokenGetter) OpenIDConfiguration(oidcDiscoveryEndpoint string) (*openIDConfigurationOutput, error) {
	res, err := atg.Client.Get(oidcDiscoveryEndpoint)
	if err != nil {
		return nil, err
	}
	defer res.Body.Close()
	if err := atg.requireStatus(res, http.StatusOK); err != nil {
		return nil, err
	}

	var output openIDConfigurationOutput
	if err := json.NewDecoder(res.Body).Decode(&output); err != nil {
		return nil, err
	}

	return &output, nil
}

type authorizationGrantOutput struct {
	AccessToken string `json:"access_token"`
	ExpiresIn   int    `json:"expires_in"`
}

func (atg accessTokenGetter) AuthorizationGrant(
	openIDConfiguration *openIDConfigurationOutput,
	issuer string,
	clientServiceURI string,
) (*authorizationGrantOutput, error) {
	req, err := http.NewRequest("GET", openIDConfiguration.AuthorizationEndpoint, nil)
	if err != nil {
		return nil, err
	}

	req.URL.RawQuery = url.Values{
		"host":              []string{issuer},
		"response_type":     []string{"assertion"},
		"twitch_s2s_client": []string{clientServiceURI},
	}.Encode()

	res, err := atg.SigV4Client.Do(req)
	if err != nil {
		return nil, err
	}
	defer res.Body.Close()

	if err := atg.requireStatus(res, http.StatusOK); err != nil {
		return nil, err
	}

	var output authorizationGrantOutput
	if err := json.NewDecoder(res.Body).Decode(&output); err != nil {
		return nil, err
	}

	return &output, nil
}

type tokenOutput struct {
	AccessToken string `json:"access_token"`
	ExpiresIn   int    `json:"expires_in"`
	Scope       string `json:"scope"`
}

func (atg accessTokenGetter) ClientCredentials(
	openIDConfiguration *openIDConfigurationOutput,
	authorizationGrant *authorizationGrantOutput,
	getTokenScope string,
	issuer string,
) (*tokenOutput, error) {
	req, err := http.NewRequest(
		"POST",
		openIDConfiguration.TokenEndpoint,
		bytes.NewBufferString(url.Values{
			"grant_type": []string{"twitch_s2s_service_credentials"},
			"assertion":  []string{string(authorizationGrant.AccessToken)},
			"scope":      []string{getTokenScope},
		}.Encode()),
	)
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
	req.Header.Set("x-host", issuer)

	res, err := atg.Client.Do(req)
	if err != nil {
		return nil, err
	}
	defer res.Body.Close()

	if err := atg.requireStatus(res, http.StatusOK); err != nil {
		return nil, err
	}

	var output tokenOutput
	if err := json.NewDecoder(res.Body).Decode(&output); err != nil {
		return nil, err
	}

	return &output, nil
}

func (atg accessTokenGetter) AccessToken(
	openIDConfiguration *openIDConfigurationOutput,
	clientCredentials *tokenOutput,
	intendedHost string,
	scopeURI string,
) (*tokenOutput, error) {
	req, err := http.NewRequest(
		"POST",
		openIDConfiguration.TokenEndpoint,
		bytes.NewBufferString(url.Values{
			"grant_type": []string{"client_credentials"},
			"scope":      []string{scopeURI},
		}.Encode()),
	)
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
	req.Header.Set("Authorization", "Bearer "+string(clientCredentials.AccessToken))
	req.Header.Set("x-host", intendedHost)

	res, err := atg.Client.Do(req)
	if err != nil {
		return nil, err
	}

	if err := atg.requireStatus(res, http.StatusOK); err != nil {
		return nil, err
	}

	var output tokenOutput
	if err := json.NewDecoder(res.Body).Decode(&output); err != nil {
		return nil, err
	}

	return &output, nil
}

type jsonWebKeySetOutput struct {
	Keys []struct {
		KeyID string `json:"kid"`
		X     []byte `json:"x"`
		Y     []byte `json:"y"`
	} `json:"keys"`
}

func (atg accessTokenGetter) JSONWebKeySet(
	openIDConfiguration *openIDConfigurationOutput,
) (*jsonWebKeySetOutput, time.Duration, error) {
	req, err := http.NewRequest("GET", openIDConfiguration.JSONWebKeySetURI, nil)
	if err != nil {
		return nil, 0, err
	}

	res, err := atg.Client.Do(req)
	if err != nil {
		return nil, 0, err
	}

	if err := atg.requireStatus(res, http.StatusOK); err != nil {
		return nil, 0, err
	}

	var output jsonWebKeySetOutput
	if err := json.NewDecoder(res.Body).Decode(&output); err != nil {
		return nil, 0, err
	}

	maxAge, err := strconv.Atoi(numbersRegExp.FindString(res.Header.Get("Cache-Control")))
	if err != nil {
		return nil, 0, err
	}

	return &output, time.Duration(maxAge) * time.Second, nil
}

type sigV4RoundTripper struct {
	http.RoundTripper

	AWSRegion  string
	AWSService string
	Signer     *v4.Signer
}

func (rt *sigV4RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
	if _, err := rt.Signer.Sign(req, nil, rt.AWSService, rt.AWSRegion, time.Now()); err != nil {
		return nil, err
	}
	return rt.RoundTripper.RoundTrip(req)
}

func main() {
	oidcDiscoveryEndpoint := flag.String("discovery", "https://gateway.us-west-2.prod.s2s.s.twitch.a2z.com/.well-known/openid-configuration", "OIDC discovery endpoint to use")
	getTokenScope := flag.String("get-token-scope", "https://auth.prod.services.s2s.twitch.a2z.com#GetToken", "Scope to use for client credentials")
	serviceURI := flag.String("service-uri", "https://prod.services.s2s.twitch.a2z.com/5d56171b-4549-4e4e-908f-dbab8ccad582", "Service to get a token for")
	intendedHost := flag.String("host", "https://my.target.service", "Host token is intended for. Only this host will validate this token.")
	issuerURI := flag.String("issuer", "https://gateway.us-west-2.prod.s2s.s.twitch.a2z.com", "Issuer URI")
	scopeURI := flag.String("scope", "https://prod.services.s2s.twitch.a2z.com/1f6c07cd-5e18-4026-b6ae-41d21e271245#AuthorizedMethod", "Scope to request")
	flag.Parse()

	atg := accessTokenGetter{
		Client: &http.Client{},
		SigV4Client: &http.Client{
			Transport: &sigV4RoundTripper{
				RoundTripper: http.DefaultTransport,
				AWSRegion:    "us-west-2",
				AWSService:   "execute-api",
				Signer:       v4.NewSigner(session.Must(session.NewSession(&aws.Config{})).Config.Credentials),
			},
		},
	}

	openIDConfiguration, err := atg.OpenIDConfiguration(*oidcDiscoveryEndpoint)
	if err != nil {
		panic(err)
	}

	authorizationGrant, err := atg.AuthorizationGrant(
		openIDConfiguration,
		*issuerURI,
		*serviceURI,
	)
	if err != nil {
		panic(err)
	}

	clientCredentials, err := atg.ClientCredentials(
		openIDConfiguration,
		authorizationGrant,
		*getTokenScope,
		*issuerURI,
	)
	if err != nil {
		panic(err)
	}

	accessToken, err := atg.AccessToken(
		openIDConfiguration,
		clientCredentials,
		*intendedHost,
		*scopeURI,
	)
	if err != nil {
		panic(err)
	}

	fmt.Println("Access Token")
	fmt.Printf("%s\n", string(accessToken.AccessToken))

	jsonWebKeySet, maxAge, err := atg.JSONWebKeySet(openIDConfiguration)
	if err != nil {
		panic(err)
	}

	fmt.Println("")
	fmt.Println("S2S Auth Server Validation Keys")
	fmt.Printf("Cachable for %s\n", maxAge)
	for _, key := range jsonWebKeySet.Keys {
		x := big.NewInt(0)
		x.SetBytes(key.X)
		y := big.NewInt(0)
		y.SetBytes(key.Y)
		fmt.Printf("%s: X=%d, Y=%d\n", key.KeyID, x, y)
	}
}
