diff --git a/controller/oauth.go b/controller/oauth.go index 7d1bde7ea..1552aa10e 100644 --- a/controller/oauth.go +++ b/controller/oauth.go @@ -3,10 +3,16 @@ package controller import ( "encoding/json" "net/http" + "one-api/model" "one-api/setting/system_setting" "one-api/src/oauth" + "time" "github.com/gin-gonic/gin" + jwt "github.com/golang-jwt/jwt/v5" + "one-api/middleware" + "strconv" + "strings" ) // GetJWKS 获取JWKS公钥集 @@ -19,6 +25,9 @@ func GetJWKS(c *gin.Context) { return } + // lazy init if needed + _ = oauth.EnsureInitialized() + jwks := oauth.GetJWKS() if jwks == nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -70,7 +79,7 @@ func OAuthTokenEndpoint(c *gin.Context) { // 只允许application/x-www-form-urlencoded内容类型 contentType := c.GetHeader("Content-Type") - if contentType != "application/x-www-form-urlencoded" { + if contentType == "" || !strings.Contains(strings.ToLower(contentType), "application/x-www-form-urlencoded") { c.JSON(http.StatusBadRequest, gin.H{ "error": "invalid_request", "error_description": "Content-Type must be application/x-www-form-urlencoded", @@ -78,7 +87,11 @@ func OAuthTokenEndpoint(c *gin.Context) { return } - // 委托给OAuth2服务器处理 + // lazy init + if err := oauth.EnsureInitialized(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()}) + return + } oauth.HandleTokenRequest(c) } @@ -93,7 +106,10 @@ func OAuthAuthorizeEndpoint(c *gin.Context) { return } - // 委托给OAuth2服务器处理 + if err := oauth.EnsureInitialized(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()}) + return + } oauth.HandleAuthorizeRequest(c) } @@ -108,20 +124,68 @@ func OAuthServerInfo(c *gin.Context) { } // 返回OAuth2服务器的基本信息(类似OpenID Connect Discovery) + issuer := settings.Issuer + if issuer == "" { + scheme := "https" + if c.Request.TLS == nil { + if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" { + scheme = hdr + } else { + scheme = "http" + } + } + issuer = scheme + "://" + c.Request.Host + } + + base := issuer + "/api" c.JSON(http.StatusOK, gin.H{ - "issuer": settings.Issuer, - "authorization_endpoint": settings.Issuer + "/oauth/authorize", - "token_endpoint": settings.Issuer + "/oauth/token", - "jwks_uri": settings.Issuer + "/.well-known/jwks.json", + "issuer": issuer, + "authorization_endpoint": base + "/oauth/authorize", + "token_endpoint": base + "/oauth/token", + "jwks_uri": base + "/.well-known/jwks.json", "grant_types_supported": settings.AllowedGrantTypes, - "response_types_supported": []string{"code"}, + "response_types_supported": []string{"code", "token"}, "token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"}, "code_challenge_methods_supported": []string{"S256"}, - "scopes_supported": []string{ - "api:read", - "api:write", - "admin", - }, + "scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"}, + "default_private_key_path": settings.DefaultPrivateKeyPath, + }) +} + +// OAuthOIDCConfiguration OIDC discovery document +func OAuthOIDCConfiguration(c *gin.Context) { + settings := system_setting.GetOAuth2Settings() + if !settings.Enabled { + c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"}) + return + } + issuer := settings.Issuer + if issuer == "" { + scheme := "https" + if c.Request.TLS == nil { + if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" { + scheme = hdr + } else { + scheme = "http" + } + } + issuer = scheme + "://" + c.Request.Host + } + base := issuer + "/api" + c.JSON(http.StatusOK, gin.H{ + "issuer": issuer, + "authorization_endpoint": base + "/oauth/authorize", + "token_endpoint": base + "/oauth/token", + "userinfo_endpoint": base + "/oauth/userinfo", + "jwks_uri": base + "/.well-known/jwks.json", + "response_types_supported": []string{"code", "token"}, + "grant_types_supported": settings.AllowedGrantTypes, + "subject_types_supported": []string{"public"}, + "id_token_signing_alg_values_supported": []string{"RS256"}, + "scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"}, + "token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"}, + "code_challenge_methods_supported": []string{"S256"}, + "default_private_key_path": settings.DefaultPrivateKeyPath, }) } @@ -152,14 +216,50 @@ func OAuthIntrospect(c *gin.Context) { return } - // TODO: 实现令牌内省逻辑 - // 1. 验证调用者的认证信息 - // 2. 解析和验证JWT令牌 - // 3. 返回令牌的元信息 + tokenString := token - c.JSON(http.StatusOK, gin.H{ - "active": false, // 临时返回,需要实现实际的内省逻辑 + // 验证并解析JWT + parsed, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, jwt.ErrTokenSignatureInvalid + } + pub := oauth.GetPublicKeyByKid(func() string { + if v, ok := token.Header["kid"].(string); ok { + return v + } + return "" + }()) + if pub == nil { + return nil, jwt.ErrTokenUnverifiable + } + return pub, nil }) + if err != nil || !parsed.Valid { + c.JSON(http.StatusOK, gin.H{"active": false}) + return + } + + claims, ok := parsed.Claims.(jwt.MapClaims) + if !ok { + c.JSON(http.StatusOK, gin.H{"active": false}) + return + } + + // 检查撤销 + if jti, ok := claims["jti"].(string); ok && jti != "" { + if revoked, _ := model.IsTokenRevoked(jti); revoked { + c.JSON(http.StatusOK, gin.H{"active": false}) + return + } + } + + // 有效 + resp := gin.H{"active": true} + for k, v := range claims { + resp[k] = v + } + resp["token_type"] = "Bearer" + c.JSON(http.StatusOK, resp) } // OAuthRevoke 令牌撤销端点(RFC 7009) @@ -190,11 +290,86 @@ func OAuthRevoke(c *gin.Context) { return } - // TODO: 实现令牌撤销逻辑 - // 1. 验证调用者的认证信息 - // 2. 撤销指定的令牌(加入黑名单或从存储中删除) + token = c.PostForm("token") + if token == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid_request", + "error_description": "Missing token parameter", + }) + return + } - c.JSON(http.StatusOK, gin.H{ - "success": true, + // 尝试解析JWT,若成功则记录jti到撤销表 + parsed, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok { + return nil, jwt.ErrTokenSignatureInvalid + } + pub := oauth.GetRSAPublicKey() + if pub == nil { + return nil, jwt.ErrTokenUnverifiable + } + return pub, nil }) + if err == nil && parsed != nil && parsed.Valid { + if claims, ok := parsed.Claims.(jwt.MapClaims); ok { + var jti string + var exp int64 + if v, ok := claims["jti"].(string); ok { + jti = v + } + if v, ok := claims["exp"].(float64); ok { + exp = int64(v) + } else if v, ok := claims["exp"].(int64); ok { + exp = v + } + if jti != "" { + // 如果没有exp,默认撤销至当前+TTL 10分钟 + if exp == 0 { + exp = time.Now().Add(10 * time.Minute).Unix() + } + _ = model.RevokeToken(jti, exp) + } + } + } + + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// OAuthUserInfo returns OIDC userinfo based on access token +func OAuthUserInfo(c *gin.Context) { + settings := system_setting.GetOAuth2Settings() + if !settings.Enabled { + c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"}) + return + } + // 需要 OAuthJWTAuth 中间件注入 claims + claims, ok := middleware.GetOAuthClaims(c) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"}) + return + } + // scope 校验:必须包含 openid + scope, _ := claims["scope"].(string) + if !strings.Contains(" "+scope+" ", " openid ") { + c.JSON(http.StatusForbidden, gin.H{"error": "insufficient_scope"}) + return + } + sub, _ := claims["sub"].(string) + resp := gin.H{"sub": sub} + // 若包含 profile/email scope,补充返回 + if strings.Contains(" "+scope+" ", " profile ") || strings.Contains(" "+scope+" ", " email ") { + if uid, err := strconv.Atoi(sub); err == nil { + if user, err2 := model.GetUserById(uid, false); err2 == nil && user != nil { + if strings.Contains(" "+scope+" ", " profile ") { + resp["name"] = user.DisplayName + resp["preferred_username"] = user.Username + } + if strings.Contains(" "+scope+" ", " email ") { + resp["email"] = user.Email + resp["email_verified"] = true + } + } + } + } + c.JSON(http.StatusOK, resp) } diff --git a/controller/oauth_keys.go b/controller/oauth_keys.go new file mode 100644 index 000000000..9e2397d3a --- /dev/null +++ b/controller/oauth_keys.go @@ -0,0 +1,89 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/logger" + "one-api/src/oauth" +) + +type rotateKeyRequest struct { + Kid string `json:"kid"` +} + +type genKeyFileRequest struct { + Path string `json:"path"` + Kid string `json:"kid"` + Overwrite bool `json:"overwrite"` +} + +type importPemRequest struct { + Pem string `json:"pem"` + Kid string `json:"kid"` +} + +// RotateOAuthSigningKey rotates the OAuth2 JWT signing key (Root only) +func RotateOAuthSigningKey(c *gin.Context) { + var req rotateKeyRequest + _ = c.BindJSON(&req) + kid, err := oauth.RotateSigningKey(req.Kid) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) + return + } + logger.LogInfo(c, "oauth signing key rotated: "+kid) + c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid}) +} + +// ListOAuthSigningKeys returns current and historical JWKS signing keys +func ListOAuthSigningKeys(c *gin.Context) { + keys := oauth.ListSigningKeys() + c.JSON(http.StatusOK, gin.H{"success": true, "data": keys}) +} + +// DeleteOAuthSigningKey deletes a non-current key by kid +func DeleteOAuthSigningKey(c *gin.Context) { + kid := c.Param("kid") + if kid == "" { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "kid required"}) + return + } + if err := oauth.DeleteSigningKey(kid); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } + logger.LogInfo(c, "oauth signing key deleted: "+kid) + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// GenerateOAuthSigningKeyFile generates a private key file and rotates current kid +func GenerateOAuthSigningKeyFile(c *gin.Context) { + var req genKeyFileRequest + if err := c.ShouldBindJSON(&req); err != nil || req.Path == "" { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "path required"}) + return + } + kid, err := oauth.GenerateAndPersistKey(req.Path, req.Kid, req.Overwrite) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) + return + } + logger.LogInfo(c, "oauth signing key generated to file: "+req.Path+" kid="+kid) + c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid, "path": req.Path}) +} + +// ImportOAuthSigningKey imports PEM text and rotates current kid +func ImportOAuthSigningKey(c *gin.Context) { + var req importPemRequest + if err := c.ShouldBindJSON(&req); err != nil || req.Pem == "" { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "pem required"}) + return + } + kid, err := oauth.ImportPEMKey(req.Pem, req.Kid) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } + logger.LogInfo(c, "oauth signing key imported from PEM, kid="+kid) + c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid}) +} diff --git a/examples/oauth/oauth-demo.html b/examples/oauth/oauth-demo.html new file mode 100644 index 000000000..210e0d254 --- /dev/null +++ b/examples/oauth/oauth-demo.html @@ -0,0 +1,326 @@ + + + + + + OAuth2/OIDC 授权码 + PKCE 前端演示 + + + +
+

OAuth2/OIDC 授权码 + PKCE 前端演示

+ +
+
+
+ + +
+
提示:若未配置 Issuer,可直接填写下方端点。
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + + + + +
+
+
+ + +
+
+
+
+
说明: +
    +
  • 本页为纯前端演示,适用于公开客户端(不需要 client_secret)。
  • +
  • 如跨域调用 Token/UserInfo,需要服务端正确设置 CORS;建议将此 demo 部署到同源域名下。
  • +
+
+ +
+
+
+ + +
+ + +
+
可将服务端返回的 OIDC Discovery JSON 粘贴到此处,点击“解析并填充端点”。
+
+
+
+ +
+
+
+ +
等待授权...
+
+
+
+
+ + +
+ + +
+
+
+
+ + +
+ +
+

