package main

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"net"
	"net/http"
	"testing"
	"time"

	"code.justin.tv/video/lvsapi/internal/auth"
	"code.justin.tv/video/lvsapi/internal/logging"
	"code.justin.tv/video/lvsapi/rpc/lvs"
	"github.com/pkg/errors"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"github.com/twitchtv/twirp"
)

func TestHttpServer(t *testing.T) {
	require.NoError(t, flag.Set("strict-auth", "false"))
	require.NoError(t, flag.Set("twitch-env", "dev"))
	handler := configureHandler(logging.Silent(), nil)
	handler = auth.AddSecurityHeader(handler)
	srv := configureServer(handler)
	ln, err := net.Listen("tcp", ":0")
	require.NoError(t, err)

	done := make(chan error, 1)
	go func() {
		done <- srv.Serve(ln)
	}()

	port := ln.Addr().(*net.TCPAddr).Port
	// query health endpoint
	resp, err := http.Get(fmt.Sprintf("http://localhost:%d/debug/health", port))
	require.NoError(t, err)

	// verify security headers
	require.Equal(t, "default-src: 'none'", resp.Header.Get("Content-Security-Policy"))
	require.Equal(t, "DENY", resp.Header.Get("X-Frame-Options"))

	_, err = io.Copy(ioutil.Discard, resp.Body)
	require.NoError(t, err)
	err = resp.Body.Close()
	require.NoError(t, err)

	// make an auth check query against the lvs api to make sure it's installed
	lvsClient := lvs.NewLiveVideoServiceProtobufClient(fmt.Sprintf("http://localhost:%d", port), http.DefaultClient)

	headers := make(http.Header)
	headers.Set(auth.ForceAuthHeader, "test-customer")
	ctx, err := twirp.WithHTTPRequestHeaders(context.Background(), headers)
	require.NoError(t, err)

	twResp, err := lvsClient.CheckAuth(ctx, &lvs.CheckAuthRequest{})
	require.NoError(t, err)
	assert.Equal(t, twResp.LvsCustomerId, "test-customer")

	_ = srv.Close()
	<-done
}

func TestHttpStrictAuth(t *testing.T) {
	// set strict auth to its default value, which should cause CheckAuth to error
	flag.Set("strict-auth", flag.Lookup("strict-auth").DefValue)
	handler := configureHandler(logging.Silent(), nil)
	handler = auth.AddSecurityHeader(handler)
	srv := configureServer(handler)
	ln, err := net.Listen("tcp", ":0")
	require.NoError(t, err)

	done := make(chan error, 1)
	go func() {
		done <- srv.Serve(ln)
	}()

	port := ln.Addr().(*net.TCPAddr).Port

	// query health endpoint
	resp, err := http.Get(fmt.Sprintf("http://localhost:%d/debug/health", port))
	require.NoError(t, err)

	// verify security headers
	require.Equal(t, "default-src: 'none'", resp.Header.Get("Content-Security-Policy"))
	require.Equal(t, "DENY", resp.Header.Get("X-Frame-Options"))

	_, err = io.Copy(ioutil.Discard, resp.Body)
	require.NoError(t, err)
	err = resp.Body.Close()
	require.NoError(t, err)

	// make an auth check query against the lvs api to make sure it's installed
	lvsClient := lvs.NewLiveVideoServiceProtobufClient(fmt.Sprintf("http://localhost:%d", port), http.DefaultClient)

	headers := make(http.Header)
	headers.Set(auth.ForceAuthHeader, "test-customer")
	ctx, err := twirp.WithHTTPRequestHeaders(context.Background(), headers)
	require.NoError(t, err)

	_, err = lvsClient.CheckAuth(ctx, &lvs.CheckAuthRequest{})
	assert.Error(t, err)

	_ = srv.Close()
	<-done
}

