diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index bc001693..7eabde62 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 02c4421d..6dfb2137 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) - authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index b4607b5e..0c7c2da7 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -264,11 +264,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { - statusCode := http.StatusBadRequest - c.JSON(statusCode, gin.H{ - "error": infraerrors.Reason(err), - "message": infraerrors.Message(err), - }) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index f6d40f29..a9edb5fa 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -12,6 +12,7 @@ import ( "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -59,6 +60,7 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { + entClient *dbent.Client userRepo UserRepository redeemRepo RedeemCodeRepository refreshTokenCache RefreshTokenCache @@ -77,6 +79,7 @@ type DefaultSubscriptionAssigner interface { // NewAuthService 创建认证服务实例 func NewAuthService( + entClient *dbent.Client, userRepo UserRepository, redeemRepo RedeemCodeRepository, refreshTokenCache RefreshTokenCache, @@ -89,6 +92,7 @@ func NewAuthService( defaultSubAssigner DefaultSubscriptionAssigner, ) *AuthService { return &AuthService{ + entClient: entClient, userRepo: userRepo, redeemRepo: redeemRepo, refreshTokenCache: refreshTokenCache, @@ -597,24 +601,52 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema Status: StatusActive, } - if err := s.userRepo.Create(ctx, newUser); err != nil { - if errors.Is(err, ErrEmailExists) { - user, err = s.userRepo.GetByEmail(ctx, email) - if err != nil { - logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + if s.entClient != nil && invitationRedeemCode != nil { + tx, err := s.entClient.Tx(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err) + return nil, nil, ErrServiceUnavailable + } + defer func() { _ = tx.Rollback() }() + txCtx := dbent.NewTxContext(ctx, tx) + + if err := s.userRepo.Create(txCtx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return nil, nil, ErrServiceUnavailable } } else { - logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) - return nil, nil, ErrServiceUnavailable + if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if err := tx.Commit(); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err) + return nil, nil, ErrServiceUnavailable + } + user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { - user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) - if invitationRedeemCode != nil { - if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { - logger.LegacyPrintf("service.auth", "[Auth] Failed to mark invitation code as used for oauth user %d: %v", user.ID, err) + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) + return nil, nil, ErrServiceUnavailable } + } else { + user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } } else { @@ -644,9 +676,13 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema // pendingOAuthTokenTTL is the validity period for pending OAuth tokens. const pendingOAuthTokenTTL = 10 * time.Minute +// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens. +const pendingOAuthPurpose = "pending_oauth_registration" + type pendingOAuthClaims struct { Email string `json:"email"` Username string `json:"username"` + Purpose string `json:"purpose"` jwt.RegisteredClaims } @@ -657,6 +693,7 @@ func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, e claims := &pendingOAuthClaims{ Email: email, Username: username, + Purpose: pendingOAuthPurpose, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), IssuedAt: jwt.NewNumericDate(now), @@ -687,6 +724,9 @@ func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username if !ok || !token.Valid { return "", "", ErrInvalidToken } + if claims.Purpose != pendingOAuthPurpose { + return "", "", ErrInvalidToken + } return claims.Email, claims.Username, nil }