package ldap

import (
	"encoding/base64"
	"encoding/json"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/assert"
)

func createHeaders(user string, groups []string) http.Header {
	h := make(http.Header)
	if user != "" {
		h.Set(HTTPHeaderUser, user)
	}
	if groups != nil {
		b, _ := json.Marshal(groups)
		g := base64.StdEncoding.EncodeToString(b)
		h.Set(HTTPHeaderGroups, g)
	}
	return h
}

func TestMiddleware(t *testing.T) {
	runner := func(expectError bool, userHeader, groupsHeader string) {
		resp := httptest.NewRecorder()
		req := &http.Request{Header: make(http.Header)}
		req.Header.Set(HTTPHeaderUser, userHeader)
		req.Header.Set(HTTPHeaderGroups, groupsHeader)
		called := false
		h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			called = true
			assert.NotEmpty(t, User(r.Context()))
			assert.NotEmpty(t, Groups(r.Context()))
		})
		wrapped := Middleware(h)
		wrapped.ServeHTTP(resp, req)
		assert.Equal(t, expectError, !called)
	}

	// big table style thing of tests, using runner above
	tests := []struct {
		expectError  bool
		userHeader   string
		groupsHeader string
	}{
		{false, "cooluser", "WyJxd2UiLCAiYXNkIiwgInp4YyJdCg=="},   // ["qwe", "asd", "zxc"], happy path
		{true, "", "WyJxd2UiLCAiYXNkIiwgInp4YyJdCg=="},            // failure no user
		{true, "amazingperson", "WyJxd2UiLCAiYXNkIiwgInp4YyJdCg"}, // failure non-terminated base64
		{true, "splendidhuman", ""},                               // failure no groups
		{true, "exquisitebeing", "!!!!!!!"},                       // failure invalid base64
		{true, "exceptionalmeatsack", "deadbeefdeadbeefdeadbeef"}, // failure invalid json
	}
	for _, test := range tests {
		runner(test.expectError, test.userHeader, test.groupsHeader)
	}
}

func TestExtractUser(t *testing.T) {
	t.Run("HappyPath", func(t *testing.T) {
		user, err := extractUser(createHeaders("cooluser", []string{"group1", "group2", "group3"}))
		assert.NoError(t, err)
		assert.Equal(t, "cooluser", user)
	})
	t.Run("NoUserHeader", func(t *testing.T) {
		_, err := extractUser(createHeaders("", nil)) // wont set user header
		assert.Error(t, err)
		assert.Equal(t, ErrUnknownUser, err)
	})
	t.Run("EmptyUserHeader", func(t *testing.T) {
		h := make(http.Header)
		h.Set(HTTPHeaderUser, "")
		_, err := extractUser(h)
		assert.Error(t, err)
		assert.Equal(t, ErrUnknownUser, err)
	})

}

func TestExtractGroup(t *testing.T) {
	t.Run("HappyPath", func(t *testing.T) {
		groups, err := extractGroups(createHeaders("cooluser", []string{"group1", "group2", "group3"}))
		assert.NoError(t, err)
		assert.Equal(t, "group1", groups[0])
		assert.Equal(t, "group2", groups[1])
		assert.Equal(t, "group3", groups[2])
	})
	t.Run("NoGroupsHeader", func(t *testing.T) {
		_, err := extractGroups(createHeaders("cooluser", nil)) // wont set group header
		assert.Error(t, err)
		assert.Equal(t, ErrUnknownGroups, err)
	})
	t.Run("EmptyGroupsheader", func(t *testing.T) {
		h := make(http.Header)
		h.Set(HTTPHeaderGroups, "")
		_, err := extractGroups(h)
		assert.Error(t, err)
		assert.Equal(t, ErrUnknownGroups, err)
	})
	t.Run("MalformedGroups", func(t *testing.T) {
		t.Run("InvalidBase64", func(t *testing.T) {
			h := make(http.Header)
			h.Set(HTTPHeaderGroups, "!!!!!")
			_, err := extractGroups(h)
			assert.Error(t, err)
			assert.Equal(t, ErrMalformattedGroups, err)
		})
		t.Run("InvalidJSON", func(t *testing.T) {
			badJSON := []byte("[\"almost\",\"valid\",\"json\"")
			h := make(http.Header)
			h.Set(HTTPHeaderGroups, base64.StdEncoding.EncodeToString(badJSON))
			_, err := extractGroups(h)
			assert.Error(t, err)
			assert.Equal(t, ErrMalformattedGroups, err)
		})
	})
}
