package manager

import (
	"context"
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"errors"
	"fmt"
	"time"

	"code.justin.tv/systems/sandstorm/internal/envelope"
	"code.justin.tv/systems/sandstorm/internal/stat/statiface"
	"code.justin.tv/systems/sandstorm/logging"
	"code.justin.tv/systems/sandstorm/resource"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/aws/aws-sdk-go/service/kms"
)

// An Envelope encrypts and decrypts secrets with single-use KMS data keys using
// AES-256-GCM.
type Envelope struct {
	clientLoader  *clientLoader
	primaryRegion string
	kmsKeys       []resource.KMSKey
	Logger        logging.Logger
	statter       statiface.API
}

func (e *Envelope) getPrimaryKMSKeys() (keys []resource.KMSKey, err error) {

	for _, key := range e.kmsKeys {
		if key.Region == e.primaryRegion {
			keys = append(keys, key)
		}
	}
	if len(keys) == 0 {
		err = fmt.Errorf("kms key for %s not configured", e.primaryRegion)
	}
	return
}

func (e *Envelope) getSecondaryKMSKeys() (keys []resource.KMSKey) {
	for _, key := range e.kmsKeys {
		if key.Region != e.primaryRegion {
			keys = append(keys, key)
		}
	}
	return
}

// Seal generates a 256-bit data key using KMS and encrypts the given plaintext
// with AES-256-GCM using a random nonce. The ciphertext is appended to the
// nonce, which is in turn appended to the KMS data key ciphertext and returned.
func (e *Envelope) Seal(ctxt map[string]string, plaintext []byte) (encryptedDataKeys []envelope.EncryptedDataKey, ciphertext []byte, err error) {

	primaryKMSKeys, err := e.getPrimaryKMSKeys()
	if err != nil {
		return
	}

	var plaintextDataKey []byte
	defer Zero(plaintextDataKey)

	var primaryKMSKey resource.KMSKey

	for len(primaryKMSKeys) > 0 {
		primaryKMSKey, primaryKMSKeys = primaryKMSKeys[0], primaryKMSKeys[1:]

		var dataKeyOutput *kms.GenerateDataKeyOutput
		err = e.statter.WithMeasuredResult("GenerateDataKey", []*cloudwatch.Dimension{&cloudwatch.Dimension{
			Name:  aws.String("KMSKeyARN"),
			Value: aws.String(primaryKMSKey.KeyARN),
		}}, func() error {
			dataKeyOutput, err = e.clientLoader.KMS(primaryKMSKey.Region).GenerateDataKey(&kms.GenerateDataKeyInput{
				EncryptionContext: e.context(ctxt),
				KeySpec:           aws.String("AES_256"),
				KeyId:             aws.String(primaryKMSKey.KeyARN),
			})
			return err
		})

		if err != nil {
			e.Logger.Warnf("could not generate data key using, %s. err: %s", primaryKMSKey.KeyARN, err.Error())
			continue
		}

		plaintextDataKey = dataKeyOutput.Plaintext
		encryptedDataKeys = append(encryptedDataKeys, envelope.EncryptedDataKey{
			Value:  dataKeyOutput.CiphertextBlob,
			Region: primaryKMSKey.Region,
			KeyARN: aws.StringValue(dataKeyOutput.KeyId),
		})

		break
	}

	if primaryKMSKey.KeyARN == "" {
		err = errors.New("failed to generate data key in primary region")
		return
	}

	ciphertext, err = encrypt(plaintextDataKey, plaintext, []byte(primaryKMSKey.KeyARN))
	if err != nil {
		return
	}

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	encryptionChan := make(chan struct {
		Output *kms.EncryptOutput
		Region string
		Err    error
	})
	allKeys := append(primaryKMSKeys, e.getSecondaryKMSKeys()...)

	for _, key := range allKeys {
		// goroutine scope
		key := key
		go func() {

			var out *kms.EncryptOutput
			err = e.statter.WithMeasuredResult("Encrypt", []*cloudwatch.Dimension{&cloudwatch.Dimension{
				Name:  aws.String("KMSKeyARN"),
				Value: aws.String(key.KeyARN),
			}}, func() error {

				out, err = e.clientLoader.KMS(key.Region).EncryptWithContext(ctx, &kms.EncryptInput{
					EncryptionContext: e.context(ctxt),
					KeyId:             aws.String(key.KeyARN),
					Plaintext:         plaintextDataKey,
				})
				return err
			})

			encryptionChan <- struct {
				Output *kms.EncryptOutput
				Region string
				Err    error
			}{
				Output: out,
				Region: key.Region,
				Err:    err,
			}
		}()
	}

	var encryptError []error
	for i := 0; i < len(allKeys); i++ {
		encryptionOutput := <-encryptionChan
		if encryptionOutput.Err != nil {
			encryptError = append(encryptError, encryptionOutput.Err)
			continue
		}

		encryptedDataKeys = append(encryptedDataKeys, envelope.EncryptedDataKey{
			Region: encryptionOutput.Region,
			KeyARN: aws.StringValue(encryptionOutput.Output.KeyId),
			Value:  encryptionOutput.Output.CiphertextBlob,
		})
	}

	if len(encryptError) > 0 {
		err = encryptError[0]
		return
	}
	return
}

