// +build test

package handlers

import (
	"encoding/base64"
	"regexp"
	"strconv"
	"testing"

	"ting/util"
	. "ting/util/types"

	"code.justin.tv/tshadwell/jwt"

	"github.com/gin-gonic/gin"
)

var jwtSecret []byte
var jwtAlgo jwt.Algorithm

func init() {
	jwtSecret = []byte("password123")
	jwtAlgo = jwt.HS256(jwtSecret)
}

func defaultJWTHeader() map[string]string {
	return map[string]string{"alg": "HS256", "typ": "JWT"}
}

func defaultJWTClaims() jwtClaims {
	return jwtClaims{
		Expiration:  32503680000, // Jan 1, 3000
		ChannelID:   "123",
		OpaqueID:    "U123",
		Role:        "viewer",
		PubsubPerms: map[string][]string{"listen": []string{"*"}, "send": []string{"*"}},
		TwitchID:    "T123",
	}
}

func makeJWT(t *testing.T, headerDiff, claimsDiff StringMap) string {
	header := defaultJWTHeader()
	if len(headerDiff) > 0 {
		util.FromJSON(t, util.ToJSON(t, headerDiff), &header)
	}

	claims := defaultJWTClaims()
	if len(claimsDiff) > 0 {
		// For convenience, accept "channel_id" as an `int` as well.
		if idV, found := claimsDiff["channel_id"]; found {
			if id, ok := idV.(int); ok {
				claimsDiff["channel_id"] = strconv.Itoa(id)
			}
		}
		util.FromJSON(t, util.ToJSON(t, claimsDiff), &claims)
	}

	buf, err := jwt.Encode(header, claims, jwtAlgo)
	if err != nil {
		t.Fatalf("error encoding JWT: %s", err)
	}
	return string(buf)
}

func makeJWTHeader(t *testing.T, headerDiff, claimsDiff StringMap) map[string]string {
	return map[string]string{
		"Authorization": "Bearer " + makeJWT(t, headerDiff, claimsDiff),
	}
}

func initJWTEngine(t *testing.T) *gin.Engine {
	engine := gin.New()
	jwtSecretB64 := base64.URLEncoding.EncodeToString(jwtSecret)
	if mw, err := JWTMiddleware(jwtSecretB64); err != nil {
		t.Fatalf("error initializing JWT middleware: %s", err)
	} else {
		engine.Use(mw)
	}
	engine.Any("/hello", func(ctx *gin.Context) { ctx.JSON(200, ctx.Keys) })
	return engine
}

func assertJWTError(t *testing.T, e *gin.Engine, token, msgRegex string) {
	t.Helper()
	reqHeaders := make(map[string]string, 1)
	if token != "" {
		reqHeaders["Authorization"] = "Bearer " + token
	}
	code, body, headers := request(t, e, "GET", "/hello", "", reqHeaders)
	defer dumpResponseIfFailed(t, code, body, headers)

	if code != 401 {
		t.Fatalf("wrong HTTP response code: %d; expected 401", code)
	} else if authType := headers.Get("WWW-Authenticate"); authType == "" {
		t.Fatal(`response headers do not include "WWW-Authenticate"`)
	} else if authType != `Bearer token_type="JWT"` {
		t.Fatalf(`wrong value for "WWW-Authenticate" header: %q`, authType)
	} else if jsonBody := util.MapFromJSON(t, body); len(jsonBody) != 1 {
		t.Fatal("expected body to be singleton JSON object")
	} else if msgV, found := jsonBody["error"]; !found {
		t.Fatal(`response JSON does not have "error" key`)
	} else if msg, ok := msgV.(string); !ok {
		t.Fatalf("response error message is not a string: %T(%#v)", msgV, msgV)
	} else if !regexp.MustCompile(msgRegex).MatchString(msg) {
		t.Fatalf("wrong error message: %q; expected to match %q", msg, msgRegex)
	}
}

func TestJWTMiddleware(t *testing.T) {
	e := initJWTEngine(t)

	t.Run("NoHeader", func(t *testing.T) {
		assertJWTError(t, e, "", "missing Authorization header")
	})

	t.Run("NonJWT", func(t *testing.T) {
		assertJWTError(t, e, "lolwut", "invalid section count")
	})

	t.Run("AlmostJWT", func(t *testing.T) {
		token := makeJWT(t, nil, nil)
		// Base64 decode accepts missing "=" and "==" padding -> chop 3 characters.
		assertJWTError(t, e, token[:len(token)-3], "invalid base64 value")
	})

	t.Run("BadJWTHeader", func(t *testing.T) {
		token := makeJWT(t, StringMap{"typ": "foo"}, nil)
		assertJWTError(t, e, token, "invalid header")
	})

	t.Run("WrongAlgo", func(t *testing.T) {
		token := makeJWT(t, StringMap{"alg": "HS512"}, nil)
		assertJWTError(t, e, token, "invalid header")
	})

	t.Run("BadJWTSignature", func(t *testing.T) {
		token := makeJWT(t, nil, nil)
		lastRune := token[len(token)-1]
		if lastRune == 'A' {
			lastRune = 'B'
		} else {
			lastRune = 'A'
		}
		token = token[:len(token)-1] + string(lastRune)

		assertJWTError(t, e, token, "invalid signature")
	})

	t.Run("Expired", func(t *testing.T) {
		token := makeJWT(t, nil, StringMap{"exp": 1})
		assertJWTError(t, e, token, "token expired")
	})

	t.Run("OK", func(t *testing.T) {
		reqHeaders := map[string]string{"Authorization": "Bearer " + makeJWT(t, nil, nil)}
		code, body, headers := request(t, e, "GET", "/hello", "", reqHeaders)
		defer dumpResponseIfFailed(t, code, body, headers)
		if code != 200 {
			t.Fatalf("unexpected failure: HTTP %d\nbody: %s", code, body)
		}
		c := defaultJWTClaims()
		expVars := map[string]interface{}{
			"jwt": true,
			// `channel_id` is `float64` here because we'll be asserting against a `map[string]interface{}`
			// that is produced by decoding JSON, which always parses numbers as `float64`.
			"channel_id": util.AtofT(t, c.ChannelID),
			"opaque_id":  c.OpaqueID,
			"role":       c.Role,
		}
		actVars := util.MapFromJSON(t, body)
		util.AssertEqual(t, actVars, expVars)
	})

	t.Run("OPTIONS", func(t *testing.T) {
		code, body, _ := request(t, e, "OPTIONS", "/hello", "", nil)
		if code != 200 {
			t.Fatalf("unexpected failure: HTTP %d\nbody: %s", code, body)
		} else if body != "null" {
			t.Fatalf("unexpected response: %s", body)
		}
	})
}
