package extjwt

import (
	"net/http"
	"net/http/httptest"
	"net/url"
	"strings"
	"testing"
	"time"

	"code.justin.tv/devhub/e2ml/libs/errors"
	httpx "code.justin.tv/devhub/e2ml/libs/http"
	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/metrics/devnull"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type testHandler struct {
	healthCheck int
	statusCode  int
	wait        time.Duration
}

func (t *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	if strings.Contains(r.URL.Path, "debug/running") {
		w.WriteHeader(t.healthCheck)
		return
	}
	w.WriteHeader(t.statusCode)
	time.Sleep(t.wait)
	if err, found := httpx.GenericErrorForHTTPStatus(t.statusCode); found {
		if bytes, err := errors.MarshalAny(err); err == nil {
			w.Write(bytes)
		} else {
			w.Write([]byte("{}"))
		}
	} else {
		w.Write([]byte("{}"))
	}
}

func TestRemoteValidator(t *testing.T) {
	okURL, err := url.Parse(httptest.NewServer(&testHandler{healthCheck: http.StatusOK, statusCode: http.StatusOK}).URL)
	require.NoError(t, err)
	unhealthyURL, err := url.Parse(httptest.NewServer(&testHandler{healthCheck: http.StatusServiceUnavailable, statusCode: http.StatusOK}).URL)
	require.NoError(t, err)
	noContentURL, err := url.Parse(httptest.NewServer(&testHandler{healthCheck: http.StatusOK, statusCode: http.StatusNoContent}).URL)
	require.NoError(t, err)
	forbiddenURL, err := url.Parse(httptest.NewServer(&testHandler{healthCheck: http.StatusOK, statusCode: http.StatusForbidden}).URL)
	require.NoError(t, err)
	slowURL, err := url.Parse(httptest.NewServer(&testHandler{healthCheck: http.StatusOK, statusCode: http.StatusOK, wait: time.Second}).URL)
	require.NoError(t, err)

	okExp := time.Now().Add(time.Hour).Unix()
	emptyClaims := &claims{}
	invalidChannelClaims := &claims{ChannelID: "suspect"}
	noVerbClaims := &claims{ChannelID: channelID, Expires: okExp}
	expiredClaims := &claims{
		ChannelID: channelID,
		Verbs:     map[string][]string{listenVerb: []string{"global"}},
	}
	okClaims := &claims{
		ChannelID: channelID,
		Verbs:     map[string][]string{listenVerb: []string{"global"}},
		Expires:   okExp,
	}
	okToken := testToken(okClaims)

	t.Run("should return error on missing claims", func(t *testing.T) {
		valid, err := NewRemoteValidator(okURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid(clientID, nil, nil)
		assert.False(t, ok)
		assert.Equal(t, stream.ErrMissingJWTClaims, err)
	})

	t.Run("should return unauthorized on missing client_id", func(t *testing.T) {
		valid, err := NewRemoteValidator(okURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid("", okToken, okClaims)
		assert.False(t, ok)
		assert.Equal(t, stream.ErrMissingClientID, err)
	})

	t.Run("should return unauthorized on invalid client_id (length)", func(t *testing.T) {
		valid, err := NewRemoteValidator(okURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid("this0is0a0very0long0client0id0that0is0obviously0not0valid", okToken, okClaims)
		assert.False(t, ok)
		assert.Equal(t, stream.ErrInvalidClientID, err)
	})

	t.Run("should return unauthorized on invalid client_id (content)", func(t *testing.T) {
		valid, err := NewRemoteValidator(okURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid("illegal?", okToken, okClaims)
		assert.False(t, ok)
		assert.Equal(t, stream.ErrInvalidClientID, err)
	})

	t.Run("should return unauthorized on missing channel", func(t *testing.T) {
		valid, err := NewRemoteValidator(okURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid(clientID, okToken, emptyClaims)
		assert.False(t, ok)
		assert.Equal(t, stream.ErrMissingChannelID, err)
	})

	t.Run("should return unauthorized on invalid channel", func(t *testing.T) {
		valid, err := NewRemoteValidator(okURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid(clientID, okToken, invalidChannelClaims)
		assert.False(t, ok)
		assert.Equal(t, stream.ErrInvalidChannelID, err)
	})

	t.Run("should return unauthorized on empty verbs", func(t *testing.T) {
		valid, err := NewRemoteValidator(okURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid(clientID, okToken, noVerbClaims)
		assert.False(t, ok)
		assert.Equal(t, stream.ErrMissingListenPermissions, err)
	})

	t.Run("should return unauthorized on expired token", func(t *testing.T) {
		valid, err := NewRemoteValidator(okURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid(clientID, okToken, expiredClaims)
		assert.False(t, ok)
		assert.Equal(t, stream.ErrAuthExpired, err)
	})

	t.Run("should return error on illegal URL", func(t *testing.T) {
		_, err := NewRemoteValidator(&url.URL{Host: "::"}, http.Client{}, devnull.NewTracker(), logging.Noop)
		assert.Error(t, err)
	})

	t.Run("should return error on failed healthcheck", func(t *testing.T) {
		_, err := NewRemoteValidator(unhealthyURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		assert.Error(t, err)
	})

	t.Run("should return true on 200", func(t *testing.T) {
		valid, err := NewRemoteValidator(okURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid(clientID, okToken, okClaims)
		assert.True(t, ok)
		assert.NoError(t, err)
	})

	t.Run("should return true on 204", func(t *testing.T) {
		valid, err := NewRemoteValidator(noContentURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid(clientID, okToken, okClaims)
		assert.True(t, ok)
		assert.NoError(t, err)
	})

	t.Run("should return false on error response code", func(t *testing.T) {
		valid, err := NewRemoteValidator(forbiddenURL, http.Client{}, devnull.NewTracker(), logging.Noop)
		require.NoError(t, err)
		ok, err := valid(clientID, okToken, okClaims)
		assert.False(t, ok)
		assert.Equal(t, httpx.ErrForbidden, err) // mock server sending standard HTTP status code errors
	})

	t.Run("should return an error if unable to connect", func(t *testing.T) {
		valid, err := NewRemoteValidator(slowURL, http.Client{Timeout: time.Nanosecond}, devnull.NewTracker(), logging.Noop)
		assert.Nil(t, valid)
		assert.Error(t, err)
	})

	t.Run("should return false on timeout after connect", func(t *testing.T) {
		client := http.Client{Timeout: time.Nanosecond}
		valid := remoteValidator{client, slowURL, devnull.NewTracker(), logging.Noop}.validate
		ok, err := valid(clientID, okToken, okClaims)
		assert.False(t, ok)
		assert.Error(t, protocol.ErrServiceUnavailable, err)
	})
}
