package auth

import (
	"context"
	"fmt"
	"net"

	grpcAuth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
	"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/peer"
	"google.golang.org/grpc/status"

	"a.yandex-team.ru/library/go/yandex/blackbox"
	"a.yandex-team.ru/library/go/yandex/tvm"
	"a.yandex-team.ru/security/xray/internal/servers/grpc/infra"
)

const (
	XYaTokenHeader         = "X-Ya-Token"
	XYaServiceTicketHeader = "X-Ya-Service-Ticket"
	XYaUserTicketHeader    = "X-Ya-User-Ticket"
)

var (
	acceptableTVMScopes = []string{"bb:sessionid", "yp:api"}
)

type Auth struct {
	i *infra.Infra
}

func NewGrpcAuth(i *infra.Infra) grpcAuth.AuthFunc {
	a := Auth{
		i: i,
	}
	return a.authCb
}

func (a *Auth) authCb(ctx context.Context) (context.Context, error) {
	md := metautils.ExtractIncoming(ctx)
	if val := md.Get(XYaTokenHeader); val != "" {
		return a.processOAuthAuth(ctx, md, val)
	}

	if val := md.Get(XYaServiceTicketHeader); val != "" {
		return a.processTVMAuth(ctx, md, val)
	}

	return nil, status.Error(codes.Unauthenticated, "no auth-header provided")
}

func (a *Auth) processOAuthAuth(ctx context.Context, _ metautils.NiceMD, oauthToken string) (context.Context, error) {
	userPeer, ok := peer.FromContext(ctx)
	if !ok {
		return nil, status.Error(codes.Internal, "failed to get user-ip")
	}

	var userIP string
	switch v := userPeer.Addr.(type) {
	case *net.TCPAddr:
		userIP = v.IP.String()
	default:
		userIP = "2a02:6b8:c1d:2b96:0:47fd:1234:1234"
	}

	// TODO(buglloc): scopes needed
	rsp, err := a.i.BlackBox.OAuth(ctx, blackbox.OAuthRequest{
		OAuthToken:    oauthToken,
		UserIP:        userIP,
		GetUserTicket: true,
	})

	if err != nil {
		return nil, status.Errorf(codes.Unauthenticated, "failed to check oauth token: %s", err)
	}

	ticketInfo, err := a.i.TVM.CheckUserTicket(ctx, rsp.UserTicket)
	if err != nil {
		return nil, status.Errorf(codes.Unauthenticated, "failed to check TVM user ticket from oauth token: %s", err)
	}

	authInfo := &Info{
		UserID:    rsp.User.ID,
		UserLogin: rsp.User.Login,
	}

	err = a.checkUserRoles(ctx, ticketInfo, authInfo)
	if err != nil {
		return nil, status.Errorf(codes.Unauthenticated, "failed to check user roles: %s", err)
	}

	return WithAuthInfo(ctx, authInfo), nil
}

func (a *Auth) processTVMAuth(ctx context.Context, md metautils.NiceMD, serviceTicket string) (context.Context, error) {
	serviceInfo, err := a.i.TVM.CheckServiceTicket(ctx, serviceTicket)
	if err != nil {
		return nil, status.Errorf(codes.Unauthenticated, "failed to check TVM service ticket: %s", err)
	}

	found := false
	for _, id := range a.i.Config.AllowedTVMClients {
		if id == serviceInfo.SrcID {
			found = true
			break
		}
	}

	serviceAuthInfo := &Info{}

	if !found {
		err = a.checkServiceRoles(ctx, serviceInfo, serviceAuthInfo)
		if err == nil {
			found = true
		}
	}

	if !found {
		return nil, status.Errorf(codes.Unauthenticated, "TVM client %d is not allowed", serviceInfo.SrcID)
	}

	userTicket := md.Get(XYaUserTicketHeader)
	if userTicket == "" {
		if serviceAuthInfo.IsAdmin {
			// Admin TVM client allowed to use API w/o user ticket
			return WithAuthInfo(ctx, serviceAuthInfo), nil
		}

		return nil, status.Error(codes.Unauthenticated, "no user-ticket provided")
	}

	// TODO(buglloc): add custom oauth scope too
	ticketInfo, err := a.i.TVM.CheckUserTicket(ctx, userTicket)
	if err != nil {
		return nil, status.Errorf(codes.Unauthenticated, "failed to check TVM user ticket: %s", err)
	}

	if err = ticketInfo.CheckScopesAny(acceptableTVMScopes...); err != nil {
		return nil, status.Errorf(codes.Unauthenticated, "failed to check TVM user ticket scopes: %s", err)
	}

	userInfo, err := a.i.BlackBox.UserTicket(ctx, blackbox.UserTicketRequest{
		UserTicket: userTicket,
	})
	if err != nil {
		return nil, status.Errorf(codes.Unauthenticated, "failed to check TVM user ticket: %s", err)
	}

	if len(userInfo.Users) == 0 {
		return nil, status.Error(codes.Unauthenticated, "no users in TVM user ticket")
	}

	user := userInfo.Users[0]

	authInfo := &Info{
		UserLogin: user.Login,
		UserID:    user.ID,
	}

	err = a.checkUserRoles(ctx, ticketInfo, authInfo)
	if err != nil {
		return nil, status.Errorf(codes.Unauthenticated, "failed to check user roles: %s", err)
	}

	return WithAuthInfo(ctx, authInfo), nil
}

func (a *Auth) checkUserRoles(ctx context.Context, ticket *tvm.CheckedUserTicket, authInfo *Info) error {
	roles, err := a.i.TVM.GetRoles(ctx)
	if err != nil {
		return fmt.Errorf("failed to get roles: %w", err)
	}

	userRoles, err := roles.GetRolesForUser(ticket, nil)
	if err != nil {
		return fmt.Errorf("failed to get user roles: %w", err)
	}

	a.setRoleAttrs(userRoles, authInfo)

	return nil
}

func (a *Auth) checkServiceRoles(ctx context.Context, ticket *tvm.CheckedServiceTicket, authInfo *Info) error {
	roles, err := a.i.TVM.GetRoles(ctx)
	if err != nil {
		return fmt.Errorf("failed to get roles: %w", err)
	}

	serviceRoles := roles.GetRolesForService(ticket)

	a.setRoleAttrs(serviceRoles, authInfo)

	return nil
}

func (a *Auth) setRoleAttrs(consumerRoles *tvm.ConsumerRoles, authInfo *Info) {
	if consumerRoles.HasRole(a.i.Config.Roles.Admin) {
		authInfo.IsAdmin = true
	} else {
		authInfo.IsAdmin = false
	}

	if consumerRoles.HasRole(a.i.Config.Roles.Reader) {
		authInfo.IsReader = true
	} else {
		authInfo.IsReader = false
	}
}
