package sfxresink

import (
	"bytes"
	"sort"

	"context"

	"github.com/signalfx/golib/datapoint"
	"github.com/signalfx/golib/datapoint/dpsink"
	"github.com/signalfx/golib/event"
	"github.com/signalfx/golib/sfxclient"
)

type Resink struct {
	Fallthrough       dpsink.Sink
	RemovedDimensions map[string]struct{}
	previousCC        map[string]datapoint.Value
}

var _ sfxclient.Sink = &Resink{}
var _ dpsink.Sink = &Resink{}

func hashKey(dp *datapoint.Datapoint) string {
	ret := bytes.Buffer{}
	ret.WriteString(dp.Metric)
	keys := make([]string, 0, len(dp.Dimensions))
	for k := range dp.Dimensions {
		keys = append(keys, k)
	}
	sort.Strings(keys)
	for _, key := range keys {
		ret.WriteByte(0)
		ret.WriteString(dp.Dimensions[key])
	}
	return ret.String()
}

func convertValue(prevVal datapoint.Value, thisVal datapoint.Value) datapoint.Value {
	if prevVal == nil {
		return thisVal
	}
	switch t := thisVal.(type) {
	case datapoint.IntValue:
		if prevInt, ok := prevVal.(datapoint.IntValue); ok {
			return datapoint.NewIntValue(t.Int() - prevInt.Int())
		}
	case datapoint.FloatValue:
		if prevInt, ok := prevVal.(datapoint.FloatValue); ok {
			return datapoint.NewFloatValue(t.Float() - prevInt.Float())
		}
	}
	return thisVal
}

func (r *Resink) AddEvents(ctx context.Context, events []*event.Event) error {
	return r.Fallthrough.AddEvents(ctx, events)
}

func (r *Resink) AddDatapoints(ctx context.Context, points []*datapoint.Datapoint) error {
	ret := make([]*datapoint.Datapoint, 0, len(points))
	for _, p := range points {
		nextP := *p
		prevCount := len(nextP.Dimensions)
		nextP.Dimensions = r.copyWithoutDims(nextP.Dimensions)
		newCount := len(nextP.Dimensions)
		if prevCount != newCount && nextP.MetricType == datapoint.Counter {
			nextP.Metric = nextP.Metric + ".per_host"
			hKey := hashKey(&nextP)
			if r.previousCC == nil {
				r.previousCC = make(map[string]datapoint.Value)
			}
			nextVal := convertValue(r.previousCC[hKey], nextP.Value)
			r.previousCC[hKey] = nextP.Value
			nextP.Value = nextVal
			nextP.MetricType = datapoint.Count
		}
		ret = append(ret, &nextP)
	}
	return r.Fallthrough.AddDatapoints(ctx, ret)
}

func (r *Resink) copyWithoutDims(dims map[string]string) map[string]string {
	ret := make(map[string]string, len(dims))
	for k, v := range dims {
		if _, exists := r.RemovedDimensions[k]; !exists {
			ret[k] = v
		}
	}
	return ret
}
