Files
new-api/middleware/oauth_jwt.go
2025-09-16 17:10:01 +08:00

292 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"crypto/rsa"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting/system_setting"
"one-api/src/oauth"
"strings"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)
// OAuthJWTAuth OAuth2 JWT认证中间件
func OAuthJWTAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 检查OAuth2是否启用
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.Next()
return
}
// 获取Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Next() // 没有Authorization header继续到下一个中间件
return
}
// 检查是否为Bearer token
if !strings.HasPrefix(authHeader, "Bearer ") {
c.Next() // 不是Bearer token继续到下一个中间件
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == "" {
abortWithOAuthError(c, "invalid_token", "Missing token")
return
}
// 验证JWT token
claims, err := validateOAuthJWT(tokenString)
if err != nil {
abortWithOAuthError(c, "invalid_token", err.Error())
return
}
// 验证token的有效性
if err := validateOAuthClaims(claims); err != nil {
abortWithOAuthError(c, "invalid_token", err.Error())
return
}
// 设置上下文信息
setOAuthContext(c, claims)
c.Next()
}
}
// validateOAuthJWT 验证OAuth2 JWT令牌
func validateOAuthJWT(tokenString string) (jwt.MapClaims, error) {
// 解析JWT而不验证签名先获取header中的kid
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// 检查签名方法
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// 获取kid
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("missing kid in token header")
}
// 根据kid获取公钥
publicKey, err := getPublicKeyByKid(kid)
if err != nil {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
return publicKey, nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
if !token.Valid {
return nil, fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid token claims")
}
return claims, nil
}
// getPublicKeyByKid 根据kid获取公钥
func getPublicKeyByKid(kid string) (*rsa.PublicKey, error) {
// 这里需要从JWKS获取公钥
// 在实际实现中你可能需要从OAuth server获取JWKS
// 这里先实现一个简单版本
// TODO: 实现JWKS缓存和刷新机制
pub := oauth.GetPublicKeyByKid(kid)
if pub == nil {
return nil, fmt.Errorf("unknown kid: %s", kid)
}
return pub, nil
}
// validateOAuthClaims 验证OAuth2 claims
func validateOAuthClaims(claims jwt.MapClaims) error {
settings := system_setting.GetOAuth2Settings()
// 验证issuer若配置了 Issuer 则强校验,否则仅要求存在)
if iss, ok := claims["iss"].(string); ok {
if settings.Issuer != "" && iss != settings.Issuer {
return fmt.Errorf("invalid issuer")
}
} else {
return fmt.Errorf("missing issuer claim")
}
// 验证audience
// if aud, ok := claims["aud"].(string); ok {
// // TODO: 验证audience
// }
// 验证客户端ID
if clientID, ok := claims["client_id"].(string); ok {
// 验证客户端是否存在且有效
client, err := model.GetOAuthClientByID(clientID)
if err != nil {
return fmt.Errorf("invalid client")
}
if client.Status != common.UserStatusEnabled {
return fmt.Errorf("client disabled")
}
// 检查是否被撤销
if jti, ok := claims["jti"].(string); ok && jti != "" {
revoked, _ := model.IsTokenRevoked(jti)
if revoked {
return fmt.Errorf("token revoked")
}
}
} else {
return fmt.Errorf("missing client_id claim")
}
return nil
}
// setOAuthContext 设置OAuth上下文信息
func setOAuthContext(c *gin.Context, claims jwt.MapClaims) {
c.Set("oauth_claims", claims)
c.Set("oauth_authenticated", true)
// 提取基本信息
if clientID, ok := claims["client_id"].(string); ok {
c.Set("oauth_client_id", clientID)
}
if scope, ok := claims["scope"].(string); ok {
c.Set("oauth_scope", scope)
}
if sub, ok := claims["sub"].(string); ok {
c.Set("oauth_subject", sub)
}
// 对于client_credentials流程subject就是client_id
// 对于authorization_code流程subject是用户ID
if grantType, ok := claims["grant_type"].(string); ok {
c.Set("oauth_grant_type", grantType)
}
}
// abortWithOAuthError 返回OAuth错误响应
func abortWithOAuthError(c *gin.Context, errorCode, description string) {
c.Header("WWW-Authenticate", fmt.Sprintf(`Bearer error="%s", error_description="%s"`, errorCode, description))
c.JSON(http.StatusUnauthorized, gin.H{
"error": errorCode,
"error_description": description,
})
c.Abort()
}
// RequireOAuthScope OAuth2 scope验证中间件
func RequireOAuthScope(requiredScope string) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查是否通过OAuth认证
if !c.GetBool("oauth_authenticated") {
abortWithOAuthError(c, "insufficient_scope", "OAuth2 authentication required")
return
}
// 获取token的scope
scope, exists := c.Get("oauth_scope")
if !exists {
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
return
}
scopeStr, ok := scope.(string)
if !ok {
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
return
}
// 检查是否包含所需的scope
scopes := strings.Split(scopeStr, " ")
for _, s := range scopes {
if strings.TrimSpace(s) == requiredScope {
c.Next()
return
}
}
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
}
}
// OptionalOAuthAuth 可选的OAuth认证中间件不会阻止请求
func OptionalOAuthAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试OAuth认证但不会阻止请求
authHeader := c.GetHeader("Authorization")
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if claims, err := validateOAuthJWT(tokenString); err == nil {
if validateOAuthClaims(claims) == nil {
setOAuthContext(c, claims)
}
}
}
c.Next()
}
}
// RequireOAuthScopeIfPresent enforces scope only when OAuth is present; otherwise no-op
func RequireOAuthScopeIfPresent(requiredScope string) gin.HandlerFunc {
return func(c *gin.Context) {
if !c.GetBool("oauth_authenticated") {
c.Next()
return
}
scope, exists := c.Get("oauth_scope")
if !exists {
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
return
}
scopeStr, ok := scope.(string)
if !ok {
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
return
}
scopes := strings.Split(scopeStr, " ")
for _, s := range scopes {
if strings.TrimSpace(s) == requiredScope {
c.Next()
return
}
}
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
}
}
// GetOAuthClaims 获取OAuth claims
func GetOAuthClaims(c *gin.Context) (jwt.MapClaims, bool) {
claims, exists := c.Get("oauth_claims")
if !exists {
return nil, false
}
mapClaims, ok := claims.(jwt.MapClaims)
return mapClaims, ok
}
// IsOAuthAuthenticated 检查是否通过OAuth认证
func IsOAuthAuthenticated(c *gin.Context) bool {
return c.GetBool("oauth_authenticated")
}