package servicelog

// copied from https://git.xarth.tv/video/warp/blob/26bf231573eeddc6ea69005e909b41816d76257f/internal/logging/cloudwatch.go

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"sync"
	"sync/atomic"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
	"go.uber.org/zap"

	identifier "code.justin.tv/amzn/TwitchProcessIdentifier"
)

const (
	// Max batch size in terms of bytes and count
	cwlBatchBytes = 1048576
	cwlBatchCount = 10000

	// Max message size in terms of bytes
	cwlEventBytes = 4096
)

// CloudwatchLogger sends batches of service logs to cloudwatch
type CloudwatchLogger struct {
	logCreator

	session *cloudwatchlogs.CloudWatchLogs
	group   string
	stream  string
	buffer  chan *cloudwatchlogs.InputLogEvent
	runOnce sync.Once

	stdOutLogger *zap.Logger
	droppedLogs  *uint64
}

// New creates a Logger for a cloudwatch log group
func New(sess *session.Session, group string, stream string, pid identifier.ProcessIdentifier) *CloudwatchLogger {
	s := sess.Copy()
	s.Handlers.Send.PushFront(func(r *request.Request) {
		r.HTTPRequest.Header.Set("x-amzn-logs-format", "json/emf")
	})
	var dropped uint64
	return &CloudwatchLogger{
		logCreator: logCreator{
			pid: pid,
		},
		session:     cloudwatchlogs.New(s),
		group:       group,
		stream:      stream,
		buffer:      make(chan *cloudwatchlogs.InputLogEvent, 1000),
		droppedLogs: &dropped,
	}
}

func (l *CloudwatchLogger) SetStdoutLogger(z *zap.Logger) {
	l.stdOutLogger = z
}

// Run collects log batches and sends them to cloudwatch. Blocks until failure.
func (l *CloudwatchLogger) Run(ctx context.Context) error {
	var token *string
	var input *cloudwatchlogs.PutLogEventsInput
	var err error

	l.runOnce.Do(func() {
		token, err = l.createStream(ctx)

		input = &cloudwatchlogs.PutLogEventsInput{
			LogGroupName:  aws.String(l.group),
			LogStreamName: aws.String(l.stream),
			SequenceToken: token,
		}
	})

	if err != nil {
		return fmt.Errorf("failed to create log stream: %w", err)
	}
	if input == nil {
		return fmt.Errorf("Run() executed more than once")
	}

	for {
		select {
		case <-ctx.Done():
			// TODO: try final flush before stopping
			return ctx.Err()
		default:
			err2 := l.collectBatch(ctx, input)

			if l.stdOutLogger != nil {
				dropped := atomic.SwapUint64(l.droppedLogs, 0)
				if dropped > 0 {
					l.stdOutLogger.Warn("dropped service logs since last send", zap.Uint64("num", dropped))
				}
			}

			// Even if there's an error (ie. when the context is canceled), try sending the batch
			if len(input.LogEvents) > 0 {
				// We do not use the context because we don't want this request to be canceled during shutdown.
				output, err := l.sendBatch(input)
				if err != nil && l.stdOutLogger != nil {
					l.stdOutLogger.Error("sendBatch failed", zap.Error(err))
				}

				input.LogEvents = input.LogEvents[:0] // reuse the slice!
				if output != nil {
					input.SequenceToken = output.NextSequenceToken
				}
			}

			if err2 != nil {
				return err2
			}
		}
	}
}

// Appends events until the max size or a timeout is hit.
func (l *CloudwatchLogger) collectBatch(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput) error {
	n := 0

	// Collect the first entry
	select {
	case <-ctx.Done():
		return ctx.Err()
	case event := <-l.buffer:
		input.LogEvents = append(input.LogEvents, event)
		n += len(*event.Message) + 26
	}

	timer := time.NewTimer(5 * time.Second)
	defer timer.Stop()

	for {
		if n+cwlEventBytes > cwlBatchBytes || len(input.LogEvents) >= cwlBatchCount {
			// Batch is full
			break
		}

		select {
		case <-ctx.Done():
			return ctx.Err()
		case event := <-l.buffer:
			if *event.Timestamp < *input.LogEvents[len(input.LogEvents)-1].Timestamp {
				continue
			}

			input.LogEvents = append(input.LogEvents, event)
			n += len(*event.Message) + 26
		case <-timer.C:
			return nil
		}
	}

	return nil
}

func (l *CloudwatchLogger) createStream(ctx context.Context) (*string, error) {
	// Look for an existing log stream with the same name.
	resp, err := l.session.DescribeLogStreamsWithContext(aws.Context(ctx), &cloudwatchlogs.DescribeLogStreamsInput{
		LogGroupName:        aws.String(l.group),
		LogStreamNamePrefix: aws.String(l.stream),
		Limit:               aws.Int64(1),
	})
	if err != nil {
		return nil, fmt.Errorf("failed to describe log streams: %w", err)
	}

	// Return the next sequence token if the log stream actually exists.
	if len(resp.LogStreams) > 0 && *resp.LogStreams[0].LogStreamName == l.stream {
		return resp.LogStreams[0].UploadSequenceToken, nil
	}

	// Otherwise, create a new one with the name.
	_, err = l.session.CreateLogStreamWithContext(aws.Context(ctx), &cloudwatchlogs.CreateLogStreamInput{
		LogGroupName:  aws.String(l.group),
		LogStreamName: aws.String(l.stream),
	})

	if err != nil {
		return nil, fmt.Errorf("failed to create log stream: %w", err)
	}

	return nil, nil
}

func (l *CloudwatchLogger) sendBatch(input *cloudwatchlogs.PutLogEventsInput) (output *cloudwatchlogs.PutLogEventsOutput, err error) {
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	output, err = l.session.PutLogEventsWithContext(ctx, input)
	if err != nil {
		var tokenError *cloudwatchlogs.InvalidSequenceTokenException
		if errors.As(err, &tokenError) {
			input.SequenceToken = tokenError.ExpectedSequenceToken
		}

		return nil, fmt.Errorf("failed to put logs: %w", err)
	}

	rejected := output.RejectedLogEventsInfo
	if rejected != nil {
		if rejected.ExpiredLogEventEndIndex != nil {
			return nil, fmt.Errorf("rejected expired events")
		}

		if rejected.TooNewLogEventStartIndex != nil {
			return nil, fmt.Errorf("rejected too new events")
		}

		if rejected.TooNewLogEventStartIndex != nil {
			return nil, fmt.Errorf("rejected too old events")
		}
	}

	return output, nil
}

// Send accepts a user struct to log and adds it to the current batch
func (l *CloudwatchLogger) Send(entry MetricLogger) error {
	err := parseMetrics(entry)
	if err != nil {
		return err
	}

	b, err := json.Marshal(entry)
	if err != nil {
		return err
	}

	if len(b) == 0 {
		return fmt.Errorf("empty write")
	} else if len(b) > cwlEventBytes {
		return fmt.Errorf("log line too long: %d", len(b))
	}

	event := &cloudwatchlogs.InputLogEvent{
		Message:   aws.String(string(b)),
		Timestamp: aws.Int64(aws.TimeUnixMilli(time.Now())),
	}

	select {
	case l.buffer <- event:
		return nil
	default:
		atomic.AddUint64(l.droppedLogs, 1)
		return nil
	}
}
