package middleware

import (
	"context"
	"fmt"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	identifier "code.justin.tv/amzn/TwitchProcessIdentifier"
	telemetry "code.justin.tv/amzn/TwitchTelemetry"
	"code.justin.tv/video/metrics-middleware/v2/operation"
	. "github.com/smartystreets/goconvey/convey"
)

var testProcessIdentifier = identifier.ProcessIdentifier{
	Service:  "TestService",
	Stage:    "TestStage",
	Substage: "TestSubstage",
	Region:   "TestRegion",
	Machine:  "TestMachine",
	LaunchID: "TestLaunchID",
	Version:  "Version",
}

var testDependencyProcessIdentifier = identifier.ProcessIdentifier{
	Service:  "DependencyService",
	Stage:    "DependencyStage",
	Substage: "DependencySubstage",
	Region:   "DependencyRegion",
}

func TestOperationMonitor(t *testing.T) {
	Convey("Given an operation starter with an operation monitor", t, func() {
		capturedSamples := &sampleObserver{}

		metricsOpMonitor := &OperationMonitor{
			SampleReporter: telemetry.SampleReporter{
				SampleBuilder:  telemetry.SampleBuilder{ProcessIdentifier: testProcessIdentifier},
				SampleObserver: capturedSamples,
			},
		}

		operationStarter := &operation.Starter{OpMonitors: []operation.OpMonitor{metricsOpMonitor}}

		Convey("And a server operation", func() {
			serverOpName := operation.Name{Kind: operation.KindServer, Group: testProcessIdentifier.Service, Method: "MyOperation"}
			serverCtx, serverOp := operationStarter.StartOp(context.Background(), serverOpName)

			Convey("When successful", func() {
				serverOp.SetStatus(operation.Status{Code: 0})
				serverOp.End()

				Convey("Then it should build 4 samples", func() {
					So(len(capturedSamples.Samples), ShouldEqual, 4)
					for _, sample := range capturedSamples.Samples {
						So(sample.MetricID.Dimensions, ShouldResemble, telemetry.DimensionSet{
							"Service":   "TestService",
							"Stage":     "TestStage",
							"Substage":  "TestSubstage",
							"Region":    "TestRegion",
							"Operation": "MyOperation",
						})
						So(sample.RollupDimensions, ShouldNotBeEmpty)
						So(time.Now().Sub(sample.Timestamp).Seconds(), ShouldBeBetween, -10, 10)
					}
				})

				Convey("Then it should build a duration metric", func() {
					So(capturedSamples.Samples[0].MetricID.Name, ShouldEqual, telemetry.MetricDuration)
					So(capturedSamples.Samples[0].Value, ShouldBeBetween, 0, 1)
					So(capturedSamples.Samples[0].Unit, ShouldEqual, "Seconds")
				})

				Convey("Then it should build success availability metrics", func() {
					So(capturedSamples.Samples[1].MetricID.Name, ShouldEqual, telemetry.MetricSuccess)
					So(capturedSamples.Samples[1].Value, ShouldEqual, 1)
					So(capturedSamples.Samples[1].Unit, ShouldEqual, "Count")

					So(capturedSamples.Samples[2].MetricID.Name, ShouldEqual, telemetry.MetricClientError)
					So(capturedSamples.Samples[2].Value, ShouldEqual, 0)
					So(capturedSamples.Samples[2].Unit, ShouldEqual, "Count")

					So(capturedSamples.Samples[3].MetricID.Name, ShouldEqual, telemetry.MetricServerError)
					So(capturedSamples.Samples[3].Value, ShouldEqual, 0)
					So(capturedSamples.Samples[3].Unit, ShouldEqual, "Count")
				})

				Convey("Then it should not call flush", func() {
					So(capturedSamples.FlushCallCount, ShouldEqual, 0)
				})
			})

			Convey("When it fails with an invalid argument", func() {
				serverOp.SetStatus(operation.Status{Code: 3})
				serverOp.End()

				Convey("Then it should build client error availability metrics", func() {
					So(len(capturedSamples.Samples), ShouldEqual, 4)
					So(capturedSamples.Samples[1].MetricID.Name, ShouldEqual, telemetry.MetricSuccess)
					So(capturedSamples.Samples[1].Value, ShouldEqual, 0)

					So(capturedSamples.Samples[2].MetricID.Name, ShouldEqual, telemetry.MetricClientError)
					So(capturedSamples.Samples[2].Value, ShouldEqual, 1)

					So(capturedSamples.Samples[3].MetricID.Name, ShouldEqual, telemetry.MetricServerError)
					So(capturedSamples.Samples[3].Value, ShouldEqual, 0)
				})
			})

			Convey("When it fails with an internal error", func() {
				serverOp.SetStatus(operation.Status{Code: 13})
				serverOp.End()

				Convey("Then it should build internal error availability metrics", func() {
					So(len(capturedSamples.Samples), ShouldEqual, 4)
					So(capturedSamples.Samples[1].MetricID.Name, ShouldEqual, telemetry.MetricSuccess)
					So(capturedSamples.Samples[1].Value, ShouldEqual, 0)

					So(capturedSamples.Samples[2].MetricID.Name, ShouldEqual, telemetry.MetricClientError)
					So(capturedSamples.Samples[2].Value, ShouldEqual, 0)

					So(capturedSamples.Samples[3].MetricID.Name, ShouldEqual, telemetry.MetricServerError)
					So(capturedSamples.Samples[3].Value, ShouldEqual, 1)
				})
			})

			Convey("When it fails with an unknown error", func() {
				serverOp.SetStatus(operation.Status{Code: 123456789})
				serverOp.End()

				Convey("Then it should build internal error availability metrics", func() {
					So(len(capturedSamples.Samples), ShouldEqual, 4)
					So(capturedSamples.Samples[1].MetricID.Name, ShouldEqual, telemetry.MetricSuccess)
					So(capturedSamples.Samples[1].Value, ShouldEqual, 0)

					So(capturedSamples.Samples[2].MetricID.Name, ShouldEqual, telemetry.MetricClientError)
					So(capturedSamples.Samples[2].Value, ShouldEqual, 0)

					So(capturedSamples.Samples[3].MetricID.Name, ShouldEqual, telemetry.MetricServerError)
					So(capturedSamples.Samples[3].Value, ShouldEqual, 1)
				})
			})

			Convey("And a span operation", func() {
				spanOp := operation.Name{Kind: operation.KindSpan, Group: testDependencyProcessIdentifier.Service, Method: "TestSpan"}
				_, internalSpan := operationStarter.StartOp(serverCtx, spanOp)

				Convey("Then the span should record a timer metric", func() {
					time.Sleep(time.Millisecond)
					internalSpan.End()

					So(len(capturedSamples.Samples), ShouldEqual, 1)
					sample := capturedSamples.Samples[0]
					So(sample.MetricID, ShouldResemble, telemetry.MetricID{
						Name: "TestSpan",
						Dimensions: telemetry.DimensionSet{
							"Service":   "TestService",
							"Stage":     "TestStage",
							"Substage":  "TestSubstage",
							"Region":    "TestRegion",
							"Operation": "MyOperation",
						},
					})
					So(sample.RollupDimensions, ShouldNotBeEmpty)
					So(time.Now().Sub(sample.Timestamp).Seconds(), ShouldBeBetween, -10, 10)
					So(sample.Value, ShouldBeBetween, 0, 10)
					So(sample.Unit, ShouldEqual, "Seconds")
				})
			})

			Convey("And a client operation", func() {
				clientOp := operation.Name{Kind: operation.KindClient, Group: testDependencyProcessIdentifier.Service, Method: "OtherOperation"}
				_, clientSpan := operationStarter.StartOp(serverCtx, clientOp)

				Convey("When successful", func() {
					clientSpan.SetStatus(operation.Status{Code: 0})
					clientSpan.End()

					Convey("Then it should build 4 samples", func() {
						So(len(capturedSamples.Samples), ShouldEqual, 4)
						for _, sample := range capturedSamples.Samples {
							So(sample.MetricID.Dimensions, ShouldResemble, telemetry.DimensionSet{
								"Service":    "TestService",
								"Stage":      "TestStage",
								"Substage":   "TestSubstage",
								"Region":     "TestRegion",
								"Operation":  "MyOperation",
								"Dependency": "DependencyService:OtherOperation",
							})
							So(sample.RollupDimensions, ShouldNotBeEmpty)
							So(time.Now().Sub(sample.Timestamp).Seconds(), ShouldBeBetween, -10, 10)
						}
					})

					Convey("Then it should build a duration metric", func() {
						So(capturedSamples.Samples[0].MetricID.Name, ShouldEqual, telemetry.MetricDependencyDuration)
						So(capturedSamples.Samples[0].Value, ShouldBeBetween, 0, 1)
						So(capturedSamples.Samples[0].Unit, ShouldEqual, "Seconds")
					})

					Convey("Then it should build success dependency availability metrics", func() {
						So(capturedSamples.Samples[1].MetricID.Name, ShouldEqual, telemetry.MetricDependencySuccess)
						So(capturedSamples.Samples[1].Value, ShouldEqual, 1)
						So(capturedSamples.Samples[1].Unit, ShouldEqual, "Count")

						So(capturedSamples.Samples[2].MetricID.Name, ShouldEqual, telemetry.MetricDependencyClientError)
						So(capturedSamples.Samples[2].Value, ShouldEqual, 0)
						So(capturedSamples.Samples[2].Unit, ShouldEqual, "Count")

						So(capturedSamples.Samples[3].MetricID.Name, ShouldEqual, telemetry.MetricDependencyServerError)
						So(capturedSamples.Samples[3].Value, ShouldEqual, 0)
						So(capturedSamples.Samples[3].Unit, ShouldEqual, "Count")
					})
				})

				Convey("When it fails with an invalid argument", func() {
					clientSpan.SetStatus(operation.Status{Code: 3})
					clientSpan.End()

					Convey("Then it should build client error dependency availability metrics", func() {
						So(len(capturedSamples.Samples), ShouldEqual, 4)
						So(capturedSamples.Samples[1].MetricID.Name, ShouldEqual, telemetry.MetricDependencySuccess)
						So(capturedSamples.Samples[1].Value, ShouldEqual, 0)

						So(capturedSamples.Samples[2].MetricID.Name, ShouldEqual, telemetry.MetricDependencyClientError)
						So(capturedSamples.Samples[2].Value, ShouldEqual, 1)

						So(capturedSamples.Samples[3].MetricID.Name, ShouldEqual, telemetry.MetricDependencyServerError)
						So(capturedSamples.Samples[3].Value, ShouldEqual, 0)
					})
				})

				Convey("When it fails with an internal error", func() {
					clientSpan.SetStatus(operation.Status{Code: 13})
					clientSpan.End()

					Convey("Then it should build internal error dependency availability metrics", func() {
						So(len(capturedSamples.Samples), ShouldEqual, 4)
						So(capturedSamples.Samples[1].MetricID.Name, ShouldEqual, telemetry.MetricDependencySuccess)
						So(capturedSamples.Samples[1].Value, ShouldEqual, 0)

						So(capturedSamples.Samples[2].MetricID.Name, ShouldEqual, telemetry.MetricDependencyClientError)
						So(capturedSamples.Samples[2].Value, ShouldEqual, 0)

						So(capturedSamples.Samples[3].MetricID.Name, ShouldEqual, telemetry.MetricDependencyServerError)
						So(capturedSamples.Samples[3].Value, ShouldEqual, 1)
					})
				})
			})
		})

		Convey("And a span operation", func() {
			spanOp := operation.Name{Kind: operation.KindSpan, Group: testDependencyProcessIdentifier.Service, Method: "TestSpan"}
			_, internalSpan := operationStarter.StartOp(context.Background(), spanOp)

			Convey("Then the span should record a timer metric", func() {
				time.Sleep(time.Millisecond)
				internalSpan.End()

				So(len(capturedSamples.Samples), ShouldEqual, 1)
				sample := capturedSamples.Samples[0]
				So(sample.MetricID, ShouldResemble, telemetry.MetricID{
					Name: "TestSpan",
					Dimensions: telemetry.DimensionSet{
						"Service":   "TestService",
						"Stage":     "TestStage",
						"Substage":  "TestSubstage",
						"Region":    "TestRegion",
						"Operation": telemetry.MetricValueUnknownOperation,
					},
				})
				So(sample.RollupDimensions, ShouldNotBeEmpty)
				So(time.Now().Sub(sample.Timestamp).Seconds(), ShouldBeBetween, -10, 10)
				So(sample.Value, ShouldBeBetween, 0, 10)
				So(sample.Unit, ShouldEqual, "Seconds")
			})
		})

		Convey("And a client operation", func() {
			clientOp := operation.Name{Kind: operation.KindClient, Group: testDependencyProcessIdentifier.Service, Method: "OtherOperation"}
			_, clientSpan := operationStarter.StartOp(context.Background(), clientOp)

			Convey("When successful", func() {
				clientSpan.SetStatus(operation.Status{Code: 0})
				clientSpan.End()

				Convey("Then it should build 4 samples", func() {
					So(len(capturedSamples.Samples), ShouldEqual, 4)
					for _, sample := range capturedSamples.Samples {
						So(sample.MetricID.Dimensions, ShouldResemble, telemetry.DimensionSet{
							"Service":    "TestService",
							"Stage":      "TestStage",
							"Substage":   "TestSubstage",
							"Region":     "TestRegion",
							"Operation":  telemetry.MetricValueUnknownOperation,
							"Dependency": "DependencyService:OtherOperation",
						})
						So(sample.RollupDimensions, ShouldNotBeEmpty)
						So(time.Now().Sub(sample.Timestamp).Seconds(), ShouldBeBetween, -10, 10)
					}
				})

				Convey("Then it should build a duration metric", func() {
					So(capturedSamples.Samples[0].MetricID.Name, ShouldEqual, telemetry.MetricDependencyDuration)
					So(capturedSamples.Samples[0].Value, ShouldBeBetween, 0, 1)
					So(capturedSamples.Samples[0].Unit, ShouldEqual, "Seconds")
				})

				Convey("Then it should build success dependency availability metrics", func() {
					So(capturedSamples.Samples[1].MetricID.Name, ShouldEqual, telemetry.MetricDependencySuccess)
					So(capturedSamples.Samples[1].Value, ShouldEqual, 1)
					So(capturedSamples.Samples[1].Unit, ShouldEqual, "Count")

					So(capturedSamples.Samples[2].MetricID.Name, ShouldEqual, telemetry.MetricDependencyClientError)
					So(capturedSamples.Samples[2].Value, ShouldEqual, 0)
					So(capturedSamples.Samples[2].Unit, ShouldEqual, "Count")

					So(capturedSamples.Samples[3].MetricID.Name, ShouldEqual, telemetry.MetricDependencyServerError)
					So(capturedSamples.Samples[3].Value, ShouldEqual, 0)
					So(capturedSamples.Samples[3].Unit, ShouldEqual, "Count")
				})
			})
		})

		Convey("And an unknown server operation", func() {
			serverOpName := operation.Name{Kind: operation.Kind(123456789), Group: testProcessIdentifier.Service, Method: "MyOperation"}
			_, serverOp := operationStarter.StartOp(context.Background(), serverOpName)
			serverOp.End()

			Convey("Then it should not build samples", func() {
				So(capturedSamples.Samples, ShouldBeEmpty)
			})
		})
		Convey("With AutoFlush", func() {
			metricsOpMonitor.AutoFlush = true

			Convey("And a server operation", func() {
				serverOpName := operation.Name{Kind: operation.KindServer, Group: testProcessIdentifier.Service, Method: "MyOperation"}
				_, serverOp := operationStarter.StartOp(context.Background(), serverOpName)

				Convey("When successful", func() {
					serverOp.SetStatus(operation.Status{Code: 0})
					serverOp.End()

					Convey("Then it should call flush", func() {
						So(capturedSamples.FlushCallCount, ShouldEqual, 1)
					})
				})
			})
		})
	})
}