func TestHttpsServer(t *testing.T) {
	t.Skip("Skipping, as the test certificates have expired. TODO: Regenerate them")
	require.NoError(t, flag.Set("client-cert", "testdata/client-cert.pem"))
	require.NoError(t, flag.Set("server-cert", "testdata/server-cert.pem"))
	require.NoError(t, flag.Set("server-key", "testdata/server-key.pem"))

	handler := configureHandler(logging.Silent(), nil)
	handler = auth.AddSecurityHeader(handler)
	srv, err := configureTLSServer(handler)
	require.NoError(t, err)

	ln, err := net.Listen("tcp", ":0")
	require.NoError(t, err)

	tlsLn := tls.NewListener(ln, srv.TLSConfig)

	done := make(chan error, 1)
	go func() {
		done <- srv.Serve(tlsLn)
	}()

	client, err := httpsClient("testdata/client-cert.pem", "testdata/client-key.pem", "testdata/server-cert.pem")
	require.NoError(t, err)

	port := ln.Addr().(*net.TCPAddr).Port
	// query health endpoint
	resp, err := client.Get(fmt.Sprintf("https://localhost:%d/debug/health", port))
	require.NoError(t, err)

	// verify security headers
	require.Equal(t, "default-src: 'none'", resp.Header.Get("Content-Security-Policy"))
	require.Equal(t, "DENY", resp.Header.Get("X-Frame-Options"))

	_, err = io.Copy(ioutil.Discard, resp.Body)
	require.NoError(t, err)

	err = resp.Body.Close()
	require.NoError(t, err)

	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	err = srv.Shutdown(ctx)
	require.NoError(t, err)

	<-done
}

func TestHttpsServerNoCertificate(t *testing.T) {
	require.NoError(t, flag.Set("client-cert", "testdata/client-cert.pem"))
	require.NoError(t, flag.Set("server-cert", "testdata/server-cert.pem"))
	require.NoError(t, flag.Set("server-key", "testdata/server-key.pem"))

	handler := configureHandler(logging.Silent(), nil)
	handler = auth.AddSecurityHeader(handler)
	srv, err := configureTLSServer(handler)
	require.NoError(t, err)

	ln, err := net.Listen("tcp", ":0")
	require.NoError(t, err)

	tlsLn := tls.NewListener(ln, srv.TLSConfig)

	done := make(chan error, 1)
	go func() {
		done <- srv.Serve(tlsLn)
	}()

	port := ln.Addr().(*net.TCPAddr).Port

	client := &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
			},
		},
	}
	// query health endpoint
	resp, err := client.Get(fmt.Sprintf("https://localhost:%d/debug/health", port))
	require.NoError(t, err)
	assert.Equal(t, resp.StatusCode, http.StatusOK)

	// verify security headers
	require.Equal(t, "default-src: 'none'", resp.Header.Get("Content-Security-Policy"))
	require.Equal(t, "DENY", resp.Header.Get("X-Frame-Options"))

	lvsClient := lvs.NewLiveVideoServiceProtobufClient(fmt.Sprintf("https://localhost:%d/", port), client)
	_, err = lvsClient.CheckAuth(context.Background(), &lvs.CheckAuthRequest{})
	require.Error(t, err)

	assert.NoError(t, srv.Close())
	<-done
}

func TestHttpsServerBadCertificate(t *testing.T) {
	t.Skip("Skipping, as the test certificates have expired. TODO: Regenerate them")
	require.NoError(t, flag.Set("client-cert", "testdata/client-cert.pem"))
	require.NoError(t, flag.Set("server-cert", "testdata/server-cert.pem"))
	require.NoError(t, flag.Set("server-key", "testdata/server-key.pem"))

	handler := configureHandler(logging.Silent(), nil)
	handler = auth.AddSecurityHeader(handler)
	srv, err := configureTLSServer(handler)
	require.NoError(t, err)

	ln, err := net.Listen("tcp", ":0")
	require.NoError(t, err)

	tlsLn := tls.NewListener(ln, srv.TLSConfig)

	done := make(chan error, 1)
	go func() {
		done <- srv.Serve(tlsLn)
	}()

	client, err := httpsClient("testdata/invalid-client-cert.pem", "testdata/invalid-client-key.pem", "testdata/server-cert.pem")
	require.NoError(t, err)

	port := ln.Addr().(*net.TCPAddr).Port
	// query health endpoint
	_, err = client.Get(fmt.Sprintf("https://localhost:%d/debug/health", port))
	require.Error(t, err)

	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	err = srv.Shutdown(ctx)
	require.NoError(t, err)

	<-done
}

func httpsClient(clientCert, clientKey, serverCert string) (*http.Client, error) {
	// Load client cert
	cert, err := tls.LoadX509KeyPair(clientCert, clientKey)
	if err != nil {
		return nil, errors.Wrap(err, "failed to load client keypair")
	}

	// Load CA cert
	caCert, err := ioutil.ReadFile(serverCert)
	if err != nil {
		return nil, errors.Wrap(err, "failed to load server cert")
	}
	caCertPool := x509.NewCertPool()
	caCertPool.AppendCertsFromPEM(caCert)

	// Setup HTTPS client
	tlsConfig := &tls.Config{
		Certificates: []tls.Certificate{cert},
		RootCAs:      caCertPool,
	}
	tlsConfig.BuildNameToCertificate()

	return &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: tlsConfig,
		},
	}, nil
}
