package consumer

import (
	"context"
	"encoding/base64"
	"fmt"
	"strconv"
	"sync"
	"time"

	"code.justin.tv/amzn/StarfruitMOLLEConsumer/util"
	serviceapi "code.justin.tv/amzn/StarfruitMOLLELambdaTwirp"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/sns"
	"github.com/aws/aws-sdk-go/service/sns/snsiface"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
	"github.com/gogo/protobuf/proto"
	"github.com/pkg/errors"
)

const (
	queuePrefix = "video-weaver-"
	queueSuffix = ""

	reportErrorTimeout = 1 * time.Second
)

// MOLLEConsumer is an interface for pulling MOLL-E messages from the service
// (they're stored in sqs)
type MOLLEConsumer interface {
	TryFetchAndHandle(ctx context.Context, handle func(serviceapi.WeaverEvent, time.Time) error) error
	StartFetchAndHandle(ctx context.Context, handle func(serviceapi.WeaverEvent, time.Time) error, isActive func() bool) <-chan error
}

type EnvConfig struct {
	AccountID   string
	Region      string
	Stage       string
	QueueSuffix string
}

// map weaver environments to MOLL-E environments (region+stage)
var mollEEnvsForConsumerEnv = map[string][]EnvConfig{
	"production": {
		{"371649652584", "us-west-2", "prod", queueSuffix},
		{"804678302183", "us-east-2", "prod", queueSuffix},
	},
	"staging": {
		{"048085864192", "us-west-2", "beta", queueSuffix},
		{"468186331846", "us-east-2", "beta", queueSuffix},
	},
}

type envConsumer struct {
	sqs sqsiface.SQSAPI
	sns snsiface.SNSAPI

	sqsQueueURL string
	snsTopicARN string
}

type molleConsumer struct {
	consumers []envConsumer
}

type HandleError struct { // more info for handler to log
	Err     error
	Receipt string                 // of event in sqs
	Msg     serviceapi.WeaverEvent // msg that was being handled when err occurred
}

func (h HandleError) Error() string {
	return h.Err.Error()
}

// should be in [1,10]
const batchSize = 10 // how many messages to fetch from a queue at once, for tuning

// NewConsumerClient returns a MOLL-E consumer client prepared to work for the given weaver cluster
func NewConsumerClient(clusterName, environment string, sess *session.Session) (MOLLEConsumer, error) {
	env := environment
	if environment == "development" {
		clusterName = "test"
		env = "staging"
	} else if environment == "canary" {
		env = "production"
	}

	mollEEnvs, ok := mollEEnvsForConsumerEnv[env]
	if !ok {
		return nil, errors.Errorf("don't have MOLL-E configuration for environment %v", environment)
	}

	return NewConsumerClientWithMollEEnvs(context.Background(), clusterName, env, mollEEnvs, sess)
}

func NewConsumerClientWithMollEEnvs(ctx context.Context, clusterName, environment string, mollEEnvs []EnvConfig, sess *session.Session) (MOLLEConsumer, error) {
	makeClients := func(s client.ConfigProvider, cfgs ...*aws.Config) (sqsiface.SQSAPI, snsiface.SNSAPI) {
		return sqs.New(s, cfgs...), sns.New(s, cfgs...)
	}

	return NewConsumerClientWithMollEEnvsAndClients(ctx, clusterName, environment, mollEEnvs, sess, makeClients)
}

