package scheduler

import (
	"testing"

	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/listener"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type stuckListener struct {
	allowLost     bool
	allowReceived bool
}

func (*stuckListener) Current(stream.Address) (stream.SourceID, stream.Position) {
	return stream.None, stream.Origin
}
func (s *stuckListener) OnDataLost(stream.MessageDescription) bool { return s.allowLost }
func (s *stuckListener) OnDataReceived(stream.Message) bool {
	return s.allowReceived
}
func (*stuckListener) OnStreamClosed(stream.Address, error) bool { return true }

type dummySeg struct {
	src stream.SourceID
	seg stream.Segment
}

func create(src stream.SourceID, start stream.Position, end stream.Position) dummySeg {
	return dummySeg{src, stream.Segment{Start: start, End: end}}
}

type dummyHistory struct {
	addr stream.Address
	segs []dummySeg
}

func newDummyHistory(addr stream.Address, segs ...dummySeg) *dummyHistory {
	return &dummyHistory{addr, segs}
}

func TestUpdate(t *testing.T) {
	from := stream.None + 1
	t.Run("should not change listener state if nothing is available", func(t *testing.T) {
		addr, err := stream.NewAddress(stream.Namespace("test"), 1, nil)
		require.NoError(t, err)
		l := listener.NewDevNull()
		written, err := Update(newDummyHistory(addr), stream.Origin, l)
		assert.Zero(t, written)
		assert.NoError(t, err)
		src, pos := l.Current(addr)
		assert.Equal(t, stream.None, src)
		assert.Equal(t, stream.Origin, pos)
		assert.Zero(t, l.Received())
		assert.Zero(t, l.Lost())
	})

	t.Run("should forward keyframes without missing notices", func(t *testing.T) {
		addr, err := stream.NewAddress(stream.Namespace("test"), 1, nil)
		require.NoError(t, err)
		l := listener.NewDevNull()
		written, err := Update(newDummyHistory(addr,
			create(from, stream.Origin, stream.Position(1)),
			create(from, stream.Origin, stream.Position(3)),
			create(from, stream.Origin, stream.Position(10)),
		), stream.Origin, l)
		assert.Zero(t, written)
		assert.NoError(t, err)
		src, pos := l.Current(addr)
		assert.Equal(t, from, src)
		assert.Equal(t, stream.Position(10), pos)
		assert.Equal(t, 3, l.Received())
		assert.Zero(t, l.Lost())
	})

	t.Run("should forward deltas with missing notices", func(t *testing.T) {
		addr, err := stream.NewAddress(stream.Namespace("test"), 1, nil)
		require.NoError(t, err)
		l := listener.NewDevNull()
		written, err := Update(newDummyHistory(addr,
			create(from, stream.Origin, stream.Position(1)),
			create(from, stream.Position(2), stream.Position(3)),
			create(from, stream.Position(3), stream.Position(4)),
		), stream.Origin, l)
		assert.NoError(t, err)
		assert.Zero(t, written)
		src, pos := l.Current(addr)
		assert.Equal(t, from, src)
		assert.Equal(t, stream.Position(4), pos)
		assert.Equal(t, 3, l.Received())
		assert.Equal(t, 1, l.Lost())
	})

	t.Run("should append a missing notice when appropriate", func(t *testing.T) {
		addr, err := stream.NewAddress(stream.Namespace("test"), 1, nil)
		require.NoError(t, err)
		l := listener.NewDevNull()
		written, err := Update(newDummyHistory(addr,
			create(from, stream.Origin, stream.Position(1)),
			create(from, stream.Position(2), stream.Position(3)),
			create(from, stream.Position(3), stream.Position(4)),
		), stream.Position(6), l)
		assert.NoError(t, err)
		assert.Zero(t, written)
		src, pos := l.Current(addr)
		assert.Equal(t, from, src)
		assert.Equal(t, stream.Position(6), pos)
		assert.Equal(t, 3, l.Received())
		assert.Equal(t, 2, l.Lost())
	})

	t.Run("should return immediately if no progress is reported", func(t *testing.T) {
		addr, err := stream.NewAddress(stream.Namespace("test"), 1, nil)
		require.NoError(t, err)
		l := &stuckListener{true, true}
		written, err := Update(newDummyHistory(addr, create(from, stream.Origin, stream.Position(1))), stream.Origin, l)
		assert.Equal(t, stream.ErrNoProgress, err)
		assert.Zero(t, written)
	})

	t.Run("should abort if failure is reported on loss", func(t *testing.T) {
		addr, err := stream.NewAddress(stream.Namespace("test"), 1, nil)
		require.NoError(t, err)
		l := &stuckListener{false, true}
		written, err := Update(newDummyHistory(addr, create(from, stream.Position(1), stream.Position(2))), stream.Origin, l)
		assert.NoError(t, err)
		assert.Zero(t, written)
	})

	t.Run("should abort if failure is reported on receive", func(t *testing.T) {
		addr, err := stream.NewAddress(stream.Namespace("test"), 1, nil)
		require.NoError(t, err)
		l := &stuckListener{true, false}
		written, err := Update(newDummyHistory(addr, create(from, stream.Origin, stream.Position(1))), stream.Origin, l)
		assert.NoError(t, err)
		assert.Zero(t, written)
	})
}

func (d *dummyHistory) Address() stream.Address  { return d.addr }
func (*dummyHistory) Write(stream.Message) error { return nil }
func (*dummyHistory) Close() error               { return nil }
func (d *dummyHistory) Last() (stream.SourceID, stream.Position) {
	length := len(d.segs)
	if length < 1 {
		return stream.None, stream.Origin
	}
	data := d.segs[length-1]
	return data.src, data.seg.End
}

func (d *dummyHistory) Next(src stream.SourceID, pos stream.Position) (stream.Message, bool) {
	for _, data := range d.segs {
		if data.src != src || data.seg.End > pos {
			return stream.NewMessage(d.Address(), data.src, data.seg, nil), true
		}
	}
	return nil, false
}
