package service_common

import (
	"math"
	"sync"
	"time"

	"strings"

	"code.justin.tv/feeds/ctxlog"
	"code.justin.tv/feeds/ctxlog/ctxlogaws"
	"code.justin.tv/feeds/ctxlog/ctxlogaws/ctxlogsqs"
	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
	"code.justin.tv/feeds/log"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/sqs"
	"golang.org/x/net/context"
)

// SQSQueueProcessorConfig configures the queue that masonry reads from
type SQSQueueProcessorConfig struct {
	WaitTime          *distconf.Duration
	DrainingThreads   *distconf.Int
	VisibilityTimeout *distconf.Duration
	QueueURL          *distconf.Str
}

// Verify a SQS config.  Should be called inside Load()
func (s *SQSQueueProcessorConfig) Verify(dconf *distconf.Distconf, prefix string, waitTime time.Duration, drainingThreads int64, visTimeout time.Duration) error {
	if strings.HasSuffix(prefix, ".") || prefix == "" {
		return errors.New("prefix should not end with dot and should be non empty")
	}
	s.WaitTime = dconf.Duration(prefix+".sqssource.wait_time", waitTime)
	s.DrainingThreads = dconf.Int(prefix+".sqssource.draining_threads", drainingThreads)
	// Do not make this shorter than the time it takes to process a message
	s.VisibilityTimeout = dconf.Duration(prefix+".sqssource.visibility_timeout", visTimeout)
	s.QueueURL = dconf.Str(prefix+".sqssource.queue_url", "")
	if s.QueueURL.Get() == "" {
		return errors.Errorf("variable %s.sqssource.queue_url must be set", prefix)
	}
	return nil
}

// SQSQueueProcessor does common SQS processing in multiple goroutines
type SQSQueueProcessor struct {
	Log             *log.ElevatedLog
	Ch              *ctxlog.Ctxlog
	Sqs             *sqs.SQS
	Stats           *StatSender
	StopWaitChannel chan struct{}
	ProcessMessage  func(ctx context.Context, msg *sqs.Message) error `nilcheck:"ignore"` // Ignored because people usually set a private function here
	Conf            *SQSQueueProcessorConfig
	stopFlag        chan struct{}
}

// Setup sets the StopWaitChannel if it's nil
func (s *SQSQueueProcessor) Setup() error {
	s.stopFlag = make(chan struct{})
	if s.StopWaitChannel == nil {
		s.StopWaitChannel = make(chan struct{})
	}
	return nil
}

// Close the SQSInput.  Will block if Start() is never called
func (s *SQSQueueProcessor) Close() error {
	s.Log.Debug("Close called")
	close(s.stopFlag)
	<-s.StopWaitChannel
	return nil
}

// Start SQSInput reading.  Will panic if started twice
func (s *SQSQueueProcessor) Start() error {
	defer close(s.StopWaitChannel)
	wg := sync.WaitGroup{}
	// TODO: Make the number of draining threads configurable
	numThreads := int(s.Conf.DrainingThreads.Get())
	wg.Add(numThreads)
	rootCtx := context.Background()
	for i := 0; i < numThreads; i++ {
		go func() {
			defer wg.Done()
			s.process(rootCtx)
		}()
	}
	wg.Wait()
	return nil
}

func (s *SQSQueueProcessor) shutdownCalled() bool {
	select {
	case <-s.stopFlag:
		return true
	default:
		return false
	}
}

func cancelWhen(ctx context.Context, cancelFunc func(), whenToCancel <-chan struct{}) {
	select {
	case <-ctx.Done():
	case <-whenToCancel:
		cancelFunc()
	}
}

func (s *SQSQueueProcessor) process(rootCtx context.Context) {
	s.Log.Log("processing SQS")
	defer s.Log.Log("sqs process done")
	for !s.shutdownCalled() {
		receiveMessageInput := &sqs.ReceiveMessageInput{
			QueueUrl:              aws.String(s.Conf.QueueURL.Get()),
			WaitTimeSeconds:       aws.Int64(int64(math.Ceil(s.Conf.WaitTime.Get().Seconds()))),
			VisibilityTimeout:     aws.Int64(int64(math.Ceil(s.Conf.VisibilityTimeout.Get().Seconds()))),
			MessageAttributeNames: []*string{aws.String("All")},
		}
		startTime := time.Now()
		recvCtx, onCan := context.WithCancel(rootCtx)
		go cancelWhen(recvCtx, onCan, s.stopFlag)
		req, msgOut := s.Sqs.ReceiveMessageRequest(receiveMessageInput)
		req.HTTPRequest = req.HTTPRequest.WithContext(recvCtx)
		err := req.Send()
		onCan()
		if s.shutdownCalled() {
			continue
		}

		if err != nil {
			s.Stats.IncC("sqs.ReceiveMessage.err", 1, 1)
			s.Log.LogCtx(recvCtx, "sqs_url", s.Conf.QueueURL.Get(), "err", err, "cannot receive messages")
			continue
		}
		s.Stats.TimingDurationC("sqs.ReceiveMessage", time.Since(startTime), 1)
		s.Stats.IncC("msg_input", int64(len(msgOut.Messages)), 1)
		s.Log.DebugCtx(recvCtx, "msglen", len(msgOut.Messages), "Got some messages!")
		// Note: Remember to call cancel at end of the loop
		batchContext, cancel := context.WithTimeout(rootCtx, s.Conf.VisibilityTimeout.Get())
		for _, m := range msgOut.Messages {
			processStart := time.Now()
			processCtx := ctxlogsqs.ExtractContext(batchContext, m.MessageAttributes, s.Ch)
			err := s.ProcessMessage(processCtx, m)
			if err != nil {
				s.Stats.IncC("ProcessMessage.err", 1, 1)
				s.Log.LogCtx(processCtx, log.Err, err, "Unable to process SQS message")
				continue
			}
			s.Stats.TimingDurationC("ProcessMessage.time", time.Since(processStart), 1)

			deleteMsgInput := &sqs.DeleteMessageInput{
				QueueUrl:      aws.String(s.Conf.QueueURL.Get()),
				ReceiptHandle: m.ReceiptHandle,
			}
			startDelete := time.Now()
			req, _ := s.Sqs.DeleteMessageRequest(deleteMsgInput)
			if err := ctxlogaws.DoAWSSend(req, s.Log); err != nil {
				s.Stats.IncC("sqs.DeleteMessage.err", 1, 1)
				s.Log.LogCtx(processCtx, "err", err, "unable to remove a message from the queue")
				continue
			}
			s.Stats.TimingDurationC("sqs.DeleteMessage", time.Since(startDelete), 1)
		}
		cancel()
	}
}
