package ole

import (
	"bytes"
	"errors"
	"io/ioutil"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchOLE/ole/internal/datakeycache"
	"code.justin.tv/amzn/TwitchOLE/ole/internal/encryptedobject"
	"code.justin.tv/amzn/TwitchOLE/ole/internal/stats"
	telemetry "code.justin.tv/amzn/TwitchTelemetry"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/kms"
	gomock "github.com/golang/mock/gomock"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"github.com/stretchr/testify/suite"
)

func TestNewKMSOleClient(t *testing.T) {
	t.Run("validates config", func(t *testing.T) {
		cfg := KMSOleClientConfig{}
		_, err := NewKMSOleClient(nil, cfg)
		assert.Error(t, err)
	})

	t.Run("valid config", func(t *testing.T) {
		cfg := KMSOleClientConfig{
			CMKArn:     "some-arn",
			DataKeyTTL: time.Minute,
		}

		mockController := gomock.NewController(t)
		mockKMSAPI := NewMockKMSAPI(mockController)
		client, err := NewKMSOleClient(mockKMSAPI, cfg)
		require.NoError(t, err)
		assert.NotNil(t, client)
		assert.Equal(t, mockKMSAPI, client.kms)
		assert.NotNil(t, client.cfg)
		assert.NotNil(t, client.encryptionKeyCache)
		assert.NotNil(t, client.decryptionKeyCache)
		assert.NotNil(t, client.reporter)
	})

	t.Run("user provided sample reporter", func(t *testing.T) {
		reporter := &stats.NoopReporter{}
		cfg := KMSOleClientConfig{
			CMKArn:     "some-arn",
			DataKeyTTL: time.Minute,
			Reporter:   reporter,
		}

		mockController := gomock.NewController(t)
		mockKMSAPI := NewMockKMSAPI(mockController)
		client, err := NewKMSOleClient(mockKMSAPI, cfg)
		require.NoError(t, err)
		assert.NotNil(t, client)
		assert.Equal(t, reporter, client.reporter)
	})

}

func TestValidate(t *testing.T) {
	t.Run("missing cmkarn", func(t *testing.T) {
		cfg := &KMSOleClientConfig{}
		assert.Error(t, cfg.validate())
	})

	t.Run("defaults missing algorithm", func(t *testing.T) {
		cfg := &KMSOleClientConfig{
			CMKArn: "some-arn",
		}
		assert.NoError(t, cfg.validate())
		assert.Equal(t, AES256GCM{}, cfg.Algorithm)
	})
}

func TestDecryptMissingEncryptedDataKey(t *testing.T) {
	client := &KMSClient{}
	obj := &encryptedobject.EncryptedObject{}
	_, err := client.decrypt(obj)
	assert.Error(t, err)
}

type KMSClientTestSuite struct {
	suite.Suite
	mockEncryptionKeyCacher *MockDataKeyCacher
	mockDecryptionKeyCacher *MockDataKeyCacher
	mockKMSAPI              *MockKMSAPI
	mockController          *gomock.Controller
	mockReporter            *MockSampleReporterAPI

	buffer *bytes.Buffer

	KMSClient *KMSClient
}

func (suite *KMSClientTestSuite) SetupTest() {
	suite.mockController = gomock.NewController(suite.T())
	suite.mockEncryptionKeyCacher = NewMockDataKeyCacher(suite.mockController)
	suite.mockDecryptionKeyCacher = NewMockDataKeyCacher(suite.mockController)
	suite.mockKMSAPI = NewMockKMSAPI(suite.mockController)
	suite.mockReporter = NewMockSampleReporterAPI(suite.mockController)
	suite.KMSClient = &KMSClient{
		kms: suite.mockKMSAPI,
		cfg: KMSOleClientConfig{
			CMKArn:    "cmk-arn",
			Algorithm: AES256GCM{},
		},
		encryptionKeyCache: suite.mockEncryptionKeyCacher,
		decryptionKeyCache: suite.mockDecryptionKeyCacher,
		reporter:           suite.mockReporter,
	}
	suite.buffer = &bytes.Buffer{}
}

