package protocol

import (
	"encoding/json"
	"testing"

	"code.justin.tv/gds/gds/golibs/errors"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestSegment(t *testing.T) {
	t.Run("should default to global", func(t *testing.T) {
		var s Segment
		assert.Equal(t, GlobalType, s.Type())
		assert.Equal(t, "", s.ChannelID())
	})

	t.Run("should reflect set values", func(t *testing.T) {
		s := &segmentData{segmentType: BroadcasterType, channelID: "chID"}
		assert.Equal(t, BroadcasterType, s.Type())
		assert.Equal(t, "chID", s.ChannelID())
	})
}

func TestSegment_Address(t *testing.T) {
	t.Run("should create a global address by default", func(t *testing.T) {
		var s Segment
		addr := s.Address("extID")
		assert.Equal(t, GlobalType, addr.SegmentType)
		assert.Equal(t, "", addr.ChannelID)
		assert.Equal(t, "extID", addr.ExtensionID)
	})

	t.Run("should pass set values", func(t *testing.T) {
		s := &segmentData{segmentType: BroadcasterType, channelID: "chID"}
		addr := s.Address("extID")
		assert.Equal(t, BroadcasterType, addr.SegmentType)
		assert.Equal(t, "chID", addr.ChannelID)
		assert.Equal(t, "extID", addr.ExtensionID)
	})
}

func TestSegment_String(t *testing.T) {
	seg := Global()
	assert.Equal(t, "global:", seg.String())

	seg, _ = Broadcaster("test")
	assert.Equal(t, "broadcaster:test", seg.String())

	seg, _ = Developer("test")
	assert.Equal(t, "developer:test", seg.String())
}

func TestSegment_MarshalJSON(t *testing.T) {
	var innerSeg segmentData
	seg := &innerSeg
	originalSeg := Global()
	bytes, err := originalSeg.MarshalJSON()
	assert.Nil(t, err)
	assert.Equal(t, `{"segment_type":"global"}`, string(bytes))
	assert.Nil(t, seg.UnmarshalJSON([]byte(`{"segment_type":"global"}`)))
	assert.Equal(t, originalSeg, seg)

	originalSeg, _ = Broadcaster("test")
	bytes, err = originalSeg.MarshalJSON()
	assert.Nil(t, err)
	assert.Equal(t, `{"segment_type":"broadcaster","channel_id":"test"}`, string(bytes))
	assert.Nil(t, seg.UnmarshalJSON([]byte(`{"segment_type":"broadcaster","channel_id":"test"}`)))
	assert.Equal(t, originalSeg, seg)

	originalSeg, _ = Developer("test")
	bytes, err = originalSeg.MarshalJSON()
	assert.Nil(t, err)
	assert.Equal(t, `{"segment_type":"developer","channel_id":"test"}`, string(bytes))
	assert.Nil(t, seg.UnmarshalJSON([]byte(`{"segment_type":"developer","channel_id":"test"}`)))
	assert.Equal(t, originalSeg, seg)

	// verify errors are forwarded correctly
	assert.IsType(t, &json.SyntaxError{}, seg.UnmarshalJSON([]byte("{")))
}

func TestSegment_SegmentsForChannel(t *testing.T) {
	t.Run("should include all per-channel segments", func(t *testing.T) {
		segs, err := SegmentsForChannel("chID", false)
		assert.NoError(t, err)
		require.Len(t, segs, 2)
		assert.Equal(t, &segmentData{segmentType: "developer", channelID: "chID"}, segs[0])
		assert.Equal(t, &segmentData{segmentType: "broadcaster", channelID: "chID"}, segs[1])
	})

	t.Run("should include common segments when asked", func(t *testing.T) {
		segs, err := SegmentsForChannel("chID", true)
		assert.NoError(t, err)
		require.Len(t, segs, 3)
		assert.Equal(t, &segmentData{segmentType: "developer", channelID: "chID"}, segs[0])
		assert.Equal(t, &segmentData{segmentType: "broadcaster", channelID: "chID"}, segs[1])
		assert.Equal(t, &segmentData{segmentType: "global"}, segs[2])
	})

	t.Run("should error when appropriate", func(t *testing.T) {
		segs, err := SegmentsForChannel("", false)
		assert.Empty(t, segs)
		assert.Equal(t, ErrIllegalSegmentChannelCode, errors.GetErrorCode(err))
	})
}
