package aws_test

import (
	"context"
	"testing"
	"time"

	"github.com/stretchr/testify/require"

	"code.justin.tv/video/clock"

	vault "code.justin.tv/amzn/StarfruitVault"
	"code.justin.tv/amzn/StarfruitVault/key"
	"code.justin.tv/amzn/StarfruitVault/provider"
	. "code.justin.tv/amzn/StarfruitVault/provider/aws"
	"code.justin.tv/amzn/StarfruitVault/provider/aws/mock"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/secretsmanager"
	"github.com/golang/mock/gomock"
	"github.com/stretchr/testify/assert"
)

func getFakeSecretsmanager(ctrl *gomock.Controller) *mock.MockSecretsManagerAPI {
	secrets := map[string]*string{}
	m := mock.NewMockSecretsManagerAPI(ctrl)
	m.EXPECT().
		GetSecretValueWithContext(gomock.Any(), gomock.Any(), gomock.Any()).
		DoAndReturn(
			func(_ context.Context, input *secretsmanager.GetSecretValueInput, _ ...request.Option) (*secretsmanager.GetSecretValueOutput, error) {
				str, ok := secrets[*input.SecretId]
				if !ok {
					return nil, &secretsmanager.ResourceNotFoundException{}
				}
				return &secretsmanager.GetSecretValueOutput{SecretString: str}, nil
			}).
		AnyTimes()
	m.EXPECT().
		PutSecretValueWithContext(gomock.Any(), gomock.Any(), gomock.Any()).
		DoAndReturn(func(_ context.Context, input *secretsmanager.PutSecretValueInput, _ ...request.Option) (*secretsmanager.PutSecretValueOutput, error) {
			secrets[*input.SecretId] = input.SecretString
			return &secretsmanager.PutSecretValueOutput{}, nil
		}).
		AnyTimes()
	return m
}

func TestRotate(t *testing.T) {
	ctrl := gomock.NewController(t)
	// Assert that Bar() is invoked.
	defer ctrl.Finish()

	sm := getFakeSecretsmanager(ctrl)
	_, _ = sm.PutSecretValueWithContext(context.Background(), &secretsmanager.PutSecretValueInput{
		SecretId:     aws.String("private"),
		SecretString: aws.String("{}"),
	})

	rotator, err := NewRotator(&RotatorConfig{
		SecretsManager:    sm,
		KeyExpirationTTL:  0,
		EncryptionKeyType: key.KeyTypeSharedAESGCM,
		PrivateID:         "private",
		PublicID:          "public",
	})
	assert.NoError(t, err)
	err = rotator.Rotate(context.Background())
	assert.NoError(t, err)
}

func TestRotateEncrypt(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()

	sm := getFakeSecretsmanager(ctrl)
	_, _ = sm.PutSecretValueWithContext(context.Background(), &secretsmanager.PutSecretValueInput{
		SecretId:     aws.String("private"),
		SecretString: aws.String("{}"),
	})

	vizimaRotator, err := NewRotator(&RotatorConfig{
		SecretsManager:    sm,
		EncryptionKeyType: key.KeyTypeSharedAESGCM,
		SignatureKeyType:  key.KeyTypePrivateEd25519,
		PrivateID:         "private",
		PublicID:          "public",
	})

	assert.NoError(t, err)
	err = vizimaRotator.Rotate(context.Background())
	assert.NoError(t, err)

	encodeVault, err := vault.NewEncodeVault(&vault.EncodeVaultConfig{})
	assert.NoError(t, err)

	encodeHooks := provider.NewHooks()
	encodeHooks.RegisterOnFetch(encodeVault.OnKeys)
	encodeProvider, err := NewAWSProvider(&Config{
		Session:        session.Must(session.NewSession()),
		SecretIDs:      map[string]string{"sjc02": "private"},
		Hooks:          encodeHooks,
		SecretsManager: sm,
	})
	assert.NoError(t, err)
	err = encodeProvider.Fetch(context.Background())
	assert.NoError(t, err)

	decodeVault, err := vault.NewDecodeVault(&vault.DecodeVaultConfig{})
	assert.NoError(t, err)

	decodeHooks := provider.NewHooks()
	decodeHooks.RegisterOnFetch(decodeVault.OnKeys)
	decodeProvider, err := NewAWSProvider(&Config{
		Session:        session.Must(session.NewSession()),
		SecretIDs:      map[string]string{"sjc02": "public"},
		Hooks:          decodeHooks,
		SecretsManager: sm,
	})
	assert.NoError(t, err)
	err = decodeProvider.Fetch(context.Background())
	assert.NoError(t, err)

	secret := []byte("omg i am super secure")
	encoded, err := encodeVault.Encode("sjc02", secret)
	assert.NoError(t, err)

	id, decoded, err := decodeVault.Decode(encoded)
	assert.NoError(t, err)
	assert.Equal(t, "sjc02", id)
	assert.Equal(t, secret, decoded)
}

