package twirphooks

import (
	"context"
	"strings"
	"time"

	twirpstatsd "code.justin.tv/chat/twirphookstatsd"
	"code.justin.tv/feeds/graphdb/proto/datastorerpc"
	"code.justin.tv/feeds/graphdb/proto/dynamoevent"
	"code.justin.tv/feeds/graphdb/proto/graphdb"
	"code.justin.tv/hygienic/twirpserviceslohook"
	"code.justin.tv/hygienic/twirpserviceslohook/statsdslo"

	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/accesslog"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/api/httpapi"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/interngraphdb"
	"code.justin.tv/feeds/graphdb/cmd/graphdb/internal/oldapi/conversion"
	"code.justin.tv/hygienic/statsdsender"
	"github.com/cactus/go-statsd-client/statsd"
	"github.com/twitchtv/twirp"
)

type contextKeys int

const (
	graphdbRequestReceivedCtxtKey contextKeys = iota
	graphdbTwirpErrCtxKey
)

// twirpTraceID returns the X-Ray trace header if one exists.
func twirpTraceID(ctx context.Context) string {
	headers, exists := twirp.HTTPRequestHeaders(ctx)
	if exists {
		return headers.Get("X-Amz-Trace-Id")
	}
	return ""
}

func sanitizeStatsd(s string) string {
	return strings.Map(sanitizeRune, s)
}

func sanitizeRune(r rune) rune {
	switch {
	case 'a' <= r && r <= 'z':
		return r
	case '0' <= r && r <= '9':
		return r
	case 'A' <= r && r <= 'Z':
		return r
	default:
		return '_'
	}
}

func TwirpHooks(hostname string, statsd statsd.Statter, accessLog *accesslog.AccessLog) *twirp.ServerHooks {
	errorlessStats := &statsdsender.ErrorlessStatSender{
		StatSender: statsd.NewSubStatter("twirp.req"),
	}
	fullSLO := twirpserviceslohook.MergeSLO(datastorerpc.SLOWriter, datastorerpc.SLOReader, graphdb.SLOGraphDB, dynamoevent.SLODynamoUpdates)
	tracker := twirpserviceslohook.TwirpSLOTracker{
		SLO:        fullSLO,
		DefaultSLO: time.Second,
		StatTracker: &statsdslo.StatTracker{
			Statter: statsd.NewSubStatter("twirp.slo"),
		},
	}
	return twirp.ChainHooks(
		&twirp.ServerHooks{
			Error: func(ctx context.Context, err twirp.Error) context.Context {
				ctx = context.WithValue(ctx, graphdbTwirpErrCtxKey, err.Error())
				return ctx
			},
			RequestReceived: func(ctx context.Context) (context.Context, error) {
				ctx = context.WithValue(ctx, graphdbRequestReceivedCtxtKey, time.Now())
				ctx = accesslog.WithTrace(ctx)
				return ctx, nil
			},

			ResponseSent: func(ctx context.Context) {
				msg := &interngraphdb.AccessLog{
					Api:      interngraphdb.AccessLog_PROTOBUF,
					MsgTime:  conversion.ToProtoTime(time.Now()),
					TraceId:  twirpTraceID(ctx),
					Hostname: hostname,
				}
				if method, exists := twirp.MethodName(ctx); exists {
					msg.Method = method
				}

				if code, exists := twirp.StatusCode(ctx); exists {
					msg.StatusCode = code
				}

				if headers := httpapi.GetReqHeaders(ctx); len(headers) > 0 {
					msg.Headers = headers
				}

				recv, ok := ctx.Value(graphdbRequestReceivedCtxtKey).(time.Time)
				if ok {
					latency := time.Since(recv)
					msg.ReqReceived = conversion.ToProtoTime(recv)
					msg.LatencyMs = latency.Nanoseconds() / time.Millisecond.Nanoseconds()
					msg.LatencyStr = latency.String()
				}

				errMsg, ok := ctx.Value(graphdbTwirpErrCtxKey).(string)
				if ok {
					msg.Error = errMsg
				}
				accessLog.Event(ctx, msg)
			},
		},
		twirpstatsd.NewStatsdServerHooks(statsd, .1),
		tracker.ServerHook(),
		&twirp.ServerHooks{
			ResponseSent: func(ctx context.Context) {
				method, ok := twirp.MethodName(ctx)
				if !ok {
					return
				}
				source := httpapi.GetReqHeaders(ctx)[httpapi.HeaderXCallerService]
				if source == "" {
					source = "unknown"
				}
				errorlessStats.IncC(sanitizeStatsd(source)+"."+sanitizeStatsd(method), 1, .5)
			},
		},
	)
}
