package shared

import (
	"testing"

	"code.justin.tv/extensions/configuration/services/main/data/model"
	"code.justin.tv/extensions/configuration/services/main/data/model/bad"
	"code.justin.tv/extensions/configuration/services/main/protocol"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type mockBaseChannel struct {
	block bool
	data  *model.Channel
	err   error
}

func (c *mockBaseChannel) WouldBlock() bool                 { return c.block }
func (c *mockBaseChannel) Get() (interface{}, error)        { return c.data, c.err }
func (c *mockBaseChannel) Channel() (*model.Channel, error) { return c.data, c.err }

func testValue(t *testing.T, ext, ch string, seg protocol.SegmentType, p model.RecordPromise) {
	require.False(t, p.WouldBlock())
	addr := protocol.Address{
		ChannelID:   ch,
		ExtensionID: ext,
		SegmentType: seg,
	}

	a, r, e := p.Record()
	assert.Equal(t, &addr, a)
	assert.Nil(t, r)
	assert.NoError(t, e)

	m, e := p.Get()
	assert.Equal(t, protocol.RecordMap{addr: r}, m)
	assert.Nil(t, e)
}

func TestChannelPromise(t *testing.T) {
	t.Run("should handle unset correctly", func(t *testing.T) {
		p := ChannelPromise(&mockBaseChannel{block: true})
		assert.True(t, p.WouldBlock())
		assert.IsType(t, &recordPromise{}, p.ForSegment(protocol.GlobalType))
	})

	t.Run("should handle set correctly", func(t *testing.T) {
		in := model.NewChannel("env", "extID", "chID")
		p := ChannelPromise(&mockBaseChannel{data: in})
		assert.False(t, p.WouldBlock())
		i, e := p.Get()
		assert.Equal(t, in, i)
		assert.NoError(t, e)

		out, e := p.Channel()
		assert.Equal(t, in, out)
		assert.NoError(t, e)
	})

	t.Run("should handle empty correctly", func(t *testing.T) {
		p := ChannelPromise(&mockBaseChannel{})
		assert.False(t, p.WouldBlock())
		i, e := p.Get()
		assert.Nil(t, i)
		assert.NoError(t, e)

		out, e := p.Channel()
		assert.Nil(t, out)
		assert.NoError(t, e)
	})

	t.Run("should handle errored correctly", func(t *testing.T) {
		p := ChannelPromise(&mockBaseChannel{err: bad.ErrExpected})
		assert.False(t, p.WouldBlock())
		i, e := p.Get()
		assert.Nil(t, i)
		assert.Equal(t, bad.ErrExpected, e)

		out, e := p.Channel()
		assert.Nil(t, out)
		assert.Equal(t, bad.ErrExpected, e)
	})
}

func TestChannelPromise_Find(t *testing.T) {
	seg := protocol.BroadcasterType
	t.Run("broadcaster value", func(t *testing.T) {
		ch := model.NewChannel("env", "extID", "chID")
		p := ChannelPromise(&mockBaseChannel{data: ch})
		testValue(t, ch.ExtensionID, ch.ChannelID, seg, p.ForSegment(seg))
	})

	t.Run("broadcaster (invalid)", func(t *testing.T) {
		ch := model.NewChannel("env", "extID", "")
		p := ChannelPromise(&mockBaseChannel{data: ch})
		rp := p.ForSegment(seg)
		a, r, err := rp.Record()
		assert.Nil(t, a)
		assert.Nil(t, r)
		expected := protocol.ErrIllegalSegmentChannel(seg, ch.ChannelID)
		assert.Equal(t, expected, err)
	})

	seg = protocol.DeveloperType
	t.Run("developer values", func(t *testing.T) {
		ch := model.NewChannel("env", "extID", "chID")
		p := ChannelPromise(&mockBaseChannel{data: ch})
		testValue(t, ch.ExtensionID, ch.ChannelID, seg, p.ForSegment(seg))
	})

	t.Run("developer (invalid)", func(t *testing.T) {
		ch := model.NewChannel("env", "extID", "")
		p := ChannelPromise(&mockBaseChannel{data: ch})
		a, r, err := p.ForSegment(seg).Record()
		assert.Nil(t, a)
		assert.Nil(t, r)
		expected := protocol.ErrIllegalSegmentChannel(seg, ch.ChannelID)
		assert.Equal(t, expected, err)
	})

	seg = protocol.GlobalType
	t.Run("invalid segment", func(t *testing.T) {
		ch := model.NewChannel("env", "extID", "")
		p := ChannelPromise(&mockBaseChannel{data: ch})
		a, r, err := p.ForSegment(seg).Record()
		assert.Nil(t, a)
		assert.Nil(t, r)
		expected := protocol.ErrIllegalSegmentChannel(seg, ch.ChannelID)
		assert.Equal(t, expected, err)
	})

	t.Run("error", func(t *testing.T) {
		p := ChannelPromise(&mockBaseChannel{err: bad.ErrExpected})
		a, r, err := p.ForSegment(seg).Record()
		assert.Nil(t, a)
		assert.Nil(t, r)
		assert.Equal(t, bad.ErrExpected, err)
	})
}
