package metrics

// Functions that take Prometheus remote API structs
// and do Cloudwatch things like reading and writing data
//
// This should be the only place in the code that has to know
// about Cloudwatch and its API

import (
	"bytes"
	"fmt"
	"sort"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/prometheus/prometheus/prompb"
	log "github.com/sirupsen/logrus"
)

var cw *cloudwatch.CloudWatch

func init() {
	s := session.Must(session.NewSession(&aws.Config{Region: aws.String("us-west-2")}))
	cw = cloudwatch.New(s)
}

// func to execute a prometheus query against cloudwatch
// takes in a Prometheus ReadRequest Query, returns
// many metrics
func GetCloudwatchData(q *prompb.Query) (*prompb.QueryResult, error) {

	// validate query and find the metrics to fetch
	metricName, dimensions, statistic, err := parseLabelMatchers(q.Matchers)
	if err != nil {
		return nil, err
	}
	metrics, err := listMetrics("TwitchMetrics", metricName, dimensions) // TODO: configure
	if err != nil {
		return nil, err
	}

	// determines if the request needs multiple calls to GetMetricStatistics
	// as well as the period of the request
	start := time.Unix(0, q.StartTimestampMs*int64(time.Millisecond))
	end := time.Unix(0, q.EndTimestampMs*int64(time.Millisecond))
	period := determinePeriod(start, end)

	// fetch data for each metric and create Timeseries
	timeSeriesSet := []*prompb.TimeSeries{}
	for _, metric := range metrics {
		data, err := getMetricData("TwitchMetrics", metricName, metric.Dimensions, period, start, end, statistic)
		if err != nil {
			return nil, err
		}
		timeSeriesSet = append(timeSeriesSet, createTimeSeries(metric, data, statistic))
	}
	// Create a QueryResult to return from the timeseries collected
	return &prompb.QueryResult{
		Timeseries: timeSeriesSet,
	}, nil

}

func createTimeSeries(metric *cloudwatch.Metric, data []*cloudwatch.Datapoint, statistic string) *prompb.TimeSeries {
	// Generate the prometheus labels for this timeseries
	labels := make([]*prompb.Label, 0, len(metric.Dimensions)+2)
	nameLabel := &prompb.Label{
		Name:  "__name__",
		Value: *metric.MetricName,
	}
	labels = append(labels, nameLabel)
	for _, dimension := range metric.Dimensions {
		l := &prompb.Label{
			Name:  *dimension.Name,
			Value: *dimension.Value,
		}
		labels = append(labels, l)
	}
	// convert datapoints to prometheus samples
	samples := make([]*prompb.Sample, 0, len(data))
	// Get the function that will pull the desired statistic from the datapoints
	extractStatistic := extractStatisticFunc(statistic)
	for _, d := range data {
		sampleValue, err := extractStatistic(d)
		// Skip any datapoint where the sample value cannot be extracted from
		// the datapoint
		if err != nil {
			continue
		}
		sample := &prompb.Sample{
			Value:     sampleValue,
			Timestamp: int64(d.Timestamp.Unix() * 1000), // milliseconds since epoch
		}
		samples = append(samples, sample)
	}
	return &prompb.TimeSeries{
		Labels:  labels,
		Samples: samples,
	}
}

// TODO: make this smarter, using reflection
func sumFunc(d *cloudwatch.Datapoint) (float64, error) {
	if d.Sum != nil {
		return *d.Sum, nil
	} else {
		return 0, fmt.Errorf("no Sum statistic in datapoint")
	}
}
func avgFunc(d *cloudwatch.Datapoint) (float64, error) {
	if d.Average != nil {
		return *d.Average, nil
	} else {
		return 0, fmt.Errorf("no Average statistic in datapoint")
	}
}
func minFunc(d *cloudwatch.Datapoint) (float64, error) {
	if d.Minimum != nil {
		return *d.Minimum, nil
	} else {
		return 0, fmt.Errorf("no Minimum statistic in datapoint")
	}
}
func maxFunc(d *cloudwatch.Datapoint) (float64, error) {
	if d.Maximum != nil {
		return *d.Maximum, nil
	} else {
		return 0, fmt.Errorf("no Maximum statistic in datapoint")
	}
}

func extractStatisticFunc(statistic string) func(d *cloudwatch.Datapoint) (float64, error) {
	switch statistic {
	case "Sum":
		return sumFunc
	case "Average":
		return avgFunc
	case "Minimum":
		return minFunc
	case "Maximum":
		return maxFunc
	default:
		return avgFunc
	}
	// TODO: support percentile extended statistics
}