func NewConsumerClientWithMollEEnvsAndClients(ctx context.Context, clusterName, environment string, mollEEnvs []EnvConfig, sess *session.Session, makeClients func(client.ConfigProvider, ...*aws.Config) (sqsiface.SQSAPI, snsiface.SNSAPI)) (MOLLEConsumer, error) {
	var consumers []envConsumer

	for _, mollEEnv := range mollEEnvs {
		cfg := aws.NewConfig().WithRegion(mollEEnv.Region)

		sqsClient, snsClient := makeClients(sess, cfg)

		queueName := util.BuildQueueName(queuePrefix, environment, clusterName, mollEEnv.QueueSuffix)
		resp, err := sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
			QueueName:              aws.String(queueName),
			QueueOwnerAWSAccountId: aws.String(mollEEnv.AccountID),
		})
		if err != nil {
			return nil, errors.Wrapf(err, "failed to get queue url for %v in account %v", queueName, mollEEnv.AccountID)
		}

		consumers = append(consumers, envConsumer{
			sqs: sqsClient,
			sns: snsClient,

			sqsQueueURL: aws.StringValue(resp.QueueUrl),
			snsTopicARN: fmt.Sprintf("arn:aws:sns:%v:%v:WeaverErrors-%v",
				mollEEnv.Region, mollEEnv.AccountID, mollEEnv.Stage),
		})
	}

	return &molleConsumer{consumers: consumers}, nil
}

// TryFetchAndHandle attempts to do one quick loop of the message processing system, to make sure all the bits theoretically work
func (m *molleConsumer) TryFetchAndHandle(ctx context.Context, handle func(serviceapi.WeaverEvent, time.Time) error) error {
	for _, c := range m.consumers {
		req := sqs.ReceiveMessageInput{
			MaxNumberOfMessages: aws.Int64(1), // just trying to see if reaching sqs is possible, so fetch smallest # of messages possible
			QueueUrl:            aws.String(c.sqsQueueURL),
			WaitTimeSeconds:     aws.Int64(0), // do not wait
			AttributeNames:      []*string{aws.String(sqs.MessageSystemAttributeNameSentTimestamp)},
		}
		resp, err := c.sqs.ReceiveMessageWithContext(ctx, &req)
		if err != nil {
			return errors.Wrap(err, "could not receive messages from sqs")
		}
		if len(resp.Messages) == 0 {
			continue
		}
		msg, sentTime, receipt, err := decodeMsg(resp.Messages[0])
		if err != nil { // just throw it back on the pile
			continue
		}
		err = handle(msg, sentTime)
		if err != nil { // same as above
			continue
		}
		err = m.deleteMessages(ctx, c, []string{receipt})
		if err != nil { // probably actually do want to report this one tbh
			return err
		}
	}
	return nil
}

// StartFetchAndHandle spawns two goroutines per queue that weaver is pulling messages from. One actually pulls messages
// and handles them (by synchronously calling the `handle` function provided to this method), the other serves as an
// error processing intermediary between the message pullers and the calling function.
//
// The returned channel contains errors (non-fatal and possibly fatal) that arise during message pulling/processing.
// If the channel is closed, a fatal error has occurred and all MOLL-E routines spawned by this function have exited.
// in this case, the MOLL-E procedure will need to be restarted.
//
// The consumer will block on sending errors, so callers must ensure they are reading from the returned channel for
// the consumer to make progress.
func (m *molleConsumer) StartFetchAndHandle(ctx context.Context, handle func(serviceapi.WeaverEvent, time.Time) error, isActive func() bool) <-chan error {
	ctx, cancel := context.WithCancel(ctx)
	outChan := make(chan error, 10) // sizing?
	wg := &sync.WaitGroup{}
	for _, c := range m.consumers {
		wg.Add(1)
		go func(c envConsumer) {
			defer wg.Done()
			errChan := make(chan error)
			go m.fetchAndHandle(ctx, c, handle, isActive, errChan)
			for {
				select {
				case err, ok := <-errChan:
					if err != nil { // report error out to caller
						outChan <- err
					}
					if !ok { // fatal error, cancel context and signal other goroutines to cancel
						cancel()
						return
					}
				case <-ctx.Done(): // context cancelled, exit
					return
				}
			}
		}(c)
	}

	go func() {
		defer close(outChan) // once all handlers have exited, close outer errchan
		defer cancel()       // make sure the context is canceled even if no worker canceled it
		wg.Wait()
	}()
	return outChan
}

