diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 033a5b77..138663c4 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index f8839cfe..ad9c1b5b 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index a9edb5fa..28607e9f 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -647,6 +647,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } else { user = newUser s.assignDefaultSubscriptions(ctx, user.ID) + if invitationRedeemCode != nil { + if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + } } } } else { diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go new file mode 100644 index 00000000..0472e06c --- /dev/null +++ b/backend/internal/service/auth_service_pending_oauth_test.go @@ -0,0 +1,146 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" +) + +func newAuthServiceForPendingOAuthTest() *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret-pending-oauth", + ExpireHour: 1, + }, + } + return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) +} + +// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。 +func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + token, err := svc.CreatePendingOAuthToken("user@example.com", "alice") + require.NoError(t, err) + require.NotEmpty(t, token) + + email, username, err := svc.VerifyPendingOAuthToken(token) + require.NoError(t, err) + require.Equal(t, "user@example.com", email) + require.Equal(t, "alice", username) +} + +// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + // 签发一个普通 access token(JWTClaims,无 Purpose 字段) + accessToken, err := svc.GenerateToken(&User{ + ID: 1, + Email: "user@example.com", + Role: RoleUser, + }) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(accessToken) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + now := time.Now() + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: "some_other_purpose", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + now := time.Now() + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: "", // 旧 token 无此字段,反序列化后为零值 + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + + past := time.Now().Add(-1 * time.Hour) + claims := &pendingOAuthClaims{ + Email: "user@example.com", + Username: "alice", + Purpose: pendingOAuthPurpose, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(past), + IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)), + NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) + require.NoError(t, err) + + _, _, err = svc.VerifyPendingOAuthToken(tokenStr) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) { + other := NewAuthService(nil, nil, nil, nil, &config.Config{ + JWT: config.JWTConfig{Secret: "other-secret"}, + }, nil, nil, nil, nil, nil, nil) + + token, err := other.CreatePendingOAuthToken("user@example.com", "alice") + require.NoError(t, err) + + svc := newAuthServiceForPendingOAuthTest() + _, _, err = svc.VerifyPendingOAuthToken(token) + require.ErrorIs(t, err, ErrInvalidToken) +} + +// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。 +func TestVerifyPendingOAuthToken_TooLong(t *testing.T) { + svc := newAuthServiceForPendingOAuthTest() + giant := make([]byte, maxTokenLength+1) + for i := range giant { + giant[i] = 'a' + } + _, _, err := svc.VerifyPendingOAuthToken(string(giant)) + require.ErrorIs(t, err, ErrInvalidToken) +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index b139fdcd..7b50e90d 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -130,6 +130,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E } return NewAuthService( + nil, // entClient repo, nil, // redeemRepo nil, // refreshTokenCache diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 36cb1e06..477ba1b2 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -43,6 +43,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier turnstileService := NewTurnstileService(settingService, verifier) return NewAuthService( + nil, // entClient &userRepoStub{}, nil, // redeemRepo nil, // refreshTokenCache