// Open takes the output of Seal and decrypts it. If any part of the ciphertext
// or context is modified, Seal will return an error instead of the decrypted
// data.
func (e *Envelope) Open(encryptedDataKeys []envelope.EncryptedDataKey, ciphertext []byte, ctxt map[string]string) (plaintext []byte, err error) {

	defer Zero(plaintext)
	if len(encryptedDataKeys) == 0 {
		err = errors.New("no data keys were provided")
		return
	}

	primaryKMSKeys, err := e.getPrimaryKMSKeys()
	if err != nil {
		return
	}
	primaryKMSKey := primaryKMSKeys[0]
	for _, dataKey := range encryptedDataKeys {
		var decryptOutput *kms.DecryptOutput

		err = e.statter.WithMeasuredResult("Decrypt", []*cloudwatch.Dimension{&cloudwatch.Dimension{
			Name:  aws.String("KMSKeyARN"),
			Value: aws.String(dataKey.KeyARN),
		}}, func() (err error) {
			decryptOutput, err = e.clientLoader.KMS(dataKey.Region).Decrypt(&kms.DecryptInput{
				CiphertextBlob:    dataKey.Value,
				EncryptionContext: e.context(ctxt),
			})
			return
		})

		if err != nil {
			continue
		}
		plaintext, err = decrypt(decryptOutput.Plaintext, ciphertext, []byte(primaryKMSKey.KeyARN))
		if err != nil {
			continue
		}
		return
	}

	if apiErr, ok := err.(awserr.Error); ok {
		if apiErr.Code() == "InvalidCiphertextException" {
			err = fmt.Errorf("unable to decrypt data key")
		}
	}
	return
}

func (e *Envelope) context(c map[string]string) map[string]*string {
	ctxt := make(map[string]*string)
	for k, v := range c {
		ctxt[k] = aws.String(v)
	}
	return ctxt
}

func encrypt(key, plaintext, data []byte) ([]byte, error) {
	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}

	gcm, err := cipher.NewGCM(block)
	if err != nil {
		return nil, err
	}

	nonce := make([]byte, gcm.NonceSize())
	if _, err := rand.Read(nonce); err != nil {
		return nil, err
	}

	return gcm.Seal(nonce, nonce, plaintext, data), nil
}

func decrypt(key, ciphertext, data []byte) ([]byte, error) {

	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}
	gcm, err := cipher.NewGCM(block)
	if err != nil {
		return nil, err
	}
	nonce, ciphertext := ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():]

	return gcm.Open(nil, nonce, ciphertext, data)
}

// Zero replaces every byte in b with 0
func Zero(b []byte) {
	for i := range b {
		b[i] = 0
	}
}
