package middleware

import (
	"context"
	"fmt"
	"time"

	"github.com/twitchtv/twirp"
)

var reqStartCtxKey = new(int)

type Statter interface {
	Inc(metric string, val int64, rate float32) error
	TimingDuration(metric string, val time.Duration, rate float32) error
}

// NewStatsdServerHooks tracks rpc request timings and error codes
func NewStatsdServerHooks(stats Statter) *twirp.ServerHooks {
	return &twirp.ServerHooks{
		// Before request
		RequestReceived: func(ctx context.Context) (context.Context, error) {
			start := time.Now()
			return context.WithValue(ctx, reqStartCtxKey, start), nil
		},

		// After request
		ResponseSent: func(ctx context.Context) {
			// Request info
			service := getServiceName(ctx)
			method := getMethodName(ctx)
			statusCode := getStatusCode(ctx)

			// Request duration
			end := time.Now()
			start, ok := ctx.Value(reqStartCtxKey).(time.Time)
			if !ok { // should never happen, but just in case make sure the problem is visible
				method += "_nostart"
				start = end.Add(-1)
			}
			duration := end.Sub(start)

			// Track stats
			statname := fmt.Sprintf("twirpservice.%s.%s.%s", service, method, statusCode)
			_ = stats.Inc(statname, 1, 1.0)
			_ = stats.TimingDuration(statname, duration, 1.0)
		},
	}
}

func getServiceName(ctx context.Context) string {
	service, ok := twirp.ServiceName(ctx)
	if !ok || service == "" { // should never happen, but just in case make sure the problem is visible
		return "hooksfail"
	}
	return service
}

func getMethodName(ctx context.Context) string {
	method, ok := twirp.MethodName(ctx)
	if !ok || method == "" { // should never happen, but just in case make sure the problem is visible
		return "hooksfail"
	}
	return method
}

func getStatusCode(ctx context.Context) string {
	statusCode, ok := twirp.StatusCode(ctx)
	if !ok || statusCode == "" { // should never happen, but just in case make sure the problem is visible
		return "0"
	}
	return statusCode
}