+        
+
+
+
+ + +
+ +
+
+
+ + +
+
+
+
+ + + + diff --git a/examples/oauth/oauth2_test_client.go b/examples/oauth/oauth2_test_client.go new file mode 100644 index 000000000..d8e6dd239 --- /dev/null +++ b/examples/oauth/oauth2_test_client.go @@ -0,0 +1,181 @@ +package main + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/url" + "path" + "strings" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" +) + +func main() { + // 测试 Client Credentials 流程 + //testClientCredentials() + + // 测试 Authorization Code + PKCE 流程(需要浏览器交互) + testAuthorizationCode() +} + +// testClientCredentials 测试服务对服务认证 +func testClientCredentials() { + fmt.Println("=== Testing Client Credentials Flow ===") + + cfg := clientcredentials.Config{ + ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端 + ClientSecret: "hLLdn2Ia4UM7hcsJaSuUFDV0Px9BrkNq", + TokenURL: "http://localhost:3000/api/oauth/token", + Scopes: []string{"api:read", "api:write"}, + EndpointParams: map[string][]string{ + "audience": {"api://new-api"}, + }, + } + + // 创建HTTP客户端 + httpClient := cfg.Client(context.Background()) + + // 调用受保护的API + resp, err := httpClient.Get("http://localhost:3000/api/status") + if err != nil { + log.Printf("Request failed: %v", err) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("Failed to read response: %v", err) + return + } + + fmt.Printf("Status: %s\n", resp.Status) + fmt.Printf("Response: %s\n", string(body)) +} + +// testAuthorizationCode 测试授权码流程 +func testAuthorizationCode() { + fmt.Println("=== Testing Authorization Code + PKCE Flow ===") + + conf := oauth2.Config{ + ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端 + ClientSecret: "JHiugKf89OMmTLuZMZyA2sgZnO0Ioae3", + RedirectURL: "http://localhost:9999/callback", + // 包含 openid/profile/email 以便调用 UserInfo + Scopes: []string{"openid", "profile", "email", "api:read"}, + Endpoint: oauth2.Endpoint{ + AuthURL: "http://localhost:3000/api/oauth/authorize", + TokenURL: "http://localhost:3000/api/oauth/token", + }, + } + + // 生成PKCE参数 + codeVerifier := oauth2.GenerateVerifier() + state := fmt.Sprintf("state-%d", time.Now().Unix()) + + // 构建授权URL + url := conf.AuthCodeURL( + state, + oauth2.S256ChallengeOption(codeVerifier), + //oauth2.SetAuthURLParam("audience", "api://new-api"), + ) + + fmt.Printf("Visit this URL to authorize:\n%s\n\n", url) + fmt.Printf("A local server will listen on http://localhost:9999/callback to receive the code...\n") + + // 启动回调本地服务器,自动接收授权码 + codeCh := make(chan string, 1) + srv := &http.Server{Addr: ":9999"} + http.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + if errParam := q.Get("error"); errParam != "" { + fmt.Fprintf(w, "Authorization failed: %s", errParam) + return + } + gotState := q.Get("state") + if gotState != state { + http.Error(w, "state mismatch", http.StatusBadRequest) + return + } + code := q.Get("code") + if code == "" { + http.Error(w, "missing code", http.StatusBadRequest) + return + } + fmt.Fprintln(w, "Authorization received. You may close this window.") + select { + case codeCh <- code: + default: + } + go func() { + // 稍后关闭服务 + _ = srv.Shutdown(context.Background()) + }() + }) + go func() { + _ = srv.ListenAndServe() + }() + + // 等待授权码 + var code string + select { + case code = <-codeCh: + case <-time.After(5 * time.Minute): + log.Println("Timeout waiting for authorization code") + _ = srv.Shutdown(context.Background()) + return + } + + // 交换令牌 + token, err := conf.Exchange( + context.Background(), + code, + oauth2.VerifierOption(codeVerifier), + ) + if err != nil { + log.Printf("Token exchange failed: %v", err) + return + } + + fmt.Printf("Access Token: %s\n", token.AccessToken) + fmt.Printf("Token Type: %s\n", token.TokenType) + fmt.Printf("Expires In: %v\n", token.Expiry) + + // 使用令牌调用 UserInfo + client := conf.Client(context.Background(), token) + userInfoURL := buildUserInfoFromAuth(conf.Endpoint.AuthURL) + resp, err := client.Get(userInfoURL) + if err != nil { + log.Printf("UserInfo request failed: %v", err) + return + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("Failed to read UserInfo response: %v", err) + return + } + fmt.Printf("UserInfo: %s\n", string(body)) +} + +// buildUserInfoFromAuth 将授权端点URL转换为UserInfo端点URL +func buildUserInfoFromAuth(auth string) string { + u, err := url.Parse(auth) + if err != nil { + return "" + } + // 将最后一个路径段 authorize 替换为 userinfo + dir := path.Dir(u.Path) + if strings.HasSuffix(u.Path, "/authorize") { + u.Path = path.Join(dir, "userinfo") + } else { + // 回退:追加默认 /oauth/userinfo + u.Path = path.Join(dir, "userinfo") + } + return u.String() +} diff --git a/examples/oauth2_test_client.go b/examples/oauth2_test_client.go deleted file mode 100644 index 30a6ac233..000000000 --- a/examples/oauth2_test_client.go +++ /dev/null @@ -1,125 +0,0 @@ -package main - -import ( - "context" - "fmt" - "io" - "log" - "net/http" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/clientcredentials" -) - -func main() { - // 测试 Client Credentials 流程 - testClientCredentials() - - // 测试 Authorization Code + PKCE 流程(需要浏览器交互) - // testAuthorizationCode() -} - -// testClientCredentials 测试服务对服务认证 -func testClientCredentials() { - fmt.Println("=== Testing Client Credentials Flow ===") - - cfg := clientcredentials.Config{ - ClientID: "client_demo123456789", // 需要先创建客户端 - ClientSecret: "demo_secret_32_chars_long_123456", - TokenURL: "http://127.0.0.1:8080/api/oauth/token", - Scopes: []string{"api:read", "api:write"}, - EndpointParams: map[string][]string{ - "audience": {"api://new-api"}, - }, - } - - // 创建HTTP客户端 - httpClient := cfg.Client(context.Background()) - - // 调用受保护的API - resp, err := httpClient.Get("http://127.0.0.1:8080/api/status") - if err != nil { - log.Printf("Request failed: %v", err) - return - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Printf("Failed to read response: %v", err) - return - } - - fmt.Printf("Status: %s\n", resp.Status) - fmt.Printf("Response: %s\n", string(body)) -} - -// testAuthorizationCode 测试授权码流程 -func testAuthorizationCode() { - fmt.Println("=== Testing Authorization Code + PKCE Flow ===") - - conf := oauth2.Config{ - ClientID: "client_web123456789", // Web客户端 - ClientSecret: "web_secret_32_chars_long_123456", - RedirectURL: "http://localhost:9999/callback", - Scopes: []string{"api:read", "api:write"}, - Endpoint: oauth2.Endpoint{ - AuthURL: "http://127.0.0.1:8080/api/oauth/authorize", - TokenURL: "http://127.0.0.1:8080/api/oauth/token", - }, - } - - // 生成PKCE参数 - codeVerifier := oauth2.GenerateVerifier() - - // 构建授权URL - url := conf.AuthCodeURL( - "random-state-string", - oauth2.S256ChallengeOption(codeVerifier), - oauth2.SetAuthURLParam("audience", "api://new-api"), - ) - - fmt.Printf("Visit this URL to authorize:\n%s\n\n", url) - fmt.Printf("After authorization, you'll get a code. Use it to exchange for tokens.\n") - - // 在实际应用中,这里需要启动一个HTTP服务器来接收回调 - // 或者手动输入从回调URL中获取的授权码 - - fmt.Print("Enter the authorization code: ") - var code string - fmt.Scanln(&code) - - if code != "" { - // 交换令牌 - token, err := conf.Exchange( - context.Background(), - code, - oauth2.VerifierOption(codeVerifier), - ) - if err != nil { - log.Printf("Token exchange failed: %v", err) - return - } - - fmt.Printf("Access Token: %s\n", token.AccessToken) - fmt.Printf("Token Type: %s\n", token.TokenType) - fmt.Printf("Expires In: %v\n", token.Expiry) - - // 使用令牌调用API - client := conf.Client(context.Background(), token) - resp, err := client.Get("http://127.0.0.1:8080/api/status") - if err != nil { - log.Printf("API request failed: %v", err) - return - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Printf("Failed to read response: %v", err) - return - } - - fmt.Printf("API Response: %s\n", string(body)) - } -} diff --git a/middleware/auth.go b/middleware/auth.go index 25caf50d9..eaf73998c 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -8,11 +8,14 @@ import ( "one-api/model" "one-api/setting" "one-api/setting/ratio_setting" + "one-api/setting/system_setting" + "one-api/src/oauth" "strconv" "strings" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + jwt "github.com/golang-jwt/jwt/v5" ) func validUserInfo(username string, role int) bool { @@ -177,6 +180,7 @@ func WssAuth(c *gin.Context) { func TokenAuth() func(c *gin.Context) { return func(c *gin.Context) { + rawAuth := c.Request.Header.Get("Authorization") // 先检测是否为ws if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" { // Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1 @@ -235,6 +239,11 @@ func TokenAuth() func(c *gin.Context) { } } if err != nil { + // OAuth Bearer fallback + if tryOAuthBearer(c, rawAuth) { + c.Next() + return + } abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return } @@ -288,6 +297,74 @@ func TokenAuth() func(c *gin.Context) { } } +// tryOAuthBearer validates an OAuth JWT access token and sets minimal context for relay +func tryOAuthBearer(c *gin.Context, rawAuth string) bool { + if rawAuth == "" || !strings.HasPrefix(rawAuth, "Bearer ") { + return false + } + tokenString := strings.TrimSpace(strings.TrimPrefix(rawAuth, "Bearer ")) + if tokenString == "" { + return false + } + settings := system_setting.GetOAuth2Settings() + // Parse & verify + parsed, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok { + return nil, jwt.ErrTokenSignatureInvalid + } + if kid, ok := t.Header["kid"].(string); ok { + if settings.JWTKeyID != "" && kid != settings.JWTKeyID { + return nil, jwt.ErrTokenSignatureInvalid + } + } + pub := oauth.GetRSAPublicKey() + if pub == nil { + return nil, jwt.ErrTokenUnverifiable + } + return pub, nil + }) + if err != nil || parsed == nil || !parsed.Valid { + return false + } + claims, ok := parsed.Claims.(jwt.MapClaims) + if !ok { + return false + } + // issuer check when configured + if iss, ok2 := claims["iss"].(string); !ok2 || (settings.Issuer != "" && iss != settings.Issuer) { + return false + } + // revoke check + if jti, ok2 := claims["jti"].(string); ok2 && jti != "" { + if revoked, _ := model.IsTokenRevoked(jti); revoked { + return false + } + } + // scope check: must contain api:read or api:write or admin + scope, _ := claims["scope"].(string) + scopePadded := " " + scope + " " + if !(strings.Contains(scopePadded, " api:read ") || strings.Contains(scopePadded, " api:write ") || strings.Contains(scopePadded, " admin ")) { + return false + } + // subject must be user id to support quota logic + sub, _ := claims["sub"].(string) + uid, err := strconv.Atoi(sub) + if err != nil || uid <= 0 { + return false + } + // load user cache & set context + userCache, err := model.GetUserCache(uid) + if err != nil || userCache == nil || userCache.Status != common.UserStatusEnabled { + return false + } + c.Set("id", uid) + c.Set("group", userCache.Group) + c.Set("user_group", userCache.Group) + // set UsingGroup + common.SetContextKey(c, constant.ContextKeyUsingGroup, userCache.Group) + return true +} + func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error { if token == nil { return fmt.Errorf("token is nil") diff --git a/middleware/oauth_jwt.go b/middleware/oauth_jwt.go index 3e8fe0c69..38cbb3fe1 100644 --- a/middleware/oauth_jwt.go +++ b/middleware/oauth_jwt.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/model" "one-api/setting/system_setting" + "one-api/src/oauth" "strings" "github.com/gin-gonic/gin" @@ -108,23 +109,20 @@ func getPublicKeyByKid(kid string) (*rsa.PublicKey, error) { // 这里先实现一个简单版本 // TODO: 实现JWKS缓存和刷新机制 - settings := system_setting.GetOAuth2Settings() - if settings.JWTKeyID == kid { - // 从OAuth server模块获取公钥 - // 这需要在OAuth server初始化后才能使用 - return nil, fmt.Errorf("JWKS functionality not yet implemented") + pub := oauth.GetPublicKeyByKid(kid) + if pub == nil { + return nil, fmt.Errorf("unknown kid: %s", kid) } - - return nil, fmt.Errorf("unknown kid: %s", kid) + return pub, nil } // validateOAuthClaims 验证OAuth2 claims func validateOAuthClaims(claims jwt.MapClaims) error { settings := system_setting.GetOAuth2Settings() - // 验证issuer + // 验证issuer(若配置了 Issuer 则强校验,否则仅要求存在) if iss, ok := claims["iss"].(string); ok { - if iss != settings.Issuer { + if settings.Issuer != "" && iss != settings.Issuer { return fmt.Errorf("invalid issuer") } } else { @@ -146,6 +144,14 @@ func validateOAuthClaims(claims jwt.MapClaims) error { if client.Status != common.UserStatusEnabled { return fmt.Errorf("client disabled") } + + // 检查是否被撤销 + if jti, ok := claims["jti"].(string); ok && jti != "" { + revoked, _ := model.IsTokenRevoked(jti) + if revoked { + return fmt.Errorf("token revoked") + } + } } else { return fmt.Errorf("missing client_id claim") } @@ -240,6 +246,34 @@ func OptionalOAuthAuth() gin.HandlerFunc { } } +// RequireOAuthScopeIfPresent enforces scope only when OAuth is present; otherwise no-op +func RequireOAuthScopeIfPresent(requiredScope string) gin.HandlerFunc { + return func(c *gin.Context) { + if !c.GetBool("oauth_authenticated") { + c.Next() + return + } + scope, exists := c.Get("oauth_scope") + if !exists { + abortWithOAuthError(c, "insufficient_scope", "No scope in token") + return + } + scopeStr, ok := scope.(string) + if !ok { + abortWithOAuthError(c, "insufficient_scope", "Invalid scope format") + return + } + scopes := strings.Split(scopeStr, " ") + for _, s := range scopes { + if strings.TrimSpace(s) == requiredScope { + c.Next() + return + } + } + abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope)) + } +} + // GetOAuthClaims 获取OAuth claims func GetOAuthClaims(c *gin.Context) (jwt.MapClaims, bool) { claims, exists := c.Get("oauth_claims") diff --git a/model/oauth_revoked_token.go b/model/oauth_revoked_token.go new file mode 100644 index 000000000..35e7d4b08 --- /dev/null +++ b/model/oauth_revoked_token.go @@ -0,0 +1,57 @@ +package model + +import ( + "fmt" + "one-api/common" + "sync" + "time" +) + +var revokedMem sync.Map // jti -> exp(unix) + +func RevokeToken(jti string, exp int64) error { + if jti == "" { + return nil + } + // Prefer Redis, else in-memory + if common.RedisEnabled { + ttl := time.Duration(0) + if exp > 0 { + ttl = time.Until(time.Unix(exp, 0)) + } + if ttl <= 0 { + ttl = time.Minute + } + key := fmt.Sprintf("oauth:revoked:%s", jti) + return common.RedisSet(key, "1", ttl) + } + if exp <= 0 { + exp = time.Now().Add(time.Minute).Unix() + } + revokedMem.Store(jti, exp) + return nil +} + +func IsTokenRevoked(jti string) (bool, error) { + if jti == "" { + return false, nil + } + if common.RedisEnabled { + key := fmt.Sprintf("oauth:revoked:%s", jti) + if _, err := common.RedisGet(key); err == nil { + return true, nil + } else { + // Not found or error; treat as not revoked on error to avoid hard failures + return false, nil + } + } + // In-memory check + if v, ok := revokedMem.Load(jti); ok { + exp, _ := v.(int64) + if exp == 0 || time.Now().Unix() <= exp { + return true, nil + } + revokedMem.Delete(jti) + } + return false, nil +} diff --git a/router/api-router.go b/router/api-router.go index 9b4f61a65..9b43c6cf8 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -34,11 +34,13 @@ func SetApiRouter(router *gin.Engine) { // OAuth2 Server endpoints apiRouter.GET("/.well-known/jwks.json", controller.GetJWKS) + apiRouter.GET("/.well-known/openid-configuration", controller.OAuthOIDCConfiguration) apiRouter.GET("/.well-known/oauth-authorization-server", controller.OAuthServerInfo) apiRouter.POST("/oauth/token", middleware.CriticalRateLimit(), controller.OAuthTokenEndpoint) apiRouter.GET("/oauth/authorize", controller.OAuthAuthorizeEndpoint) apiRouter.POST("/oauth/introspect", middleware.AdminAuth(), controller.OAuthIntrospect) apiRouter.POST("/oauth/revoke", middleware.CriticalRateLimit(), controller.OAuthRevoke) + apiRouter.GET("/oauth/userinfo", middleware.OAuthJWTAuth(), controller.OAuthUserInfo) // OAuth2 管理API (前端使用) apiRouter.GET("/oauth/jwks", controller.GetJWKS) @@ -53,6 +55,17 @@ func SetApiRouter(router *gin.Engine) { apiRouter.POST("/stripe/webhook", controller.StripeWebhook) + // OAuth2 admin operations + oauthAdmin := apiRouter.Group("/oauth") + oauthAdmin.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth()) + { + oauthAdmin.POST("/keys/rotate", controller.RotateOAuthSigningKey) + oauthAdmin.GET("/keys", controller.ListOAuthSigningKeys) + oauthAdmin.DELETE("/keys/:kid", controller.DeleteOAuthSigningKey) + oauthAdmin.POST("/keys/generate_file", controller.GenerateOAuthSigningKeyFile) + oauthAdmin.POST("/keys/import_pem", controller.ImportOAuthSigningKey) + } + userRoute := apiRouter.Group("/user") { userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) @@ -91,7 +104,7 @@ func SetApiRouter(router *gin.Engine) { } adminRoute := userRoute.Group("/") - adminRoute.Use(middleware.AdminAuth()) + adminRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth()) { adminRoute.GET("/", controller.GetAllUsers) adminRoute.GET("/search", controller.SearchUsers) @@ -107,7 +120,7 @@ func SetApiRouter(router *gin.Engine) { } } optionRoute := apiRouter.Group("/option") - optionRoute.Use(middleware.RootAuth()) + optionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth()) { optionRoute.GET("/", controller.GetOptions) optionRoute.PUT("/", controller.UpdateOption) @@ -121,7 +134,7 @@ func SetApiRouter(router *gin.Engine) { ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios) } channelRoute := apiRouter.Group("/channel") - channelRoute.Use(middleware.AdminAuth()) + channelRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth()) { channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) @@ -172,7 +185,7 @@ func SetApiRouter(router *gin.Engine) { } redemptionRoute := apiRouter.Group("/redemption") - redemptionRoute.Use(middleware.AdminAuth()) + redemptionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth()) { redemptionRoute.GET("/", controller.GetAllRedemptions) redemptionRoute.GET("/search", controller.SearchRedemptions) @@ -200,13 +213,13 @@ func SetApiRouter(router *gin.Engine) { logRoute.GET("/token", controller.GetLogByKey) } groupRoute := apiRouter.Group("/group") - groupRoute.Use(middleware.AdminAuth()) + groupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth()) { groupRoute.GET("/", controller.GetGroups) } prefillGroupRoute := apiRouter.Group("/prefill_group") - prefillGroupRoute.Use(middleware.AdminAuth()) + prefillGroupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth()) { prefillGroupRoute.GET("/", controller.GetPrefillGroups) prefillGroupRoute.POST("/", controller.CreatePrefillGroup) diff --git a/setting/system_setting/oauth2.go b/setting/system_setting/oauth2.go index 078fe69e9..8bcd73d77 100644 --- a/setting/system_setting/oauth2.go +++ b/setting/system_setting/oauth2.go @@ -3,19 +3,21 @@ package system_setting import "one-api/setting/config" type OAuth2Settings struct { - Enabled bool `json:"enabled"` - Issuer string `json:"issuer"` - AccessTokenTTL int `json:"access_token_ttl"` // in minutes - RefreshTokenTTL int `json:"refresh_token_ttl"` // in minutes - AllowedGrantTypes []string `json:"allowed_grant_types"` // client_credentials, authorization_code, refresh_token - RequirePKCE bool `json:"require_pkce"` // force PKCE for authorization code flow - JWTSigningAlgorithm string `json:"jwt_signing_algorithm"` - JWTKeyID string `json:"jwt_key_id"` - JWTPrivateKeyFile string `json:"jwt_private_key_file"` - AutoCreateUser bool `json:"auto_create_user"` // auto create user on first OAuth2 login - DefaultUserRole int `json:"default_user_role"` // default role for auto-created users - DefaultUserGroup string `json:"default_user_group"` // default group for auto-created users - ScopeMappings map[string][]string `json:"scope_mappings"` // scope to permissions mapping + Enabled bool `json:"enabled"` + Issuer string `json:"issuer"` + AccessTokenTTL int `json:"access_token_ttl"` // in minutes + RefreshTokenTTL int `json:"refresh_token_ttl"` // in minutes + AllowedGrantTypes []string `json:"allowed_grant_types"` // client_credentials, authorization_code, refresh_token + RequirePKCE bool `json:"require_pkce"` // force PKCE for authorization code flow + JWTSigningAlgorithm string `json:"jwt_signing_algorithm"` + JWTKeyID string `json:"jwt_key_id"` + JWTPrivateKeyFile string `json:"jwt_private_key_file"` + AutoCreateUser bool `json:"auto_create_user"` // auto create user on first OAuth2 login + DefaultUserRole int `json:"default_user_role"` // default role for auto-created users + DefaultUserGroup string `json:"default_user_group"` // default group for auto-created users + ScopeMappings map[string][]string `json:"scope_mappings"` // scope to permissions mapping + MaxJWKSKeys int `json:"max_jwks_keys"` // maximum number of JWKS signing keys to retain + DefaultPrivateKeyPath string `json:"default_private_key_path"` // suggested private key file path } // 默认配置 @@ -35,6 +37,8 @@ var defaultOAuth2Settings = OAuth2Settings{ "api:write": {"write"}, "admin": {"admin"}, }, + MaxJWKSKeys: 3, + DefaultPrivateKeyPath: "/etc/new-api/oauth2-private.pem", } func init() { diff --git a/src/oauth/server.go b/src/oauth/server.go index 65f53b121..792311463 100644 --- a/src/oauth/server.go +++ b/src/oauth/server.go @@ -3,17 +3,34 @@ package oauth import ( "crypto/rand" "crypto/rsa" + "errors" "fmt" + "net/http" + "net/url" + "sort" + "strings" + "time" + "one-api/common" + "one-api/logger" + "one-api/model" "one-api/setting/system_setting" + "crypto/x509" + "encoding/pem" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + jwt "github.com/golang-jwt/jwt/v5" "github.com/lestrrat-go/jwx/v2/jwk" + "os" + "strconv" ) var ( - simplePrivateKey *rsa.PrivateKey - simpleJWKSSet jwk.Set + signingKeys = map[string]*rsa.PrivateKey{} + currentKeyID string + simpleJWKSSet jwk.Set + keyMeta = map[string]int64{} // kid -> created_at (unix) ) // InitOAuthServer 简化版OAuth2服务器初始化 @@ -24,15 +41,21 @@ func InitOAuthServer() error { return nil } - // 生成RSA私钥(简化版本) + // 生成RSA私钥,并设置当前 kid var err error - simplePrivateKey, err = rsa.GenerateKey(rand.Reader, 2048) + if settings.JWTKeyID == "" { + settings.JWTKeyID = "oauth2-key-1" + } + currentKeyID = settings.JWTKeyID + key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return fmt.Errorf("failed to generate RSA key: %w", err) } + signingKeys[currentKeyID] = key + keyMeta[currentKeyID] = time.Now().Unix() - // 创建JWKS - simpleJWKSSet, err = createSimpleJWKS(simplePrivateKey, settings.JWTKeyID) + // 创建JWKS,加入当前公钥 + simpleJWKSSet, err = createSimpleJWKS(key, currentKeyID) if err != nil { return fmt.Errorf("failed to create JWKS: %w", err) } @@ -41,6 +64,35 @@ func InitOAuthServer() error { return nil } +// EnsureInitialized lazily initializes signing keys and JWKS if OAuth2 is enabled but not yet ready +func EnsureInitialized() error { + settings := system_setting.GetOAuth2Settings() + if !settings.Enabled { + return nil + } + if len(signingKeys) > 0 && simpleJWKSSet != nil && currentKeyID != "" { + return nil + } + // generate one key and JWKS on demand + if settings.JWTKeyID == "" { + settings.JWTKeyID = fmt.Sprintf("oauth2-key-%d", time.Now().Unix()) + } + currentKeyID = settings.JWTKeyID + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + signingKeys[currentKeyID] = key + keyMeta[currentKeyID] = time.Now().Unix() + jwks, err := createSimpleJWKS(key, currentKeyID) + if err != nil { + return err + } + simpleJWKSSet = jwks + common.SysLog("OAuth2 lazy-initialized: signing key and JWKS ready") + return nil +} + // createSimpleJWKS 创建简单的JWKS func createSimpleJWKS(privateKey *rsa.PrivateKey, keyID string) (jwk.Set, error) { pubJWK, err := jwk.FromRaw(&privateKey.PublicKey) @@ -63,18 +115,938 @@ func GetJWKS() jwk.Set { return simpleJWKSSet } -// HandleTokenRequest 简化的令牌处理(临时实现) +// GetRSAPublicKey 返回当前用于签发的RSA公钥 +func GetRSAPublicKey() *rsa.PublicKey { + if k, ok := signingKeys[currentKeyID]; ok && k != nil { + return &k.PublicKey + } + return nil +} + +// GetPublicKeyByKid returns public key by kid if exists +func GetPublicKeyByKid(kid string) *rsa.PublicKey { + if k, ok := signingKeys[kid]; ok && k != nil { + return &k.PublicKey + } + return nil +} + +// RotateSigningKey generates a new RSA key, updates current kid, and adds to JWKS +func RotateSigningKey(newKid string) (string, error) { + if newKid == "" { + newKid = fmt.Sprintf("oauth2-key-%d", time.Now().Unix()) + } + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", err + } + signingKeys[newKid] = key + keyMeta[newKid] = time.Now().Unix() + // add to jwks set + pubJWK, err := jwk.FromRaw(&key.PublicKey) + if err == nil { + _ = pubJWK.Set(jwk.KeyIDKey, newKid) + _ = pubJWK.Set(jwk.AlgorithmKey, "RS256") + _ = pubJWK.Set(jwk.KeyUsageKey, "sig") + _ = simpleJWKSSet.AddKey(pubJWK) + } + currentKeyID = newKid + enforceKeyRetention() + return newKid, nil +} + +// GenerateAndPersistKey generates a new RSA key, writes to a server file, and rotates current kid +func GenerateAndPersistKey(path string, kid string, overwrite bool) (string, error) { + if kid == "" { + kid = fmt.Sprintf("oauth2-key-%d", time.Now().Unix()) + } + if _, err := os.Stat(path); err == nil && !overwrite { + return "", fmt.Errorf("file exists") + } else if err != nil && !errors.Is(err, os.ErrNotExist) { + return "", err + } + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", err + } + // write PKCS1 PEM with 0600 perms + der := x509.MarshalPKCS1PrivateKey(key) + blk := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der} + pemBytes := pem.EncodeToMemory(blk) + if err := os.WriteFile(path, pemBytes, 0600); err != nil { + return "", err + } + // rotate in memory + signingKeys[kid] = key + keyMeta[kid] = time.Now().Unix() + // add to jwks + pubJWK, err := jwk.FromRaw(&key.PublicKey) + if err == nil { + _ = pubJWK.Set(jwk.KeyIDKey, kid) + _ = pubJWK.Set(jwk.AlgorithmKey, "RS256") + _ = pubJWK.Set(jwk.KeyUsageKey, "sig") + _ = simpleJWKSSet.AddKey(pubJWK) + } + currentKeyID = kid + enforceKeyRetention() + return kid, nil +} + +// ListSigningKeys returns metadata of keys +type KeyInfo struct { + Kid string `json:"kid"` + CreatedAt int64 `json:"created_at"` + Current bool `json:"current"` +} + +func ListSigningKeys() []KeyInfo { + out := make([]KeyInfo, 0, len(signingKeys)) + for kid := range signingKeys { + out = append(out, KeyInfo{Kid: kid, CreatedAt: keyMeta[kid], Current: kid == currentKeyID}) + } + // sort by CreatedAt asc + sort.Slice(out, func(i, j int) bool { return out[i].CreatedAt < out[j].CreatedAt }) + return out +} + +// DeleteSigningKey removes a non-current key +func DeleteSigningKey(kid string) error { + if kid == "" { + return fmt.Errorf("kid required") + } + if kid == currentKeyID { + return fmt.Errorf("cannot delete current signing key") + } + if _, ok := signingKeys[kid]; !ok { + return fmt.Errorf("unknown kid") + } + delete(signingKeys, kid) + delete(keyMeta, kid) + rebuildJWKS() + return nil +} + +func rebuildJWKS() { + jwks := jwk.NewSet() + for kid, k := range signingKeys { + pub, err := jwk.FromRaw(&k.PublicKey) + if err == nil { + _ = pub.Set(jwk.KeyIDKey, kid) + _ = pub.Set(jwk.AlgorithmKey, "RS256") + _ = pub.Set(jwk.KeyUsageKey, "sig") + _ = jwks.AddKey(pub) + } + } + simpleJWKSSet = jwks +} + +func enforceKeyRetention() { + max := system_setting.GetOAuth2Settings().MaxJWKSKeys + if max <= 0 { + max = 1 + } + // retain max most recent keys + infos := ListSigningKeys() + if len(infos) <= max { + return + } + // delete oldest first, skipping current + toDelete := len(infos) - max + for _, ki := range infos { + if toDelete == 0 { + break + } + if ki.Kid == currentKeyID { + continue + } + _ = DeleteSigningKey(ki.Kid) + toDelete-- + } +} + +// ImportPEMKey imports an RSA private key from PEM text and rotates current kid +func ImportPEMKey(pemText string, kid string) (string, error) { + if kid == "" { + kid = fmt.Sprintf("oauth2-key-%d", time.Now().Unix()) + } + // decode PEM + var block *pem.Block + var rest = []byte(pemText) + for { + block, rest = pem.Decode(rest) + if block == nil { + break + } + if block.Type == "RSA PRIVATE KEY" || strings.Contains(block.Type, "PRIVATE KEY") { + var key *rsa.PrivateKey + var err error + if block.Type == "RSA PRIVATE KEY" { + key, err = x509.ParsePKCS1PrivateKey(block.Bytes) + } else { + // try PKCS#8 + priv, err2 := x509.ParsePKCS8PrivateKey(block.Bytes) + if err2 != nil { + return "", err2 + } + var ok bool + key, ok = priv.(*rsa.PrivateKey) + if !ok { + return "", fmt.Errorf("not an RSA private key") + } + } + if err != nil { + return "", err + } + signingKeys[kid] = key + keyMeta[kid] = time.Now().Unix() + pubJWK, err := jwk.FromRaw(&key.PublicKey) + if err == nil { + _ = pubJWK.Set(jwk.KeyIDKey, kid) + _ = pubJWK.Set(jwk.AlgorithmKey, "RS256") + _ = pubJWK.Set(jwk.KeyUsageKey, "sig") + _ = simpleJWKSSet.AddKey(pubJWK) + } + currentKeyID = kid + enforceKeyRetention() + return kid, nil + } + if len(rest) == 0 { + break + } + } + return "", fmt.Errorf("no private key found in PEM") +} + +// HandleTokenRequest 实现最小可用的令牌签发(client_credentials) func HandleTokenRequest(c *gin.Context) { - c.JSON(501, map[string]string{ - "error": "not_implemented", - "error_description": "OAuth2 token endpoint not fully implemented yet", + settings := system_setting.GetOAuth2Settings() + + grantType := strings.TrimSpace(c.PostForm("grant_type")) + if grantType == "" { + writeOAuthError(c, http.StatusBadRequest, "invalid_request", "missing grant_type") + return + } + if !settings.ValidateGrantType(grantType) { + writeOAuthError(c, http.StatusBadRequest, "unsupported_grant_type", "grant_type not allowed") + return + } + + switch grantType { + case "client_credentials": + handleClientCredentials(c, settings) + case "refresh_token": + handleRefreshToken(c, settings) + case "authorization_code": + handleAuthorizationCodeExchange(c, settings) + default: + writeOAuthError(c, http.StatusBadRequest, "unsupported_grant_type", "unsupported grant_type") + } +} + +func handleClientCredentials(c *gin.Context, settings *system_setting.OAuth2Settings) { + clientID, clientSecret := getFormOrBasicAuth(c) + if clientID == "" || clientSecret == "" { + writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "missing client credentials") + return + } + + client, err := model.GetOAuthClientByID(clientID) + if err != nil { + writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "unknown client") + return + } + if client.Secret != clientSecret { + writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "invalid client secret") + return + } + // client type can be confidential or public; client_credentials only for confidential + if client.ClientType == "public" { + writeOAuthError(c, http.StatusBadRequest, "unauthorized_client", "public client cannot use client_credentials") + return + } + if !client.ValidateGrantType("client_credentials") { + writeOAuthError(c, http.StatusBadRequest, "unauthorized_client", "grant_type not enabled for client") + return + } + + scope := strings.TrimSpace(c.PostForm("scope")) + if scope == "" { + // default to client's first scope or api:read + allowed := client.GetScopes() + if len(allowed) == 0 { + scope = "api:read" + } else { + scope = strings.Join(allowed, " ") + } + } + if !client.ValidateScope(scope) { + writeOAuthError(c, http.StatusBadRequest, "invalid_scope", "requested scope not allowed") + return + } + + // issue JWT access token + accessTTL := time.Duration(settings.AccessTokenTTL) * time.Minute + tokenStr, exp, jti, err := signAccessToken(settings, clientID, "", scope, "client_credentials", accessTTL, c) + if err != nil { + writeOAuthError(c, http.StatusInternalServerError, "server_error", "failed to issue token") + return + } + + // update client usage + _ = client.UpdateLastUsedTime() + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenStr, + "token_type": "Bearer", + "expires_in": int64(exp.Sub(time.Now()).Seconds()), + "scope": scope, + "jti": jti, }) } +// handleAuthorizationCodeExchange 处理授权码换取令牌 +func handleAuthorizationCodeExchange(c *gin.Context, settings *system_setting.OAuth2Settings) { + // Redis not required; fallback to in-memory store + clientID, clientSecret := getFormOrBasicAuth(c) + code := strings.TrimSpace(c.PostForm("code")) + redirectURI := strings.TrimSpace(c.PostForm("redirect_uri")) + codeVerifier := strings.TrimSpace(c.PostForm("code_verifier")) + + if clientID == "" { + writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "missing client_id") + return + } + client, err := model.GetOAuthClientByID(clientID) + if err != nil { + writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "unknown client") + return + } + if client.ClientType == "confidential" { + if clientSecret == "" || client.Secret != clientSecret { + writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "invalid client secret") + return + } + } + if !client.ValidateGrantType("authorization_code") { + writeOAuthError(c, http.StatusBadRequest, "unauthorized_client", "authorization_code not enabled for client") + return + } + if redirectURI == "" || !client.ValidateRedirectURI(redirectURI) { + writeOAuthError(c, http.StatusBadRequest, "invalid_request", "redirect_uri mismatch or missing") + return + } + if code == "" { + writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "missing code") + return + } + + // 从Redis获取授权码数据 + key := fmt.Sprintf("oauth:code:%s", code) + raw, ok := storeGet(key) + if !ok || raw == "" { + writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "invalid or expired code") + return + } + + // 解析:clientID|redirectURI|scope|userID|codeChallenge|codeChallengeMethod|exp[|nonce] + parts := strings.Split(raw, "|") + if len(parts) < 7 { + writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "malformed code payload") + return + } + payloadClientID := parts[0] + payloadRedirectURI := parts[1] + payloadScope := parts[2] + payloadUserIDStr := parts[3] + payloadCodeChallenge := parts[4] + payloadCodeChallengeMethod := parts[5] + // parts[6] = exp (unused here) + var payloadNonce string + if len(parts) >= 8 { + payloadNonce = parts[7] + } + // 单次使用:删除授权码 + _ = storeDel(key) + + if payloadClientID != clientID { + writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "client_id mismatch") + return + } + if payloadRedirectURI != redirectURI { + writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "redirect_uri mismatch") + return + } + // PKCE 校验 + requirePKCE := settings.RequirePKCE || client.RequirePKCE + if requirePKCE || payloadCodeChallenge != "" { + if codeVerifier == "" { + writeOAuthError(c, http.StatusBadRequest, "invalid_request", "missing code_verifier") + return + } + method := strings.ToUpper(payloadCodeChallengeMethod) + if method == "" { + method = "S256" + } + switch method { + case "S256": + if s256Base64URL(codeVerifier) != payloadCodeChallenge { + writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "code_verifier mismatch") + return + } + default: + writeOAuthError(c, http.StatusBadRequest, "invalid_request", "unsupported code_challenge_method") + return + } + } + + // 颁发令牌 + scope := payloadScope + userIDStr := payloadUserIDStr + accessTTL := time.Duration(settings.AccessTokenTTL) * time.Minute + tokenStr, exp, jti, err := signAccessToken(settings, clientID, userIDStr, scope, "authorization_code", accessTTL, c) + if err != nil { + writeOAuthError(c, http.StatusInternalServerError, "server_error", "failed to issue token") + return + } + + // 可选:签发刷新令牌(仅当允许) + resp := gin.H{ + "access_token": tokenStr, + "token_type": "Bearer", + "expires_in": int64(exp.Sub(time.Now()).Seconds()), + "scope": scope, + "jti": jti, + } + // OIDC: 当 scope 包含 openid 时,签发 id_token + if strings.Contains(" "+scope+" ", " openid ") { + idt, err := signIDToken(settings, clientID, payloadUserIDStr, payloadNonce, c) + if err == nil { + resp["id_token"] = idt + } + } + if settings.ValidateGrantType("refresh_token") && client.ValidateGrantType("refresh_token") { + rt, err := genCode(32) + if err == nil { + ttl := time.Duration(settings.RefreshTokenTTL) * time.Minute + rtKey := fmt.Sprintf("oauth:rt:%s", rt) + // 存储 clientID|userID|scope|nonce(便于刷新时维持 openid/nonce) + val := fmt.Sprintf("%s|%s|%s|%s", clientID, userIDStr, scope, payloadNonce) + _ = storeSet(rtKey, val, ttl) + resp["refresh_token"] = rt + } + } + + _ = client.UpdateLastUsedTime() + writeNoStore(c) + c.JSON(http.StatusOK, resp) +} + +// handleRefreshToken 刷新令牌 +func handleRefreshToken(c *gin.Context, settings *system_setting.OAuth2Settings) { + // Redis not required; fallback to in-memory store + clientID, clientSecret := getFormOrBasicAuth(c) + refreshToken := strings.TrimSpace(c.PostForm("refresh_token")) + if clientID == "" { + writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "missing client_id") + return + } + client, err := model.GetOAuthClientByID(clientID) + if err != nil { + writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "unknown client") + return + } + if client.ClientType == "confidential" { + if clientSecret == "" || client.Secret != clientSecret { + writeOAuthError(c, http.StatusUnauthorized, "invalid_client", "invalid client secret") + return + } + } + if !client.ValidateGrantType("refresh_token") { + writeOAuthError(c, http.StatusBadRequest, "unauthorized_client", "refresh_token not enabled for client") + return + } + if refreshToken == "" { + writeOAuthError(c, http.StatusBadRequest, "invalid_request", "missing refresh_token") + return + } + key := fmt.Sprintf("oauth:rt:%s", refreshToken) + raw, ok := storeGet(key) + if !ok || raw == "" { + writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "invalid refresh_token") + return + } + // 解析值:clientID|userID|scope|nonce + parts := strings.Split(raw, "|") + if len(parts) < 3 { + writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "malformed refresh token") + return + } + storedClientID := parts[0] + userIDStr := parts[1] + scope := parts[2] + var nonce string + if len(parts) >= 4 { + nonce = parts[3] + } + if storedClientID != clientID { + writeOAuthError(c, http.StatusBadRequest, "invalid_grant", "client_id mismatch") + return + } + + // 旋转refresh_token:删除旧的,签发新的 + _ = storeDel(key) + newRT, err := genCode(32) + if err == nil { + ttl := time.Duration(settings.RefreshTokenTTL) * time.Minute + newKey := fmt.Sprintf("oauth:rt:%s", newRT) + _ = storeSet(newKey, raw, ttl) + } + + // 颁发新的访问令牌 + accessTTL := time.Duration(settings.AccessTokenTTL) * time.Minute + tokenStr, exp, jti, err := signAccessToken(settings, clientID, userIDStr, scope, "refresh_token", accessTTL, c) + if err != nil { + writeOAuthError(c, http.StatusInternalServerError, "server_error", "failed to issue token") + return + } + resp := gin.H{ + "access_token": tokenStr, + "token_type": "Bearer", + "expires_in": int64(exp.Sub(time.Now()).Seconds()), + "scope": scope, + "jti": jti, + } + if strings.Contains(" "+scope+" ", " openid ") { + if idt, err := signIDToken(settings, clientID, userIDStr, nonce, c); err == nil { + resp["id_token"] = idt + } + } + if newRT != "" { + resp["refresh_token"] = newRT + } + writeNoStore(c) + c.JSON(http.StatusOK, resp) +} + +// signAccessToken 使用内置RSA私钥签发JWT访问令牌 +func signAccessToken(settings *system_setting.OAuth2Settings, clientID string, subject string, scope string, grantType string, ttl time.Duration, c *gin.Context) (string, time.Time, string, error) { + now := time.Now() + exp := now.Add(ttl) + jti := common.GetUUID() + iss := settings.Issuer + if iss == "" { + // derive from requestd + scheme := "https" + if c != nil && c.Request != nil { + if c.Request.TLS == nil { + if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" { + scheme = hdr + } else { + scheme = "http" + } + } + host := c.Request.Host + if host != "" { + iss = fmt.Sprintf("%s://%s", scheme, host) + } + } + } + + claims := jwt.MapClaims{ + "iss": iss, + "sub": func() string { + if subject != "" { + return subject + } + return clientID + }(), + "aud": "one-api", + "iat": now.Unix(), + "nbf": now.Unix(), + "exp": exp.Unix(), + "scope": scope, + "client_id": clientID, + "grant_type": grantType, + "jti": jti, + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + // set kid + kid := currentKeyID + if kid != "" { + token.Header["kid"] = kid + } + k := signingKeys[kid] + if k == nil { + return "", time.Time{}, "", errors.New("signing key missing") + } + signed, err := token.SignedString(k) + if err != nil { + return "", time.Time{}, "", err + } + return signed, exp, jti, nil +} + +// signIDToken 生成 OIDC id_token +func signIDToken(settings *system_setting.OAuth2Settings, clientID string, subject string, nonce string, c *gin.Context) (string, error) { + k := signingKeys[currentKeyID] + if k == nil { + return "", errors.New("oauth private key not initialized") + } + // derive issuer similar to access token + iss := settings.Issuer + if iss == "" && c != nil && c.Request != nil { + scheme := "https" + if c.Request.TLS == nil { + if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" { + scheme = hdr + } else { + scheme = "http" + } + } + host := c.Request.Host + if host != "" { + iss = fmt.Sprintf("%s://%s", scheme, host) + } + } + now := time.Now() + exp := now.Add(10 * time.Minute) // id_token 短时有效 + + claims := jwt.MapClaims{ + "iss": iss, + "sub": subject, + "aud": clientID, + "iat": now.Unix(), + "exp": exp.Unix(), + } + if nonce != "" { + claims["nonce"] = nonce + } + + // 可选:附加 profile / email claims 由上层根据 scope 决定 + if uid, err := strconv.Atoi(subject); err == nil { + if user, err2 := model.GetUserById(uid, false); err2 == nil && user != nil { + if user.Username != "" { + claims["preferred_username"] = user.Username + claims["name"] = user.DisplayName + } + if user.Email != "" { + claims["email"] = user.Email + claims["email_verified"] = true + } + } + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + if currentKeyID != "" { + token.Header["kid"] = currentKeyID + } + return token.SignedString(k) +} + // HandleAuthorizeRequest 简化的授权处理(临时实现) func HandleAuthorizeRequest(c *gin.Context) { - c.JSON(501, map[string]string{ - "error": "not_implemented", - "error_description": "OAuth2 authorize endpoint not fully implemented yet", + settings := system_setting.GetOAuth2Settings() + // Redis not required; fallback to in-memory store + + // 解析参数 + responseType := c.Query("response_type") + clientID := c.Query("client_id") + redirectURI := c.Query("redirect_uri") + scope := strings.TrimSpace(c.Query("scope")) + state := c.Query("state") + codeChallenge := c.Query("code_challenge") + codeChallengeMethod := strings.ToUpper(c.Query("code_challenge_method")) + nonce := c.Query("nonce") + + if responseType == "" { + responseType = "code" + } + if responseType != "code" && responseType != "token" { + c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_response_type"}) + return + } + if clientID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request", "error_description": "missing client_id"}) + return + } + client, err := model.GetOAuthClientByID(clientID) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_client"}) + return + } + // 对于 implicit (response_type=token),允许客户端拥有 authorization_code 或 implicit 任一权限 + if responseType == "code" { + if !client.ValidateGrantType("authorization_code") { + c.JSON(http.StatusBadRequest, gin.H{"error": "unauthorized_client"}) + return + } + } else { + if !(client.ValidateGrantType("implicit") || client.ValidateGrantType("authorization_code")) { + c.JSON(http.StatusBadRequest, gin.H{"error": "unauthorized_client"}) + return + } + } + // 严格匹配或本地回环地址宽松匹配(忽略端口,遵循 RFC 8252) + validRedirect := client.ValidateRedirectURI(redirectURI) + if !validRedirect { + if isLoopbackRedirectAllowed(redirectURI, client.GetRedirectURIs()) { + validRedirect = true + } + } + if redirectURI == "" || !validRedirect { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request", "error_description": "redirect_uri mismatch or missing"}) + return + } + + // 支持前端预取信息 + mode := c.Query("mode") // mode=prepare 返回JSON供前端展示 + + // 校验scope + if scope == "" { + scope = strings.Join(client.GetScopes(), " ") + } else if !client.ValidateScope(scope) { + writeOAuthRedirectError(c, redirectURI, "invalid_scope", "requested scope not allowed", state) + return + } + + // PKCE 要求 + if responseType == "code" && (settings.RequirePKCE || client.RequirePKCE) { + if codeChallenge == "" { + writeOAuthRedirectError(c, redirectURI, "invalid_request", "code_challenge required", state) + return + } + if codeChallengeMethod == "" { + codeChallengeMethod = "S256" + } + if codeChallengeMethod != "S256" { + writeOAuthRedirectError(c, redirectURI, "invalid_request", "unsupported code_challenge_method", state) + return + } + } + + // 检查用户会话(要求已登录) + sess := sessions.Default(c) + uidVal := sess.Get("id") + if uidVal == nil { + if mode == "prepare" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "login_required"}) + return + } + // 重定向到前端登录后回到同意页 + consentPath := "/oauth/consent?" + c.Request.URL.RawQuery + loginPath := "/login?next=" + url.QueryEscape(consentPath) + writeNoStore(c) + c.Redirect(http.StatusFound, loginPath) + return + } + userID, _ := uidVal.(int) + if userID == 0 { + // 某些 session 库会将数字解码为 int64 + if v64, ok := uidVal.(int64); ok { + userID = int(v64) + } + } + if userID == 0 { + writeOAuthRedirectError(c, redirectURI, "login_required", "user not logged in", state) + return + } + + // prepare 模式:返回前端展示信息 + if mode == "prepare" { + // 解析重定向域名 + rHost := "" + if u, err := url.Parse(redirectURI); err == nil { + rHost = u.Hostname() + } + verified := false + if client.Domain != "" && rHost != "" { + verified = strings.EqualFold(client.Domain, rHost) + } + // scope 明细 + scopeNames := strings.Fields(scope) + type scopeItem struct{ Name, Description string } + var scopeInfo []scopeItem + for _, s := range scopeNames { + d := "" + switch s { + case "openid": + d = "访问你的基础身份 (sub)" + case "profile": + d = "读取你的公开资料 (昵称/用户名)" + case "email": + d = "读取你的邮箱地址" + case "api:read": + d = "读取 API 资源" + case "api:write": + d = "写入/修改 API 资源" + case "admin": + d = "管理权限 (高危)" + default: + d = "" + } + scopeInfo = append(scopeInfo, scopeItem{Name: s, Description: d}) + } + // 当前用户信息(用于展示) + var userName, userEmail string + if user, err := model.GetUserById(userID, false); err == nil && user != nil { + userName = user.DisplayName + if userName == "" { + userName = user.Username + } + userEmail = user.Email + } + c.JSON(http.StatusOK, gin.H{ + "client": gin.H{ + "id": client.ID, + "name": client.Name, + "type": client.ClientType, + "desc": client.Description, + "domain": client.Domain, + }, + "scope": scope, + "scope_list": scopeNames, + "scope_info": scopeInfo, + "redirect_uri": redirectURI, + "redirect_host": rHost, + "verified": verified, + "state": state, + "response_type": responseType, + "require_pkce": (responseType == "code") && (settings.RequirePKCE || client.RequirePKCE), + "user": gin.H{ + "id": userID, + "name": userName, + "email": userEmail, + }, + }) + return + } + + // 拒绝授权:返回错误给回调地址 + if c.Query("deny") == "1" || strings.EqualFold(c.Query("decision"), "deny") { + logger.LogInfo(c, fmt.Sprintf("oauth consent denied: user=%v client=%s scope=%s redirect=%s", sess.Get("id"), clientID, scope, redirectURI)) + writeOAuthRedirectError(c, redirectURI, "access_denied", "user denied the request", state) + return + } + + // 未明确选择,跳转前端同意页 + if !(c.Query("approve") == "1" || strings.EqualFold(c.Query("decision"), "approve")) { + consentPath := "/oauth/consent?" + c.Request.URL.RawQuery + writeNoStore(c) + c.Redirect(http.StatusFound, consentPath) + return + } + + // 根据响应类型返回 + if responseType == "code" { + // 生成授权码,写入 存储(短TTL) + code, err := genCode(32) + if err != nil { + writeOAuthRedirectError(c, redirectURI, "server_error", "failed to generate code", state) + return + } + ttl := 2 * time.Minute + exp := time.Now().Add(ttl).Unix() + // 存储 clientID|redirectURI|scope|userID|codeChallenge|codeChallengeMethod|exp|nonce + val := fmt.Sprintf("%s|%s|%s|%d|%s|%s|%d|%s", clientID, redirectURI, scope, userID, codeChallenge, codeChallengeMethod, exp, nonce) + key := fmt.Sprintf("oauth:code:%s", code) + if err := storeSet(key, val, ttl); err != nil { + writeOAuthRedirectError(c, redirectURI, "server_error", "failed to store code", state) + return + } + logger.LogInfo(c, fmt.Sprintf("oauth consent approved (code): user=%d client=%s scope=%s redirect=%s", userID, clientID, scope, redirectURI)) + + // 成功,重定向(查询参数) + u, _ := url.Parse(redirectURI) + q := u.Query() + q.Set("code", code) + if state != "" { + q.Set("state", state) + } + u.RawQuery = q.Encode() + writeNoStore(c) + c.Redirect(http.StatusFound, u.String()) + return + } + + // response_type=token (implicit) + // 直接签发 Access Token(不下发 Refresh Token) + accessTTL := time.Duration(settings.AccessTokenTTL) * time.Minute + userIDStr := fmt.Sprintf("%d", userID) + tokenStr, expTime, jti, err := signAccessToken(settings, clientID, userIDStr, scope, "implicit", accessTTL, c) + if err != nil { + writeOAuthRedirectError(c, redirectURI, "server_error", "failed to issue token", state) + return + } + _ = client.UpdateLastUsedTime() + logger.LogInfo(c, fmt.Sprintf("oauth consent approved (token): user=%d client=%s scope=%s redirect=%s jti=%s", userID, clientID, scope, redirectURI, jti)) + + // 使用 fragment 传递(#access_token=...) + u, _ := url.Parse(redirectURI) + frag := url.Values{} + frag.Set("access_token", tokenStr) + frag.Set("token_type", "Bearer") + frag.Set("expires_in", fmt.Sprintf("%d", int64(expTime.Sub(time.Now()).Seconds()))) + if scope != "" { + frag.Set("scope", scope) + } + if state != "" { + frag.Set("state", state) + } + u.Fragment = frag.Encode() + writeNoStore(c) + c.Redirect(http.StatusFound, u.String()) +} + +func writeOAuthError(c *gin.Context, status int, code, description string) { + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.JSON(status, gin.H{ + "error": code, + "error_description": description, }) } + +// isLoopback returns true if hostname represents a local loopback host +func isLoopback(host string) bool { + if host == "" { + return false + } + h := strings.ToLower(host) + if h == "localhost" || h == "::1" { + return true + } + if strings.HasPrefix(h, "127.") { + return true + } + return false +} + +// isLoopbackRedirectAllowed allows redirect URIs on loopback hosts to match ignoring port +// This follows OAuth 2.0 for Native Apps (RFC 8252) guidance to use loopback interface with dynamic port. +func isLoopbackRedirectAllowed(requested string, allowed []string) bool { + if requested == "" || len(allowed) == 0 { + return false + } + ru, err := url.Parse(requested) + if err != nil { + return false + } + if !isLoopback(ru.Hostname()) { + return false + } + for _, a := range allowed { + au, err := url.Parse(a) + if err != nil { + continue + } + if !isLoopback(au.Hostname()) { + continue + } + // require same scheme and path; ignore port and host variant among loopback + if strings.EqualFold(ru.Scheme, au.Scheme) && ru.Path == au.Path { + return true + } + } + return false +} diff --git a/src/oauth/store.go b/src/oauth/store.go new file mode 100644 index 000000000..5d560af31 --- /dev/null +++ b/src/oauth/store.go @@ -0,0 +1,82 @@ +package oauth + +import ( + "one-api/common" + "sync" + "time" +) + +// KVStore is a minimal TTL key-value abstraction used by OAuth flows. +type KVStore interface { + Set(key, value string, ttl time.Duration) error + Get(key string) (string, bool) + Del(key string) error +} + +type redisStore struct{} + +func (r *redisStore) Set(key, value string, ttl time.Duration) error { + return common.RedisSet(key, value, ttl) +} +func (r *redisStore) Get(key string) (string, bool) { + v, err := common.RedisGet(key) + if err != nil || v == "" { + return "", false + } + return v, true +} +func (r *redisStore) Del(key string) error { + return common.RedisDel(key) +} + +type memEntry struct { + val string + exp int64 // unix seconds, 0 means no expiry +} + +type memoryStore struct { + m sync.Map // key -> memEntry +} + +func (m *memoryStore) Set(key, value string, ttl time.Duration) error { + var exp int64 + if ttl > 0 { + exp = time.Now().Add(ttl).Unix() + } + m.m.Store(key, memEntry{val: value, exp: exp}) + return nil +} + +func (m *memoryStore) Get(key string) (string, bool) { + v, ok := m.m.Load(key) + if !ok { + return "", false + } + e := v.(memEntry) + if e.exp > 0 && time.Now().Unix() > e.exp { + m.m.Delete(key) + return "", false + } + return e.val, true +} + +func (m *memoryStore) Del(key string) error { + m.m.Delete(key) + return nil +} + +var ( + memStore = &memoryStore{} + rdsStore = &redisStore{} +) + +func getStore() KVStore { + if common.RedisEnabled { + return rdsStore + } + return memStore +} + +func storeSet(key, val string, ttl time.Duration) error { return getStore().Set(key, val, ttl) } +func storeGet(key string) (string, bool) { return getStore().Get(key) } +func storeDel(key string) error { return getStore().Del(key) } diff --git a/src/oauth/util.go b/src/oauth/util.go new file mode 100644 index 000000000..01e2aacdc --- /dev/null +++ b/src/oauth/util.go @@ -0,0 +1,59 @@ +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "net/http" + "net/url" + "strings" + + "github.com/gin-gonic/gin" +) + +// getFormOrBasicAuth extracts client_id/client_secret from Basic Auth first, then form +func getFormOrBasicAuth(c *gin.Context) (clientID, clientSecret string) { + id, secret, ok := c.Request.BasicAuth() + if ok { + return strings.TrimSpace(id), strings.TrimSpace(secret) + } + return strings.TrimSpace(c.PostForm("client_id")), strings.TrimSpace(c.PostForm("client_secret")) +} + +// genCode generates URL-safe random string based on nBytes of entropy +func genCode(nBytes int) (string, error) { + b := make([]byte, nBytes) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// s256Base64URL computes base64url-encoded SHA256 digest +func s256Base64URL(verifier string) string { + sum := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(sum[:]) +} + +// writeNoStore sets no-store cache headers for OAuth responses +func writeNoStore(c *gin.Context) { + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") +} + +// writeOAuthRedirectError builds an error redirect to redirect_uri as RFC6749 +func writeOAuthRedirectError(c *gin.Context, redirectURI, errCode, description, state string) { + writeNoStore(c) + q := "error=" + url.QueryEscape(errCode) + if description != "" { + q += "&error_description=" + url.QueryEscape(description) + } + if state != "" { + q += "&state=" + url.QueryEscape(state) + } + sep := "?" + if strings.Contains(redirectURI, "?") { + sep = "&" + } + c.Redirect(http.StatusFound, redirectURI+sep+q) +} diff --git a/web/public/oauth-demo.html b/web/public/oauth-demo.html new file mode 100644 index 000000000..ba5821d32 --- /dev/null +++ b/web/public/oauth-demo.html @@ -0,0 +1,167 @@ + + + + + + + OAuth2/OIDC 授权码 + PKCE 前端演示 + + + +
+

OAuth2/OIDC 授权码 + PKCE 前端演示

+
+
+
+ + +
+
提示:若未配置 Issuer,可直接填写下方端点。
+
+
+
+
+ + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + + + + +
+
+
+ + +
+
+
+
+
说明: +
    +
  • 本页为纯前端演示,适用于公开客户端(不需要 client_secret)。
  • +
  • 如跨域调用 Token/UserInfo,需要服务端正确设置 CORS;建议将此 demo 部署到同源域名下。
  • +
+
+
+
+
+ + +
+ + +
+
可将服务端返回的 OIDC Discovery JSON 粘贴到此处,点击“解析并填充端点”。
+
+
+
+
+
等待授权...
+
+
+ + +
+
+
+
+ + +
+

+        
+
+
+
+ + +
+
+
+ + +
+
+
+
+ + + diff --git a/web/src/App.jsx b/web/src/App.jsx index 635742f91..4baf5c42b 100644 --- a/web/src/App.jsx +++ b/web/src/App.jsx @@ -44,6 +44,7 @@ import Task from './pages/Task'; import ModelPage from './pages/Model'; import Playground from './pages/Playground'; import OAuth2Callback from './components/auth/OAuth2Callback'; +import OAuthConsent from './pages/OAuth/Consent'; import PersonalSetting from './components/settings/PersonalSetting'; import Setup from './pages/Setup'; import SetupCheck from './components/layout/SetupCheck'; @@ -198,6 +199,14 @@ function App() { } /> + }> + + + } + /> . For commercial licensing, please contact support@quantumnous.com */ -import React, { useState } from 'react'; +import React, { useEffect, useMemo, useState } from 'react'; import { Modal, Form, @@ -40,17 +40,128 @@ const { Option } = Select; const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => { const [formApi, setFormApi] = useState(null); const [loading, setLoading] = useState(false); - const [redirectUris, setRedirectUris] = useState(['']); + const [redirectUris, setRedirectUris] = useState([]); const [clientType, setClientType] = useState('confidential'); const [grantTypes, setGrantTypes] = useState(['client_credentials']); + const [allowedGrantTypes, setAllowedGrantTypes] = useState([ + 'client_credentials', + 'authorization_code', + 'refresh_token', + ]); + + // 加载后端允许的授权类型(用于限制和默认值) + useEffect(() => { + let mounted = true; + (async () => { + try { + const res = await API.get('/api/option/'); + const { success, data } = res.data || {}; + if (!success || !Array.isArray(data)) return; + const found = data.find((i) => i.key === 'oauth2.allowed_grant_types'); + if (!found) return; + let parsed = []; + try { + parsed = JSON.parse(found.value || '[]'); + } catch (_) {} + if (mounted && Array.isArray(parsed) && parsed.length) { + setAllowedGrantTypes(parsed); + } + } catch (_) { + // 忽略错误,使用默认allowedGrantTypes + } + })(); + return () => { + mounted = false; + }; + }, []); + + const computeDefaultGrantTypes = (type, allowed) => { + const cand = + type === 'public' + ? ['authorization_code', 'refresh_token'] + : ['client_credentials', 'authorization_code', 'refresh_token']; + const subset = cand.filter((g) => allowed.includes(g)); + return subset.length ? subset : [allowed[0]].filter(Boolean); + }; + + // 当允许的类型或客户端类型变化时,自动设置更合理的默认值 + useEffect(() => { + setGrantTypes((prev) => { + const normalizedPrev = Array.isArray(prev) ? prev : []; + // 移除不被允许或与客户端类型冲突的类型 + let next = normalizedPrev.filter((g) => allowedGrantTypes.includes(g)); + if (clientType === 'public') { + next = next.filter((g) => g !== 'client_credentials'); + } + // 如果为空,则使用计算的默认 + if (!next.length) { + next = computeDefaultGrantTypes(clientType, allowedGrantTypes); + } + return next; + }); + }, [clientType, allowedGrantTypes]); + + const isGrantTypeDisabled = (value) => { + if (!allowedGrantTypes.includes(value)) return true; + if (clientType === 'public' && value === 'client_credentials') return true; + return false; + }; + + // URL校验:允许 http(s),本地开发可 http + const isValidRedirectUri = (uri) => { + if (!uri || !uri.trim()) return false; + try { + const u = new URL(uri.trim()); + if (u.protocol !== 'https:' && u.protocol !== 'http:') return false; + if (u.protocol === 'http:') { + // 仅允许本地开发时使用 http + const host = u.hostname; + const isLocal = + host === 'localhost' || host === '127.0.0.1' || host.endsWith('.local'); + if (!isLocal) return false; + } + return true; + } catch (e) { + return false; + } + }; // 处理提交 const handleSubmit = async (values) => { setLoading(true); try { // 过滤空的重定向URI - const validRedirectUris = redirectUris.filter(uri => uri.trim()); - + const validRedirectUris = redirectUris + .map((u) => (u || '').trim()) + .filter((u) => u.length > 0); + + // 业务校验 + if (!grantTypes.length) { + showError('请至少选择一种授权类型'); + return; + } + // 校验是否包含不被允许的授权类型 + const invalids = grantTypes.filter((g) => !allowedGrantTypes.includes(g)); + if (invalids.length) { + showError(`不被允许的授权类型: ${invalids.join(', ')}`); + return; + } + if (clientType === 'public' && grantTypes.includes('client_credentials')) { + showError('公开客户端不允许使用client_credentials授权类型'); + return; + } + if (grantTypes.includes('authorization_code')) { + if (!validRedirectUris.length) { + showError('选择授权码授权类型时,必须填写至少一个重定向URI'); + return; + } + const allValid = validRedirectUris.every(isValidRedirectUri); + if (!allValid) { + showError('重定向URI格式不合法:仅支持https,或本地开发使用http'); + return; + } + } + const payload = { ...values, client_type: clientType, @@ -118,8 +229,8 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => { formApi.reset(); } setClientType('confidential'); - setGrantTypes(['client_credentials']); - setRedirectUris(['']); + setGrantTypes(computeDefaultGrantTypes('confidential', allowedGrantTypes)); + setRedirectUris([]); }; // 处理取消 @@ -149,9 +260,13 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => { const handleGrantTypesChange = (values) => { setGrantTypes(values); // 如果包含authorization_code但没有重定向URI,则添加一个 - if (values.includes('authorization_code') && redirectUris.length === 1 && !redirectUris[0]) { + if (values.includes('authorization_code') && redirectUris.length === 0) { setRedirectUris(['']); } + // 公开客户端不允许client_credentials + if (clientType === 'public' && values.includes('client_credentials')) { + setGrantTypes(values.filter((v) => v !== 'client_credentials')); + } }; return ( @@ -159,7 +274,7 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => { title="创建OAuth2客户端" visible={visible} onCancel={handleCancel} - onOk={() => formApi?.submit()} + onOk={() => formApi?.submitForm()} okText="创建" cancelText="取消" confirmLoading={loading} @@ -168,6 +283,12 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => { >
setFormApi(api)} + initValues={{ + // 表单默认值优化:预置 OIDC 常用 scope + scopes: ['openid', 'profile', 'email', 'api:read'], + require_pkce: true, + grant_types: grantTypes, + }} onSubmit={handleSubmit} labelPosition="top" > @@ -237,9 +358,15 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => { onChange={handleGrantTypesChange} rules={[{ required: true, message: '请选择至少一种授权类型' }]} > - - - + + + {/* Scope */} @@ -247,9 +374,11 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => { field="scopes" label="允许的权限范围(Scope)" multiple - defaultValue={['api:read']} rules={[{ required: true, message: '请选择至少一个权限范围' }]} > + + + @@ -259,20 +388,19 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => { PKCE(Proof Key for Code Exchange)可提高授权码流程的安全性。 {/* 重定向URI */} - {grantTypes.includes('authorization_code') && ( + {(grantTypes.includes('authorization_code') || redirectUris.length > 0) && ( <> 重定向URI配置
重定向URI - 用于授权码流程,用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP)。 + 用于授权码流程,用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP,仅限localhost/127.0.0.1)。 @@ -315,4 +443,4 @@ const CreateOAuth2ClientModal = ({ visible, onCancel, onSuccess }) => { ); }; -export default CreateOAuth2ClientModal; \ No newline at end of file +export default CreateOAuth2ClientModal; diff --git a/web/src/components/modals/oauth2/EditOAuth2ClientModal.jsx b/web/src/components/modals/oauth2/EditOAuth2ClientModal.jsx index 39729bba9..7eec45e3d 100644 --- a/web/src/components/modals/oauth2/EditOAuth2ClientModal.jsx +++ b/web/src/components/modals/oauth2/EditOAuth2ClientModal.jsx @@ -39,8 +39,39 @@ const { Option } = Select; const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { const [formApi, setFormApi] = useState(null); const [loading, setLoading] = useState(false); - const [redirectUris, setRedirectUris] = useState(['']); + const [redirectUris, setRedirectUris] = useState([]); const [grantTypes, setGrantTypes] = useState(['client_credentials']); + const [allowedGrantTypes, setAllowedGrantTypes] = useState([ + 'client_credentials', + 'authorization_code', + 'refresh_token', + ]); + + // 加载后端允许的授权类型 + useEffect(() => { + let mounted = true; + (async () => { + try { + const res = await API.get('/api/option/'); + const { success, data } = res.data || {}; + if (!success || !Array.isArray(data)) return; + const found = data.find((i) => i.key === 'oauth2.allowed_grant_types'); + if (!found) return; + let parsed = []; + try { + parsed = JSON.parse(found.value || '[]'); + } catch (_) {} + if (mounted && Array.isArray(parsed) && parsed.length) { + setAllowedGrantTypes(parsed); + } + } catch (_) { + // 忽略错误 + } + })(); + return () => { + mounted = false; + }; + }, []); // 初始化表单数据 useEffect(() => { @@ -60,9 +91,12 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { } else if (Array.isArray(client.scopes)) { parsedScopes = client.scopes; } + if (!parsedScopes || parsedScopes.length === 0) { + parsedScopes = ['openid', 'profile', 'email', 'api:read']; + } // 解析重定向URI - let parsedRedirectUris = ['']; + let parsedRedirectUris = []; if (client.redirect_uris) { try { const parsed = typeof client.redirect_uris === 'string' @@ -76,8 +110,20 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { } } - setGrantTypes(parsedGrantTypes); - setRedirectUris(parsedRedirectUris); + // 过滤不被允许或不兼容的授权类型 + const filteredGrantTypes = (parsedGrantTypes || []).filter((g) => + allowedGrantTypes.includes(g), + ); + const finalGrantTypes = client.client_type === 'public' + ? filteredGrantTypes.filter((g) => g !== 'client_credentials') + : filteredGrantTypes; + + setGrantTypes(finalGrantTypes); + if (finalGrantTypes.includes('authorization_code') && parsedRedirectUris.length === 0) { + setRedirectUris(['']); + } else { + setRedirectUris(parsedRedirectUris); + } // 设置表单值 const formValues = { @@ -87,7 +133,7 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { client_type: client.client_type, grant_types: parsedGrantTypes, scopes: parsedScopes, - require_pkce: client.require_pkce, + require_pkce: !!client.require_pkce, status: client.status, }; if (formApi) { @@ -101,7 +147,57 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { setLoading(true); try { // 过滤空的重定向URI - const validRedirectUris = redirectUris.filter(uri => uri.trim()); + const validRedirectUris = redirectUris + .map((u) => (u || '').trim()) + .filter((u) => u.length > 0); + + // 校验授权类型 + if (!grantTypes.length) { + showError('请至少选择一种授权类型'); + setLoading(false); + return; + } + const invalids = grantTypes.filter((g) => !allowedGrantTypes.includes(g)); + if (invalids.length) { + showError(`不被允许的授权类型: ${invalids.join(', ')}`); + setLoading(false); + return; + } + if (client?.client_type === 'public' && grantTypes.includes('client_credentials')) { + showError('公开客户端不允许使用client_credentials授权类型'); + setLoading(false); + return; + } + // 授权码需要有效重定向URI + const isValidRedirectUri = (uri) => { + if (!uri || !uri.trim()) return false; + try { + const u = new URL(uri.trim()); + if (u.protocol !== 'https:' && u.protocol !== 'http:') return false; + if (u.protocol === 'http:') { + const host = u.hostname; + const isLocal = + host === 'localhost' || host === '127.0.0.1' || host.endsWith('.local'); + if (!isLocal) return false; + } + return true; + } catch (e) { + return false; + } + }; + if (grantTypes.includes('authorization_code')) { + if (!validRedirectUris.length) { + showError('选择授权码授权类型时,必须填写至少一个重定向URI'); + setLoading(false); + return; + } + const allValid = validRedirectUris.every(isValidRedirectUri); + if (!allValid) { + showError('重定向URI格式不合法:仅支持https,或本地开发使用http'); + setLoading(false); + return; + } + } const payload = { ...values, @@ -146,9 +242,13 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { const handleGrantTypesChange = (values) => { setGrantTypes(values); // 如果包含authorization_code但没有重定向URI,则添加一个 - if (values.includes('authorization_code') && redirectUris.length === 1 && !redirectUris[0]) { + if (values.includes('authorization_code') && redirectUris.length === 0) { setRedirectUris(['']); } + // 公开客户端不允许client_credentials + if (client?.client_type === 'public' && values.includes('client_credentials')) { + setGrantTypes(values.filter((v) => v !== 'client_credentials')); + } }; if (!client) return null; @@ -158,7 +258,7 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { title={`编辑OAuth2客户端 - ${client.name}`} visible={visible} onCancel={onCancel} - onOk={() => formApi?.submit()} + onOk={() => formApi?.submitForm()} okText="保存" cancelText="取消" confirmLoading={loading} @@ -217,9 +317,17 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { onChange={handleGrantTypesChange} rules={[{ required: true, message: '请选择至少一种授权类型' }]} > - - - + + + {/* Scope */} @@ -229,6 +337,9 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { multiple rules={[{ required: true, message: '请选择至少一个权限范围' }]} > + + + @@ -254,13 +365,13 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { {/* 重定向URI */} - {grantTypes.includes('authorization_code') && ( + {(grantTypes.includes('authorization_code') || redirectUris.length > 0) && ( <> 重定向URI配置
重定向URI - 用于授权码流程,用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP)。 + 用于授权码流程,用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP,仅限localhost/127.0.0.1)。 @@ -303,4 +414,4 @@ const EditOAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => { ); }; -export default EditOAuth2ClientModal; \ No newline at end of file +export default EditOAuth2ClientModal; diff --git a/web/src/components/modals/oauth2/JWKSManagerModal.jsx b/web/src/components/modals/oauth2/JWKSManagerModal.jsx new file mode 100644 index 000000000..ef5d3c5c6 --- /dev/null +++ b/web/src/components/modals/oauth2/JWKSManagerModal.jsx @@ -0,0 +1,148 @@ +import React, { useEffect, useState } from 'react'; +import { Modal, Table, Button, Space, Tag, Typography, Popconfirm, Toast, Form, TextArea, Divider, Input } from '@douyinfe/semi-ui'; +import { IconRefresh, IconDelete, IconPlay } from '@douyinfe/semi-icons'; +import { API, showError, showSuccess } from '../../../helpers'; + +const { Text } = Typography; + +export default function JWKSManagerModal({ visible, onClose }) { + const [loading, setLoading] = useState(false); + const [keys, setKeys] = useState([]); + + const load = async () => { + setLoading(true); + try { + const res = await API.get('/api/oauth/keys'); + if (res?.data?.success) setKeys(res.data.data || []); + else showError(res?.data?.message || '获取密钥列表失败'); + } catch { showError('获取密钥列表失败'); } finally { setLoading(false); } + }; + + const rotate = async () => { + setLoading(true); + try { + const res = await API.post('/api/oauth/keys/rotate', {}); + if (res?.data?.success) { showSuccess('签名密钥已轮换:' + res.data.kid); await load(); } + else showError(res?.data?.message || '密钥轮换失败'); + } catch { showError('密钥轮换失败'); } finally { setLoading(false); } + }; + + const del = async (kid) => { + setLoading(true); + try { + const res = await API.delete(`/api/oauth/keys/${kid}`); + if (res?.data?.success) { Toast.success('已删除:' + kid); await load(); } + else showError(res?.data?.message || '删除失败'); + } catch { showError('删除失败'); } finally { setLoading(false); } + }; + + useEffect(() => { if (visible) load(); }, [visible]); + useEffect(() => { + if (!visible) return; + (async ()=>{ + try{ + const res = await API.get('/api/oauth/server-info'); + const p = res?.data?.default_private_key_path; + if (p) setGenPath(p); + }catch{} + })(); + }, [visible]); + + // Import PEM state + const [showImport, setShowImport] = useState(false); + const [pem, setPem] = useState(''); + const [customKid, setCustomKid] = useState(''); + const importPem = async () => { + if (!pem.trim()) return Toast.warning('请粘贴 PEM 私钥'); + setLoading(true); + try { + const res = await API.post('/api/oauth/keys/import_pem', { pem, kid: customKid.trim() }); + if (res?.data?.success) { + Toast.success('已导入私钥并切换到 kid=' + res.data.kid); + setPem(''); setCustomKid(''); setShowImport(false); + await load(); + } else { + Toast.error(res?.data?.message || '导入失败'); + } + } catch { Toast.error('导入失败'); } finally { setLoading(false); } + }; + + // Generate PEM file state + const [showGenerate, setShowGenerate] = useState(false); + const [genPath, setGenPath] = useState('/etc/new-api/oauth2-private.pem'); + const [genKid, setGenKid] = useState(''); + const generatePemFile = async () => { + if (!genPath.trim()) return Toast.warning('请填写保存路径'); + setLoading(true); + try { + const res = await API.post('/api/oauth/keys/generate_file', { path: genPath.trim(), kid: genKid.trim() }); + if (res?.data?.success) { + Toast.success('已生成并生效:' + res.data.path); + await load(); + } else { + Toast.error(res?.data?.message || '生成失败'); + } + } catch { Toast.error('生成失败'); } finally { setLoading(false); } + }; + + const columns = [ + { title: 'KID', dataIndex: 'kid', render: (kid) => {kid} }, + { title: '创建时间', dataIndex: 'created_at', render: (ts) => (ts ? new Date(ts * 1000).toLocaleString() : '-') }, + { title: '状态', dataIndex: 'current', render: (cur) => (cur ? 当前 : 历史) }, + { title: '操作', render: (_, r) => ( + + {!r.current && ( + del(r.kid)}> + + + )} + + ) }, + ]; + + return ( + + + + + + + + + {showGenerate && ( +
+ + + + +
+ +
+ + 建议:仅在合规要求下使用文件私钥。请确保目录权限安全(建议 0600),并妥善备份。 +
+ )} + {showImport && ( +
+
+ + + +
+ +
+ + 建议:优先使用内存签名密钥与 JWKS 轮换;仅在有合规要求时导入外部私钥。 +
+ )} + 暂无密钥} /> + + ); +} diff --git a/web/src/components/modals/oauth2/OAuth2QuickStartModal.jsx b/web/src/components/modals/oauth2/OAuth2QuickStartModal.jsx new file mode 100644 index 000000000..91559c580 --- /dev/null +++ b/web/src/components/modals/oauth2/OAuth2QuickStartModal.jsx @@ -0,0 +1,230 @@ +import React, { useEffect, useMemo, useState } from 'react'; +import { Modal, Steps, Form, Input, Select, Switch, Typography, Space, Button, Tag, Toast } from '@douyinfe/semi-ui'; +import { API, showError, showSuccess } from '../../../helpers'; + +const { Text } = Typography; + +export default function OAuth2QuickStartModal({ visible, onClose, onDone }) { + const origin = useMemo(() => window.location.origin, []); + const [step, setStep] = useState(0); + const [loading, setLoading] = useState(false); + + // Step state + const [enableOAuth, setEnableOAuth] = useState(true); + const [issuer, setIssuer] = useState(origin); + + const [clientType, setClientType] = useState('public'); + const [redirect1, setRedirect1] = useState(origin + '/oauth/oidc'); + const [redirect2, setRedirect2] = useState(''); + const [scopes, setScopes] = useState(['openid', 'profile', 'email', 'api:read']); + + // Results + const [createdClient, setCreatedClient] = useState(null); + + useEffect(() => { + if (!visible) { + setStep(0); + setLoading(false); + setEnableOAuth(true); + setIssuer(origin); + setClientType('public'); + setRedirect1(origin + '/oauth/oidc'); + setRedirect2(''); + setScopes(['openid', 'profile', 'email', 'api:read']); + setCreatedClient(null); + } + }, [visible, origin]); + + // 打开时读取现有配置作为默认值 + useEffect(() => { + if (!visible) return; + (async () => { + try { + const res = await API.get('/api/option/'); + const { success, data } = res.data || {}; + if (!success || !Array.isArray(data)) return; + const map = Object.fromEntries(data.map(i => [i.key, i.value])); + if (typeof map['oauth2.enabled'] !== 'undefined') { + setEnableOAuth(String(map['oauth2.enabled']).toLowerCase() === 'true'); + } + if (map['oauth2.issuer']) { + setIssuer(map['oauth2.issuer']); + } + } catch (_) {} + })(); + }, [visible]); + + const applyRecommended = async () => { + setLoading(true); + try { + const ops = [ + { key: 'oauth2.enabled', value: String(enableOAuth) }, + { key: 'oauth2.issuer', value: issuer || '' }, + { key: 'oauth2.allowed_grant_types', value: JSON.stringify(['authorization_code', 'refresh_token', 'client_credentials']) }, + { key: 'oauth2.require_pkce', value: 'true' }, + { key: 'oauth2.jwt_signing_algorithm', value: 'RS256' }, + ]; + for (const op of ops) { + await API.put('/api/option/', op); + } + showSuccess('已应用推荐配置'); + setStep(1); + onDone && onDone(); + } catch (e) { + showError('应用推荐配置失败'); + } finally { + setLoading(false); + } + }; + + const rotateKey = async () => { + setLoading(true); + try { + const res = await API.post('/api/oauth/keys/rotate', {}); + if (res?.data?.success) { + showSuccess('签名密钥已准备:' + res.data.kid); + } else { + showError(res?.data?.message || '签名密钥操作失败'); + return; + } + setStep(2); + } catch (e) { + showError('签名密钥操作失败'); + } finally { + setLoading(false); + } + }; + + const createClient = async () => { + setLoading(true); + try { + const grant_types = clientType === 'public' ? ['authorization_code', 'refresh_token'] : ['authorization_code', 'refresh_token', 'client_credentials']; + const payload = { + name: 'Default OIDC Client', + client_type: clientType, + grant_types, + redirect_uris: [redirect1, redirect2].filter(Boolean), + scopes, + require_pkce: true, + }; + const res = await API.post('/api/oauth_clients/', payload); + if (res?.data?.success) { + setCreatedClient({ id: res.data.client_id, secret: res.data.client_secret }); + showSuccess('客户端已创建'); + setStep(3); + } else { + showError(res?.data?.message || '创建失败'); + } + onDone && onDone(); + } catch (e) { + showError('创建失败'); + } finally { + setLoading(false); + } + }; + + const steps = [ + { + title: '应用推荐配置', + content: ( +
+
+
+
+ + + +
+
+ 说明 +
+ grant_types: auth_code / refresh_token / client_credentials + PKCE: S256 + 算法: RS256 +
+
+
+
+ +
+
+ ) + }, + { + title: '准备签名密钥', + content: ( +
+ 若无密钥则初始化;如已存在建议立即轮换以生成新的 kid 并发布到 JWKS。 +
+ +
+
+ ) + }, + { + title: '创建默认 OIDC 客户端', + content: ( +
+
+ + 公开客户端(SPA/移动端) + 机密客户端(服务端) + + + + + openid + profile + email + api:read + api:write + admin + + +
+ +
+
+ ) + }, + { + title: '完成', + content: ( +
+ {createdClient ? ( +
+ 客户端已创建: +
+ Client ID: {createdClient.id} +
+ {createdClient.secret && ( +
+ Client Secret(仅此一次展示): {createdClient.secret} +
+ )} +
+ ) : 已完成初始化。} +
+ ) + } + ]; + + return ( + + + {steps.map((s, idx) => )} + +
+ {steps[step].content} +
+
+ ); +} diff --git a/web/src/components/modals/oauth2/OAuth2ToolsModal.jsx b/web/src/components/modals/oauth2/OAuth2ToolsModal.jsx new file mode 100644 index 000000000..515954bc9 --- /dev/null +++ b/web/src/components/modals/oauth2/OAuth2ToolsModal.jsx @@ -0,0 +1,324 @@ +import React, { useEffect, useMemo, useState } from 'react'; +import { Modal, Form, Input, Button, Space, Select, Typography, Divider, Toast, TextArea } from '@douyinfe/semi-ui'; +import { API } from '../../../helpers'; + +const { Text } = Typography; + +async function sha256Base64Url(input) { + const enc = new TextEncoder(); + const data = enc.encode(input); + const hash = await crypto.subtle.digest('SHA-256', data); + const bytes = new Uint8Array(hash); + let binary = ''; + for (let i = 0; i < bytes.byteLength; i++) binary += String.fromCharCode(bytes[i]); + return btoa(binary).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, ''); +} + +function randomString(len = 43) { + const charset = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~'; + let res = ''; + const array = new Uint32Array(len); + crypto.getRandomValues(array); + for (let i = 0; i < len; i++) res += charset[array[i] % charset.length]; + return res; +} + +export default function OAuth2ToolsModal({ visible, onClose }) { + const [server, setServer] = useState({}); + const [authURL, setAuthURL] = useState(''); + const [issuer, setIssuer] = useState(''); + const [confJSON, setConfJSON] = useState(''); + const [userinfoEndpoint, setUserinfoEndpoint] = useState(''); + const [code, setCode] = useState(''); + const [accessToken, setAccessToken] = useState(''); + const [idToken, setIdToken] = useState(''); + const [refreshToken, setRefreshToken] = useState(''); + const [tokenRaw, setTokenRaw] = useState(''); + const [jwtClaims, setJwtClaims] = useState(''); + const [userinfoOut, setUserinfoOut] = useState(''); + const [values, setValues] = useState({ + authorization_endpoint: '', + token_endpoint: '', + client_id: '', + client_secret: '', + redirect_uri: window.location.origin + '/oauth/oidc', + scope: 'openid profile email', + response_type: 'code', + code_verifier: '', + code_challenge: '', + code_challenge_method: 'S256', + state: '', + nonce: '', + }); + + useEffect(() => { + if (!visible) return; + (async () => { + try { + const res = await API.get('/api/oauth/server-info'); + if (res?.data) { + const d = res.data; + setServer(d); + setValues((v) => ({ + ...v, + authorization_endpoint: d.authorization_endpoint, + token_endpoint: d.token_endpoint, + })); + setIssuer(d.issuer || ''); + setUserinfoEndpoint(d.userinfo_endpoint || ''); + } + } catch {} + })(); + }, [visible]); + + const buildAuthorizeURL = () => { + const u = new URL(values.authorization_endpoint || (server.issuer + '/api/oauth/authorize')); + const rt = values.response_type || 'code'; + u.searchParams.set('response_type', rt); + u.searchParams.set('client_id', values.client_id); + u.searchParams.set('redirect_uri', values.redirect_uri); + u.searchParams.set('scope', values.scope); + if (values.state) u.searchParams.set('state', values.state); + if (values.nonce) u.searchParams.set('nonce', values.nonce); + if (rt === 'code' && values.code_challenge) { + u.searchParams.set('code_challenge', values.code_challenge); + u.searchParams.set('code_challenge_method', values.code_challenge_method || 'S256'); + } + return u.toString(); + }; + + const copy = async (text, tip = '已复制') => { + try { await navigator.clipboard.writeText(text); Toast.success(tip); } catch {} + }; + + const genVerifier = async () => { + const v = randomString(64); + const c = await sha256Base64Url(v); + setValues((val) => ({ ...val, code_verifier: v, code_challenge: c })); + }; + + const discover = async () => { + const iss = (issuer || '').trim(); + if (!iss) { Toast.warning('请填写 Issuer'); return; } + try { + const url = iss.replace(/\/$/, '') + '/api/.well-known/openid-configuration'; + const res = await fetch(url); + const d = await res.json(); + setValues((v)=>({ + ...v, + authorization_endpoint: d.authorization_endpoint || v.authorization_endpoint, + token_endpoint: d.token_endpoint || v.token_endpoint, + })); + setUserinfoEndpoint(d.userinfo_endpoint || ''); + setIssuer(d.issuer || iss); + setConfJSON(JSON.stringify(d, null, 2)); + Toast.success('已从发现文档加载端点'); + } catch (e) { + Toast.error('自动发现失败'); + } + }; + + const parseConf = () => { + try { + const d = JSON.parse(confJSON || '{}'); + if (d.issuer) setIssuer(d.issuer); + if (d.authorization_endpoint) setValues((v)=>({...v, authorization_endpoint: d.authorization_endpoint})); + if (d.token_endpoint) setValues((v)=>({...v, token_endpoint: d.token_endpoint})); + if (d.userinfo_endpoint) setUserinfoEndpoint(d.userinfo_endpoint); + Toast.success('已解析配置并填充端点'); + } catch (e) { + Toast.error('解析失败:' + e.message); + } + }; + + const genConf = () => { + const d = { + issuer: issuer || undefined, + authorization_endpoint: values.authorization_endpoint || undefined, + token_endpoint: values.token_endpoint || undefined, + userinfo_endpoint: userinfoEndpoint || undefined, + }; + setConfJSON(JSON.stringify(d, null, 2)); + }; + + async function postForm(url, data, basicAuth) { + const body = Object.entries(data) + .filter(([_, v]) => v !== undefined && v !== null) + .map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(String(v))}`) + .join('&'); + const headers = { 'Content-Type': 'application/x-www-form-urlencoded' }; + if (basicAuth) headers['Authorization'] = 'Basic ' + btoa(`${basicAuth.id}:${basicAuth.secret}`); + const res = await fetch(url, { method: 'POST', headers, body }); + if (!res.ok) { + const t = await res.text(); + throw new Error(`HTTP ${res.status} ${t}`); + } + return res.json(); + } + + const exchangeCode = async () => { + try { + const basic = values.client_secret ? { id: values.client_id, secret: values.client_secret } : undefined; + const data = await postForm(values.token_endpoint, { + grant_type: 'authorization_code', + code: code.trim(), + client_id: values.client_id, + redirect_uri: values.redirect_uri, + code_verifier: values.code_verifier, + }, basic); + setAccessToken(data.access_token || ''); + setIdToken(data.id_token || ''); + setRefreshToken(data.refresh_token || ''); + setTokenRaw(JSON.stringify(data, null, 2)); + Toast.success('已获取令牌'); + } catch (e) { + Toast.error('兑换失败:' + e.message); + } + }; + + const decodeIdToken = () => { + const t = (idToken || '').trim(); + if (!t) { setJwtClaims('(空)'); return; } + const parts = t.split('.'); + if (parts.length < 2) { setJwtClaims('格式错误'); return; } + try { + const json = JSON.parse(atob(parts[1].replace(/-/g,'+').replace(/_/g,'/'))); + setJwtClaims(JSON.stringify(json, null, 2)); + } catch (e) { + setJwtClaims('解码失败:' + e); + } + }; + + const callUserInfo = async () => { + if (!accessToken || !userinfoEndpoint) { Toast.warning('缺少 AccessToken 或 UserInfo 端点'); return; } + try { + const res = await fetch(userinfoEndpoint, { headers: { Authorization: 'Bearer ' + accessToken } }); + const data = await res.json(); + setUserinfoOut(JSON.stringify(data, null, 2)); + } catch (e) { + setUserinfoOut('调用失败:' + e); + } + }; + + const doRefresh = async () => { + if (!refreshToken) { Toast.warning('没有刷新令牌'); return; } + try { + const basic = values.client_secret ? { id: values.client_id, secret: values.client_secret } : undefined; + const data = await postForm(values.token_endpoint, { + grant_type: 'refresh_token', + refresh_token: refreshToken, + client_id: values.client_id, + }, basic); + setAccessToken(data.access_token || ''); + setIdToken(data.id_token || ''); + setRefreshToken(data.refresh_token || ''); + setTokenRaw(JSON.stringify(data, null, 2)); + Toast.success('刷新成功'); + } catch (e) { + Toast.error('刷新失败:' + e.message); + } + }; + + return ( + 关闭} + width={720} + style={{ top: 48 }} + > + {/* Discovery */} + OIDC 发现 +
+ + + + + + + +