package backoff

import (
	"math/rand"
	"sync"
	"time"
)

const (
	DefaultExponentialMultiplier = 1.6
	DefaultExponentialJitter     = 0.2
	DefaultExponentialDelayMin   = 100 * time.Millisecond
	DefaultExponentialDelayMax   = 3 * time.Second
)

type (
	ExponentialStrategy struct {
		multiplier float64
		jitter     float64
		min        float64
		max        float64
		rndLock    sync.Mutex
		rnd        *rand.Rand
	}

	ExponentialStrategyOption func(*ExponentialStrategy)
)

var (
	DefaultExponential = NewExponential(DefaultExponentialDelayMin, DefaultExponentialDelayMax)
)

// NewExponential returns backoff strategy that perform exponential backoff with jitter.
// Based on the attempt number and limited by the provided minimum and maximum durations.
func NewExponential(min, max time.Duration, opts ...ExponentialStrategyOption) *ExponentialStrategy {
	strategy := &ExponentialStrategy{
		multiplier: DefaultExponentialMultiplier,
		jitter:     DefaultExponentialJitter,
		min:        float64(min),
		max:        float64(max),
		rnd:        rand.New(rand.NewSource(time.Now().UnixNano())),
	}

	for _, opt := range opts {
		opt(strategy)
	}

	return strategy
}

// WithExponentialMultiplier sets custom multiplier.
//
// By default used DefaultExponentialMultiplier.
func WithExponentialMultiplier(multiplier float64) ExponentialStrategyOption {
	return func(s *ExponentialStrategy) {
		s.multiplier = multiplier
	}
}

// WithExponentialJitter sets custom jitter.
//
// By default used DefaultExponentialJitter.
func WithExponentialJitter(jitter float64) ExponentialStrategyOption {
	return func(s *ExponentialStrategy) {
		s.jitter = jitter
	}
}

// WithExponentialJitter sets custom random source.
func WithExponentialRand(rnd *rand.Rand) ExponentialStrategyOption {
	return func(s *ExponentialStrategy) {
		s.rnd = rnd
	}
}

func (s *ExponentialStrategy) Delay(err error, attempt int) time.Duration {
	if attempt == 0 {
		return time.Duration(s.min)
	}

	delay := s.min
	for delay < s.max && attempt > 0 {
		delay *= s.multiplier
		attempt--
	}

	if delay > s.max {
		delay = s.max
	}

	s.rndLock.Lock()
	rnd := s.rnd.Float64()
	s.rndLock.Unlock()

	delay *= 1 + s.jitter*(rnd*2-1)
	if delay < 0 {
		return 0
	}

	return time.Duration(delay)
}
