package sqsclient

import (
	"context"
	"encoding/json"
	"errors"
	"sync"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
	"github.com/aws/aws-sdk-go/service/sts"
	"github.com/aws/aws-sdk-go/service/sts/stsiface"
	uuid "github.com/gofrs/uuid"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	eventbus "code.justin.tv/eventbus/client"
	"code.justin.tv/eventbus/client/internal"
	"code.justin.tv/eventbus/client/internal/testevents"
	"code.justin.tv/eventbus/client/internal/testschema/clock"
	"code.justin.tv/eventbus/client/internal/wire"
	"code.justin.tv/eventbus/client/lowlevel/snsmarshal"
	"code.justin.tv/eventbus/client/publisher"
)

func TestPollerAndShutdown(t *testing.T) {
	var mu sync.Mutex
	mux := eventbus.NewMux()
	var handlerCalls []int
	var handlerRetryCounts []int
	clock.RegisterClockUpdateHandler(mux, func(ctx context.Context, header *eventbus.Header, tick *clock.ClockUpdate) error {
		mu.Lock()
		handlerCalls = append(handlerCalls, int(tick.Time.GetSeconds()))
		retryCount, err := eventbus.RetryCountFromContext(ctx)
		if err != nil {
			return err
		}
		handlerRetryCounts = append(handlerRetryCounts, retryCount)
		mu.Unlock()
		return nil
	})

	// Fill a channel with a bunch of receive message results
	ch := filler(context.Background(),
		&sqs.ReceiveMessageOutput{
			Messages: []*sqs.Message{
				sqsPayload(testevents.NewClockUpdate(1234)),
				sqsPayload(testevents.NewClockUpdate(1235)),
			},
		},
		&sqs.ReceiveMessageOutput{
			Messages: []*sqs.Message{
				sqsPayload(testevents.NewClockUpdate(1236)),
				sqsPayload(testevents.NewClockUpdate(1237)),
			},
		},
		&sqs.ReceiveMessageOutput{},
		100*time.Millisecond,
		&sqs.ReceiveMessageOutput{},
		100*time.Millisecond,
		&sqs.ReceiveMessageOutput{},
	)
	mock := &mockSQS{ToReceive: ch}

	c, err := New(Config{
		OverrideClient:     mock,
		Dispatcher:         mux.Dispatcher(),
		QueueURL:           "http://localhost/example",
		MinPollers:         1,
		DeliverConcurrency: 2,
	})
	require.NoError(t, err)

	t.Run("Receives some calls", func(t *testing.T) {
		time.Sleep(50 * time.Millisecond)
		mock.mu.Lock()
		assert.True(t, len(mock.ReceiveCalls) > 3)
		mock.mu.Unlock()

		mu.Lock()
		assert.ElementsMatch(t, []int{1234, 1235, 1236, 1237}, handlerCalls)
		assert.ElementsMatch(t, []int{1, 1, 1, 1}, handlerRetryCounts)
		mu.Unlock()
	})

	t.Run("Shutdown without error", func(t *testing.T) {
		err := c.Shutdown()
		assert.NoError(t, err)
	})

	t.Run("All delete batches were called", func(t *testing.T) {
		var deleteBatchIds []string
		for _, call := range mock.DeleteBatchCalls {
			for _, entry := range call.Entries {
				deleteBatchIds = append(deleteBatchIds, *entry.ReceiptHandle)
			}
		}
		assert.Len(t, deleteBatchIds, 4)
	})
}

func TestDeleteBatching(t *testing.T) {
	mux := eventbus.NewMux()

	// this handler will error on any tick which is a multiple of 20 (5% rejection)
	clock.RegisterClockUpdateHandler(mux, func(ctx context.Context, header *eventbus.Header, tick *clock.ClockUpdate) error {
		if tick.Time.GetSeconds()%20 == 0 {
			return errors.New("Multiple of 20")
		}
		return nil
	})

	var expectedHandles []string
	var outputs []interface{}
	for i := 1; i <= 500; i++ {
		message := sqsPayload(testevents.NewClockUpdate(int64(i)))
		if i%20 != 0 {
			expectedHandles = append(expectedHandles, *message.ReceiptHandle)
		}
		outputs = append(outputs, &sqs.ReceiveMessageOutput{
			Messages: []*sqs.Message{message},
		}, 1*time.Millisecond)
	}

	for i := 0; i < 10; i++ {
		outputs = append(outputs, 50*time.Millisecond, &sqs.ReceiveMessageOutput{})
	}

	fillCtx, cancel := context.WithCancel(context.Background())
	defer cancel()

	ch := filler(fillCtx, outputs...)
	mock := &mockSQS{ToReceive: ch}

	collectHandles := func() (handles []string) {
		mock.mu.Lock()
		// collect all the receipt handles from the calls
		for _, call := range mock.DeleteBatchCalls {
			for _, entry := range call.Entries {
				handles = append(handles, *entry.ReceiptHandle)
			}
		}
		mock.mu.Unlock()
		return handles
	}

	c, err := New(Config{
		OverrideClient:     mock,
		Dispatcher:         mux.Dispatcher(),
		QueueURL:           "http://localhost/example",
		MinPollers:         1,
		DeliverConcurrency: 2,
	})
	require.NoError(t, err)

	time.Sleep(500 * time.Millisecond)
	// wait until our expected work is done, typically happens in 2 loops,
	// but a little bit of leeway is given for clock jitter and busy machines.
	for i := 0; i < 10; i++ {
		time.Sleep(50 * time.Millisecond)
		if handles := collectHandles(); len(handles) >= len(expectedHandles) {
			break
		}
	}

	t.Run("Shutdown without error", func(t *testing.T) {
		err := c.Shutdown()
		assert.NoError(t, err)
	})

	t.Run("Received all deletes in fairly small batches", func(t *testing.T) {
		mock.mu.Lock()
		assert.True(t, len(mock.DeleteBatchCalls) >= 48)
		assert.True(t, len(mock.DeleteBatchCalls) < 100)
		mock.mu.Unlock()

		handles := collectHandles()
		// ElementsMatch ignores ordering which makes our life easier
		assert.ElementsMatch(t, expectedHandles, handles)
	})

}

