package middleware

import (
	"context"
	"net/http"

	"github.com/twitchtv/twirp"

	owl "code.justin.tv/web/owl/client"

	"code.justin.tv/devrel/devsite-rbac/backend/viennauserwhitelist"
	"code.justin.tv/devrel/devsite-rbac/clients/owlcli"
	"code.justin.tv/devrel/devsite-rbac/internal/auth"
	"code.justin.tv/devrel/devsite-rbac/internal/errorutil"
)

const (
	authorizationHeader = "Authorization"
	cartmanTokenHeader  = "Twitch-Authorization"
)

// AuthHeadersMiddleware adds the 'Authorization' and 'Twitch-Authorization' headers into the context,
// so they can be used by Twirp methods and the authorization hook.
func AuthHeadersMiddleware(h http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		rawToken := r.Header.Get(authorizationHeader)
		cartmanToken := r.Header.Get(cartmanTokenHeader)

		ctx := r.Context()
		ctx = auth.WithRawAuthorizationToken(ctx, rawToken)
		ctx = auth.WithCartmanToken(ctx, cartmanToken)
		h.ServeHTTP(w, r.WithContext(ctx))
	})
}

// OwlHooks authenticates the user in the request
// based on the Authorization header value (available in the ctx).
// The user twitchID can be fetch with auth.GetTwitchID(ctx).
//
// This logic is done in a Twirp hook instead of directly in the HTTP Middleware
// so it properly triggers error hooks if the operation fails.
func OwlHooks(owlValidator owlcli.Client) *twirp.ServerHooks {
	hooks := &twirp.ServerHooks{}
	hooks.RequestReceived = func(ctx context.Context) (context.Context, error) {
		oauthToken := auth.GetAuthorizationToken(ctx)
		if oauthToken == "" { // with no token
			return ctx, nil // continue as anonymous request (endpoints that require auth must check the context)
		}

		authx, err := owlValidator.Validate(ctx, oauthToken)
		if err != nil && err == owl.ErrInvalidClientID {
			return ctx, twirp.NewError(twirp.PermissionDenied, "invalid client")
		} else if err != nil && err == owl.ErrForbiddenClientID {
			return ctx, twirp.NewError(twirp.PermissionDenied, "forbidden client")
		} else if err != nil {
			return ctx, err
		}

		if !authx.Valid {
			return ctx, twirp.NewError(twirp.PermissionDenied, "token invalid")
		}

		if authx.ExpiresIn == 0 {
			return ctx, twirp.NewError(twirp.PermissionDenied, "token has expired")
		}

		ctx = auth.WithTwitchID(ctx, authx.OwnerID)
		return ctx, nil
	}

	return hooks
}

func ViennaAuthWhitelistHooks(viennaAdminBackend viennauserwhitelist.UserWhitelist) *twirp.ServerHooks {
	hooks := &twirp.ServerHooks{}

	hooks.RequestReceived = func(ctx context.Context) (context.Context, error) {
		if !auth.IsRequestFromVienna(ctx) {
			return ctx, nil // only applies the whitelist to requests with an OAuth token
		}
		twitchID := auth.GetTwitchID(ctx)
		userWhitelist, err := viennaAdminBackend.GetWhitelistedUser(ctx, twitchID)
		if errorutil.IsErrNoRows(err) {
			return ctx, twirp.NewError(twirp.PermissionDenied, "User is not authorized")
		}
		if err != nil {
			return ctx, err
		}
		return auth.WithViennaWhitelistUserRole(ctx, userWhitelist.Role), nil
	}

	return hooks
}
