package objectcache

import (
	"context"
	"sync"
	"testing"
	"time"

	"code.justin.tv/hygienic/statsdsender"
	"github.com/cactus/go-statsd-client/statsd"
	"github.com/stretchr/testify/require"
)

type memoryClientPool struct {
	vals   map[string][]byte
	hits   int
	misses int
	mu     sync.Mutex
}

func (m *memoryClientPool) Delete(ctx context.Context, key string) error {
	m.mu.Lock()
	defer m.mu.Unlock()
	delete(m.vals, key)
	return nil
}

func (m *memoryClientPool) Get(ctx context.Context, key string) ([]byte, error) {
	m.mu.Lock()
	defer m.mu.Unlock()
	ret, exists := m.vals[key]
	if exists {
		m.hits++
		return ret, nil
	}
	m.misses++
	return nil, nil
}

func (m *memoryClientPool) Return() {
}

func (m *memoryClientPool) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
	m.mu.Lock()
	defer m.mu.Unlock()
	if m.vals == nil {
		m.vals = make(map[string][]byte)
	}
	m.vals[key] = value
	return nil
}

func (m *memoryClientPool) GetClient(ctx context.Context) CacheClient {
	return m
}

var _ CacheClientPool = &memoryClientPool{}

func gz(s string) string {
	b, err := GZipDecoder().Marshal(s)
	if err != nil {
		panic(err)
	}
	return string(b)
}

func requireHitsMisses(t *testing.T, c *memoryClientPool, hits int, misses int) {
	t.Helper()
	if c.hits != hits {
		t.Fatalf("Hits mismatch %d vs %d", hits, c.hits)
	}
	if c.misses != misses {
		t.Fatalf("misses mismatch %d vs %d", misses, c.misses)
	}
}

func TestMigration_Cached(t *testing.T) {
	c1 := &memoryClientPool{}
	c2 := &memoryClientPool{}
	stats := &statsdsender.ErrorlessStatSender{
		StatSender: &statsd.NoopClient{},
	}

	objC1 := &ObjectCache{
		ClientPool: c1,
		Stats:      stats,
		KeyPrefix:  "1:",
	}
	objC2 := &ObjectCache{
		ClientPool:   c2,
		Stats:        stats,
		KeyPrefix:    "2:",
		ValueDecoder: GZipDecoder(),
	}
	mig := Migration{
		From: objC1,
		To:   objC2,
	}
	ctx := context.Background()
	var res string
	requireHitsMisses(t, c1, 0, 0)
	require.NoError(t, objC1.Cached(ctx, "hello2", func() (i interface{}, e error) {
		return "world2", nil
	}, &res))
	require.Equal(t, `"world2"`, string(c1.vals["1:hello2"]))
	require.Equal(t, "world2", res)
	requireHitsMisses(t, c1, 0, 1)

	require.NoError(t, mig.Cached(ctx, "hello", func() (i interface{}, e error) {
		return "world", nil
	}, &res))
	requireHitsMisses(t, c1, 0, 2)
	requireHitsMisses(t, c2, 0, 1)
	require.Equal(t, "world", res)
	require.Equal(t, `"world"`, string(c1.vals["1:hello"]))
	require.Equal(t, gz("world"), string(c2.vals["2:hello"]))

	var res3 string
	require.NoError(t, mig.Cached(ctx, "hello", func() (i interface{}, e error) {
		panic("Unreachable")
	}, &res3))
	require.Equal(t, "world", res3)
	requireHitsMisses(t, c1, 0, 2)
	requireHitsMisses(t, c2, 1, 1)

	var res4 string
	require.NotEqual(t, gz("world2"), string(c2.vals["2:hello2"]))
	require.NoError(t, mig.Cached(ctx, "hello2", func() (i interface{}, e error) {
		panic("Unreachable")
	}, &res4))
	require.Equal(t, "world2", res4)
	require.Equal(t, gz("world2"), string(c2.vals["2:hello2"]))
	requireHitsMisses(t, c1, 1, 2)
	requireHitsMisses(t, c2, 1, 2)
}
