package datakeycache

import (
	"errors"
	"strconv"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchOLE/ole/internal/keybytes"
	"code.justin.tv/amzn/TwitchOLE/ole/internal/stats"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestCacheConfig(t *testing.T) {
	t.Run("either KeyExpiration or MaxUses must be set", func(t *testing.T) {
		cfg := CacheConfig{}
		err := cfg.validate()
		assert.Error(t, err)
	})

	t.Run("NewCache validates config", func(t *testing.T) {
		_, err := NewCache(CacheConfig{
			MaxSize: -1,
		})
		assert.Error(t, err)
	})

	t.Run("validate", func(t *testing.T) {
		newCache := func(cb func(*CacheConfig)) CacheConfig {
			cc := CacheConfig{
				KeyExpiration:      time.Hour,
				KeyExpirationSplay: time.Minute,
				MaxUses:            1,
				MaxSize:            16,
			}
			cb(&cc)
			return cc
		}

		tcs := []struct {
			Case          CacheConfig
			ExpectedError error
		}{
			{
				Case:          newCache(func(*CacheConfig) {}),
				ExpectedError: nil,
			},
			{
				Case: newCache(func(cc *CacheConfig) {
					cc.KeyExpiration = 0
					cc.MaxUses = 0
				}),
				ExpectedError: errors.New("OLECache: either KeyExpiration or MaxUses must be set to a nonzero value"),
			},
			{
				Case: newCache(func(cc *CacheConfig) {
					cc.KeyExpiration = -1
				}),
				ExpectedError: errors.New("OLECache: KeyExpiration must be greater than or equal to 0"),
			},
			{
				Case: newCache(func(cc *CacheConfig) {
					cc.KeyExpirationSplay = -1
				}),
				ExpectedError: errors.New("OLECache: KeyExpirationSplay must be greater than or equal to 0"),
			},
			{
				Case: newCache(func(cc *CacheConfig) {
					cc.MaxUses = -1
				}),
				ExpectedError: errors.New("OLECache: MaxUses must be greater than or equal to 0"),
			},
			{
				Case: newCache(func(cc *CacheConfig) {
					cc.MaxSize = -1
				}),
				ExpectedError: errors.New("OLECache: MaxSize must be greater than or equal to 0"),
			},
		}

		for nTc, tc := range tcs {
			t.Run(strconv.Itoa(nTc), func(t *testing.T) {
				assert.Equal(t, tc.ExpectedError, tc.Case.validate())
			})
		}
	})
}

func TestClearKeys(t *testing.T) {
	ci := &CacheItem{
		plaintextDataKey: []byte("p"),
		encryptedDataKey: []byte("e"),
	}
	ci.clearKeys()
	require.Empty(t, ci.plaintextDataKey)
	require.Empty(t, ci.encryptedDataKey)
}

func TestGetMissingKey(t *testing.T) {
	cache, err := NewCache(CacheConfig{
		KeyExpiration: time.Hour,
		Reporter:      &stats.NoopReporter{},
	})
	require.NoError(t, err)

	key := EncryptionKeyCacheCompositeKey{
		EncryptionContext: map[string]string{"key": "value"},
		AlgorithmID:       "aes_256",
	}
	item := cache.Get(key)
	assert.NotNil(t, item)
	assert.Nil(t, item.plaintextDataKey)
	assert.Nil(t, item.encryptedDataKey)
	assert.Equal(t, cache, item.conf)
}

func TestGetOrFetchMissingKey(t *testing.T) {
	p, e := []byte("p"), []byte("e")
	key := EncryptionKeyCacheCompositeKey{
		EncryptionContext: map[string]string{"key": "value"},
		AlgorithmID:       "aes_256",
	}

	t.Run("fetches keys", func(t *testing.T) {
		cache, err := NewCache(CacheConfig{
			KeyExpiration: time.Hour,
			Reporter:      &stats.NoopReporter{},
		})
		require.NoError(t, err)

		ci := cache.Get(key)

		gp, ge, err := ci.GetOrFetchKeys(func() ([]byte, []byte, error) {
			return p, e, nil
		})
		require.NoError(t, err)
		assert.Equal(t, p, gp)
		assert.Equal(t, e, ge)
	})

	t.Run("doesnt call fetch if cached key is still valid", func(t *testing.T) {
		cache, err := NewCache(CacheConfig{
			KeyExpiration: time.Hour,
			Reporter:      &stats.NoopReporter{},
		})
		require.NoError(t, err)

		ci := cache.Get(key)

		gp, ge, err := ci.GetOrFetchKeys(func() ([]byte, []byte, error) {
			return p, e, nil
		})
		require.NoError(t, err)
		assert.Equal(t, p, gp)
		assert.Equal(t, e, ge)

		gp, ge, err = ci.GetOrFetchKeys(func() ([]byte, []byte, error) {
			require.False(t, true, "should not be called")
			return p, e, nil
		})
		require.NoError(t, err)
		assert.Equal(t, p, gp)
		assert.Equal(t, e, ge)
	})

	t.Run("decrements tickets and fetches new key when all tickets are used", func(t *testing.T) {
		cache, err := NewCache(CacheConfig{
			KeyExpiration: time.Hour,
			MaxUses:       2,
			Reporter:      &stats.NoopReporter{},
		})
		require.NoError(t, err)

		ci := cache.Get(key)

		op, oe := keybytes.Copy(p), keybytes.Copy(e)
		gp, ge, err := ci.GetOrFetchKeys(func() ([]byte, []byte, error) {
			return op, oe, nil
		})
		require.NoError(t, err)
		assert.Equal(t, p, gp)
		assert.Equal(t, e, ge)
		assert.Equal(t, int64(1), ci.tickets, "decremented ticket on first fetch")

		gp, ge, err = ci.GetOrFetchKeys(func() ([]byte, []byte, error) {
			return op, oe, nil
		})
		require.NoError(t, err)
		assert.Equal(t, p, gp)
		assert.Equal(t, e, ge)
		assert.Equal(t, int64(0), ci.tickets, "decremented ticket on fetch")

		ap, ae := []byte("ap"), []byte("ae")
		ggp, gge, err := ci.GetOrFetchKeys(func() ([]byte, []byte, error) {
			return keybytes.Copy(ap), keybytes.Copy(ae), nil
		})
		require.NoError(t, err)
		assert.Equal(t, ap, ggp, "fetched new plaintext")
		assert.Equal(t, ae, gge, "fetched new encrypted")
		assert.Equal(t, p, gp, "didnt clear copied plaintext")
		assert.Equal(t, e, ge, "didnt clear copied encrypted")
		assert.Equal(t, []byte{0}, op, "cleared cached plaintext")
		assert.Equal(t, []byte{0}, oe, "cleared cached encrypted")
		assert.Equal(t, int64(1), ci.tickets, "reset tickets")
	})

	t.Run("key expiration", func(t *testing.T) {
		cache, err := NewCache(CacheConfig{
			KeyExpiration: 2 * time.Second,
			Reporter:      &stats.NoopReporter{},
		})
		require.NoError(t, err)
		ci := cache.Get(key)
		op, oe := keybytes.Copy(p), keybytes.Copy(e)
		gp, ge, err := ci.GetOrFetchKeys(func() ([]byte, []byte, error) {
			return op, oe, nil
		})
		require.NoError(t, err)
		assert.Equal(t, op, gp)
		assert.Equal(t, oe, ge)

		time.Sleep(2 * time.Second)

		ap, ae := []byte("ap"), []byte("ae")
		ggp, gge, err := ci.GetOrFetchKeys(func() ([]byte, []byte, error) {
			return keybytes.Copy(ap), keybytes.Copy(ae), nil
		})
		require.NoError(t, err)
		assert.Equal(t, []byte{0}, op, "cleared expired plaintext")
		assert.Equal(t, []byte{0}, oe, "cleared expired encrypted")
		assert.Equal(t, ap, ggp, "fetched new plaintext")
		assert.Equal(t, ae, gge, "fetched new encrypted")
		assert.Equal(t, p, gp, "didnt clear copied plaintext")
		assert.Equal(t, e, ge, "didnt clear copied encrypted")
	})

	t.Run("surfaces fetch error", func(t *testing.T) {
		cache, err := NewCache(CacheConfig{
			KeyExpiration: time.Hour,
			Reporter:      &stats.NoopReporter{},
		})
		require.NoError(t, err)

		ci := cache.Get(key)
		someError := errors.New("some error")
		gp, ge, err := ci.GetOrFetchKeys(func() ([]byte, []byte, error) {
			return nil, nil, someError
		})
		require.Equal(t, someError, err)
		assert.Empty(t, gp)
		assert.Empty(t, ge)
	})
}

func TestInvalidCacheItem(t *testing.T) {
	c, err := NewCache(CacheConfig{
		KeyExpiration: time.Hour,
	})
	require.NoError(t, err)

	cc := c.(*cache)

	key := &EncryptionKeyCacheCompositeKey{
		EncryptionContext: map[string]string{"a": "b"},
		AlgorithmID:       "algo",
	}

	// add invalid cache item
	cc.cache.Add(key.Key(), struct{}{})

	// doesn't panic, should overwrite and return valid cache item
	item := c.Get(key)
	require.NotNil(t, item)

	p, e, err := []byte("a"), []byte("b"), nil
	_, _, err = item.GetOrFetchKeys(func() ([]byte, []byte, error) {
		return p, e, err
	})
	require.NoError(t, err)

	item = c.Get(key)
	pp, ee, err := item.GetOrFetchKeys(func() ([]byte, []byte, error) {
		return nil, nil, nil
	})
	assert.NoError(t, err)
	assert.Equal(t, p, pp)
	assert.Equal(t, e, ee)
}

func TestCalculateExpiry(t *testing.T) {
	t.Run("with expiration", func(t *testing.T) {
		const keyExpiration = time.Minute
		newCache := func() *cache {
			return &cache{keyExpiration: keyExpiration}
		}

		t.Run("with splay should use splay", func(t *testing.T) {
			cc := newCache()
			cc.keyExpirationSplay = time.Hour

			for n := 0; n < 1000; n++ {
				assert.InDelta(t,
					time.Now().Add(keyExpiration).Add(time.Hour).UnixNano(),
					cc.calculateExpiry(),
					float64(time.Hour+5*time.Second))
			}
		})

		t.Run("with no splay should not use splay", func(t *testing.T) {
			cc := newCache()

			assert.InDelta(t,
				time.Now().Add(keyExpiration).UnixNano(),
				cc.calculateExpiry(),
				float64(5*time.Second))
		})
	})

	t.Run("with no expiration should return 0", func(t *testing.T) {
		cc := &cache{}
		assert.Equal(t, int64(0), cc.calculateExpiry())
	})
}
