diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index bc001693..7eabde62 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 7c817e12..24cd93a2 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) - authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 0ccf47e4..0c7c2da7 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -211,8 +211,22 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { email = linuxDoSyntheticEmail(subject) } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username) + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { + if errors.Is(err, service.ErrOAuthInvitationRequired) { + pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) + if tokenErr != nil { + redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") + return + } + fragment := url.Values{} + fragment.Set("error", "invitation_required") + fragment.Set("pending_oauth_token", pendingToken) + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) + return + } // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) return @@ -227,6 +241,41 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { redirectWithFragment(c, frontendCallback, fragment) } +type completeLinuxDoOAuthRequest struct { + PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` +} + +// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating +// the invitation code and creating the user account. +// POST /api/v1/auth/oauth/linuxdo/complete-registration +func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { + var req completeLinuxDoOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + return + } + + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { if h != nil && h.settingSvc != nil { return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx) diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 033a5b77..138663c4 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index f8839cfe..ad9c1b5b 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index c168820c..0efc9560 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -61,6 +61,12 @@ func RegisterAuthRoutes( }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) + auth.POST("/oauth/linuxdo/complete-registration", + rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteLinuxDoOAuthRegistration, + ) } // 公开设置(无需认证) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 6a17c83f..28607e9f 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -12,6 +12,7 @@ import ( "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -21,24 +22,25 @@ import ( ) var ( - ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") - ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") - ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") - ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") - ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") - ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") - ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") - ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") - ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") - ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") - ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") - ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") - ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") - ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed") - ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") - ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") - ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code") + ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") + ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") + ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") + ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") + ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") + ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") + ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") + ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") + ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") + ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") + ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") + ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") + ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed") + ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") + ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") + ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code") + ErrOAuthInvitationRequired = infraerrors.Forbidden("OAUTH_INVITATION_REQUIRED", "invitation code required to complete oauth registration") ) // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 @@ -58,6 +60,7 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { + entClient *dbent.Client userRepo UserRepository redeemRepo RedeemCodeRepository refreshTokenCache RefreshTokenCache @@ -76,6 +79,7 @@ type DefaultSubscriptionAssigner interface { // NewAuthService 创建认证服务实例 func NewAuthService( + entClient *dbent.Client, userRepo UserRepository, redeemRepo RedeemCodeRepository, refreshTokenCache RefreshTokenCache, @@ -88,6 +92,7 @@ func NewAuthService( defaultSubAssigner DefaultSubscriptionAssigner, ) *AuthService { return &AuthService{ + entClient: entClient, userRepo: userRepo, redeemRepo: redeemRepo, refreshTokenCache: refreshTokenCache, @@ -523,9 +528,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return token, user, nil } -// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair -// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token -func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) { +// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。 +// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。 +// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。 +func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, nil, errors.New("refresh token cache not configured") @@ -552,6 +558,22 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrRegDisabled } + // 检查是否需要邀请码 + var invitationRedeemCode *RedeemCode + if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) { + if invitationCode == "" { + return nil, nil, ErrOAuthInvitationRequired + } + redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + if err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + return nil, nil, ErrInvitationCodeInvalid + } + invitationRedeemCode = redeemCode + } + randomPassword, err := randomHexString(32) if err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) @@ -579,20 +601,58 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema Status: StatusActive, } - if err := s.userRepo.Create(ctx, newUser); err != nil { - if errors.Is(err, ErrEmailExists) { - user, err = s.userRepo.GetByEmail(ctx, email) - if err != nil { - logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + if s.entClient != nil && invitationRedeemCode != nil { + tx, err := s.entClient.Tx(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err) + return nil, nil, ErrServiceUnavailable + } + defer func() { _ = tx.Rollback() }() + txCtx := dbent.NewTxContext(ctx, tx) + + if err := s.userRepo.Create(txCtx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return nil, nil, ErrServiceUnavailable } } else { - logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) - return nil, nil, ErrServiceUnavailable + if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if err := tx.Commit(); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err) + return nil, nil, ErrServiceUnavailable + } + user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { - user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) + if invitationRedeemCode != nil { + if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + } + } } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -618,6 +678,63 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } +// pendingOAuthTokenTTL is the validity period for pending OAuth tokens. +const pendingOAuthTokenTTL = 10 * time.Minute + +// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens. +const pendingOAuthPurpose = "pending_oauth_registration" + +type pendingOAuthClaims struct { + Email string `json:"email"` + Username string `json:"username"` + Purpose string `json:"purpose"` + jwt.RegisteredClaims +} + +// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity +// while waiting for the user to supply an invitation code. +func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) { + now := time.Now() + claims := &pendingOAuthClaims{ + Email: email, + Username: username, + Purpose: pendingOAuthPurpose, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(s.cfg.JWT.Secret)) +} + +// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity. +// Returns ErrInvalidToken when the token is invalid or expired. +func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) { + if len(tokenStr) > maxTokenLength { + return "", "", ErrInvalidToken + } + parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return []byte(s.cfg.JWT.Secret), nil + }) + if parseErr != nil { + return "", "", ErrInvalidToken + } + claims, ok := token.Claims.(*pendingOAuthClaims) + if !ok || !token.Valid { + return "", "", ErrInvalidToken + } + if claims.Purpose != pendingOAuthPurpose { + return "", "", ErrInvalidToken + } + return claims.Email, claims.Username, nil +} + func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go new file mode 100644 index 00000000..0472e06c --- /dev/null +++ b/backend/internal/service/auth_service_pending_oauth_test.go @@ -0,0 +1,146 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" +) + +func newAuthServiceForPendingOAuthTest() *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret-pending-oauth", + ExpireHour: 1, + }, + } + return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) +} + +// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。 +func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + token, err := svc.CreatePendingOAuthToken("user@example.com", "alice") + require.NoError(t, err) + require.NotEmpty(t, token) + + email, username, err := svc.VerifyPendingOAuthToken(token) + require.NoError(t, err) + require.Equal(t, "user@example.com", email) + require.Equal(t, "alice", username) +} + +// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + // 签发一个普通 access token(JWTClaims,无 Purpose 字段) + accessToken, err := svc.GenerateToken(&User{ + ID: 1, + Email: "user@example.com", + Role: RoleUser, + }) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(accessToken) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + now := time.Now() + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: "some_other_purpose", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + now := time.Now() + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: "", // 旧 token 无此字段,反序列化后为零值 + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + past := time.Now().Add(-1 * time.Hour) + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: pendingOAuthPurpose, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(past), + IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)), + NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) { + other := NewAuthService(nil, nil, nil, nil, &config.Config{ + JWT: config.JWTConfig{Secret: "other-secret"}, + }, nil, nil, nil, nil, nil, nil) + + token, err := other.CreatePendingOAuthToken("user@example.com", "alice") + require.NoError(t, err) + + svc := newAuthServiceForPendingOAuthTest() + _, _, err = svc.VerifyPendingOAuthToken(token) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_TooLong(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + giant := make([]byte, maxTokenLength+1) + for i := range giant { + giant[i] = 'a' + } + _, _, err := svc.VerifyPendingOAuthToken(string(giant)) + require.ErrorIs(t, err, ErrInvalidToken) +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index b139fdcd..7b50e90d 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -130,6 +130,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E } return NewAuthService( + nil, // entClient repo, nil, // redeemRepo nil, // refreshTokenCache diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 36cb1e06..477ba1b2 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -43,6 +43,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier turnstileService := NewTurnstileService(settingService, verifier) return NewAuthService( + nil, // entClient &userRepoStub{}, nil, // redeemRepo nil, // refreshTokenCache diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index e196e234..c5e1f35d 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -335,6 +335,28 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { + const { data } = await apiClient.post<{ + access_token: string + refresh_token: string + expires_in: number + token_type: string + }>('/auth/oauth/linuxdo/complete-registration', { + pending_oauth_token: pendingOAuthToken, + invitation_code: invitationCode + }) + return data +} + export const authAPI = { login, login2FA, @@ -357,7 +379,8 @@ export const authAPI = { forgotPassword, resetPassword, refreshToken, - revokeAllSessions + revokeAllSessions, + completeLinuxDoOAuthRegistration } export default authAPI diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index be6aff35..9832ed85 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -434,7 +434,12 @@ export default { callbackProcessing: 'Completing login, please wait...', callbackHint: 'If you are not redirected automatically, go back to the login page and try again.', callbackMissingToken: 'Missing login token, please try again.', - backToLogin: 'Back to Login' + backToLogin: 'Back to Login', + invitationRequired: 'This Linux.do account is not yet registered. The site requires an invitation code — please enter one to complete registration.', + invalidPendingToken: 'The registration token has expired. Please sign in with Linux.do again.', + completeRegistration: 'Complete Registration', + completing: 'Completing registration…', + completeRegistrationFailed: 'Registration failed. Please check your invitation code and try again.' }, oauth: { code: 'Code', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 949d51ea..7ad89848 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -433,7 +433,12 @@ export default { callbackProcessing: '正在验证登录信息,请稍候...', callbackHint: '如果页面未自动跳转,请返回登录页重试。', callbackMissingToken: '登录信息缺失,请返回重试。', - backToLogin: '返回登录' + backToLogin: '返回登录', + invitationRequired: '该 Linux.do 账号尚未注册,站点已开启邀请码注册,请输入邀请码以完成注册。', + invalidPendingToken: '注册凭证已失效,请重新使用 Linux.do 登录。', + completeRegistration: '完成注册', + completing: '正在完成注册...', + completeRegistrationFailed: '注册失败,请检查邀请码后重试。' }, oauth: { code: '授权码', diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index 4dbca1df..af48959b 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -10,6 +10,36 @@

+ +
+

+ {{ t('auth.linuxdo.invitationRequired') }} +

+
+ +
+ +

+ {{ invitationError }} +

+
+ +
+
+
{ const params = parseFragmentParams() @@ -80,6 +147,19 @@ onMounted(async () => { const errorDesc = params.get('error_description') || params.get('error_message') || '' if (error) { + if (error === 'invitation_required') { + pendingOAuthToken.value = params.get('pending_oauth_token') || '' + redirectTo.value = sanitizeRedirectPath(params.get('redirect')) + if (!pendingOAuthToken.value) { + errorMessage.value = t('auth.linuxdo.invalidPendingToken') + appStore.showError(errorMessage.value) + isProcessing.value = false + return + } + needsInvitation.value = true + isProcessing.value = false + return + } errorMessage.value = errorDesc || error appStore.showError(errorMessage.value) isProcessing.value = false