package eventstreamstats

import (
	"context"
	"math"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface"
	"github.com/pkg/errors"
)

const (
	eventThroughput = "epm"
	eventSize       = "size"
)

// CloudwatchClient uses a type alias around the cloudwatch API to add
// a method for retrieving EventStream statistics
type CloudwatchClient struct {
	client cloudwatchiface.CloudWatchAPI
}

// NewCloudwatchClient returns a new cloudwatch stats fetcher
func NewCloudwatchClient(c cloudwatchiface.CloudWatchAPI) *CloudwatchClient {
	return &CloudwatchClient{
		client: c,
	}
}

// Fetch gets event stream statistics data from cloudwatch
func (c *CloudwatchClient) Fetch(ctx context.Context, snsTopicARN string) (*EventStreamStatistics, error) {
	// If the topic ARN is empty, it may just not be created yet.
	// Instead of returning an error, hand back empty values.
	if snsTopicARN == "" {
		return &EventStreamStatistics{
			// empty
		}, nil
	}

	end := time.Now()
	start := end.Add(-24 * time.Hour)

	topicName, err := topicNameFromARN(snsTopicARN)
	if err != nil {
		return nil, errors.Wrap(err, "could not fetch event stream stats")
	}

	input := query(topicName, start, end)
	res, err := c.client.GetMetricDataWithContext(ctx, input)
	if err != nil {
		return nil, err
	}

	// grok query results, if they exist
	var throughputData *cloudwatch.MetricDataResult
	var sizeData *cloudwatch.MetricDataResult
	for _, metricData := range res.MetricDataResults {
		switch aws.StringValue(metricData.Id) {
		case eventThroughput:
			throughputData = metricData
		case eventSize:
			sizeData = metricData
		}
	}

	stats := &EventStreamStatistics{}

	// if each query returned data, process it and dump final values into stats structure.
	// if not present (because the metrics don't exist) then the values will return as nil
	// due to the wrappers and the client will be able to distinguish between nil and the
	// zero value.

	if throughputData != nil {
		min, max, avg := minMaxAvg(throughputData)
		min = min / 5
		max = max / 5
		avg = avg / 5
		// divide by 5 to convert from "per 5 minutes" to "per minute"
		stats.MaxEventsPerMinute = &max
		stats.MinEventsPerMinute = &min
		stats.MeanEventsPerMinute = &avg
	}

	if sizeData != nil {
		_, _, avg := minMaxAvg(sizeData)
		stats.MeanEventPayloadSize = &avg
	}

	return stats, nil
}

func topicNameFromARN(arn string) (string, error) {
	parts := strings.Split(arn, ":")
	if len(parts) != 6 {
		return "", errors.New("invalid sns topic arn")
	}
	return parts[5], nil
}

// Given some cloudwatch data, returns the min, max, and average of the set of datapoints
func minMaxAvg(data *cloudwatch.MetricDataResult) (float64, float64, float64) {
	// This should never happen, but let's not risk a divide-by-zero error
	if len(data.Values) == 0 {
		return 0, 0, 0
	}
	sum := float64(0)
	min := math.MaxFloat64
	max := float64(0)
	var val float64
	for _, valPtr := range data.Values {
		val = aws.Float64Value(valPtr)
		sum += val
		if val > max {
			max = val
		}
		if val < min {
			min = val
		}
	}
	avg := sum / float64(len(data.Values))
	return min, max, avg
}

// Constructs a cloudwatch query input that fetches necessary data to construct
// EventStreamStats for a given event stream SNS topic
func query(topicName string, start, end time.Time) *cloudwatch.GetMetricDataInput {
	return &cloudwatch.GetMetricDataInput{
		StartTime: aws.Time(start),
		EndTime:   aws.Time(end),
		MetricDataQueries: []*cloudwatch.MetricDataQuery{
			{
				Id: aws.String(eventThroughput),
				MetricStat: &cloudwatch.MetricStat{
					Metric: &cloudwatch.Metric{
						Namespace:  aws.String("AWS/SNS"),
						MetricName: aws.String("NumberOfMessagesPublished"),
						Dimensions: []*cloudwatch.Dimension{
							{
								Name:  aws.String("TopicName"),
								Value: aws.String(topicName),
							},
						},
					},
					Period: aws.Int64(300),
					Stat:   aws.String("Sum"),
				},
			},
			{
				Id: aws.String(eventSize),
				MetricStat: &cloudwatch.MetricStat{
					Metric: &cloudwatch.Metric{
						Namespace:  aws.String("AWS/SNS"),
						MetricName: aws.String("PublishSize"),
						Dimensions: []*cloudwatch.Dimension{
							{
								Name:  aws.String("TopicName"),
								Value: aws.String(topicName),
							},
						},
					},
					Period: aws.Int64(3600),
					Stat:   aws.String("Average"),
				},
			},
		},
	}
}
