package lambda_test

import (
	"context"
	"errors"
	"net/http"
	"testing"

	"code.justin.tv/amzn/TwirpGoLangAWSTransports/internal/stalkmarket"
	lambda_transport "code.justin.tv/amzn/TwirpGoLangAWSTransports/lambda"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/service/lambda"
	"github.com/stretchr/testify/assert"
	"github.com/twitchtv/twirp"
)

const testLambdaFunctionArn = "arn:aws:lambda:us-west-2:1234567890:function:StalkMarketLambda-LambdaFunction-1234567890/live"

type fakeStalkMarketServer struct {
}

func (f *fakeStalkMarketServer) GetPrice(c context.Context, request *stalkmarket.GetPriceRequest) (*stalkmarket.Price, error) {
	return &stalkmarket.Price{
		Price:     3,
		Vegetable: request.Vegetable,
	}, nil
}

func (f *fakeStalkMarketServer) UpdatePrice(c context.Context, request *stalkmarket.UpdatePriceRequest) (*stalkmarket.Price, error) {
	return &stalkmarket.Price{
		Price:     request.Price,
		Vegetable: request.Vegetable,
	}, nil
}

type erroringStalkMarketServer struct {
}

func (f *erroringStalkMarketServer) GetPrice(c context.Context, request *stalkmarket.GetPriceRequest) (*stalkmarket.Price, error) {
	return nil, errors.New("error")
}

func (f *erroringStalkMarketServer) UpdatePrice(c context.Context, request *stalkmarket.UpdatePriceRequest) (*stalkmarket.Price, error) {
	return nil, errors.New("error")
}

func newLambdaHTTPClientWithService(svc stalkmarket.StalkMarket) (*lambda_transport.Client, *lambda_transport.TestServer) {
	ptl := &lambda_transport.TestServer{
		Server: stalkmarket.NewStalkMarketServer(svc, nil),
	}
	return lambda_transport.NewClient(ptl, testLambdaFunctionArn), ptl
}

func TestLambdaClientJSONRequestSuccess(t *testing.T) {
	sm := &fakeStalkMarketServer{}

	cl, _ := newLambdaHTTPClientWithService(sm)
	smc := stalkmarket.NewStalkMarketJSONClient("", cl)

	resp, err := smc.UpdatePrice(context.Background(), &stalkmarket.UpdatePriceRequest{
		Price:     32,
		Vegetable: "tamater",
	})

	assert.NoError(t, err)
	assert.Equal(t, int32(32), resp.Price)
	assert.Equal(t, "tamater", resp.Vegetable)
}

func TestLambdaClientJSONRequestSuccessWithLambdaEndpoint(t *testing.T) {
	sm := &fakeStalkMarketServer{}

	cl, _ := newLambdaHTTPClientWithService(sm)
	smc := stalkmarket.NewStalkMarketJSONClient(testLambdaFunctionArn, cl)

	resp, err := smc.UpdatePrice(context.Background(), &stalkmarket.UpdatePriceRequest{
		Price:     32,
		Vegetable: "tamater",
	})

	assert.NoError(t, err)
	assert.Equal(t, int32(32), resp.Price)
	assert.Equal(t, "tamater", resp.Vegetable)
}

func TestLambdaClientJSONRequestFailureWithDifferentEndpoint(t *testing.T) {
	sm := &fakeStalkMarketServer{}

	cl, _ := newLambdaHTTPClientWithService(sm)
	smc := stalkmarket.NewStalkMarketJSONClient("foo"+testLambdaFunctionArn+"/bar", cl)

	_, err := smc.UpdatePrice(context.Background(), &stalkmarket.UpdatePriceRequest{
		Price:     32,
		Vegetable: "tamater",
	})

	assert.Error(t, err)
	assert.Contains(t, err.Error(), "unable to extract a path from URL")
}

// TestHeaderPassing makes sure http headers are passed to the lambda
func TestHeaderPassing(t *testing.T) {
	sm := &fakeStalkMarketServer{}

	cl, srv := newLambdaHTTPClientWithService(sm)
	smc := stalkmarket.NewStalkMarketJSONClient("", cl)

	// Some test headers
	header := make(http.Header)
	header.Set("foo", "bar")
	header.Set("fizz", "buzz")
	ctx := context.Background()
	ctx, err := twirp.WithHTTPRequestHeaders(ctx, header)
	if err != nil {
		t.Fail()
		return
	}

	_, _ = smc.UpdatePrice(ctx, &stalkmarket.UpdatePriceRequest{
		Price:     32,
		Vegetable: "tamater",
	})

	assert.Equal(t, 1, len(srv.Requests))
	request := srv.Requests[0]
	assert.Equal(t, "application/json", request.Header.Get("Content-Type"))
	assert.Equal(t, "application/json", request.Header.Get("Accept"))
	assert.NotNil(t, request.Header.Get("Twirp-Version"))
	assert.Equal(t, "bar", request.Header.Get("foo"))
	assert.Equal(t, "buzz", request.Header.Get("fizz"))
}

func TestLambdaClientProtobufRequestSuccess(t *testing.T) {
	sm := &fakeStalkMarketServer{}

	cl, _ := newLambdaHTTPClientWithService(sm)
	smc := stalkmarket.NewStalkMarketJSONClient("", cl)

	resp, err := smc.UpdatePrice(context.Background(), &stalkmarket.UpdatePriceRequest{
		Price:     32,
		Vegetable: "tamater",
	})

	assert.NoError(t, err)
	assert.Equal(t, int32(32), resp.Price)
	assert.Equal(t, "tamater", resp.Vegetable)
}

