package logging

import (
	"context"
	"fmt"
	"net/http"
	"sync"
	"time"

	"code.justin.tv/common/chitin"
	"code.justin.tv/common/twirp"

	netctx "golang.org/x/net/context"
)

type responseTracker struct {
	http.ResponseWriter

	status      int
	headersTime time.Time
}

func (l *responseTracker) Write(buf []byte) (int, error) {
	if l.status == 0 {
		l.status = 200
		l.headersTime = time.Now()
	}
	return l.ResponseWriter.Write(buf)
}

func (l *responseTracker) WriteHeader(status int) {
	if l.status == 0 {
		l.status = status
		l.headersTime = time.Now()
	}
	l.ResponseWriter.WriteHeader(status)
}

type reqErrsKey struct{}

// Middleware returns a middleware handler which can be used with tools like goji.io
func Middleware(globalLogger Logger) func(http.Handler) http.Handler {
	return func(h http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			// these values are embedded in this function to avoid any
			// possibility that they may be used outside of this function
			var (
				reqErrsMu sync.Mutex
				reqErrs   []error
				logger    = globalLogger
			)

			// store the path early in case one of the routers mucks with it (as is common)
			path := r.URL.EscapedPath()
			rt := &responseTracker{
				ResponseWriter: w,
			}
			startTime := time.Now()

			// this function is a specialized receiver which is intended to allow the
			// HandlerError function to inject error values to be logged at the end of
			// a request. The mutex is important to guard against any potential
			// programming errors around this complex construct and make it safe
			ctx := context.WithValue(r.Context(), reqErrsKey{}, func(err error) {
				reqErrsMu.Lock()
				defer reqErrsMu.Unlock()
				reqErrs = append(reqErrs, err)
			})

			defer func() {
				panicVal := recover()

				endTime := time.Now()
				if rt.status == 0 {
					rt.status = 200
					rt.headersTime = endTime
				}

				logger = logger.WithFields(Fields{
					"method":           r.Method,
					"path":             path,
					"status":           rt.status,
					"duration":         fmt.Sprintf("%.3fs", endTime.Sub(startTime).Seconds()),
					"headers-duration": fmt.Sprintf("%.3fs", rt.headersTime.Sub(startTime).Seconds()),
				})

				if tid := chitin.GetTraceID(r.Context()); tid != nil {
					logger = logger.WithField("trace-id", tid.String())
				}

				// at this point we lock the errors mutex because adding errors after
				// the handler returns is invalid anyway. we defer the unlock to make sure
				// that if somehow a call does come in due to programming error it doesn't
				// cause a deadlock if a panic occurs at the same time
				reqErrsMu.Lock()
				defer reqErrsMu.Unlock()

				errToLog := flattenErrors(reqErrs)
				if errToLog != nil {
					logger = logger.WithError(errToLog)
				}

				if panicVal != nil {
					logger.WithField("panic", panicVal).Error()
					panic(panicVal)
				}

				if errToLog != nil {
					logger.Error()
				} else {
					logger.Info()
				}
			}()

			h.ServeHTTP(rt, r.WithContext(ctx))
		})
	}
}

// flattens multiple errors into a single error value
func flattenErrors(errs []error) error {
	switch len(errs) {
	case 0:
		return nil
	case 1:
		return errs[0]
	default:
		var errStrs []string
		for _, e := range errs {
			errStrs = append(errStrs, e.Error())
		}
		return fmt.Errorf("multiple HandlerError calls %+v", errStrs)
	}
}

// HandlerError sets an error message to be logged by the logging middleware
// which is installed by the Middleware() function. This call has no effect when
// called from outside of a handler
func HandlerError(ctx context.Context, err error) {
	if f, ok := ctx.Value(reqErrsKey{}).(func(error)); ok {
		f(err)
	}
}

// TwirpHooks provides a hook which will automaticaly call HandlerError when any
// error is returned by a handler.
func TwirpHooks() *twirp.ServerHooks {
	h := twirp.NewServerHooks()
	h.Error = func(ctx netctx.Context, err twirp.Error) netctx.Context {
		if err != nil {
			HandlerError(ctx, err)
		}
		return ctx
	}
	return h
}
