package ingestqueueconsumer

//go:generate mockgen -destination ../mocks/mock_sqsiface/mock_sqsiface.go github.com/aws/aws-sdk-go/service/sqs/sqsiface SQSAPI
//go:generate mockgen -destination ../mocks/mock_client/mock_client.go github.com/aws/aws-sdk-go/aws/client ConfigProvider

import (
	"context"
	"encoding/base64"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/golang/mock/gomock"
	"github.com/golang/protobuf/proto"
	"github.com/pkg/errors"

	"code.justin.tv/dta/rockpaperscissors/internal/mocks/mock_sqsiface"
	"code.justin.tv/dta/rockpaperscissors/internal/testutil"
	pb "code.justin.tv/dta/rockpaperscissors/proto"
)

type MockBlueprintIngestor struct {
	lastRequest *pb.IngestBlueprintRequest
	response    *pb.IngestBlueprintResponse
	err         error
}

func (i *MockBlueprintIngestor) Ingest(ctx context.Context, req *pb.IngestBlueprintRequest) error {
	i.lastRequest = req
	return i.err
}

type MockGitHubStatsIngestor struct {
	lastRequest *pb.IngestGitHubStatsRequest
	err         error
}

func (i *MockGitHubStatsIngestor) Ingest(ctx context.Context, req *pb.IngestGitHubStatsRequest) error {
	i.lastRequest = req
	return i.err
}

func newConsumer(mockCtrl *gomock.Controller) (*IngestQueueConsumer, *mock_sqsiface.MockSQSAPI, *MockBlueprintIngestor, *MockGitHubStatsIngestor) {
	mockSQS := mock_sqsiface.NewMockSQSAPI(mockCtrl)
	mockBlueprintIngestor := &MockBlueprintIngestor{
		response: &pb.IngestBlueprintResponse{},
	}
	mockGitHubStatsIngestor := &MockGitHubStatsIngestor{}

	consumer := &IngestQueueConsumer{
		stopping:            make(chan chan interface{}),
		queueName:           "queue-name",
		sqs:                 mockSQS,
		blueprintIngestor:   mockBlueprintIngestor,
		gitHubStatsIngestor: mockGitHubStatsIngestor,
		// Reduce the wait times to make the test run quickly.
		pollSleepDuration: 100 * time.Millisecond,
		loopSleepDuration: 100 * time.Millisecond,
	}
	return consumer, mockSQS, mockBlueprintIngestor, mockGitHubStatsIngestor
}

type expectationsFuncType func(*IngestQueueConsumer, *mock_sqsiface.MockSQSAPI, *MockBlueprintIngestor, chan interface{})

func runTestInLoop(t *testing.T, expectationsFunc expectationsFuncType) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	consumer, mockSQS, mockBlueprintIngestor, _ := newConsumer(mockCtrl)
	stopChan := make(chan interface{})
	expectationsFunc(consumer, mockSQS, mockBlueprintIngestor, stopChan)

	err := consumer.Start()
	testutil.AssertNil(t, "Start() should return nil error", err)
	<-stopChan // Wait for our last mock expectation to have run
	consumer.Stop()
}

// Test that errors from GetQueueUrl are handled.
func TestGetQueueUrlError(t *testing.T) {
	mockCtrl := gomock.NewController(t)
	defer mockCtrl.Finish()

	consumer, mockSQS, _, _ := newConsumer(mockCtrl)

	gomock.InOrder(
		mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
			QueueName: aws.String(consumer.queueName),
		}).Return(nil, errors.New("An Error")),
	)

	// This first .Start() should return an error from a failed GetQueueUrl().
	err := consumer.Start()
	testutil.AssertNotNil(t, "Start() should return error", err)
}

