package aws

import (
	"code.justin.tv/event-engineering/goldengate/pkg/aws/backend/backendfakes"
	loggingfakes "code.justin.tv/event-engineering/goldengate/pkg/logging/backend/backendfakes"
	"fmt"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/s3"
	"github.com/aws/aws-sdk-go/service/s3/s3manager"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/aws/aws-sdk-go/service/ssm"
	"github.com/pkg/errors"
	"github.com/stretchr/testify/assert"
	"net/http"
	"net/url"
	"strings"
	"testing"
)

func TestUploadRecording(t *testing.T) {
	t.Parallel()
	a := assert.New(t)

	fakeClient := new(backendfakes.FakeClient)
	fakeLogger := new(loggingfakes.FakeLogger)
	client := New(fakeClient, fakeLogger)

	// Test erroring upload
	fakeClient.S3UploadStub = func(input *s3manager.UploadInput, options ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
		return nil, errors.New("something broke")
	}

	recordingURL, err := client.UploadRecording(nil, "test-bucket", "text-key")
	a.EqualValues("", recordingURL)
	a.NotNil(err)
	a.True(strings.HasPrefix(err.Error(), "Failed to upload recording:"), "error message should start with Failed to upload recording: - got `%v`", err.Error())

	// Test erroring url generation
	fakeClient.S3UploadStub = func(input *s3manager.UploadInput, options ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
		return nil, nil
	}

	fakeClient.S3GetObjectRequestStub = func(input *s3.GetObjectInput) (req *request.Request, output *s3.GetObjectOutput) {
		return &request.Request{
			Operation: &request.Operation{
				Name:       "GetObject",
				HTTPMethod: "GET",
				HTTPPath:   fmt.Sprintf("/{%v/%v", input.Bucket, input.Key),
			},
			HTTPRequest: &http.Request{
				URL: &url.URL{
					Scheme: "https",
					Host:   "google.com",
					Path:   "test",
				},
			},
		}, nil
	}

	presignedURLExpiry = 0

	recordingURL, err = client.UploadRecording(nil, "test-bucket", "text-key")
	a.EqualValues("", recordingURL)
	a.NotNil(err)
	a.True(strings.HasPrefix(err.Error(), "Failed to generate presigned URL:"), "error message should start with Failed to generate presigned URL: - got `%v`", err.Error())

	// Test successful URL generation
	presignedURLExpiry = 1

	recordingURL, err = client.UploadRecording(nil, "test-bucket", "text-key")
	a.EqualValues("https://google.com/test", recordingURL)
	a.Nil(err)
}