func TestConstruction(t *testing.T) {
	sess := session.Must(session.NewSession())
	mockSTS := &mockSTSClient{}
	mockVersion := &mockCFNVersionChecker{}

	// Calls out to assume role
	sqsClient, err := getSQSClient(sess, mockSTS, mockVersion)
	require.NoError(t, err)
	require.NotNil(t, sqsClient)
	assert.Equal(t, 1, mockSTS.calls)

	// Does not call out to assume role
	mockVersion.shouldReturn(true, nil)
	mockSTS.calls = 0 // reset
	sqsClient, err = getSQSClient(sess, mockSTS, mockVersion)
	require.NoError(t, err)
	require.NotNil(t, sqsClient)
	assert.Equal(t, 0, mockSTS.calls)

}

//////////////////////////////////////////
// Testing helpers and mocks
//////////////////////////////////////////

func sqsPayload(input internal.Message) *sqs.Message {
	buf, err := wire.DefaultsEncode(input, string(publisher.EnvStaging))
	if err != nil {
		panic(err)
	}
	encoded := snsmarshal.EncodeToString(buf)

	j, err := json.Marshal(snsMessage{
		Message: encoded,
	})
	if err != nil {
		panic(err)
	}
	id := uuid.Must(uuid.NewV4()).String()
	return &sqs.Message{
		MessageId:     aws.String(id),
		Body:          aws.String(string(j)),
		ReceiptHandle: aws.String(id),
		Attributes: map[string]*string{
			"ApproximateReceiveCount": aws.String("1"),
		},
	}
}

type receiveOut struct {
	msg *sqs.ReceiveMessageOutput
	err error
}

type mockSQS struct {
	sqsiface.SQSAPI

	mu               sync.Mutex
	ReceiveCalls     []sqs.ReceiveMessageInput
	DeleteBatchCalls []sqs.DeleteMessageBatchInput
	ToReceive        chan receiveOut
}

func (s *mockSQS) ReceiveMessageWithContext(ctx aws.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) {
	s.mu.Lock()
	s.ReceiveCalls = append(s.ReceiveCalls, *input)
	s.mu.Unlock()
	select {
	case v, ok := <-s.ToReceive:
		if !ok {
			return nil, errors.New("Mock SQS finished")
		}
		return v.msg, v.err
	case <-ctx.Done():
		return nil, ctx.Err()
	}
}

func (s *mockSQS) DeleteMessageBatchWithContext(ctx aws.Context, input *sqs.DeleteMessageBatchInput, opts ...request.Option) (*sqs.DeleteMessageBatchOutput, error) {
	s.mu.Lock()
	s.DeleteBatchCalls = append(s.DeleteBatchCalls, *input)
	s.mu.Unlock()
	return &sqs.DeleteMessageBatchOutput{}, nil
}

func filler(ctx context.Context, input ...interface{}) chan receiveOut {
	ch := make(chan receiveOut)

	// pre-process bare outputs and bare errors
	for i, v := range input {
		switch v := v.(type) {
		case *sqs.ReceiveMessageOutput:
			input[i] = receiveOut{msg: v}
		case error:
			input[i] = receiveOut{err: v}
		}
	}

	go func() {
		defer close(ch)
		for _, v := range input {
			switch v := v.(type) {
			case time.Duration:
				select {
				case <-ctx.Done():
					return
				case <-time.After(v):
				}
			case receiveOut:
				select {
				case ch <- v:
				case <-ctx.Done():
					return
				}
			default:
				panic(v)
			}
		}
	}()
	return ch
}

type mockSTSClient struct {
	stsiface.STSAPI
	calls int
}

func (m *mockSTSClient) GetCallerIdentity(input *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) {
	m.calls++
	return &sts.GetCallerIdentityOutput{
		Account: aws.String("123456789012"),
		Arn:     aws.String("arn:aws:iam::123456789012:resource"),
		UserId:  aws.String("123abc"),
	}, nil
}

type mockCFNVersionChecker struct {
	retBool bool
	retErr  error
}

func (m *mockCFNVersionChecker) shouldReturn(b bool, err error) {
	m.retBool = b
	m.retErr = err
}

func (m *mockCFNVersionChecker) Require(ctx context.Context, sess client.ConfigProvider, version string) (bool, error) {
	return m.retBool, m.retErr
}
