package eventstreamstats

import (
	"context"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface"
	"github.com/pkg/errors"
	"github.com/stretchr/testify/assert"
)

// mockCloudWatchClient allows us to easily control the values returned by cloudwatch:GetMetricData API calls
type mockCloudwatchClient struct {
	cloudwatchiface.CloudWatchAPI
	throughputVals []float64
	sizeVals       []float64
	shouldError    bool
}

func (m *mockCloudwatchClient) SetThroughputValues(vals []float64) {
	m.throughputVals = vals
}

func (m *mockCloudwatchClient) SetSizeValues(vals []float64) {
	m.sizeVals = vals
}

func (m *mockCloudwatchClient) ShouldError(b bool) {
	m.shouldError = b
}

func (m *mockCloudwatchClient) GetMetricDataWithContext(ctx context.Context, input *cloudwatch.GetMetricDataInput, opts ...request.Option) (*cloudwatch.GetMetricDataOutput, error) {
	if m.shouldError {
		return nil, errors.New("ERROR")
	}
	output := &cloudwatch.GetMetricDataOutput{
		MetricDataResults: make([]*cloudwatch.MetricDataResult, 0),
	}
	if m.throughputVals != nil {
		output.MetricDataResults = append(output.MetricDataResults, metricDataResponse(eventThroughput, m.throughputVals))
	}
	if m.sizeVals != nil {
		output.MetricDataResults = append(output.MetricDataResults, metricDataResponse(eventSize, m.sizeVals))
	}
	return output, nil
}

func metricDataResponse(id string, v []float64) *cloudwatch.MetricDataResult {
	ts := time.Now().Add(time.Hour * -24)
	times := make([]*time.Time, len(v))
	vals := make([]*float64, len(v))
	for i := range v {
		vals[i] = aws.Float64(v[i])
		times[i] = aws.Time(ts.Add(time.Duration(i) * time.Minute * 5))
	}
	return &cloudwatch.MetricDataResult{
		Id:         aws.String(id),
		Timestamps: times,
		Values:     vals,
	}

}

func TestCloudwatchClientMath(t *testing.T) {
	arnPresent := "sns:arn:sns:arn:sns:arn"
	arnEmpty := ""
	tableTests := []struct {
		name              string
		snsTopicARN       string
		inThroughputVals  []float64
		inSizeVals        []float64
		outThroughputMax  *float64
		outThroughputMin  *float64
		outThroughputMean *float64
		outSizeMean       *float64
	}{
		{"Basic", arnPresent, []float64{20, 30, 10, 40}, []float64{1, 2, 3}, fP(8), fP(2), fP(5), fP(2)},
		{"NoSizeData", arnPresent, []float64{4, 5, 6}, nil, fP(1.2), fP(.8), fP(1), nil},
		{"NoThroughputData", arnPresent, nil, []float64{4, 5, 6}, nil, nil, nil, fP(5)},
		{"NoData", arnPresent, nil, nil, nil, nil, nil, nil},
		{"NoARN", arnEmpty, []float64{1}, []float64{1}, nil, nil, nil, nil}, // use return mock values to demonstrate that a shortcircuit happens
	}

	for _, tt := range tableTests {
		t.Run(tt.name, func(t *testing.T) {
			m := &mockCloudwatchClient{
				throughputVals: tt.inThroughputVals,
				sizeVals:       tt.inSizeVals,
			}
			client := NewCloudwatchClient(m)
			out, err := client.Fetch(context.Background(), tt.snsTopicARN)
			assert.NoError(t, err)
			// Check each field for either equality or nil based on expectation
			outCheck := func(expect, actual *float64) {
				if expect != nil {
					assert.Equal(t, f(expect), f(actual))
				} else {
					assert.Nil(t, actual)
				}
			}
			outCheck(tt.outThroughputMax, out.MaxEventsPerMinute)
			outCheck(tt.outThroughputMin, out.MinEventsPerMinute)
			outCheck(tt.outThroughputMean, out.MeanEventsPerMinute)
			outCheck(tt.outSizeMean, out.MeanEventPayloadSize)
		})
	}
}

func TestCloudwatchClientErorrs(t *testing.T) {
	goodARN := "arn:sns:arn:sns:arn:sns" // contains 6 parts
	badARN := "this is not a valid ARN"
	m := &mockCloudwatchClient{}
	m.ShouldError(false)
	client := NewCloudwatchClient(m)

	_, err := client.Fetch(context.Background(), badARN)
	assert.Error(t, err)

	m.ShouldError(true)
	_, err = client.Fetch(context.Background(), goodARN)
	assert.Error(t, err)
}

// pointerifies float64s
func fP(f float64) *float64 { return &f }

// unpointerifies *float64s
func f(f *float64) float64 { return *f }
