package lambdafunc

import (
	"context"
	"testing"
	"time"

	"github.com/aws/aws-lambda-go/events"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	eventbus "code.justin.tv/eventbus/client"
	"code.justin.tv/eventbus/client/internal/testevents"
	"code.justin.tv/eventbus/client/internal/testschema/clock"
)

func TestHandler(t *testing.T) {
	type call struct {
		header *eventbus.Header
		tick   *clock.ClockUpdate
	}

	var calls []call
	mux := eventbus.NewMux()
	clock.RegisterClockUpdateHandler(mux, func(ctx context.Context, h *eventbus.Header, tick *clock.ClockUpdate) error {
		calls = append(calls, call{h, tick})
		return nil
	})

	h := NewSNS(mux.Dispatcher())

	t.Run("Errors on zero records", func(t *testing.T) {
		err := h(context.Background(), events.SNSEvent{
			Records: nil,
		})
		require.Error(t, err)
		assert.Equal(t, "Invalid record count 0 != 1", err.Error())
	})

	t.Run("Errors on >1 records", func(t *testing.T) {
		err := h(context.Background(), events.SNSEvent{
			Records: []events.SNSEventRecord{
				{},
				{},
			},
		})
		require.Error(t, err)
		assert.Equal(t, "Invalid record count 2 != 1", err.Error())
	})

	t.Run("Errors on Base64 decoding", func(t *testing.T) {
		event := basicEvent()
		event.SNS.Message = "!!abz"

		err := h(context.Background(), events.SNSEvent{
			Records: []events.SNSEventRecord{event},
		})
		require.Error(t, err)
		assert.Equal(t, "illegal base64 data at input byte 0", err.Error())
	})

	t.Run("Errors when invalid UUID", func(t *testing.T) {
		event := basicEvent()
		event.SNS.Message = testevents.InvalidMessageID.String()
		err := h(context.Background(), events.SNSEvent{
			Records: []events.SNSEventRecord{event},
		})
		require.Error(t, err)
	})

	t.Run("Runs with valid message", func(t *testing.T) {
		event := basicEvent()
		event.SNS.Message = testevents.ClockUpdateApril1.String()
		err := h(context.Background(), events.SNSEvent{
			Records: []events.SNSEventRecord{event},
		})
		require.NoError(t, err)
		require.Len(t, calls, 1)
		assert.Equal(t, testevents.April1UTCSeconds, calls[0].header.CreatedAt.Unix())
		assert.Equal(t, testevents.UUID1(), calls[0].header.MessageID.Bytes())
		assert.Equal(t, int64(1522540800), calls[0].tick.Time.GetSeconds())
	})
}

func basicEvent() events.SNSEventRecord {
	return events.SNSEventRecord{
		EventVersion:         "1.0",
		EventSubscriptionArn: "arn:aws:blah",
		EventSource:          "aws:sns",
		SNS: events.SNSEntity{
			SignatureVersion: "1",
			Timestamp:        time.Now(),
			Signature:        "EXAMPLE",
			SigningCertURL:   "EXAMPLE",
			MessageID:        "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
			Message:          "",
		},
	}
}
