package handler

import (
	"context"
	"net/http"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
	"github.com/go-chi/chi/v5/middleware"
	"github.com/gofrs/uuid"
)

type userIP struct{}

var userIPKey userIP

func GetRequestLogger(logger log.Logger) func(next http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			start := time.Now()
			ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
			userIP := getUserIPFromRequest(r)

			defer func() {
				time := time.Since(start).Milliseconds()
				status := ww.Status()
				size := ww.BytesWritten()

				ctx := ctxlog.WithFields(
					r.Context(),
					log.String("referer", r.Referer()),
					log.Int64("time", time),
					log.Int("status", status),
					log.String("user_ip", userIP),
					log.Int("size", size),
					log.String("host", r.Host),
					log.String("uri", r.RequestURI),
					log.String("method", r.Method),
				)
				ctxlog.Infof(ctx, logger, "%v>%v [%v] %vms %vB %v%v", r.Method, status, userIP, time, size, r.Host, r.RequestURI)
			}()

			requestID, err := uuid.NewV4()
			if err != nil {
				logger.Warn("unable to generate request_id")
			}
			requestUserField := log.String("user_ip", userIP)
			requestIDField := log.String("request_id", requestID.String())
			ctxWithIP := ctxWithUserIP(r.Context(), userIP)
			r = r.WithContext(ctxlog.WithFields(ctxWithIP, requestIDField, requestUserField))

			next.ServeHTTP(ww, r)
		})
	}
}

func ctxWithUserIP(ctx context.Context, ip string) context.Context {
	return context.WithValue(ctx, &userIPKey, ip)
}

func getUserIP(ctx context.Context) string {
	if ctx == nil {
		return ""
	}
	if ip, ok := ctx.Value(&userIPKey).(string); ok {
		return ip
	}
	return ""
}

func getUserIPFromRequest(r *http.Request) string {
	var ip string
	for _, header := range []string{"X-Client-Real-IP", "X-Real-IP"} {
		ip = r.Header.Get(header)
		if ip != "" {
			break
		}
	}
	if ip == "" {
		ip = r.RemoteAddr
	}
	return ip
}

func GetRequestTimeout(logger log.Logger, timeout time.Duration) func(next http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			ctx, cancel := context.WithTimeout(r.Context(), timeout)
			defer cancel()
			next.ServeHTTP(w, r.WithContext(ctx))
		})
	}
}