// Just make sure that the right methods are being called and the data is being passed through correctly
func TestDDBPassthrough(t *testing.T) {
	t.Parallel()
	a := assert.New(t)

	fakeClient := new(backendfakes.FakeClient)
	fakeLogger := new(loggingfakes.FakeLogger)
	client := New(fakeClient, fakeLogger)

	// Query
	fakeClient.DDBQueryStub = func(input *dynamodb.QueryInput) (*dynamodb.QueryOutput, error) {
		a.EqualValues("DDB_TABLE", *input.TableName)
		var count int64 = 1337
		return &dynamodb.QueryOutput{
			Count: &count,
		}, nil
	}

	tableName := "DDB_TABLE"
	queryInput := &dynamodb.QueryInput{
		TableName: &tableName,
	}

	queryOutput, err := client.DDBQuery(queryInput)
	a.Nil(err)
	a.EqualValues(1337, *queryOutput.Count)
	a.EqualValues(1, fakeClient.DDBQueryCallCount())

	// Get
	fakeClient.DDBGetItemStub = func(input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
		a.EqualValues("DDB_TABLE", *input.TableName)

		testVal := "test_value"
		return &dynamodb.GetItemOutput{
			Item: map[string]*dynamodb.AttributeValue{
				"test_attribute": &dynamodb.AttributeValue{S: &testVal},
			},
		}, nil
	}

	getInput := &dynamodb.GetItemInput{
		TableName: &tableName,
	}

	getOutput, err := client.DDBGetItem(getInput)
	a.Nil(err)
	a.EqualValues("test_value", *getOutput.Item["test_attribute"].S)
	a.EqualValues(1, fakeClient.DDBGetItemCallCount())

	// Put
	fakeClient.DDBPutItemStub = func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
		a.EqualValues("DDB_TABLE", *input.TableName)

		testVal := "test_value"
		return &dynamodb.PutItemOutput{
			Attributes: map[string]*dynamodb.AttributeValue{
				"test_attribute": &dynamodb.AttributeValue{S: &testVal},
			},
		}, nil
	}

	putInput := &dynamodb.PutItemInput{
		TableName: &tableName,
	}

	putOutput, err := client.DDBPutItem(putInput)
	a.Nil(err)
	a.EqualValues("test_value", *putOutput.Attributes["test_attribute"].S)
	a.EqualValues(1, fakeClient.DDBPutItemCallCount())

	// Delete
	fakeClient.DDBDeleteItemStub = func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) {
		a.EqualValues("DDB_TABLE", *input.TableName)

		testVal := "test_value"
		return &dynamodb.DeleteItemOutput{
			Attributes: map[string]*dynamodb.AttributeValue{
				"test_attribute": &dynamodb.AttributeValue{S: &testVal},
			},
		}, nil
	}

	deleteInput := &dynamodb.DeleteItemInput{
		TableName: &tableName,
	}

	deleteOutput, err := client.DDBDeleteItem(deleteInput)
	a.Nil(err)
	a.EqualValues("test_value", *deleteOutput.Attributes["test_attribute"].S)
	a.EqualValues(1, fakeClient.DDBDeleteItemCallCount())
}

// Just make sure that the right methods are being called and the data is being passed through correctly
func TestSQSPassthrough(t *testing.T) {
	t.Parallel()
	a := assert.New(t)

	fakeClient := new(backendfakes.FakeClient)
	fakeLogger := new(loggingfakes.FakeLogger)
	client := New(fakeClient, fakeLogger)

	queueURL := "QUEUE_URL"

	fakeClient.SendSQSMessageStub = func(input *sqs.SendMessageInput) (*sqs.SendMessageOutput, error) {
		a.EqualValues("QUEUE_URL", *input.QueueUrl)

		messageID := "message_id"
		return &sqs.SendMessageOutput{
			MessageId: &messageID,
		}, nil
	}

	input := &sqs.SendMessageInput{
		QueueUrl: &queueURL,
	}

	output, err := client.SendSQSMessage(input)
	a.Nil(err)
	a.EqualValues("message_id", *output.MessageId)
	a.EqualValues(1, fakeClient.SendSQSMessageCallCount())
}

// Just make sure that the right methods are being called and the data is being passed through correctly
func TestSSMPassthrough(t *testing.T) {
	t.Parallel()
	a := assert.New(t)

	fakeClient := new(backendfakes.FakeClient)
	fakeLogger := new(loggingfakes.FakeLogger)
	client := New(fakeClient, fakeLogger)

	paramName := "PARAM_NAME"

	fakeClient.SSMGetParametersStub = func(input *ssm.GetParametersInput) (*ssm.GetParametersOutput, error) {
		a.EqualValues("PARAM_NAME", *input.Names[0])

		param := "param_name"
		value := "param_value"
		return &ssm.GetParametersOutput{
			Parameters: []*ssm.Parameter{
				&ssm.Parameter{
					Name:  &param,
					Value: &value,
				},
			},
		}, nil
	}

	input := &ssm.GetParametersInput{
		Names: []*string{&paramName},
	}

	output, err := client.SSMGetParameters(input)
	a.Nil(err)
	a.EqualValues("param_value", *output.Parameters[0].Value)
	a.EqualValues(1, fakeClient.SSMGetParametersCallCount())
}