// fetchAndHandle will fetch and process (handle) messages from a particular queue, defined as a (region, endpoint) pair.
// it will report any non-fatal errors via errChan, and fatal errors via errChan, closing it, and exiting.
func (m *molleConsumer) fetchAndHandle(ctx context.Context, c envConsumer, handle func(serviceapi.WeaverEvent, time.Time) error, isActive func() bool, errChan chan<- error) {
	defer close(errChan)

	upstreamErrChan := make(chan HandleError, 100)

	t := time.NewTicker(time.Second)
	defer t.Stop()

	// report errors to SNS without blocking the consumer goroutine
	wg := &sync.WaitGroup{}
	wg.Add(1)
	go func() {
		defer wg.Done()
		for {
			select {
			case <-ctx.Done():
				return
			case handleErr := <-upstreamErrChan:
				err := m.reportError(ctx, handleErr.Msg, c, handleErr.Err)
				if err != nil {
					err := errors.Wrap(err, "couldn't report previous error back to MOLL-E")
					errChan <- HandleError{
						Err:     err,
						Receipt: handleErr.Receipt,
						Msg:     handleErr.Msg,
					}
				}
			}
		}
	}()

	wait := newBackoffWaiter()

	for {
		select {
		case <-ctx.Done():
			return
		default:
			for isActive != nil && !isActive() {
				select {
				case <-ctx.Done():
					return
				case <-t.C:
				}
			}

			m.fetchAndHandleOnce(ctx, c, handle, errChan, upstreamErrChan, wait)
		}
	}
	wg.Wait()
}

func (m *molleConsumer) fetchAndHandleOnce(ctx context.Context, c envConsumer, handle func(serviceapi.WeaverEvent, time.Time) error, errChan chan<- error, upstreamErrChan chan<- HandleError, wait *backoffWaiter) {
	const receiveTimeoutSeconds = 20
	ctx, cancel := context.WithTimeout(ctx, (receiveTimeoutSeconds+1)*time.Second)
	defer cancel()
	req := sqs.ReceiveMessageInput{
		MaxNumberOfMessages: aws.Int64(batchSize),
		QueueUrl:            aws.String(c.sqsQueueURL),
		WaitTimeSeconds:     aws.Int64(receiveTimeoutSeconds),
		AttributeNames:      []*string{aws.String(sqs.MessageSystemAttributeNameSentTimestamp)},
	}

	resp, err := c.sqs.ReceiveMessageWithContext(ctx, &req)
	if err != nil {
		errChan <- errors.Wrapf(err, "couldn't fetch MOLL-E messages from %v", c.sqsQueueURL)

		// if we're misconfigured and somehow the Receive call
		// instantly fails, we don't want to busy-loop here and
		// ddos SQS.
		wait.Wait(ctx)
		return
	}
	wait.Reset()

	if len(resp.Messages) == 0 {
		return
	}

	consumed := make([]string, 0, len(resp.Messages))
	for _, msg := range resp.Messages {
		msg, sentTime, receipt, err := decodeMsg(msg)
		if err != nil {
			err := errors.Wrap(err, "couldn't decode MOLL-E message")
			errChan <- HandleError{
				Err:     err,
				Receipt: receipt,
			}
			// don't want to retry this, won't be able to decode it next
			// time either.
			consumed = append(consumed, receipt)
			continue
		}

		err = handle(msg, sentTime)
		if err != nil {
			errChan <- HandleError{
				Err:     errors.Wrap(err, "MOLL-E handler returned error"),
				Receipt: receipt,
				Msg:     msg,
			}

			// let the user code opt out of retrying
			if isNonRetryable(err) {
				consumed = append(consumed, receipt)
			}

			// Since this is an application-level error, it's
			// most likely to be interesting to the sender of
			// the message, so we report it back upstream.
			report := HandleError{
				Err:     err,
				Receipt: receipt,
				Msg:     msg,
			}
			select {
			case upstreamErrChan <- report:
			default:
				errChan <- HandleError{
					Err:     errors.Wrap(err, "upstream error reporting channel full, dropping error"),
					Receipt: receipt,
					Msg:     msg,
				}
			}
			continue
		}
		consumed = append(consumed, receipt)
	}
	err = m.deleteMessages(ctx, c, consumed)
	if err != nil {
		errChan <- errors.Wrapf(err, "failed to report successful handling of %d MOLL-E messages", len(consumed))
	}
}

