package goretry_test

import (
	"context"
	"encoding/json"
	"errors"
	"net/http"
	"os"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/request"
	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
	"github.com/twitchtv/twirp"

	. "code.justin.tv/amzn/GoRetry"
	loggers "code.justin.tv/amzn/TwitchLoggingCommonLoggers"
)

func TestSuite(t *testing.T) {
	RegisterFailHandler(Fail)
	RunSpecs(t, "Test GoRetry")
}

var _ = Describe("RetryFuncWithContext", func() {
	retryTest(true)
})

var _ = Describe("RetryFunc", func() {
	retryTest(false)
})

func retryTest(withContext bool) {
	var (
		err           error
		function      interface{} // Should be either RetryFunc or RetryFuncWithContext
		options       Options
		ctx           context.Context
		totalDuration time.Duration
	)

	JustBeforeEach(func() {
		start := time.Now()
		if withContext {
			if function == nil {
				err = RetryFuncWithContext(ctx, nil, options)
			} else {
				err = RetryFuncWithContext(ctx, function.(RetryableFuncWithContext), options)
			}
		} else {
			if function == nil {
				err = RetryFunc(nil, options)
			} else {
				err = RetryFunc(function.(RetryableFunc), options)
			}
		}
		end := time.Now()
		totalDuration = end.Sub(start)
	})

	Context("when context is nil", func() {
		It("should return an error", func() {
			Expect(err).ToNot(BeNil())
			Expect(err.(twirp.Error).Code()).To(Equal(twirp.InvalidArgument))
		})
	})

	Context("when context is not nil", func() {
		BeforeEach(func() {
			ctx = context.Background()
		})

		Context("when options is non-empty", func() {
			BeforeEach(func() {
				options = Options{
					Timeout:           10 * time.Second,
					MaxAttempts:       3,
					InitialRetryDelay: time.Microsecond,
				}
				if withContext {
					function = func(ctx context.Context) (err error) {
						return nil
					}
				} else {
					function = func() (err error) {
						return nil
					}
				}
			})

			It("should not return an error", func() {
				Expect(err).To(BeNil())
			})

			Context("when Timeout is negative", func() {
				BeforeEach(func() {
					options.Timeout = -1
				})

				It("should return an error", func() {
					Expect(err).ToNot(BeNil())
					Expect(err.(twirp.Error).Code()).To(Equal(twirp.InvalidArgument))
				})
			})

			Context("when InitialRetryDelay is negative", func() {
				BeforeEach(func() {
					options.InitialRetryDelay = -1
				})

				It("should return an error", func() {
					Expect(err).ToNot(BeNil())
					Expect(err.(twirp.Error).Code()).To(Equal(twirp.InvalidArgument))
				})
			})

			Context("when TimeoutMax is less than Timeout", func() {
				BeforeEach(func() {
					options.TimeoutMax = time.Second
				})

				It("should return an error", func() {
					Expect(err).ToNot(BeNil())
					Expect(err.(twirp.Error).Code()).To(Equal(twirp.InvalidArgument))
				})
			})

			Context("when TimeoutMax is negative", func() {
				BeforeEach(func() {
					options.TimeoutMax = -1
				})

				It("should return an error", func() {
					Expect(err).ToNot(BeNil())
					Expect(err.(twirp.Error).Code()).To(Equal(twirp.InvalidArgument))
				})
			})
		})

		Context("when options is empty (except logger)", func() {
			BeforeEach(func() {
				options = Options{
					Logger: &loggers.JSONLogger{Dest: json.NewEncoder(os.Stdout), OnError: nil},
				}
				function = nil
			})

			Context("when retryFunc is nil", func() {
				It("should return an error", func() {
					Expect(err).ToNot(BeNil())
					Expect(err.(twirp.Error).Code()).To(Equal(twirp.InvalidArgument))
				})
			})

			Context("when retryFunc is not nil", func() {
				var (
					innerFunction RetryableFuncWithContext
					counter       int
					errText       string
				)

				BeforeEach(func() {
					counter = 0
					if withContext {
						function = func(ctx context.Context) (err error) {
							counter += 1
							return innerFunction(ctx)
						}
					} else {
						function = func() (err error) {
							counter += 1
							return innerFunction(context.Background())
						}
					}
				})

				Context("when retryFunc fails - retryable AWS 400", func() {
					BeforeEach(func() {
						errText = "ProvisionedThroughputExceededException" // DDB-specific retryable 4xx code
						innerFunction = func(ctx context.Context) (err error) {
							return awserr.NewRequestFailure(
								awserr.New(errText, "", nil),
								http.StatusBadRequest,
								"")
						}
					})

					It("should return an error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.(awserr.RequestFailure).StatusCode()).To(Equal(http.StatusBadRequest))
						Expect(err.(awserr.RequestFailure).Code()).To(Equal(errText))
						Expect(counter).To(Equal(3))
					})
				})

				Context("when retryFunc fails - retryable AWS 429", func() {
					BeforeEach(func() {
						errText = "TooManyRequestsException"
						innerFunction = func(ctx context.Context) (err error) {
							return awserr.NewRequestFailure(
								awserr.New(errText, "", nil),
								http.StatusTooManyRequests,
								"")
						}
					})

					It("should return an error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.(awserr.RequestFailure).StatusCode()).To(Equal(http.StatusTooManyRequests))
						Expect(err.(awserr.RequestFailure).Code()).To(Equal(errText))
						Expect(counter).To(Equal(3))
					})
				})

				Context("when retryFunc fails - retryable AWS 500", func() {
					BeforeEach(func() {
						errText = "SomeError"
						innerFunction = func(ctx context.Context) (err error) {
							return awserr.NewRequestFailure(
								awserr.New(errText, "", nil),
								http.StatusInternalServerError,
								"")
						}
					})

					It("should return an error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.(awserr.RequestFailure).StatusCode()).To(Equal(http.StatusInternalServerError))
						Expect(err.(awserr.RequestFailure).Code()).To(Equal(errText))
						Expect(counter).To(Equal(3))
					})
				})

				Context("when retryFunc fails - non-retryable AWS 400", func() {
					BeforeEach(func() {
						errText = "SomeError"
						innerFunction = func(ctx context.Context) (err error) {
							return awserr.NewRequestFailure(
								awserr.New(errText, "", nil),
								http.StatusBadRequest,
								"")
						}
					})

					It("should return an error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.(awserr.RequestFailure).StatusCode()).To(Equal(http.StatusBadRequest))
						Expect(err.(awserr.RequestFailure).Code()).To(Equal(errText))
						Expect(counter).To(Equal(1))
					})
				})

				Context("when retryFunc fails - non-retryable AWS non-request", func() {
					BeforeEach(func() {
						errText = "SomeError"
						innerFunction = func(ctx context.Context) (err error) {
							return awserr.New(
								request.InvalidParameterErrCode,
								errText,
								nil)
						}
					})

					It("should return an error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.(awserr.Error).Code()).To(Equal(request.InvalidParameterErrCode))
						Expect(err.(awserr.Error).Message()).To(Equal(errText))
						Expect(counter).To(Equal(1))
					})
				})

				Context("when retryFunc fails - retryable twirp", func() {
					BeforeEach(func() {
						errText = "SomeError"
						innerFunction = func(ctx context.Context) (err error) {
							return twirp.InternalError(errText)
						}
					})

					It("should return an error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.(twirp.Error).Code()).To(Equal(twirp.Internal))
						Expect(err.(twirp.Error).Msg()).To(Equal(errText))
						Expect(counter).To(Equal(3))
					})
				})

				Context("when retryFunc fails - non-retryable twirp", func() {
					BeforeEach(func() {
						errText = "SomeError"
						innerFunction = func(ctx context.Context) (err error) {
							return twirp.NotFoundError(errText)
						}
					})

					It("should return an error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.(twirp.Error).Code()).To(Equal(twirp.NotFound))
						Expect(err.(twirp.Error).Msg()).To(Equal(errText))
						Expect(counter).To(Equal(1))
					})
				})

				Context("when retryFunc fails - unrecoverable error", func() {
					BeforeEach(func() {
						errText = "SomeError"
						innerFunction = func(ctx context.Context) (err error) {
							return Unrecoverable(errors.New(errText))
						}
					})

					It("should return an error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.Error()).To(Equal(errText))
						Expect(counter).To(Equal(1))
					})
				})

				Context("when retryFunc fails - unhandled error", func() {
					BeforeEach(func() {
						errText = "SomeError"
						innerFunction = func(ctx context.Context) (err error) {
							return errors.New(errText)
						}
					})

					It("should return an error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.Error()).To(Equal(errText))
						Expect(counter).To(Equal(3))
					})
				})

				Context("when retryFunc takes too long", func() {
					var maxAttempts uint

					BeforeEach(func() {
						maxAttempts = 5
						errText = "unit test failed"
						options.Timeout = 100 * time.Millisecond
						options.InitialRetryDelay = time.Microsecond // negligible
						options.MaxAttempts = maxAttempts
						innerFunction = func(ctx context.Context) (err error) {
							select {
							case <-time.After(10 * time.Second):
								// Timed out internally - fail the test
								return twirp.NotFoundError(errText)
							case <-ctx.Done():
								// Caused by outer context cancellation - return InternalError to pass the test
								return twirp.InternalError("context deadline exceeded")
							}
						}
					})

					It("should return a context timeout error", func() {
						Expect(err).ToNot(BeNil())
						Expect(err.(twirp.Error).Code()).To(Equal(twirp.Internal))
						Expect(err.(twirp.Error).Msg()).ToNot(Equal(errText))
						Expect(counter).To(Equal(int(maxAttempts)))
					})

					It("should cancel context using outside timeout", func() {
						targetLatency := (100 + 100 + 100 + 100 + 100) * time.Millisecond
						targetLatency += 100 * time.Millisecond // Overhead
						Expect(targetLatency.Milliseconds()).To(
							BeNumerically("~", totalDuration.Milliseconds(), 100))
					})

					Context("exponential timeout", func() {
						BeforeEach(func() {
							options.TimeoutType = Exponential
							options.TimeoutMax = 500 * time.Millisecond
						})

						It("should cancel context using outside timeout", func() {
							targetLatency := (100 + 200 + 400 + 500 + 500) * time.Millisecond
							targetLatency += 100 * time.Millisecond // Overhead
							Expect(targetLatency.Milliseconds()).To(
								BeNumerically("~", totalDuration.Milliseconds(), 100))
						})
					})
				})

				Context("when retryFunc runs successfully", func() {
					BeforeEach(func() {
						innerFunction = func(ctx context.Context) (err error) {
							return nil
						}
					})

					It("should not return an error", func() {
						Expect(err).To(BeNil())
						Expect(counter).To(Equal(1))
					})
				})
			})
		})
	})
}
