package circuitmetrics

import (
	"context"
	"errors"
	"fmt"
	"math"
	"reflect"
	"strings"
	"sync"
	"testing"
	"time"

	identifier "code.justin.tv/amzn/TwitchProcessIdentifier"
	telemetry "code.justin.tv/amzn/TwitchTelemetry"
	"github.com/cep21/circuit/v3"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func Test_nameTag(t *testing.T) {
	type args struct {
		s string
	}
	tests := []struct {
		name string
		args args
		want map[string]string
	}{
		{
			name: "normal",
			args: args{
				s: "hello",
			},
			want: map[string]string{"CircuitName": "hello"},
		},
		{
			name: "empty",
			args: args{
				s: "",
			},
			want: map[string]string{"CircuitName": ""},
		},
		{
			name: "badchars",
			args: args{
				s: "aAzZ09_,",
			},
			want: map[string]string{"CircuitName": "aAzZ09__"},
		},
		{
			name: "long",
			args: args{
				s: strings.Repeat("a", 65),
			},
			want: map[string]string{"CircuitName": strings.Repeat("a", 64)},
		}, {
			name: "abc.123.*&#",
			args: args{
				s: "abc.123.*&#",
			},
			want: map[string]string{"CircuitName": "abc_123____"},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if got := nameTag(tt.args.s); !reflect.DeepEqual(got, tt.want) {
				t.Errorf("nameTag() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestCommandProperties_Metrics(t *testing.T) {
	var obs observer
	metrics := CommandFactory{
		Reporter: telemetry.SampleReporter{
			SampleBuilder: telemetry.SampleBuilder{ProcessIdentifier: identifier.ProcessIdentifier{
				Service: "TestService",
			}},
			SampleObserver: &obs,
		},
	}
	cm := &circuit.Manager{DefaultCircuitProperties: []circuit.CommandPropertiesConstructor{metrics.CommandProperties}}

	circuit1 := cm.MustCreateCircuit("circuit1")
	circuit2 := cm.MustCreateCircuit("circuit2")

	require.NoError(t, circuit1.Run(context.Background(), func(context.Context) error {
		return nil
	}))
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit1"}, "Success", 1.0)

	require.Error(t, circuit1.Run(context.Background(), func(context.Context) error {
		return errors.New("test error")
	}))
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit1"}, "Error_Failure", 1.0)
	obs.assertNumSamplesFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit1"}, "Success", 0, 1)

	require.Error(t, circuit1.Run(context.Background(), func(context.Context) error {
		return circuit.SimpleBadRequest{Err: errors.New("bad request")}
	}))
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit1"}, "Error_BadRequest", 1.0)
	obs.assertNumSamplesFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit1"}, "Success", 0, 2)

	cancelled, cancel := context.WithCancel(context.Background())
	cancel()
	require.NoError(t, circuit2.Execute(cancelled,
		func(context.Context) error {
			return errors.New("cancelled anyway")
		},
		func(context.Context, error) error {
			return nil
		}))
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit2"}, "Error_Interrupt", 1.0)
	obs.assertNumSamplesFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit2"}, "Success", 0, 1)
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.fallback", "CircuitName": "circuit2"}, "Success", 1.0)

	require.Error(t, circuit2.Execute(context.Background(),
		func(context.Context) error {
			return errors.New("force fallback")
		},
		func(context.Context, error) error {
			return errors.New("fallback error")
		}))
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit2"}, "Error_Failure", 1.0)
	obs.assertNumSamplesFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit2"}, "Success", 0, 2)
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.fallback", "CircuitName": "circuit2"}, "Error_Failure", 1.0)

	// verify that all durations were reported
	require.NotEmpty(t, obs.findSamples(map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit1"}, "Duration"), "duration was not reported")
	require.NotEmpty(t, obs.findSamples(map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit2"}, "Duration"), "duration was not reported")
	require.NotEmpty(t, obs.findSamples(map[string]string{"Service": "TestService", "Producer": "circuit.fallback", "CircuitName": "circuit2"}, "Duration"), "duration was not reported")
}

func TestConcurrencyCollector_Metrics(t *testing.T) {
	var obs observer
	metrics := CommandFactory{
		Reporter: telemetry.SampleReporter{
			SampleBuilder: telemetry.SampleBuilder{ProcessIdentifier: identifier.ProcessIdentifier{
				Service: "TestService",
			}},
			SampleObserver: &obs,
		},
		AdditionalMetricRollups: []AdditionalMetricRollup{
			{
				RollupName: "TestRollup",
				ReportOnes: []Outcome{
					Failure,
					Timeout,
					ShortCircuit,
					ConcurrencyLimitReject,
				},
				ReportZeroes: []Outcome{
					Success,
				},
			},
		},
	}
	cm := &circuit.Manager{
		DefaultCircuitProperties: []circuit.CommandPropertiesConstructor{
			metrics.CommandProperties,
		},
	}
	cc := metrics.ConcurrencyCollector(cm)

	circuit1 := cm.MustCreateCircuit("circuit1", circuit.Config{
		Execution: circuit.ExecutionConfig{
			MaxConcurrentRequests: 1,
		},
	})
	circuit2 := cm.MustCreateCircuit("circuit2")
	_ = circuit2

	var wg sync.WaitGroup
	c := make(chan struct{})
	wg.Add(1)
	go func() {
		defer wg.Done()
		err := circuit1.Run(context.Background(), func(ctx context.Context) error {
			c <- struct{}{}
			c <- struct{}{}
			return nil
		})
		require.NoError(t, err)
	}()

	<-c // wait for command to park

	wg.Add(1)
	go func() {
		defer wg.Done()
		cc.Start()
	}()

	// wait for some samples to arrive
	require.Eventually(t, func() bool { return obs.sampleCount() > 0 }, time.Second, 10*time.Millisecond, "collector did not report any samples within a second")
	require.NoError(t, cc.Close())

	<-c // unblock the command
	wg.Wait()

	require.Equal(t, 9, obs.sampleCount())

	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit", "CircuitName": "circuit1"}, "IsOpen", 0.0)
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit", "CircuitName": "circuit1"}, "Concurrent", 1.0)
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit", "CircuitName": "circuit1"}, "ConcurrentPct", 100.0)

	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit", "CircuitName": "circuit2"}, "IsOpen", 0.0)
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit", "CircuitName": "circuit2"}, "Concurrent", 0.0)
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit", "CircuitName": "circuit2"}, "ConcurrentPct", 0.0)

	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit.run", "CircuitName": "circuit1"}, "Rollup_TestRollup", 0.0)
}

func TestConcurrencyCollector_SkipZeroValueMetrics(t *testing.T) {
	var obs observer
	metrics := CommandFactory{
		Reporter: telemetry.SampleReporter{
			SampleBuilder: telemetry.SampleBuilder{ProcessIdentifier: identifier.ProcessIdentifier{
				Service: "TestService",
			}},
			SampleObserver: &obs,
		},
	}
	cm := &circuit.Manager{}
	cc := metrics.ConcurrencyCollector(cm)
	cc.Opts.SkipZeroValueMetrics = true

	circuit1 := cm.MustCreateCircuit("circuit1", circuit.Config{
		Execution: circuit.ExecutionConfig{
			MaxConcurrentRequests: 1,
		},
	})
	circuit2 := cm.MustCreateCircuit("circuit2")
	_ = circuit2

	var wg sync.WaitGroup
	c := make(chan struct{})
	wg.Add(1)
	go func() {
		defer wg.Done()
		err := circuit1.Run(context.Background(), func(ctx context.Context) error {
			c <- struct{}{}
			c <- struct{}{}
			return nil
		})
		require.NoError(t, err)
	}()

	<-c // wait for command to park

	wg.Add(1)
	go func() {
		defer wg.Done()
		cc.Start()
	}()

	// wait for some samples to arrive
	require.Eventually(t, func() bool { return obs.sampleCount() > 0 }, time.Second, 10*time.Millisecond, "collector did not report any samples within a second")
	require.NoError(t, cc.Close())

	<-c // unblock the command
	wg.Wait()

	require.Equal(t, 2, obs.sampleCount())

	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit", "CircuitName": "circuit1"}, "Concurrent", 1.0)
	obs.assertSampleFound(t, map[string]string{"Service": "TestService", "Producer": "circuit", "CircuitName": "circuit1"}, "ConcurrentPct", 100.0)
}

func TestConcurrencyCollector_StartStop(t *testing.T) {
	var obs observer
	metrics := CommandFactory{
		Reporter: telemetry.SampleReporter{
			SampleBuilder: telemetry.SampleBuilder{ProcessIdentifier: identifier.ProcessIdentifier{
				Service: "TestService",
			}},
			SampleObserver: &obs,
		},
	}

	cm := &circuit.Manager{}
	cc := metrics.ConcurrencyCollector(cm)

	go func() {
		cc.Start()
	}()

	require.NoError(t, cc.Close())
}

type observer struct {
	sync.Mutex
	Samples []*telemetry.Sample
}

func (s *observer) sampleCount() int {
	s.Lock()
	defer s.Unlock()
	return len(s.Samples)
}

func (s *observer) assertSampleFound(t *testing.T, dims map[string]string, name string, value float64) {
	s.assertNumSamplesFound(t, dims, name, value, 1)
}

func (s *observer) assertNumSamplesFound(t *testing.T, dims map[string]string, name string, value float64, num int) {
	values := s.findSamplesWithValue(dims, name, value)
	if len(values) != num {
		assert.Fail(t, fmt.Sprintf("metric %q with %v not found %d time(s) in:\n%s", name, dims, num, s))
	} else {
		assert.Contains(t, values, value, fmt.Sprintf("metric %q with %v equal to %2.0f not found %d time(s)", name, dims, value, num))
	}
}

func (s *observer) String() string {
	s.Lock()
	defer s.Unlock()

	var ss []string
	for _, s := range s.Samples {
		ss = append(ss, fmt.Sprintf("%q with %v == %2.0f", s.MetricID.Name, s.MetricID.Dimensions, s.Value))
	}
	return strings.Join(ss, "\n")
}

func (s *observer) findSamples(dims map[string]string, name string) (values []float64) {
	return s.findSamplesWithValue(dims, name, anyValue)
}

const anyValue = math.MaxFloat64

func (s *observer) findSamplesWithValue(dims map[string]string, name string, value float64) (values []float64) {
	s.Lock()
	defer s.Unlock()

	for _, s := range s.Samples {
		if s.MetricID.Name == name && reflect.DeepEqual(s.MetricID.Dimensions, telemetry.DimensionSet(dims)) && (s.Value == value || value == anyValue) {
			values = append(values, s.Value)
		}
	}
	return values
}

func (s *observer) ObserveSample(sample *telemetry.Sample) {
	s.Lock()
	s.Samples = append(s.Samples, sample)
	s.Unlock()
}

func (s *observer) Flush() {}
func (s *observer) Stop()  {}

var _ telemetry.SampleObserver = new(observer)
