package stream

import (
	"testing"

	"github.com/stretchr/testify/assert"
)

type testDesc struct {
	src SourceID
	at  Segment
}

var _ MessageDescription = (*testDesc)(nil)

func (*testDesc) Address() Address   { return nil }
func (d *testDesc) Source() SourceID { return d.src }
func (d *testDesc) At() Segment      { return d.at }

func TestTracker(t *testing.T) {
	t.Run("should have sane initial values", func(t *testing.T) {
		var tracker Tracker
		src, pos := tracker.Current()
		assert.Equal(t, None, src)
		assert.Equal(t, Origin, pos)

		tracker2 := CreateTracker(None+1, Origin+1)
		src, pos = tracker2.Current()
		assert.Equal(t, None+1, src)
		assert.Equal(t, Origin+1, pos)
	})

	t.Run("should allow nil reads/writes with sane defaults", func(t *testing.T) {
		tracker := (*Tracker)(nil)
		src, pos := tracker.Current()
		assert.Equal(t, None, src)
		assert.Equal(t, Origin, pos)

		src, pos = tracker.Next()
		assert.Equal(t, None, src)
		assert.Equal(t, Origin, pos)

		assert.False(t, tracker.Set(src, pos+1))
		assert.False(t, tracker.SetSource(src))

		src, ok := tracker.SetPosition(pos + 1)
		assert.Equal(t, None, src)
		assert.False(t, ok)
	})

	t.Run("should allow threadsafe increment", func(t *testing.T) {
		from := None + 1
		var tracker Tracker
		assert.True(t, tracker.SetSource(from))
		src, pos := tracker.Current()
		assert.Equal(t, from, src)
		assert.Equal(t, Origin, pos)

		src, pos = tracker.Next()
		assert.Equal(t, from, src)
		assert.Equal(t, Origin+1, pos)

		src, pos = tracker.Next()
		assert.Equal(t, from, src)
		assert.Equal(t, Origin+2, pos)
	})

	t.Run("should only accept increasing values for a single source", func(t *testing.T) {
		from := None + 1
		pos1 := Origin + 1
		pos2 := Origin + 2
		pos3 := Origin + 3
		var tracker Tracker

		assert.True(t, tracker.Set(from, pos1))
		src, pos := tracker.Current()
		assert.Equal(t, from, src)
		assert.Equal(t, pos1, pos)

		assert.True(t, tracker.Set(from, pos2))
		src, pos = tracker.Current()
		assert.Equal(t, from, src)
		assert.Equal(t, pos2, pos)

		assert.False(t, tracker.Set(from, pos1))
		src, pos = tracker.Current()
		assert.Equal(t, from, src)
		assert.Equal(t, pos2, pos)

		assert.False(t, tracker.Set(from, pos2))
		src, pos = tracker.Current()
		assert.Equal(t, from, src)
		assert.Equal(t, pos2, pos)

		src, ok := tracker.SetPosition(pos3)
		assert.True(t, ok)
		assert.Equal(t, from, src)
		src, pos = tracker.Current()
		assert.Equal(t, from, src)
		assert.Equal(t, pos3, pos)

		src, ok = tracker.SetPosition(pos3)
		assert.False(t, ok)
		assert.Equal(t, None, src)
		src, pos = tracker.Current()
		assert.Equal(t, from, src)
		assert.Equal(t, pos3, pos)
	})

	t.Run("should accept any value from changing sources", func(t *testing.T) {
		from1 := None + 1
		from2 := None + 2
		pos1 := Origin + 1
		var tracker Tracker

		assert.True(t, tracker.Set(from1, pos1))
		src, pos := tracker.Current()
		assert.Equal(t, from1, src)
		assert.Equal(t, pos1, pos)

		// note this resets the position counter
		assert.True(t, tracker.SetSource(from2))
		src, pos = tracker.Current()
		assert.Equal(t, from2, src)
		assert.Equal(t, Origin, pos)

		assert.True(t, tracker.Set(from1, pos1))
		src, pos = tracker.Current()
		assert.Equal(t, from1, src)
		assert.Equal(t, pos1, pos)

		assert.True(t, tracker.Set(from2, pos1))
		src, pos = tracker.Current()
		assert.Equal(t, from2, src)
		assert.Equal(t, pos1, pos)

		desc := &testDesc{from1, Segment{Origin, pos1}}
		assert.True(t, tracker.Accept(desc))
		src, pos = tracker.Current()
		assert.Equal(t, from1, src)
		assert.Equal(t, pos1, pos)
	})
}
