package user

import (
	"context"
	"net/http"
	"strings"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/yandex/blackbox"
	"a.yandex-team.ru/library/go/yandex/tvm"
)

type userID struct{}
type serviceID struct{}

var userIDKey userID
var serviceIDKey serviceID

func SetUserAndServiceID(logger log.Structured, bbClient blackbox.Client, tvmClient tvm.Client) func(next http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			var userID tvm.UID
			var serviceID tvm.ClientID
			var err error

			defer func() {
				if err != nil || (serviceID == 0 && userID == 0) {
					return
				}
				newContext := r.Context()
				if serviceID != 0 {
					newContext = ctxWithServiceID(newContext, serviceID)
				}
				if userID != 0 {
					newContext = ctxWithUserID(newContext, userID)
				}
				next.ServeHTTP(w, r.WithContext(newContext))
			}()

			sessionIDCookie, _ := r.Cookie("Session_id")
			if sessionIDCookie != nil {
				userID, err = getUserIDBySessionID(r.Context(), bbClient, sessionIDCookie.Value, r)
				if err != nil {
					logger.Warn("unable to get uid by session_id", log.Error(err))
					http.Error(w, "Unauthorized", http.StatusUnauthorized)
				}
				return
			}

			authorizationHeader := r.Header.Get("Authorization")
			oauthTokenPrefix := "OAuth "
			if authorizationHeader != "" && strings.HasPrefix(authorizationHeader, oauthTokenPrefix) {
				oauthToken := authorizationHeader[len(oauthTokenPrefix):]
				userID, err = getUserIDByOAuthToken(r.Context(), bbClient, oauthToken, r)
				if err != nil {
					logger.Warn("unable to get uid by oauth token", log.Error(err))
					http.Error(w, "Unauthorized", http.StatusUnauthorized)
				}
				return
			}

			serviceTicket := r.Header.Get("X-Ya-Service-Ticket")
			if serviceTicket != "" {
				serviceID, err = getServiceIDByServiceTicket(r.Context(), tvmClient, serviceTicket)
				if err != nil {
					logger.Warn("unable to get service id by service ticket", log.Error(err))
					http.Error(w, "Unauthorized", http.StatusUnauthorized)
					return
				}
				userTicket := r.Header.Get("X-Ya-User-Ticket")
				if userTicket != "" {
					userID, err = getUserIDByUserTicket(r.Context(), tvmClient, userTicket)
					if err != nil {
						logger.Warn("unable to get user id by user ticket", log.Error(err))
						http.Error(w, "Unauthorized", http.StatusUnauthorized)
					}
				}
				return
			}

			http.Error(w, "Unauthorized", http.StatusUnauthorized)
		})
	}
}

func ctxWithUserID(ctx context.Context, userID tvm.UID) context.Context {
	return context.WithValue(ctx, &userIDKey, userID)
}

func GetUserID(ctx context.Context) (userID tvm.UID) {
	value := ctx.Value(&userIDKey)
	if value != nil {
		userID = value.(tvm.UID)
	}
	return
}

func ctxWithServiceID(ctx context.Context, serviceID tvm.ClientID) context.Context {
	return context.WithValue(ctx, &serviceIDKey, serviceID)
}

func GetServiceID(ctx context.Context) (serviceID tvm.ClientID) {
	value := ctx.Value(&serviceIDKey)
	if value != nil {
		serviceID = value.(tvm.ClientID)
	}
	return
}

func getUserIP(r *http.Request) string {
	return r.Header.Get("X-Real-IP")
}

func getUserIDBySessionID(ctx context.Context, bbClient blackbox.Client, sessionID string, r *http.Request) (tvm.UID, error) {
	request := blackbox.SessionIDRequest{
		SessionID: sessionID,
		UserIP:    getUserIP(r),
		Host:      r.Host,
	}

	response, err := bbClient.SessionID(ctx, request)
	if err != nil {
		return 0, err
	}

	return tvm.UID(response.User.ID), nil
}

func getUserIDByOAuthToken(ctx context.Context, bbClient blackbox.Client, oauthToken string, r *http.Request) (tvm.UID, error) {
	request := blackbox.OAuthRequest{
		OAuthToken: oauthToken,
		UserIP:     getUserIP(r),
	}

	response, err := bbClient.OAuth(ctx, request)
	if err != nil {
		return 0, err
	}

	return tvm.UID(response.User.ID), nil
}

func getServiceIDByServiceTicket(ctx context.Context, tvmClient tvm.Client, serviceTicket string) (tvm.ClientID, error) {
	ticket, err := tvmClient.CheckServiceTicket(ctx, serviceTicket)
	if err != nil {
		return 0, err
	}
	return ticket.SrcID, nil
}

func getUserIDByUserTicket(ctx context.Context, tvmClient tvm.Client, userTicket string) (tvm.UID, error) {
	ticket, err := tvmClient.CheckUserTicket(ctx, userTicket)
	if err != nil {
		return 0, err
	}
	return ticket.DefaultUID, nil
}
