package statsd

import (
	"fmt"
	"testing"
	"time"

	identifier "code.justin.tv/amzn/TwitchProcessIdentifier"
	telemetry "code.justin.tv/amzn/TwitchTelemetry"

	"github.com/cactus/go-statsd-client/statsd"
	. "github.com/smartystreets/goconvey/convey"
)

// a mock implementation of statsd.Statter that just tracks
// hit counts for each metric name
type mockStatter struct {
	timerMap   map[string]int32
	counterMap map[string]int32
}

// mock statsd.Statter implementation
func NewMockStatter() statsd.Statter {
	return &mockStatter{
		timerMap:   make(map[string]int32),
		counterMap: make(map[string]int32),
	}
}
func (ms *mockStatter) Inc(s string, v int64, r float32) error {
	incrementOrInit(ms.counterMap, s)
	return nil
}
func (ms *mockStatter) Timing(s string, v int64, r float32) error {
	incrementOrInit(ms.timerMap, s)
	return nil
}
func (ms *mockStatter) TimingDuration(s string, t time.Duration, r float32) error {
	incrementOrInit(ms.timerMap, s)
	return nil
}
func (ms *mockStatter) Dec(string, int64, float32) error         { return nil }
func (ms *mockStatter) Gauge(string, int64, float32) error       { return nil }
func (ms *mockStatter) GaugeDelta(string, int64, float32) error  { return nil }
func (ms *mockStatter) Set(string, string, float32) error        { return nil }
func (ms *mockStatter) SetInt(string, int64, float32) error      { return nil }
func (ms *mockStatter) Raw(string, string, float32) error        { return nil }
func (ms *mockStatter) NewSubStatter(s string) statsd.SubStatter { return nil }
func (ms *mockStatter) SetPrefix(s string)                       { return }
func (ms *mockStatter) Close() error                             { return nil }
func (ms *mockStatter) CounterHits(key string) int32 {
	var val int32
	val, _ = ms.counterMap[key]
	return val
}
func (ms *mockStatter) TimerHits(key string) int32 {
	var val int32
	val, _ = ms.timerMap[key]
	return val
}

// inits a map val to 1 or increments if it already exists
func incrementOrInit(m map[string]int32, k string) {
	var found bool
	if _, found = m[k]; !found {
		m[k] = 1
	} else {
		m[k]++
	}
}

func TestStatsdClient(t *testing.T) {
	Convey("Given a statsd metrics sender client with a mock statter", t, func() {
		st := &identifier.ProcessIdentifier{
			Service: "Foobar",
		}
		c, err := New(st, "localhost:8125", nil)
		So(err, ShouldBeNil)
		ms := NewMockStatter()

		// Stop the client so it stops doing automated flushes and statsd.Statter generation
		c.(*Client).Stop()

		Convey("And a basic sample with unit 'Count'", func() {
			c.(*Client).withStatter(ms)
			dimensions := telemetry.DimensionSet{
				"Service":  "Foobar",
				"Stage":    "prod",
				"Substage": "primary",
				"Region":   "us-west-2",
			}

			sample := &telemetry.Sample{
				MetricID: telemetry.MetricID{
					Name:       "CoolMetric",
					Dimensions: dimensions,
				},
				Value: float64(0.5),
				Unit:  telemetry.UnitCount,
			}

			Convey("When the sample is submitted", func() {
				c.ObserveSample(sample)
				c.(*Client).Flush()
				metricName := "Service/Foobar.Stage/prod.Substage/primary.Region/us-west-2.CoolMetric"
				So(ms.(*mockStatter).CounterHits(metricName), ShouldEqual, 1)
				So(ms.(*mockStatter).TimerHits(metricName), ShouldEqual, 0)
			})
		})

		Convey("And a basic sample with unit 'Seconds' and with rollups", func() {
			c.(*Client).withStatter(ms)
			dimensions := telemetry.DimensionSet{
				"Service": "Foobar",
				"Stage":   "prod",
				"Region":  "us-west-2",
				"MyKey":   "MyValue",
			}

			sample := &telemetry.Sample{
				MetricID: telemetry.MetricID{
					Name:       "RadMetric",
					Dimensions: dimensions,
				},
				Value: float64(0.5),
				Unit:  telemetry.UnitSeconds,
				RollupDimensions: [][]string{
					[]string{"MyKey", "Region"},
					[]string{"MyKey"},
				},
			}

			Convey("When the sample is submitted", func() {
				c.ObserveSample(sample)
				c.(*Client).Flush()

				Convey("The stat and the rollups should be recorded as timers", func() {
					metric1 := "Service/Foobar.Stage/prod.Substage/none.Region/us-west-2.MyKey/MyValue.RadMetric"
					metric2 := "Service/Foobar.Stage/prod.Substage/none.Region/all.MyKey/all.RadMetric"
					metric3 := "Service/Foobar.Stage/prod.Substage/none.Region/us-west-2.MyKey/all.RadMetric"
					So(ms.(*mockStatter).TimerHits(metric1), ShouldEqual, 1)
					So(ms.(*mockStatter).CounterHits(metric1), ShouldEqual, 0)
					So(ms.(*mockStatter).TimerHits(metric2), ShouldEqual, 1)
					So(ms.(*mockStatter).CounterHits(metric2), ShouldEqual, 0)
					So(ms.(*mockStatter).TimerHits(metric3), ShouldEqual, 1)
					So(ms.(*mockStatter).CounterHits(metric3), ShouldEqual, 0)
				})
			})
		})
	})
}

