package certgen

import (
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"errors"
	"strconv"
	"time"

	"code.justin.tv/event-engineering/acm-ca-go/pkg/aws"
	awsBackend "code.justin.tv/event-engineering/acm-ca-go/pkg/aws/backend"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/acm"
	"github.com/aws/aws-sdk-go/service/acmpca"
	"github.com/youmark/pkcs8"
	"log"
)

type certgen struct {
	privateCAArn string
	awsClient    *aws.Client
}

type Generator interface {
	CanHazCert(subject string, subjectAlternateNames []string) (tls.Certificate, error)
	GetRootCACert() ([]*x509.Certificate, error)
}

func New(sess *session.Session, privateCAArn string) Generator {
	awsBackend := awsBackend.New(sess)

	awsClient := aws.New(awsBackend)

	return &certgen{
		privateCAArn: privateCAArn,
		awsClient:    awsClient,
	}
}

func (cg *certgen) GetRootCACert() ([]*x509.Certificate, error) {
	result, err := cg.awsClient.ACMPCAGetCertificateAuthorityCertificate(&acmpca.GetCertificateAuthorityCertificateInput{
		CertificateAuthorityArn: &cg.privateCAArn,
	})

	if err != nil {
		return nil, err
	}

	certBlock, _ := pem.Decode([]byte(*result.Certificate))

	cert, err := x509.ParseCertificate(certBlock.Bytes)
	if err != nil {
		log.Println("Failed to parse certificate block")
		return nil, err
	}

	certs := []*x509.Certificate{
		cert,
	}

	if result.CertificateChain != nil {
		chainBlock, _ := pem.Decode([]byte(*result.CertificateChain))

		chain, err := x509.ParseCertificate(chainBlock.Bytes)
		if err != nil {
			log.Println("Failed to parse certificate chain")
			return nil, err
		}

		certs = append(certs, chain)
	}

	return certs, nil
}

func (cg *certgen) CanHazCert(subject string, subjectAlternateNames []string) (tls.Certificate, error) {
	statusFilter := acm.CertificateStatusIssued
	var maxItems int64 = 20
	var nextToken *string

	results := make([]*acm.CertificateSummary, 0)

	for {
		certs, err := cg.awsClient.ACMListCertificates(&acm.ListCertificatesInput{
			CertificateStatuses: []*string{
				&statusFilter,
			},
			MaxItems:  &maxItems,
			NextToken: nextToken,
		})

		if err != nil {
			log.Println(err)
			return tls.Certificate{}, err
		}

		results = append(results, certs.CertificateSummaryList...)

		if certs.NextToken == nil {
			break
		}

		nextToken = certs.NextToken
	}

	for _, cert := range results {
		if *cert.DomainName == subject {
			log.Println("Attempting to export cert ", *cert.CertificateArn)
			tlsCert, err := cg.exportCertificate(*cert.CertificateArn)
			if err != nil {
				log.Println(err)
				continue
			}

			return tlsCert, nil
		}
	}

	// If we're here then we didn't find a certificate we can use, so lets generate one
	return cg.requestCertificate(subject, subjectAlternateNames)
}

func (cg *certgen) requestCertificate(subject string, subjectAlternateNames []string) (tls.Certificate, error) {
	token := strconv.FormatInt(time.Now().Unix(), 10)
	transparency := acm.CertificateTransparencyLoggingPreferenceDisabled

	reqCertInput := &acm.RequestCertificateInput{
		CertificateAuthorityArn: &cg.privateCAArn,
		IdempotencyToken:        &token,
		DomainName:              &subject,
		Options: &acm.CertificateOptions{
			CertificateTransparencyLoggingPreference: &transparency,
		},
	}

	if len(subjectAlternateNames) > 0 {
		sans := make([]*string, 0)
		for _, san := range subjectAlternateNames {
			sans = append(sans, &san)
		}

		reqCertInput.SubjectAlternativeNames = sans
	}

	certReq, err := cg.awsClient.ACMRequestCertificate(reqCertInput)

	if err != nil {
		log.Println(err)
		return tls.Certificate{}, err
	}

	ticker := time.NewTicker(1 * time.Second)
	timeout := time.NewTimer(15 * time.Second)

	defer ticker.Stop()
	defer timeout.Stop()

	for {
		select {
		case <-ticker.C:
			cert, err := cg.awsClient.ACMDescribeCertificate(&acm.DescribeCertificateInput{
				CertificateArn: certReq.CertificateArn,
			})

			if err != nil {
				log.Println(err)
				continue
			}

			if *cert.Certificate.Status == acm.CertificateStatusFailed {
				return tls.Certificate{}, errors.New("Certificate creation failed")
			}

			if *cert.Certificate.Status != acm.CertificateStatusIssued {
				log.Println("Cert not yet issued")
				continue
			}

			tlsCert, err := cg.exportCertificate(*cert.Certificate.CertificateArn)

			if err != nil {
				log.Println(err)
				return tls.Certificate{}, err
			}

			return tlsCert, nil
		case <-timeout.C:
			log.Println("Certificate creation timed out")
			return tls.Certificate{}, errors.New("Certificate creation timed out")
		}
	}
}

func (cg *certgen) exportCertificate(arn string) (tls.Certificate, error) {
	// We're going to decrypt this before returning it, but ACM won't export it without a passphrase
	passphrase := []byte("Would you like a carrot?")

	export, err := cg.awsClient.ACMExportCertificate(&acm.ExportCertificateInput{
		CertificateArn: &arn,
		Passphrase:     passphrase,
	})

	if err != nil {
		return tls.Certificate{}, err
	}

	certBlock, _ := pem.Decode([]byte(*export.Certificate))

	cert, err := x509.ParseCertificate(certBlock.Bytes)
	if err != nil {
		log.Println("Failed to parse certificate block")
		return tls.Certificate{}, err
	}

	// Give us some grace time to generate a new cert
	if time.Now().Add(720 * time.Hour).After(cert.NotAfter) {
		return tls.Certificate{}, errors.New("Cert was too close to expiry")
	}

	chainBlock, _ := pem.Decode([]byte(*export.CertificateChain))

	chain, err := x509.ParseCertificate(chainBlock.Bytes)
	if err != nil {
		log.Println("Failed to parse certificate chain")
		return tls.Certificate{}, err
	}

	privKeyBlock, _ := pem.Decode([]byte(*export.PrivateKey))

	privateKey, err := pkcs8.ParsePKCS8PrivateKey(privKeyBlock.Bytes, passphrase)
	if err != nil {
		log.Println("Failed to decrypt private key")
		return tls.Certificate{}, err
	}

	return tls.Certificate{
		Certificate: [][]byte{cert.Raw, chain.Raw},
		PrivateKey:  privateKey,
	}, nil
}
