package api

import (
	"context"
	"net/http"
	"time"

	"github.com/rs/zerolog"
	"github.com/rs/zerolog/hlog"
)

type ctxKey int

const (
	requestLogSkipCtxKey ctxKey = 1
)

func LoggingMiddleware(log zerolog.Logger, stats Statter) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			// Create a copy of the logger (including internal context) to prevent data race if UpdateContext is used.
			l := log.With().Logger()

			ctx := r.Context()
			ctx = l.WithContext(ctx) // add to context, to be used by AddRequestLogKeyStr(r, k, v)
			skipper := &requestLogSkip{}
			ctx = context.WithValue(ctx, requestLogSkipCtxKey, skipper) // to be used by SkipRequestLog(r)
			r = r.WithContext(ctx)

			start := time.Now()
			wr := &writterRecorder{ResponseWriter: w}
			next.ServeHTTP(wr, r)

			if !skipper.skip { // if SkipRequestLog(r) was called in the handler, skip logs for this request
				reqDur := time.Since(start)
				l := hlog.FromRequest(r)
				l.Info().
					Str("route", r.Method+" "+r.URL.String()). // e.g. "GET /foobar"
					Int("status", wr.status).
					Int("bytes", wr.bytesWritten).
					Dur("durMs", reqDur). // request duration in milliseconds
					Msg("Request")

				statusXxx := statusGroupXxx(wr.status)
				stats.Inc("Request."+statusXxx, 1)
				stats.Duration("RequestDurarion."+statusXxx, reqDur)
			}
		})
	}
}

// AddRequestLogKeyStr adds an extra key/value (Str) to the log line for the request
func AddRequestLogKeyStr(r *http.Request, key, val string) {
	hlog.FromRequest(r).UpdateContext(func(c zerolog.Context) zerolog.Context {
		return c.Str(key, val)
	})
}

// AddRequestLogKeyErr adds an extra "error" key (Err) to the log line for the request
func AddRequestLogKeyErr(r *http.Request, err error) {
	hlog.FromRequest(r).UpdateContext(func(c zerolog.Context) zerolog.Context {
		return c.Err(err)
	})
}

// SkipRequestLog can be called from a route to avoid logging that access.
// Useful to keep logs clean off health checks and access for static assets.
func SkipRequestLog(r *http.Request) {
	skipper, ok := r.Context().Value(requestLogSkipCtxKey).(*requestLogSkip)
	if ok {
		skipper.skip = true
	}
}

// writterRecorder wraps a http.ResponseWriter to record activity
type writterRecorder struct {
	http.ResponseWriter
	wroteHeader  bool
	status       int
	bytesWritten int
}

func (wr *writterRecorder) Write(buf []byte) (int, error) {
	wr.WriteHeader(http.StatusOK) // write 200 unless another value was wrote before
	n, err := wr.ResponseWriter.Write(buf)
	wr.bytesWritten += n
	return n, err
}

func (wr *writterRecorder) WriteHeader(status int) {
	if !wr.wroteHeader {
		wr.status = status
		wr.wroteHeader = true
		wr.ResponseWriter.WriteHeader(status)
	}
}

// requestLogSkip is used to track if the request access log should be skipped on a given request context
type requestLogSkip struct {
	skip bool
}

func statusGroupXxx(status int) string {
	if status >= 200 && status <= 299 {
		return "2xx"
	} else if status >= 300 && status <= 399 {
		return "3xx"
	} else if status >= 400 && status <= 499 {
		return "4xx"
	} else if status >= 500 && status <= 599 {
		return "5xx"
	} else {
		return "unk"
	}
}
