package taggingcircuitsmetrics

import (
	"context"
	"errors"
	"reflect"
	"strings"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"code.justin.tv/hygienic/metricscactusstatsd"
	"github.com/cep21/circuit/metrics/responsetimeslo"
	"github.com/cep21/circuit/v3"
	"github.com/stretchr/testify/require"
)

type rememberStats struct {
	incs    []map[string]string
	gauges  []map[string]string
	timings []map[string]string

	mu sync.Mutex
}

func (n *rememberStats) IncD(stat string, dims map[string]string, val int64) {
	n.mu.Lock()
	defer n.mu.Unlock()
	if n.incs == nil {
		n.incs = []map[string]string{}
	}
	n.incs = append(n.incs, combineDimensions(dims, map[string]string{"metricName": stat}))
}

func (n *rememberStats) EachGauge(f func(map[string]string)) {
	n.mu.Lock()
	defer n.mu.Unlock()
	for _, v := range n.gauges {
		f(v)
	}
}

func (n *rememberStats) EachCounter(f func(map[string]string)) {
	n.mu.Lock()
	defer n.mu.Unlock()
	for _, v := range n.incs {
		f(v)
	}
}

func (n *rememberStats) NumGauges() int {
	n.mu.Lock()
	defer n.mu.Unlock()
	return len(n.gauges)
}

func (n *rememberStats) GaugeD(stat string, dims map[string]string, val int64) {
	n.mu.Lock()
	defer n.mu.Unlock()
	if n.gauges == nil {
		n.gauges = []map[string]string{}
	}
	n.gauges = append(n.gauges, combineDimensions(dims, map[string]string{"metricName": stat}))
}

func (n *rememberStats) TimingDurationD(stat string, dims map[string]string, val time.Duration) {
	n.mu.Lock()
	defer n.mu.Unlock()
	if n.timings == nil {
		n.timings = []map[string]string{}
	}

	n.timings = append(n.timings, combineDimensions(dims, map[string]string{"metricName": stat}))
}
func (n *rememberStats) NewDimensionalSubStatter(dims map[string]string) metricscactusstatsd.TaggingSubStatter {

	return &rememberDims{n, dims}
}

type rememberDims struct {
	parrent metricscactusstatsd.TaggingSubStatter
	dims    map[string]string
}

func (n *rememberDims) IncD(stat string, dims map[string]string, val int64) {
	n.parrent.IncD(stat, combineDimensions(n.dims, dims), val)
}

func (n *rememberDims) GaugeD(stat string, dims map[string]string, val int64) {

	n.parrent.GaugeD(stat, combineDimensions(n.dims, dims), val)
}

func (n *rememberDims) TimingDurationD(stat string, dims map[string]string, val time.Duration) {

	n.parrent.TimingDurationD(stat, combineDimensions(n.dims, dims), val)
}

func (n *rememberDims) NewDimensionalSubStatter(dims map[string]string) metricscactusstatsd.TaggingSubStatter {

	return &rememberDims{n, dims}
}

var _ metricscactusstatsd.TaggingSubStatter = &rememberStats{}

// func TestAppendStatsdParts(t *testing.T) {
// 	if x := appendStatsdParts(sanitizeStatsd, "hello", "", "world"); x != "hello.world" {
// 		t.Fatalf("expect hello.world, got %s", x)
// 	}
// }

func TestConcurrencyCollector_delay(t *testing.T) {
	x := ConcurrencyCollector{}
	require.Equal(t, x.delay(), time.Second*10)
	x.Delay.Set(time.Second.Nanoseconds() * 3)
	require.Equal(t, x.delay(), time.Second*3)
}

func waitForGauge(dims map[string]string, ss *rememberStats, clk *mockClock) {
	hasGauge := false
	for !hasGauge {
		time.Sleep(time.Millisecond)
		ss.EachGauge(func(d map[string]string) {
			if reflect.DeepEqual(dims, d) {
				hasGauge = true
			}
		})
	}
}

func waitForCounter(dims map[string]string, ss *rememberStats, clk *mockClock) {
	hasCounter := false
	for !hasCounter {
		time.Sleep(time.Millisecond)
		ss.EachCounter(func(d map[string]string) {
			if reflect.DeepEqual(dims, d) {
				hasCounter = true
			}
		})
	}
}

