package api

import (
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/httputil/headers"
	"a.yandex-team.ru/library/go/yandex/blackbox"
	"a.yandex-team.ru/mail/payments-sdk-backend/internal/api/httputil"
	"a.yandex-team.ru/mail/payments-sdk-backend/internal/utils/ctxutil"
	"a.yandex-team.ru/mail/payments-sdk-backend/internal/utils/stats"
	"a.yandex-team.ru/mail/payments-sdk-backend/internal/utils/tracing"
	"fmt"
	"github.com/go-chi/chi/v5/middleware"
	"github.com/opentracing/opentracing-go"
	"github.com/opentracing/opentracing-go/ext"
	"net/http"
	"runtime/debug"
	"strconv"
	"time"
)

const (
	SdkVersionHeader           = "X-SDK-Version"
	AuthHeaderKeyXUid          = "X-Uid"
	AuthHeaderKeyXServiceToken = "X-Service-Token"
	OAuthTokenPrefix           = "OAuth "
	UserIPHeader               = "X-Real-IP"
)

type Middleware func(http.Handler) http.Handler

func jsonResponseFormat(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		next.ServeHTTP(w, r)
	})
}

func injectContext(logger log.Logger) Middleware {
	return func(next http.Handler) http.Handler {
		handler := func(w http.ResponseWriter, r *http.Request) {
			ctx := r.Context()
			ctx = ctxutil.WithLogger(ctx, logger)

			reqID, ok := ctx.Value(middleware.RequestIDKey).(string)
			if ok {
				ctx = ctxutil.WithRequestID(ctx, reqID)
			}

			userAgent := r.Header.Get(headers.UserAgentKey)
			if userAgent != "" {
				ctx = ctxutil.WithUserAgent(ctx, userAgent)
			}

			sdkVersion := r.Header.Get(SdkVersionHeader)
			if sdkVersion != "" {
				ctx = ctxutil.WithSdkVersion(ctx, sdkVersion)
				if _, err := ctxutil.GetSdkSemverValue(ctx); err != nil {
					logger.Warn(
						fmt.Sprintf("Failed to parse sdk version header - %s, error - %s", sdkVersion, err),
					)
				}
			}

			remoteIP := r.RemoteAddr
			ctx = ctxutil.WithRemoteIP(ctx, remoteIP)

			ctx = ctxutil.WithPath(ctx, r.URL.Path)

			next.ServeHTTP(w, r.WithContext(ctx))
		}
		return http.HandlerFunc(handler)
	}
}

func recoveryLogging(logger log.Logger) Middleware {
	return func(next http.Handler) http.Handler {
		handler := func(w http.ResponseWriter, r *http.Request) {
			defer func() {
				ctx := r.Context()
				if rvr := recover(); rvr != nil && rvr != http.ErrAbortHandler {
					logger.Error(
						fmt.Sprintf("Internal server error: %+v\n %+v", rvr, string(debug.Stack())),
						ctxutil.GetStoredFields(ctx)...,
					)

					httputil.ResponseWithInternalServerError(ctx, w, httputil.ServiceErrorStatusCodeInfra,
						"Internal Server Error")
				}
			}()
			next.ServeHTTP(w, r)
		}
		return http.HandlerFunc(handler)
	}
}

