package goretry

import (
	"context"
	"fmt"
	"net/http"
	"time"

	"github.com/avast/retry-go"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/twitchtv/twirp"

	logging "code.justin.tv/amzn/TwitchLogging"
)

const (
	timeout           = 10 * time.Second
	maxAttempts       = 3
	initialRetryDelay = 10 * time.Millisecond
)

type TimeoutType int

const (
	Fixed       TimeoutType = 0
	Exponential TimeoutType = 1
)

type Options struct {
	Timeout           time.Duration
	TimeoutMax        time.Duration
	TimeoutType       TimeoutType
	MaxAttempts       uint
	InitialRetryDelay time.Duration
	Logger            logging.Logger
	currentTimeout    time.Duration
}

type RetryableFunc = func() (err error)
type RetryableFuncWithContext = func(ctx context.Context) (err error)

func RetryFuncWithContext(ctx context.Context, retryFunc RetryableFuncWithContext, options Options) error {
	if err := validateParams(ctx, retryFunc, &options); err != nil {
		return err
	}

	return retryDo(
		func() (err error) {
			ctx, cancel := context.WithTimeout(ctx, options.currentTimeout)
			defer cancel()
			return retryFunc(ctx)
		},
		&options,
	)
}

func RetryFunc(retryFunc RetryableFunc, options Options) error {
	ctx := context.Background()
	if err := validateParams(ctx, retryFunc, &options); err != nil {
		return err
	}

	return retryDo(
		func() (err error) {
			ctx, cancel := context.WithTimeout(ctx, options.currentTimeout)
			defer cancel()

			respChan := make(chan error, 1)
			go func() { // Enclose in goroutine to isolate blocking threads
				respChan <- retryFunc()
			}()

			select {
			case <-ctx.Done():
				return twirp.InternalError("context deadline exceeded")
			case err = <-respChan:
				return err
			}
		},
		&options,
	)
}

func retryDo(function retry.RetryableFunc, options *Options) error {
	return retry.Do(
		function,
		retry.OnRetry(func(n uint, err error) {
			if options.TimeoutType == Fixed {
				// Do nothing
			} else if options.TimeoutType == Exponential {
				options.currentTimeout <<= 1
			}
			if options.TimeoutMax != 0 && options.currentTimeout > options.TimeoutMax {
				options.currentTimeout = options.TimeoutMax
			}
			if options.Logger != nil {
				errorMessage := "Retrying request"
				if n+1 == options.MaxAttempts {
					errorMessage = "Retry attempts exceeded"
				}
				options.Logger.Log(fmt.Sprintf("%s. Failed attempts: %d", errorMessage, n+1), "error", err)
			}
		}),
		retry.Attempts(options.MaxAttempts),
		retry.Delay(options.InitialRetryDelay),
		retry.RetryIf(func(err error) bool {
			// Test AWS errors
			if awsErr, ok := err.(awserr.Error); ok {
				if request.IsErrorRetryable(awsErr) || request.IsErrorThrottle(awsErr) ||
					awsErr.Code() == request.CanceledErrorCode { // Retry since we requested cancellation
					return true
				}

				if reqErr, ok := err.(awserr.RequestFailure); ok {
					return reqErr.StatusCode() >= 500 || reqErr.StatusCode() == http.StatusTooManyRequests
				}

				return false // Non-retryable AWS error
			}

			// Test twirp errors
			if twErr, ok := err.(twirp.Error); ok {
				return twirp.ServerHTTPStatusFromErrorCode(twErr.Code()) >= 500 // 5xx codes are transient
			}

			return retry.IsRecoverable(err) // Unknown error, allow retry if error is not Unrecoverable
		}),
		retry.LastErrorOnly(true),
	)
}

func Unrecoverable(err error) error {
	return retry.Unrecoverable(err)
}

func validateParams(ctx context.Context, retryFunc interface{}, options *Options) error {
	if ctx == nil {
		return twirp.InvalidArgumentError("ctx", "Cannot be nil")
	}

	nilFunc := retryFunc == nil
	if !nilFunc {
		switch f := retryFunc.(type) {
		case RetryableFunc:
			nilFunc = f == nil
		case RetryableFuncWithContext:
			nilFunc = f == nil
		default:
			nilFunc = true
		}
	}

	if nilFunc {
		return twirp.InvalidArgumentError("retryFunc", "Cannot be nil")
	}

	if options.Timeout == 0 {
		options.Timeout = timeout
	} else if options.Timeout < 0 {
		return twirp.InvalidArgumentError("options.Timeout", "cannot be negative")
	}

	options.currentTimeout = options.Timeout

	if options.TimeoutMax > 0 && options.TimeoutMax < options.Timeout {
		return twirp.InvalidArgumentError("options.TimeoutMax", "cannot be lower than options.Timeout")
	} else if options.TimeoutMax < 0 {
		return twirp.InvalidArgumentError("options.TimeoutMax", "cannot be negative")
	}

	if options.MaxAttempts == 0 {
		options.MaxAttempts = maxAttempts
	}

	if options.InitialRetryDelay == 0 {
		options.InitialRetryDelay = initialRetryDelay
	} else if options.InitialRetryDelay < 0 {
		return twirp.InvalidArgumentError("options.InitialRetryDelay", "cannot be negative")
	}

	return nil
}