func TestOperationMonitorConcurrency(t *testing.T) {
	Convey("Given an operation starter with an operation monitor and AutoFlush", t, func() {
		observer := &threadSafeSampleObserver{}

		metricsOpMonitor := &OperationMonitor{
			SampleReporter: telemetry.SampleReporter{
				SampleBuilder:  telemetry.SampleBuilder{ProcessIdentifier: testProcessIdentifier},
				SampleObserver: observer,
			},
			AutoFlush: true,
		}

		operationStarter := &operation.Starter{OpMonitors: []operation.OpMonitor{metricsOpMonitor}}

		Convey("When requests are run concurrently", func() {
			clientCallsPerRequest := 5

			requestHandler := func(parameters []interface{}) {
				serverOpName := operation.Name{Kind: operation.KindServer, Group: testProcessIdentifier.Service, Method: "MyOperation"}
				serverCtx, serverOp := operationStarter.StartOp(context.Background(), serverOpName)

				for clientCall := 0; clientCall < clientCallsPerRequest; clientCall++ {
					clientOp := operation.Name{Kind: operation.KindClient, Group: testDependencyProcessIdentifier.Service, Method: "OtherOperation"}
					_, clientSpan := operationStarter.StartOp(serverCtx, clientOp)

					clientSpan.SetStatus(operation.Status{Code: 0})
					clientSpan.End()
				}

				serverOp.SetStatus(operation.Status{Code: 0})
				serverOp.End()
			}

			runner := &loadRunner{Concurrency: 4, TestDuration: 2 * time.Second, Task: requestHandler}
			runner.Run()

			invocations := runner.GetInvocations()
			fmt.Printf("(%v invocations, %.2f per second)", invocations, float64(invocations)/runner.TestDuration.Seconds())

			Convey("Then it should succeed and report samples", func() {
				So(invocations, ShouldBeGreaterThan, 50)
				So(observer.FlushCount(), ShouldEqual, invocations)

				// Four samples per invocation and client call.
				expectedSamples := 4 * invocations * int64(1+clientCallsPerRequest)
				So(observer.SampleCount(), ShouldEqual, expectedSamples)
			})
		})

		Convey("When requests and clients are run concurrently", func() {
			clientCallsPerRequest := 5

			requestHandler := func(parameters []interface{}) {
				serverOpName := operation.Name{Kind: operation.KindServer, Group: testProcessIdentifier.Service, Method: "MyOperation"}
				serverCtx, serverOp := operationStarter.StartOp(context.Background(), serverOpName)

				var waitgroup sync.WaitGroup
				clientHandler := func() {
					defer waitgroup.Done()
					clientOp := operation.Name{Kind: operation.KindClient, Group: testDependencyProcessIdentifier.Service, Method: "OtherOperation"}
					_, clientSpan := operationStarter.StartOp(serverCtx, clientOp)

					clientSpan.SetStatus(operation.Status{Code: 0})
					clientSpan.End()
				}

				for clientCall := 0; clientCall < clientCallsPerRequest; clientCall++ {
					waitgroup.Add(1)
					go clientHandler()
				}
				waitgroup.Wait()

				serverOp.SetStatus(operation.Status{Code: 0})
				serverOp.End()
			}

			runner := &loadRunner{Concurrency: 2, TestDuration: 2 * time.Second, Task: requestHandler}
			runner.Run()

			invocations := runner.GetInvocations()
			fmt.Printf("(%v invocations, %.2f per second)", invocations, float64(invocations)/runner.TestDuration.Seconds())

			Convey("Then it should succeed and report samples", func() {
				So(invocations, ShouldBeGreaterThan, 50)
				So(observer.FlushCount(), ShouldEqual, invocations)

				// Four samples per invocation and client call.
				expectedSamples := 4 * invocations * int64(1+clientCallsPerRequest)
				So(observer.SampleCount(), ShouldEqual, expectedSamples)
			})
		})
	})
}

