package auth

import (
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io/ioutil"
	"log"
	"net"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	goji "goji.io"
	"goji.io/pat"
)

func newTLSTestServer(handler http.Handler) *httptest.Server {
	srv := httptest.NewUnstartedServer(handler)

	cert, err := tls.LoadX509KeyPair("testdata/server-cert.pem", "testdata/server-key.pem")
	if err != nil {
		log.Fatal(err)
	}

	clientPEM, err := ioutil.ReadFile("testdata/client-root-cert.pem")
	if err != nil {
		log.Fatal(err)
	}

	clientCAs := x509.NewCertPool()
	if !clientCAs.AppendCertsFromPEM(clientPEM) {
		log.Fatal("Invalid client certificate")
	}

	srv.TLS = &tls.Config{
		ClientCAs:    clientCAs,
		Certificates: []tls.Certificate{cert},
		ClientAuth:   tls.RequireAndVerifyClientCert,
	}

	srv.StartTLS()

	return srv
}

func newTLSTestClient(cert, key string) *http.Client {
	var certificates []tls.Certificate

	if cert != "" || key != "" {
		// Load client cert
		cert, err := tls.LoadX509KeyPair(cert, key)
		if err != nil {
			log.Fatal("Failed to load client key pair: ", err)
		}
		certificates = append(certificates, cert)
	}

	// Load CA cert
	caCert, err := ioutil.ReadFile("testdata/server-cert.pem")
	if err != nil {
		log.Fatal("Failed to read server cert: ", err)
	}
	caCertPool := x509.NewCertPool()
	if !caCertPool.AppendCertsFromPEM(caCert) {
		log.Fatal("Invalid server certificate")
	}

	// Setup HTTPS client
	tlsConfig := &tls.Config{
		Certificates: certificates,
		RootCAs:      caCertPool,
	}
	tlsConfig.BuildNameToCertificate()
	transport := &http.Transport{TLSClientConfig: tlsConfig}
	return &http.Client{Transport: transport}
}

// returns a client configured with a certificate that should have its CN set to
// lvs-customer-a
func newValidTLSTestClient() *http.Client {
	return newTLSTestClient("testdata/lvs-customer-a-cert.pem", "testdata/lvs-customer-a-key.pem")
}

func TestFromCertificate(t *testing.T) {
	t.Skip("Skipping, as the test certificates have expired. TODO: Regenerate them")
	mux := goji.NewMux()
	mux.Use(FromCertificate)
	mux.Handle(pat.Get("/test"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		customerID, ok := LvsCustomerID(r.Context())
		if !ok {
			http.NotFound(w, r)
			return
		}

		w.Write([]byte(customerID))
	}))

	server := newTLSTestServer(mux)
	client := newValidTLSTestClient()

	port := server.Listener.Addr().(*net.TCPAddr).Port
	resp, err := client.Get(fmt.Sprintf("https://localhost:%d/test", port))
	require.NoError(t, err)
	defer resp.Body.Close()

	respBody, err := ioutil.ReadAll(resp.Body)
	assert.NoError(t, err)

	assert.Equal(t, string(respBody), "lvs-customer-a")
}

// Tests that setting authentication in the header works, and that the value in
// matches the value out
func TestFromHeaders(t *testing.T) {
	testCustomerID := "this-is-a-test"

	mux := goji.NewMux()
	mux.Use(FromHeader)
	mux.Handle(pat.Get("/test"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		customerID, ok := LvsCustomerID(r.Context())
		if !ok {
			http.NotFound(w, r)
			return
		}

		w.Write([]byte(customerID))
	}))
	server := httptest.NewServer(mux)
	defer server.Close()

	req, err := http.NewRequest("GET", server.URL+"/test", nil)
	require.NoError(t, err)

	req.Header.Set(ForceAuthHeader, testCustomerID)

	resp, err := http.DefaultClient.Do(req)
	require.NoError(t, err)
	defer resp.Body.Close()

	respBody, err := ioutil.ReadAll(resp.Body)
	assert.NoError(t, err)

	assert.Equal(t, string(respBody), testCustomerID)
}

// Tests that setting authentication in the header works, and that the value in
// matches the value out
func TestRequireValid(t *testing.T) {
	validCustomerID := "valid-id"
	invalidCustomerID := ""

	mux := goji.NewMux()
	mux.Use(FromHeader)
	mux.Use(RequireValid)
	mux.Handle(pat.Get("/test"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		customerID, ok := LvsCustomerID(r.Context())
		if !ok {
			http.NotFound(w, r)
			return
		}

		_, _ = w.Write([]byte(customerID))
	}))
	server := httptest.NewServer(mux)
	defer server.Close()

	// first test a valid id
	req, err := http.NewRequest("GET", server.URL+"/test", nil)
	require.NoError(t, err)

	req.Header.Set(ForceAuthHeader, validCustomerID)

	resp, err := http.DefaultClient.Do(req)
	require.NoError(t, err)

	respBody, err := ioutil.ReadAll(resp.Body)
	assert.NoError(t, err)
	assert.NoError(t, resp.Body.Close())

	assert.Equal(t, string(respBody), validCustomerID)

	// test an invalid id
	req, err = http.NewRequest("GET", server.URL+"/test", nil)
	require.NoError(t, err)

	req.Header.Set(ForceAuthHeader, invalidCustomerID)

	resp, err = http.DefaultClient.Do(req)
	require.NoError(t, err)
	assert.Equal(t, resp.StatusCode, http.StatusNotFound)
	assert.NoError(t, resp.Body.Close())

	// test no id
	resp, err = http.Get(server.URL + "/test")
	require.NoError(t, err)
	assert.Equal(t, resp.StatusCode, http.StatusNotFound)
	assert.NoError(t, resp.Body.Close())

}

func TestIsValidCustomerID(t *testing.T) {
	valid := []string{
		"this-is_valid",
		"0fa0f2ef09v",
		"0123456789012345678901234567890123456789",
	}

	invalid := []string{
		"?fasldf?",
		"/asdfl/asdf/",
	}

	for _, id := range valid {
		assert.True(t, IsValidCustomerID(id))
	}

	for _, id := range invalid {
		assert.False(t, IsValidCustomerID(id))
	}
}
