package timbersaw_test

import (
	"context"
	"fmt"
	"sync"
	"testing"
	"time"

	"math/rand"

	"github.com/stretchr/testify/require"

	tm "code.justin.tv/amzn/StarfruitTimbersawClient/timbersaw"
	"code.justin.tv/video/invoker"
)

var (
	channelA = tm.CacheKey{"channelA", "pdx05"}
	channelB = tm.CacheKey{"channelB", "cmh01"}
	channelC = tm.CacheKey{"channelC", "lhr03"}

	pathsA = []tm.Path{tm.Path{"sea01", "pdx01", "pdx05"}, tm.Path{"sea01", "ord03", "pdx05"}}
	pathsB = []tm.Path{tm.Path{"lhr03", "jfk04", "cmh01"}, tm.Path{"lhr03", "iad05", "cmh01"}}
	pathsC = []tm.Path{tm.Path{"lhr05", "qro04", "lhr03"}, tm.Path{"lhr05", "ord03", "lhr03"}}

	smallCacheSize = 2
	largeCacheSize = 100
)

func getBasicCache() *tm.Cache {
	return tm.NewCache(largeCacheSize, time.Duration(0))
}

func withCacheCleanup(ctx context.Context, cacheSize int, cacheThreshold time.Duration, f func(ctx context.Context, cache *tm.Cache)) error {
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	testCache := tm.NewCache(cacheSize, cacheThreshold)

	i := invoker.New()
	i.Add(func(ctx context.Context) error {
		defer cancel()
		f(ctx, testCache)
		return nil
	})

	i.Add(testCache.Run)
	return i.Run(ctx)
}

func TestSetNonExistingPairs(t *testing.T) {
	testCache := getBasicCache()

	testCache.Set(channelA, pathsA, time.Now())
	testCache.Set(channelB, pathsB, time.Now())

	testPaths, ok := testCache.Get(channelA)
	require.Equal(t, true, ok, "entry should be in the testCache")
	require.Equal(t, pathsA, testPaths)

	testPaths, ok = testCache.Get(channelB)
	require.Equal(t, true, ok, "entry should be in the testCache")
	require.Equal(t, pathsB, testPaths)
}

func TestSetExistingPairs(t *testing.T) {
	testCache := getBasicCache()

	testCache.Set(channelA, pathsA, time.Now())
	testCache.Set(channelB, pathsB, time.Now())
	testCache.Set(channelA, pathsB, time.Now())
	testCache.Set(channelB, pathsA, time.Now())

	testPaths, ok := testCache.Get(channelA)
	require.Equal(t, true, ok, "entry should be in the testCache")
	require.Equal(t, pathsB, testPaths)

	testPaths, ok = testCache.Get(channelB)
	require.Equal(t, true, ok, "entry should be in the testCache")
	require.Equal(t, pathsA, testPaths)
}

func TestNotStaleAndNotRefreshable(t *testing.T) {
	testCache := tm.NewCache(largeCacheSize, time.Hour)

	ctx, cancel := context.WithCancel(context.Background())

	// begin GC
	go testCache.Run(ctx)

	testCache.Set(channelA, pathsA, time.Now().Add(time.Hour))
	time.Sleep(3 * time.Second)

	testPaths, ok := testCache.Get(channelA)
	require.Equal(t, true, ok)
	require.Equal(t, pathsA, testPaths)

	cancel()
}

func TestNotStaleAndRefreshable(t *testing.T) {
	testCache := tm.NewCache(largeCacheSize, time.Hour)

	ctx, cancel := context.WithCancel(context.Background())

	// begin GC
	go testCache.Run(ctx)

	testCache.Set(channelA, pathsA, time.Now().Add(time.Second))
	time.Sleep(3 * time.Second)

	testPaths, ok := testCache.Get(channelA)
	require.Equal(t, true, ok)
	require.Equal(t, pathsA, testPaths)

	cancel()
}

func TestStaleAndRefreshable(t *testing.T) {
	testCache := tm.NewCache(largeCacheSize, time.Nanosecond)

	ctx, cancel := context.WithCancel(context.Background())

	// begin GC
	go testCache.Run(ctx)

	testCache.Set(channelA, pathsA, time.Now().Add(time.Second))
	time.Sleep(3 * time.Second)

	testPaths, ok := testCache.Get(channelA)
	require.Equal(t, false, ok)
	require.Nil(t, testPaths)

	cancel()
}