func TestStatsdUnbufferedClient(t *testing.T) {
	Convey("Given a statsd metrics sender client with a mock statter", t, func() {
		st := &identifier.ProcessIdentifier{
			Service: "Foobar",
		}
		c, err := NewUnbuffered(st, "localhost:8125", nil)
		So(err, ShouldBeNil)
		ms := NewMockStatter()

		Convey("And a basic sample with unit 'Count'", func() {
			c.(*Unbuffered).withStatter(ms)
			dimensions := telemetry.DimensionSet{
				"Service":  "Foobar",
				"Stage":    "prod",
				"Substage": "primary",
				"Region":   "us-west-2",
			}

			sample := &telemetry.Sample{
				MetricID: telemetry.MetricID{
					Name:       "CoolMetric",
					Dimensions: dimensions,
				},
				Value: float64(0.5),
				Unit:  telemetry.UnitCount,
			}

			Convey("When the sample is submitted", func() {
				// Normally, we use a BufferedAggregator for calling the Unbuffered sender. However, since we're testing
				// the functionality of the MWS sender, we will instead generate a distribution from a standard
				// Aggregator and manually call the Unbuffered sender's FlushWithoutBuffering method
				aggregator := telemetry.NewAggregator(2 * time.Second)
				aggregator.AggregateSample(sample)
				distribution := aggregator.Flush()
				c.FlushWithoutBuffering(distribution)
				metricName := "Service/Foobar.Stage/prod.Substage/primary.Region/us-west-2.CoolMetric"
				So(ms.(*mockStatter).CounterHits(metricName), ShouldEqual, 1)
				So(ms.(*mockStatter).TimerHits(metricName), ShouldEqual, 0)
			})
		})

		Convey("And a basic sample with unit 'Seconds' and with rollups", func() {
			c.(*Unbuffered).withStatter(ms)
			dimensions := telemetry.DimensionSet{
				"Service": "Foobar",
				"Stage":   "prod",
				"Region":  "us-west-2",
				"MyKey":   "MyValue",
			}

			sample := &telemetry.Sample{
				MetricID: telemetry.MetricID{
					Name:       "RadMetric",
					Dimensions: dimensions,
				},
				Value: float64(0.5),
				Unit:  telemetry.UnitSeconds,
				RollupDimensions: [][]string{
					[]string{"MyKey", "Region"},
					[]string{"MyKey"},
				},
			}

			Convey("When the sample is submitted", func() {
				// Normally, we use a BufferedAggregator for calling the Unbuffered sender. However, since we're testing
				// the functionality of the MWS sender, we will instead generate a distribution from a standard
				// Aggregator and manually call the Unbuffered sender's FlushWithoutBuffering method
				aggregator := telemetry.NewAggregator(2 * time.Second)
				aggregator.AggregateSample(sample)
				distribution := aggregator.Flush()
				c.FlushWithoutBuffering(distribution)
				Convey("The stat and the rollups should be recorded as timers", func() {
					metric1 := "Service/Foobar.Stage/prod.Substage/none.Region/us-west-2.MyKey/MyValue.RadMetric"
					metric2 := "Service/Foobar.Stage/prod.Substage/none.Region/all.MyKey/all.RadMetric"
					metric3 := "Service/Foobar.Stage/prod.Substage/none.Region/us-west-2.MyKey/all.RadMetric"
					So(ms.(*mockStatter).TimerHits(metric1), ShouldEqual, 1)
					So(ms.(*mockStatter).CounterHits(metric1), ShouldEqual, 0)
					So(ms.(*mockStatter).TimerHits(metric2), ShouldEqual, 1)
					So(ms.(*mockStatter).CounterHits(metric2), ShouldEqual, 0)
					So(ms.(*mockStatter).TimerHits(metric3), ShouldEqual, 1)
					So(ms.(*mockStatter).CounterHits(metric3), ShouldEqual, 0)
				})
			})
		})
	})
}