// Test that errors from ReceiveMessage are handled.
func TestHandleReceiveMessageError(t *testing.T) {
	runTestInLoop(t,
		func(consumer *IngestQueueConsumer, mockSQS *mock_sqsiface.MockSQSAPI, mockIngestor *MockBlueprintIngestor, stopChan chan interface{}) {
			gomock.InOrder(
				mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
					QueueName: aws.String(consumer.queueName),
				}).Return(&sqs.GetQueueUrlOutput{QueueUrl: aws.String("queue-url")}, nil),

				// Loop has started.

				mockSQS.EXPECT().ReceiveMessage(&sqs.ReceiveMessageInput{
					QueueUrl:            aws.String("queue-url"),
					WaitTimeSeconds:     aws.Int64(consumer.waitTimeSeconds),
					MaxNumberOfMessages: aws.Int64(consumer.messagesReceivedPerLoop),
				}).Return(
					nil, errors.New("An Error")).Do(
					func(*sqs.ReceiveMessageInput) { stopChan <- true }),
			)
		})
}

// Test that a nil ReceiveMessage is handled.
func TestHandleReceiveMessageNilOutput(t *testing.T) {
	runTestInLoop(t,
		func(consumer *IngestQueueConsumer, mockSQS *mock_sqsiface.MockSQSAPI, mockIngestor *MockBlueprintIngestor, stopChan chan interface{}) {
			gomock.InOrder(
				mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
					QueueName: aws.String(consumer.queueName),
				}).Return(&sqs.GetQueueUrlOutput{QueueUrl: aws.String("queue-url")}, nil),

				// Loop has started.

				mockSQS.EXPECT().ReceiveMessage(&sqs.ReceiveMessageInput{
					QueueUrl:            aws.String("queue-url"),
					WaitTimeSeconds:     aws.Int64(consumer.waitTimeSeconds),
					MaxNumberOfMessages: aws.Int64(consumer.messagesReceivedPerLoop),
				}).Return(
					nil, nil).Do(
					func(*sqs.ReceiveMessageInput) { stopChan <- true }),
			)
		})
}

// Test that errors from DeleteMessage are handled.
func TestHandleDeleteMessageErr(t *testing.T) {
	runTestInLoop(t,
		func(consumer *IngestQueueConsumer, mockSQS *mock_sqsiface.MockSQSAPI, mockIngestor *MockBlueprintIngestor, stopChan chan interface{}) {
			gomock.InOrder(
				mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
					QueueName: aws.String(consumer.queueName),
				}).Return(&sqs.GetQueueUrlOutput{QueueUrl: aws.String("queue-url")}, nil),

				// Loop has started.

				mockSQS.EXPECT().ReceiveMessage(&sqs.ReceiveMessageInput{
					QueueUrl:            aws.String("queue-url"),
					WaitTimeSeconds:     aws.Int64(consumer.waitTimeSeconds),
					MaxNumberOfMessages: aws.Int64(consumer.messagesReceivedPerLoop),
				}).Return(
					&sqs.ReceiveMessageOutput{Messages: []*sqs.Message{
						{
							MessageId:     aws.String("message-id"),
							ReceiptHandle: aws.String("receipt-handle"),
							// empty string is a validly encoded request protobuf
							Body: aws.String(""),
						},
					}}, nil),
				mockSQS.EXPECT().DeleteMessage(&sqs.DeleteMessageInput{
					QueueUrl:      aws.String("queue-url"),
					ReceiptHandle: aws.String("receipt-handle"),
				}).Return(nil, errors.New("An Error")).Do(
					func(*sqs.DeleteMessageInput) { stopChan <- true }),
			)
		})
}

// Test that requests with bad base64 encoding are handled.
func TestHandleBadBase64(t *testing.T) {
	runTestInLoop(t,
		func(consumer *IngestQueueConsumer, mockSQS *mock_sqsiface.MockSQSAPI, mockIngestor *MockBlueprintIngestor, stopChan chan interface{}) {
			gomock.InOrder(
				mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
					QueueName: aws.String(consumer.queueName),
				}).Return(&sqs.GetQueueUrlOutput{QueueUrl: aws.String("queue-url")}, nil),

				// Loop has started.

				mockSQS.EXPECT().ReceiveMessage(&sqs.ReceiveMessageInput{
					QueueUrl:            aws.String("queue-url"),
					WaitTimeSeconds:     aws.Int64(consumer.waitTimeSeconds),
					MaxNumberOfMessages: aws.Int64(consumer.messagesReceivedPerLoop),
				}).Return(
					&sqs.ReceiveMessageOutput{Messages: []*sqs.Message{
						{
							MessageId:     aws.String("message-id"),
							ReceiptHandle: aws.String("receipt-handle"),
							Body:          aws.String("blah blah"),
						},
					}}, nil).Do(
					func(*sqs.ReceiveMessageInput) { stopChan <- true }),
			)
		})
}

