package api

import (
	"context"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/twitchtv/twirp"

	pubsub "code.justin.tv/extensions/smart/services/smart/clients/pubsub/mocks"
	validator "code.justin.tv/extensions/smart/services/smart/clients/validator/mocks"
	"code.justin.tv/extensions/smart/services/smart/rpc"
	store "code.justin.tv/extensions/smart/services/smart/store/mocks"
)

func TestPostMessage(t *testing.T) {
	req := &smartrpc.PostMessageRequest{
		Auth: &smartrpc.Token{ExtAuthToken: "test"},
		Topics: &smartrpc.Topics{
			ChannelId:   "test",
			ExtensionId: "test",
			Targets:     []string{"test"},
		},
		Message: &smartrpc.UnsequencedMessage{
			Content:     []string{"test"},
			ContentType: "test",
		},
	}

	t.Run("when no auth is provided", func(t *testing.T) {
		a := &API{Store: &store.AlwaysFail{}}
		req.Auth = nil
		_, err := a.PostMessage(context.Background(), req)
		assert.NotNil(t, err)
		assert.Equal(t, smartrpc.ErrNoAuthProvided, getErrorType(err))
	})
	req.Auth = &smartrpc.Token{ExtAuthToken: "test"}
	t.Run("when auth fails", func(t *testing.T) {
		a := &API{Store: &store.AlwaysFail{}, Validator: validator.AlwaysFail{}}
		_, err := a.PostMessage(context.Background(), req)
		assert.NotNil(t, err)
		assert.Equal(t, smartrpc.ErrNotAuthorized, getErrorType(err))
	})
	t.Run("when the store is bad", func(t *testing.T) {
		a := &API{Store: &store.AlwaysFail{}, Validator: validator.AlwaysPass{}}
		_, err := a.PostMessage(context.Background(), req)
		assert.NotNil(t, err)
		assert.Equal(t, smartrpc.ErrMemcacheFail, getErrorType(err))
	})
	t.Run("when the pubsub is bad", func(t *testing.T) {
		a := &API{Store: &store.AlwaysPass{}, Validator: validator.AlwaysPass{}, PubSub: pubsub.AlwaysFail{}}
		_, err := a.PostMessage(context.Background(), req)
		assert.NotNil(t, err)
		assert.Equal(t, smartrpc.ErrPubSubFail, getErrorType(err))
	})
	t.Run("when everything is good", func(t *testing.T) {
		a := &API{Store: &store.AlwaysPass{}, Validator: validator.AlwaysPass{}, PubSub: pubsub.AlwaysPass{}}
		topics, err := a.PostMessage(context.Background(), req)
		assert.Nil(t, err)
		assert.Len(t, topics.Targets, 1)
	})
	t.Run("when the message is too long", func(t *testing.T) {
		a := &API{Store: &store.AlwaysPass{}, Validator: validator.AlwaysPass{}, PubSub: pubsub.AlwaysPass{}}
		req.Message.Content = []string{*makeLongString(MessageSizeCap + 1)}
		topics, err := a.PostMessage(context.Background(), req)
		assert.NotNil(t, err)
		assert.Equal(t, smartrpc.ErrMessageTooLong, getErrorType(err))
		assert.Len(t, topics.Targets, 1)
	})
}

func TestCheckSize(t *testing.T) {
	l5 := makeLongString(5*1024 - 1)
	e5 := makeLongString(5 * 1024)
	g5 := makeLongString(5*1024 + 1)
	e10 := makeLongString(10 * 1024)
	g10 := makeLongString(10*1024 + 1)
	extID := "not overridden"
	overridden := "overridden"
	overrides := map[string]bool{overridden: true}
	t.Run("When <5k and not on override list", func(t *testing.T) {
		size, err := checkSize([]string{*l5}, extID, overrides)
		assert.Nil(t, err)
		assert.Equal(t, 5*1024-1, size)
	})
	t.Run("When =5k and not on override list", func(t *testing.T) {
		size, err := checkSize([]string{*e5}, extID, overrides)
		assert.Nil(t, err)
		assert.Equal(t, 5*1024, size)
	})
	t.Run("When >5k and not on override list", func(t *testing.T) {
		size, err := checkSize([]string{*g5}, extID, overrides)
		assert.NotNil(t, err)
		assert.Equal(t, 5*1024+1, size)
	})
	t.Run("When <5k and on override list", func(t *testing.T) {
		size, err := checkSize([]string{*l5}, overridden, overrides)
		assert.Nil(t, err)
		assert.Equal(t, 5*1024-1, size)
	})
	t.Run("When >5kb <10k and on override list", func(t *testing.T) {
		size, err := checkSize([]string{*g5}, overridden, overrides)
		assert.Nil(t, err)
		assert.Equal(t, 5*1024+1, size)
	})
	t.Run("When =10k and on override list", func(t *testing.T) {
		size, err := checkSize([]string{*e10}, overridden, overrides)
		assert.Nil(t, err)
		assert.Equal(t, 10*1024, size)
	})
	t.Run("When >10k and on override list", func(t *testing.T) {
		size, err := checkSize([]string{*g10}, overridden, overrides)
		assert.NotNil(t, err)
		assert.Equal(t, 10*1024+1, size)
	})

}
func makeLongString(n int) *string {
	b := make([]rune, n, n)
	s := string(b)
	return &s

}

func getErrorType(err error) string {
	tw, ok := err.(twirp.Error)
	if ok {
		return tw.Meta("error_key")
	}
	return ""
}
func getMessage(err error) string {
	tw, ok := err.(twirp.Error)
	if ok {
		return tw.Msg()
	}
	return ""
}
