package metrics

import (
	"reflect"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/prometheus/prometheus/prompb"
	"github.com/stretchr/testify/assert"
)

func TestDeterminePeriod(t *testing.T) {
	var start time.Time
	var end time.Time
	var now time.Time
	var period int64
	now = time.Now()

	start = now.Add(-5 * time.Minute)
	end = now
	period = determinePeriod(start, end)
	assert.Equal(t, int64(60), period)

	start = now.Add(-8 * 24 * time.Hour)
	period = determinePeriod(start, end)
	assert.Equal(t, int64(540), period)

	start = now.Add(-20 * 24 * time.Hour)
	end = now.Add(-19 * 24 * time.Hour)
	period = determinePeriod(start, end)
	assert.Equal(t, int64(300), period)

	start = now.Add(-30 * 24 * time.Hour)
	end = start.Add(7 * 24 * time.Hour)
	period = determinePeriod(start, end)
	assert.Equal(t, int64(600), period)

	start = now.Add(-60 * 24 * time.Hour)
	end = start.Add(16 * 24 * time.Hour)
	period = determinePeriod(start, end)
	assert.Equal(t, int64(1200), period)

	start = now.Add(-100 * 24 * time.Hour)
	end = start.Add(1 * 24 * time.Hour)
	period = determinePeriod(start, end)
	assert.Equal(t, int64(3600), period)

	start = now.Add(-200 * 24 * time.Hour)
	end = start.Add(100 * 24 * time.Hour)
	period = determinePeriod(start, end)
	assert.Equal(t, int64(7200), period)
}

// func createTimeSeries(metric *cloudwatch.Metric, data []*cloudwatch.Datapoint, statistic string) *prompb.TimeSeries {
func TestCreateTimeseries(t *testing.T) {
	// have test where statistic doesnt match whats in the datapoints, and the resulting timeseries should have no samples (not nil pointer exception!)
	// Create some dummy metric
	m := &cloudwatch.Metric{
		Namespace:  aws.String("garply"),
		MetricName: aws.String("foobar_requests"),
		Dimensions: []*cloudwatch.Dimension{
			&cloudwatch.Dimension{
				Name:  aws.String("floop"),
				Value: aws.String("doop"),
			},
			&cloudwatch.Dimension{
				Name:  aws.String("herp"),
				Value: aws.String("derp"),
			},
		},
	}
	// Generate some datapoints
	datapoints := generateDatapoints(time.Now().Add(-5*time.Hour), 60, 50, "Average")
	// Walk through a typical use case
	ts := createTimeSeries(m, datapoints, "Average")
	assert.Equal(t, len(ts.Samples), 50)
	labelsMap := mapLabels(ts.Labels)
	assert.Equal(t, "foobar_requests", labelsMap["__name__"])
	assert.Equal(t, "doop", labelsMap["floop"])
	assert.Equal(t, "derp", labelsMap["herp"])

	// Run the above for each statistic
	datapointsSum := generateDatapoints(time.Now().Add(-5*time.Hour), 60, 50, "Sum")
	datapointsMaximum := generateDatapoints(time.Now().Add(-5*time.Hour), 60, 50, "Maximum")
	datapointsMinimum := generateDatapoints(time.Now().Add(-5*time.Hour), 60, 50, "Minimum")

	// Sum
	ts = createTimeSeries(m, datapointsSum, "Sum")
	assert.Equal(t, len(ts.Samples), 50)

	// Maximum
	ts = createTimeSeries(m, datapointsMaximum, "Maximum")
	assert.Equal(t, len(ts.Samples), 50)

	// Minimum
	ts = createTimeSeries(m, datapointsMinimum, "Minimum")
	assert.Equal(t, len(ts.Samples), 50)

	// If you try to extract a statistic that doesnt exist... the method should just
	// not return any samples (len(samples) == 0), but not throw any panics or anything
	datapoints = generateDatapoints(time.Now().Add(-5*time.Hour), 60, 50, "Average")
	ts = createTimeSeries(m, datapoints, "Sum")
	assert.Equal(t, len(ts.Samples), 0)
}