// deleteMessages will mark a list of messages (by receipt handle) in the given queue (by (region, endpoint)) as complete.
func (m *molleConsumer) deleteMessages(ctx context.Context, c envConsumer, handles []string) error {
	if len(handles) == 0 {
		// that's unfortunate. Avoid making an empty request to sqs.
		return nil
	}

	req := sqs.DeleteMessageBatchInput{
		QueueUrl: aws.String(c.sqsQueueURL),
		Entries:  make([]*sqs.DeleteMessageBatchRequestEntry, 0, len(handles)),
	}
	for i, receipt := range handles {
		req.Entries = append(req.Entries, &sqs.DeleteMessageBatchRequestEntry{
			Id:            aws.String(fmt.Sprintf("%v", i)),
			ReceiptHandle: aws.String(receipt),
		})
	}

	_, err := c.sqs.DeleteMessageBatchWithContext(ctx, &req)
	if err != nil {
		return err
	}
	return nil
}

func (m *molleConsumer) reportError(ctx context.Context, msg serviceapi.WeaverEvent, c envConsumer, payload error) error {
	// TODO: If messages come in faster than we can report errors (or
	// time out), we may end up blocking the consumer and hence impede
	// application logic with mere error reporting.
	//
	// We should have something like a fixed-size pool of worker threads
	// submitting these and then just discard errors that come in while
	// they're all busy and their channel is full.

	ctx, cancel := context.WithTimeout(ctx, reportErrorTimeout)
	defer cancel()

	reportBytes, err := proto.Marshal(&serviceapi.WeaverError{
		RequestId: msg.RequestId,
		Error:     payload.Error(),
	})
	if err != nil {
		return errors.Wrap(err, "failed to marshal error report")
	}

	encodedReport := base64.StdEncoding.EncodeToString(reportBytes)

	_, err = c.sns.PublishWithContext(ctx, &sns.PublishInput{
		TopicArn: aws.String(c.snsTopicARN),
		Message:  aws.String(encodedReport),
	})
	if err != nil {
		return errors.Wrapf(err, "failed publish to sns topic %v (payload size: %v)", c.snsTopicARN, len(encodedReport))
	}

	return nil
}

// decodeMsg takes an sqs message and loads it into a WeaverEvent, basically undoing the process by which it was put
// into sqs.
// it also returns the receipt of the message, for marking it as complete later.
func decodeMsg(input *sqs.Message) (serviceapi.WeaverEvent, time.Time, string, error) {
	decodedEvent, err := base64.StdEncoding.DecodeString(aws.StringValue(input.Body))
	if err != nil {
		return serviceapi.WeaverEvent{}, time.Time{}, "", errors.Wrap(err, "could not decode message from base64")
	}
	var msg serviceapi.WeaverEvent
	err = proto.Unmarshal(decodedEvent, &msg)
	if err != nil {
		return serviceapi.WeaverEvent{}, time.Time{}, "", errors.Wrap(err, "could not unmarshal message into WeaverEvent")
	}

	var sentTime time.Time
	sentTimestampString := input.Attributes[sqs.MessageSystemAttributeNameSentTimestamp]
	sentTimestamp, err := strconv.ParseUint(aws.StringValue(sentTimestampString), 10, 64)
	if err != nil { // probably empty
		sentTimestamp = 0
	}
	sentTime = time.Unix(int64(sentTimestamp)/1000, int64(sentTimestamp)%1000)

	return msg, sentTime, aws.StringValue(input.ReceiptHandle), nil
}
