package admapi

import (
	"context"
	"errors"
	"net/http"
	"strings"

	"a.yandex-team.ru/library/go/yandex/blackbox"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/security/skotty/service/internal/oauth"
)

type contextKey string

const contextUserAuthKey = contextKey("user")
const contextTVMAuthKey = contextKey("tvm")

type UserAuth struct {
	UID        blackbox.ID
	Login      string
	UserTicket string
}

type TVMAuth struct {
	Ticket *tvm.CheckedServiceTicket
}

func authByAuthHeader(ctx context.Context, bbClient blackbox.Client, token, ip string) (context.Context, error) {
	if !strings.HasPrefix(token, "OAuth ") {
		return ctx, errors.New("invalid OAuth token format: expecting 'OAuth AQAD...'")
	}

	rsp, err := bbClient.OAuth(ctx, blackbox.OAuthRequest{
		OAuthToken:    token[6:],
		UserIP:        ip,
		GetUserTicket: true,
		Scopes:        oauth.RequiredScopes,
	})
	if err != nil {
		return ctx, err
	}

	return context.WithValue(
		ctx,
		contextUserAuthKey,
		&UserAuth{
			UID:        rsp.User.UID.ID,
			Login:      rsp.User.Login,
			UserTicket: rsp.UserTicket,
		},
	), nil
}

func authByTVMTicket(ctx context.Context, tvmClient tvm.Client, ticket string) (context.Context, error) {
	rsp, err := tvmClient.CheckServiceTicket(ctx, ticket)
	if err != nil {
		return ctx, err
	}

	return context.WithValue(
		ctx,
		contextTVMAuthKey,
		&TVMAuth{
			Ticket: rsp,
		},
	), nil
}

func authMiddleware(bbClient blackbox.Client, tvmClient tvm.Client) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		fn := func(w http.ResponseWriter, r *http.Request) {
			auth := func(ctx context.Context) (context.Context, error) {
				if v := r.Header.Get("Authorization"); v != "" {
					return authByAuthHeader(ctx, bbClient, v, r.RemoteAddr)
				}

				if v := r.Header.Get("X-Ya-Service-Ticket"); v != "" {
					return authByTVMTicket(ctx, tvmClient, v)
				}

				return ctx, errors.New("no 'Authorization' or 'X-Ya-Service-Ticket' header found")
			}

			ctx, err := auth(r.Context())
			if err != nil {
				RespErrorf(w, http.StatusUnauthorized, "auth fail: %v", err)
				return
			}

			next.ServeHTTP(w, r.WithContext(ctx))
		}
		return http.HandlerFunc(fn)
	}
}

func checkACLMiddleware(tvmClient tvm.Client, role string) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		fn := func(w http.ResponseWriter, r *http.Request) {
			roles, err := tvmClient.GetRoles(r.Context())
			if err != nil {
				RespErrorf(w, http.StatusForbidden, "can't get TVM roles")
				return
			}

			checkAuth := func() bool {
				u, ok := r.Context().Value(contextUserAuthKey).(*UserAuth)
				if ok && u != nil {
					ticket, err := tvmClient.CheckUserTicket(r.Context(), u.UserTicket)
					if err != nil {
						RespErrorf(w, http.StatusForbidden, "blackbox returned invalid user ticket")
						return false
					}

					ok, err := roles.CheckUserRole(ticket, role, nil)
					if err != nil {
						RespErrorf(w, http.StatusForbidden, "can't check user role")
						return false
					}

					return ok
				}

				t, ok := r.Context().Value(contextTVMAuthKey).(*TVMAuth)
				if ok && t != nil {
					return roles.CheckServiceRole(t.Ticket, role, nil)
				}

				return false
			}

			if !checkAuth() {
				RespErrorf(w, http.StatusForbidden, "access denied")
				return
			}

			next.ServeHTTP(w, r)
		}
		return http.HandlerFunc(fn)
	}
}