func TestStatisticExtraction(t *testing.T) {
	// for each statistic type, test extracting the statistic
	// from both a valid and invalid datapoint
	timestamp := time.Now()
	val := float64(42)
	var err error
	var output float64
	var dValid *cloudwatch.Datapoint
	var dInvalid *cloudwatch.Datapoint
	// Sum
	dValid = generateDatapoint(timestamp, val, "Sum")
	dInvalid = generateDatapoint(timestamp, val, "Average")
	output, err = sumFunc(dValid)
	assert.Equal(t, val, output)
	assert.Nil(t, err)
	output, err = sumFunc(dInvalid)
	assert.Equal(t, float64(0), output)
	assert.NotNil(t, err)

	//Average
	dValid = generateDatapoint(timestamp, val, "Average")
	dInvalid = generateDatapoint(timestamp, val, "Maximum")
	output, err = avgFunc(dValid)
	assert.Equal(t, val, output)
	assert.Nil(t, err)
	output, err = avgFunc(dInvalid)
	assert.Equal(t, float64(0), output)
	assert.NotNil(t, err)

	//Max
	dValid = generateDatapoint(timestamp, val, "Maximum")
	dInvalid = generateDatapoint(timestamp, val, "Minimum")
	output, err = maxFunc(dValid)
	assert.Equal(t, val, output)
	assert.Nil(t, err)
	output, err = maxFunc(dInvalid)
	assert.Equal(t, float64(0), output)
	assert.NotNil(t, err)

	//Min
	dValid = generateDatapoint(timestamp, val, "Minimum")
	dInvalid = generateDatapoint(timestamp, val, "Sum")
	output, err = minFunc(dValid)
	assert.Equal(t, val, output)
	assert.Nil(t, err)
	output, err = minFunc(dInvalid)
	assert.Equal(t, float64(0), output)
	assert.NotNil(t, err)

}

func TestParseLabelMatchers(t *testing.T) {
	//func parseLabelMatchers(matchers []*prompb.LabelMatcher) (string, []*cloudwatch.DimensionFilter, string, error) {
	var name string
	var filters []*cloudwatch.DimensionFilter
	var statistic string
	var err error
	var labelMatchers []*prompb.LabelMatcher

	// Valid label matchers
	labelMatchers = []*prompb.LabelMatcher{
		&prompb.LabelMatcher{
			Type:  0,
			Name:  "__name__",
			Value: "foobar_requests",
		},
		&prompb.LabelMatcher{
			Type:  0,
			Name:  "statistic",
			Value: "Maximum",
		},
		&prompb.LabelMatcher{
			Type:  0,
			Name:  "environment",
			Value: "staging",
		},
	}
	name, filters, statistic, err = parseLabelMatchers(labelMatchers)
	assert.Equal(t, "foobar_requests", name)
	assert.Equal(t, "Maximum", statistic)
	assert.Equal(t, "environment", *filters[0].Name)
	assert.Equal(t, "staging", *filters[0].Value)
	assert.Nil(t, err)

	// valid label matchers -- use default statistic
	labelMatchers = []*prompb.LabelMatcher{
		&prompb.LabelMatcher{
			Type:  0,
			Name:  "__name__",
			Value: "foobar_requests",
		},
		&prompb.LabelMatcher{
			Type:  0,
			Name:  "environment",
			Value: "staging",
		},
	}
	name, filters, statistic, err = parseLabelMatchers(labelMatchers)
	assert.Equal(t, "Average", statistic)

	// Invalid label matchers -- no name specified
	labelMatchers = []*prompb.LabelMatcher{
		&prompb.LabelMatcher{
			Type:  0,
			Name:  "statistic",
			Value: "Average",
		},
		&prompb.LabelMatcher{
			Type:  0,
			Name:  "environment",
			Value: "staging",
		},
	}
	name, filters, statistic, err = parseLabelMatchers(labelMatchers)
	assert.NotNil(t, err)

	// Invalid label matchers -- non-equality matcher used
	labelMatchers = []*prompb.LabelMatcher{
		&prompb.LabelMatcher{
			Type:  0,
			Name:  "__name__",
			Value: "foobar_requests",
		},
		&prompb.LabelMatcher{
			Type:  0,
			Name:  "statistic",
			Value: "Maximum",
		},
		&prompb.LabelMatcher{
			Type:  2,
			Name:  "environment",
			Value: "staging",
		},
	}
	name, filters, statistic, err = parseLabelMatchers(labelMatchers)
	assert.NotNil(t, err)

}

