package twirptagginghook

import (
	"context"
	"strings"
	"time"

	"code.justin.tv/hygienic/metricscactusstatsd"
	"github.com/twitchtv/twirp"
)

var reqStartTimestampKey = new(int)

func markReqStart(ctx context.Context) context.Context {
	return context.WithValue(ctx, reqStartTimestampKey, time.Now())
}

func getReqStart(ctx context.Context) (time.Time, bool) {
	t, ok := ctx.Value(reqStartTimestampKey).(time.Time)
	return t, ok
}

// NewTaggingServerHooks provides a twirp.ServerHooks struct which
// sends data to a tagging based service.
func NewTaggingServerHooks(stats metricscactusstatsd.TaggingSubStatter) *twirp.ServerHooks {
	hooks := &twirp.ServerHooks{}
	stats.NewDimensionalSubStatter(map[string]string{"producer": "twirp"})
	hooks.RequestReceived = func(ctx context.Context) (context.Context, error) {
		ctx = markReqStart(ctx)

		stats.IncD("requests", map[string]string{"method": "all"}, 1)
		return ctx, nil
	}

	hooks.RequestRouted = func(ctx context.Context) (context.Context, error) {
		method, ok := twirp.MethodName(ctx)
		if !ok {
			return ctx, nil
		}
		stats.IncD("requests", map[string]string{"method": sanitize(method)}, 1)
		return ctx, nil
	}

	hooks.ResponseSent = func(ctx context.Context) {
		// Three pieces of data to get, none are guaranteed to be present:
		// - time that the request started
		// - method that was called
		// - status code of response
		var (
			start  time.Time
			method string
			status string

			haveStart  bool
			haveMethod bool
			haveStatus bool
		)

		start, haveStart = getReqStart(ctx)
		method, haveMethod = twirp.MethodName(ctx)
		status, haveStatus = twirp.StatusCode(ctx)

		method = sanitize(method)
		status = sanitize(status)

		stats.IncD("response", map[string]string{"method": "all"}, 1)

		if haveMethod {
			stats.IncD("response", map[string]string{"method": method}, 1)
		}
		if haveStatus {
			stats.IncD(status, map[string]string{"method": "all"}, 1)
		}
		if haveMethod && haveStatus {
			stats.IncD(status, map[string]string{"method": method}, 1)
		}

		if haveStart {
			dur := time.Since(start)

			stats.TimingDurationD("response", map[string]string{"method": "all"}, dur)

			if haveMethod {
				stats.TimingDurationD("response", map[string]string{"method": method}, dur)
			}
			if haveStatus {
				stats.TimingDurationD(status, map[string]string{"method": "all"}, dur)
			}
			if haveMethod && haveStatus {
				stats.TimingDurationD(status, map[string]string{"method": method}, dur)
			}
		}
	}

	if t, ok := stats.(metricscactusstatsd.TelemetryStatsdShim); ok {
		return twirp.ChainHooks(hooks, t.ServerHooks())
	}

	return hooks
}

func sanitize(s string) string {
	return strings.Map(sanitizeRune, s)
}

func sanitizeRune(r rune) rune {
	switch {
	case 'a' <= r && r <= 'z':
		return r
	case '0' <= r && r <= '9':
		return r
	case 'A' <= r && r <= 'Z':
		return r
	default:
		return '_'
	}
}