func listMetrics(namespace string, name string, dimensions []*cloudwatch.DimensionFilter) ([]*cloudwatch.Metric, error) {
	// Create a ListMetricsInput
	lmi := &cloudwatch.ListMetricsInput{
		Namespace:  aws.String(namespace),
		MetricName: aws.String(name),
		Dimensions: dimensions,
	}
	dimString := dimensionFiltersToString(dimensions)
	log.Infof("ListMetrics: namespace=%s, name=%s, dimensions=%s", namespace, name, dimString)

	lmo, err := cw.ListMetrics(lmi)
	if err != nil {
		return nil, err
	}
	for _, m := range lmo.Metrics {
		dimString := dimensionsToString(m.Dimensions)
		log.Infof("Found metric: namespace=%s, name=%s, dimensions=%s", *m.Namespace, *m.MetricName, dimString)
	}
	// TODO: configure
	if len(lmo.Metrics) > 25 {
		return nil, fmt.Errorf("too many metrics matching query: %d", len(lmo.Metrics))
	}
	return lmo.Metrics, nil
}

func getMetricData(namespace, metricName string, dimensions []*cloudwatch.Dimension, period int64, start, end time.Time, statistic string) ([]*cloudwatch.Datapoint, error) {
	// Create GetMetricStatistics API call struct
	gmsi := &cloudwatch.GetMetricStatisticsInput{
		Namespace:  aws.String(namespace),
		MetricName: aws.String(metricName),
		Dimensions: dimensions,
		Period:     aws.Int64(period),
		StartTime:  aws.Time(start),
		EndTime:    aws.Time(end),
		Statistics: aws.StringSlice([]string{statistic}),
	}
	// Call the cloudwatch API and get results
	dimString := dimensionsToString(dimensions)
	log.Infof("GetMetricStatistics: namespace=%s, metricName=%s, dimensions=%s, period=%d, start=%s, end=%s", namespace, metricName, dimString, period, start, end)
	gmso, err := cw.GetMetricStatistics(gmsi)
	if err != nil {
		return nil, err
	}
	log.Infof("GetMetricStatistics success, found %d datapoints", len(gmso.Datapoints))
	// Sort the dataset before returning... cloudwatch makes no promises about
	// data ordering
	sort.Sort(DataSet(gmso.Datapoints))
	return gmso.Datapoints, nil
}

// some helper function to validate the dimensions of
// the given Prometheus ReadRequest.
// Need the labels:
// - __name__
// - namespace
// - statistic
// Additionally, the read implementation cannot support regex (and not-equals) matching at this time,
// so check for non-equality matching
func parseLabelMatchers(matchers []*prompb.LabelMatcher) (string, []*cloudwatch.DimensionFilter, string, error) {
	foundName := false
	foundStatistic := false
	var metricName string
	var statistic string
	dimensions := make([]*cloudwatch.DimensionFilter, 0, len(matchers))
	for _, lm := range matchers {
		if lm.GetType().String() != "EQ" {
			return "", nil, "", fmt.Errorf("query must contain only exact equality matchers (=)")
		}
		// Look for required dimensions, and create filters
		switch label := lm.Name; label {
		case "__name__":
			foundName = true
			metricName = lm.Value
			continue
		case "statistic":
			foundStatistic = true
			statistic = lm.Value
			continue
		}
		d := &cloudwatch.DimensionFilter{
			Name:  aws.String(lm.Name),
			Value: aws.String(lm.Value),
		}
		dimensions = append(dimensions, d)
	}
	// Ensure that we found the required labels
	if !foundName {
		return "", nil, "", fmt.Errorf("a metric name must be specified (e.g. my_metric_name{statistic=\"sum\"})")
	}
	if !foundStatistic {
		statistic = "Average"
	}
	return metricName, dimensions, statistic, nil
}

// helper function to choose a data period to request based
// on the start and end times of the Prometheus ReadRequest
func determinePeriod(start, end time.Time) int64 {
	now := time.Now()
	var minPeriod int64
	// first determine the period
	if start.Before(now.Add(-63 * time.Hour * 24)) {
		minPeriod = 3600
	} else if start.Before(now.Add(-15 * time.Hour * 24)) {
		minPeriod = 300
	} else {
		minPeriod = 60
	}
	// adjust the period to meet the 1440 returned datapoint max
	period := minPeriod
	datapointsReturned := int64(end.Sub(start).Seconds())/period + 1
	for datapointsReturned > 1440 {
		period = period + minPeriod
		datapointsReturned = int64(end.Sub(start).Seconds())/period + 1
	}
	return period
}