func TestParseLabels(t *testing.T) {
	//func parseLabels(labels []*prompb.Label) (string, []*cloudwatch.Dimension, error) {
	var labels []*prompb.Label
	var name string
	var dimensions []*cloudwatch.Dimension
	var err error

	// Valid labels
	labels = []*prompb.Label{
		&prompb.Label{
			Name:  "__name__",
			Value: "cool_service_latency",
		},
		&prompb.Label{
			Name:  "environment",
			Value: "darklaunch",
		},
		&prompb.Label{
			Name:  "method",
			Value: "GET",
		},
	}
	name, dimensions, err = parseLabels(labels)
	assert.Equal(t, "cool_service_latency", name)
	assert.Equal(t, 2, len(dimensions))
	assert.Nil(t, err)

	// Invalid labels -- no __name__ provided
	labels = []*prompb.Label{
		&prompb.Label{
			Name:  "environment",
			Value: "darklaunch",
		},
		&prompb.Label{
			Name:  "method",
			Value: "GET",
		},
		&prompb.Label{
			Name:  "status_code",
			Value: "5XX",
		},
	}
	name, dimensions, err = parseLabels(labels)
	assert.NotNil(t, err)

	// Valid labels -- max supported labels
	labels = []*prompb.Label{
		&prompb.Label{
			Name:  "__name__",
			Value: "cool_service_latency",
		},
		&prompb.Label{
			Name:  "environment",
			Value: "darklaunch",
		},
		&prompb.Label{
			Name:  "method",
			Value: "GET",
		},
		&prompb.Label{
			Name:  "status_code",
			Value: "5XX",
		},
		&prompb.Label{
			Name:  "another_label1",
			Value: "val",
		},
		&prompb.Label{
			Name:  "another_label2",
			Value: "val",
		},
		&prompb.Label{
			Name:  "another_label3",
			Value: "val",
		},
		&prompb.Label{
			Name:  "another_label4",
			Value: "val",
		},
		&prompb.Label{
			Name:  "another_label5",
			Value: "val",
		},
		&prompb.Label{
			Name:  "another_label6",
			Value: "val",
		},
		&prompb.Label{
			Name:  "another_label7",
			Value: "val",
		},
	}
	name, dimensions, err = parseLabels(labels)
	assert.Equal(t, 10, len(dimensions))
	assert.Nil(t, err)

	// Invalid labels -- too many labels
	// just add one to the last label set
	labels = append(labels, &prompb.Label{Name: "another_label8", Value: "val"})

}

// Helper to turn labels into an addressable map
func mapLabels(labels []*prompb.Label) map[string]string {
	m := make(map[string]string)
	for _, l := range labels {
		m[l.Name] = l.Value
	}
	return m
}

// Helper to make a bunch of datapoints of the right statistic type
func generateDatapoints(start time.Time, period int64, n int, statistic string) []*cloudwatch.Datapoint {
	data := make([]*cloudwatch.Datapoint, 0)
	t := start
	for i := 0; i < n; i++ {
		d := generateDatapoint(t, float64(5), statistic)
		t = t.Add(time.Duration(period) * time.Second)
		data = append(data, d)
	}
	return data
}

func generateDatapoint(timestamp time.Time, val float64, statistic string) *cloudwatch.Datapoint {
	d := &cloudwatch.Datapoint{
		Timestamp: aws.Time(timestamp),
	}
	reflect.ValueOf(d).Elem().FieldByName(statistic).Set(reflect.ValueOf(aws.Float64(val)))
	return d
}
