package tls

import (
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"io/ioutil"
	"math/rand"
	"path/filepath"
	"sync"
	"time"

	"GoLog/log"
	"aaa/internal/env"

	"github.com/pkg/errors"
)

// ioutilReadFile is used for testing
var ioutilReadFile = ioutil.ReadFile

func NewConfig() (*tls.Config, error) {
	root, err := env.GetRoot()
	if err != nil {
		return nil, errors.WithMessage(err, "create AAA-TLS config")
	}
	log.Debug("Using root", root)

	certPool, err := loadCaCert(root)
	if err != nil || certPool == nil {
		return nil, errors.WithMessage(err, "create AAA-TLS config")
	}
	log.Debug("Loaded CA Cert")

	cp := expiringCertPool{
		certRoot: root,
	}
	// Preload certificate
	cp.getValidCertificate()

	config := tls.Config{
		GetClientCertificate: cp.GetClientCertificate,
		GetCertificate:       cp.GetCertificate,
		ClientAuth:           tls.RequireAndVerifyClientCert,
		RootCAs:              certPool,
		ClientCAs:            certPool,
	}

	return &config, nil
}

func loadCaCert(root string) (*x509.CertPool, error) {
	caCertPath := filepath.Join(root, "var", "state", "aaatls", "aaaca.pem")
	log.Debug("Reading certificate at", caCertPath)
	cert, err := ioutilReadFile(caCertPath)
	if err != nil {
		return nil, errors.Wrapf(err, "load CA certificate %v", caCertPath)
	}

	certPool := x509.NewCertPool()
	if ok := certPool.AppendCertsFromPEM(cert); !ok {
		return nil, errors.Errorf("load CA certificate %v: cannot add certificate to pool", caCertPath)
	}

	return certPool, nil
}

type expiringCert struct {
	tlsCert        *tls.Certificate
	expirationTime time.Time
}
type expiringCertPool struct {
	sync.Mutex
	certRoot string
	pool     []expiringCert
}

func (cp *expiringCertPool) loadCert() (*tls.Certificate, error) {
	aaaTlsPath := filepath.Join(cp.certRoot, "var", "state", "aaatls")

	keyPairPath := filepath.Join(aaaTlsPath, "aaaic_and_key.pem")
	log.Debug("Reading key pair at", keyPairPath)
	keyPair, err := ioutilReadFile(keyPairPath)
	certPem, keyPem := keyPair, keyPair
	if err != nil {
		// Try fallback of joining the aaaic and loading the key.uuid since AAAWorkspaceSupport does not
		// currently provide the aaaic_and_key.pem
		log.Info("Attempting to manually create key pair for AAAWorkspaceSupport:", keyPairPath, err)

		certPath := filepath.Join(aaaTlsPath, "aaaic.pem")
		log.Debug("Reading cert at", certPath)
		certPem, err = ioutilReadFile(certPath)
		if err != nil {
			return nil, errors.Wrap(err, "unable to load key pair")
		}

		keyUuidPath := filepath.Join(aaaTlsPath, "key.uuid")
		log.Debug("Using key file at", keyUuidPath)
		keyPemPath, err := ioutilReadFile(keyUuidPath)
		if err != nil {
			return nil, errors.Wrap(err, "unable to load key pair")
		}

		log.Debug("Reading private key at", string(keyPemPath))
		keyPem, err = ioutilReadFile(string(keyPemPath))
		if err != nil {
			return nil, errors.Wrap(err, "unable to load key pair")
		}
	}

	cert, err := tls.X509KeyPair(certPem, keyPem)
	if err != nil {
		return nil, errors.Wrap(err, "unable to load key pair")
	}

	caCertPath := filepath.Join(aaaTlsPath, "aaaca.pem")
	log.Debug("Reading certificate at", caCertPath)
	caCertBytes, err := ioutilReadFile(caCertPath)
	if err != nil {
		return nil, errors.Wrap(err, "unable to load key pair")
	}

	block, _ := pem.Decode(caCertBytes)
	if block == nil {
		return nil, errors.New("unable to load key pair: failed to parse certificate PEM")
	}
	cert.Certificate = append(cert.Certificate, block.Bytes)

	return &cert, nil
}

func (cp *expiringCertPool) newExpiringCert() (*expiringCert, error) {
	tlsCert, err := cp.loadCert()
	if err != nil {
		return nil, err
	}

	// Randomly distribute expiration between 23h30m and 24h from now
	expirationDuration := time.Hour*24 - (time.Minute * time.Duration(rand.Intn(30)))
	expiration := time.Now().Add(expirationDuration)
	cert := expiringCert{
		tlsCert:        tlsCert,
		expirationTime: expiration,
	}
	return &cert, nil
}

func (cp *expiringCertPool) get() (*expiringCert, error) {
	cp.Lock()
	if len(cp.pool) == 0 {
		cp.Unlock()
		return cp.newExpiringCert()
	}
	defer cp.Unlock()
	lastIndex := len(cp.pool) - 1
	c := cp.pool[lastIndex]
	cp.pool = cp.pool[:lastIndex]
	return &c, nil
}

func (cp *expiringCertPool) put(cert *expiringCert) {
	cp.Lock()
	cp.pool = append(cp.pool, *cert)
	cp.Unlock()
}

func (cp *expiringCertPool) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
	log.Debug("Getting certificate")
	return cp.getValidCertificate()
}

func (cp *expiringCertPool) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
	log.Debug("Getting client certificate")
	return cp.getValidCertificate()
}

func (cp *expiringCertPool) getValidCertificate() (*tls.Certificate, error) {
	// Forgive up to three errors before giving up and returning an error
	errs := make([]error, 0, 3)
	var c *expiringCert
	foundValidCert := false
	for !foundValidCert && len(errs) < 3 {
		var err error
		c, err = cp.get()
		if err != nil {
			errs = append(errs, err)
			continue
		}
		foundValidCert = time.Now().Before(c.expirationTime)
	}
	if c == nil {
		// Taking the last error because we need to take one
		err := errs[len(errs)-1]

		// Logging error here since this (usually) returns to the standard library
		log.Error(errors.WithMessage(err, "tried 3 times to get valid cert"))

		return nil, err
	}

	// Found a valid cert
	cp.put(c)
	return c.tlsCert, nil
}