func BenchmarkRequestWithOneClientCall(b *testing.B) {
	metricsOpMonitor := &OperationMonitor{
		SampleReporter: telemetry.SampleReporter{
			SampleBuilder:  telemetry.SampleBuilder{ProcessIdentifier: testProcessIdentifier},
			SampleObserver: &nopSampleObserver{},
		},
	}

	operationStarter := &operation.Starter{OpMonitors: []operation.OpMonitor{metricsOpMonitor}}

	for i := 0; i < b.N; i++ {
		serverOpName := operation.Name{Kind: operation.KindServer, Group: testProcessIdentifier.Service, Method: "MyOperation"}
		serverCtx, serverOp := operationStarter.StartOp(context.Background(), serverOpName)

		clientOp := operation.Name{Kind: operation.KindClient, Group: testDependencyProcessIdentifier.Service, Method: "OtherOperation"}
		_, clientSpan := operationStarter.StartOp(serverCtx, clientOp)

		clientSpan.SetStatus(operation.Status{Code: 0})
		clientSpan.End()

		serverOp.SetStatus(operation.Status{Code: 0})
		serverOp.End()
	}
}

type sampleObserver struct {
	Samples        []*telemetry.Sample
	FlushCallCount int
}

func (observer *sampleObserver) ObserveSample(sample *telemetry.Sample) {
	observer.Samples = append(observer.Samples, sample)
}

