package sqsclient

import (
	"context"
	"encoding/json"
	"testing"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/pkg/errors"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	eventbus "code.justin.tv/eventbus/client"
	"code.justin.tv/eventbus/client/internal/sqspoller"
	"code.justin.tv/eventbus/client/internal/testschema/clock"
	"code.justin.tv/eventbus/client/internal/wire"
	"code.justin.tv/eventbus/client/lowlevel/snsmarshal"
	"code.justin.tv/eventbus/client/publisher"
)

func TestHookForwarding(t *testing.T) {
	var received HookEvent

	callback := wrapCallback(func(event HookEvent) {
		received = event
	})

	t.Run("Forwardable type works", func(t *testing.T) {
		callback(&sqspoller.HookEventAckError{
			Err:            errors.New("an error"),
			ReceiptHandles: []string{"abc", "def"},
		})

		require.NotNil(t, received)
		assert.NotNil(t, received.Error())
		assert.NotEmpty(t, received.String())

		t.Run("type can be converted to mirror type", func(t *testing.T) {
			converted, ok := received.(*HookEventAckError)
			require.True(t, ok)
			assert.Equal(t, "def", converted.ReceiptHandles[1])
		})
	})
}

func TestDecodeHooks(t *testing.T) {
	var received []HookEvent

	callback := func(event HookEvent) {
		received = append(received, event)
	}

	var handlerErr error

	mux := eventbus.NewMux()
	clock.RegisterUpdateHandler(mux, func(ctx context.Context, h *eventbus.Header, u *clock.Update) error {
		return handlerErr
	})

	client := SQSClient{
		callback:    callback,
		dispatcher:  mux.Dispatcher(),
		poller:      &sqspoller.SQSPoller{},
		ctxRequests: context.Background(),
	}

	constructSNS := func(message string) string {
		v := snsMessage{Message: message}
		buf, err := json.Marshal(v)
		if err != nil {
			panic(err)
		}
		return string(buf)
	}

	clockbytes, err := wire.DefaultsEncode(&clock.ClockUpdate{}, string(publisher.EnvStaging))
	require.NoError(t, err)

	t.Run("correct message dispatch works", func(t *testing.T) {
		message := &sqs.Message{
			Body: aws.String(constructSNS(snsmarshal.EncodeToString(clockbytes))),
			Attributes: map[string]*string{
				"ApproximateReceiveCount": aws.String("1"),
			},
		}
		err := client.doDeliver(message)
		require.NoError(t, err)
		assert.Len(t, received, 0)
	})

	t.Run("Bad SNS payload", func(t *testing.T) {
		message := &sqs.Message{
			MessageId: aws.String("abcd-id-f"),
			Body:      aws.String(`{notjson}`),
			Attributes: map[string]*string{
				"ApproximateReceiveCount": aws.String("1"),
			},
		}
		err := client.doDeliver(message)
		require.Error(t, err)
		assert.Len(t, received, 1)
		_, ok := received[0].(*HookEventDecodeError)
		assert.True(t, ok)
		assert.Equal(t, `Decoding message "abcd-id-f": could not parse SNS message body from SQS message: invalid character 'n' looking for beginning of object key string`, received[0].String())
	})

	t.Run("Bad base64", func(t *testing.T) {
		message := &sqs.Message{
			MessageId: aws.String("abcd-id-g"),
			Body:      aws.String(constructSNS("z!b")),
			Attributes: map[string]*string{
				"ApproximateReceiveCount": aws.String("1"),
			},
		}
		err := client.doDeliver(message)
		require.Error(t, err)
		assert.Len(t, received, 2)
		event, ok := received[1].(*HookEventDecodeError)
		assert.True(t, ok)
		assert.Equal(t, `Decoding message "abcd-id-g": could not decode event bus payload in SNS message: illegal base64 data at input byte 1`, event.String())
	})

	t.Run("Bad Eventbus Payload", func(t *testing.T) {
		message := &sqs.Message{
			MessageId: aws.String("abcd-id-h"),
			Body:      aws.String(constructSNS(snsmarshal.EncodeToString([]byte("notprotobuf")))),
			Attributes: map[string]*string{
				"ApproximateReceiveCount": aws.String("1"),
			},
		}
		err := client.doDeliver(message)
		require.Error(t, err)
		assert.Len(t, received, 3)
		event, ok := received[2].(*HookEventDecodeError)
		assert.True(t, ok)
		assert.Equal(t, `Decoding message "abcd-id-h": Decode Error: Unknown message version`, event.String())

		_, ok = event.Err.(eventbus.DecodeError)
		assert.True(t, ok, "original error was not a DecodeError")
	})

	t.Run("Dispatch error", func(t *testing.T) {
		handlerErr = errors.New("Handler Error")
		message := &sqs.Message{
			MessageId: aws.String("abcd-id-i"),
			Body:      aws.String(constructSNS(snsmarshal.EncodeToString(clockbytes))),
			Attributes: map[string]*string{
				"ApproximateReceiveCount": aws.String("1"),
			},
		}
		err := client.doDeliver(message)
		require.Error(t, err)
		assert.Len(t, received, 4)
		event, ok := received[3].(*HookEventDispatchError)
		assert.True(t, ok)
		assert.Equal(t, `Dispatching message "abcd-id-i": could not dispatch: Handler Error`, event.String())
	})
}

type causer interface {
	Cause() error
}
