package sqsextender

import (
	"context"
	"time"

	"github.com/aws/aws-sdk-go/service/sqs"
)

type emptyCircuit struct{}

func (e emptyCircuit) Run(ctx context.Context, f func(context.Context) error) error {
	return f(ctx)
}

var defaultCircuit = emptyCircuit{}

// SQSMessageTimeoutExtender can extend SQS messages over time to allow them to continue to be processed
type SQSMessageTimeoutExtender struct {
	SQS          *sqs.SQS
	QueueURL     string
	Circuit      Circuit
	ExtendBuffer time.Duration
}

func asInt64(i int64) *int64 {
	return &i
}

// Circuit is anything that should wrap the SQS.Extend call
type Circuit interface {
	Run(context.Context, func(context.Context) error) error
}

// Extend should take a SQS message handle and returns an extender that you should call `Extend` on with a context,
// and `Stop` on when done
func (s *SQSMessageTimeoutExtender) Extend(msgHandle *string) *ExtendedSQSMessage {
	extendTime := s.ExtendBuffer
	if extendTime == 0 {
		extendTime = time.Minute * 5
	}
	return &ExtendedSQSMessage{
		SQS:          s.SQS,
		queueURL:     s.QueueURL,
		msgHandle:    msgHandle,
		circuit:      s.Circuit,
		extendBuffer: extendTime,
		onStop:       make(chan struct{}),
	}
}

// ExtendedSQSMessage is a single message that is being extended
type ExtendedSQSMessage struct {
	SQS          *sqs.SQS
	queueURL     string
	msgHandle    *string
	circuit      Circuit
	extendBuffer time.Duration
	onStop       chan struct{}
}

// Stop extending a message that was previously started with Extend
func (s *ExtendedSQSMessage) Stop() {
	close(s.onStop)
}

// Done returns a chan that is closed when ExtendedSQSMessage stops
func (s *ExtendedSQSMessage) Done() <-chan struct{} {
	return s.onStop
}

func (s *ExtendedSQSMessage) getCircuit() Circuit {
	if s.circuit == nil {
		return defaultCircuit
	}
	return s.circuit
}

const extendRatio = 10

// Extend runs blocking (call this in a goroutine) that extends a msg as long as the ctx lives
func (s *ExtendedSQSMessage) Extend(ctx context.Context) error {
	var failuresInARow int
	fetchTime := time.Now()
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-s.onStop:
			return nil
		case <-time.After(time.Duration(s.extendBuffer.Nanoseconds()/extendRatio) + time.Second):
			newTimeout := time.Since(fetchTime) + s.extendBuffer + time.Second
			in := &sqs.ChangeMessageVisibilityInput{
				QueueUrl:          &s.queueURL,
				ReceiptHandle:     s.msgHandle,
				VisibilityTimeout: asInt64(int64(newTimeout.Seconds())),
			}
			req, _ := s.SQS.ChangeMessageVisibilityRequest(in)
			err := s.getCircuit().Run(ctx, func(ctx context.Context) error {
				req.SetContext(ctx)
				return req.Send()
			})
			if err != nil {
				failuresInARow++
				if failuresInARow == extendRatio-3 {
					return err
				}
			} else {
				failuresInARow = 0
			}
		}
	}
}