// Test that requests with bad proto marshaling are handled.
func TestHandleBadProtobuf(t *testing.T) {
	runTestInLoop(t,
		func(consumer *IngestQueueConsumer, mockSQS *mock_sqsiface.MockSQSAPI, mockIngestor *MockBlueprintIngestor, stopChan chan interface{}) {
			gomock.InOrder(
				mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
					QueueName: aws.String(consumer.queueName),
				}).Return(&sqs.GetQueueUrlOutput{QueueUrl: aws.String("queue-url")}, nil),

				// Loop has started.

				mockSQS.EXPECT().ReceiveMessage(&sqs.ReceiveMessageInput{
					QueueUrl:            aws.String("queue-url"),
					WaitTimeSeconds:     aws.Int64(consumer.waitTimeSeconds),
					MaxNumberOfMessages: aws.Int64(consumer.messagesReceivedPerLoop),
				}).Return(
					&sqs.ReceiveMessageOutput{Messages: []*sqs.Message{
						{
							MessageId:     aws.String("message-id"),
							ReceiptHandle: aws.String("receipt-handle"),
							Body: aws.String(
								base64.StdEncoding.EncodeToString([]byte("blah blah"))),
						},
					}}, nil).Do(
					func(*sqs.ReceiveMessageInput) { stopChan <- true }),
			)
		})
}

// Test that responses with no requests is handled.
func TestHandleNoMessages(t *testing.T) {
	runTestInLoop(t,
		func(consumer *IngestQueueConsumer, mockSQS *mock_sqsiface.MockSQSAPI, mockIngestor *MockBlueprintIngestor, stopChan chan interface{}) {
			gomock.InOrder(
				mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
					QueueName: aws.String(consumer.queueName),
				}).Return(&sqs.GetQueueUrlOutput{QueueUrl: aws.String("queue-url")}, nil),

				// Loop has started.

				mockSQS.EXPECT().ReceiveMessage(&sqs.ReceiveMessageInput{
					QueueUrl:            aws.String("queue-url"),
					WaitTimeSeconds:     aws.Int64(consumer.waitTimeSeconds),
					MaxNumberOfMessages: aws.Int64(consumer.messagesReceivedPerLoop),
				}).Return(
					&sqs.ReceiveMessageOutput{}, nil).Do(
					func(*sqs.ReceiveMessageInput) { stopChan <- true }),
			)
		})
}

// Test that errors from IngestBlueprint are handled.
func TestHandleIngestBlueprintErr(t *testing.T) {
	runTestInLoop(t,
		func(consumer *IngestQueueConsumer, mockSQS *mock_sqsiface.MockSQSAPI, mockIngestor *MockBlueprintIngestor, stopChan chan interface{}) {
			req, err := proto.Marshal(&pb.IngestBlueprintRequest{
				Source: &pb.SourceFilePath{
					Path: proto.String("my.blueprint"),
				},
			})
			if err != nil {
				t.Fatalf("Failed to marshal protobuf: %v", err)
			}
			body := base64.StdEncoding.EncodeToString(req)

			mockIngestor.response = nil
			mockIngestor.err = errors.New("An Error")

			gomock.InOrder(
				mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
					QueueName: aws.String(consumer.queueName),
				}).Return(&sqs.GetQueueUrlOutput{QueueUrl: aws.String("queue-url")}, nil),

				// Loop has started.

				mockSQS.EXPECT().ReceiveMessage(&sqs.ReceiveMessageInput{
					QueueUrl:            aws.String("queue-url"),
					WaitTimeSeconds:     aws.Int64(consumer.waitTimeSeconds),
					MaxNumberOfMessages: aws.Int64(consumer.messagesReceivedPerLoop),
				}).Return(&sqs.ReceiveMessageOutput{Messages: []*sqs.Message{
					{
						MessageId:     aws.String("message-id"),
						ReceiptHandle: aws.String("receipt-handle"),
						Body:          aws.String(body),
					},
				}}, nil).Do(
					func(*sqs.ReceiveMessageInput) { stopChan <- true }),
			)
		})
}