func TestDurationFromFloat64(t *testing.T) {
	Convey("Given a number of seconds as a float64", t, func() {
		sec := float64(1.2345)
		Convey("When it is converted to a duration", func() {
			duration := durationFromFloat64(sec)
			Convey("Then it still represents the same amount of time", func() {
				So(duration.String(), ShouldResemble, "1.2345s")
			})
		})
	})
}

func TestHandleNewStatterErrorUnbuffered(t *testing.T) {
	defer func() {
		statterMaker = statsd.NewBufferedClient
	}()

	var calledOnce bool
	statterMaker = func(addr, prefix string, flushInterval time.Duration, flushBytes int) (statsd.Statter, error) {
		if calledOnce {
			return nil, fmt.Errorf("lol, nope")
		}
		calledOnce = true
		return NewMockStatter(), nil
	}

	dimensions := telemetry.DimensionSet{
		"Service":  "Foobar",
		"Stage":    "prod",
		"Substage": "primary",
		"Region":   "us-west-2",
	}

	sample := &telemetry.Sample{
		MetricID: telemetry.MetricID{
			Name:       "CoolMetric",
			Dimensions: dimensions,
		},
		Value: float64(0.5),
		Unit:  telemetry.UnitCount,
	}

	st := &identifier.ProcessIdentifier{
		Service: "Foobar",
	}
	unbufferedClient, err := NewUnbuffered(st, "localhost:8125", nil)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	aggregator := telemetry.NewAggregator(2 * time.Second)
	aggregator.AggregateSample(sample)
	distribution := aggregator.Flush()

	unbufferedClient.FlushWithoutBuffering(distribution)
	if unbufferedClient.(*Unbuffered).statter != nil {
		t.Fatal("expected the client's statter to be nil")
	}
	unbufferedClient.FlushWithoutBuffering(distribution)
}

func TestHandleNewStatterErrorBuffered(t *testing.T) {
	defer func() {
		statterMaker = statsd.NewBufferedClient
	}()

	var calledOnce bool
	statterMaker = func(addr, prefix string, flushInterval time.Duration, flushBytes int) (statsd.Statter, error) {
		if calledOnce {
			return nil, fmt.Errorf("lol, nope")
		}
		calledOnce = true
		return NewMockStatter(), nil
	}

	dimensions := telemetry.DimensionSet{
		"Service":  "Foobar",
		"Stage":    "prod",
		"Substage": "primary",
		"Region":   "us-west-2",
	}

	sample := &telemetry.Sample{
		MetricID: telemetry.MetricID{
			Name:       "CoolMetric",
			Dimensions: dimensions,
		},
		Value: float64(0.5),
		Unit:  telemetry.UnitCount,
	}

	aggregator := telemetry.NewAggregator(2 * time.Second)
	aggregator.AggregateSample(sample)

	st := &identifier.ProcessIdentifier{
		Service: "Foobar",
	}
	bufferedClient, err := New(st, "localhost:8125", nil)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	bufferedClient.ObserveSample(sample)
	bufferedClient.(*Client).Flush()
	if bufferedClient.(*Client).statter != nil {
		t.Fatal("expected the client's statter to be nil")
	}
	bufferedClient.ObserveSample(sample)
	bufferedClient.(*Client).Flush()
}
