package shared

import (
	"testing"

	"code.justin.tv/extensions/configuration/services/main/data/model"
	"code.justin.tv/extensions/configuration/services/main/data/model/bad"
	"code.justin.tv/extensions/configuration/services/main/protocol"
	"github.com/stretchr/testify/assert"
)

type mockCommonPromise struct {
	block bool
	data  *model.Common
	err   error
}

func (c *mockCommonPromise) WouldBlock() bool               { return c.block }
func (c *mockCommonPromise) Get() (interface{}, error)      { return c.data, c.err }
func (c *mockCommonPromise) Common() (*model.Common, error) { return c.data, c.err }

func TestCommonPromise(t *testing.T) {
	t.Run("should handle unset correctly", func(t *testing.T) {
		p := CommonPromise(&mockCommonPromise{block: true})
		assert.True(t, p.WouldBlock())
		assert.IsType(t, &recordPromise{}, p.ForSegment(protocol.GlobalType))
	})

	t.Run("should handle set correctly", func(t *testing.T) {
		in := model.NewCommon("env", "extID")
		p := CommonPromise(&mockCommonPromise{data: in})
		assert.False(t, p.WouldBlock())
		i, e := p.Get()
		assert.Equal(t, in, i)
		assert.NoError(t, e)

		out, e := p.Common()
		assert.Equal(t, in, out)
		assert.NoError(t, e)
	})

	t.Run("should handle empty correctly", func(t *testing.T) {
		p := CommonPromise(&mockCommonPromise{})
		assert.False(t, p.WouldBlock())
		i, e := p.Get()
		assert.Nil(t, i)
		assert.NoError(t, e)

		out, e := p.Common()
		assert.Nil(t, out)
		assert.NoError(t, e)
	})

	t.Run("should handle errored correctly", func(t *testing.T) {
		p := CommonPromise(&mockCommonPromise{err: bad.ErrExpected})
		assert.False(t, p.WouldBlock())
		i, e := p.Get()
		assert.Nil(t, i)
		assert.Equal(t, bad.ErrExpected, e)

		out, e := p.Common()
		assert.Nil(t, out)
		assert.Equal(t, bad.ErrExpected, e)
	})
}

func TestCommonPromise_Find(t *testing.T) {
	co := model.NewCommon("env", "extID")
	t.Run("global value", func(t *testing.T) {
		p := CommonPromise(&mockCommonPromise{data: co})
		seg := protocol.GlobalType
		testValue(t, co.ExtensionID, "", seg, p.ForSegment(seg))
	})

	t.Run("invalid segment", func(t *testing.T) {
		p := CommonPromise(&mockCommonPromise{data: co})
		a, r, err := p.ForSegment(protocol.DeveloperType).Record()
		assert.Nil(t, a)
		assert.Nil(t, r)
		assert.Equal(t, protocol.ErrIllegalSegmentChannel(protocol.DeveloperType, ""), err)
	})

	t.Run("error", func(t *testing.T) {
		p := CommonPromise(&mockCommonPromise{err: bad.ErrExpected})
		a, r, err := p.ForSegment(protocol.GlobalType).Record()
		assert.Nil(t, a)
		assert.Nil(t, r)
		assert.Equal(t, bad.ErrExpected, err)
	})
}