func TestStaleAndNotRefreshable(t *testing.T) {
	testCache := tm.NewCache(largeCacheSize, time.Nanosecond)

	ctx, cancel := context.WithCancel(context.Background())

	// begin GC
	go testCache.Run(ctx)

	testCache.Set(channelA, pathsA, time.Now().Add(time.Hour))
	time.Sleep(3 * time.Second)

	testPaths, ok := testCache.Get(channelA)
	require.Equal(t, false, ok)
	require.Nil(t, testPaths)

	cancel()
}

func TestNextRefresh(t *testing.T) {
	withCacheCleanup(context.Background(), largeCacheSize, time.Hour, func(ctx context.Context, testCache *tm.Cache) {
		for i := 0; i < 10; i++ {
			key := tm.CacheKey{
				ChannelARN: fmt.Sprintf("channel-%d", i),
				Origin:     "ads",
			}
			paths := []tm.Path{{fmt.Sprintf("upstream-%d", i)}}

			testCache.Set(key, paths, time.Now())
		}

		var refreshKeys []tm.CacheKey
		for t := time.Now(); time.Since(t) < 3*time.Second; time.Sleep(100 * time.Millisecond) {
			refreshKeys, _ = testCache.NextRefreshKeys(10)
			if len(refreshKeys) != 0 {
				break
			}
		}
		require.Len(t, refreshKeys, 10, "Should have gotten 10 entries to refresh")

		// Note: NextRefreshKeys does not guarantee returned keys are unique
		// nor does it guarantee the paths behind returned keys are unique
		// These tests rely on the initialization above to guarantee uniqueness
		uniqueKeys := map[tm.CacheKey]struct{}{}
		for _, k := range refreshKeys {
			uniqueKeys[k] = struct{}{}
		}
		require.Len(t, uniqueKeys, 10, "Should get 10 unique keys back.")

		uniqueValues := map[string]struct{}{}
		for _, k := range refreshKeys {
			paths, ok := testCache.Get(k)
			require.True(t, ok, "Should have found the cache entry.")
			uniqueValues[paths[0][0]] = struct{}{}
		}
		require.Len(t, uniqueValues, 10, "Should get 10 unique values back.")
	})
}

func TestBasicLRUPolicy(t *testing.T) {
	testCache := tm.NewCache(smallCacheSize, time.Hour)

	testCache.Set(channelA, pathsA, time.Now().Add(time.Hour))
	testCache.Set(channelB, pathsB, time.Now().Add(time.Hour))
	testCache.Set(channelC, pathsC, time.Now().Add(time.Hour))

	// channelA, pathsA should be evicted
	testPaths, ok := testCache.Get(channelA)
	require.Equal(t, false, ok)
	require.Nil(t, testPaths)

	testPaths, ok = testCache.Get(channelC)
	require.Equal(t, true, ok)
	require.Equal(t, pathsC, testPaths)

	testPaths, ok = testCache.Get(channelB)
	require.Equal(t, true, ok)
	require.Equal(t, pathsB, testPaths)

	// Path C should be least recently used now
	testCache.Set(channelA, pathsA, time.Now().Add(time.Hour))

	testPaths, ok = testCache.Get(channelC)
	require.Equal(t, false, ok)
	require.Nil(t, testPaths)
}

func TestConcurrentSetAndGet(t *testing.T) {
	testCache := tm.NewCache(smallCacheSize, time.Hour)

	ctx, cancel := context.WithCancel(context.Background())

	// begin GC
	go testCache.Run(ctx)

	channels := []tm.CacheKey{channelA, channelB, channelC}
	pathsList := [][]tm.Path{pathsA, pathsB, pathsC}

	var n sync.WaitGroup
	for i := 0; i < 1000; i++ {
		n.Add(2)
		x, y := rand.Int()%3, rand.Int()%3
		go func() {
			testCache.Set(channels[x], pathsList[y], time.Now().Add(time.Second))
			n.Done()
		}()
		go func() {
			testCache.Get(channels[x])
			n.Done()
		}()
	}
	n.Wait()

	cancel()
}
