package sqsbatcher

import (
	"fmt"
	"testing"

	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
)

type mockSQSClient struct {
	mock.Mock
}

func (c *mockSQSClient) SendMessageBatch(input *sqs.SendMessageBatchInput) (*sqs.SendMessageBatchOutput, error) {
	args := c.Called(input)
	return args.Get(0).(*sqs.SendMessageBatchOutput), args.Error(1)
}

// helper needed because we can't take the address of string literals
func strAddr(s string) *string {
	return &s
}

func TestSqsBatcherSingleString(t *testing.T) {
	m := &mockSQSClient{}
	b := SQSBatcher{
		MaxLen: 10,
		SQS:    m,
	}
	err := b.SendString("hello")
	assert.NoError(t, err, "single SendString failed")
	m.AssertNotCalled(t, "SendMessageBatch")

	m.On("SendMessageBatch", &sqs.SendMessageBatchInput{
		Entries: []*sqs.SendMessageBatchRequestEntry{
			{Id: strAddr("0"), MessageBody: strAddr("hello")},
		},
	}).Return(&sqs.SendMessageBatchOutput{}, nil).Once()

	err = b.Flush()
	assert.NoError(t, err, "Flush failed after single SendString")
	m.AssertExpectations(t)
}

func TestSqsBatcherMultiString(t *testing.T) {
	m := &mockSQSClient{}
	b := SQSBatcher{
		MaxLen: 3,
		SQS:    m,
	}

	m.On("SendMessageBatch", &sqs.SendMessageBatchInput{
		Entries: []*sqs.SendMessageBatchRequestEntry{
			{Id: strAddr("0"), MessageBody: strAddr("hello 0")},
			{Id: strAddr("1"), MessageBody: strAddr("hello 1")},
			{Id: strAddr("2"), MessageBody: strAddr("hello 2")},
		},
	}).Return(&sqs.SendMessageBatchOutput{}, nil).Once()

	for i := 0; i < 5; i++ {
		msgbody := fmt.Sprintf("hello %d", i)
		b.SendString(msgbody)
	}

	m.AssertExpectations(t)

	m.On("SendMessageBatch", &sqs.SendMessageBatchInput{
		Entries: []*sqs.SendMessageBatchRequestEntry{
			{Id: strAddr("0"), MessageBody: strAddr("hello 3")},
			{Id: strAddr("1"), MessageBody: strAddr("hello 4")},
		},
	}).Return(&sqs.SendMessageBatchOutput{}, nil).Once()

	b.Flush()

	m.AssertExpectations(t)
}

func TestSqsBatcherFailed(t *testing.T) {
	m := &mockSQSClient{}
	b := SQSBatcher{
		MaxLen: 10,
		SQS:    m,
	}
	b.SendString("hello")
	b.SendString("thiswillfail")

	m.On("SendMessageBatch", &sqs.SendMessageBatchInput{
		Entries: []*sqs.SendMessageBatchRequestEntry{
			{Id: strAddr("0"), MessageBody: strAddr("hello")},
			{Id: strAddr("1"), MessageBody: strAddr("thiswillfail")},
		},
	}).Return(&sqs.SendMessageBatchOutput{
		Failed: []*sqs.BatchResultErrorEntry{{Id: strAddr("1"), Code: strAddr("WAT"), Message: strAddr("uh oh")}},
	}, nil).Once()

	err := b.Flush()
	assert.Error(t, err, "Flush didn't fail when batch send returned failed entry")
	m.AssertExpectations(t)
}

func TestSqsBatcherErr(t *testing.T) {
	m := &mockSQSClient{}
	b := SQSBatcher{
		MaxLen: 10,
		SQS:    m,
	}
	b.SendString("hello")
	b.SendString("hello2")

	m.On("SendMessageBatch", &sqs.SendMessageBatchInput{
		Entries: []*sqs.SendMessageBatchRequestEntry{
			{Id: strAddr("0"), MessageBody: strAddr("hello")},
			{Id: strAddr("1"), MessageBody: strAddr("hello2")},
		},
	}).Return(&sqs.SendMessageBatchOutput{}, fmt.Errorf("failed")).Once()

	err := b.Flush()
	assert.Error(t, err, "Flush didn't fail when batch send returned error")
	m.AssertExpectations(t)
}