func (observer *sampleObserver) Flush() {
	observer.FlushCallCount++
}

func (observer *sampleObserver) Stop() {}

type nopSampleObserver struct{}

func (observer *nopSampleObserver) ObserveSample(sample *telemetry.Sample) {}
func (observer *nopSampleObserver) Flush()                                 {}
func (observer *nopSampleObserver) Stop()                                  {}

type threadSafeSampleObserver struct {
	sampleCount int32
	flushCount  int32
}

func (observer *threadSafeSampleObserver) ObserveSample(sample *telemetry.Sample) {
	atomic.AddInt32(&observer.sampleCount, 1)
}

func (observer *threadSafeSampleObserver) Flush() {
	atomic.AddInt32(&observer.flushCount, 1)
}

func (observer *threadSafeSampleObserver) SampleCount() int32 {
	return atomic.LoadInt32(&observer.sampleCount)
}

func (observer *threadSafeSampleObserver) FlushCount() int32 {
	return atomic.LoadInt32(&observer.flushCount)
}

func (observer *threadSafeSampleObserver) Stop() {}

// loadRunner runs tasks concurrently with optional setup and cleanup for each thread.
type loadRunner struct {
	Concurrency          int
	TestDuration         time.Duration
	Task                 func([]interface{})
	ThreadContextSetup   func(index int) []interface{}
	ThreadContextCleanup func(index int, parameters []interface{})
	Invocations          int64
}

// Run runs the tasks.
func (runner *loadRunner) Run() {
	testEndTime := time.Now().Add(runner.TestDuration)
	var waitgroup sync.WaitGroup
	for i := 0; i < runner.Concurrency; i++ {
		waitgroup.Add(1)
		threadID := i
		go func() {
			defer waitgroup.Done()
			var threadParameters []interface{}
			if runner.ThreadContextSetup != nil {
				threadParameters = runner.ThreadContextSetup(threadID)
			}
			for time.Now().Before(testEndTime) {
				runner.Task(threadParameters)
				atomic.AddInt64(&runner.Invocations, 1)
			}
			if runner.ThreadContextCleanup != nil {
				runner.ThreadContextCleanup(threadID, threadParameters)
			}
		}()
	}
	waitgroup.Wait()
}

// GetInvocations returns the number of times the task was invoked.
func (runner *loadRunner) GetInvocations() int64 {
	return atomic.LoadInt64(&runner.Invocations)
}