func TestLambdaClientProtoOnError(t *testing.T) {
	sm := &erroringStalkMarketServer{}

	cl, _ := newLambdaHTTPClientWithService(sm)
	smc := stalkmarket.NewStalkMarketJSONClient("", cl)

	resp, err := smc.UpdatePrice(context.Background(), &stalkmarket.UpdatePriceRequest{
		Price:     32,
		Vegetable: "tamater",
	})
	assertTwirpError(t, err, twirp.Internal)
	assert.NotNil(t, resp)
}

type failingLambdaClient struct {
}

func TestLambdaClientOnLambdaError(t *testing.T) {
	ec := &lambda_transport.LambdaInvokeErrorClient{}
	client := lambda_transport.NewClient(ec, "arn")
	smc := stalkmarket.NewStalkMarketProtobufClient("", client)
	ec.SetError(awserr.New(lambda.ErrCodeInvalidRequestContentException, "test invoke error", nil))

	resp, err := smc.UpdatePrice(context.Background(), &stalkmarket.UpdatePriceRequest{
		Price:     32,
		Vegetable: "tamater",
	})
	assert.NotNil(t, resp)
	assertTwirpError(t, err, twirp.Internal)
}

func TestLambdaInvokeErrors(t *testing.T) {
	invokeErrors := map[string]twirp.ErrorCode{
		lambda.ErrCodeEC2AccessDeniedException:             twirp.Unavailable,
		lambda.ErrCodeEC2ThrottledException:                twirp.Unavailable,
		lambda.ErrCodeEC2UnexpectedException:               twirp.Unavailable,
		lambda.ErrCodeENILimitReachedException:             twirp.Unavailable,
		lambda.ErrCodeInvalidParameterValueException:       twirp.Internal,
		lambda.ErrCodeInvalidRequestContentException:       twirp.Internal,
		lambda.ErrCodeInvalidRuntimeException:              twirp.Unavailable,
		lambda.ErrCodeInvalidSecurityGroupIDException:      twirp.Unavailable,
		lambda.ErrCodeInvalidSubnetIDException:             twirp.Unavailable,
		lambda.ErrCodeInvalidZipFileException:              twirp.Unavailable,
		lambda.ErrCodeKMSAccessDeniedException:             twirp.Unavailable,
		lambda.ErrCodeKMSDisabledException:                 twirp.Unavailable,
		lambda.ErrCodeKMSInvalidStateException:             twirp.Unavailable,
		lambda.ErrCodeKMSNotFoundException:                 twirp.Unavailable,
		lambda.ErrCodeRequestTooLargeException:             twirp.Internal,
		lambda.ErrCodeResourceNotFoundException:            twirp.BadRoute,
		lambda.ErrCodeServiceException:                     twirp.Unavailable,
		lambda.ErrCodeSubnetIPAddressLimitReachedException: twirp.Unavailable,
		lambda.ErrCodeTooManyRequestsException:             twirp.Unavailable,
		lambda.ErrCodeUnsupportedMediaTypeException:        twirp.Internal,
		"unmapped lambda error":                            twirp.Internal,
	}
	ec := &lambda_transport.LambdaInvokeErrorClient{}
	client := lambda_transport.NewClient(ec, "arn")
	smc := stalkmarket.NewStalkMarketProtobufClient("", client)

	for lambdaCode, twirpCode := range invokeErrors {
		t.Logf("Testing lambda error \"%v\", expecting \"%v\"", lambdaCode, twirpCode)
		ec.SetError(awserr.New(lambdaCode, "test invoke error", nil))
		resp, err := smc.UpdatePrice(context.Background(), &stalkmarket.UpdatePriceRequest{
			Price:     32,
			Vegetable: "tamater",
		})
		assert.NotNil(t, resp)
		assertTwirpError(t, err, twirpCode)
	}
}

func TestLambdaFunctionError(t *testing.T) {
	ec := &lambda_transport.LambdaInvokeErrorClient{}
	client := lambda_transport.NewClient(ec, "arn")
	smc := stalkmarket.NewStalkMarketProtobufClient("", client)

	ec.SetError(nil)
	ec.SetFunctionError(&lambda_transport.LambdaErrorJSON{
		Message:    "Some terrible error occurred",
		Type:       "string_error",
		StackTrace: []string{"a", "b", "c"},
	})
	resp, err := smc.UpdatePrice(context.Background(), &stalkmarket.UpdatePriceRequest{
		Price:     32,
		Vegetable: "tamater",
	})
	assert.NotNil(t, resp)
	assertTwirpError(t, err, twirp.Internal)
	assertTwirpErrorMessage(t, err, "There was an error executing the Lambda function.")
}

func TestLambdaFunctionName(t *testing.T) {
	const lambdaName = "LAMBDANAME"
	l := lambda_transport.NewClient(nil, lambdaName)
	assert.Equal(t, lambdaName, l.LambdaName())
}

func assertTwirpError(t *testing.T, err error, twirpCode twirp.ErrorCode) {
	assert.Error(t, err)
	terr, ok := err.(twirp.Error)
	assert.True(t, ok)
	assert.NotNil(t, terr)
	assert.Equal(t, twirpCode, terr.Code())
}

func assertTwirpErrorMessage(t *testing.T, err error, msg string) {
	assert.Error(t, err)
	terr, ok := err.(twirp.Error)
	assert.True(t, ok)
	assert.NotNil(t, terr)
	assert.Equal(t, msg, terr.Msg())
}