func TestConcurrencyCollector_Start(t *testing.T) {
	clk := mockClock{}
	now := time.Now()
	clk.Set(now)
	ss := rememberStats{}
	c := CommandFactory{
		StatSender: &ss,
	}
	tf := responsetimeslo.Factory{
		CollectorConstructors: []func(circuitName string) responsetimeslo.Collector{
			c.SLOCollector,
		},
	}
	m := &circuit.Manager{
		DefaultCircuitProperties: []circuit.CommandPropertiesConstructor{
			c.CommandProperties,
			tf.CommandProperties,
		},
	}
	exampleCircuit := m.MustCreateCircuit("example")
	require.NotNil(t, exampleCircuit)
	cc := c.ConcurrencyCollector(m)
	cc.timeAfter = clk.After
	wg := sync.WaitGroup{}
	wg.Add(1)
	go func() {
		defer wg.Done()
		cc.Start()
	}()
	expectedGauges := []map[string]string{
		map[string]string{"Producer": "circuit", "CircuitName": "example", "metricName": "is_open"},
		map[string]string{"Producer": "circuit", "CircuitName": "example", "metricName": "concurrent"},
		map[string]string{"Producer": "circuit", "CircuitName": "example", "metricName": "max_concurrent"},
	}
	expectedCounters := []map[string]string{
		map[string]string{"Producer": "circuit.run", "CircuitName": "example", "metricName": "success"},
		map[string]string{"Producer": "circuit.run", "CircuitName": "example", "metricName": "err_failure"},
		map[string]string{"Producer": "circuit.slo", "CircuitName": "example", "metricName": "failed"},
		map[string]string{"Producer": "circuit.slo", "CircuitName": "example", "metricName": "passed"},
	}

	foundExpected := int64(0)
	for _, e := range expectedGauges {
		wg.Add(1)
		e := e
		go func() {
			defer wg.Done()
			defer atomic.AddInt64(&foundExpected, 1)
			waitForGauge(e, &ss, &clk)
		}()
	}
	for _, e := range expectedCounters {
		wg.Add(1)
		e := e
		go func() {
			defer wg.Done()
			defer atomic.AddInt64(&foundExpected, 1)
			waitForCounter(e, &ss, &clk)
		}()
	}

	// Do 1 of most types to make sure they trigger
	require.Error(t, exampleCircuit.Run(context.Background(), func(ctx context.Context) error {
		return errors.New("bad")
	}))
	require.NoError(t, exampleCircuit.Run(context.Background(), func(ctx context.Context) error {
		return nil
	}))

	// Keep ticking until we find the 3 metrics we expect
	TickUntil(&clk, func() bool {
		return atomic.LoadInt64(&foundExpected) == int64(len(expectedCounters)+len(expectedGauges))
	}, time.Millisecond, time.Second)
	require.NoError(t, cc.Close())
	wg.Wait()
}

func Test_nameTag(t *testing.T) {
	type args struct {
		s string
	}
	tests := []struct {
		name string
		args args
		want map[string]string
	}{
		{
			name: "normal",
			args: args{
				s: "hello",
			},
			want: map[string]string{"CircuitName": "hello"},
		},
		{
			name: "empty",
			args: args{
				s: "",
			},
			want: map[string]string{"CircuitName": ""},
		},
		{
			name: "badchars",
			args: args{
				s: "aAzZ09_,",
			},
			want: map[string]string{"CircuitName": "aAzZ09__"},
		},
		{
			name: "long",
			args: args{
				s: strings.Repeat("a", 65),
			},
			want: map[string]string{"CircuitName": strings.Repeat("a", 64)},
		}, {
			name: "abc.123.*&#",
			args: args{
				s: "abc.123.*&#",
			},
			want: map[string]string{"CircuitName": "abc_123____"},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if got := nameTag(tt.args.s); !reflect.DeepEqual(got, tt.want) {
				t.Errorf("nameTag() = %v, want %v", got, tt.want)
			}
		})
	}
}
