package history

import (
	"testing"
	"time"

	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type memoryCount struct{ value int64 }

func (m *memoryCount) Add(delta int64) { m.value += delta }

var _ metrics.Count = (*memoryCount)(nil)

func TestRateLimited(t *testing.T) {
	addr, err := stream.NewAddress(stream.Namespace("n"), 1, nil)
	require.NoError(t, err)
	t.Run("should initialize", func(t *testing.T) {
		dummyCount := &memoryCount{}
		h := NewRateLimited(addr, 1, time.Minute, dummyCount)
		assert.Equal(t, addr, h.Address())
		src, pos := h.Last()
		assert.Equal(t, stream.None, src)
		assert.Equal(t, stream.Origin, pos)
		msg, ok := h.Next(stream.None, stream.Origin)
		assert.Nil(t, msg)
		assert.False(t, ok)
		assert.Zero(t, dummyCount.value)
	})

	t.Run("should report last value", func(t *testing.T) {
		dummyCount := &memoryCount{}
		from := stream.SourceID(1)
		seg := stream.Segment{Start: stream.Origin, End: stream.Position(1)}
		msg := stream.NewMessage(addr, from, seg, []byte("hello"))
		h := NewRateLimited(addr, 1, time.Minute, dummyCount)
		assert.NoError(t, h.Write(msg))
		out, ok := h.Next(stream.None, stream.Origin)
		assert.Equal(t, msg, out)
		assert.True(t, ok)
		src, pos := h.Last()
		assert.Equal(t, from, src)
		assert.Equal(t, seg.End, pos)
		assert.Zero(t, dummyCount.value)
	})

	t.Run("should reject new values if under limit", func(t *testing.T) {
		dummyCount := &memoryCount{}
		from := stream.SourceID(1)
		seg := stream.Segment{Start: stream.Origin, End: stream.Position(1)}
		msg := stream.NewMessage(addr, from, seg, []byte("hello"))
		h := NewRateLimited(addr, 5, time.Hour, dummyCount)
		assert.NoError(t, h.Write(msg)) // 4 remain
		assert.NoError(t, h.Write(msg)) // 3 remain
		assert.NoError(t, h.Write(msg)) // 2 remain
		assert.NoError(t, h.Write(msg)) // 1 remain
		assert.NoError(t, h.Write(msg)) // 0 remain
		assert.Equal(t, protocol.ErrRateLimited(5, time.Hour), h.Write(msg))
		assert.Equal(t, protocol.ErrRateLimited(5, time.Hour), h.Write(msg))
		assert.Equal(t, dummyCount.value, int64(2)) // make sure we report the limit hits
	})

	t.Run("should allow more uses over time", func(t *testing.T) {
		h := NewRateLimited(addr, 5, 5*time.Millisecond, &memoryCount{}).(*rateLimited)
		now := time.Now()
		assert.True(t, h.check(now)) // 4 remain
		assert.True(t, h.check(now)) // 3 remain
		assert.True(t, h.check(now)) // 2 remain

		now = now.Add(3 * time.Millisecond) // +3 milliseconds, refund 0
		assert.True(t, h.check(now))        // 1 remain
		assert.True(t, h.check(now))        // 0 remain
		assert.False(t, h.check(now))

		now = now.Add(3 * time.Millisecond) // +6 total milliseconds, refund 2, 1 already consumed
		assert.True(t, h.check(now))        // 1 remain
		assert.True(t, h.check(now))        // 0 remain
		assert.False(t, h.check(now))
	})

	t.Run("should clear values on close", func(t *testing.T) {
		dummyCount := &memoryCount{}
		from := stream.SourceID(1)
		seg := stream.Segment{Start: stream.Origin, End: stream.Position(1)}
		msg := stream.NewMessage(addr, from, seg, []byte("hello"))
		h := NewRateLimited(addr, 1, time.Minute, dummyCount)
		assert.NoError(t, h.Write(msg))
		assert.NoError(t, h.Close())
		out, ok := h.Next(stream.SourceID(1), stream.Origin)
		assert.Nil(t, out)
		assert.False(t, ok)
		src, pos := h.Last()
		assert.Equal(t, stream.None, src)
		assert.Equal(t, stream.Origin, pos)
	})
}