// func to write timeseries data into cloudwatch. Takes
// a list of prompb Timeseries structs and writes the datapoints
// into cloudwatch
// Cloudwatch only supports 20 metrics per call, so batch on that size
func PutCloudwatchData(tss []*prompb.TimeSeries) error {
	// each timeseries can have multiple samples, so enumerate over the nested samples
	// in batches of 20 per PutMetricData call
	// start by flattening the list of samples, converting them to datapoints as we go,
	// rejecting any that arent valid
	metricData := make([]*cloudwatch.MetricDatum, 0)
	errors := make(PutErrors, 0)
	for _, timeseries := range tss {
		// Extract name and dimensional data from the timeseries
		metricName, dimensions, err := parseLabels(timeseries.Labels)
		// Swallow the error and skip the timeseries
		if err != nil {
			errors = append(errors, err)
			continue
		}
		// Convert each sample in the timeseries to a cloudwatch datapoint
		for _, sample := range timeseries.Samples {
			metricDatum := &cloudwatch.MetricDatum{
				MetricName: aws.String(metricName),
				Dimensions: dimensions,
				Timestamp:  aws.Time(time.Unix(0, sample.Timestamp*int64(time.Millisecond))),
				Value:      aws.Float64(sample.Value),
			}
			metricData = append(metricData, metricDatum)
		}
	}
	// Batch samples into 20 each, and ship them to cloudwatch
	numBatches := int(len(metricData)) / 20
	if len(metricData)%20 != 0 {
		numBatches++
	}
	for i := 0; i < numBatches; i++ {
		var batch []*cloudwatch.MetricDatum
		if i == numBatches-1 {
			batch = metricData[i*20:]
		} else {
			batch = metricData[i*20 : (i+1)*20]
		}
		pmdi := &cloudwatch.PutMetricDataInput{
			Namespace:  aws.String("TwitchMetrics"), // TODO: configure
			MetricData: batch,
		}
		_, err := cw.PutMetricData(pmdi)
		// Don't immediately fail, but swallow the error
		// for returning at the end
		if err != nil {
			errors = append(errors, err)
		}
	}
	// If we hit errors along the way, return them
	if len(errors) != 0 {
		return errors
	}
	return nil
}

// Given a set of Labels, hands back a MetricName and set of dimensions
func parseLabels(labels []*prompb.Label) (string, []*cloudwatch.Dimension, error) {
	foundName := false
	var metricName string
	dimensions := make([]*cloudwatch.Dimension, 0, len(labels)-1)
	// Cloudwatch only supports 10 dimensions
	// so check for 10 dimensions, on top of the expected __name__ which
	// becomes the MetricName
	if len(labels) > 11 {
		return "", nil, fmt.Errorf("too many dimensions to store data in cloudwatch: %s", labelsToString(labels))
	}
	for _, l := range labels {
		// handle special case for the __name__ label
		switch l.Name {
		case "__name__":
			foundName = true
			metricName = l.Value
			continue
		}
		// Convert the Label into a dimension
		d := &cloudwatch.Dimension{
			Name:  aws.String(l.Name),
			Value: aws.String(l.Value),
		}
		dimensions = append(dimensions, d)
	}
	// Ensure that we found the __name__ label
	if !foundName {
		return "", nil, fmt.Errorf("'__name__' label not found for ingested sample")
	}
	return metricName, dimensions, nil

}

func labelsToString(labels []*prompb.Label) string {
	var buffer bytes.Buffer
	buffer.WriteString("{")
	for _, label := range labels {
		buffer.WriteString(fmt.Sprintf("%s:'%s',", label.Name, label.Value))
	}
	buffer.WriteString("}")
	return buffer.String()
}
func dimensionFiltersToString(filters []*cloudwatch.DimensionFilter) string {
	var buffer bytes.Buffer
	buffer.WriteString("{")
	for _, filter := range filters {
		buffer.WriteString(fmt.Sprintf("%s:'%s',", *filter.Name, *filter.Value))
	}
	buffer.WriteString("}")
	return buffer.String()
}
func dimensionsToString(dimensions []*cloudwatch.Dimension) string {
	var buffer bytes.Buffer
	buffer.WriteString("{")
	for _, dimension := range dimensions {
		buffer.WriteString(fmt.Sprintf("%s:'%s',", *dimension.Name, *dimension.Value))
	}
	buffer.WriteString("}")
	return buffer.String()
}

// HELPERS

// Utility type for stashing multiple PutMetricData
// errors
type PutErrors []error

func (e PutErrors) Error() string {
	if len(e) == 1 {
		return e[0].Error()
	}

	msg := "multiple put errors:"
	for _, err := range e {
		msg += "\n" + err.Error()
	}
	return msg
}

// Utility type for implementing sorting on
// top of cloudwatch returned datapoints
type DataSet []*cloudwatch.Datapoint

func (ds DataSet) Len() int           { return len(ds) }
func (ds DataSet) Swap(i, j int)      { ds[i], ds[j] = ds[j], ds[i] }
func (ds DataSet) Less(i, j int) bool { return ds[i].Timestamp.Before(*ds[j].Timestamp) }
