package cert

import (
	"bytes"
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/json"
	"errors"
	"io/ioutil"
	"math/big"
	"math/rand"
	"net/http"
	"net/url"
	"strings"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/cert/mocks"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/logutil"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/s2s2err"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/service"
	"code.justin.tv/video/metrics-middleware/v2/operation"
	"github.com/golang/mock/gomock"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

//go:generate mockgen -package mocks -destination ./mocks/cert.go code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/cert HTTPClient

func TestCertificatesLoadRootCertificatePool(t *testing.T) {
	const keyID = "keyID"
	const serviceDomain = "serviceDomain"
	const serviceName = "serviceName"
	const serviceStage = "serviceStage"
	const serviceURL = "https://identities/twitch/serviceName/serviceStage.json"

	ctx := context.Background()

	caPrivateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(69)))
	require.NoError(t, err)

	servicePrivateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(70)))
	require.NoError(t, err)

	response := func(t *testing.T, statusCode int, bodyJSON interface{}) *http.Response {
		var body bytes.Buffer
		require.NoError(t, json.NewEncoder(&body).Encode(bodyJSON))
		return &http.Response{
			StatusCode: statusCode,
			Body:       ioutil.NopCloser(&body),
		}
	}

	t.Run("success", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		_, serviceX5C := test.ServiceCertificate(t, servicePrivateKey, caPrivateKey, serviceDomain, serviceName, serviceStage)

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusOK, struct {
				Keys []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				} `json:"keys"`
			}{
				Keys: []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				}{
					{KeyID: keyID, X509CertificateChain: serviceX5C},
				},
			}), nil)

		require.NoError(t, test.LoadRootCertificatePool(ctx))
	})

	t.Run("Do error", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		myErr := errors.New("myerr")
		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(nil, myErr)

		assert.Equal(t, myErr, test.LoadRootCertificatePool(ctx))
	})

	t.Run("500", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusInternalServerError, struct{}{}), nil)

		assert.Contains(t, test.LoadRootCertificatePool(ctx).Error(), "unknown error with status code<500>: {}")
	})

	t.Run("response wrong format", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusOK, struct {
				Keys []struct {
					KeyID int `json:"kid"`
				} `json:"keys"`
			}{
				Keys: []struct {
					KeyID int `json:"kid"`
				}{
					{KeyID: 69},
				},
			}), nil)

		assert.Contains(t, test.LoadRootCertificatePool(ctx).Error(), ".keys")
	})
}

func TestCertificatesWhitelistService(t *testing.T) {
	const keyID = "keyID"
	const serviceDomain = "serviceDomain"
	const serviceName = "serviceName"
	const serviceStage = "serviceStage"
	const serviceURL = "https://identities/twitch/serviceName/serviceStage.json"

	ctx := context.Background()

	caPrivateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(69)))
	require.NoError(t, err)

	servicePrivateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(70)))
	require.NoError(t, err)

	response := func(t *testing.T, statusCode int, bodyJSON interface{}) *http.Response {
		var body bytes.Buffer
		require.NoError(t, json.NewEncoder(&body).Encode(bodyJSON))
		return &http.Response{
			StatusCode: statusCode,
			Body:       ioutil.NopCloser(&body),
		}
	}

	t.Run("success", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		serviceCertificate, serviceX5C := test.ServiceCertificate(t, servicePrivateKey, caPrivateKey, serviceDomain, serviceName, serviceStage)

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusOK, struct {
				Keys []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				} `json:"keys"`
			}{
				Keys: []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				}{
					{KeyID: keyID, X509CertificateChain: serviceX5C},
				},
			}), nil)

		require.NoError(t, test.WhitelistService(ctx, serviceName, serviceStage))

		assert.Equal(t, serviceCertificate, test.certificates[keyID].Certificate)
	})

	t.Run("Do error", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		myErr := errors.New("myerr")
		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(nil, myErr)

		assert.Equal(t, myErr, test.WhitelistService(ctx, serviceName, serviceStage))
	})

	t.Run("404", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusNotFound, struct{}{}), nil)

		assert.Equal(t, s2s2err.NewError(s2s2err.CodeServiceStageNotFound, &ErrServiceStageNotFound{
			Service: serviceName,
			Stage:   serviceStage,
			URL:     "/serviceName/serviceStage.json",
		}), test.WhitelistService(ctx, serviceName, serviceStage))
	})

	t.Run("500", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusInternalServerError, struct{}{}), nil)

		assert.Contains(t, test.WhitelistService(ctx, serviceName, serviceStage).Error(), "unknown error with status code<500>: {}")
	})

	t.Run("response wrong format", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusOK, struct {
				Keys []struct {
					KeyID int `json:"kid"`
				} `json:"keys"`
			}{
				Keys: []struct {
					KeyID int `json:"kid"`
				}{
					{KeyID: 69},
				},
			}), nil)

		assert.Contains(t, test.WhitelistService(ctx, serviceName, serviceStage).Error(), ".keys")
	})

	t.Run("certificate not signed by root", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		_, serviceX5C := test.ServiceCertificate(t, servicePrivateKey, servicePrivateKey, serviceDomain, serviceName, serviceStage)

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusOK, struct {
				Keys []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				} `json:"keys"`
			}{
				Keys: []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				}{
					{KeyID: keyID, X509CertificateChain: serviceX5C},
				},
			}), nil)

		assert.Contains(t, test.WhitelistService(ctx, serviceName, serviceStage).Error(), "certificate signed by unknown authority")
	})
}

