package async

import (
	"bytes"
	"context"
	"encoding/base64"
	"io"
	"strings"
	"time"

	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/interngraphdb"
	"code.justin.tv/feeds/graphdb/proto/graphdb"
	"code.justin.tv/feeds/log"
	"code.justin.tv/hygienic/statsdsender"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/cep21/circuit"
	"github.com/golang/protobuf/proto"
	"golang.org/x/sync/errgroup"
)

type QueueConfig struct {
	Queue             *distconf.Str
	VisibilityTimeout *distconf.Duration
	WaitTime          *distconf.Duration
}

// Add some room for the message attributes
const maximumSQSSize = 122144

func (q *QueueConfig) Load(d *distconf.Distconf) error {
	q.Queue = d.Str("graphdb.async_queue.queue_url", "")
	if q.Queue.Get() == "" {
		return errors.New("Unable to find queue URL from async_queue.queue_url")
	}
	q.VisibilityTimeout = d.Duration("graphdb.async_queue.visibility_timeout", time.Minute*30)
	q.WaitTime = d.Duration("graphdb.async_queue.wait_time", time.Second*10)
	return nil
}

// QueueCircuits is the hystrix circuits for all the queue operations we care about
type QueueCircuits struct {
	Receive *circuit.Circuit
	Delete  *circuit.Circuit
	Send    *circuit.Circuit
}

// Queue abstracts the SQS queue.  We use this to send/receive/mark deleted messages from the queue
type Queue struct {
	SQS         *sqs.SQS
	Stats       *statsdsender.ErrorlessStatSender `nilcheck:"nodepth"`
	QueueConfig *QueueConfig
	Log         log.Logger
	Storage     graphdb.GraphDB
	Circuits    QueueCircuits
}

// Message contains both the SQS message we originally got, and the protobuf that was sent with that message.  This
// allows us to cleanly clean up the SQS message when we are done with it.
type Message struct {
	msg *sqs.Message
	Msg *interngraphdb.AsyncRequestQueueMessage
}

// ReceiveMessages reads and decodes messages from the queue
func (a *Queue) ReceiveMessages(ctx context.Context) ([]*Message, error) {
	req, out := a.SQS.ReceiveMessageRequest(&sqs.ReceiveMessageInput{
		QueueUrl:              aws.String(a.QueueConfig.Queue.Get()),
		AttributeNames:        []*string{aws.String("All")},
		MessageAttributeNames: []*string{aws.String("All")},
		VisibilityTimeout:     aws.Int64(int64(a.QueueConfig.VisibilityTimeout.Get().Seconds()) + 1),
		WaitTimeSeconds:       aws.Int64(int64(a.QueueConfig.WaitTime.Get().Seconds()) + 1),
	})
	err := a.Circuits.Receive.Run(ctx, func(ctx context.Context) error {
		req.SetContext(ctx)
		return req.Send()
	})
	if err != nil {
		return nil, err
	}
	var ret []*Message
	for _, msg := range out.Messages {
		var into interngraphdb.AsyncRequestQueueMessage
		if err := decode(strings.NewReader(*msg.Body), &into); err != nil {
			a.Log.Log("err", err, "invalid sqs message body")
			invalidMsg := &Message{msg: msg}
			if err := a.DeleteMessage(ctx, invalidMsg); err != nil {
				a.Log.Log("err", err, "unable to delete invalid message")
			}
			continue
		}
		ret = append(ret, &Message{
			msg: msg,
			Msg: &into,
		})
	}
	return ret, nil
}

// split a message into two messages.  We have to do this because SQS limits how big any single message can be
func split(in *interngraphdb.AsyncRequestQueueMessage) (*interngraphdb.AsyncRequestQueueMessage, *interngraphdb.AsyncRequestQueueMessage) {
	midPoint := len(in.Requests) / 2
	return &interngraphdb.AsyncRequestQueueMessage{
			Requests: in.Requests[:midPoint],
		}, &interngraphdb.AsyncRequestQueueMessage{
			Requests: in.Requests[midPoint:],
		}
}

// encode the message we're going to send to SQS.  We have to URLEncoding the message because SQS is picky about
// any special characters
func encode(in proto.Message, out io.Writer) error {
	b, err := proto.Marshal(in)
	if err != nil {
		return err
	}
	c := base64.NewEncoder(base64.URLEncoding, out)
	_, err = io.Copy(c, bytes.NewReader(b))
	if err != nil {
		return err
	}
	return c.Close()
}

// decode a previously `encode` request
func decode(in io.Reader, out proto.Message) error {
	reader := base64.NewDecoder(base64.URLEncoding, in)
	buf := bytes.Buffer{}
	_, err := io.Copy(&buf, reader)
	if err != nil {
		return err
	}
	return proto.Unmarshal(buf.Bytes(), out)
}

// SendMessage adds the Async message to the SQS queue.  If the message is too big, we will try to split it into two
// smaller messages
func (a *Queue) SendMessage(ctx context.Context, m *interngraphdb.AsyncRequestQueueMessage, delay time.Duration) error {
	buf := bytes.Buffer{}
	if err := encode(m, &buf); err != nil {
		return err
	}
	b := buf.String()

	// Note: We need to update this to split correctly if Requests the protobuf changes.  Probably want to make the
	// contents of the protobuf a oneof.
	if len(b) >= maximumSQSSize && len(m.Requests) > 3 {
		// Message is too big to fit inside SQS.  Split it into parts
		firstHalf, secondHalf := split(m)

		// Send all three at once
		eg, egCtx := errgroup.WithContext(ctx)
		// TODO: Check panics
		eg.Go(func() error {
			return a.SendMessage(egCtx, firstHalf, delay)
		})
		eg.Go(func() error {
			return a.SendMessage(egCtx, secondHalf, delay)
		})
		eg.Go(func() error {
			leftOver := &interngraphdb.AsyncRequestQueueMessage{
				CountRetries:     m.CountRetries,
				CountNodeRetries: m.CountNodeRetries,
			}
			if len(leftOver.CountRetries) == 0 && len(leftOver.CountNodeRetries) == 0 {
				return nil
			}
			return a.SendMessage(egCtx, leftOver, delay)
		})
		return eg.Wait()
	}

	// At this point, message should be small enough to add to SQS
	in := &sqs.SendMessageInput{
		QueueUrl:     aws.String(a.QueueConfig.Queue.Get()),
		MessageBody:  aws.String(b),
		DelaySeconds: aws.Int64(int64(delay.Seconds())),
	}
	req, out := a.SQS.SendMessageRequest(in)
	err := a.Circuits.Send.Run(ctx, func(ctx context.Context) error {
		req.SetContext(ctx)
		return req.Send()
	})
	if err != nil {
		a.Stats.IncC("node_recounts", int64(len(m.GetCountNodeRetries())), .25)
		a.Stats.IncC("edge_recounts", int64(len(m.GetCountRetries())), .25)
		a.Stats.IncC("async_requests", int64(len(m.GetRequests())), .25)
	}
	// TODO: Log out message id somehow
	_ = out.MessageId
	return err
}

// DeleteMessage removes a previously read message from the queue
func (a *Queue) DeleteMessage(ctx context.Context, msg *Message) error {
	req, _ := a.SQS.DeleteMessageRequest(&sqs.DeleteMessageInput{
		QueueUrl:      aws.String(a.QueueConfig.Queue.Get()),
		ReceiptHandle: msg.msg.ReceiptHandle,
	})
	err := a.Circuits.Delete.Run(ctx, func(ctx context.Context) error {
		req.SetContext(ctx)
		return req.Send()
	})
	return err
}
