package auth

import (
	"context"
	"errors"
	"fmt"
	"net/http"

	"github.com/labstack/echo/v4"

	"a.yandex-team.ru/library/go/yandex/blackbox"
	"a.yandex-team.ru/security/xray/internal/servers/humanizer/infra"
	"a.yandex-team.ru/security/xray/pkg/xray"
)

const (
	UserTicketKey = "X-Request-User-Ticket"
)

var (
	errNoAuthProvided = errors.New("no Authorization header or Session_id cookie provided")
)

func NewAuthMiddleware(i *infra.Infra) echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(e echo.Context) error {
			ticket, err := getUserTicket(i, e)
			if err != nil {
				return &echo.HTTPError{
					Code:     http.StatusUnauthorized,
					Message:  err.Error(),
					Internal: err,
				}
			}

			e.Set(UserTicketKey, ticket)

			return next(e)
		}
	}
}

func WithXrayCredentials(i *infra.Infra, e echo.Context) context.Context {
	return xray.WithCredentials(e.Request().Context(), &xray.UserTicketCredentials{
		TVMCredentials: i.XRayCredentials,
		UserTicket:     e.Get(UserTicketKey).(string),
	})
}

func getUserTicket(i *infra.Infra, e echo.Context) (string, error) {
	var (
		ticket string
		err    error
	)

	switch {
	case e.Request().Header.Get("Authorization") != "":
		ticket, err = fromOAuth(i, e)
	default:
		ticket, err = fromCookies(i, e)
	}

	if err != nil {
		return "", fmt.Errorf("auth failed: %w", err)
	}

	return ticket, nil
}

func fromCookies(i *infra.Infra, e echo.Context) (string, error) {
	sessID, err := e.Cookie("Session_id")
	if err != nil {
		return "", errNoAuthProvided
	}

	rsp, err := i.BlackBox.SessionID(
		e.Request().Context(),
		blackbox.SessionIDRequest{
			SessionID:     sessID.Value,
			UserIP:        getIP(e),
			Host:          e.Request().Host,
			GetUserTicket: true,
		})

	if err != nil {
		return "", fmt.Errorf("failed to check session cookies: %w", err)
	}

	return rsp.UserTicket, nil
}

func fromOAuth(i *infra.Infra, e echo.Context) (string, error) {
	token := e.Request().Header.Get("Authorization")
	if len(token) <= 6 || token[:6] != "OAuth " {
		return "", errNoAuthProvided
	}

	rsp, err := i.BlackBox.OAuth(
		e.Request().Context(),
		blackbox.OAuthRequest{
			OAuthToken:    token[6:],
			UserIP:        getIP(e),
			GetUserTicket: true,
		})

	if err != nil {
		return "", fmt.Errorf("failed to check OAuth token: %w", err)
	}

	return rsp.UserTicket, nil
}

func getIP(e echo.Context) string {
	l7IP := e.Request().Header.Get("X-Forwarded-For-Y")
	if l7IP == "" {
		return e.RealIP()
	}
	return l7IP
}