func TestCertificatesGet(t *testing.T) {
	const keyID = "keyID"
	const x5u = "x5u"
	const serviceDomain = "serviceDomain"
	const serviceName = "serviceName"
	const serviceStage = "serviceStage"

	ctx := context.Background()

	expectedService := service.Service{
		Domain: serviceDomain,
		Name:   serviceName,
		Stage:  serviceStage,
	}

	caPrivateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(69)))
	require.NoError(t, err)

	servicePrivateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(70)))
	require.NoError(t, err)

	response := func(t *testing.T, statusCode int, bodyJSON interface{}) *http.Response {
		var body bytes.Buffer
		require.NoError(t, json.NewEncoder(&body).Encode(bodyJSON))
		return &http.Response{
			StatusCode: statusCode,
			Body:       ioutil.NopCloser(&body),
		}
	}

	expectDoAndReturnCert := func(test *certificatesTest, cert [][]byte) *gomock.Call {
		return test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusOK, struct {
				Keys []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				} `json:"keys"`
			}{
				Keys: []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				}{
					{KeyID: keyID, X509CertificateChain: cert},
				},
			}), nil)
	}

	t.Run("success", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)
		test.certificates[x5u] = &certificate{
			Certificate: &x509.Certificate{
				NotBefore: time.Now().Add(-24 * time.Hour),
				NotAfter:  time.Now().Add(24 * time.Hour),
			},
			PublicKey: &caPrivateKey.PublicKey,
			Service:   &expectedService,
		}

		svc, pubkey, err := test.Get(ctx, x5u)
		require.NoError(t, err)
		assert.Equal(t, caPrivateKey.Public(), pubkey)
		assert.Equal(t, &expectedService, svc)
	})

	t.Run("foreground refresh required", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)
		test.certificates[x5u] = &certificate{
			Certificate: &x509.Certificate{
				NotBefore: time.Now().Add(-24 * time.Hour),
				NotAfter:  time.Now().Add(24 * time.Hour),
			},
			PublicKey: &caPrivateKey.PublicKey,
			Service:   &expectedService,
		}

		_, serviceX5C := test.ServiceCertificate(t, servicePrivateKey, caPrivateKey, serviceDomain, serviceName, serviceStage)

		expectDoAndReturnCert(test, [][]byte{test.AuthorityCertificate.Raw})
		expectDoAndReturnCert(test, serviceX5C)

		time.Sleep(100 * time.Millisecond)

		svc, pubkey, err := test.Get(ctx, x5u)
		require.NoError(t, err)
		assert.Equal(t, caPrivateKey.Public(), pubkey)
		assert.Equal(t, &expectedService, svc)
	})

	t.Run("foreground refresh required but ca refresh fail", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)
		test.certificates[x5u] = &certificate{
			Certificate: &x509.Certificate{
				NotBefore: time.Now().Add(-24 * time.Hour),
				NotAfter:  time.Now().Add(24 * time.Hour),
			},
			PublicKey: &caPrivateKey.PublicKey,
			Service:   &expectedService,
		}

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(nil, errors.New("myerr"))

		time.Sleep(100 * time.Millisecond)

		svc, pubkey, err := test.Get(ctx, x5u)
		require.NoError(t, err)
		assert.Equal(t, caPrivateKey.Public(), pubkey)
		assert.Equal(t, &expectedService, svc)
	})

	t.Run("foreground refresh required but service cert refresh fail", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)
		test.certificates[x5u] = &certificate{
			Certificate: &x509.Certificate{
				NotBefore: time.Now().Add(-24 * time.Hour),
				NotAfter:  time.Now().Add(24 * time.Hour),
			},
			PublicKey: &caPrivateKey.PublicKey,
			Service:   &expectedService,
		}

		expectDoAndReturnCert(test, [][]byte{test.AuthorityCertificate.Raw})

		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(nil, errors.New("myerr"))

		time.Sleep(100 * time.Millisecond)

		svc, pubkey, err := test.Get(ctx, x5u)
		require.NoError(t, err)
		assert.Equal(t, caPrivateKey.Public(), pubkey)
		assert.Equal(t, &expectedService, svc)
	})

	t.Run("not in cache", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		_, _, err := test.Get(ctx, x5u)
		assert.Equal(t, s2s2err.NewError(s2s2err.CodeX5UNotInCache, &ErrX5UNotInCache{X5U: x5u}), err)
	})

	t.Run("certificate expired", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)
		test.certificates[x5u] = &certificate{
			Certificate: &x509.Certificate{
				NotBefore: time.Now().Add(-24 * time.Hour),
				NotAfter:  time.Now().Add(-1 * time.Hour),
			},
			PublicKey: &caPrivateKey.PublicKey,
			Service:   &expectedService,
		}

		_, _, err := test.Get(ctx, x5u)
		assert.Contains(t, err.Error(), "is only valid before")
	})

	t.Run("certificate not valid yet", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)
		test.certificates[x5u] = &certificate{
			Certificate: &x509.Certificate{
				NotBefore: time.Now().Add(2 * time.Hour),
				NotAfter:  time.Now().Add(24 * time.Hour),
			},
			PublicKey: &caPrivateKey.PublicKey,
			Service:   &expectedService,
		}

		_, _, err := test.Get(ctx, x5u)
		assert.Contains(t, err.Error(), "is only valid after")
	})
}

