diff --git a/controller/oauth.go b/controller/oauth.go
index 7d1bde7ea..1552aa10e 100644
--- a/controller/oauth.go
+++ b/controller/oauth.go
@@ -3,10 +3,16 @@ package controller
import (
"encoding/json"
"net/http"
+ "one-api/model"
"one-api/setting/system_setting"
"one-api/src/oauth"
+ "time"
"github.com/gin-gonic/gin"
+ jwt "github.com/golang-jwt/jwt/v5"
+ "one-api/middleware"
+ "strconv"
+ "strings"
)
// GetJWKS 获取JWKS公钥集
@@ -19,6 +25,9 @@ func GetJWKS(c *gin.Context) {
return
}
+ // lazy init if needed
+ _ = oauth.EnsureInitialized()
+
jwks := oauth.GetJWKS()
if jwks == nil {
c.JSON(http.StatusInternalServerError, gin.H{
@@ -70,7 +79,7 @@ func OAuthTokenEndpoint(c *gin.Context) {
// 只允许application/x-www-form-urlencoded内容类型
contentType := c.GetHeader("Content-Type")
- if contentType != "application/x-www-form-urlencoded" {
+ if contentType == "" || !strings.Contains(strings.ToLower(contentType), "application/x-www-form-urlencoded") {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"error_description": "Content-Type must be application/x-www-form-urlencoded",
@@ -78,7 +87,11 @@ func OAuthTokenEndpoint(c *gin.Context) {
return
}
- // 委托给OAuth2服务器处理
+ // lazy init
+ if err := oauth.EnsureInitialized(); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()})
+ return
+ }
oauth.HandleTokenRequest(c)
}
@@ -93,7 +106,10 @@ func OAuthAuthorizeEndpoint(c *gin.Context) {
return
}
- // 委托给OAuth2服务器处理
+ if err := oauth.EnsureInitialized(); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()})
+ return
+ }
oauth.HandleAuthorizeRequest(c)
}
@@ -108,20 +124,68 @@ func OAuthServerInfo(c *gin.Context) {
}
// 返回OAuth2服务器的基本信息(类似OpenID Connect Discovery)
+ issuer := settings.Issuer
+ if issuer == "" {
+ scheme := "https"
+ if c.Request.TLS == nil {
+ if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
+ scheme = hdr
+ } else {
+ scheme = "http"
+ }
+ }
+ issuer = scheme + "://" + c.Request.Host
+ }
+
+ base := issuer + "/api"
c.JSON(http.StatusOK, gin.H{
- "issuer": settings.Issuer,
- "authorization_endpoint": settings.Issuer + "/oauth/authorize",
- "token_endpoint": settings.Issuer + "/oauth/token",
- "jwks_uri": settings.Issuer + "/.well-known/jwks.json",
+ "issuer": issuer,
+ "authorization_endpoint": base + "/oauth/authorize",
+ "token_endpoint": base + "/oauth/token",
+ "jwks_uri": base + "/.well-known/jwks.json",
"grant_types_supported": settings.AllowedGrantTypes,
- "response_types_supported": []string{"code"},
+ "response_types_supported": []string{"code", "token"},
"token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"},
"code_challenge_methods_supported": []string{"S256"},
- "scopes_supported": []string{
- "api:read",
- "api:write",
- "admin",
- },
+ "scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"},
+ "default_private_key_path": settings.DefaultPrivateKeyPath,
+ })
+}
+
+// OAuthOIDCConfiguration OIDC discovery document
+func OAuthOIDCConfiguration(c *gin.Context) {
+ settings := system_setting.GetOAuth2Settings()
+ if !settings.Enabled {
+ c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"})
+ return
+ }
+ issuer := settings.Issuer
+ if issuer == "" {
+ scheme := "https"
+ if c.Request.TLS == nil {
+ if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
+ scheme = hdr
+ } else {
+ scheme = "http"
+ }
+ }
+ issuer = scheme + "://" + c.Request.Host
+ }
+ base := issuer + "/api"
+ c.JSON(http.StatusOK, gin.H{
+ "issuer": issuer,
+ "authorization_endpoint": base + "/oauth/authorize",
+ "token_endpoint": base + "/oauth/token",
+ "userinfo_endpoint": base + "/oauth/userinfo",
+ "jwks_uri": base + "/.well-known/jwks.json",
+ "response_types_supported": []string{"code", "token"},
+ "grant_types_supported": settings.AllowedGrantTypes,
+ "subject_types_supported": []string{"public"},
+ "id_token_signing_alg_values_supported": []string{"RS256"},
+ "scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"},
+ "token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"},
+ "code_challenge_methods_supported": []string{"S256"},
+ "default_private_key_path": settings.DefaultPrivateKeyPath,
})
}
@@ -152,14 +216,50 @@ func OAuthIntrospect(c *gin.Context) {
return
}
- // TODO: 实现令牌内省逻辑
- // 1. 验证调用者的认证信息
- // 2. 解析和验证JWT令牌
- // 3. 返回令牌的元信息
+ tokenString := token
- c.JSON(http.StatusOK, gin.H{
- "active": false, // 临时返回,需要实现实际的内省逻辑
+ // 验证并解析JWT
+ parsed, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
+ if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
+ return nil, jwt.ErrTokenSignatureInvalid
+ }
+ pub := oauth.GetPublicKeyByKid(func() string {
+ if v, ok := token.Header["kid"].(string); ok {
+ return v
+ }
+ return ""
+ }())
+ if pub == nil {
+ return nil, jwt.ErrTokenUnverifiable
+ }
+ return pub, nil
})
+ if err != nil || !parsed.Valid {
+ c.JSON(http.StatusOK, gin.H{"active": false})
+ return
+ }
+
+ claims, ok := parsed.Claims.(jwt.MapClaims)
+ if !ok {
+ c.JSON(http.StatusOK, gin.H{"active": false})
+ return
+ }
+
+ // 检查撤销
+ if jti, ok := claims["jti"].(string); ok && jti != "" {
+ if revoked, _ := model.IsTokenRevoked(jti); revoked {
+ c.JSON(http.StatusOK, gin.H{"active": false})
+ return
+ }
+ }
+
+ // 有效
+ resp := gin.H{"active": true}
+ for k, v := range claims {
+ resp[k] = v
+ }
+ resp["token_type"] = "Bearer"
+ c.JSON(http.StatusOK, resp)
}
// OAuthRevoke 令牌撤销端点(RFC 7009)
@@ -190,11 +290,86 @@ func OAuthRevoke(c *gin.Context) {
return
}
- // TODO: 实现令牌撤销逻辑
- // 1. 验证调用者的认证信息
- // 2. 撤销指定的令牌(加入黑名单或从存储中删除)
+ token = c.PostForm("token")
+ if token == "" {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": "invalid_request",
+ "error_description": "Missing token parameter",
+ })
+ return
+ }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
+ // 尝试解析JWT,若成功则记录jti到撤销表
+ parsed, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
+ if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
+ return nil, jwt.ErrTokenSignatureInvalid
+ }
+ pub := oauth.GetRSAPublicKey()
+ if pub == nil {
+ return nil, jwt.ErrTokenUnverifiable
+ }
+ return pub, nil
})
+ if err == nil && parsed != nil && parsed.Valid {
+ if claims, ok := parsed.Claims.(jwt.MapClaims); ok {
+ var jti string
+ var exp int64
+ if v, ok := claims["jti"].(string); ok {
+ jti = v
+ }
+ if v, ok := claims["exp"].(float64); ok {
+ exp = int64(v)
+ } else if v, ok := claims["exp"].(int64); ok {
+ exp = v
+ }
+ if jti != "" {
+ // 如果没有exp,默认撤销至当前+TTL 10分钟
+ if exp == 0 {
+ exp = time.Now().Add(10 * time.Minute).Unix()
+ }
+ _ = model.RevokeToken(jti, exp)
+ }
+ }
+ }
+
+ c.JSON(http.StatusOK, gin.H{"success": true})
+}
+
+// OAuthUserInfo returns OIDC userinfo based on access token
+func OAuthUserInfo(c *gin.Context) {
+ settings := system_setting.GetOAuth2Settings()
+ if !settings.Enabled {
+ c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"})
+ return
+ }
+ // 需要 OAuthJWTAuth 中间件注入 claims
+ claims, ok := middleware.GetOAuthClaims(c)
+ if !ok {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"})
+ return
+ }
+ // scope 校验:必须包含 openid
+ scope, _ := claims["scope"].(string)
+ if !strings.Contains(" "+scope+" ", " openid ") {
+ c.JSON(http.StatusForbidden, gin.H{"error": "insufficient_scope"})
+ return
+ }
+ sub, _ := claims["sub"].(string)
+ resp := gin.H{"sub": sub}
+ // 若包含 profile/email scope,补充返回
+ if strings.Contains(" "+scope+" ", " profile ") || strings.Contains(" "+scope+" ", " email ") {
+ if uid, err := strconv.Atoi(sub); err == nil {
+ if user, err2 := model.GetUserById(uid, false); err2 == nil && user != nil {
+ if strings.Contains(" "+scope+" ", " profile ") {
+ resp["name"] = user.DisplayName
+ resp["preferred_username"] = user.Username
+ }
+ if strings.Contains(" "+scope+" ", " email ") {
+ resp["email"] = user.Email
+ resp["email_verified"] = true
+ }
+ }
+ }
+ }
+ c.JSON(http.StatusOK, resp)
}
diff --git a/controller/oauth_keys.go b/controller/oauth_keys.go
new file mode 100644
index 000000000..9e2397d3a
--- /dev/null
+++ b/controller/oauth_keys.go
@@ -0,0 +1,89 @@
+package controller
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/logger"
+ "one-api/src/oauth"
+)
+
+type rotateKeyRequest struct {
+ Kid string `json:"kid"`
+}
+
+type genKeyFileRequest struct {
+ Path string `json:"path"`
+ Kid string `json:"kid"`
+ Overwrite bool `json:"overwrite"`
+}
+
+type importPemRequest struct {
+ Pem string `json:"pem"`
+ Kid string `json:"kid"`
+}
+
+// RotateOAuthSigningKey rotates the OAuth2 JWT signing key (Root only)
+func RotateOAuthSigningKey(c *gin.Context) {
+ var req rotateKeyRequest
+ _ = c.BindJSON(&req)
+ kid, err := oauth.RotateSigningKey(req.Kid)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ logger.LogInfo(c, "oauth signing key rotated: "+kid)
+ c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid})
+}
+
+// ListOAuthSigningKeys returns current and historical JWKS signing keys
+func ListOAuthSigningKeys(c *gin.Context) {
+ keys := oauth.ListSigningKeys()
+ c.JSON(http.StatusOK, gin.H{"success": true, "data": keys})
+}
+
+// DeleteOAuthSigningKey deletes a non-current key by kid
+func DeleteOAuthSigningKey(c *gin.Context) {
+ kid := c.Param("kid")
+ if kid == "" {
+ c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "kid required"})
+ return
+ }
+ if err := oauth.DeleteSigningKey(kid); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ logger.LogInfo(c, "oauth signing key deleted: "+kid)
+ c.JSON(http.StatusOK, gin.H{"success": true})
+}
+
+// GenerateOAuthSigningKeyFile generates a private key file and rotates current kid
+func GenerateOAuthSigningKeyFile(c *gin.Context) {
+ var req genKeyFileRequest
+ if err := c.ShouldBindJSON(&req); err != nil || req.Path == "" {
+ c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "path required"})
+ return
+ }
+ kid, err := oauth.GenerateAndPersistKey(req.Path, req.Kid, req.Overwrite)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ logger.LogInfo(c, "oauth signing key generated to file: "+req.Path+" kid="+kid)
+ c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid, "path": req.Path})
+}
+
+// ImportOAuthSigningKey imports PEM text and rotates current kid
+func ImportOAuthSigningKey(c *gin.Context) {
+ var req importPemRequest
+ if err := c.ShouldBindJSON(&req); err != nil || req.Pem == "" {
+ c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "pem required"})
+ return
+ }
+ kid, err := oauth.ImportPEMKey(req.Pem, req.Kid)
+ if err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ logger.LogInfo(c, "oauth signing key imported from PEM, kid="+kid)
+ c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid})
+}
diff --git a/examples/oauth/oauth-demo.html b/examples/oauth/oauth-demo.html
new file mode 100644
index 000000000..210e0d254
--- /dev/null
+++ b/examples/oauth/oauth-demo.html
@@ -0,0 +1,326 @@
+
+
+
+
+
+ OAuth2/OIDC 授权码 + PKCE 前端演示
+
+
+
+
+
OAuth2/OIDC 授权码 + PKCE 前端演示
+
+
+
+
+
+
+
+
提示:若未配置 Issuer,可直接填写下方端点。
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
说明:
+
+ - 本页为纯前端演示,适用于公开客户端(不需要 client_secret)。
+ - 如跨域调用 Token/UserInfo,需要服务端正确设置 CORS;建议将此 demo 部署到同源域名下。
+
+
+
+
+
+
+
+
+
+
+
+
+
可将服务端返回的 OIDC Discovery JSON 粘贴到此处,点击“解析并填充端点”。
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/oauth/oauth2_test_client.go b/examples/oauth/oauth2_test_client.go
new file mode 100644
index 000000000..d8e6dd239
--- /dev/null
+++ b/examples/oauth/oauth2_test_client.go
@@ -0,0 +1,181 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/url"
+ "path"
+ "strings"
+ "time"
+
+ "golang.org/x/oauth2"
+ "golang.org/x/oauth2/clientcredentials"
+)
+
+func main() {
+ // 测试 Client Credentials 流程
+ //testClientCredentials()
+
+ // 测试 Authorization Code + PKCE 流程(需要浏览器交互)
+ testAuthorizationCode()
+}
+
+// testClientCredentials 测试服务对服务认证
+func testClientCredentials() {
+ fmt.Println("=== Testing Client Credentials Flow ===")
+
+ cfg := clientcredentials.Config{
+ ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端
+ ClientSecret: "hLLdn2Ia4UM7hcsJaSuUFDV0Px9BrkNq",
+ TokenURL: "http://localhost:3000/api/oauth/token",
+ Scopes: []string{"api:read", "api:write"},
+ EndpointParams: map[string][]string{
+ "audience": {"api://new-api"},
+ },
+ }
+
+ // 创建HTTP客户端
+ httpClient := cfg.Client(context.Background())
+
+ // 调用受保护的API
+ resp, err := httpClient.Get("http://localhost:3000/api/status")
+ if err != nil {
+ log.Printf("Request failed: %v", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Printf("Failed to read response: %v", err)
+ return
+ }
+
+ fmt.Printf("Status: %s\n", resp.Status)
+ fmt.Printf("Response: %s\n", string(body))
+}
+
+// testAuthorizationCode 测试授权码流程
+func testAuthorizationCode() {
+ fmt.Println("=== Testing Authorization Code + PKCE Flow ===")
+
+ conf := oauth2.Config{
+ ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端
+ ClientSecret: "JHiugKf89OMmTLuZMZyA2sgZnO0Ioae3",
+ RedirectURL: "http://localhost:9999/callback",
+ // 包含 openid/profile/email 以便调用 UserInfo
+ Scopes: []string{"openid", "profile", "email", "api:read"},
+ Endpoint: oauth2.Endpoint{
+ AuthURL: "http://localhost:3000/api/oauth/authorize",
+ TokenURL: "http://localhost:3000/api/oauth/token",
+ },
+ }
+
+ // 生成PKCE参数
+ codeVerifier := oauth2.GenerateVerifier()
+ state := fmt.Sprintf("state-%d", time.Now().Unix())
+
+ // 构建授权URL
+ url := conf.AuthCodeURL(
+ state,
+ oauth2.S256ChallengeOption(codeVerifier),
+ //oauth2.SetAuthURLParam("audience", "api://new-api"),
+ )
+
+ fmt.Printf("Visit this URL to authorize:\n%s\n\n", url)
+ fmt.Printf("A local server will listen on http://localhost:9999/callback to receive the code...\n")
+
+ // 启动回调本地服务器,自动接收授权码
+ codeCh := make(chan string, 1)
+ srv := &http.Server{Addr: ":9999"}
+ http.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
+ q := r.URL.Query()
+ if errParam := q.Get("error"); errParam != "" {
+ fmt.Fprintf(w, "Authorization failed: %s", errParam)
+ return
+ }
+ gotState := q.Get("state")
+ if gotState != state {
+ http.Error(w, "state mismatch", http.StatusBadRequest)
+ return
+ }
+ code := q.Get("code")
+ if code == "" {
+ http.Error(w, "missing code", http.StatusBadRequest)
+ return
+ }
+ fmt.Fprintln(w, "Authorization received. You may close this window.")
+ select {
+ case codeCh <- code:
+ default:
+ }
+ go func() {
+ // 稍后关闭服务
+ _ = srv.Shutdown(context.Background())
+ }()
+ })
+ go func() {
+ _ = srv.ListenAndServe()
+ }()
+
+ // 等待授权码
+ var code string
+ select {
+ case code = <-codeCh:
+ case <-time.After(5 * time.Minute):
+ log.Println("Timeout waiting for authorization code")
+ _ = srv.Shutdown(context.Background())
+ return
+ }
+
+ // 交换令牌
+ token, err := conf.Exchange(
+ context.Background(),
+ code,
+ oauth2.VerifierOption(codeVerifier),
+ )
+ if err != nil {
+ log.Printf("Token exchange failed: %v", err)
+ return
+ }
+
+ fmt.Printf("Access Token: %s\n", token.AccessToken)
+ fmt.Printf("Token Type: %s\n", token.TokenType)
+ fmt.Printf("Expires In: %v\n", token.Expiry)
+
+ // 使用令牌调用 UserInfo
+ client := conf.Client(context.Background(), token)
+ userInfoURL := buildUserInfoFromAuth(conf.Endpoint.AuthURL)
+ resp, err := client.Get(userInfoURL)
+ if err != nil {
+ log.Printf("UserInfo request failed: %v", err)
+ return
+ }
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Printf("Failed to read UserInfo response: %v", err)
+ return
+ }
+ fmt.Printf("UserInfo: %s\n", string(body))
+}
+
+// buildUserInfoFromAuth 将授权端点URL转换为UserInfo端点URL
+func buildUserInfoFromAuth(auth string) string {
+ u, err := url.Parse(auth)
+ if err != nil {
+ return ""
+ }
+ // 将最后一个路径段 authorize 替换为 userinfo
+ dir := path.Dir(u.Path)
+ if strings.HasSuffix(u.Path, "/authorize") {
+ u.Path = path.Join(dir, "userinfo")
+ } else {
+ // 回退:追加默认 /oauth/userinfo
+ u.Path = path.Join(dir, "userinfo")
+ }
+ return u.String()
+}
diff --git a/examples/oauth2_test_client.go b/examples/oauth2_test_client.go
deleted file mode 100644
index 30a6ac233..000000000
--- a/examples/oauth2_test_client.go
+++ /dev/null
@@ -1,125 +0,0 @@
-package main
-
-import (
- "context"
- "fmt"
- "io"
- "log"
- "net/http"
-
- "golang.org/x/oauth2"
- "golang.org/x/oauth2/clientcredentials"
-)
-
-func main() {
- // 测试 Client Credentials 流程
- testClientCredentials()
-
- // 测试 Authorization Code + PKCE 流程(需要浏览器交互)
- // testAuthorizationCode()
-}
-
-// testClientCredentials 测试服务对服务认证
-func testClientCredentials() {
- fmt.Println("=== Testing Client Credentials Flow ===")
-
- cfg := clientcredentials.Config{
- ClientID: "client_demo123456789", // 需要先创建客户端
- ClientSecret: "demo_secret_32_chars_long_123456",
- TokenURL: "http://127.0.0.1:8080/api/oauth/token",
- Scopes: []string{"api:read", "api:write"},
- EndpointParams: map[string][]string{
- "audience": {"api://new-api"},
- },
- }
-
- // 创建HTTP客户端
- httpClient := cfg.Client(context.Background())
-
- // 调用受保护的API
- resp, err := httpClient.Get("http://127.0.0.1:8080/api/status")
- if err != nil {
- log.Printf("Request failed: %v", err)
- return
- }
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- log.Printf("Failed to read response: %v", err)
- return
- }
-
- fmt.Printf("Status: %s\n", resp.Status)
- fmt.Printf("Response: %s\n", string(body))
-}
-
-// testAuthorizationCode 测试授权码流程
-func testAuthorizationCode() {
- fmt.Println("=== Testing Authorization Code + PKCE Flow ===")
-
- conf := oauth2.Config{
- ClientID: "client_web123456789", // Web客户端
- ClientSecret: "web_secret_32_chars_long_123456",
- RedirectURL: "http://localhost:9999/callback",
- Scopes: []string{"api:read", "api:write"},
- Endpoint: oauth2.Endpoint{
- AuthURL: "http://127.0.0.1:8080/api/oauth/authorize",
- TokenURL: "http://127.0.0.1:8080/api/oauth/token",
- },
- }
-
- // 生成PKCE参数
- codeVerifier := oauth2.GenerateVerifier()
-
- // 构建授权URL
- url := conf.AuthCodeURL(
- "random-state-string",
- oauth2.S256ChallengeOption(codeVerifier),
- oauth2.SetAuthURLParam("audience", "api://new-api"),
- )
-
- fmt.Printf("Visit this URL to authorize:\n%s\n\n", url)
- fmt.Printf("After authorization, you'll get a code. Use it to exchange for tokens.\n")
-
- // 在实际应用中,这里需要启动一个HTTP服务器来接收回调
- // 或者手动输入从回调URL中获取的授权码
-
- fmt.Print("Enter the authorization code: ")
- var code string
- fmt.Scanln(&code)
-
- if code != "" {
- // 交换令牌
- token, err := conf.Exchange(
- context.Background(),
- code,
- oauth2.VerifierOption(codeVerifier),
- )
- if err != nil {
- log.Printf("Token exchange failed: %v", err)
- return
- }
-
- fmt.Printf("Access Token: %s\n", token.AccessToken)
- fmt.Printf("Token Type: %s\n", token.TokenType)
- fmt.Printf("Expires In: %v\n", token.Expiry)
-
- // 使用令牌调用API
- client := conf.Client(context.Background(), token)
- resp, err := client.Get("http://127.0.0.1:8080/api/status")
- if err != nil {
- log.Printf("API request failed: %v", err)
- return
- }
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- log.Printf("Failed to read response: %v", err)
- return
- }
-
- fmt.Printf("API Response: %s\n", string(body))
- }
-}
diff --git a/middleware/auth.go b/middleware/auth.go
index 25caf50d9..eaf73998c 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -8,11 +8,14 @@ import (
"one-api/model"
"one-api/setting"
"one-api/setting/ratio_setting"
+ "one-api/setting/system_setting"
+ "one-api/src/oauth"
"strconv"
"strings"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
+ jwt "github.com/golang-jwt/jwt/v5"
)
func validUserInfo(username string, role int) bool {
@@ -177,6 +180,7 @@ func WssAuth(c *gin.Context) {
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
+ rawAuth := c.Request.Header.Get("Authorization")
// 先检测是否为ws
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
@@ -235,6 +239,11 @@ func TokenAuth() func(c *gin.Context) {
}
}
if err != nil {
+ // OAuth Bearer fallback
+ if tryOAuthBearer(c, rawAuth) {
+ c.Next()
+ return
+ }
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
@@ -288,6 +297,74 @@ func TokenAuth() func(c *gin.Context) {
}
}
+// tryOAuthBearer validates an OAuth JWT access token and sets minimal context for relay
+func tryOAuthBearer(c *gin.Context, rawAuth string) bool {
+ if rawAuth == "" || !strings.HasPrefix(rawAuth, "Bearer ") {
+ return false
+ }
+ tokenString := strings.TrimSpace(strings.TrimPrefix(rawAuth, "Bearer "))
+ if tokenString == "" {
+ return false
+ }
+ settings := system_setting.GetOAuth2Settings()
+ // Parse & verify
+ parsed, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
+ if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
+ return nil, jwt.ErrTokenSignatureInvalid
+ }
+ if kid, ok := t.Header["kid"].(string); ok {
+ if settings.JWTKeyID != "" && kid != settings.JWTKeyID {
+ return nil, jwt.ErrTokenSignatureInvalid
+ }
+ }
+ pub := oauth.GetRSAPublicKey()
+ if pub == nil {
+ return nil, jwt.ErrTokenUnverifiable
+ }
+ return pub, nil
+ })
+ if err != nil || parsed == nil || !parsed.Valid {
+ return false
+ }
+ claims, ok := parsed.Claims.(jwt.MapClaims)
+ if !ok {
+ return false
+ }
+ // issuer check when configured
+ if iss, ok2 := claims["iss"].(string); !ok2 || (settings.Issuer != "" && iss != settings.Issuer) {
+ return false
+ }
+ // revoke check
+ if jti, ok2 := claims["jti"].(string); ok2 && jti != "" {
+ if revoked, _ := model.IsTokenRevoked(jti); revoked {
+ return false
+ }
+ }
+ // scope check: must contain api:read or api:write or admin
+ scope, _ := claims["scope"].(string)
+ scopePadded := " " + scope + " "
+ if !(strings.Contains(scopePadded, " api:read ") || strings.Contains(scopePadded, " api:write ") || strings.Contains(scopePadded, " admin ")) {
+ return false
+ }
+ // subject must be user id to support quota logic
+ sub, _ := claims["sub"].(string)
+ uid, err := strconv.Atoi(sub)
+ if err != nil || uid <= 0 {
+ return false
+ }
+ // load user cache & set context
+ userCache, err := model.GetUserCache(uid)
+ if err != nil || userCache == nil || userCache.Status != common.UserStatusEnabled {
+ return false
+ }
+ c.Set("id", uid)
+ c.Set("group", userCache.Group)
+ c.Set("user_group", userCache.Group)
+ // set UsingGroup
+ common.SetContextKey(c, constant.ContextKeyUsingGroup, userCache.Group)
+ return true
+}
+
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
if token == nil {
return fmt.Errorf("token is nil")
diff --git a/middleware/oauth_jwt.go b/middleware/oauth_jwt.go
index 3e8fe0c69..38cbb3fe1 100644
--- a/middleware/oauth_jwt.go
+++ b/middleware/oauth_jwt.go
@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/model"
"one-api/setting/system_setting"
+ "one-api/src/oauth"
"strings"
"github.com/gin-gonic/gin"
@@ -108,23 +109,20 @@ func getPublicKeyByKid(kid string) (*rsa.PublicKey, error) {
// 这里先实现一个简单版本
// TODO: 实现JWKS缓存和刷新机制
- settings := system_setting.GetOAuth2Settings()
- if settings.JWTKeyID == kid {
- // 从OAuth server模块获取公钥
- // 这需要在OAuth server初始化后才能使用
- return nil, fmt.Errorf("JWKS functionality not yet implemented")
+ pub := oauth.GetPublicKeyByKid(kid)
+ if pub == nil {
+ return nil, fmt.Errorf("unknown kid: %s", kid)
}
-
- return nil, fmt.Errorf("unknown kid: %s", kid)
+ return pub, nil
}
// validateOAuthClaims 验证OAuth2 claims
func validateOAuthClaims(claims jwt.MapClaims) error {
settings := system_setting.GetOAuth2Settings()
- // 验证issuer
+ // 验证issuer(若配置了 Issuer 则强校验,否则仅要求存在)
if iss, ok := claims["iss"].(string); ok {
- if iss != settings.Issuer {
+ if settings.Issuer != "" && iss != settings.Issuer {
return fmt.Errorf("invalid issuer")
}
} else {
@@ -146,6 +144,14 @@ func validateOAuthClaims(claims jwt.MapClaims) error {
if client.Status != common.UserStatusEnabled {
return fmt.Errorf("client disabled")
}
+
+ // 检查是否被撤销
+ if jti, ok := claims["jti"].(string); ok && jti != "" {
+ revoked, _ := model.IsTokenRevoked(jti)
+ if revoked {
+ return fmt.Errorf("token revoked")
+ }
+ }
} else {
return fmt.Errorf("missing client_id claim")
}
@@ -240,6 +246,34 @@ func OptionalOAuthAuth() gin.HandlerFunc {
}
}
+// RequireOAuthScopeIfPresent enforces scope only when OAuth is present; otherwise no-op
+func RequireOAuthScopeIfPresent(requiredScope string) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ if !c.GetBool("oauth_authenticated") {
+ c.Next()
+ return
+ }
+ scope, exists := c.Get("oauth_scope")
+ if !exists {
+ abortWithOAuthError(c, "insufficient_scope", "No scope in token")
+ return
+ }
+ scopeStr, ok := scope.(string)
+ if !ok {
+ abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
+ return
+ }
+ scopes := strings.Split(scopeStr, " ")
+ for _, s := range scopes {
+ if strings.TrimSpace(s) == requiredScope {
+ c.Next()
+ return
+ }
+ }
+ abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
+ }
+}
+
// GetOAuthClaims 获取OAuth claims
func GetOAuthClaims(c *gin.Context) (jwt.MapClaims, bool) {
claims, exists := c.Get("oauth_claims")
diff --git a/model/oauth_revoked_token.go b/model/oauth_revoked_token.go
new file mode 100644
index 000000000..35e7d4b08
--- /dev/null
+++ b/model/oauth_revoked_token.go
@@ -0,0 +1,57 @@
+package model
+
+import (
+ "fmt"
+ "one-api/common"
+ "sync"
+ "time"
+)
+
+var revokedMem sync.Map // jti -> exp(unix)
+
+func RevokeToken(jti string, exp int64) error {
+ if jti == "" {
+ return nil
+ }
+ // Prefer Redis, else in-memory
+ if common.RedisEnabled {
+ ttl := time.Duration(0)
+ if exp > 0 {
+ ttl = time.Until(time.Unix(exp, 0))
+ }
+ if ttl <= 0 {
+ ttl = time.Minute
+ }
+ key := fmt.Sprintf("oauth:revoked:%s", jti)
+ return common.RedisSet(key, "1", ttl)
+ }
+ if exp <= 0 {
+ exp = time.Now().Add(time.Minute).Unix()
+ }
+ revokedMem.Store(jti, exp)
+ return nil
+}
+
+func IsTokenRevoked(jti string) (bool, error) {
+ if jti == "" {
+ return false, nil
+ }
+ if common.RedisEnabled {
+ key := fmt.Sprintf("oauth:revoked:%s", jti)
+ if _, err := common.RedisGet(key); err == nil {
+ return true, nil
+ } else {
+ // Not found or error; treat as not revoked on error to avoid hard failures
+ return false, nil
+ }
+ }
+ // In-memory check
+ if v, ok := revokedMem.Load(jti); ok {
+ exp, _ := v.(int64)
+ if exp == 0 || time.Now().Unix() <= exp {
+ return true, nil
+ }
+ revokedMem.Delete(jti)
+ }
+ return false, nil
+}
diff --git a/router/api-router.go b/router/api-router.go
index 9b4f61a65..9b43c6cf8 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -34,11 +34,13 @@ func SetApiRouter(router *gin.Engine) {
// OAuth2 Server endpoints
apiRouter.GET("/.well-known/jwks.json", controller.GetJWKS)
+ apiRouter.GET("/.well-known/openid-configuration", controller.OAuthOIDCConfiguration)
apiRouter.GET("/.well-known/oauth-authorization-server", controller.OAuthServerInfo)
apiRouter.POST("/oauth/token", middleware.CriticalRateLimit(), controller.OAuthTokenEndpoint)
apiRouter.GET("/oauth/authorize", controller.OAuthAuthorizeEndpoint)
apiRouter.POST("/oauth/introspect", middleware.AdminAuth(), controller.OAuthIntrospect)
apiRouter.POST("/oauth/revoke", middleware.CriticalRateLimit(), controller.OAuthRevoke)
+ apiRouter.GET("/oauth/userinfo", middleware.OAuthJWTAuth(), controller.OAuthUserInfo)
// OAuth2 管理API (前端使用)
apiRouter.GET("/oauth/jwks", controller.GetJWKS)
@@ -53,6 +55,17 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
+ // OAuth2 admin operations
+ oauthAdmin := apiRouter.Group("/oauth")
+ oauthAdmin.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth())
+ {
+ oauthAdmin.POST("/keys/rotate", controller.RotateOAuthSigningKey)
+ oauthAdmin.GET("/keys", controller.ListOAuthSigningKeys)
+ oauthAdmin.DELETE("/keys/:kid", controller.DeleteOAuthSigningKey)
+ oauthAdmin.POST("/keys/generate_file", controller.GenerateOAuthSigningKeyFile)
+ oauthAdmin.POST("/keys/import_pem", controller.ImportOAuthSigningKey)
+ }
+
userRoute := apiRouter.Group("/user")
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
@@ -91,7 +104,7 @@ func SetApiRouter(router *gin.Engine) {
}
adminRoute := userRoute.Group("/")
- adminRoute.Use(middleware.AdminAuth())
+ adminRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
adminRoute.GET("/", controller.GetAllUsers)
adminRoute.GET("/search", controller.SearchUsers)
@@ -107,7 +120,7 @@ func SetApiRouter(router *gin.Engine) {
}
}
optionRoute := apiRouter.Group("/option")
- optionRoute.Use(middleware.RootAuth())
+ optionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth())
{
optionRoute.GET("/", controller.GetOptions)
optionRoute.PUT("/", controller.UpdateOption)
@@ -121,7 +134,7 @@ func SetApiRouter(router *gin.Engine) {
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
}
channelRoute := apiRouter.Group("/channel")
- channelRoute.Use(middleware.AdminAuth())
+ channelRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
@@ -172,7 +185,7 @@ func SetApiRouter(router *gin.Engine) {
}
redemptionRoute := apiRouter.Group("/redemption")
- redemptionRoute.Use(middleware.AdminAuth())
+ redemptionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
redemptionRoute.GET("/", controller.GetAllRedemptions)
redemptionRoute.GET("/search", controller.SearchRedemptions)
@@ -200,13 +213,13 @@ func SetApiRouter(router *gin.Engine) {
logRoute.GET("/token", controller.GetLogByKey)
}
groupRoute := apiRouter.Group("/group")
- groupRoute.Use(middleware.AdminAuth())
+ groupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
groupRoute.GET("/", controller.GetGroups)
}
prefillGroupRoute := apiRouter.Group("/prefill_group")
- prefillGroupRoute.Use(middleware.AdminAuth())
+ prefillGroupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
prefillGroupRoute.GET("/", controller.GetPrefillGroups)
prefillGroupRoute.POST("/", controller.CreatePrefillGroup)
diff --git a/setting/system_setting/oauth2.go b/setting/system_setting/oauth2.go
index 078fe69e9..8bcd73d77 100644
--- a/setting/system_setting/oauth2.go
+++ b/setting/system_setting/oauth2.go
@@ -3,19 +3,21 @@ package system_setting
import "one-api/setting/config"
type OAuth2Settings struct {
- Enabled bool `json:"enabled"`
- Issuer string `json:"issuer"`
- AccessTokenTTL int `json:"access_token_ttl"` // in minutes
- RefreshTokenTTL int `json:"refresh_token_ttl"` // in minutes
- AllowedGrantTypes []string `json:"allowed_grant_types"` // client_credentials, authorization_code, refresh_token
- RequirePKCE bool `json:"require_pkce"` // force PKCE for authorization code flow
- JWTSigningAlgorithm string `json:"jwt_signing_algorithm"`
- JWTKeyID string `json:"jwt_key_id"`
- JWTPrivateKeyFile string `json:"jwt_private_key_file"`
- AutoCreateUser bool `json:"auto_create_user"` // auto create user on first OAuth2 login
- DefaultUserRole int `json:"default_user_role"` // default role for auto-created users
- DefaultUserGroup string `json:"default_user_group"` // default group for auto-created users
- ScopeMappings map[string][]string `json:"scope_mappings"` // scope to permissions mapping
+ Enabled bool `json:"enabled"`
+ Issuer string `json:"issuer"`
+ AccessTokenTTL int `json:"access_token_ttl"` // in minutes
+ RefreshTokenTTL int `json:"refresh_token_ttl"` // in minutes
+ AllowedGrantTypes []string `json:"allowed_grant_types"` // client_credentials, authorization_code, refresh_token
+ RequirePKCE bool `json:"require_pkce"` // force PKCE for authorization code flow
+ JWTSigningAlgorithm string `json:"jwt_signing_algorithm"`
+ JWTKeyID string `json:"jwt_key_id"`
+ JWTPrivateKeyFile string `json:"jwt_private_key_file"`
+ AutoCreateUser bool `json:"auto_create_user"` // auto create user on first OAuth2 login
+ DefaultUserRole int `json:"default_user_role"` // default role for auto-created users
+ DefaultUserGroup string `json:"default_user_group"` // default group for auto-created users
+ ScopeMappings map[string][]string `json:"scope_mappings"` // scope to permissions mapping
+ MaxJWKSKeys int `json:"max_jwks_keys"` // maximum number of JWKS signing keys to retain
+ DefaultPrivateKeyPath string `json:"default_private_key_path"` // suggested private key file path
}
// 默认配置
@@ -35,6 +37,8 @@ var defaultOAuth2Settings = OAuth2Settings{
"api:write": {"write"},
"admin": {"admin"},
},
+ MaxJWKSKeys: 3,
+ DefaultPrivateKeyPath: "/etc/new-api/oauth2-private.pem",
}
func init() {
diff --git a/src/oauth/server.go b/src/oauth/server.go
index 65f53b121..792311463 100644
--- a/src/oauth/server.go
+++ b/src/oauth/server.go
@@ -3,17 +3,34 @@ package oauth
import (
"crypto/rand"
"crypto/rsa"
+ "errors"
"fmt"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "time"
+
"one-api/common"
+ "one-api/logger"
+ "one-api/model"
"one-api/setting/system_setting"
+ "crypto/x509"
+ "encoding/pem"
+ "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
+ jwt "github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v2/jwk"
+ "os"
+ "strconv"
)
var (
- simplePrivateKey *rsa.PrivateKey
- simpleJWKSSet jwk.Set
+ signingKeys = map[string]*rsa.PrivateKey{}
+ currentKeyID string
+ simpleJWKSSet jwk.Set
+ keyMeta = map[string]int64{} // kid -> created_at (unix)
)
// InitOAuthServer 简化版OAuth2服务器初始化
@@ -24,15 +41,21 @@ func InitOAuthServer() error {
return nil
}
- // 生成RSA私钥(简化版本)
+ // 生成RSA私钥,并设置当前 kid
var err error
- simplePrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
+ if settings.JWTKeyID == "" {
+ settings.JWTKeyID = "oauth2-key-1"
+ }
+ currentKeyID = settings.JWTKeyID
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return fmt.Errorf("failed to generate RSA key: %w", err)
}
+ signingKeys[currentKeyID] = key
+ keyMeta[currentKeyID] = time.Now().Unix()
- // 创建JWKS
- simpleJWKSSet, err = createSimpleJWKS(simplePrivateKey, settings.JWTKeyID)
+ // 创建JWKS,加入当前公钥
+ simpleJWKSSet, err = createSimpleJWKS(key, currentKeyID)
if err != nil {
return fmt.Errorf("failed to create JWKS: %w", err)
}
@@ -41,6 +64,35 @@ func InitOAuthServer() error {
return nil
}
+// EnsureInitialized lazily initializes signing keys and JWKS if OAuth2 is enabled but not yet ready
+func EnsureInitialized() error {
+ settings := system_setting.GetOAuth2Settings()
+ if !settings.Enabled {
+ return nil
+ }
+ if len(signingKeys) > 0 && simpleJWKSSet != nil && currentKeyID != "" {
+ return nil
+ }
+ // generate one key and JWKS on demand
+ if settings.JWTKeyID == "" {
+ settings.JWTKeyID = fmt.Sprintf("oauth2-key-%d", time.Now().Unix())
+ }
+ currentKeyID = settings.JWTKeyID
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return err
+ }
+ signingKeys[currentKeyID] = key
+ keyMeta[currentKeyID] = time.Now().Unix()
+ jwks, err := createSimpleJWKS(key, currentKeyID)
+ if err != nil {
+ return err
+ }
+ simpleJWKSSet = jwks
+ common.SysLog("OAuth2 lazy-initialized: signing key and JWKS ready")
+ return nil
+}
+
// createSimpleJWKS 创建简单的JWKS
func createSimpleJWKS(privateKey *rsa.PrivateKey, keyID string) (jwk.Set, error) {
pubJWK, err := jwk.FromRaw(&privateKey.PublicKey)
@@ -63,18 +115,938 @@ func GetJWKS() jwk.Set {
return simpleJWKSSet
}
-// HandleTokenRequest 简化的令牌处理(临时实现)
+// GetRSAPublicKey 返回当前用于签发的RSA公钥
+func GetRSAPublicKey() *rsa.PublicKey {
+ if k, ok := signingKeys[currentKeyID]; ok && k != nil {
+ return &k.PublicKey
+ }
+ return nil
+}
+
+// GetPublicKeyByKid returns public key by kid if exists
+func GetPublicKeyByKid(kid string) *rsa.PublicKey {
+ if k, ok := signingKeys[kid]; ok && k != nil {
+ return &k.PublicKey
+ }
+ return nil
+}
+
+// RotateSigningKey generates a new RSA key, updates current kid, and adds to JWKS
+func RotateSigningKey(newKid string) (string, error) {
+ if newKid == "" {
+ newKid = fmt.Sprintf("oauth2-key-%d", time.Now().Unix())
+ }
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return "", err
+ }
+ signingKeys[newKid] = key
+ keyMeta[newKid] = time.Now().Unix()
+ // add to jwks set
+ pubJWK, err := jwk.FromRaw(&key.PublicKey)
+ if err == nil {
+ _ = pubJWK.Set(jwk.KeyIDKey, newKid)
+ _ = pubJWK.Set(jwk.AlgorithmKey, "RS256")
+ _ = pubJWK.Set(jwk.KeyUsageKey, "sig")
+ _ = simpleJWKSSet.AddKey(pubJWK)
+ }
+ currentKeyID = newKid
+ enforceKeyRetention()
+ return newKid, nil
+}
+
+// GenerateAndPersistKey generates a new RSA key, writes to a server file, and rotates current kid
+func GenerateAndPersistKey(path string, kid string, overwrite bool) (string, error) {
+ if kid == "" {
+ kid = fmt.Sprintf("oauth2-key-%d", time.Now().Unix())
+ }
+ if _, err := os.Stat(path); err == nil && !overwrite {
+ return "", fmt.Errorf("file exists")
+ } else if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return "", err
+ }
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return "", err
+ }
+ // write PKCS1 PEM with 0600 perms
+ der := x509.MarshalPKCS1PrivateKey(key)
+ blk := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}
+ pemBytes := pem.EncodeToMemory(blk)
+ if err := os.WriteFile(path, pemBytes, 0600); err != nil {
+ return "", err
+ }
+ // rotate in memory
+ signingKeys[kid] = key
+ keyMeta[kid] = time.Now().Unix()
+ // add to jwks
+ pubJWK, err := jwk.FromRaw(&key.PublicKey)
+ if err == nil {
+ _ = pubJWK.Set(jwk.KeyIDKey, kid)
+ _ = pubJWK.Set(jwk.AlgorithmKey, "RS256")
+ _ = pubJWK.Set(jwk.KeyUsageKey, "sig")
+ _ = simpleJWKSSet.AddKey(pubJWK)
+ }
+ currentKeyID = kid
+ enforceKeyRetention()
+ return kid, nil
+}
+
+// ListSigningKeys returns metadata of keys
+type KeyInfo struct {
+ Kid string `json:"kid"`
+ CreatedAt int64 `json:"created_at"`
+ Current bool `json:"current"`
+}
+
+func ListSigningKeys() []KeyInfo {
+ out := make([]KeyInfo, 0, len(signingKeys))
+ for kid := range signingKeys {
+ out = append(out, KeyInfo{Kid: kid, CreatedAt: keyMeta[kid], Current: kid == currentKeyID})
+ }
+ // sort by CreatedAt asc
+ sort.Slice(out, func(i, j int) bool { return out[i].CreatedAt < out[j].CreatedAt })
+ return out
+}
+
+// DeleteSigningKey removes a non-current key
+func DeleteSigningKey(kid string) error {
+ if kid == "" {
+ return fmt.Errorf("kid required")
+ }
+ if kid == currentKeyID {
+ return fmt.Errorf("cannot delete current signing key")
+ }
+ if _, ok := signingKeys[kid]; !ok {
+ return fmt.Errorf("unknown kid")
+ }
+ delete(signingKeys, kid)
+ delete(keyMeta, kid)
+ rebuildJWKS()
+ return nil
+}
+
+func rebuildJWKS() {
+ jwks := jwk.NewSet()
+ for kid, k := range signingKeys {
+ pub, err := jwk.FromRaw(&k.PublicKey)
+ if err == nil {
+ _ = pub.Set(jwk.KeyIDKey, kid)
+ _ = pub.Set(jwk.AlgorithmKey, "RS256")
+ _ = pub.Set(jwk.KeyUsageKey, "sig")
+ _ = jwks.AddKey(pub)
+ }
+ }
+ simpleJWKSSet = jwks
+}
+
+func enforceKeyRetention() {
+ max := system_setting.GetOAuth2Settings().MaxJWKSKeys
+ if max <= 0 {
+ max = 1
+ }
+ // retain max most recent keys
+ infos := ListSigningKeys()
+ if len(infos) <= max {
+ return
+ }
+ // delete oldest first, skipping current
+ toDelete := len(infos) - max
+ for _, ki := range infos {
+ if toDelete == 0 {
+ break
+ }
+ if ki.Kid == currentKeyID {
+ continue
+ }
+ _ = DeleteSigningKey(ki.Kid)
+ toDelete--
+ }
+}
+
+// ImportPEMKey imports an RSA private key from PEM text and rotates current kid
+func ImportPEMKey(pemText string, kid string) (string, error) {
+ if kid == "" {
+ kid = fmt.Sprintf("oauth2-key-%d", time.Now().Unix())
+ }
+ // decode PEM
+ var block *pem.Block
+ var rest = []byte(pemText)
+ for {
+ block, rest = pem.Decode(rest)
+ if block == nil {
+ break
+ }
+ if block.Type == "RSA PRIVATE KEY" || strings.Contains(block.Type, "PRIVATE KEY") {
+ var key *rsa.PrivateKey
+ var err error
+ if block.Type == "RSA PRIVATE KEY" {
+ key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
+ } else {
+ // try PKCS#8
+ priv, err2 := x509.ParsePKCS8PrivateKey(block.Bytes)
+ if err2 != nil {
+ return "", err2
+ }
+ var ok bool
+ key, ok = priv.(*rsa.PrivateKey)
+ if !ok {
+ return "", fmt.Errorf("not an RSA private key")
+ }
+ }
+ if err != nil {
+ return "", err
+ }
+ signingKeys[kid] = key
+ keyMeta[kid] = time.Now().Unix()
+ pubJWK, err := jwk.FromRaw(&key.PublicKey)
+ if err == nil {
+ _ = pubJWK.Set(jwk.KeyIDKey, kid)
+ _ = pubJWK.Set(jwk.AlgorithmKey, "RS256")
+ _ = pubJWK.Set(jwk.KeyUsageKey, "sig")
+ _ = simpleJWKSSet.AddKey(pubJWK)
+ }
+ currentKeyID = kid
+ enforceKeyRetention()
+ return kid, nil
+ }
+ if len(rest) == 0 {
+ break
+ }
+ }
+ return "", fmt.Errorf("no private key found in PEM")
+}
+
+// HandleTokenRequest 实现最小可用的令牌签发(client_credentials)
func HandleTokenRequest(c *gin.Context) {
- c.JSON(501, map[string]string{
- "error": "not_implemented",
- "error_description": "OAuth2 token endpoint not fully implemented yet",
+ settings := system_setting.GetOAuth2Settings()
+
+ grantType := strings.TrimSpace(c.PostForm("grant_type"))
+ if grantType == "" {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_request", "missing grant_type")
+ return
+ }
+ if !settings.ValidateGrantType(grantType) {
+ writeOAuthError(c, http.StatusBadRequest, "unsupported_grant_type", "grant_type not allowed")
+ return
+ }
+
+ switch grantType {
+ case "client_credentials":
+ handleClientCredentials(c, settings)
+ case "refresh_token":
+ handleRefreshToken(c, settings)
+ case "authorization_code":
+ handleAuthorizationCodeExchange(c, settings)
+ default:
+ writeOAuthError(c, http.StatusBadRequest, "unsupported_grant_type", "unsupported grant_type")
+ }
+}
+
+func handleClientCredentials(c *gin.Context, settings *system_setting.OAuth2Settings) {
+ clientID, clientSecret := getFormOrBasicAuth(c)
+ if clientID == "" || clientSecret == "" {
+ writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "missing client credentials")
+ return
+ }
+
+ client, err := model.GetOAuthClientByID(clientID)
+ if err != nil {
+ writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "unknown client")
+ return
+ }
+ if client.Secret != clientSecret {
+ writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "invalid client secret")
+ return
+ }
+ // client type can be confidential or public; client_credentials only for confidential
+ if client.ClientType == "public" {
+ writeOAuthError(c, http.StatusBadRequest, "unauthorized_client", "public client cannot use client_credentials")
+ return
+ }
+ if !client.ValidateGrantType("client_credentials") {
+ writeOAuthError(c, http.StatusBadRequest, "unauthorized_client", "grant_type not enabled for client")
+ return
+ }
+
+ scope := strings.TrimSpace(c.PostForm("scope"))
+ if scope == "" {
+ // default to client's first scope or api:read
+ allowed := client.GetScopes()
+ if len(allowed) == 0 {
+ scope = "api:read"
+ } else {
+ scope = strings.Join(allowed, " ")
+ }
+ }
+ if !client.ValidateScope(scope) {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_scope", "requested scope not allowed")
+ return
+ }
+
+ // issue JWT access token
+ accessTTL := time.Duration(settings.AccessTokenTTL) * time.Minute
+ tokenStr, exp, jti, err := signAccessToken(settings, clientID, "", scope, "client_credentials", accessTTL, c)
+ if err != nil {
+ writeOAuthError(c, http.StatusInternalServerError, "server_error", "failed to issue token")
+ return
+ }
+
+ // update client usage
+ _ = client.UpdateLastUsedTime()
+
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenStr,
+ "token_type": "Bearer",
+ "expires_in": int64(exp.Sub(time.Now()).Seconds()),
+ "scope": scope,
+ "jti": jti,
})
}
+// handleAuthorizationCodeExchange 处理授权码换取令牌
+func handleAuthorizationCodeExchange(c *gin.Context, settings *system_setting.OAuth2Settings) {
+ // Redis not required; fallback to in-memory store
+ clientID, clientSecret := getFormOrBasicAuth(c)
+ code := strings.TrimSpace(c.PostForm("code"))
+ redirectURI := strings.TrimSpace(c.PostForm("redirect_uri"))
+ codeVerifier := strings.TrimSpace(c.PostForm("code_verifier"))
+
+ if clientID == "" {
+ writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "missing client_id")
+ return
+ }
+ client, err := model.GetOAuthClientByID(clientID)
+ if err != nil {
+ writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "unknown client")
+ return
+ }
+ if client.ClientType == "confidential" {
+ if clientSecret == "" || client.Secret != clientSecret {
+ writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "invalid client secret")
+ return
+ }
+ }
+ if !client.ValidateGrantType("authorization_code") {
+ writeOAuthError(c, http.StatusBadRequest, "unauthorized_client", "authorization_code not enabled for client")
+ return
+ }
+ if redirectURI == "" || !client.ValidateRedirectURI(redirectURI) {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_request", "redirect_uri mismatch or missing")
+ return
+ }
+ if code == "" {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "missing code")
+ return
+ }
+
+ // 从Redis获取授权码数据
+ key := fmt.Sprintf("oauth:code:%s", code)
+ raw, ok := storeGet(key)
+ if !ok || raw == "" {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "invalid or expired code")
+ return
+ }
+
+ // 解析:clientID|redirectURI|scope|userID|codeChallenge|codeChallengeMethod|exp[|nonce]
+ parts := strings.Split(raw, "|")
+ if len(parts) < 7 {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "malformed code payload")
+ return
+ }
+ payloadClientID := parts[0]
+ payloadRedirectURI := parts[1]
+ payloadScope := parts[2]
+ payloadUserIDStr := parts[3]
+ payloadCodeChallenge := parts[4]
+ payloadCodeChallengeMethod := parts[5]
+ // parts[6] = exp (unused here)
+ var payloadNonce string
+ if len(parts) >= 8 {
+ payloadNonce = parts[7]
+ }
+ // 单次使用:删除授权码
+ _ = storeDel(key)
+
+ if payloadClientID != clientID {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "client_id mismatch")
+ return
+ }
+ if payloadRedirectURI != redirectURI {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "redirect_uri mismatch")
+ return
+ }
+ // PKCE 校验
+ requirePKCE := settings.RequirePKCE || client.RequirePKCE
+ if requirePKCE || payloadCodeChallenge != "" {
+ if codeVerifier == "" {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_request", "missing code_verifier")
+ return
+ }
+ method := strings.ToUpper(payloadCodeChallengeMethod)
+ if method == "" {
+ method = "S256"
+ }
+ switch method {
+ case "S256":
+ if s256Base64URL(codeVerifier) != payloadCodeChallenge {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "code_verifier mismatch")
+ return
+ }
+ default:
+ writeOAuthError(c, http.StatusBadRequest, "invalid_request", "unsupported code_challenge_method")
+ return
+ }
+ }
+
+ // 颁发令牌
+ scope := payloadScope
+ userIDStr := payloadUserIDStr
+ accessTTL := time.Duration(settings.AccessTokenTTL) * time.Minute
+ tokenStr, exp, jti, err := signAccessToken(settings, clientID, userIDStr, scope, "authorization_code", accessTTL, c)
+ if err != nil {
+ writeOAuthError(c, http.StatusInternalServerError, "server_error", "failed to issue token")
+ return
+ }
+
+ // 可选:签发刷新令牌(仅当允许)
+ resp := gin.H{
+ "access_token": tokenStr,
+ "token_type": "Bearer",
+ "expires_in": int64(exp.Sub(time.Now()).Seconds()),
+ "scope": scope,
+ "jti": jti,
+ }
+ // OIDC: 当 scope 包含 openid 时,签发 id_token
+ if strings.Contains(" "+scope+" ", " openid ") {
+ idt, err := signIDToken(settings, clientID, payloadUserIDStr, payloadNonce, c)
+ if err == nil {
+ resp["id_token"] = idt
+ }
+ }
+ if settings.ValidateGrantType("refresh_token") && client.ValidateGrantType("refresh_token") {
+ rt, err := genCode(32)
+ if err == nil {
+ ttl := time.Duration(settings.RefreshTokenTTL) * time.Minute
+ rtKey := fmt.Sprintf("oauth:rt:%s", rt)
+ // 存储 clientID|userID|scope|nonce(便于刷新时维持 openid/nonce)
+ val := fmt.Sprintf("%s|%s|%s|%s", clientID, userIDStr, scope, payloadNonce)
+ _ = storeSet(rtKey, val, ttl)
+ resp["refresh_token"] = rt
+ }
+ }
+
+ _ = client.UpdateLastUsedTime()
+ writeNoStore(c)
+ c.JSON(http.StatusOK, resp)
+}
+
+// handleRefreshToken 刷新令牌
+func handleRefreshToken(c *gin.Context, settings *system_setting.OAuth2Settings) {
+ // Redis not required; fallback to in-memory store
+ clientID, clientSecret := getFormOrBasicAuth(c)
+ refreshToken := strings.TrimSpace(c.PostForm("refresh_token"))
+ if clientID == "" {
+ writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "missing client_id")
+ return
+ }
+ client, err := model.GetOAuthClientByID(clientID)
+ if err != nil {
+ writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "unknown client")
+ return
+ }
+ if client.ClientType == "confidential" {
+ if clientSecret == "" || client.Secret != clientSecret {
+ writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "invalid client secret")
+ return
+ }
+ }
+ if !client.ValidateGrantType("refresh_token") {
+ writeOAuthError(c, http.StatusBadRequest, "unauthorized_client", "refresh_token not enabled for client")
+ return
+ }
+ if refreshToken == "" {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_request", "missing refresh_token")
+ return
+ }
+ key := fmt.Sprintf("oauth:rt:%s", refreshToken)
+ raw, ok := storeGet(key)
+ if !ok || raw == "" {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "invalid refresh_token")
+ return
+ }
+ // 解析值:clientID|userID|scope|nonce
+ parts := strings.Split(raw, "|")
+ if len(parts) < 3 {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "malformed refresh token")
+ return
+ }
+ storedClientID := parts[0]
+ userIDStr := parts[1]
+ scope := parts[2]
+ var nonce string
+ if len(parts) >= 4 {
+ nonce = parts[3]
+ }
+ if storedClientID != clientID {
+ writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "client_id mismatch")
+ return
+ }
+
+ // 旋转refresh_token:删除旧的,签发新的
+ _ = storeDel(key)
+ newRT, err := genCode(32)
+ if err == nil {
+ ttl := time.Duration(settings.RefreshTokenTTL) * time.Minute
+ newKey := fmt.Sprintf("oauth:rt:%s", newRT)
+ _ = storeSet(newKey, raw, ttl)
+ }
+
+ // 颁发新的访问令牌
+ accessTTL := time.Duration(settings.AccessTokenTTL) * time.Minute
+ tokenStr, exp, jti, err := signAccessToken(settings, clientID, userIDStr, scope, "refresh_token", accessTTL, c)
+ if err != nil {
+ writeOAuthError(c, http.StatusInternalServerError, "server_error", "failed to issue token")
+ return
+ }
+ resp := gin.H{
+ "access_token": tokenStr,
+ "token_type": "Bearer",
+ "expires_in": int64(exp.Sub(time.Now()).Seconds()),
+ "scope": scope,
+ "jti": jti,
+ }
+ if strings.Contains(" "+scope+" ", " openid ") {
+ if idt, err := signIDToken(settings, clientID, userIDStr, nonce, c); err == nil {
+ resp["id_token"] = idt
+ }
+ }
+ if newRT != "" {
+ resp["refresh_token"] = newRT
+ }
+ writeNoStore(c)
+ c.JSON(http.StatusOK, resp)
+}
+
+// signAccessToken 使用内置RSA私钥签发JWT访问令牌
+func signAccessToken(settings *system_setting.OAuth2Settings, clientID string, subject string, scope string, grantType string, ttl time.Duration, c *gin.Context) (string, time.Time, string, error) {
+ now := time.Now()
+ exp := now.Add(ttl)
+ jti := common.GetUUID()
+ iss := settings.Issuer
+ if iss == "" {
+ // derive from requestd
+ scheme := "https"
+ if c != nil && c.Request != nil {
+ if c.Request.TLS == nil {
+ if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
+ scheme = hdr
+ } else {
+ scheme = "http"
+ }
+ }
+ host := c.Request.Host
+ if host != "" {
+ iss = fmt.Sprintf("%s://%s", scheme, host)
+ }
+ }
+ }
+
+ claims := jwt.MapClaims{
+ "iss": iss,
+ "sub": func() string {
+ if subject != "" {
+ return subject
+ }
+ return clientID
+ }(),
+ "aud": "one-api",
+ "iat": now.Unix(),
+ "nbf": now.Unix(),
+ "exp": exp.Unix(),
+ "scope": scope,
+ "client_id": clientID,
+ "grant_type": grantType,
+ "jti": jti,
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ // set kid
+ kid := currentKeyID
+ if kid != "" {
+ token.Header["kid"] = kid
+ }
+ k := signingKeys[kid]
+ if k == nil {
+ return "", time.Time{}, "", errors.New("signing key missing")
+ }
+ signed, err := token.SignedString(k)
+ if err != nil {
+ return "", time.Time{}, "", err
+ }
+ return signed, exp, jti, nil
+}
+
+// signIDToken 生成 OIDC id_token
+func signIDToken(settings *system_setting.OAuth2Settings, clientID string, subject string, nonce string, c *gin.Context) (string, error) {
+ k := signingKeys[currentKeyID]
+ if k == nil {
+ return "", errors.New("oauth private key not initialized")
+ }
+ // derive issuer similar to access token
+ iss := settings.Issuer
+ if iss == "" && c != nil && c.Request != nil {
+ scheme := "https"
+ if c.Request.TLS == nil {
+ if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
+ scheme = hdr
+ } else {
+ scheme = "http"
+ }
+ }
+ host := c.Request.Host
+ if host != "" {
+ iss = fmt.Sprintf("%s://%s", scheme, host)
+ }
+ }
+ now := time.Now()
+ exp := now.Add(10 * time.Minute) // id_token 短时有效
+
+ claims := jwt.MapClaims{
+ "iss": iss,
+ "sub": subject,
+ "aud": clientID,
+ "iat": now.Unix(),
+ "exp": exp.Unix(),
+ }
+ if nonce != "" {
+ claims["nonce"] = nonce
+ }
+
+ // 可选:附加 profile / email claims 由上层根据 scope 决定
+ if uid, err := strconv.Atoi(subject); err == nil {
+ if user, err2 := model.GetUserById(uid, false); err2 == nil && user != nil {
+ if user.Username != "" {
+ claims["preferred_username"] = user.Username
+ claims["name"] = user.DisplayName
+ }
+ if user.Email != "" {
+ claims["email"] = user.Email
+ claims["email_verified"] = true
+ }
+ }
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ if currentKeyID != "" {
+ token.Header["kid"] = currentKeyID
+ }
+ return token.SignedString(k)
+}
+
// HandleAuthorizeRequest 简化的授权处理(临时实现)
func HandleAuthorizeRequest(c *gin.Context) {
- c.JSON(501, map[string]string{
- "error": "not_implemented",
- "error_description": "OAuth2 authorize endpoint not fully implemented yet",
+ settings := system_setting.GetOAuth2Settings()
+ // Redis not required; fallback to in-memory store
+
+ // 解析参数
+ responseType := c.Query("response_type")
+ clientID := c.Query("client_id")
+ redirectURI := c.Query("redirect_uri")
+ scope := strings.TrimSpace(c.Query("scope"))
+ state := c.Query("state")
+ codeChallenge := c.Query("code_challenge")
+ codeChallengeMethod := strings.ToUpper(c.Query("code_challenge_method"))
+ nonce := c.Query("nonce")
+
+ if responseType == "" {
+ responseType = "code"
+ }
+ if responseType != "code" && responseType != "token" {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_response_type"})
+ return
+ }
+ if clientID == "" {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request", "error_description": "missing client_id"})
+ return
+ }
+ client, err := model.GetOAuthClientByID(clientID)
+ if err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_client"})
+ return
+ }
+ // 对于 implicit (response_type=token),允许客户端拥有 authorization_code 或 implicit 任一权限
+ if responseType == "code" {
+ if !client.ValidateGrantType("authorization_code") {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "unauthorized_client"})
+ return
+ }
+ } else {
+ if !(client.ValidateGrantType("implicit") || client.ValidateGrantType("authorization_code")) {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "unauthorized_client"})
+ return
+ }
+ }
+ // 严格匹配或本地回环地址宽松匹配(忽略端口,遵循 RFC 8252)
+ validRedirect := client.ValidateRedirectURI(redirectURI)
+ if !validRedirect {
+ if isLoopbackRedirectAllowed(redirectURI, client.GetRedirectURIs()) {
+ validRedirect = true
+ }
+ }
+ if redirectURI == "" || !validRedirect {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request", "error_description": "redirect_uri mismatch or missing"})
+ return
+ }
+
+ // 支持前端预取信息
+ mode := c.Query("mode") // mode=prepare 返回JSON供前端展示
+
+ // 校验scope
+ if scope == "" {
+ scope = strings.Join(client.GetScopes(), " ")
+ } else if !client.ValidateScope(scope) {
+ writeOAuthRedirectError(c, redirectURI, "invalid_scope", "requested scope not allowed", state)
+ return
+ }
+
+ // PKCE 要求
+ if responseType == "code" && (settings.RequirePKCE || client.RequirePKCE) {
+ if codeChallenge == "" {
+ writeOAuthRedirectError(c, redirectURI, "invalid_request", "code_challenge required", state)
+ return
+ }
+ if codeChallengeMethod == "" {
+ codeChallengeMethod = "S256"
+ }
+ if codeChallengeMethod != "S256" {
+ writeOAuthRedirectError(c, redirectURI, "invalid_request", "unsupported code_challenge_method", state)
+ return
+ }
+ }
+
+ // 检查用户会话(要求已登录)
+ sess := sessions.Default(c)
+ uidVal := sess.Get("id")
+ if uidVal == nil {
+ if mode == "prepare" {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "login_required"})
+ return
+ }
+ // 重定向到前端登录后回到同意页
+ consentPath := "/oauth/consent?" + c.Request.URL.RawQuery
+ loginPath := "/login?next=" + url.QueryEscape(consentPath)
+ writeNoStore(c)
+ c.Redirect(http.StatusFound, loginPath)
+ return
+ }
+ userID, _ := uidVal.(int)
+ if userID == 0 {
+ // 某些 session 库会将数字解码为 int64
+ if v64, ok := uidVal.(int64); ok {
+ userID = int(v64)
+ }
+ }
+ if userID == 0 {
+ writeOAuthRedirectError(c, redirectURI, "login_required", "user not logged in", state)
+ return
+ }
+
+ // prepare 模式:返回前端展示信息
+ if mode == "prepare" {
+ // 解析重定向域名
+ rHost := ""
+ if u, err := url.Parse(redirectURI); err == nil {
+ rHost = u.Hostname()
+ }
+ verified := false
+ if client.Domain != "" && rHost != "" {
+ verified = strings.EqualFold(client.Domain, rHost)
+ }
+ // scope 明细
+ scopeNames := strings.Fields(scope)
+ type scopeItem struct{ Name, Description string }
+ var scopeInfo []scopeItem
+ for _, s := range scopeNames {
+ d := ""
+ switch s {
+ case "openid":
+ d = "访问你的基础身份 (sub)"
+ case "profile":
+ d = "读取你的公开资料 (昵称/用户名)"
+ case "email":
+ d = "读取你的邮箱地址"
+ case "api:read":
+ d = "读取 API 资源"
+ case "api:write":
+ d = "写入/修改 API 资源"
+ case "admin":
+ d = "管理权限 (高危)"
+ default:
+ d = ""
+ }
+ scopeInfo = append(scopeInfo, scopeItem{Name: s, Description: d})
+ }
+ // 当前用户信息(用于展示)
+ var userName, userEmail string
+ if user, err := model.GetUserById(userID, false); err == nil && user != nil {
+ userName = user.DisplayName
+ if userName == "" {
+ userName = user.Username
+ }
+ userEmail = user.Email
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "client": gin.H{
+ "id": client.ID,
+ "name": client.Name,
+ "type": client.ClientType,
+ "desc": client.Description,
+ "domain": client.Domain,
+ },
+ "scope": scope,
+ "scope_list": scopeNames,
+ "scope_info": scopeInfo,
+ "redirect_uri": redirectURI,
+ "redirect_host": rHost,
+ "verified": verified,
+ "state": state,
+ "response_type": responseType,
+ "require_pkce": (responseType == "code") && (settings.RequirePKCE || client.RequirePKCE),
+ "user": gin.H{
+ "id": userID,
+ "name": userName,
+ "email": userEmail,
+ },
+ })
+ return
+ }
+
+ // 拒绝授权:返回错误给回调地址
+ if c.Query("deny") == "1" || strings.EqualFold(c.Query("decision"), "deny") {
+ logger.LogInfo(c, fmt.Sprintf("oauth consent denied: user=%v client=%s scope=%s redirect=%s", sess.Get("id"), clientID, scope, redirectURI))
+ writeOAuthRedirectError(c, redirectURI, "access_denied", "user denied the request", state)
+ return
+ }
+
+ // 未明确选择,跳转前端同意页
+ if !(c.Query("approve") == "1" || strings.EqualFold(c.Query("decision"), "approve")) {
+ consentPath := "/oauth/consent?" + c.Request.URL.RawQuery
+ writeNoStore(c)
+ c.Redirect(http.StatusFound, consentPath)
+ return
+ }
+
+ // 根据响应类型返回
+ if responseType == "code" {
+ // 生成授权码,写入 存储(短TTL)
+ code, err := genCode(32)
+ if err != nil {
+ writeOAuthRedirectError(c, redirectURI, "server_error", "failed to generate code", state)
+ return
+ }
+ ttl := 2 * time.Minute
+ exp := time.Now().Add(ttl).Unix()
+ // 存储 clientID|redirectURI|scope|userID|codeChallenge|codeChallengeMethod|exp|nonce
+ val := fmt.Sprintf("%s|%s|%s|%d|%s|%s|%d|%s", clientID, redirectURI, scope, userID, codeChallenge, codeChallengeMethod, exp, nonce)
+ key := fmt.Sprintf("oauth:code:%s", code)
+ if err := storeSet(key, val, ttl); err != nil {
+ writeOAuthRedirectError(c, redirectURI, "server_error", "failed to store code", state)
+ return
+ }
+ logger.LogInfo(c, fmt.Sprintf("oauth consent approved (code): user=%d client=%s scope=%s redirect=%s", userID, clientID, scope, redirectURI))
+
+ // 成功,重定向(查询参数)
+ u, _ := url.Parse(redirectURI)
+ q := u.Query()
+ q.Set("code", code)
+ if state != "" {
+ q.Set("state", state)
+ }
+ u.RawQuery = q.Encode()
+ writeNoStore(c)
+ c.Redirect(http.StatusFound, u.String())
+ return
+ }
+
+ // response_type=token (implicit)
+ // 直接签发 Access Token(不下发 Refresh Token)
+ accessTTL := time.Duration(settings.AccessTokenTTL) * time.Minute
+ userIDStr := fmt.Sprintf("%d", userID)
+ tokenStr, expTime, jti, err := signAccessToken(settings, clientID, userIDStr, scope, "implicit", accessTTL, c)
+ if err != nil {
+ writeOAuthRedirectError(c, redirectURI, "server_error", "failed to issue token", state)
+ return
+ }
+ _ = client.UpdateLastUsedTime()
+ logger.LogInfo(c, fmt.Sprintf("oauth consent approved (token): user=%d client=%s scope=%s redirect=%s jti=%s", userID, clientID, scope, redirectURI, jti))
+
+ // 使用 fragment 传递(#access_token=...)
+ u, _ := url.Parse(redirectURI)
+ frag := url.Values{}
+ frag.Set("access_token", tokenStr)
+ frag.Set("token_type", "Bearer")
+ frag.Set("expires_in", fmt.Sprintf("%d", int64(expTime.Sub(time.Now()).Seconds())))
+ if scope != "" {
+ frag.Set("scope", scope)
+ }
+ if state != "" {
+ frag.Set("state", state)
+ }
+ u.Fragment = frag.Encode()
+ writeNoStore(c)
+ c.Redirect(http.StatusFound, u.String())
+}
+
+func writeOAuthError(c *gin.Context, status int, code, description string) {
+ c.Header("Cache-Control", "no-store")
+ c.Header("Pragma", "no-cache")
+ c.JSON(status, gin.H{
+ "error": code,
+ "error_description": description,
})
}
+
+// isLoopback returns true if hostname represents a local loopback host
+func isLoopback(host string) bool {
+ if host == "" {
+ return false
+ }
+ h := strings.ToLower(host)
+ if h == "localhost" || h == "::1" {
+ return true
+ }
+ if strings.HasPrefix(h, "127.") {
+ return true
+ }
+ return false
+}
+
+// isLoopbackRedirectAllowed allows redirect URIs on loopback hosts to match ignoring port
+// This follows OAuth 2.0 for Native Apps (RFC 8252) guidance to use loopback interface with dynamic port.
+func isLoopbackRedirectAllowed(requested string, allowed []string) bool {
+ if requested == "" || len(allowed) == 0 {
+ return false
+ }
+ ru, err := url.Parse(requested)
+ if err != nil {
+ return false
+ }
+ if !isLoopback(ru.Hostname()) {
+ return false
+ }
+ for _, a := range allowed {
+ au, err := url.Parse(a)
+ if err != nil {
+ continue
+ }
+ if !isLoopback(au.Hostname()) {
+ continue
+ }
+ // require same scheme and path; ignore port and host variant among loopback
+ if strings.EqualFold(ru.Scheme, au.Scheme) && ru.Path == au.Path {
+ return true
+ }
+ }
+ return false
+}
diff --git a/src/oauth/store.go b/src/oauth/store.go
new file mode 100644
index 000000000..5d560af31
--- /dev/null
+++ b/src/oauth/store.go
@@ -0,0 +1,82 @@
+package oauth
+
+import (
+ "one-api/common"
+ "sync"
+ "time"
+)
+
+// KVStore is a minimal TTL key-value abstraction used by OAuth flows.
+type KVStore interface {
+ Set(key, value string, ttl time.Duration) error
+ Get(key string) (string, bool)
+ Del(key string) error
+}
+
+type redisStore struct{}
+
+func (r *redisStore) Set(key, value string, ttl time.Duration) error {
+ return common.RedisSet(key, value, ttl)
+}
+func (r *redisStore) Get(key string) (string, bool) {
+ v, err := common.RedisGet(key)
+ if err != nil || v == "" {
+ return "", false
+ }
+ return v, true
+}
+func (r *redisStore) Del(key string) error {
+ return common.RedisDel(key)
+}
+
+type memEntry struct {
+ val string
+ exp int64 // unix seconds, 0 means no expiry
+}
+
+type memoryStore struct {
+ m sync.Map // key -> memEntry
+}
+
+func (m *memoryStore) Set(key, value string, ttl time.Duration) error {
+ var exp int64
+ if ttl > 0 {
+ exp = time.Now().Add(ttl).Unix()
+ }
+ m.m.Store(key, memEntry{val: value, exp: exp})
+ return nil
+}
+
+func (m *memoryStore) Get(key string) (string, bool) {
+ v, ok := m.m.Load(key)
+ if !ok {
+ return "", false
+ }
+ e := v.(memEntry)
+ if e.exp > 0 && time.Now().Unix() > e.exp {
+ m.m.Delete(key)
+ return "", false
+ }
+ return e.val, true
+}
+
+func (m *memoryStore) Del(key string) error {
+ m.m.Delete(key)
+ return nil
+}
+
+var (
+ memStore = &memoryStore{}
+ rdsStore = &redisStore{}
+)
+
+func getStore() KVStore {
+ if common.RedisEnabled {
+ return rdsStore
+ }
+ return memStore
+}
+
+func storeSet(key, val string, ttl time.Duration) error { return getStore().Set(key, val, ttl) }
+func storeGet(key string) (string, bool) { return getStore().Get(key) }
+func storeDel(key string) error { return getStore().Del(key) }
diff --git a/src/oauth/util.go b/src/oauth/util.go
new file mode 100644
index 000000000..01e2aacdc
--- /dev/null
+++ b/src/oauth/util.go
@@ -0,0 +1,59 @@
+package oauth
+
+import (
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "net/http"
+ "net/url"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+// getFormOrBasicAuth extracts client_id/client_secret from Basic Auth first, then form
+func getFormOrBasicAuth(c *gin.Context) (clientID, clientSecret string) {
+ id, secret, ok := c.Request.BasicAuth()
+ if ok {
+ return strings.TrimSpace(id), strings.TrimSpace(secret)
+ }
+ return strings.TrimSpace(c.PostForm("client_id")), strings.TrimSpace(c.PostForm("client_secret"))
+}
+
+// genCode generates URL-safe random string based on nBytes of entropy
+func genCode(nBytes int) (string, error) {
+ b := make([]byte, nBytes)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// s256Base64URL computes base64url-encoded SHA256 digest
+func s256Base64URL(verifier string) string {
+ sum := sha256.Sum256([]byte(verifier))
+ return base64.RawURLEncoding.EncodeToString(sum[:])
+}
+
+// writeNoStore sets no-store cache headers for OAuth responses
+func writeNoStore(c *gin.Context) {
+ c.Header("Cache-Control", "no-store")
+ c.Header("Pragma", "no-cache")
+}
+
+// writeOAuthRedirectError builds an error redirect to redirect_uri as RFC6749
+func writeOAuthRedirectError(c *gin.Context, redirectURI, errCode, description, state string) {
+ writeNoStore(c)
+ q := "error=" + url.QueryEscape(errCode)
+ if description != "" {
+ q += "&error_description=" + url.QueryEscape(description)
+ }
+ if state != "" {
+ q += "&state=" + url.QueryEscape(state)
+ }
+ sep := "?"
+ if strings.Contains(redirectURI, "?") {
+ sep = "&"
+ }
+ c.Redirect(http.StatusFound, redirectURI+sep+q)
+}
diff --git a/web/public/oauth-demo.html b/web/public/oauth-demo.html
new file mode 100644
index 000000000..ba5821d32
--- /dev/null
+++ b/web/public/oauth-demo.html
@@ -0,0 +1,167 @@
+
+
+
+
+
+
+ OAuth2/OIDC 授权码 + PKCE 前端演示
+
+
+
+
+
OAuth2/OIDC 授权码 + PKCE 前端演示
+
+
+
+
+
+
+
提示:若未配置 Issuer,可直接填写下方端点。
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
说明:
+
+ - 本页为纯前端演示,适用于公开客户端(不需要 client_secret)。
+ - 如跨域调用 Token/UserInfo,需要服务端正确设置 CORS;建议将此 demo 部署到同源域名下。
+
+
+
+
+
+
+
+
+
+
+
+
可将服务端返回的 OIDC Discovery JSON 粘贴到此处,点击“解析并填充端点”。
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/web/src/App.jsx b/web/src/App.jsx
index 635742f91..4baf5c42b 100644
--- a/web/src/App.jsx
+++ b/web/src/App.jsx
@@ -44,6 +44,7 @@ import Task from './pages/Task';
import ModelPage from './pages/Model';
import Playground from './pages/Playground';
import OAuth2Callback from './components/auth/OAuth2Callback';
+import OAuthConsent from './pages/OAuth/Consent';
import PersonalSetting from './components/settings/PersonalSetting';
import Setup from './pages/Setup';
import SetupCheck from './components/layout/SetupCheck';
@@ -198,6 +199,14 @@ function App() {
}
/>
+ }>
+
+
+ }
+ />
.
For commercial licensing, please contact support@quantumnous.com
*/
-import React, { useState } from 'react';
+import React, { useEffect, useMemo, useState } from 'react';
import {
Modal,
Form,
@@ -40,17 +40,128 @@ const { Option } = Select;
const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => {
const [formApi, setFormApi] = useState(null);
const [loading, setLoading] = useState(false);
- const [redirectUris, setRedirectUris] = useState(['']);
+ const [redirectUris, setRedirectUris] = useState([]);
const [clientType, setClientType] = useState('confidential');
const [grantTypes, setGrantTypes] = useState(['client_credentials']);
+ const [allowedGrantTypes, setAllowedGrantTypes] = useState([
+ 'client_credentials',
+ 'authorization_code',
+ 'refresh_token',
+ ]);
+
+ // 加载后端允许的授权类型(用于限制和默认值)
+ useEffect(() => {
+ let mounted = true;
+ (async () => {
+ try {
+ const res = await API.get('/api/option/');
+ const { success, data } = res.data || {};
+ if (!success || !Array.isArray(data)) return;
+ const found = data.find((i) => i.key === 'oauth2.allowed_grant_types');
+ if (!found) return;
+ let parsed = [];
+ try {
+ parsed = JSON.parse(found.value || '[]');
+ } catch (_) {}
+ if (mounted && Array.isArray(parsed) && parsed.length) {
+ setAllowedGrantTypes(parsed);
+ }
+ } catch (_) {
+ // 忽略错误,使用默认allowedGrantTypes
+ }
+ })();
+ return () => {
+ mounted = false;
+ };
+ }, []);
+
+ const computeDefaultGrantTypes = (type, allowed) => {
+ const cand =
+ type === 'public'
+ ? ['authorization_code', 'refresh_token']
+ : ['client_credentials', 'authorization_code', 'refresh_token'];
+ const subset = cand.filter((g) => allowed.includes(g));
+ return subset.length ? subset : [allowed[0]].filter(Boolean);
+ };
+
+ // 当允许的类型或客户端类型变化时,自动设置更合理的默认值
+ useEffect(() => {
+ setGrantTypes((prev) => {
+ const normalizedPrev = Array.isArray(prev) ? prev : [];
+ // 移除不被允许或与客户端类型冲突的类型
+ let next = normalizedPrev.filter((g) => allowedGrantTypes.includes(g));
+ if (clientType === 'public') {
+ next = next.filter((g) => g !== 'client_credentials');
+ }
+ // 如果为空,则使用计算的默认
+ if (!next.length) {
+ next = computeDefaultGrantTypes(clientType, allowedGrantTypes);
+ }
+ return next;
+ });
+ }, [clientType, allowedGrantTypes]);
+
+ const isGrantTypeDisabled = (value) => {
+ if (!allowedGrantTypes.includes(value)) return true;
+ if (clientType === 'public' && value === 'client_credentials') return true;
+ return false;
+ };
+
+ // URL校验:允许 http(s),本地开发可 http
+ const isValidRedirectUri = (uri) => {
+ if (!uri || !uri.trim()) return false;
+ try {
+ const u = new URL(uri.trim());
+ if (u.protocol !== 'https:' && u.protocol !== 'http:') return false;
+ if (u.protocol === 'http:') {
+ // 仅允许本地开发时使用 http
+ const host = u.hostname;
+ const isLocal =
+ host === 'localhost' || host === '127.0.0.1' || host.endsWith('.local');
+ if (!isLocal) return false;
+ }
+ return true;
+ } catch (e) {
+ return false;
+ }
+ };
// 处理提交
const handleSubmit = async (values) => {
setLoading(true);
try {
// 过滤空的重定向URI
- const validRedirectUris = redirectUris.filter(uri => uri.trim());
-
+ const validRedirectUris = redirectUris
+ .map((u) => (u || '').trim())
+ .filter((u) => u.length > 0);
+
+ // 业务校验
+ if (!grantTypes.length) {
+ showError('请至少选择一种授权类型');
+ return;
+ }
+ // 校验是否包含不被允许的授权类型
+ const invalids = grantTypes.filter((g) => !allowedGrantTypes.includes(g));
+ if (invalids.length) {
+ showError(`不被允许的授权类型: ${invalids.join(', ')}`);
+ return;
+ }
+ if (clientType === 'public' && grantTypes.includes('client_credentials')) {
+ showError('公开客户端不允许使用client_credentials授权类型');
+ return;
+ }
+ if (grantTypes.includes('authorization_code')) {
+ if (!validRedirectUris.length) {
+ showError('选择授权码授权类型时,必须填写至少一个重定向URI');
+ return;
+ }
+ const allValid = validRedirectUris.every(isValidRedirectUri);
+ if (!allValid) {
+ showError('重定向URI格式不合法:仅支持https,或本地开发使用http');
+ return;
+ }
+ }
+
const payload = {
...values,
client_type: clientType,
@@ -118,8 +229,8 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => {
formApi.reset();
}
setClientType('confidential');
- setGrantTypes(['client_credentials']);
- setRedirectUris(['']);
+ setGrantTypes(computeDefaultGrantTypes('confidential', allowedGrantTypes));
+ setRedirectUris([]);
};
// 处理取消
@@ -149,9 +260,13 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => {
const handleGrantTypesChange = (values) => {
setGrantTypes(values);
// 如果包含authorization_code但没有重定向URI,则添加一个
- if (values.includes('authorization_code') && redirectUris.length === 1 && !redirectUris[0]) {
+ if (values.includes('authorization_code') && redirectUris.length === 0) {
setRedirectUris(['']);
}
+ // 公开客户端不允许client_credentials
+ if (clientType === 'public' && values.includes('client_credentials')) {
+ setGrantTypes(values.filter((v) => v !== 'client_credentials'));
+ }
};
return (
@@ -159,7 +274,7 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => {
title="创建OAuth2客户端"
visible={visible}
onCancel={handleCancel}
- onOk={() => formApi?.submit()}
+ onOk={() => formApi?.submitForm()}
okText="创建"
cancelText="取消"
confirmLoading={loading}
@@ -168,6 +283,12 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => {
>
PKCE(Proof Key for Code Exchange)可提高授权码流程的安全性。
{/* 重定向URI */}
- {grantTypes.includes('authorization_code') && (
+ {(grantTypes.includes('authorization_code') || redirectUris.length > 0) && (
<>
重定向URI配置
重定向URI
- 用于授权码流程,用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP)。
+ 用于授权码流程,用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP,仅限localhost/127.0.0.1)。
@@ -315,4 +443,4 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => {
);
};
-export default CreateOAuth2ClientModal;
\ No newline at end of file
+export default CreateOAuth2ClientModal;
diff --git a/web/src/components/modals/oauth2/EditOAuth2ClientModal.jsx b/web/src/components/modals/oauth2/EditOAuth2ClientModal.jsx
index 39729bba9..7eec45e3d 100644
--- a/web/src/components/modals/oauth2/EditOAuth2ClientModal.jsx
+++ b/web/src/components/modals/oauth2/EditOAuth2ClientModal.jsx
@@ -39,8 +39,39 @@ const { Option } = Select;
const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
const [formApi, setFormApi] = useState(null);
const [loading, setLoading] = useState(false);
- const [redirectUris, setRedirectUris] = useState(['']);
+ const [redirectUris, setRedirectUris] = useState([]);
const [grantTypes, setGrantTypes] = useState(['client_credentials']);
+ const [allowedGrantTypes, setAllowedGrantTypes] = useState([
+ 'client_credentials',
+ 'authorization_code',
+ 'refresh_token',
+ ]);
+
+ // 加载后端允许的授权类型
+ useEffect(() => {
+ let mounted = true;
+ (async () => {
+ try {
+ const res = await API.get('/api/option/');
+ const { success, data } = res.data || {};
+ if (!success || !Array.isArray(data)) return;
+ const found = data.find((i) => i.key === 'oauth2.allowed_grant_types');
+ if (!found) return;
+ let parsed = [];
+ try {
+ parsed = JSON.parse(found.value || '[]');
+ } catch (_) {}
+ if (mounted && Array.isArray(parsed) && parsed.length) {
+ setAllowedGrantTypes(parsed);
+ }
+ } catch (_) {
+ // 忽略错误
+ }
+ })();
+ return () => {
+ mounted = false;
+ };
+ }, []);
// 初始化表单数据
useEffect(() => {
@@ -60,9 +91,12 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
} else if (Array.isArray(client.scopes)) {
parsedScopes = client.scopes;
}
+ if (!parsedScopes || parsedScopes.length === 0) {
+ parsedScopes = ['openid', 'profile', 'email', 'api:read'];
+ }
// 解析重定向URI
- let parsedRedirectUris = [''];
+ let parsedRedirectUris = [];
if (client.redirect_uris) {
try {
const parsed = typeof client.redirect_uris === 'string'
@@ -76,8 +110,20 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
}
}
- setGrantTypes(parsedGrantTypes);
- setRedirectUris(parsedRedirectUris);
+ // 过滤不被允许或不兼容的授权类型
+ const filteredGrantTypes = (parsedGrantTypes || []).filter((g) =>
+ allowedGrantTypes.includes(g),
+ );
+ const finalGrantTypes = client.client_type === 'public'
+ ? filteredGrantTypes.filter((g) => g !== 'client_credentials')
+ : filteredGrantTypes;
+
+ setGrantTypes(finalGrantTypes);
+ if (finalGrantTypes.includes('authorization_code') && parsedRedirectUris.length === 0) {
+ setRedirectUris(['']);
+ } else {
+ setRedirectUris(parsedRedirectUris);
+ }
// 设置表单值
const formValues = {
@@ -87,7 +133,7 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
client_type: client.client_type,
grant_types: parsedGrantTypes,
scopes: parsedScopes,
- require_pkce: client.require_pkce,
+ require_pkce: !!client.require_pkce,
status: client.status,
};
if (formApi) {
@@ -101,7 +147,57 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
setLoading(true);
try {
// 过滤空的重定向URI
- const validRedirectUris = redirectUris.filter(uri => uri.trim());
+ const validRedirectUris = redirectUris
+ .map((u) => (u || '').trim())
+ .filter((u) => u.length > 0);
+
+ // 校验授权类型
+ if (!grantTypes.length) {
+ showError('请至少选择一种授权类型');
+ setLoading(false);
+ return;
+ }
+ const invalids = grantTypes.filter((g) => !allowedGrantTypes.includes(g));
+ if (invalids.length) {
+ showError(`不被允许的授权类型: ${invalids.join(', ')}`);
+ setLoading(false);
+ return;
+ }
+ if (client?.client_type === 'public' && grantTypes.includes('client_credentials')) {
+ showError('公开客户端不允许使用client_credentials授权类型');
+ setLoading(false);
+ return;
+ }
+ // 授权码需要有效重定向URI
+ const isValidRedirectUri = (uri) => {
+ if (!uri || !uri.trim()) return false;
+ try {
+ const u = new URL(uri.trim());
+ if (u.protocol !== 'https:' && u.protocol !== 'http:') return false;
+ if (u.protocol === 'http:') {
+ const host = u.hostname;
+ const isLocal =
+ host === 'localhost' || host === '127.0.0.1' || host.endsWith('.local');
+ if (!isLocal) return false;
+ }
+ return true;
+ } catch (e) {
+ return false;
+ }
+ };
+ if (grantTypes.includes('authorization_code')) {
+ if (!validRedirectUris.length) {
+ showError('选择授权码授权类型时,必须填写至少一个重定向URI');
+ setLoading(false);
+ return;
+ }
+ const allValid = validRedirectUris.every(isValidRedirectUri);
+ if (!allValid) {
+ showError('重定向URI格式不合法:仅支持https,或本地开发使用http');
+ setLoading(false);
+ return;
+ }
+ }
const payload = {
...values,
@@ -146,9 +242,13 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
const handleGrantTypesChange = (values) => {
setGrantTypes(values);
// 如果包含authorization_code但没有重定向URI,则添加一个
- if (values.includes('authorization_code') && redirectUris.length === 1 && !redirectUris[0]) {
+ if (values.includes('authorization_code') && redirectUris.length === 0) {
setRedirectUris(['']);
}
+ // 公开客户端不允许client_credentials
+ if (client?.client_type === 'public' && values.includes('client_credentials')) {
+ setGrantTypes(values.filter((v) => v !== 'client_credentials'));
+ }
};
if (!client) return null;
@@ -158,7 +258,7 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
title={`编辑OAuth2客户端 - ${client.name}`}
visible={visible}
onCancel={onCancel}
- onOk={() => formApi?.submit()}
+ onOk={() => formApi?.submitForm()}
okText="保存"
cancelText="取消"
confirmLoading={loading}
@@ -217,9 +317,17 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
onChange={handleGrantTypesChange}
rules={[{ required: true, message: '请选择至少一种授权类型' }]}
>
-
-
-
+
+
+
{/* Scope */}
@@ -229,6 +337,9 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
multiple
rules={[{ required: true, message: '请选择至少一个权限范围' }]}
>
+
+
+
@@ -254,13 +365,13 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
{/* 重定向URI */}
- {grantTypes.includes('authorization_code') && (
+ {(grantTypes.includes('authorization_code') || redirectUris.length > 0) && (
<>
重定向URI配置
重定向URI
- 用于授权码流程,用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP)。
+ 用于授权码流程,用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP,仅限localhost/127.0.0.1)。
@@ -303,4 +414,4 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
);
};
-export default EditOAuth2ClientModal;
\ No newline at end of file
+export default EditOAuth2ClientModal;
diff --git a/web/src/components/modals/oauth2/JWKSManagerModal.jsx b/web/src/components/modals/oauth2/JWKSManagerModal.jsx
new file mode 100644
index 000000000..ef5d3c5c6
--- /dev/null
+++ b/web/src/components/modals/oauth2/JWKSManagerModal.jsx
@@ -0,0 +1,148 @@
+import React, { useEffect, useState } from 'react';
+import { Modal, Table, Button, Space, Tag, Typography, Popconfirm, Toast, Form, TextArea, Divider, Input } from '@douyinfe/semi-ui';
+import { IconRefresh, IconDelete, IconPlay } from '@douyinfe/semi-icons';
+import { API, showError, showSuccess } from '../../../helpers';
+
+const { Text } = Typography;
+
+export default function JWKSManagerModal({ visible, onClose }) {
+ const [loading, setLoading] = useState(false);
+ const [keys, setKeys] = useState([]);
+
+ const load = async () => {
+ setLoading(true);
+ try {
+ const res = await API.get('/api/oauth/keys');
+ if (res?.data?.success) setKeys(res.data.data || []);
+ else showError(res?.data?.message || '获取密钥列表失败');
+ } catch { showError('获取密钥列表失败'); } finally { setLoading(false); }
+ };
+
+ const rotate = async () => {
+ setLoading(true);
+ try {
+ const res = await API.post('/api/oauth/keys/rotate', {});
+ if (res?.data?.success) { showSuccess('签名密钥已轮换:' + res.data.kid); await load(); }
+ else showError(res?.data?.message || '密钥轮换失败');
+ } catch { showError('密钥轮换失败'); } finally { setLoading(false); }
+ };
+
+ const del = async (kid) => {
+ setLoading(true);
+ try {
+ const res = await API.delete(`/api/oauth/keys/${kid}`);
+ if (res?.data?.success) { Toast.success('已删除:' + kid); await load(); }
+ else showError(res?.data?.message || '删除失败');
+ } catch { showError('删除失败'); } finally { setLoading(false); }
+ };
+
+ useEffect(() => { if (visible) load(); }, [visible]);
+ useEffect(() => {
+ if (!visible) return;
+ (async ()=>{
+ try{
+ const res = await API.get('/api/oauth/server-info');
+ const p = res?.data?.default_private_key_path;
+ if (p) setGenPath(p);
+ }catch{}
+ })();
+ }, [visible]);
+
+ // Import PEM state
+ const [showImport, setShowImport] = useState(false);
+ const [pem, setPem] = useState('');
+ const [customKid, setCustomKid] = useState('');
+ const importPem = async () => {
+ if (!pem.trim()) return Toast.warning('请粘贴 PEM 私钥');
+ setLoading(true);
+ try {
+ const res = await API.post('/api/oauth/keys/import_pem', { pem, kid: customKid.trim() });
+ if (res?.data?.success) {
+ Toast.success('已导入私钥并切换到 kid=' + res.data.kid);
+ setPem(''); setCustomKid(''); setShowImport(false);
+ await load();
+ } else {
+ Toast.error(res?.data?.message || '导入失败');
+ }
+ } catch { Toast.error('导入失败'); } finally { setLoading(false); }
+ };
+
+ // Generate PEM file state
+ const [showGenerate, setShowGenerate] = useState(false);
+ const [genPath, setGenPath] = useState('/etc/new-api/oauth2-private.pem');
+ const [genKid, setGenKid] = useState('');
+ const generatePemFile = async () => {
+ if (!genPath.trim()) return Toast.warning('请填写保存路径');
+ setLoading(true);
+ try {
+ const res = await API.post('/api/oauth/keys/generate_file', { path: genPath.trim(), kid: genKid.trim() });
+ if (res?.data?.success) {
+ Toast.success('已生成并生效:' + res.data.path);
+ await load();
+ } else {
+ Toast.error(res?.data?.message || '生成失败');
+ }
+ } catch { Toast.error('生成失败'); } finally { setLoading(false); }
+ };
+
+ const columns = [
+ { title: 'KID', dataIndex: 'kid', render: (kid) => {kid} },
+ { title: '创建时间', dataIndex: 'created_at', render: (ts) => (ts ? new Date(ts * 1000).toLocaleString() : '-') },
+ { title: '状态', dataIndex: 'current', render: (cur) => (cur ? 当前 : 历史) },
+ { title: '操作', render: (_, r) => (
+
+ {!r.current && (
+ del(r.kid)}>
+ } size='small' theme='borderless'>删除
+
+ )}
+
+ ) },
+ ];
+
+ return (
+
+
+ } onClick={load} loading={loading}>刷新
+ } type='primary' onClick={rotate} loading={loading}>轮换密钥
+
+
+
+
+ {showGenerate && (
+
+
+
+
+
+
+
+
+
建议:仅在合规要求下使用文件私钥。请确保目录权限安全(建议 0600),并妥善备份。
+
+ )}
+ {showImport && (
+
+
+
+
+
+
+
+
+
建议:优先使用内存签名密钥与 JWKS 轮换;仅在有合规要求时导入外部私钥。
+
+ )}
+ 暂无密钥} />
+
+ );
+}
diff --git a/web/src/components/modals/oauth2/OAuth2QuickStartModal.jsx b/web/src/components/modals/oauth2/OAuth2QuickStartModal.jsx
new file mode 100644
index 000000000..91559c580
--- /dev/null
+++ b/web/src/components/modals/oauth2/OAuth2QuickStartModal.jsx
@@ -0,0 +1,230 @@
+import React, { useEffect, useMemo, useState } from 'react';
+import { Modal, Steps, Form, Input, Select, Switch, Typography, Space, Button, Tag, Toast } from '@douyinfe/semi-ui';
+import { API, showError, showSuccess } from '../../../helpers';
+
+const { Text } = Typography;
+
+export default function OAuth2QuickStartModal({ visible, onClose, onDone }) {
+ const origin = useMemo(() => window.location.origin, []);
+ const [step, setStep] = useState(0);
+ const [loading, setLoading] = useState(false);
+
+ // Step state
+ const [enableOAuth, setEnableOAuth] = useState(true);
+ const [issuer, setIssuer] = useState(origin);
+
+ const [clientType, setClientType] = useState('public');
+ const [redirect1, setRedirect1] = useState(origin + '/oauth/oidc');
+ const [redirect2, setRedirect2] = useState('');
+ const [scopes, setScopes] = useState(['openid', 'profile', 'email', 'api:read']);
+
+ // Results
+ const [createdClient, setCreatedClient] = useState(null);
+
+ useEffect(() => {
+ if (!visible) {
+ setStep(0);
+ setLoading(false);
+ setEnableOAuth(true);
+ setIssuer(origin);
+ setClientType('public');
+ setRedirect1(origin + '/oauth/oidc');
+ setRedirect2('');
+ setScopes(['openid', 'profile', 'email', 'api:read']);
+ setCreatedClient(null);
+ }
+ }, [visible, origin]);
+
+ // 打开时读取现有配置作为默认值
+ useEffect(() => {
+ if (!visible) return;
+ (async () => {
+ try {
+ const res = await API.get('/api/option/');
+ const { success, data } = res.data || {};
+ if (!success || !Array.isArray(data)) return;
+ const map = Object.fromEntries(data.map(i => [i.key, i.value]));
+ if (typeof map['oauth2.enabled'] !== 'undefined') {
+ setEnableOAuth(String(map['oauth2.enabled']).toLowerCase() === 'true');
+ }
+ if (map['oauth2.issuer']) {
+ setIssuer(map['oauth2.issuer']);
+ }
+ } catch (_) {}
+ })();
+ }, [visible]);
+
+ const applyRecommended = async () => {
+ setLoading(true);
+ try {
+ const ops = [
+ { key: 'oauth2.enabled', value: String(enableOAuth) },
+ { key: 'oauth2.issuer', value: issuer || '' },
+ { key: 'oauth2.allowed_grant_types', value: JSON.stringify(['authorization_code', 'refresh_token', 'client_credentials']) },
+ { key: 'oauth2.require_pkce', value: 'true' },
+ { key: 'oauth2.jwt_signing_algorithm', value: 'RS256' },
+ ];
+ for (const op of ops) {
+ await API.put('/api/option/', op);
+ }
+ showSuccess('已应用推荐配置');
+ setStep(1);
+ onDone && onDone();
+ } catch (e) {
+ showError('应用推荐配置失败');
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ const rotateKey = async () => {
+ setLoading(true);
+ try {
+ const res = await API.post('/api/oauth/keys/rotate', {});
+ if (res?.data?.success) {
+ showSuccess('签名密钥已准备:' + res.data.kid);
+ } else {
+ showError(res?.data?.message || '签名密钥操作失败');
+ return;
+ }
+ setStep(2);
+ } catch (e) {
+ showError('签名密钥操作失败');
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ const createClient = async () => {
+ setLoading(true);
+ try {
+ const grant_types = clientType === 'public' ? ['authorization_code', 'refresh_token'] : ['authorization_code', 'refresh_token', 'client_credentials'];
+ const payload = {
+ name: 'Default OIDC Client',
+ client_type: clientType,
+ grant_types,
+ redirect_uris: [redirect1, redirect2].filter(Boolean),
+ scopes,
+ require_pkce: true,
+ };
+ const res = await API.post('/api/oauth_clients/', payload);
+ if (res?.data?.success) {
+ setCreatedClient({ id: res.data.client_id, secret: res.data.client_secret });
+ showSuccess('客户端已创建');
+ setStep(3);
+ } else {
+ showError(res?.data?.message || '创建失败');
+ }
+ onDone && onDone();
+ } catch (e) {
+ showError('创建失败');
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ const steps = [
+ {
+ title: '应用推荐配置',
+ content: (
+
+
+
+
+
+
+
+
+
说明
+
+ grant_types: auth_code / refresh_token / client_credentials
+ PKCE: S256
+ 算法: RS256
+
+
+
+
+
+
+
+ )
+ },
+ {
+ title: '准备签名密钥',
+ content: (
+
+
若无密钥则初始化;如已存在建议立即轮换以生成新的 kid 并发布到 JWKS。
+
+
+
+
+ )
+ },
+ {
+ title: '创建默认 OIDC 客户端',
+ content: (
+
+
+ 公开客户端(SPA/移动端)
+ 机密客户端(服务端)
+
+
+
+
+ openid
+ profile
+ email
+ api:read
+ api:write
+ admin
+
+
+
+
+
+
+ )
+ },
+ {
+ title: '完成',
+ content: (
+
+ {createdClient ? (
+
+
客户端已创建:
+
+ Client ID: {createdClient.id}
+
+ {createdClient.secret && (
+
+ Client Secret(仅此一次展示): {createdClient.secret}
+
+ )}
+
+ ) :
已完成初始化。}
+
+ )
+ }
+ ];
+
+ return (
+
+
+ {steps.map((s, idx) => )}
+
+
+ {steps[step].content}
+
+
+ );
+}
diff --git a/web/src/components/modals/oauth2/OAuth2ToolsModal.jsx b/web/src/components/modals/oauth2/OAuth2ToolsModal.jsx
new file mode 100644
index 000000000..515954bc9
--- /dev/null
+++ b/web/src/components/modals/oauth2/OAuth2ToolsModal.jsx
@@ -0,0 +1,324 @@
+import React, { useEffect, useMemo, useState } from 'react';
+import { Modal, Form, Input, Button, Space, Select, Typography, Divider, Toast, TextArea } from '@douyinfe/semi-ui';
+import { API } from '../../../helpers';
+
+const { Text } = Typography;
+
+async function sha256Base64Url(input) {
+ const enc = new TextEncoder();
+ const data = enc.encode(input);
+ const hash = await crypto.subtle.digest('SHA-256', data);
+ const bytes = new Uint8Array(hash);
+ let binary = '';
+ for (let i = 0; i < bytes.byteLength; i++) binary += String.fromCharCode(bytes[i]);
+ return btoa(binary).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '');
+}
+
+function randomString(len = 43) {
+ const charset = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';
+ let res = '';
+ const array = new Uint32Array(len);
+ crypto.getRandomValues(array);
+ for (let i = 0; i < len; i++) res += charset[array[i] % charset.length];
+ return res;
+}
+
+export default function OAuth2ToolsModal({ visible, onClose }) {
+ const [server, setServer] = useState({});
+ const [authURL, setAuthURL] = useState('');
+ const [issuer, setIssuer] = useState('');
+ const [confJSON, setConfJSON] = useState('');
+ const [userinfoEndpoint, setUserinfoEndpoint] = useState('');
+ const [code, setCode] = useState('');
+ const [accessToken, setAccessToken] = useState('');
+ const [idToken, setIdToken] = useState('');
+ const [refreshToken, setRefreshToken] = useState('');
+ const [tokenRaw, setTokenRaw] = useState('');
+ const [jwtClaims, setJwtClaims] = useState('');
+ const [userinfoOut, setUserinfoOut] = useState('');
+ const [values, setValues] = useState({
+ authorization_endpoint: '',
+ token_endpoint: '',
+ client_id: '',
+ client_secret: '',
+ redirect_uri: window.location.origin + '/oauth/oidc',
+ scope: 'openid profile email',
+ response_type: 'code',
+ code_verifier: '',
+ code_challenge: '',
+ code_challenge_method: 'S256',
+ state: '',
+ nonce: '',
+ });
+
+ useEffect(() => {
+ if (!visible) return;
+ (async () => {
+ try {
+ const res = await API.get('/api/oauth/server-info');
+ if (res?.data) {
+ const d = res.data;
+ setServer(d);
+ setValues((v) => ({
+ ...v,
+ authorization_endpoint: d.authorization_endpoint,
+ token_endpoint: d.token_endpoint,
+ }));
+ setIssuer(d.issuer || '');
+ setUserinfoEndpoint(d.userinfo_endpoint || '');
+ }
+ } catch {}
+ })();
+ }, [visible]);
+
+ const buildAuthorizeURL = () => {
+ const u = new URL(values.authorization_endpoint || (server.issuer + '/api/oauth/authorize'));
+ const rt = values.response_type || 'code';
+ u.searchParams.set('response_type', rt);
+ u.searchParams.set('client_id', values.client_id);
+ u.searchParams.set('redirect_uri', values.redirect_uri);
+ u.searchParams.set('scope', values.scope);
+ if (values.state) u.searchParams.set('state', values.state);
+ if (values.nonce) u.searchParams.set('nonce', values.nonce);
+ if (rt === 'code' && values.code_challenge) {
+ u.searchParams.set('code_challenge', values.code_challenge);
+ u.searchParams.set('code_challenge_method', values.code_challenge_method || 'S256');
+ }
+ return u.toString();
+ };
+
+ const copy = async (text, tip = '已复制') => {
+ try { await navigator.clipboard.writeText(text); Toast.success(tip); } catch {}
+ };
+
+ const genVerifier = async () => {
+ const v = randomString(64);
+ const c = await sha256Base64Url(v);
+ setValues((val) => ({ ...val, code_verifier: v, code_challenge: c }));
+ };
+
+ const discover = async () => {
+ const iss = (issuer || '').trim();
+ if (!iss) { Toast.warning('请填写 Issuer'); return; }
+ try {
+ const url = iss.replace(/\/$/, '') + '/api/.well-known/openid-configuration';
+ const res = await fetch(url);
+ const d = await res.json();
+ setValues((v)=>({
+ ...v,
+ authorization_endpoint: d.authorization_endpoint || v.authorization_endpoint,
+ token_endpoint: d.token_endpoint || v.token_endpoint,
+ }));
+ setUserinfoEndpoint(d.userinfo_endpoint || '');
+ setIssuer(d.issuer || iss);
+ setConfJSON(JSON.stringify(d, null, 2));
+ Toast.success('已从发现文档加载端点');
+ } catch (e) {
+ Toast.error('自动发现失败');
+ }
+ };
+
+ const parseConf = () => {
+ try {
+ const d = JSON.parse(confJSON || '{}');
+ if (d.issuer) setIssuer(d.issuer);
+ if (d.authorization_endpoint) setValues((v)=>({...v, authorization_endpoint: d.authorization_endpoint}));
+ if (d.token_endpoint) setValues((v)=>({...v, token_endpoint: d.token_endpoint}));
+ if (d.userinfo_endpoint) setUserinfoEndpoint(d.userinfo_endpoint);
+ Toast.success('已解析配置并填充端点');
+ } catch (e) {
+ Toast.error('解析失败:' + e.message);
+ }
+ };
+
+ const genConf = () => {
+ const d = {
+ issuer: issuer || undefined,
+ authorization_endpoint: values.authorization_endpoint || undefined,
+ token_endpoint: values.token_endpoint || undefined,
+ userinfo_endpoint: userinfoEndpoint || undefined,
+ };
+ setConfJSON(JSON.stringify(d, null, 2));
+ };
+
+ async function postForm(url, data, basicAuth) {
+ const body = Object.entries(data)
+ .filter(([_, v]) => v !== undefined && v !== null)
+ .map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(String(v))}`)
+ .join('&');
+ const headers = { 'Content-Type': 'application/x-www-form-urlencoded' };
+ if (basicAuth) headers['Authorization'] = 'Basic ' + btoa(`${basicAuth.id}:${basicAuth.secret}`);
+ const res = await fetch(url, { method: 'POST', headers, body });
+ if (!res.ok) {
+ const t = await res.text();
+ throw new Error(`HTTP ${res.status} ${t}`);
+ }
+ return res.json();
+ }
+
+ const exchangeCode = async () => {
+ try {
+ const basic = values.client_secret ? { id: values.client_id, secret: values.client_secret } : undefined;
+ const data = await postForm(values.token_endpoint, {
+ grant_type: 'authorization_code',
+ code: code.trim(),
+ client_id: values.client_id,
+ redirect_uri: values.redirect_uri,
+ code_verifier: values.code_verifier,
+ }, basic);
+ setAccessToken(data.access_token || '');
+ setIdToken(data.id_token || '');
+ setRefreshToken(data.refresh_token || '');
+ setTokenRaw(JSON.stringify(data, null, 2));
+ Toast.success('已获取令牌');
+ } catch (e) {
+ Toast.error('兑换失败:' + e.message);
+ }
+ };
+
+ const decodeIdToken = () => {
+ const t = (idToken || '').trim();
+ if (!t) { setJwtClaims('(空)'); return; }
+ const parts = t.split('.');
+ if (parts.length < 2) { setJwtClaims('格式错误'); return; }
+ try {
+ const json = JSON.parse(atob(parts[1].replace(/-/g,'+').replace(/_/g,'/')));
+ setJwtClaims(JSON.stringify(json, null, 2));
+ } catch (e) {
+ setJwtClaims('解码失败:' + e);
+ }
+ };
+
+ const callUserInfo = async () => {
+ if (!accessToken || !userinfoEndpoint) { Toast.warning('缺少 AccessToken 或 UserInfo 端点'); return; }
+ try {
+ const res = await fetch(userinfoEndpoint, { headers: { Authorization: 'Bearer ' + accessToken } });
+ const data = await res.json();
+ setUserinfoOut(JSON.stringify(data, null, 2));
+ } catch (e) {
+ setUserinfoOut('调用失败:' + e);
+ }
+ };
+
+ const doRefresh = async () => {
+ if (!refreshToken) { Toast.warning('没有刷新令牌'); return; }
+ try {
+ const basic = values.client_secret ? { id: values.client_id, secret: values.client_secret } : undefined;
+ const data = await postForm(values.token_endpoint, {
+ grant_type: 'refresh_token',
+ refresh_token: refreshToken,
+ client_id: values.client_id,
+ }, basic);
+ setAccessToken(data.access_token || '');
+ setIdToken(data.id_token || '');
+ setRefreshToken(data.refresh_token || '');
+ setTokenRaw(JSON.stringify(data, null, 2));
+ Toast.success('刷新成功');
+ } catch (e) {
+ Toast.error('刷新失败:' + e.message);
+ }
+ };
+
+ return (
+ 关闭}
+ width={720}
+ style={{ top: 48 }}
+ >
+ {/* Discovery */}
+ OIDC 发现
+
+
+
+
+
+
+
+
+
+
+ {/* Authorization URL & PKCE */}
+ 授权参数
+ setValues({...values, response_type: v})}>
+ code
+ token
+
+ setValues({...values, authorization_endpoint: v})} />
+ setValues({...values, token_endpoint: v})} />
+ setValues({...values, client_id: v})} />
+ setValues({...values, client_secret: v})} />
+ setValues({...values, redirect_uri: v})} />
+ setValues({...values, scope: v})} />
+ setValues({...values, code_challenge_method: v})}>
+ S256
+
+ setValues({...values, code_verifier: v})} suffix={} />
+ setValues({...values, code_challenge: v})} />
+ setValues({...values, state: v})} suffix={} />
+ setValues({...values, nonce: v})} suffix={} />
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 提示:将上述参数粘贴到 oauthdebugger.com,或直接打开授权URL完成授权后回调。
+
+
+
+ {/* Token exchange */}
+ 令牌操作
+
+
+
+
+
+
+
+ copy(accessToken,'AccessToken已复制')}>复制} />
+ 解码} />
+
+
+
+
+
+
+ UserInfo
+
+
+
+
+
+
+
+ );
+}
diff --git a/web/src/components/settings/OAuth2Setting.jsx b/web/src/components/settings/OAuth2Setting.jsx
index 7b2bac6ae..d03d633b8 100644
--- a/web/src/components/settings/OAuth2Setting.jsx
+++ b/web/src/components/settings/OAuth2Setting.jsx
@@ -18,26 +18,18 @@ For commercial licensing, please contact support@quantumnous.com
*/
import React, { useEffect, useState } from 'react';
-import { Card, Spin } from '@douyinfe/semi-ui';
-import { API, showError, toBoolean } from '../../helpers';
+import { Card, Spin, Space, Button } from '@douyinfe/semi-ui';
+import { API, showError } from '../../helpers';
import OAuth2ServerSettings from '../../pages/Setting/OAuth2/OAuth2ServerSettings';
import OAuth2ClientSettings from '../../pages/Setting/OAuth2/OAuth2ClientSettings';
+// import OAuth2Tools from '../../pages/Setting/OAuth2/OAuth2Tools';
+import OAuth2ToolsModal from '../../components/modals/oauth2/OAuth2ToolsModal';
+import OAuth2QuickStartModal from '../../components/modals/oauth2/OAuth2QuickStartModal';
+import JWKSManagerModal from '../../components/modals/oauth2/JWKSManagerModal';
const OAuth2Setting = () => {
- const [inputs, setInputs] = useState({
- 'oauth2.enabled': false,
- 'oauth2.issuer': '',
- 'oauth2.access_token_ttl': 10,
- 'oauth2.refresh_token_ttl': 720,
- 'oauth2.jwt_signing_algorithm': 'RS256',
- 'oauth2.jwt_key_id': 'oauth2-key-1',
- 'oauth2.jwt_private_key_file': '',
- 'oauth2.allowed_grant_types': ['client_credentials', 'authorization_code'],
- 'oauth2.require_pkce': true,
- 'oauth2.auto_create_user': false,
- 'oauth2.default_user_role': 1,
- 'oauth2.default_user_group': 'default',
- });
+ // 原样保存后端 Option 键值(字符串),避免类型转换造成子组件解析错误
+ const [options, setOptions] = useState({});
const [loading, setLoading] = useState(false);
const getOptions = async () => {
@@ -46,25 +38,11 @@ const OAuth2Setting = () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
- let newInputs = {};
- data.forEach((item) => {
- if (Object.keys(inputs).includes(item.key)) {
- if (item.key === 'oauth2.allowed_grant_types') {
- try {
- newInputs[item.key] = JSON.parse(item.value || '["client_credentials","authorization_code"]');
- } catch {
- newInputs[item.key] = ['client_credentials', 'authorization_code'];
- }
- } else if (typeof inputs[item.key] === 'boolean') {
- newInputs[item.key] = toBoolean(item.value);
- } else if (typeof inputs[item.key] === 'number') {
- newInputs[item.key] = parseInt(item.value) || inputs[item.key];
- } else {
- newInputs[item.key] = item.value;
- }
- }
- });
- setInputs({...inputs, ...newInputs});
+ const map = {};
+ for (const item of data) {
+ map[item.key] = item.value;
+ }
+ setOptions(map);
} else {
showError(message);
}
@@ -83,6 +61,10 @@ const OAuth2Setting = () => {
getOptions();
}, []);
+ const [qsVisible, setQsVisible] = useState(false);
+ const [jwksVisible, setJwksVisible] = useState(false);
+ const [toolsVisible, setToolsVisible] = useState(false);
+
return (
{
marginTop: '10px',
}}
>
-
+
+
+
+
+
+
+
+
+ setQsVisible(false)} onDone={refresh} />
+ setJwksVisible(false)} />
+ setToolsVisible(false)} />
+ setJwksVisible(true)} />
);
};
-export default OAuth2Setting;
\ No newline at end of file
+export default OAuth2Setting;
diff --git a/web/src/pages/OAuth/Consent.jsx b/web/src/pages/OAuth/Consent.jsx
new file mode 100644
index 000000000..e98c525f6
--- /dev/null
+++ b/web/src/pages/OAuth/Consent.jsx
@@ -0,0 +1,199 @@
+import React, { useEffect, useMemo, useState } from 'react';
+import { Card, Button, Typography, Tag, Space, Divider, Spin, Banner, Descriptions, Avatar, Tooltip } from '@douyinfe/semi-ui';
+import { IconShield, IconTickCircle, IconClose } from '@douyinfe/semi-icons';
+import { useLocation } from 'react-router-dom';
+import { API, showError } from '../../helpers';
+
+const { Title, Text, Paragraph } = Typography;
+
+function useQuery() {
+ const { search } = useLocation();
+ return useMemo(() => new URLSearchParams(search), [search]);
+}
+
+export default function OAuthConsent() {
+ const query = useQuery();
+ const [loading, setLoading] = useState(true);
+ const [info, setInfo] = useState(null);
+ const [error, setError] = useState('');
+
+ const params = useMemo(() => {
+ const allowed = [
+ 'response_type',
+ 'client_id',
+ 'redirect_uri',
+ 'scope',
+ 'state',
+ 'code_challenge',
+ 'code_challenge_method',
+ 'nonce',
+ ];
+ const obj = {};
+ allowed.forEach((k) => {
+ const v = query.get(k);
+ if (v) obj[k] = v;
+ });
+ if (!obj.response_type) obj.response_type = 'code';
+ return obj;
+ }, [query]);
+
+ useEffect(() => {
+ (async () => {
+ setLoading(true);
+ try {
+ const res = await API.get('/api/oauth/authorize', {
+ params: { ...params, mode: 'prepare' },
+ // skip error toast, we'll handle gracefully
+ skipErrorHandler: true,
+ });
+ setInfo(res.data);
+ setError('');
+ } catch (e) {
+ // 401 login required or other error
+ setError(e?.response?.data?.error || 'failed');
+ } finally {
+ setLoading(false);
+ }
+ })();
+ }, [params]);
+
+ const onApprove = () => {
+ const u = new URL(window.location.origin + '/api/oauth/authorize');
+ Object.entries(params).forEach(([k, v]) => u.searchParams.set(k, v));
+ u.searchParams.set('approve', '1');
+ window.location.href = u.toString();
+ };
+ const onDeny = () => {
+ const u = new URL(window.location.origin + '/api/oauth/authorize');
+ Object.entries(params).forEach(([k, v]) => u.searchParams.set(k, v));
+ u.searchParams.set('deny', '1');
+ window.location.href = u.toString();
+ };
+
+ const renderScope = () => {
+ if (!info?.scope_info?.length) return (
+
+ {info?.scope_list?.map((s) => (
+ {s}
+ ))}
+
+ );
+ return (
+
+ {info.scope_info.map((s) => (
+
+ {s.Name}
+
+ ))}
+
+ );
+ };
+
+ const displayClient = () => (
+
+
+
+ {String(info?.client?.name || info?.client?.id || 'A').slice(0, 1).toUpperCase()}
+
+ {info?.client?.name || info?.client?.id}
+ {info?.verified && 已验证}
+ {info?.client?.type === 'public' && 公开客户端}
+ {info?.client?.type === 'confidential' && 机密客户端}
+
+ {info?.client?.desc && (
+
{info.client.desc}
+ )}
+
+
+ );
+
+ const displayUser = () => (
+
+ {String(info?.user?.name || 'U').slice(0,1).toUpperCase()}
+ {info?.user?.name || '当前用户'}
+ {info?.user?.email && ({info.user.email})}
+
+
+ );
+
+ return (
+
+
+
+
+
+
应用请求访问你的账户
+
请确认是否授权下列权限给第三方应用。
+
+
+
+ {loading ? (
+
+
+
+ ) : error ? (
+
+ ) : (
+ info && (
+
+
+
+
+ {displayClient()}
+ {displayUser()}
+
+ 请求的权限范围
+ {renderScope()}
+
+
+
回调地址
+
{info?.redirect_uri}
+
+
+
+
+
安全提示
+
+ - 仅在信任的网络环境中授权。
+ - 确认回调域名与申请方一致{info?.verified ? '(已验证)' : '(未验证)'}。
+ - 你可以随时在账户设置中撤销授权。
+
+
+
+
+
+
+
+
+
+
+ } onClick={onDeny} theme='borderless'>
+ 拒绝
+
+ } type='primary' onClick={onApprove}>
+ 授权
+
+
+
+ )
+ )}
+
+
+ );
+}
diff --git a/web/src/pages/Setting/OAuth2/JWKSManager.jsx b/web/src/pages/Setting/OAuth2/JWKSManager.jsx
new file mode 100644
index 000000000..0c4c75b43
--- /dev/null
+++ b/web/src/pages/Setting/OAuth2/JWKSManager.jsx
@@ -0,0 +1,123 @@
+import React, { useEffect, useState } from 'react';
+import { Card, Table, Button, Space, Tag, Typography, Popconfirm, Toast } from '@douyinfe/semi-ui';
+import { IconRefresh, IconDelete, IconPlay } from '@douyinfe/semi-icons';
+import { API, showError, showSuccess } from '../../../helpers';
+
+const { Text } = Typography;
+
+export default function JWKSManager() {
+ const [loading, setLoading] = useState(false);
+ const [keys, setKeys] = useState([]);
+
+ const load = async () => {
+ setLoading(true);
+ try {
+ const res = await API.get('/api/oauth/keys');
+ if (res?.data?.success) {
+ setKeys(res.data.data || []);
+ } else {
+ showError(res?.data?.message || '获取密钥列表失败');
+ }
+ } catch (e) {
+ showError('获取密钥列表失败');
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ const rotate = async () => {
+ setLoading(true);
+ try {
+ const res = await API.post('/api/oauth/keys/rotate', {});
+ if (res?.data?.success) {
+ showSuccess('签名密钥已轮换:' + res.data.kid);
+ await load();
+ } else {
+ showError(res?.data?.message || '密钥轮换失败');
+ }
+ } catch (e) {
+ showError('密钥轮换失败');
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ const del = async (kid) => {
+ setLoading(true);
+ try {
+ const res = await API.delete(`/api/oauth/keys/${kid}`);
+ if (res?.data?.success) {
+ Toast.success('已删除:' + kid);
+ await load();
+ } else {
+ showError(res?.data?.message || '删除失败');
+ }
+ } catch (e) {
+ showError('删除失败');
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ useEffect(() => {
+ load();
+ }, []);
+
+ const columns = [
+ {
+ title: 'KID',
+ dataIndex: 'kid',
+ render: (kid) => {kid},
+ },
+ {
+ title: '创建时间',
+ dataIndex: 'created_at',
+ render: (ts) => (ts ? new Date(ts * 1000).toLocaleString() : '-'),
+ },
+ {
+ title: '状态',
+ dataIndex: 'current',
+ render: (cur) => (cur ? 当前 : 历史),
+ },
+ {
+ title: '操作',
+ render: (_, r) => (
+
+ {!r.current && (
+ del(r.kid)}
+ >
+ } size='small' theme='borderless'>删除
+
+ )}
+
+ ),
+ },
+ ];
+
+ return (
+
+ } onClick={load} loading={loading}>刷新
+ } type='primary' onClick={rotate} loading={loading}>轮换密钥
+
+ }
+ style={{ marginTop: 10 }}
+ >
+ 暂无密钥}
+ />
+
+ );
+}
+
diff --git a/web/src/pages/Setting/OAuth2/OAuth2ClientSettings.jsx b/web/src/pages/Setting/OAuth2/OAuth2ClientSettings.jsx
index 4afc6478e..01dab707d 100644
--- a/web/src/pages/Setting/OAuth2/OAuth2ClientSettings.jsx
+++ b/web/src/pages/Setting/OAuth2/OAuth2ClientSettings.jsx
@@ -193,20 +193,37 @@ export default function OAuth2ClientSettings() {
编辑
{record.client_type === 'confidential' && (
-
+
+
)}
+ 客户端:{record.name}
+ 删除后无法恢复,相关 API 调用将立即失效。
+
+ }
onConfirm={() => handleDelete(record)}
- okText="确定"
+ okText="确定删除"
cancelText="取消"
>