func TestRotateEncryptExpiration(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()

	sm := getFakeSecretsmanager(ctrl)
	_, _ = sm.PutSecretValueWithContext(context.Background(), &secretsmanager.PutSecretValueInput{
		SecretId:     aws.String("private"),
		SecretString: aws.String("{}"),
	})

	mockClock := clock.NewMock()
	mockClock.Set(time.Now())

	vizimaRotator, err := NewRotator(&RotatorConfig{
		SecretsManager:    sm,
		EncryptionKeyType: key.KeyTypeSharedAESGCM,
		SignatureKeyType:  key.KeyTypePrivateEd25519,
		PrivateID:         "private",
		PublicID:          "public",
		Clock:             mockClock,
	})

	assert.NoError(t, err)
	err = vizimaRotator.Rotate(context.Background())
	assert.NoError(t, err)

	encodeVault, err := vault.NewEncodeVault(&vault.EncodeVaultConfig{})
	assert.NoError(t, err)

	encodeHooks := provider.NewHooks()
	encodeHooks.RegisterOnFetch(encodeVault.OnKeys)
	encodeProvider, err := NewAWSProvider(&Config{
		Session:        session.Must(session.NewSession()),
		SecretIDs:      map[string]string{"sjc02": "private"},
		Hooks:          encodeHooks,
		SecretsManager: sm,
	})
	assert.NoError(t, err)
	err = encodeProvider.Fetch(context.Background())
	assert.NoError(t, err)

	decodeVault, err := vault.NewDecodeVault(&vault.DecodeVaultConfig{})
	assert.NoError(t, err)

	decodeHooks := provider.NewHooks()
	decodeHooks.RegisterOnFetch(decodeVault.OnKeys)

	var keyCount int
	decodeHooks.RegisterOnFetch(func(secretAlias string, keys *key.KeySet) {
		keyCount = len(keys.Keys)
	})

	decodeProvider, err := NewAWSProvider(&Config{
		Session:        session.Must(session.NewSession()),
		SecretIDs:      map[string]string{"sjc02": "public"},
		Hooks:          decodeHooks,
		SecretsManager: sm,
	})
	assert.NoError(t, err)

	var successSet [][]byte
	var failSet [][]byte
	for i := time.Duration(0); i < 2*DefaultKeyExpirationTTL; i += DefaultKeyRotationTTL {
		mockClock.Add(DefaultKeyRotationTTL)

		err = vizimaRotator.Rotate(context.Background())
		assert.NoError(t, err)
		err = encodeProvider.Fetch(context.Background())
		assert.NoError(t, err)
		err = decodeProvider.Fetch(context.Background())
		assert.NoError(t, err)

		secret := []byte("omg i am super secure")
		encoded, err := encodeVault.Encode("sjc02", secret)
		assert.NoError(t, err)

		successSet = append(successSet, encoded)
		if len(successSet) > int(DefaultKeyDeactivationTTL/DefaultKeyRotationTTL) {
			failSet = append(failSet, successSet[0])
			successSet = successSet[1:]
		}

		for _, encoded := range successSet {
			id, decoded, err := decodeVault.Decode(encoded)
			require.NoError(t, err)
			assert.Equal(t, "sjc02", id)
			assert.Equal(t, secret, decoded)
		}

		for _, encoded := range failSet {
			_, _, err := decodeVault.Decode(encoded)
			require.Error(t, err)
		}
	}

	// This is + 1 because we always provision once we evict
	assert.Equal(t, int(DefaultKeyExpirationTTL/DefaultKeyRotationTTL)+1, keyCount)
}
