package webauth

import (
	"bytes"
	"crypto/hmac"
	"crypto/sha256"
	"errors"
	"fmt"
	"time"

	"github.com/golang-jwt/jwt/v4"

	"a.yandex-team.ru/security/skotty/service/internal/models"
)

const (
	Issuer = "skotty"
)

type WebAuth struct {
	key []byte
}

type AuthInfo struct {
	AuthKind     models.AuthKind
	AuthID       string
	User         string
	TokenType    models.TokenType
	TokenSerial  string
	TokenName    string
	Hostname     string
	EnrollmentID string
}

type RenewInfo struct {
	AuthID       string
	User         string
	TokenSerial  string
	TokenType    models.TokenType
	EnrollmentID string
}

func NewWebAuth(key string) (*WebAuth, error) {
	if len(key) < 10 {
		return nil, errors.New("secret key length must gather than 10 characters")
	}

	return &WebAuth{
		key: []byte(key),
	}, nil
}

func (a *WebAuth) EnrollAuthTokens(info AuthInfo) (string, string, error) {
	hostToken, err := a.newAuthToken(info, models.AuthKindEnrollHost)
	if err != nil {
		return "", "", fmt.Errorf("can't sign host token: %w", err)
	}

	userToken, err := a.newAuthToken(info, models.AuthKindEnrollUser)
	if err != nil {
		return "", "", fmt.Errorf("can't sign user token: %w", err)
	}

	return hostToken, userToken, nil
}

func (a *WebAuth) AuthRenewTokens(info AuthInfo) (string, string, error) {
	hostToken, err := a.newAuthToken(info, models.AuthKindRenewHost)
	if err != nil {
		return "", "", fmt.Errorf("can't sign host token: %w", err)
	}

	userToken, err := a.newAuthToken(info, models.AuthKindRenewUser)
	if err != nil {
		return "", "", fmt.Errorf("can't sign user token: %w", err)
	}

	return hostToken, userToken, nil
}

func (a *WebAuth) SignAuthorization(authInfo *models.Authorization) (string, error) {
	mess := bytes.NewBuffer(nil)
	mess.WriteString(authInfo.ID)
	mess.WriteByte(':')
	mess.WriteString(authInfo.User)
	if authInfo.UserTicket != "" {
		// for backward compatibility
		mess.WriteByte(':')
		mess.WriteString(authInfo.UserTicket)
	}

	h := hmac.New(sha256.New, a.key)
	_, err := h.Write(mess.Bytes())
	return fmt.Sprintf("%x", h.Sum(nil)), err
}

func (a *WebAuth) CheckAuthorizationSign(authInfo *models.Authorization) (bool, error) {
	expected, err := a.SignAuthorization(authInfo)
	if err != nil {
		return false, err
	}

	return hmac.Equal([]byte(expected), []byte(authInfo.Sign)), nil
}

func (a *WebAuth) RenewToken(info RenewInfo) (string, error) {
	now := time.Now()
	renewJWT := jwt.NewWithClaims(jwt.SigningMethodHS256, RenewClaims{
		User:         info.User,
		TokenSerial:  info.TokenSerial,
		TokenType:    info.TokenType,
		EnrollmentID: info.EnrollmentID,
		StandardClaims: StandardClaims{
			Issuer:    Issuer,
			IssuedAt:  now.Unix(),
			ExpiresAt: now.AddDate(1, 0, 0).Unix(),
			ID:        info.AuthID,
			AuthKind:  models.AuthKindTokenRenew,
		},
	})

	token, err := renewJWT.SignedString(a.key)
	if err != nil {
		return "", fmt.Errorf("can't sign token: %w", err)
	}

	return token, nil
}

func (a *WebAuth) ParseAuthToken(tokenString string, authKind models.AuthKind) (AuthInfo, error) {
	var claims AuthClaims
	_, err := jwt.ParseWithClaims(tokenString, &claims, func(token *jwt.Token) (interface{}, error) {
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
		}

		return a.key, nil
	})
	if err != nil {
		return AuthInfo{}, fmt.Errorf("failed to parse token: %w", err)
	}

	if err := claims.Valid(); err != nil {
		return AuthInfo{}, fmt.Errorf("invalid token: %w", err)
	}

	if authKind != models.AuthKindNone && claims.AuthKind != authKind {
		return AuthInfo{}, fmt.Errorf("invalid token aud: %s != %s", claims.AuthKind, authKind)
	}

	return AuthInfo{
		AuthKind:     claims.AuthKind,
		AuthID:       claims.ID,
		User:         claims.User,
		TokenType:    claims.TokenType,
		TokenSerial:  claims.TokenSerial,
		TokenName:    claims.TokenName,
		Hostname:     claims.Hostname,
		EnrollmentID: claims.EnrollmentID,
	}, nil
}

func (a *WebAuth) ParseRenew(tokenString string) (RenewInfo, error) {
	var claims RenewClaims
	_, err := jwt.ParseWithClaims(tokenString, &claims, func(token *jwt.Token) (interface{}, error) {
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
		}

		return a.key, nil
	})
	if err != nil {
		return RenewInfo{}, fmt.Errorf("failed to parse token: %w", err)
	}

	if err := claims.Valid(); err != nil {
		return RenewInfo{}, fmt.Errorf("invalid token: %w", err)
	}

	if claims.AuthKind != models.AuthKindTokenRenew {
		return RenewInfo{}, fmt.Errorf("invalid token aud: %s != %s", claims.AuthKind, models.AuthKindTokenRenew)
	}

	out := RenewInfo{
		User:         claims.User,
		TokenType:    claims.TokenType,
		TokenSerial:  claims.TokenSerial,
		EnrollmentID: claims.EnrollmentID,
	}
	return out, nil
}

func (a *WebAuth) newAuthToken(info AuthInfo, authKind models.AuthKind) (string, error) {
	now := time.Now()
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, AuthClaims{
		User:         info.User,
		TokenType:    info.TokenType,
		TokenSerial:  info.TokenSerial,
		TokenName:    info.TokenName,
		EnrollmentID: info.EnrollmentID,
		Hostname:     info.Hostname,
		StandardClaims: StandardClaims{
			Issuer:    Issuer,
			IssuedAt:  now.Unix(),
			ExpiresAt: now.Add(time.Hour).Unix(),
			ID:        info.AuthID,
			AuthKind:  authKind,
		},
	})

	return token.SignedString(a.key)
}
