package lib

import (
	"context"
	"fmt"
	"sync"
)

// traceBody lives along side a context and allows us to add arbitrary key/value pairs to the context so we can
// later log them out inside the access logs
type traceBody struct {
	strings map[string]string
	ints    map[string]int64
	floats  map[string]float64
	mu      sync.Mutex
}

func (t *traceBody) addString(key string, value string) {
	if t == nil {
		return
	}
	t.mu.Lock()
	defer t.mu.Unlock()
	if t.strings == nil {
		t.strings = make(map[string]string)
	}
	t.strings[key] = value
}

func (t *traceBody) incInt(key string, value int64) {
	if t == nil {
		return
	}
	t.mu.Lock()
	defer t.mu.Unlock()
	if t.ints == nil {
		t.ints = make(map[string]int64)
	}
	t.ints[key] = t.ints[key] + value
}

func (t *traceBody) incFloat(key string, value float64) {
	if t == nil {
		return
	}
	t.mu.Lock()
	defer t.mu.Unlock()
	if t.floats == nil {
		t.floats = make(map[string]float64)
	}
	t.floats[key] = t.floats[key] + value
}

func (t *traceBody) setInt(key string, value int64) {
	if t == nil {
		return
	}
	t.mu.Lock()
	defer t.mu.Unlock()
	if t.ints == nil {
		t.ints = make(map[string]int64)
	}
	t.ints[key] = value
}

func (t *traceBody) setFloat(key string, value float64) {
	if t == nil {
		return
	}
	t.mu.Lock()
	defer t.mu.Unlock()
	if t.floats == nil {
		t.floats = make(map[string]float64)
	}
	t.floats[key] = value
}

func (t *traceBody) copyTraces() (map[string]string, map[string]int64, map[string]float64) {
	if t == nil {
		return nil, nil, nil
	}
	t.mu.Lock()
	defer t.mu.Unlock()
	retStr := make(map[string]string, len(t.strings))
	for k, v := range t.strings {
		retStr[k] = v
	}
	retInt := make(map[string]int64, len(t.ints))
	for k, v := range t.ints {
		retInt[k] = v
	}
	retFloat := make(map[string]float64, len(t.floats))
	for k, v := range t.floats {
		retFloat[k] = v
	}
	return retStr, retInt, retFloat
}

type logTypes int

const (
	traceKey logTypes = iota
)

// WithTrace initializes a context with a traceBody so we can inject key/value pairs into the context later
func WithTrace(ctx context.Context) context.Context {
	return context.WithValue(ctx, traceKey, &traceBody{})
}

// traceValues extracts the previously injected key/value pairs so we can add them to the access log
func traceValues(ctx context.Context) (map[string]string, map[string]int64, map[string]float64, bool) {
	t := getTrace(ctx)
	if t == nil {
		return nil, nil, nil, false
	}
	a, b, c := t.copyTraces()
	return a, b, c, true
}

// getTrace extracts the trace created with `WithTrace`, so we can add key/value pairs to it
func getTrace(ctx context.Context) *traceBody {
	retI := ctx.Value(traceKey)
	if retI == nil {
		return nil
	}
	ret, ok := retI.(*traceBody)
	if ok {
		return ret
	}
	return nil
}

// SetTraceString adds a string key/value pair to the context
func SetTraceString(ctx context.Context, key string, value string) {
	getTrace(ctx).addString(key, value)
}

// IncTraceInt increments an integer key/value pair to the context
func IncTraceInt(ctx context.Context, key string, toAdd int64) {
	getTrace(ctx).incInt(key, toAdd)
}

// SetTraceInt sets an integer key/value pair to the context
func SetTraceInt(ctx context.Context, key string, value int64) {
	getTrace(ctx).setInt(key, value)
}

// IncTraceFloat increments an floating point key/value pair to the context
func IncTraceFloat(ctx context.Context, key string, value float64) {
	getTrace(ctx).incFloat(key, value)
}

// SetTraceFloat sets an floating point key/value pair to the context
func SetTraceFloat(ctx context.Context, key string, value float64) {
	getTrace(ctx).setFloat(key, value)
}

// SetRequestBody adds a string, which should be the body of the twirp request, to the context so we can later pull it
// out during access logs
func SetRequestBody(ctx context.Context, req fmt.Stringer) {
	if req == nil || ctx == nil {
		return
	}
	asBody := req.String()
	SetTraceInt(ctx, "req_body_size", int64(len(asBody)))
	if len(asBody) < 1024 && len(asBody) > 0 {
		SetTraceString(ctx, "req_body", asBody)
	}
}

// SetResponseBody adds a string, which should be the body of the twirp response, to the context so we can later pull it
// out during access logs
func SetResponseBody(ctx context.Context, response fmt.Stringer) {
	if response == nil || ctx == nil {
		return
	}
	asBody := response.String()
	SetTraceInt(ctx, "resp_body_size", int64(len(asBody)))
	if len(asBody) < 1024 && len(asBody) > 0 {
		SetTraceString(ctx, "resp_body", asBody)
	}
}
