package integration_test

import (
	"bytes"
	telemetry "code.justin.tv/amzn/TwitchTelemetry"
	"code.justin.tv/amzn/TwitchTelemetryCloudWatchEMFSender/logger"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/stretchr/testify/assert"
	"os"
	"time"

	"fmt"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
	"testing"
)

const millisecondsInASecond = 1000

func generateEMFLog(timestamp time.Time) string {
	var buff bytes.Buffer
	l := logger.New(&buff, time.Minute, "test_namespace")
	l.PutMetric(&telemetry.Sample{
		MetricID: telemetry.MetricID{
			Name:       "latency",
			Dimensions: map[string]string{"env": "production", "region": "na", "app": "test_app"},
		},
		RollupDimensions: [][]string{{"env"}, {"env", "region"}},
		Timestamp:        timestamp,
		Value:            float64(100),
		Unit:             "Seconds",
	})
	l.Flush()
	return buff.String()
}

func getMetric(svc *cloudwatch.CloudWatch, start time.Time, end time.Time) (*cloudwatch.GetMetricDataOutput, error) {
	resp, err := svc.GetMetricData(&cloudwatch.GetMetricDataInput{
		StartTime: aws.Time(start),
		EndTime:   aws.Time(end),
		MetricDataQueries: []*cloudwatch.MetricDataQuery{
			{
				Id: aws.String("integration_test"),
				MetricStat: &cloudwatch.MetricStat{
					Stat:   aws.String("SampleCount"),
					Period: aws.Int64(60),
					Metric: &cloudwatch.Metric{
						Dimensions: []*cloudwatch.Dimension{
							{
								Name:  aws.String("env"),
								Value: aws.String("production"),
							},
							{
								Name:  aws.String("region"),
								Value: aws.String("na"),
							},
							{
								Name:  aws.String("app"),
								Value: aws.String("test_app"),
							},
						},
						MetricName: aws.String("latency"),
						Namespace:  aws.String("test_namespace"),
					},
				},
			},
		},
	})
	if err != nil {
		return nil, err
	}
	return resp, nil
}

func putLogEvents(svc *cloudwatchlogs.CloudWatchLogs, group *string, stream *string, message *string) (*cloudwatchlogs.PutLogEventsOutput, error) {
	now := time.Now().Unix() * millisecondsInASecond
	req, out := svc.PutLogEventsRequest(&cloudwatchlogs.PutLogEventsInput{
		LogEvents: []*cloudwatchlogs.InputLogEvent{
			{
				Message:   message,
				Timestamp: &now,
			},
		},
		LogGroupName:  group,
		LogStreamName: stream,
	})
	req.HTTPRequest.Header.Set("x-amzn-logs-format", "json/emf")
	err := req.Send()
	if err != nil {
		return nil, err
	}

	return out, nil
}

func TestIntegration(t *testing.T) {
	sess := session.Must(session.NewSessionWithOptions(session.Options{
		SharedConfigState: session.SharedConfigEnable,
	}))
	logService := cloudwatchlogs.New(sess)
	metricService := cloudwatch.New(sess)

	logGroupNameEnv := "INTEGRATION_TEST_LOG_GROUP_NAME"
	logGroupName, exists := os.LookupEnv(logGroupNameEnv)
	if !exists {
		fmt.Printf("Failed to read the following environment variable: %s\n", logGroupNameEnv)
		return
	}

	logGroup := aws.String(logGroupName)
	logStream := aws.String(fmt.Sprintf("integration-test-stream-%d", time.Now().Unix()))
	_, err := logService.CreateLogStream(&cloudwatchlogs.CreateLogStreamInput{
		LogGroupName:  logGroup,
		LogStreamName: logStream,
	})
	if err != nil {
		fmt.Printf("Got error creating log stream:\n%s\n", err)
		return
	}

	timestamp := time.Now().Truncate(time.Minute)
	log := generateEMFLog(timestamp)

	_, err = putLogEvents(logService, logGroup, logStream, &log)
	if err != nil {
		fmt.Printf("Got error putting log events:\n%s\n", err)
		return
	}

	time.Sleep(time.Second * 5)
	metricResp, err := getMetric(metricService, timestamp.Add(-1*time.Minute), timestamp.Add(time.Minute))
	if err != nil {
		fmt.Printf("Got error getting metric:\n%s\n", err)
		return
	}

	metricDataResults := metricResp.MetricDataResults
	assert.Equal(t, 1, len(metricDataResults))
	assert.NotNil(t, metricDataResults[0].Values)
	assert.Greater(t, len(metricDataResults[0].Values), 0)
}
