package streamkey

import (
	"context"
	"testing"

	"github.com/golang/protobuf/proto"
	"github.com/golang/protobuf/ptypes"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestEncryptDecrypt(t *testing.T) {
	customerID := "test-customer"

	secret, err := GenerateSecret()
	require.NoError(t, err)

	ss := make(MapSecretSource)
	require.NoError(t, ss.Set(context.Background(), customerID, secret))

	priv := &PrivateData{ExpirationTime: ptypes.TimestampNow()}
	sk := NewV1(customerID, priv)

	// do this twice and make sure that the encrypted keys are unique for the same
	// input while producing the same data when decrypted
	keya, err := Encrypt(context.Background(), ss, sk)
	require.NoError(t, err)

	keyb, err := Encrypt(context.Background(), ss, sk)
	require.NoError(t, err)

	t.Logf("keya: %q, keyb: %q", keya, keyb)

	// two stream keys should not be equal for the same input
	assert.NotEqual(t, keya, keyb)

	// check keya matches
	sk, err = Decrypt(context.Background(), ss, keya)
	require.NoError(t, err)
	assert.True(t, proto.Equal(priv, sk.Priv))

	// check keyb matches
	sk, err = Decrypt(context.Background(), ss, keyb)
	require.NoError(t, err)
	assert.True(t, proto.Equal(priv, sk.Priv))
}

func TestSecretReset(t *testing.T) {
	customerID := "test-customer"

	secret, err := GenerateSecret()
	require.NoError(t, err)

	ss := make(MapSecretSource)
	require.NoError(t, ss.Set(context.Background(), customerID, secret))

	priv := &PrivateData{ExpirationTime: ptypes.TimestampNow()}
	sk := NewV1(customerID, priv)

	// encrypt a key
	key, err := Encrypt(context.Background(), ss, sk)
	require.NoError(t, err)

	// check key decrypts correctly
	sk, err = Decrypt(context.Background(), ss, key)
	require.NoError(t, err)
	assert.True(t, proto.Equal(priv, sk.Priv))

	// generate and set a new secret
	secret, err = GenerateSecret()
	require.NoError(t, err)
	require.NoError(t, ss.Set(context.Background(), customerID, secret))

	_, err = Decrypt(context.Background(), ss, key)
	assert.Error(t, err)
}

func TestInvalidCustomer(t *testing.T) {
	customerID := "test-customer"
	ss := make(MapSecretSource)

	priv := &PrivateData{ExpirationTime: ptypes.TimestampNow()}
	sk := NewV1(customerID, priv)

	// do this twice and make sure that the encrypted keys are unique for the same
	// input while producing the same data when decrypted
	_, err := Encrypt(context.Background(), ss, sk)
	require.Error(t, err)
}

func TestAllStreamkeyParams(t *testing.T) {
	customerID := "test-customer"
	ss := make(MapSecretSource)
	secret, _ := GenerateSecret()
	require.NoError(t, ss.Set(context.Background(), customerID, secret))

	priv := &PrivateData{
		ExpirationTime:          ptypes.TimestampNow(),
		CustomerId:              "test-customer",
		ContentId:               "test-content",
		S3Bucket:                "s3-us-west-2.amazonaws.com/test-lvs-vods",
		S3Prefix:                "test-dir/test-prefix",
		SnsNotificationEndpoint: "arn:aws:sns:us-west-2:848744099708:TestLVSStreamNotifications",
		ChannelName:             "test-channel",
	}

	sk := NewV1(customerID, priv)
	key, err := Encrypt(context.Background(), ss, sk)
	require.NoError(t, err)

	// check key matches
	sk, err = Decrypt(context.Background(), ss, key)
	require.NoError(t, err)
	assert.True(t, proto.Equal(priv, sk.Priv))
}
