fix claudecode review bug

This commit is contained in:
Elysia
2026-03-09 01:18:49 +08:00
parent c069b3b1e8
commit 106b20cdbf
4 changed files with 55 additions and 19 deletions

View File

@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

View File

@@ -67,7 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)

View File

@@ -264,11 +264,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
statusCode := http.StatusBadRequest
c.JSON(statusCode, gin.H{
"error": infraerrors.Reason(err),
"message": infraerrors.Message(err),
})
response.ErrorFrom(c, err)
return
}

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -59,6 +60,7 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
entClient *dbent.Client
userRepo UserRepository
redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache
@@ -77,6 +79,7 @@ type DefaultSubscriptionAssigner interface {
// NewAuthService 创建认证服务实例
func NewAuthService(
entClient *dbent.Client,
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache,
@@ -89,6 +92,7 @@ func NewAuthService(
defaultSubAssigner DefaultSubscriptionAssigner,
) *AuthService {
return &AuthService{
entClient: entClient,
userRepo: userRepo,
redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache,
@@ -597,24 +601,52 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
if s.entClient != nil && invitationRedeemCode != nil {
tx, err := s.entClient.Tx(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to begin transaction for oauth registration: %v", err)
return nil, nil, ErrServiceUnavailable
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := s.userRepo.Create(txCtx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
if err := s.redeemRepo.Use(txCtx, invitationRedeemCode.ID, newUser.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
if err := tx.Commit(); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to commit oauth registration transaction: %v", err)
return nil, nil, ErrServiceUnavailable
}
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to mark invitation code as used for oauth user %d: %v", user.ID, err)
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
}
}
} else {
@@ -644,9 +676,13 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
const pendingOAuthTokenTTL = 10 * time.Minute
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
const pendingOAuthPurpose = "pending_oauth_registration"
type pendingOAuthClaims struct {
Email string `json:"email"`
Username string `json:"username"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims
}
@@ -657,6 +693,7 @@ func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, e
claims := &pendingOAuthClaims{
Email: email,
Username: username,
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
IssuedAt: jwt.NewNumericDate(now),
@@ -687,6 +724,9 @@ func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username
if !ok || !token.Valid {
return "", "", ErrInvalidToken
}
if claims.Purpose != pendingOAuthPurpose {
return "", "", ErrInvalidToken
}
return claims.Email, claims.Username, nil
}