// Test successful code path.
func TestHandleIngestBlueprintSuccess(t *testing.T) {
	runTestInLoop(t,
		func(consumer *IngestQueueConsumer, mockSQS *mock_sqsiface.MockSQSAPI, mockIngestor *MockBlueprintIngestor, stopChan chan interface{}) {
			req, err := proto.Marshal(&pb.IngestRequest{
				Request: &pb.IngestRequest_IngestBlueprintRequest{
					IngestBlueprintRequest: &pb.IngestBlueprintRequest{
						Source: &pb.SourceFilePath{
							Path: proto.String("my.blueprint"),
						},
					},
				},
			})
			if err != nil {
				t.Fatalf("Failed to marshal protobuf: %v", err)
			}
			body := base64.StdEncoding.EncodeToString(req)

			gomock.InOrder(
				mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
					QueueName: aws.String(consumer.queueName),
				}).Return(&sqs.GetQueueUrlOutput{QueueUrl: aws.String("queue-url")}, nil),

				// Loop has started.

				mockSQS.EXPECT().ReceiveMessage(&sqs.ReceiveMessageInput{
					QueueUrl:            aws.String("queue-url"),
					WaitTimeSeconds:     aws.Int64(consumer.waitTimeSeconds),
					MaxNumberOfMessages: aws.Int64(consumer.messagesReceivedPerLoop),
				}).Return(&sqs.ReceiveMessageOutput{Messages: []*sqs.Message{
					{
						MessageId:     aws.String("message-id"),
						ReceiptHandle: aws.String("receipt-handle"),
						Body:          aws.String(body),
					},
				}}, nil),
				mockSQS.EXPECT().DeleteMessage(&sqs.DeleteMessageInput{
					QueueUrl:      aws.String("queue-url"),
					ReceiptHandle: aws.String("receipt-handle"),
				}).Return(&sqs.DeleteMessageOutput{}, nil).Do(
					func(*sqs.DeleteMessageInput) { stopChan <- true }),
			)
		})
}

// Test successful code path.
func TestHandleIngestGithubStatsSuccess(t *testing.T) {
	runTestInLoop(t,
		func(consumer *IngestQueueConsumer, mockSQS *mock_sqsiface.MockSQSAPI, mockIngestor *MockBlueprintIngestor, stopChan chan interface{}) {
			req, err := proto.Marshal(&pb.IngestRequest{
				Request: &pb.IngestRequest_IngestGithubStatsRequest{
					IngestGithubStatsRequest: &pb.IngestGitHubStatsRequest{},
				},
			})
			if err != nil {
				t.Fatalf("Failed to marshal protobuf: %v", err)
			}
			body := base64.StdEncoding.EncodeToString(req)

			gomock.InOrder(
				mockSQS.EXPECT().GetQueueUrl(&sqs.GetQueueUrlInput{
					QueueName: aws.String(consumer.queueName),
				}).Return(&sqs.GetQueueUrlOutput{QueueUrl: aws.String("queue-url")}, nil),

				// Loop has started.

				mockSQS.EXPECT().ReceiveMessage(&sqs.ReceiveMessageInput{
					QueueUrl:            aws.String("queue-url"),
					WaitTimeSeconds:     aws.Int64(consumer.waitTimeSeconds),
					MaxNumberOfMessages: aws.Int64(consumer.messagesReceivedPerLoop),
				}).Return(&sqs.ReceiveMessageOutput{Messages: []*sqs.Message{
					{
						MessageId:     aws.String("message-id"),
						ReceiptHandle: aws.String("receipt-handle"),
						Body:          aws.String(body),
					},
				}}, nil),
				mockSQS.EXPECT().DeleteMessage(&sqs.DeleteMessageInput{
					QueueUrl:      aws.String("queue-url"),
					ReceiptHandle: aws.String("receipt-handle"),
				}).Return(&sqs.DeleteMessageOutput{}, nil).Do(
					func(*sqs.DeleteMessageInput) { stopChan <- true }),
			)
		})
}
