package clients

import (
	"context"
	"math/rand"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	log "github.com/sirupsen/logrus"
)

const (
	maxRetryCount = 3
	baseBackoff   = 100 * time.Millisecond
)

// Retry uses a RetryPolicy to retry calls when an error is encountered, if the callback returns true for "stop" the retyr process is cancelled and the error propagated.
func Retry(ctx context.Context, retryPolicy RetryPolicy, fn func() (stop bool, err error)) error {
	var lastErr error

	for retriedCount := 0; retriedCount <= retryPolicy.GetRetryCount(); retriedCount++ {
		stop, err := fn()

		if stop || err == nil {
			return err
		}

		if err != nil {
			lastErr = err
		}

		log.WithField("retryCount", retriedCount+1).WithError(err).Warn("Error calling dependency during retry.")

		backoff := retryPolicy.GetSleepDuration()

		if err := aws.SleepWithContext(ctx, backoff); err != nil {
			// Context cancelled, deadline exceeded, etc.
			return err
		}
	}

	return lastErr
}

type RetryPolicy interface {
	GetRetryCount() int
	GetSleepDuration() time.Duration
}

// Retry policy that performs exponential backoff (100ms, 200ms, 400ms, etc) with a random 20% jitter.
type BackoffWithJitterRetryPolicy struct {
	backoff time.Duration
}

func NewBackoffWithJitterRetryPolicy() RetryPolicy {
	return &BackoffWithJitterRetryPolicy{
		backoff: baseBackoff,
	}
}

func (p *BackoffWithJitterRetryPolicy) GetRetryCount() int {
	return maxRetryCount
}

func (p *BackoffWithJitterRetryPolicy) GetSleepDuration() time.Duration {
	// Add random jitter up to 20% of the total backoff time.
	jitterMax := int64(p.backoff) / 5
	jitter := time.Duration(rand.Int63n(jitterMax)) // #nosec
	backoff := p.backoff + jitter

	p.backoff *= 2

	return backoff
}
