package cwlogevent

import (
	"context"
	"time"

	"crypto/rand"
	"encoding/base64"
	"fmt"
	"os"
	"regexp"

	"strings"

	"sort"

	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
	"code.justin.tv/hygienic/messagebatch"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
	"github.com/cep21/circuit"
)

// Config configures log events to cloudwatch
type Config struct {
	Prefix string
	messagebatch.Config
	LogGroupName      *distconf.Str
	CloudwatchTimeout *distconf.Duration
}

// Load the configuration from distconf
func (c *Config) Load(d *distconf.Distconf) error {
	c.Config.Prefix = c.Prefix
	c.LogGroupName = d.Str(c.Prefix+"log_group_name", "")
	if c.LogGroupName.Get() == "" {
		return errors.New("expected a key for log group name")
	}
	c.CloudwatchTimeout = d.Duration(c.Prefix+"cloudwatch_timeout", time.Second*3)
	return c.Config.Load(d)
}

// CloudwatchLogBatcher adds cloudwatch events in batch
type CloudwatchLogBatcher struct {
	messagebatch.Batcher
	Config        *Config
	LogStreamName string
	Client        *cloudwatchlogs.CloudWatchLogs
	Circuit       *circuit.Circuit
	sequenceToken *string
}

type event struct {
	msg  string
	when time.Time
}

// Event sends a string event to cloudwatch logs
func (a *CloudwatchLogBatcher) Event(msg string, when time.Time) {
	if when.IsZero() {
		when = time.Now()
	}
	a.Batcher.Event(event{
		msg:  msg,
		when: when,
	})
}

// Setup should be called before Start()
func (a *CloudwatchLogBatcher) Setup() error {
	a.Batcher.Config = &a.Config.Config
	if a.LogStreamName == "" {
		a.LogStreamName = CreateLogStreamName("")
	}
	req, _ := a.Client.CreateLogStreamRequest(&cloudwatchlogs.CreateLogStreamInput{
		LogGroupName:  aws.String(a.Config.LogGroupName.Get()),
		LogStreamName: &a.LogStreamName,
	})
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, a.Config.CloudwatchTimeout.Get())
	defer cancel()
	err := a.Circuit.Run(ctx, func(ctx context.Context) error {
		req.SetContext(ctx)
		return req.Send()
	})
	if err != nil {
		return err
	}
	a.SendEvents = a.sendEvents
	return a.Batcher.Setup()
}

var filterRegex = regexp.MustCompile("[^a-zA-Z0-9_]+")

// CreateLogStreamName helps create a random log stream name.
// Log stream names must be unique within the log group.
// Log stream names can be between 1 and 512 characters long.
// The ':' (colon) and '*' (asterisk) characters are not allowed.
func CreateLogStreamName(prefix string) string {
	fullName := prefix
	h, err := os.Hostname()
	if err == nil {
		fullName += "_" + h
	}
	fullName += fmt.Sprintf("%d", time.Now().UnixNano())
	var b [8]byte
	_, err = rand.Read(b[0:8])
	if err == nil {
		fullName += "_" + base64.StdEncoding.EncodeToString(b[0:8])
	}
	ret := filterRegex.ReplaceAllString(fullName, "_")
	if len(ret) > 500 {
		ret = ret[:500]
	}
	return ret
}

func intVal(a *int64) int64 {
	if a == nil {
		return 0
	}
	return *a
}

func (a *CloudwatchLogBatcher) logResponse(resp *cloudwatchlogs.PutLogEventsOutput) {
	if resp.RejectedLogEventsInfo == nil {
		return
	}
	b := intVal(resp.RejectedLogEventsInfo.ExpiredLogEventEndIndex)
	c := intVal(resp.RejectedLogEventsInfo.TooNewLogEventStartIndex)
	d := intVal(resp.RejectedLogEventsInfo.TooOldLogEventEndIndex)
	if d+b+c == 0 {
		return
	}
	a.Log.Log("err", errors.New("some log events lost"), "expired_end", d, "too_new_start", b, "too_old_event", c)
}

func toCloudwatch(events []interface{}) []*cloudwatchlogs.InputLogEvent {
	ret := make([]*cloudwatchlogs.InputLogEvent, len(events))
	for i := range events {
		e := events[i].(event)
		ret[i] = &cloudwatchlogs.InputLogEvent{
			Message:   &e.msg,
			Timestamp: aws.Int64(e.when.UnixNano() / time.Millisecond.Nanoseconds()),
		}
	}
	// Note: Messages must be sent in order
	sort.SliceStable(ret, func(i, j int) bool {
		return *ret[i].Timestamp < *ret[j].Timestamp
	})
	return ret
}

func (a *CloudwatchLogBatcher) attempt(ctx context.Context, events []interface{}) error {
	req, resp := a.Client.PutLogEventsRequest(&cloudwatchlogs.PutLogEventsInput{
		LogGroupName:  aws.String(a.Config.LogGroupName.Get()),
		LogStreamName: &a.LogStreamName,
		SequenceToken: a.sequenceToken,
		LogEvents:     toCloudwatch(events),
	})
	err := a.Circuit.Run(ctx, func(ctx context.Context) error {
		req.SetContext(ctx)
		return req.Send()
	})
	if err != nil {
		return err
	}
	a.sequenceToken = resp.NextSequenceToken
	a.logResponse(resp)
	return nil
}

func (a *CloudwatchLogBatcher) sendEvents(events []interface{}) error {
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, a.Config.CloudwatchTimeout.Get())
	defer cancel()
	err := a.attempt(ctx, events)
	if err == nil {
		return nil
	}
	if strings.Contains(err.Error(), "InvalidSequenceTokenException") || strings.Contains(err.Error(), "DataAlreadyAcceptedException") || strings.Contains(err.Error(), "DataAlreadyAcceptedException") {
		// Reset the token once if we get out of sync.  This can happen if the first error fails
		req, resp := a.Client.DescribeLogStreamsRequest(&cloudwatchlogs.DescribeLogStreamsInput{
			LogGroupName:        aws.String(a.Config.LogGroupName.Get()),
			LogStreamNamePrefix: &a.LogStreamName,
		})
		err2 := a.Circuit.Run(ctx, func(ctx context.Context) error {
			req.SetContext(ctx)
			return req.Send()
		})
		if err2 != nil {
			return err2
		}
		if len(resp.LogStreams) != 1 {
			return err
		}
		a.sequenceToken = resp.LogStreams[0].UploadSequenceToken
	}
	return a.attempt(ctx, events)
}