func (suite *KMSClientTestSuite) TestGetOrGenerateDataKeyOnlyCallsKMSOnce() {
	defer suite.mockController.Finish()
	suite.KMSClient.reporter = &stats.NoopReporter{}
	suite.mockKMSAPI.EXPECT().GenerateDataKey(gomock.Any()).Return(&kms.GenerateDataKeyOutput{
		Plaintext:      []byte("plaintext"),
		CiphertextBlob: []byte("encrypted"),
	}, nil).Do(func(*kms.GenerateDataKeyInput) {
		time.Sleep(100 * time.Millisecond)
	}).MaxTimes(1)
	suite.mockReporter.EXPECT().Report("OLECacheEncryptionKeyMiss", 1.0, telemetry.UnitCount).MaxTimes(1)
	suite.mockReporter.EXPECT().Report("OLECacheEncryptionKeyHit", 1.0, telemetry.UnitCount).AnyTimes()

	// use a real encryption cache
	var err error
	suite.KMSClient.encryptionKeyCache, err = datakeycache.NewCache(datakeycache.CacheConfig{
		KeyExpiration: time.Hour,
	})
	suite.Require().NoError(err)

	done := make(chan interface{}, 10)
	for n := 0; n < 10; n++ {
		go func() {
			_, _, err := suite.KMSClient.getOrGenerateEncryptionDataKey(map[string]string{
				"a": "b",
			})
			require.NoError(suite.T(), err)
			done <- nil
		}()
	}

	for n := 0; n < 10; n++ {
		<-done
	}
}

func (suite *KMSClientTestSuite) TestEncryptAndDecrypt() {
	defer suite.mockController.Finish()
	plaintextData := []byte("plaintextData")
	plaintextKey := []byte("plaintextKey11111111111111111111") // 32 byte key lol
	encryptedKey := []byte("encryptedKey")
	encryptionContext := map[string]string{"k": "v"}
	encryptionKeyCache, err := datakeycache.NewCache(datakeycache.CacheConfig{
		KeyExpiration: time.Hour,
		KeyType:       datakeycache.KeyTypeEncryption,
		Reporter:      suite.mockReporter,
	})
	suite.Require().NoError(err)
	suite.KMSClient.encryptionKeyCache = encryptionKeyCache

	decryptionKeyCache, err := datakeycache.NewCache(datakeycache.CacheConfig{
		KeyExpiration: time.Hour,
		KeyType:       datakeycache.KeyTypeDecryption,
		Reporter:      suite.mockReporter,
	})
	suite.Require().NoError(err)
	suite.KMSClient.decryptionKeyCache = decryptionKeyCache

	suite.mockReporter.EXPECT().Report("OLECacheEncryptionKeyMiss", 1.0, telemetry.UnitCount)
	suite.mockReporter.EXPECT().Report("OLECacheDecryptionKeyMiss", 1.0, telemetry.UnitCount)

	suite.mockKMSAPI.EXPECT().GenerateDataKey(gomock.Eq(&kms.GenerateDataKeyInput{
		EncryptionContext: toAWSEncryptionContext(encryptionContext),
		KeyId:             aws.String(suite.KMSClient.cfg.CMKArn),
		KeySpec:           aws.String("AES_256"),
	})).Return(&kms.GenerateDataKeyOutput{
		Plaintext:      plaintextKey,
		CiphertextBlob: encryptedKey,
	}, nil)

	suite.mockKMSAPI.EXPECT().Decrypt(gomock.Eq(&kms.DecryptInput{
		CiphertextBlob:    encryptedKey,
		EncryptionContext: toAWSEncryptionContext(encryptionContext),
	})).Return(&kms.DecryptOutput{Plaintext: plaintextKey}, nil)

	// encrypt and write
	w := suite.KMSClient.NewEncryptor(encryptionContext, suite.buffer)
	_, err = w.Write(plaintextData)
	suite.Require().NoError(err)

	// read and decrypt
	r := suite.KMSClient.NewDecryptor(suite.buffer)
	bs, err := ioutil.ReadAll(r)
	suite.Require().NoError(err)
	suite.Assert().Equal(plaintextData, bs)
}

func (suite *KMSClientTestSuite) TestGenerateDataKeyError() {
	plaintextData := []byte("plaintextData")
	encryptionContext := map[string]string{"k": "v"}

	encryptionCacheKey := datakeycache.EncryptionKeyCacheCompositeKey{
		EncryptionContext: encryptionContext,
		AlgorithmID:       AES256GCM{}.ID(),
	}
	suite.mockEncryptionKeyCacher.EXPECT().Get(gomock.Eq(encryptionCacheKey)).Return(&datakeycache.CacheItem{})
	suite.mockReporter.EXPECT().Report("OLECacheEncryptionKeyMiss", 1.0, telemetry.UnitCount)

	gdkErr := errors.New("my-gdk-error")
	suite.mockKMSAPI.EXPECT().GenerateDataKey(gomock.Eq(&kms.GenerateDataKeyInput{
		EncryptionContext: toAWSEncryptionContext(encryptionContext),
		KeyId:             aws.String(suite.KMSClient.cfg.CMKArn),
		KeySpec:           aws.String("AES_256"),
	})).Return(nil, gdkErr)

	// encrypt and write
	w := suite.KMSClient.NewEncryptor(encryptionContext, suite.buffer)
	_, err := w.Write(plaintextData)
	suite.Assert().Equal(gdkErr, err)

	suite.mockController.Finish()
}

func TestKMSClientTestSuite(t *testing.T) {
	suite.Run(t, &KMSClientTestSuite{})
}
