package metrics

import (
	"fmt"
	"testing"
	"time"

	"code.justin.tv/video/lvsapi/internal/metrics/mwsfakes"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/stretchr/testify/assert"
)

const (
	testService    = "lvsapi-test"
	testHostname   = "hostname"
	testEnv        = "test-env"
	testAWSID      = "AWS_FAKE_ID"
	testAWSKSecret = "AWS_FAKE_SECRET"
)

var testCreds = credentials.NewStaticCredentials(testAWSID, testAWSKSecret, "")
var tickerChan chan time.Time

// Create a new metrics client by configuring all its dependencies as mocks
func createNewTestMetricsClient() *Client {
	ticker := time.NewTicker(time.Minute)
	tickerChan = make(chan time.Time)
	ticker.C = tickerChan
	defaultStatsTicker = func() *time.Ticker {
		return ticker
	}

	mwsfakes.FakeMWSClientInternal = mwsfakes.FakeIAmazonMWSGoClient{}
	defaultMWSClient = mwsfakes.FakeMWSClient{}
	return New(testService, testHostname, testEnv, testCreds)
}

// Validate the common set of data that each metric/report has, which is env, service, host
// NOTE: this needs to be called after Stop()
func assetCommon(t *testing.T) {
	callCount := mwsfakes.FakeMWSClientInternal.PutMetricsForAggregationCallCount()
	for i := 0; i < callCount; i++ {
		req, _, credsUsed := mwsfakes.FakeMWSClientInternal.PutMetricsForAggregationArgsForCall(i)
		if credsUsed != testCreds {
			t.Fatalf("Metrics recording, expected credentials: %v, actual credentials: %v", testCreds, credsUsed)
		}
		assert.Equal(t, testHostname, req.MetricReports[0].Metadata.Host)
		assert.Equal(t, testEnv, req.MetricReports[0].Metadata.Environment)
		assert.Equal(t, testService, req.MetricReports[0].Metrics[0].Dimensions.ServiceName)
	}
}

func TestMetricsDoesNotBlockOnFullBuffer(t *testing.T) {
	c := createNewTestMetricsClient()
	for i := 0; i < bufferSize; i++ {
		err := c.Record("TestClient", "TestMethod", "TestTime", UnitMillisecond, 0.1)
		if err != nil {
			t.Fatalf("Metric should have been successfully queued up but got an error: %v", err)
		}
	}

	// The buffer is full, we should return an error and not block
	for i := 0; i < 2; i++ {
		err := c.Record("TestClient", "TestMethod", "TestTime", UnitMillisecond, 0.1)
		if err == nil {
			t.Fatalf("The metric should have failed to queue up")
		}
	}
	c.Stop()
	assetCommon(t)
}

func TestEnsureRecordIsSubmitted(t *testing.T) {
	clientID := "test-client-id"
	method := "test-method"
	metric := "test-metric"
	unit := UnitMillisecond
	value := 0.1

	c := createNewTestMetricsClient()
	c.Start()
	err := c.Record(clientID, method, metric, unit, value)
	if err != nil {
		t.Fatalf("Metric should have been queued successfully: %v", err)
	}
	c.Stop()
	assetCommon(t)

	callCount := mwsfakes.FakeMWSClientInternal.PutMetricsForAggregationCallCount()
	if callCount != 1 {
		t.Fatalf("PutMetricsForAggregation() should have been called exactly once but called '%d' times", callCount)
	}

	req, _, _ := mwsfakes.FakeMWSClientInternal.PutMetricsForAggregationArgsForCall(0)

	assert.Equal(t, clientID, req.MetricReports[0].Metadata.CustomerID)
	assert.Equal(t, method, req.MetricReports[0].Metrics[0].Dimensions.MethodName)
	assert.Equal(t, metric, req.MetricReports[0].Metrics[0].MetricName)
	assert.Equal(t, unit, req.MetricReports[0].Metrics[0].Unit)
	assert.Equal(t, value, req.MetricReports[0].Metrics[0].Values[0])
}

func TestStatsRecording(t *testing.T) {
	clientID := "test-client-id"
	method := "test-method"
	metric := "test-metric"
	unit := UnitMillisecond
	value := 0.1

	c := createNewTestMetricsClient()

	// Make sure we drop one metric (since the MWS client is not started)
	for i := 0; i < bufferSize; i++ {
		err := c.Record(clientID, method, fmt.Sprintf("%s_%d", metric, i), unit, value)
		if err != nil {
			t.Fatalf("Metric should have been queued successfully: %v", err)
		}
	}
	err := c.Record(clientID, method, metric+"_fail", unit, value)
	if err == nil {
		t.Fatalf("Metric should have been dropped")
	}
	c.Start()

	// Trigger the stats ticker and cause it to report stats about how many metrics
	// have been reported.
	tickerChan <- time.Now()

	c.Stop()
	assetCommon(t)

	// Ensure we have recorded our metrics as well as the 4 internal stats we report
	callCount := mwsfakes.FakeMWSClientInternal.PutMetricsForAggregationCallCount()
	if callCount != bufferSize+4 {
		t.Fatalf("PutMetricsForAggregation() should have been called bufferSize + 4 times, but was called %d times", callCount)
	}

	// Read all the metrics that were reported and validate the stats we expected to see
	for i := 0; i <= 35; i++ {
		req, _, _ := mwsfakes.FakeMWSClientInternal.PutMetricsForAggregationArgsForCall(i)

		// If QueueCallCount validate that value is bufferSize+1
		if req.MetricReports[0].Metrics[0].MetricName == "QueueCallCount" {
			assert.Equal(t, float64(bufferSize+1), req.MetricReports[0].Metrics[0].Values[0])
		}

		if req.MetricReports[0].Metrics[0].MetricName == "QueueDropCount" {
			assert.Equal(t, float64(1), req.MetricReports[0].Metrics[0].Values[0])
		}
	}
}
