mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 00:27:26 +00:00
292 lines
7.4 KiB
Go
292 lines
7.4 KiB
Go
package middleware
|
||
|
||
import (
|
||
"crypto/rsa"
|
||
"fmt"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/model"
|
||
"one-api/setting/system_setting"
|
||
"one-api/src/oauth"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/golang-jwt/jwt/v5"
|
||
)
|
||
|
||
// OAuthJWTAuth OAuth2 JWT认证中间件
|
||
func OAuthJWTAuth() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 检查OAuth2是否启用
|
||
settings := system_setting.GetOAuth2Settings()
|
||
if !settings.Enabled {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 获取Authorization header
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
c.Next() // 没有Authorization header,继续到下一个中间件
|
||
return
|
||
}
|
||
|
||
// 检查是否为Bearer token
|
||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||
c.Next() // 不是Bearer token,继续到下一个中间件
|
||
return
|
||
}
|
||
|
||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||
if tokenString == "" {
|
||
abortWithOAuthError(c, "invalid_token", "Missing token")
|
||
return
|
||
}
|
||
|
||
// 验证JWT token
|
||
claims, err := validateOAuthJWT(tokenString)
|
||
if err != nil {
|
||
abortWithOAuthError(c, "invalid_token", err.Error())
|
||
return
|
||
}
|
||
|
||
// 验证token的有效性
|
||
if err := validateOAuthClaims(claims); err != nil {
|
||
abortWithOAuthError(c, "invalid_token", err.Error())
|
||
return
|
||
}
|
||
|
||
// 设置上下文信息
|
||
setOAuthContext(c, claims)
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// validateOAuthJWT 验证OAuth2 JWT令牌
|
||
func validateOAuthJWT(tokenString string) (jwt.MapClaims, error) {
|
||
// 解析JWT而不验证签名(先获取header中的kid)
|
||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||
// 检查签名方法
|
||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||
}
|
||
|
||
// 获取kid
|
||
kid, ok := token.Header["kid"].(string)
|
||
if !ok {
|
||
return nil, fmt.Errorf("missing kid in token header")
|
||
}
|
||
|
||
// 根据kid获取公钥
|
||
publicKey, err := getPublicKeyByKid(kid)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||
}
|
||
|
||
return publicKey, nil
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||
}
|
||
|
||
if !token.Valid {
|
||
return nil, fmt.Errorf("invalid token")
|
||
}
|
||
|
||
claims, ok := token.Claims.(jwt.MapClaims)
|
||
if !ok {
|
||
return nil, fmt.Errorf("invalid token claims")
|
||
}
|
||
|
||
return claims, nil
|
||
}
|
||
|
||
// getPublicKeyByKid 根据kid获取公钥
|
||
func getPublicKeyByKid(kid string) (*rsa.PublicKey, error) {
|
||
// 这里需要从JWKS获取公钥
|
||
// 在实际实现中,你可能需要从OAuth server获取JWKS
|
||
// 这里先实现一个简单版本
|
||
|
||
// TODO: 实现JWKS缓存和刷新机制
|
||
pub := oauth.GetPublicKeyByKid(kid)
|
||
if pub == nil {
|
||
return nil, fmt.Errorf("unknown kid: %s", kid)
|
||
}
|
||
return pub, nil
|
||
}
|
||
|
||
// validateOAuthClaims 验证OAuth2 claims
|
||
func validateOAuthClaims(claims jwt.MapClaims) error {
|
||
settings := system_setting.GetOAuth2Settings()
|
||
|
||
// 验证issuer(若配置了 Issuer 则强校验,否则仅要求存在)
|
||
if iss, ok := claims["iss"].(string); ok {
|
||
if settings.Issuer != "" && iss != settings.Issuer {
|
||
return fmt.Errorf("invalid issuer")
|
||
}
|
||
} else {
|
||
return fmt.Errorf("missing issuer claim")
|
||
}
|
||
|
||
// 验证audience
|
||
// if aud, ok := claims["aud"].(string); ok {
|
||
// // TODO: 验证audience
|
||
// }
|
||
|
||
// 验证客户端ID
|
||
if clientID, ok := claims["client_id"].(string); ok {
|
||
// 验证客户端是否存在且有效
|
||
client, err := model.GetOAuthClientByID(clientID)
|
||
if err != nil {
|
||
return fmt.Errorf("invalid client")
|
||
}
|
||
if client.Status != common.UserStatusEnabled {
|
||
return fmt.Errorf("client disabled")
|
||
}
|
||
|
||
// 检查是否被撤销
|
||
if jti, ok := claims["jti"].(string); ok && jti != "" {
|
||
revoked, _ := model.IsTokenRevoked(jti)
|
||
if revoked {
|
||
return fmt.Errorf("token revoked")
|
||
}
|
||
}
|
||
} else {
|
||
return fmt.Errorf("missing client_id claim")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// setOAuthContext 设置OAuth上下文信息
|
||
func setOAuthContext(c *gin.Context, claims jwt.MapClaims) {
|
||
c.Set("oauth_claims", claims)
|
||
c.Set("oauth_authenticated", true)
|
||
|
||
// 提取基本信息
|
||
if clientID, ok := claims["client_id"].(string); ok {
|
||
c.Set("oauth_client_id", clientID)
|
||
}
|
||
|
||
if scope, ok := claims["scope"].(string); ok {
|
||
c.Set("oauth_scope", scope)
|
||
}
|
||
|
||
if sub, ok := claims["sub"].(string); ok {
|
||
c.Set("oauth_subject", sub)
|
||
}
|
||
|
||
// 对于client_credentials流程,subject就是client_id
|
||
// 对于authorization_code流程,subject是用户ID
|
||
if grantType, ok := claims["grant_type"].(string); ok {
|
||
c.Set("oauth_grant_type", grantType)
|
||
}
|
||
}
|
||
|
||
// abortWithOAuthError 返回OAuth错误响应
|
||
func abortWithOAuthError(c *gin.Context, errorCode, description string) {
|
||
c.Header("WWW-Authenticate", fmt.Sprintf(`Bearer error="%s", error_description="%s"`, errorCode, description))
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": errorCode,
|
||
"error_description": description,
|
||
})
|
||
c.Abort()
|
||
}
|
||
|
||
// RequireOAuthScope OAuth2 scope验证中间件
|
||
func RequireOAuthScope(requiredScope string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 检查是否通过OAuth认证
|
||
if !c.GetBool("oauth_authenticated") {
|
||
abortWithOAuthError(c, "insufficient_scope", "OAuth2 authentication required")
|
||
return
|
||
}
|
||
|
||
// 获取token的scope
|
||
scope, exists := c.Get("oauth_scope")
|
||
if !exists {
|
||
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
|
||
return
|
||
}
|
||
|
||
scopeStr, ok := scope.(string)
|
||
if !ok {
|
||
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
|
||
return
|
||
}
|
||
|
||
// 检查是否包含所需的scope
|
||
scopes := strings.Split(scopeStr, " ")
|
||
for _, s := range scopes {
|
||
if strings.TrimSpace(s) == requiredScope {
|
||
c.Next()
|
||
return
|
||
}
|
||
}
|
||
|
||
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
|
||
}
|
||
}
|
||
|
||
// OptionalOAuthAuth 可选的OAuth认证中间件(不会阻止请求)
|
||
func OptionalOAuthAuth() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 尝试OAuth认证,但不会阻止请求
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
|
||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||
if claims, err := validateOAuthJWT(tokenString); err == nil {
|
||
if validateOAuthClaims(claims) == nil {
|
||
setOAuthContext(c, claims)
|
||
}
|
||
}
|
||
}
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// RequireOAuthScopeIfPresent enforces scope only when OAuth is present; otherwise no-op
|
||
func RequireOAuthScopeIfPresent(requiredScope string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
if !c.GetBool("oauth_authenticated") {
|
||
c.Next()
|
||
return
|
||
}
|
||
scope, exists := c.Get("oauth_scope")
|
||
if !exists {
|
||
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
|
||
return
|
||
}
|
||
scopeStr, ok := scope.(string)
|
||
if !ok {
|
||
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
|
||
return
|
||
}
|
||
scopes := strings.Split(scopeStr, " ")
|
||
for _, s := range scopes {
|
||
if strings.TrimSpace(s) == requiredScope {
|
||
c.Next()
|
||
return
|
||
}
|
||
}
|
||
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
|
||
}
|
||
}
|
||
|
||
// GetOAuthClaims 获取OAuth claims
|
||
func GetOAuthClaims(c *gin.Context) (jwt.MapClaims, bool) {
|
||
claims, exists := c.Get("oauth_claims")
|
||
if !exists {
|
||
return nil, false
|
||
}
|
||
|
||
mapClaims, ok := claims.(jwt.MapClaims)
|
||
return mapClaims, ok
|
||
}
|
||
|
||
// IsOAuthAuthenticated 检查是否通过OAuth认证
|
||
func IsOAuthAuthenticated(c *gin.Context) bool {
|
||
return c.GetBool("oauth_authenticated")
|
||
}
|