func TestCertificatesRefresh(t *testing.T) {
	const keyID = "keyID"
	const serviceDomain = "serviceDomain"
	const serviceName = "serviceName"
	const serviceStage = "serviceStage"

	ctx := context.Background()

	caPrivateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(69)))
	require.NoError(t, err)

	servicePrivateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(70)))
	require.NoError(t, err)

	response := func(t *testing.T, statusCode int, bodyJSON interface{}) *http.Response {
		var body bytes.Buffer
		require.NoError(t, json.NewEncoder(&body).Encode(bodyJSON))
		return &http.Response{
			StatusCode: statusCode,
			Body:       ioutil.NopCloser(&body),
		}
	}

	expectLoadCertificatePool := func(t *testing.T, test *certificatesTest, certX5C [][]byte) {
		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).
			Return(response(t, http.StatusOK, struct {
				Keys []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				} `json:"keys"`
			}{
				Keys: []struct {
					KeyID                string   `json:"kid"`
					X509CertificateChain [][]byte `json:"x5c"`
				}{
					{KeyID: keyID, X509CertificateChain: certX5C},
				},
			}), nil)
	}

	t.Run("success", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		expectLoadCertificatePool(t, test, [][]byte{test.AuthorityCertificate.Raw})
		require.NoError(t, test.Certificates.LoadRootCertificatePool(ctx))

		serviceCertificate, serviceX5C := test.ServiceCertificate(t, servicePrivateKey, caPrivateKey, serviceDomain, serviceName, serviceStage)
		expectLoadCertificatePool(t, test, serviceX5C)
		require.NoError(t, test.WhitelistService(ctx, serviceName, serviceStage))
		assert.Equal(t, serviceCertificate, test.certificates[keyID].Certificate)

		expectLoadCertificatePool(t, test, [][]byte{test.AuthorityCertificate.Raw})
		expectLoadCertificatePool(t, test, serviceX5C)
		require.NoError(t, test.Certificates.Refresh(ctx))
		assert.Equal(t, serviceCertificate, test.certificates[keyID].Certificate)
	})

	t.Run("cert pool error", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		expectLoadCertificatePool(t, test, [][]byte{test.AuthorityCertificate.Raw})
		require.NoError(t, test.Certificates.LoadRootCertificatePool(ctx))

		myErr := errors.New("myerr")
		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).Return(nil, myErr)

		assert.Equal(t, myErr, test.Certificates.Refresh(ctx))
	})

	t.Run("load certs error", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		test := newCertificatesTest(t, ctrl, caPrivateKey)

		expectLoadCertificatePool(t, test, [][]byte{test.AuthorityCertificate.Raw})
		require.NoError(t, test.Certificates.LoadRootCertificatePool(ctx))

		serviceCertificate, serviceX5C := test.ServiceCertificate(t, servicePrivateKey, caPrivateKey, serviceDomain, serviceName, serviceStage)
		expectLoadCertificatePool(t, test, serviceX5C)
		require.NoError(t, test.WhitelistService(ctx, serviceName, serviceStage))
		assert.Equal(t, serviceCertificate, test.certificates[keyID].Certificate)

		expectLoadCertificatePool(t, test, [][]byte{test.AuthorityCertificate.Raw})

		myErr := errors.New("myerr")
		test.MockHTTPClient.EXPECT().
			Do(gomock.Any()).Return(nil, myErr)

		assert.Equal(t, myErr, test.Certificates.Refresh(ctx))
	})
}