func jaegerMiddleware() Middleware {
	return func(next http.Handler) http.Handler {
		handler := func(w http.ResponseWriter, r *http.Request) {
			ctx := r.Context()

			requestPath := r.URL.Path
			if tracing.SkipMonitoring(requestPath) {
				next.ServeHTTP(w, r)
				return
			}

			spanCtx, _ := opentracing.GlobalTracer().Extract(
				opentracing.HTTPHeaders,
				opentracing.HTTPHeadersCarrier(r.Header),
			)

			span, withSpan := opentracing.StartSpanFromContext(
				ctx,
				requestPath,
				ext.RPCServerOption(spanCtx),
				opentracing.Tag{Key: "http.method", Value: r.Method},
				opentracing.Tag{Key: "http.path", Value: requestPath},
				opentracing.Tag{Key: "http.url", Value: r.URL.String()},
			)
			uid := r.Header.Get(AuthHeaderKeyXUid)
			if len(uid) > 0 {
				opentracing.Tag{Key: "uid", Value: uid}.Set(span)
			}
			defer func() {
				if rvr := recover(); rvr != nil {
					tracing.TagErrorMessage(ctx, "Unexpected error")
					span.Finish()
					panic(rvr)
				} else {
					span.Finish()
				}
			}()

			next.ServeHTTP(w, r.WithContext(withSpan))
		}
		return http.HandlerFunc(handler)
	}
}

func authMiddleware(bb blackbox.Client, checkAuth bool) Middleware {
	return func(next http.Handler) http.Handler {
		handler := func(w http.ResponseWriter, r *http.Request) {
			ctx := r.Context()

			uid := r.Header.Get(AuthHeaderKeyXUid)
			serviceToken := r.Header.Get(AuthHeaderKeyXServiceToken)
			authToken := r.Header.Get(headers.AuthorizationKey)

			reqID := ctxutil.GetRequestID(ctx)
			logger := ctxutil.GetLogger(ctx)

			if checkAuth {
				if len(uid) > 0 && (len(authToken) == 0 || authToken == "-") ||
					len(uid) == 0 && len(authToken) > 0 {
					logger.Error("uid is empty or OAuth is empty", ctxutil.GetStoredFields(ctx)...)
					httputil.ResponseWithError(ctx, w, http.StatusForbidden, httputil.ServiceErrorStatusCodeIncorrectFormat,
						"authorization failed")
					return
				}

				if len(uid) > 0 && len(authToken) > 0 {
					validateCtx, err := validateOAuthAndUIDPair(ctx, bb, uid, authToken, r.Header.Get(UserIPHeader))
					if err != nil {
						logger.Error(err.Error(), append(ctxutil.GetStoredFields(ctx), log.String(AuthHeaderKeyXUid, uid))...)
						tracing.TagErrorWithMessage(ctx, "authorization failed", err)
						httputil.ResponseWithError(ctx, w, http.StatusForbidden, httputil.ServiceErrorStatusCodeIncorrectFormat,
							"authorization failed")
						return
					}
					ctx = validateCtx
				}
			}

			if len(uid) == 0 {
				logger.Info(
					fmt.Sprintf("Request %s doesn't have %s Header", reqID, AuthHeaderKeyXUid),
					ctxutil.GetStoredFields(ctx)...,
				)
			} else {
				iUID, _ := strconv.Atoi(uid)
				ctx = ctxutil.WithUID(ctx, uint64(iUID))
			}
			if len(serviceToken) == 0 {
				logger.Warn(
					fmt.Sprintf("Request %s doesn't have %s Header", reqID, AuthHeaderKeyXServiceToken),
					ctxutil.GetStoredFields(ctx)...,
				)
			} else {
				ctxutil.WithServiceToken(ctx, serviceToken)
			}

			next.ServeHTTP(w, r)
		}
		return http.HandlerFunc(handler)
	}
}

func requestsStatMiddleware(metrics *stats.Metrics) Middleware {
	return func(next http.Handler) http.Handler {
		handler := func(w http.ResponseWriter, r *http.Request) {
			requestPath := r.URL.Path
			if tracing.SkipMonitoring(requestPath) {
				next.ServeHTTP(w, r)
				return
			}
			measurable := NewMeasurableResponseWriter(w)
			defer func() {
				metrics.CountStatusCode(measurable.StatusCode())
				metrics.UpdateRequestTime(time.Since(measurable.startTime).Seconds())
			}()
			next.ServeHTTP(measurable, r)
		}
		return http.HandlerFunc(handler)
	}
}
