package yasmsinternal

import (
	"fmt"

	"github.com/labstack/echo/v4"

	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/passport/infra/daemons/yasms_internal/internal/errs"
	"a.yandex-team.ru/passport/infra/daemons/yasms_internal/internal/roles"
)

func (t *YasmsInternal) tvmServiceAuthMiddleware() echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(ctx echo.Context) error {
			ticket := ctx.Request().Header.Get("X-Ya-Service-Ticket")
			if ticket == "" {
				return &errs.UnauthorizedError{
					Status:  errs.Error,
					Message: "Service ticket is missing",

					Component: errs.CommonComponent,
				}
			}

			checked, err := t.tvm.CheckServiceTicket(ctx.Request().Context(), ticket)
			if err != nil {
				tvmErr := err.(*tvm.TicketError)
				e := &errs.UnauthorizedError{
					Status:       errs.Error,
					TicketStatus: tvmErr.Status.String(),
					Message:      tvmErr.Error(),

					Component: errs.CommonComponent,
				}

				if checked != nil {
					e.LoggablePart = checked.LogInfo
				}

				return e
			}

			ctx.SetRequest(ctx.Request().WithContext(
				tvm.WithServiceTicket(ctx.Request().Context(), checked),
			))
			return next(ctx)
		}
	}
}

func (t *YasmsInternal) tvmUserAuthMiddleware() echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(ctx echo.Context) error {
			ticket := ctx.Request().Header.Get("X-Ya-User-Ticket")
			if ticket == "" {
				return &errs.UnauthorizedError{
					Status:  errs.Error,
					Message: "User ticket is missing",

					Component: errs.CommonComponent,
				}
			}

			checked, err := t.tvm.CheckUserTicket(ctx.Request().Context(), ticket)
			if err != nil {
				tvmErr := err.(*tvm.TicketError)
				e := &errs.UnauthorizedError{
					Status:       errs.Error,
					TicketStatus: tvmErr.Status.String(),
					Message:      tvmErr.Error(),

					Component: errs.CommonComponent,
				}

				if checked != nil {
					e.LoggablePart = checked.LogInfo
				}

				return e
			}

			ctx.SetRequest(ctx.Request().WithContext(
				tvm.WithUserTicket(ctx.Request().Context(), checked),
			))
			return next(ctx)
		}
	}
}

func (t *YasmsInternal) checkGrantsMiddleware(
	accessType roles.AccessTypeSlug,
	handle roles.HandlerSlug,
	checkUserRoles bool,
) echo.MiddlewareFunc {
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(ctx echo.Context) error {
			r, err := t.tvm.GetRoles(ctx.Request().Context())
			if err != nil {
				return &errs.UnknownError{
					Status:  errs.Error,
					Message: "Failed to check roles",

					Component: errs.CommonComponent,
				}
			}

			st := tvm.ContextServiceTicket(ctx.Request().Context())
			ut := tvm.ContextUserTicket(ctx.Request().Context())

			authType := roles.AuthWithoutUserTicket
			if ut != nil {
				authType = roles.AuthWithUserTicket
			}

			serviceRoles := r.GetRolesForService(st)
			if missingRoles := roles.CheckServiceRoles(serviceRoles, authType, accessType, handle); missingRoles != "" {
				return &errs.AccessDeniedError{
					Status:  errs.Error,
					Message: fmt.Sprintf("Service missing required IDM roles: %s", missingRoles),

					Component: errs.CommonComponent,
				}
			}

			if ut == nil || !checkUserRoles {
				return next(ctx)
			}

			userRoles, err := r.GetRolesForUser(ut, nil)
			if err != nil {
				return &errs.UnauthorizedError{
					Status:  errs.Error,
					Message: fmt.Sprintf("Failed to check user roles: %s", err.Error()),

					Component: errs.CommonComponent,
				}
			}
			if missingRoles := roles.CheckUserRoles(userRoles, accessType, handle); missingRoles != "" {
				return &errs.AccessDeniedError{
					Status:  errs.Error,
					Message: fmt.Sprintf("User missing required IDM roles: %s", missingRoles),

					Component: errs.CommonComponent,
				}
			}

			return next(ctx)
		}
	}
}

func (t *YasmsInternal) checkServiceGrantsMiddleware(
	accessType roles.AccessTypeSlug,
	handle roles.HandlerSlug,
) echo.MiddlewareFunc {
	return t.checkGrantsMiddleware(accessType, handle, false)
}