type certificatesTest struct {
	*Certificates
	*mocks.MockHTTPClient

	AuthorityCertificate *x509.Certificate
}

func (test *certificatesTest) ServiceCertificate(
	t *testing.T,
	key *ecdsa.PrivateKey,
	authorityKey *ecdsa.PrivateKey,
	domain, service, stage string,
) (*x509.Certificate, [][]byte) {
	serviceCertRaw, err := x509.CreateCertificate(
		rand.New(rand.NewSource(182)),
		&x509.Certificate{
			DNSNames:     []string{strings.Join([]string{stage, service, domain}, ".")},
			SerialNumber: big.NewInt(2),
			Subject: pkix.Name{
				CommonName: strings.Join([]string{stage, service, domain}, "."),
			},
			KeyUsage:     x509.KeyUsageDigitalSignature,
			NotBefore:    time.Now(),
			NotAfter:     time.Now().Add(365 * time.Hour),
			SubjectKeyId: []byte{2},
		},
		test.AuthorityCertificate,
		key.Public(),
		authorityKey,
	)
	require.NoError(t, err)

	cert, err := x509.ParseCertificate(serviceCertRaw)
	require.NoError(t, err)
	return cert, [][]byte{serviceCertRaw}
}

func newCertificatesTest(
	t *testing.T,
	ctrl *gomock.Controller,
	authorityKey *ecdsa.PrivateKey,
) *certificatesTest {
	caCertificate, err := x509.CreateCertificate(
		rand.New(rand.NewSource(420)),
		&x509.Certificate{
			DNSNames:     []string{"authority"},
			ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
			KeyUsage:     x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
			SerialNumber: big.NewInt(1),
			IsCA:         true,
			NotBefore:    time.Now(),
			NotAfter:     time.Now().Add(365 * time.Hour),
			Subject: pkix.Name{
				CommonName: "a.b.c",
			},
			SubjectKeyId:          []byte{1},
			BasicConstraintsValid: true,
		},
		// self signed
		&x509.Certificate{
			DNSNames:     []string{"authority"},
			ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
			KeyUsage:     x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
			SerialNumber: big.NewInt(1),
			IsCA:         true,
			NotBefore:    time.Now(),
			NotAfter:     time.Now().Add(365 * time.Hour),
			Subject: pkix.Name{
				CommonName: "a.b.c",
			},
			SubjectKeyId:          []byte{1},
			BasicConstraintsValid: true,
		},
		authorityKey.Public(),
		authorityKey,
	)
	require.NoError(t, err)

	caCertParsed, err := x509.ParseCertificate(caCertificate)
	require.NoError(t, err)

	rootCertificates := x509.NewCertPool()
	rootCertificates.AddCert(caCertParsed)

	mockHTTPClient := mocks.NewMockHTTPClient(ctrl)

	certs := New(mockHTTPClient, url.URL{}, "", &operation.Starter{}, logutil.NoopLogger, time.Millisecond, time.Millisecond)
	certs.rootCertificates = rootCertificates

	return &certificatesTest{
		Certificates:         certs,
		MockHTTPClient:       mockHTTPClient,
		AuthorityCertificate: caCertParsed,
	}
}
