package service_common

import (
	"math/rand"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/feeds/ctxlog/ctxlogaws"
	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
	"code.justin.tv/feeds/log"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/sts"
	"golang.org/x/net/context"
)

// CreateAWSSession returns an aws session needed to connect to AWS services
func CreateAWSSession(dconf *distconf.Distconf) (*session.Session, []*aws.Config) {
	clientProvider := session.New()
	profileName := dconf.Str("aws.profile", "").Get()
	if profileName != "" {
		clientProvider.Config.Credentials = credentials.NewSharedCredentials("", profileName)
	}
	retConfig := []*aws.Config{}
	regionName := dconf.Str("aws.region", "").Get()
	if regionName != "" {
		retConfig = append(retConfig, &aws.Config{Region: &regionName})
	}

	assumedRole := dconf.Str("aws.assume_role", "").Get()
	if assumedRole != "" {
		stsclient := sts.New(session.New(retConfig...))
		arp := &stscreds.AssumeRoleProvider{
			ExpiryWindow: 10 * time.Second,
			RoleARN:      assumedRole,
			Client:       stsclient,
		}
		credentials := credentials.NewCredentials(arp)
		retConfig = append(retConfig, &aws.Config{
			Credentials: credentials,
		})
	}

	return clientProvider, retConfig
}

// ContextSend sends the context and, if the error is a throttled error, wraps it in a Throttled error type
func ContextSend(ctx context.Context, req *request.Request, elevatedLog *log.ElevatedLog) error {
	req.HTTPRequest = req.HTTPRequest.WithContext(ctx)
	err := ctxlogaws.DoAWSSend(req, elevatedLog)
	if err != nil {
		err = errors.Wrap(err, "unable to issue aws request")
		return err
	}
	return nil
}

// ThrottledBackoff allows a struct to share throttling across multiple goroutines via atomic operations
type ThrottledBackoff struct {
	// Should be >= 1.0
	Multiplier float64
	// Should be > 1
	SleepBackoff time.Duration
	MaxSleepTime time.Duration
	Rand         rand.Rand

	// mu only for rand.Rand
	mu          sync.Mutex
	timeToSleep int64
}

// DecreaseBackoff signals that throttling may no longer needed and decreases the throttle sleep amount
func (t *ThrottledBackoff) DecreaseBackoff() {
	for {
		currentSleep := atomic.LoadInt64(&t.timeToSleep)
		if currentSleep == 0 {
			return
		}
		newSleep := currentSleep - t.SleepBackoff.Nanoseconds()
		if newSleep < 0 {
			newSleep = 0
		}
		if atomic.CompareAndSwapInt64(&t.timeToSleep, currentSleep, newSleep) {
			return
		}
	}
}

func (t *ThrottledBackoff) int63n(max int64) int64 {
	t.mu.Lock()
	ret := t.Rand.Int63n(max)
	t.mu.Unlock()
	return ret
}

// ThrottledSleep will sleep till if being throttled, until ctx ends
func (t *ThrottledBackoff) ThrottledSleep(ctx context.Context) error {
	loadedSleepTime := atomic.LoadInt64(&t.timeToSleep)
	if loadedSleepTime == 0 {
		return nil
	}
	timeToSleep := time.Duration(t.int63n(loadedSleepTime))
	select {
	case <-ctx.Done():
		return ctx.Err()
	case <-time.After(timeToSleep):
	}
	return nil
}

// SignalThrottled signals that a throttle needs to happen
func (t *ThrottledBackoff) SignalThrottled() {
	currentSleepTime := atomic.LoadInt64(&t.timeToSleep)
	newSleepTime := int64(float64(currentSleepTime+t.SleepBackoff.Nanoseconds())*t.Multiplier) + 1
	if newSleepTime > t.MaxSleepTime.Nanoseconds() {
		newSleepTime = t.MaxSleepTime.Nanoseconds()
	}
	atomic.StoreInt64(&t.timeToSleep, newSleepTime)
}
