mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-31 21:09:39 +00:00
Compare commits
59 Commits
v0.9.0
...
feature/ss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
380e1b7d56 | ||
|
|
63828349de | ||
|
|
5706f0ee9f | ||
|
|
e9e1dbff5e | ||
|
|
315eabc1e7 | ||
|
|
359dbc9d94 | ||
|
|
e157ea6ba2 | ||
|
|
dc3dba0665 | ||
|
|
81272da9ac | ||
|
|
926cad87b3 | ||
|
|
418ce449b7 | ||
|
|
4a02ab23ce | ||
|
|
984097c60b | ||
|
|
5550ec017e | ||
|
|
9e6752e0ee | ||
|
|
18a385f817 | ||
|
|
8e95d338b5 | ||
|
|
f236785ed5 | ||
|
|
f3e220b196 | ||
|
|
33bf267ce8 | ||
|
|
4f760a8d40 | ||
|
|
8563eafc57 | ||
|
|
7d71f467d9 | ||
|
|
da6f24a3d4 | ||
|
|
28ed42130c | ||
|
|
96215c9fd5 | ||
|
|
6628fd9181 | ||
|
|
a3b8a1998a | ||
|
|
6a34d365ec | ||
|
|
406a3e4dca | ||
|
|
c1d7ecdeec | ||
|
|
6451158680 | ||
|
|
0bd4b34046 | ||
|
|
f14b06ec3a | ||
|
|
6ed775be8f | ||
|
|
b712279b2a | ||
|
|
1bffe3081d | ||
|
|
cfebe80822 | ||
|
|
17e697af8f | ||
|
|
01b35bb667 | ||
|
|
d8410d2f11 | ||
|
|
e68eed3d40 | ||
|
|
04cc668430 | ||
|
|
5d76e16324 | ||
|
|
b6c547ae98 | ||
|
|
93adcd57d7 | ||
|
|
e813da59cc | ||
|
|
b25ac0bfb6 | ||
|
|
70c27bc662 | ||
|
|
db6a788e0d | ||
|
|
e3bc40f11b | ||
|
|
684caa3673 | ||
|
|
47aaa695b2 | ||
|
|
cda73a2ec5 | ||
|
|
27a0a447d0 | ||
|
|
fcdfd027cd | ||
|
|
3f9698bb47 | ||
|
|
91a0eb7031 | ||
|
|
81e29aaa3d |
@@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||
var UsingMySQL = false
|
||||
var UsingClickHouse = false
|
||||
|
||||
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -342,7 +342,7 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||
}
|
||||
availableBalanceCny := response.Data.AvailableBalance
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
|
||||
channel.UpdateBalance(availableBalanceUsd)
|
||||
return availableBalanceUsd, nil
|
||||
}
|
||||
|
||||
@@ -235,7 +235,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -560,7 +561,7 @@ func AddChannel(c *gin.Context) {
|
||||
case "multi_to_single":
|
||||
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -585,7 +586,7 @@ func AddChannel(c *gin.Context) {
|
||||
}
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
case "batch":
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
// multi json
|
||||
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
@@ -840,7 +841,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 处理 Vertex AI 的特殊情况
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
// 尝试解析新密钥为JSON数组
|
||||
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||
array, err := getVertexArrayKeys(channel.Key)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -259,7 +260,7 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
@@ -284,7 +285,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,11 +58,7 @@ func GetStatus(c *gin.Context) {
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"server_address": system_setting.ServerAddress,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
@@ -75,15 +71,15 @@ func GetStatus(c *gin.Context) {
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
|
||||
"usd_exchange_rate": operation_setting.USDExchangeRate,
|
||||
"price": operation_setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
@@ -253,7 +249,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
|
||||
375
controller/oauth.go
Normal file
375
controller/oauth.go
Normal file
@@ -0,0 +1,375 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/src/oauth"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
"one-api/middleware"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetJWKS 获取JWKS公钥集
|
||||
func GetJWKS(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// lazy init if needed
|
||||
_ = oauth.EnsureInitialized()
|
||||
|
||||
jwks := oauth.GetJWKS()
|
||||
if jwks == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "JWKS not available",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置CORS headers
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type")
|
||||
c.Header("Cache-Control", "public, max-age=3600") // 缓存1小时
|
||||
|
||||
// 返回JWKS
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// 将JWKS转换为JSON字符串
|
||||
jsonData, err := json.Marshal(jwks)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Failed to marshal JWKS",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.String(http.StatusOK, string(jsonData))
|
||||
}
|
||||
|
||||
// OAuthTokenEndpoint OAuth2 令牌端点
|
||||
func OAuthTokenEndpoint(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "unsupported_grant_type",
|
||||
"error_description": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只允许POST请求
|
||||
if c.Request.Method != "POST" {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Only POST method is allowed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只允许application/x-www-form-urlencoded内容类型
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if contentType == "" || !strings.Contains(strings.ToLower(contentType), "application/x-www-form-urlencoded") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Content-Type must be application/x-www-form-urlencoded",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// lazy init
|
||||
if err := oauth.EnsureInitialized(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()})
|
||||
return
|
||||
}
|
||||
oauth.HandleTokenRequest(c)
|
||||
}
|
||||
|
||||
// OAuthAuthorizeEndpoint OAuth2 授权端点
|
||||
func OAuthAuthorizeEndpoint(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "server_error",
|
||||
"error_description": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := oauth.EnsureInitialized(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()})
|
||||
return
|
||||
}
|
||||
oauth.HandleAuthorizeRequest(c)
|
||||
}
|
||||
|
||||
// OAuthServerInfo 获取OAuth2服务器信息
|
||||
func OAuthServerInfo(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回OAuth2服务器的基本信息(类似OpenID Connect Discovery)
|
||||
issuer := settings.Issuer
|
||||
if issuer == "" {
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
|
||||
scheme = hdr
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
issuer = scheme + "://" + c.Request.Host
|
||||
}
|
||||
|
||||
base := issuer + "/api"
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": base + "/oauth/authorize",
|
||||
"token_endpoint": base + "/oauth/token",
|
||||
"jwks_uri": base + "/.well-known/jwks.json",
|
||||
"grant_types_supported": settings.AllowedGrantTypes,
|
||||
"response_types_supported": []string{"code", "token"},
|
||||
"token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"},
|
||||
"default_private_key_path": settings.DefaultPrivateKeyPath,
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthOIDCConfiguration OIDC discovery document
|
||||
func OAuthOIDCConfiguration(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"})
|
||||
return
|
||||
}
|
||||
issuer := settings.Issuer
|
||||
if issuer == "" {
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
|
||||
scheme = hdr
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
issuer = scheme + "://" + c.Request.Host
|
||||
}
|
||||
base := issuer + "/api"
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": base + "/oauth/authorize",
|
||||
"token_endpoint": base + "/oauth/token",
|
||||
"userinfo_endpoint": base + "/oauth/userinfo",
|
||||
"jwks_uri": base + "/.well-known/jwks.json",
|
||||
"response_types_supported": []string{"code", "token"},
|
||||
"grant_types_supported": settings.AllowedGrantTypes,
|
||||
"subject_types_supported": []string{"public"},
|
||||
"id_token_signing_alg_values_supported": []string{"RS256"},
|
||||
"scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"},
|
||||
"token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"default_private_key_path": settings.DefaultPrivateKeyPath,
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthIntrospect 令牌内省端点(RFC 7662)
|
||||
func OAuthIntrospect(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只允许POST请求
|
||||
if c.Request.Method != "POST" {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Only POST method is allowed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token := c.PostForm("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"active": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := token
|
||||
|
||||
// 验证并解析JWT
|
||||
parsed, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
pub := oauth.GetPublicKeyByKid(func() string {
|
||||
if v, ok := token.Header["kid"].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}())
|
||||
if pub == nil {
|
||||
return nil, jwt.ErrTokenUnverifiable
|
||||
}
|
||||
return pub, nil
|
||||
})
|
||||
if err != nil || !parsed.Valid {
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := parsed.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查撤销
|
||||
if jti, ok := claims["jti"].(string); ok && jti != "" {
|
||||
if revoked, _ := model.IsTokenRevoked(jti); revoked {
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 有效
|
||||
resp := gin.H{"active": true}
|
||||
for k, v := range claims {
|
||||
resp[k] = v
|
||||
}
|
||||
resp["token_type"] = "Bearer"
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// OAuthRevoke 令牌撤销端点(RFC 7009)
|
||||
func OAuthRevoke(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只允许POST请求
|
||||
if c.Request.Method != "POST" {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Only POST method is allowed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token := c.PostForm("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing token parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token = c.PostForm("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing token parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试解析JWT,若成功则记录jti到撤销表
|
||||
parsed, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
pub := oauth.GetRSAPublicKey()
|
||||
if pub == nil {
|
||||
return nil, jwt.ErrTokenUnverifiable
|
||||
}
|
||||
return pub, nil
|
||||
})
|
||||
if err == nil && parsed != nil && parsed.Valid {
|
||||
if claims, ok := parsed.Claims.(jwt.MapClaims); ok {
|
||||
var jti string
|
||||
var exp int64
|
||||
if v, ok := claims["jti"].(string); ok {
|
||||
jti = v
|
||||
}
|
||||
if v, ok := claims["exp"].(float64); ok {
|
||||
exp = int64(v)
|
||||
} else if v, ok := claims["exp"].(int64); ok {
|
||||
exp = v
|
||||
}
|
||||
if jti != "" {
|
||||
// 如果没有exp,默认撤销至当前+TTL 10分钟
|
||||
if exp == 0 {
|
||||
exp = time.Now().Add(10 * time.Minute).Unix()
|
||||
}
|
||||
_ = model.RevokeToken(jti, exp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// OAuthUserInfo returns OIDC userinfo based on access token
|
||||
func OAuthUserInfo(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"})
|
||||
return
|
||||
}
|
||||
// 需要 OAuthJWTAuth 中间件注入 claims
|
||||
claims, ok := middleware.GetOAuthClaims(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"})
|
||||
return
|
||||
}
|
||||
// scope 校验:必须包含 openid
|
||||
scope, _ := claims["scope"].(string)
|
||||
if !strings.Contains(" "+scope+" ", " openid ") {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient_scope"})
|
||||
return
|
||||
}
|
||||
sub, _ := claims["sub"].(string)
|
||||
resp := gin.H{"sub": sub}
|
||||
// 若包含 profile/email scope,补充返回
|
||||
if strings.Contains(" "+scope+" ", " profile ") || strings.Contains(" "+scope+" ", " email ") {
|
||||
if uid, err := strconv.Atoi(sub); err == nil {
|
||||
if user, err2 := model.GetUserById(uid, false); err2 == nil && user != nil {
|
||||
if strings.Contains(" "+scope+" ", " profile ") {
|
||||
resp["name"] = user.DisplayName
|
||||
resp["preferred_username"] = user.Username
|
||||
}
|
||||
if strings.Contains(" "+scope+" ", " email ") {
|
||||
resp["email"] = user.Email
|
||||
resp["email_verified"] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
374
controller/oauth_client.go
Normal file
374
controller/oauth_client.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
// CreateOAuthClientRequest 创建OAuth客户端请求
|
||||
type CreateOAuthClientRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
ClientType string `json:"client_type" binding:"required,oneof=confidential public"`
|
||||
GrantTypes []string `json:"grant_types" binding:"required"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
Scopes []string `json:"scopes" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
RequirePKCE bool `json:"require_pkce"`
|
||||
}
|
||||
|
||||
// UpdateOAuthClientRequest 更新OAuth客户端请求
|
||||
type UpdateOAuthClientRequest struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
ClientType string `json:"client_type" binding:"required,oneof=confidential public"`
|
||||
GrantTypes []string `json:"grant_types" binding:"required"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
Scopes []string `json:"scopes" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
RequirePKCE bool `json:"require_pkce"`
|
||||
Status int `json:"status" binding:"required,oneof=1 2"`
|
||||
}
|
||||
|
||||
// GetAllOAuthClients 获取所有OAuth客户端
|
||||
func GetAllOAuthClients(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.Query("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
perPage, _ := strconv.Atoi(c.Query("per_page"))
|
||||
if perPage < 1 || perPage > 100 {
|
||||
perPage = 20
|
||||
}
|
||||
|
||||
startIdx := (page - 1) * perPage
|
||||
clients, err := model.GetAllOAuthClients(startIdx, perPage)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 清理敏感信息
|
||||
for _, client := range clients {
|
||||
client.Secret = maskSecret(client.Secret)
|
||||
}
|
||||
|
||||
total, _ := model.CountOAuthClients()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": clients,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": perPage,
|
||||
})
|
||||
}
|
||||
|
||||
// SearchOAuthClients 搜索OAuth客户端
|
||||
func SearchOAuthClients(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
if keyword == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "关键词不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
clients, err := model.SearchOAuthClients(keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 清理敏感信息
|
||||
for _, client := range clients {
|
||||
client.Secret = maskSecret(client.Secret)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": clients,
|
||||
})
|
||||
}
|
||||
|
||||
// GetOAuthClient 获取单个OAuth客户端
|
||||
func GetOAuthClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
client, err := model.GetOAuthClientByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "客户端不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 清理敏感信息
|
||||
client.Secret = maskSecret(client.Secret)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": client,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateOAuthClient 创建OAuth客户端
|
||||
func CreateOAuthClient(c *gin.Context) {
|
||||
var req CreateOAuthClientRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "请求参数错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证授权类型
|
||||
validGrantTypes := []string{"client_credentials", "authorization_code", "refresh_token"}
|
||||
for _, grantType := range req.GrantTypes {
|
||||
if !contains(validGrantTypes, grantType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的授权类型: " + grantType,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果包含authorization_code,则必须提供redirect_uris
|
||||
if contains(req.GrantTypes, "authorization_code") && len(req.RedirectURIs) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "授权码模式需要提供重定向URI",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成客户端ID和密钥
|
||||
clientID := generateClientID()
|
||||
clientSecret := ""
|
||||
if req.ClientType == "confidential" {
|
||||
clientSecret = generateClientSecret()
|
||||
}
|
||||
|
||||
// 获取创建者ID
|
||||
createdBy := c.GetInt("id")
|
||||
|
||||
// 创建客户端
|
||||
client := &model.OAuthClient{
|
||||
ID: clientID,
|
||||
Secret: clientSecret,
|
||||
Name: req.Name,
|
||||
ClientType: req.ClientType,
|
||||
RequirePKCE: req.RequirePKCE,
|
||||
Status: common.UserStatusEnabled,
|
||||
CreatedBy: createdBy,
|
||||
Description: req.Description,
|
||||
}
|
||||
|
||||
client.SetGrantTypes(req.GrantTypes)
|
||||
client.SetRedirectURIs(req.RedirectURIs)
|
||||
client.SetScopes(req.Scopes)
|
||||
|
||||
err := model.CreateOAuthClient(client)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "创建客户端失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回结果(包含完整的客户端密钥,仅此一次)
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"success": true,
|
||||
"message": "客户端创建成功",
|
||||
"client_id": client.ID,
|
||||
"client_secret": client.Secret, // 仅在创建时返回完整密钥
|
||||
"data": client,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateOAuthClient 更新OAuth客户端
|
||||
func UpdateOAuthClient(c *gin.Context) {
|
||||
var req UpdateOAuthClientRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "请求参数错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取现有客户端
|
||||
client, err := model.GetOAuthClientByID(req.ID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "客户端不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证授权类型
|
||||
validGrantTypes := []string{"client_credentials", "authorization_code", "refresh_token"}
|
||||
for _, grantType := range req.GrantTypes {
|
||||
if !contains(validGrantTypes, grantType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的授权类型: " + grantType,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新客户端信息
|
||||
client.Name = req.Name
|
||||
client.ClientType = req.ClientType
|
||||
client.RequirePKCE = req.RequirePKCE
|
||||
client.Status = req.Status
|
||||
client.Description = req.Description
|
||||
client.SetGrantTypes(req.GrantTypes)
|
||||
client.SetRedirectURIs(req.RedirectURIs)
|
||||
client.SetScopes(req.Scopes)
|
||||
|
||||
err = model.UpdateOAuthClient(client)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "更新客户端失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 清理敏感信息
|
||||
client.Secret = maskSecret(client.Secret)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "客户端更新成功",
|
||||
"data": client,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteOAuthClient 删除OAuth客户端
|
||||
func DeleteOAuthClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err := model.DeleteOAuthClient(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "删除客户端失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "客户端删除成功",
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateOAuthClientSecret 重新生成客户端密钥
|
||||
func RegenerateOAuthClientSecret(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
client, err := model.GetOAuthClientByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "客户端不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只有机密客户端才能重新生成密钥
|
||||
if client.ClientType != "confidential" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "只有机密客户端才能重新生成密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新密钥
|
||||
client.Secret = generateClientSecret()
|
||||
|
||||
err = model.UpdateOAuthClient(client)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "重新生成密钥失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "客户端密钥重新生成成功",
|
||||
"client_secret": client.Secret, // 返回新生成的密钥
|
||||
})
|
||||
}
|
||||
|
||||
// generateClientID 生成客户端ID
|
||||
func generateClientID() string {
|
||||
return "client_" + randstr.String(16)
|
||||
}
|
||||
|
||||
// generateClientSecret 生成客户端密钥
|
||||
func generateClientSecret() string {
|
||||
return randstr.String(32)
|
||||
}
|
||||
|
||||
// maskSecret 掩码密钥显示
|
||||
func maskSecret(secret string) string {
|
||||
if len(secret) <= 6 {
|
||||
return strings.Repeat("*", len(secret))
|
||||
}
|
||||
return secret[:3] + strings.Repeat("*", len(secret)-6) + secret[len(secret)-3:]
|
||||
}
|
||||
|
||||
// contains 检查字符串切片是否包含指定值
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
89
controller/oauth_keys.go
Normal file
89
controller/oauth_keys.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/logger"
|
||||
"one-api/src/oauth"
|
||||
)
|
||||
|
||||
type rotateKeyRequest struct {
|
||||
Kid string `json:"kid"`
|
||||
}
|
||||
|
||||
type genKeyFileRequest struct {
|
||||
Path string `json:"path"`
|
||||
Kid string `json:"kid"`
|
||||
Overwrite bool `json:"overwrite"`
|
||||
}
|
||||
|
||||
type importPemRequest struct {
|
||||
Pem string `json:"pem"`
|
||||
Kid string `json:"kid"`
|
||||
}
|
||||
|
||||
// RotateOAuthSigningKey rotates the OAuth2 JWT signing key (Root only)
|
||||
func RotateOAuthSigningKey(c *gin.Context) {
|
||||
var req rotateKeyRequest
|
||||
_ = c.BindJSON(&req)
|
||||
kid, err := oauth.RotateSigningKey(req.Kid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, "oauth signing key rotated: "+kid)
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid})
|
||||
}
|
||||
|
||||
// ListOAuthSigningKeys returns current and historical JWKS signing keys
|
||||
func ListOAuthSigningKeys(c *gin.Context) {
|
||||
keys := oauth.ListSigningKeys()
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "data": keys})
|
||||
}
|
||||
|
||||
// DeleteOAuthSigningKey deletes a non-current key by kid
|
||||
func DeleteOAuthSigningKey(c *gin.Context) {
|
||||
kid := c.Param("kid")
|
||||
if kid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "kid required"})
|
||||
return
|
||||
}
|
||||
if err := oauth.DeleteSigningKey(kid); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, "oauth signing key deleted: "+kid)
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GenerateOAuthSigningKeyFile generates a private key file and rotates current kid
|
||||
func GenerateOAuthSigningKeyFile(c *gin.Context) {
|
||||
var req genKeyFileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.Path == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "path required"})
|
||||
return
|
||||
}
|
||||
kid, err := oauth.GenerateAndPersistKey(req.Path, req.Kid, req.Overwrite)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, "oauth signing key generated to file: "+req.Path+" kid="+kid)
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid, "path": req.Path})
|
||||
}
|
||||
|
||||
// ImportOAuthSigningKey imports PEM text and rotates current kid
|
||||
func ImportOAuthSigningKey(c *gin.Context) {
|
||||
var req importPemRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.Pem == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "pem required"})
|
||||
return
|
||||
}
|
||||
kid, err := oauth.ImportPEMKey(req.Pem, req.Kid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, "oauth signing key imported from PEM, kid="+kid)
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid})
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -45,7 +44,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
|
||||
@@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||
|
||||
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Only return quota if downstream failed and quota was actually pre-consumed
|
||||
if newAPIError != nil && preConsumedQuota != 0 {
|
||||
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
||||
if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
|
||||
service.ReturnPreConsumedQuota(c, relayInfo)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -277,14 +277,13 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
|
||||
gopool.Go(func() {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
gopool.Go(func() {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||
// 保存错误日志到mysql中
|
||||
|
||||
@@ -178,4 +178,4 @@ func boolToString(b bool) string {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = responseBody
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
@@ -117,7 +117,9 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Url
|
||||
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||
task.FailReason = taskResult.Url
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
@@ -146,3 +148,37 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactVideoResponseBody(body []byte) []byte {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return body
|
||||
}
|
||||
resp, _ := m["response"].(map[string]any)
|
||||
if resp != nil {
|
||||
delete(resp, "bytesBase64Encoded")
|
||||
if v, ok := resp["video"].(string); ok {
|
||||
resp["video"] = truncateBase64(v)
|
||||
}
|
||||
if vs, ok := resp["videos"].([]any); ok {
|
||||
for i := range vs {
|
||||
if vm, ok := vs[i].(map[string]any); ok {
|
||||
delete(vm, "bytesBase64Encoded")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func truncateBase64(s string) string {
|
||||
const maxKeep = 256
|
||||
if len(s) <= maxKeep {
|
||||
return s
|
||||
}
|
||||
return s[:maxKeep] + "..."
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -19,6 +21,44 @@ import (
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func GetTopUpInfo(c *gin.Context) {
|
||||
// 获取支付方式
|
||||
payMethods := operation_setting.PayMethods
|
||||
|
||||
// 如果启用了 Stripe 支付,添加到支付方法列表
|
||||
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
|
||||
// 检查是否已经包含 Stripe
|
||||
hasStripe := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == "stripe" {
|
||||
hasStripe = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasStripe {
|
||||
stripeMethod := map[string]string{
|
||||
"name": "Stripe",
|
||||
"type": "stripe",
|
||||
"color": "rgba(var(--semi-purple-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.StripeMinTopUp),
|
||||
}
|
||||
payMethods = append(payMethods, stripeMethod)
|
||||
}
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"pay_methods": payMethods,
|
||||
"min_topup": operation_setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
type EpayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
@@ -31,13 +71,13 @@ type AmountRequest struct {
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
|
||||
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
|
||||
return nil
|
||||
}
|
||||
withUrl, err := epay.NewClient(&epay.Config{
|
||||
PartnerID: setting.EpayId,
|
||||
Key: setting.EpayKey,
|
||||
}, setting.PayAddress)
|
||||
PartnerID: operation_setting.EpayId,
|
||||
Key: operation_setting.EpayKey,
|
||||
}, operation_setting.PayAddress)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -58,15 +98,23 @@ func getPayMoney(amount int64, group string) float64 {
|
||||
}
|
||||
|
||||
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
|
||||
dPrice := decimal.NewFromFloat(setting.Price)
|
||||
dPrice := decimal.NewFromFloat(operation_setting.Price)
|
||||
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||
discount := 1.0
|
||||
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok {
|
||||
if ds > 0 {
|
||||
discount = ds
|
||||
}
|
||||
}
|
||||
dDiscount := decimal.NewFromFloat(discount)
|
||||
|
||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
|
||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio).Mul(dDiscount)
|
||||
|
||||
return payMoney.InexactFloat64()
|
||||
}
|
||||
|
||||
func getMinTopup() int64 {
|
||||
minTopup := setting.MinTopUp
|
||||
minTopup := operation_setting.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
dMinTopup := decimal.NewFromInt(int64(minTopup))
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
@@ -99,13 +147,13 @@ func RequestEpay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -215,8 +217,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: stripe.String(referenceId),
|
||||
SuccessURL: stripe.String(setting.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(setting.ServerAddress + "/topup"),
|
||||
SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"),
|
||||
CancelURL: stripe.String(system_setting.ServerAddress + "/topup"),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(setting.StripePriceId),
|
||||
@@ -254,6 +256,7 @@ func GetChargedAmount(count float64, user model.User) float64 {
|
||||
}
|
||||
|
||||
func getStripePayMoney(amount float64, group string) float64 {
|
||||
originalAmount := amount
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
@@ -262,7 +265,14 @@ func getStripePayMoney(amount float64, group string) float64 {
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
|
||||
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||
discount := 1.0
|
||||
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
|
||||
if ds > 0 {
|
||||
discount = ds
|
||||
}
|
||||
}
|
||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
|
||||
return payMoney
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,14 @@ type ChannelSettings struct {
|
||||
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
|
||||
}
|
||||
|
||||
type VertexKeyType string
|
||||
|
||||
const (
|
||||
VertexKeyTypeJSON VertexKeyType = "json"
|
||||
VertexKeyTypeAPIKey VertexKeyType = "api_key"
|
||||
)
|
||||
|
||||
type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
}
|
||||
|
||||
@@ -2,12 +2,11 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
@@ -269,15 +268,14 @@ type GeminiChatResponse struct {
|
||||
}
|
||||
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
PromptTokensDetails []GeminiModalityTokenCount `json:"promptTokensDetails"`
|
||||
CandidatesTokensDetails []GeminiModalityTokenCount `json:"candidatesTokensDetails"`
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
||||
}
|
||||
|
||||
type GeminiModalityTokenCount struct {
|
||||
type GeminiPromptTokensDetails struct {
|
||||
Modality string `json:"modality"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
@@ -59,6 +59,31 @@ func (i *ImageRequest) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 序列化时需要重新把字段平铺
|
||||
func (r ImageRequest) MarshalJSON() ([]byte, error) {
|
||||
// 将已定义字段转为 map
|
||||
type Alias ImageRequest
|
||||
alias := Alias(r)
|
||||
base, err := common.Marshal(alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var baseMap map[string]json.RawMessage
|
||||
if err := common.Unmarshal(base, &baseMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 合并 ExtraFields
|
||||
for k, v := range r.Extra {
|
||||
if _, exists := baseMap[k]; !exists {
|
||||
baseMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(baseMap)
|
||||
}
|
||||
|
||||
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||
fields := make(map[string]struct{})
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
|
||||
326
examples/oauth/oauth-demo.html
Normal file
326
examples/oauth/oauth-demo.html
Normal file
@@ -0,0 +1,326 @@
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>OAuth2/OIDC 授权码 + PKCE 前端演示</title>
|
||||
<style>
|
||||
:root { --bg:#0b0c10; --panel:#111317; --muted:#aab2bf; --accent:#3b82f6; --ok:#16a34a; --warn:#f59e0b; --err:#ef4444; --border:#1f2430; }
|
||||
body { margin:0; font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial; background: var(--bg); color:#e5e7eb; }
|
||||
.wrap { max-width: 980px; margin: 32px auto; padding: 0 16px; }
|
||||
h1 { font-size: 22px; margin:0 0 16px; }
|
||||
.card { background: var(--panel); border:1px solid var(--border); border-radius: 10px; padding: 16px; margin: 12px 0; }
|
||||
.row { display:flex; gap:12px; flex-wrap:wrap; }
|
||||
.col { flex: 1 1 280px; display:flex; flex-direction:column; }
|
||||
label { font-size: 12px; color: var(--muted); margin-bottom: 6px; }
|
||||
input, textarea, select { background:#0f1115; color:#e5e7eb; border:1px solid var(--border); padding:10px 12px; border-radius:8px; outline:none; }
|
||||
textarea { min-height: 100px; resize: vertical; }
|
||||
.btns { display:flex; gap:8px; flex-wrap:wrap; margin-top: 8px; }
|
||||
button { background:#1a1f2b; color:#e5e7eb; border:1px solid var(--border); padding:8px 12px; border-radius:8px; cursor:pointer; }
|
||||
button.primary { background: var(--accent); border-color: var(--accent); color:white; }
|
||||
button.ok { background: var(--ok); border-color: var(--ok); color:white; }
|
||||
button.warn { background: var(--warn); border-color: var(--warn); color:black; }
|
||||
button.ghost { background: transparent; }
|
||||
.muted { color: var(--muted); font-size: 12px; }
|
||||
.mono { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; }
|
||||
.grid2 { display:grid; grid-template-columns: 1fr 1fr; gap: 12px; }
|
||||
@media (max-width: 880px){ .grid2 { grid-template-columns: 1fr; } }
|
||||
.pill { padding: 3px 8px; border-radius:999px; font-size: 12px; border:1px solid var(--border); background:#0f1115; }
|
||||
.ok { color: #10b981; }
|
||||
.err { color: #ef4444; }
|
||||
.sep { height:1px; background: var(--border); margin: 12px 0; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrap">
|
||||
<h1>OAuth2/OIDC 授权码 + PKCE 前端演示</h1>
|
||||
|
||||
<div class="card">
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Issuer(可选,用于自动发现 /.well-known/openid-configuration)</label>
|
||||
<input id="issuer" placeholder="https://your-domain" />
|
||||
<div class="btns"><button class="" id="btnDiscover">自动发现端点</button></div>
|
||||
<div class="muted">提示:若未配置 Issuer,可直接填写下方端点。</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>Authorization Endpoint</label><input id="authorization_endpoint" placeholder="https://domain/api/oauth/authorize" /></div>
|
||||
<div class="col"><label>Token Endpoint</label><input id="token_endpoint" placeholder="https://domain/api/oauth/token" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>UserInfo Endpoint(可选)</label><input id="userinfo_endpoint" placeholder="https://domain/api/oauth/userinfo" /></div>
|
||||
<div class="col"><label>Client ID</label><input id="client_id" placeholder="your-public-client-id" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>Redirect URI(当前页地址或你的回调)</label><input id="redirect_uri" /></div>
|
||||
<div class="col"><label>Scope</label><input id="scope" value="openid profile email" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>State</label><input id="state" /></div>
|
||||
<div class="col"><label>Nonce</label><input id="nonce" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>Code Verifier(自动生成,不会上送)</label><input id="code_verifier" class="mono" readonly /></div>
|
||||
<div class="col"><label>Code Challenge(S256)</label><input id="code_challenge" class="mono" readonly /></div>
|
||||
</div>
|
||||
<div class="btns">
|
||||
<button id="btnGenPkce">生成 PKCE</button>
|
||||
<button id="btnRandomState">随机 State</button>
|
||||
<button id="btnRandomNonce">随机 Nonce</button>
|
||||
<button id="btnMakeAuthURL">生成授权链接</button>
|
||||
<button id="btnAuthorize" class="primary">跳转授权</button>
|
||||
</div>
|
||||
<div class="row" style="margin-top:8px;">
|
||||
<div class="col">
|
||||
<label>授权链接(只生成不跳转)</label>
|
||||
<textarea id="authorize_url" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns"><button id="btnCopyAuthURL">复制链接</button></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="sep"></div>
|
||||
<div class="muted">说明:
|
||||
<ul>
|
||||
<li>本页为纯前端演示,适用于公开客户端(不需要 client_secret)。</li>
|
||||
<li>如跨域调用 Token/UserInfo,需要服务端正确设置 CORS;建议将此 demo 部署到同源域名下。</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="sep"></div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>粘贴 OIDC Discovery JSON(/.well-known/openid-configuration)</label>
|
||||
<textarea id="conf_json" class="mono" placeholder='{"issuer":"https://...","authorization_endpoint":"...","token_endpoint":"...","userinfo_endpoint":"..."}'></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnParseConf">解析并填充端点</button>
|
||||
<button id="btnGenConf">用当前端点生成 JSON</button>
|
||||
</div>
|
||||
<div class="muted">可将服务端返回的 OIDC Discovery JSON 粘贴到此处,点击“解析并填充端点”。</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>授权结果</label>
|
||||
<div id="authResult" class="muted">等待授权...</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid2" style="margin-top:12px;">
|
||||
<div>
|
||||
<label>Access Token</label>
|
||||
<textarea id="access_token" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnCopyAT">复制</button>
|
||||
<button id="btnCallUserInfo" class="ok">调用 UserInfo</button>
|
||||
</div>
|
||||
<div id="userinfoOut" class="muted" style="margin-top:6px;"></div>
|
||||
</div>
|
||||
<div>
|
||||
<label>ID Token(JWT)</label>
|
||||
<textarea id="id_token" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnDecodeJWT">解码显示 Claims</button>
|
||||
</div>
|
||||
<pre id="jwtClaims" class="mono" style="white-space:pre-wrap; word-break:break-all; margin-top:6px;"></pre>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid2" style="margin-top:12px;">
|
||||
<div>
|
||||
<label>Refresh Token</label>
|
||||
<textarea id="refresh_token" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnRefreshToken">使用 Refresh Token 刷新</button>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label>原始 Token 响应</label>
|
||||
<textarea id="token_raw" class="mono" placeholder="(空)"></textarea>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const $ = (id) => document.getElementById(id);
|
||||
const toB64Url = (buf) => btoa(String.fromCharCode(...new Uint8Array(buf))).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '');
|
||||
async function sha256B64Url(str){
|
||||
const data = new TextEncoder().encode(str);
|
||||
const digest = await crypto.subtle.digest('SHA-256', data);
|
||||
return toB64Url(digest);
|
||||
}
|
||||
function randStr(len=64){
|
||||
const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';
|
||||
const arr = new Uint8Array(len); crypto.getRandomValues(arr);
|
||||
return Array.from(arr, v => chars[v % chars.length]).join('');
|
||||
}
|
||||
function setAuthInfo(msg, ok=true){
|
||||
const el = $('authResult');
|
||||
el.textContent = msg;
|
||||
el.className = ok ? 'ok' : 'err';
|
||||
}
|
||||
function qs(name){ const u=new URL(location.href); return u.searchParams.get(name); }
|
||||
|
||||
function persist(name, val){ sessionStorage.setItem('demo_'+name, val); }
|
||||
function load(name){ return sessionStorage.getItem('demo_'+name) || ''; }
|
||||
|
||||
// init defaults
|
||||
(function init(){
|
||||
$('redirect_uri').value = window.location.origin + window.location.pathname;
|
||||
// try load from discovery if issuer saved previously
|
||||
const iss = load('issuer'); if(iss) $('issuer').value = iss;
|
||||
const cid = load('client_id'); if(cid) $('client_id').value = cid;
|
||||
const scp = load('scope'); if(scp) $('scope').value = scp;
|
||||
})();
|
||||
|
||||
$('btnDiscover').onclick = async () => {
|
||||
const iss = $('issuer').value.trim(); if(!iss){ alert('请填写 Issuer'); return; }
|
||||
try{
|
||||
persist('issuer', iss);
|
||||
const res = await fetch(iss.replace(/\/$/,'') + '/api/.well-known/openid-configuration');
|
||||
const d = await res.json();
|
||||
$('authorization_endpoint').value = d.authorization_endpoint || '';
|
||||
$('token_endpoint').value = d.token_endpoint || '';
|
||||
$('userinfo_endpoint').value = d.userinfo_endpoint || '';
|
||||
if (d.issuer) { $('issuer').value = d.issuer; persist('issuer', d.issuer); }
|
||||
$('conf_json').value = JSON.stringify(d, null, 2);
|
||||
setAuthInfo('已从发现文档加载端点', true);
|
||||
}catch(e){ setAuthInfo('自动发现失败:'+e, false); }
|
||||
};
|
||||
|
||||
$('btnGenPkce').onclick = async () => {
|
||||
const v = randStr(64); const c = await sha256B64Url(v);
|
||||
$('code_verifier').value = v; $('code_challenge').value = c;
|
||||
persist('code_verifier', v); persist('code_challenge', c);
|
||||
setAuthInfo('已生成 PKCE 参数', true);
|
||||
};
|
||||
$('btnRandomState').onclick = () => { $('state').value = randStr(16); persist('state', $('state').value); };
|
||||
$('btnRandomNonce').onclick = () => { $('nonce').value = randStr(16); persist('nonce', $('nonce').value); };
|
||||
|
||||
function buildAuthorizeURLFromFields() {
|
||||
const auth = $('authorization_endpoint').value.trim();
|
||||
const token = $('token_endpoint').value.trim(); // just validate
|
||||
const cid = $('client_id').value.trim();
|
||||
const red = $('redirect_uri').value.trim();
|
||||
const scp = $('scope').value.trim() || 'openid profile email';
|
||||
const st = $('state').value.trim() || randStr(16);
|
||||
const no = $('nonce').value.trim() || randStr(16);
|
||||
const cc = $('code_challenge').value.trim();
|
||||
const cv = $('code_verifier').value.trim();
|
||||
if(!auth || !token || !cid || !red){ throw new Error('请先完善端点/ClientID/RedirectURI'); }
|
||||
if(!cc || !cv){ throw new Error('请先生成 PKCE'); }
|
||||
persist('authorization_endpoint', auth); persist('token_endpoint', token);
|
||||
persist('client_id', cid); persist('redirect_uri', red); persist('scope', scp);
|
||||
persist('state', st); persist('nonce', no); persist('code_verifier', cv);
|
||||
const u = new URL(auth);
|
||||
u.searchParams.set('response_type', 'code');
|
||||
u.searchParams.set('client_id', cid);
|
||||
u.searchParams.set('redirect_uri', red);
|
||||
u.searchParams.set('scope', scp);
|
||||
u.searchParams.set('state', st);
|
||||
u.searchParams.set('nonce', no);
|
||||
u.searchParams.set('code_challenge', cc);
|
||||
u.searchParams.set('code_challenge_method', 'S256');
|
||||
return u.toString();
|
||||
}
|
||||
$('btnMakeAuthURL').onclick = () => {
|
||||
try {
|
||||
const url = buildAuthorizeURLFromFields();
|
||||
$('authorize_url').value = url;
|
||||
setAuthInfo('已生成授权链接', true);
|
||||
} catch(e){ setAuthInfo(e.message, false); }
|
||||
};
|
||||
$('btnAuthorize').onclick = () => {
|
||||
try { const url = buildAuthorizeURLFromFields(); location.href = url; }
|
||||
catch(e){ setAuthInfo(e.message, false); }
|
||||
};
|
||||
$('btnCopyAuthURL').onclick = async () => { try{ await navigator.clipboard.writeText($('authorize_url').value); }catch{} };
|
||||
|
||||
// Parse OIDC discovery JSON pasted by user
|
||||
$('btnParseConf').onclick = () => {
|
||||
const txt = $('conf_json').value.trim(); if(!txt){ alert('请先粘贴 JSON'); return; }
|
||||
try{
|
||||
const d = JSON.parse(txt);
|
||||
if (d.issuer) { $('issuer').value = d.issuer; persist('issuer', d.issuer); }
|
||||
if (d.authorization_endpoint) $('authorization_endpoint').value = d.authorization_endpoint;
|
||||
if (d.token_endpoint) $('token_endpoint').value = d.token_endpoint;
|
||||
if (d.userinfo_endpoint) $('userinfo_endpoint').value = d.userinfo_endpoint;
|
||||
setAuthInfo('已解析配置并填充端点', true);
|
||||
}catch(e){ setAuthInfo('解析失败:'+e, false); }
|
||||
};
|
||||
// Generate a minimal discovery JSON from current fields
|
||||
$('btnGenConf').onclick = () => {
|
||||
const d = {
|
||||
issuer: $('issuer').value.trim() || undefined,
|
||||
authorization_endpoint: $('authorization_endpoint').value.trim() || undefined,
|
||||
token_endpoint: $('token_endpoint').value.trim() || undefined,
|
||||
userinfo_endpoint: $('userinfo_endpoint').value.trim() || undefined,
|
||||
};
|
||||
$('conf_json').value = JSON.stringify(d, null, 2);
|
||||
};
|
||||
|
||||
async function postForm(url, data){
|
||||
const body = Object.entries(data).map(([k,v])=> `${encodeURIComponent(k)}=${encodeURIComponent(v)}`).join('&');
|
||||
const res = await fetch(url, { method:'POST', headers:{ 'Content-Type':'application/x-www-form-urlencoded' }, body });
|
||||
if(!res.ok){ const t = await res.text(); throw new Error(`HTTP ${res.status} ${t}`); }
|
||||
return res.json();
|
||||
}
|
||||
|
||||
async function handleCallback(){
|
||||
const code = qs('code'); const err = qs('error');
|
||||
const state = qs('state');
|
||||
if(err){ setAuthInfo('授权失败:'+err, false); return; }
|
||||
if(!code){ setAuthInfo('等待授权...', true); return; }
|
||||
// state check
|
||||
if(state && load('state') && state !== load('state')){ setAuthInfo('state 不匹配,已拒绝', false); return; }
|
||||
try{
|
||||
const tokenEp = load('token_endpoint');
|
||||
const data = await postForm(tokenEp, {
|
||||
grant_type:'authorization_code',
|
||||
code,
|
||||
client_id: load('client_id'),
|
||||
redirect_uri: load('redirect_uri'),
|
||||
code_verifier: load('code_verifier')
|
||||
});
|
||||
$('access_token').value = data.access_token || '';
|
||||
$('id_token').value = data.id_token || '';
|
||||
$('refresh_token').value = data.refresh_token || '';
|
||||
$('token_raw').value = JSON.stringify(data, null, 2);
|
||||
setAuthInfo('授权成功,已获取令牌', true);
|
||||
}catch(e){ setAuthInfo('交换令牌失败:'+e.message, false); }
|
||||
}
|
||||
handleCallback();
|
||||
|
||||
$('btnCopyAT').onclick = async () => { try{ await navigator.clipboard.writeText($('access_token').value); }catch{} };
|
||||
$('btnDecodeJWT').onclick = () => {
|
||||
const t = $('id_token').value.trim(); if(!t){ $('jwtClaims').textContent='(空)'; return; }
|
||||
const parts = t.split('.'); if(parts.length<2){ $('jwtClaims').textContent='格式错误'; return; }
|
||||
try{ const json = JSON.parse(atob(parts[1].replace(/-/g,'+').replace(/_/g,'/'))); $('jwtClaims').textContent = JSON.stringify(json, null, 2);}catch(e){ $('jwtClaims').textContent='解码失败:'+e; }
|
||||
};
|
||||
$('btnCallUserInfo').onclick = async () => {
|
||||
const at = $('access_token').value.trim(); const ep = $('userinfo_endpoint').value.trim(); if(!at||!ep){ alert('请填写UserInfo端点并获取AccessToken'); return; }
|
||||
try{
|
||||
const res = await fetch(ep, { headers:{ Authorization: 'Bearer '+at } });
|
||||
const data = await res.json(); $('userinfoOut').textContent = JSON.stringify(data, null, 2);
|
||||
}catch(e){ $('userinfoOut').textContent = '调用失败:'+e; }
|
||||
};
|
||||
$('btnRefreshToken').onclick = async () => {
|
||||
const rt = $('refresh_token').value.trim(); if(!rt){ alert('没有刷新令牌'); return; }
|
||||
try{
|
||||
const tokenEp = load('token_endpoint');
|
||||
const data = await postForm(tokenEp, {
|
||||
grant_type:'refresh_token',
|
||||
refresh_token: rt,
|
||||
client_id: load('client_id')
|
||||
});
|
||||
$('access_token').value = data.access_token || '';
|
||||
$('id_token').value = data.id_token || '';
|
||||
$('refresh_token').value = data.refresh_token || '';
|
||||
$('token_raw').value = JSON.stringify(data, null, 2);
|
||||
setAuthInfo('刷新成功', true);
|
||||
}catch(e){ setAuthInfo('刷新失败:'+e.message, false); }
|
||||
};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
181
examples/oauth/oauth2_test_client.go
Normal file
181
examples/oauth/oauth2_test_client.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 测试 Client Credentials 流程
|
||||
//testClientCredentials()
|
||||
|
||||
// 测试 Authorization Code + PKCE 流程(需要浏览器交互)
|
||||
testAuthorizationCode()
|
||||
}
|
||||
|
||||
// testClientCredentials 测试服务对服务认证
|
||||
func testClientCredentials() {
|
||||
fmt.Println("=== Testing Client Credentials Flow ===")
|
||||
|
||||
cfg := clientcredentials.Config{
|
||||
ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端
|
||||
ClientSecret: "hLLdn2Ia4UM7hcsJaSuUFDV0Px9BrkNq",
|
||||
TokenURL: "http://localhost:3000/api/oauth/token",
|
||||
Scopes: []string{"api:read", "api:write"},
|
||||
EndpointParams: map[string][]string{
|
||||
"audience": {"api://new-api"},
|
||||
},
|
||||
}
|
||||
|
||||
// 创建HTTP客户端
|
||||
httpClient := cfg.Client(context.Background())
|
||||
|
||||
// 调用受保护的API
|
||||
resp, err := httpClient.Get("http://localhost:3000/api/status")
|
||||
if err != nil {
|
||||
log.Printf("Request failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Status: %s\n", resp.Status)
|
||||
fmt.Printf("Response: %s\n", string(body))
|
||||
}
|
||||
|
||||
// testAuthorizationCode 测试授权码流程
|
||||
func testAuthorizationCode() {
|
||||
fmt.Println("=== Testing Authorization Code + PKCE Flow ===")
|
||||
|
||||
conf := oauth2.Config{
|
||||
ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端
|
||||
ClientSecret: "JHiugKf89OMmTLuZMZyA2sgZnO0Ioae3",
|
||||
RedirectURL: "http://localhost:9999/callback",
|
||||
// 包含 openid/profile/email 以便调用 UserInfo
|
||||
Scopes: []string{"openid", "profile", "email", "api:read"},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "http://localhost:3000/api/oauth/authorize",
|
||||
TokenURL: "http://localhost:3000/api/oauth/token",
|
||||
},
|
||||
}
|
||||
|
||||
// 生成PKCE参数
|
||||
codeVerifier := oauth2.GenerateVerifier()
|
||||
state := fmt.Sprintf("state-%d", time.Now().Unix())
|
||||
|
||||
// 构建授权URL
|
||||
url := conf.AuthCodeURL(
|
||||
state,
|
||||
oauth2.S256ChallengeOption(codeVerifier),
|
||||
//oauth2.SetAuthURLParam("audience", "api://new-api"),
|
||||
)
|
||||
|
||||
fmt.Printf("Visit this URL to authorize:\n%s\n\n", url)
|
||||
fmt.Printf("A local server will listen on http://localhost:9999/callback to receive the code...\n")
|
||||
|
||||
// 启动回调本地服务器,自动接收授权码
|
||||
codeCh := make(chan string, 1)
|
||||
srv := &http.Server{Addr: ":9999"}
|
||||
http.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
if errParam := q.Get("error"); errParam != "" {
|
||||
fmt.Fprintf(w, "Authorization failed: %s", errParam)
|
||||
return
|
||||
}
|
||||
gotState := q.Get("state")
|
||||
if gotState != state {
|
||||
http.Error(w, "state mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
code := q.Get("code")
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
fmt.Fprintln(w, "Authorization received. You may close this window.")
|
||||
select {
|
||||
case codeCh <- code:
|
||||
default:
|
||||
}
|
||||
go func() {
|
||||
// 稍后关闭服务
|
||||
_ = srv.Shutdown(context.Background())
|
||||
}()
|
||||
})
|
||||
go func() {
|
||||
_ = srv.ListenAndServe()
|
||||
}()
|
||||
|
||||
// 等待授权码
|
||||
var code string
|
||||
select {
|
||||
case code = <-codeCh:
|
||||
case <-time.After(5 * time.Minute):
|
||||
log.Println("Timeout waiting for authorization code")
|
||||
_ = srv.Shutdown(context.Background())
|
||||
return
|
||||
}
|
||||
|
||||
// 交换令牌
|
||||
token, err := conf.Exchange(
|
||||
context.Background(),
|
||||
code,
|
||||
oauth2.VerifierOption(codeVerifier),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Token exchange failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Access Token: %s\n", token.AccessToken)
|
||||
fmt.Printf("Token Type: %s\n", token.TokenType)
|
||||
fmt.Printf("Expires In: %v\n", token.Expiry)
|
||||
|
||||
// 使用令牌调用 UserInfo
|
||||
client := conf.Client(context.Background(), token)
|
||||
userInfoURL := buildUserInfoFromAuth(conf.Endpoint.AuthURL)
|
||||
resp, err := client.Get(userInfoURL)
|
||||
if err != nil {
|
||||
log.Printf("UserInfo request failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read UserInfo response: %v", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("UserInfo: %s\n", string(body))
|
||||
}
|
||||
|
||||
// buildUserInfoFromAuth 将授权端点URL转换为UserInfo端点URL
|
||||
func buildUserInfoFromAuth(auth string) string {
|
||||
u, err := url.Parse(auth)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// 将最后一个路径段 authorize 替换为 userinfo
|
||||
dir := path.Dir(u.Path)
|
||||
if strings.HasSuffix(u.Path, "/authorize") {
|
||||
u.Path = path.Join(dir, "userinfo")
|
||||
} else {
|
||||
// 回退:追加默认 /oauth/userinfo
|
||||
u.Path = path.Join(dir, "userinfo")
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
23
go.mod
23
go.mod
@@ -11,20 +11,24 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
|
||||
github.com/aws/smithy-go v1.22.5
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
github.com/gin-contrib/static v0.0.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/glebarez/sqlite v1.9.0
|
||||
github.com/go-oauth2/gin-server v1.1.0
|
||||
github.com/go-oauth2/oauth2/v4 v4.5.4
|
||||
github.com/go-playground/validator/v10 v10.20.0
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.6
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/samber/lo v1.39.0
|
||||
@@ -38,6 +42,7 @@ require (
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.35.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.11.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
@@ -55,6 +60,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
@@ -65,7 +71,7 @@ require (
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/goccy/go-json v0.10.3 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
@@ -79,14 +85,25 @@ require (
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.3 // indirect
|
||||
github.com/lestrrat-go/httpcc v1.0.1 // indirect
|
||||
github.com/lestrrat-go/httprc v1.0.6 // indirect
|
||||
github.com/lestrrat-go/iter v1.0.2 // indirect
|
||||
github.com/lestrrat-go/option v1.0.1 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/segmentio/asm v1.2.0 // indirect
|
||||
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 // indirect
|
||||
github.com/tidwall/buntdb v1.1.2 // indirect
|
||||
github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e // indirect
|
||||
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
@@ -94,7 +111,7 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.12.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
94
go.sum
94
go.sum
@@ -1,5 +1,7 @@
|
||||
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
|
||||
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||
@@ -23,8 +25,8 @@ github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
|
||||
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
||||
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8=
|
||||
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
@@ -39,16 +41,22 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
|
||||
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gavv/httpexpect v2.0.0+incompatible h1:1X9kcRshkSKEjNJJxX9Y9mQ5BRfbxU5kORdjhlA1yX8=
|
||||
github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc=
|
||||
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
|
||||
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
|
||||
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
|
||||
@@ -67,6 +75,10 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
|
||||
github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
|
||||
github.com/go-oauth2/gin-server v1.1.0 h1:+7AyIfrcKaThZxxABRYECysxAfTccgpFdAqY1enuzBk=
|
||||
github.com/go-oauth2/gin-server v1.1.0/go.mod h1:f08F3l5/Pbayb4pjnv5PpUdQLFejgGfHrTjA6IZb0eM=
|
||||
github.com/go-oauth2/oauth2/v4 v4.5.4 h1:YjI0tmGW8oxVhn9QSBIxlr641QugWrJY5UWa6XmLcW0=
|
||||
github.com/go-oauth2/oauth2/v4 v4.5.4/go.mod h1:BXiOY+QZtZy2ewbsGk2B5P8TWmtz/Rf7ES5ZttQFxfQ=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
@@ -90,20 +102,26 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
|
||||
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
|
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0=
|
||||
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||
@@ -112,6 +130,8 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
|
||||
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -132,6 +152,10 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||
github.com/klauspost/compress v1.15.0 h1:xqfchp4whNFxn5A4XFyyYtitiWI8Hy5EW59jEwcyL6U=
|
||||
github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
|
||||
@@ -148,6 +172,18 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
|
||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs=
|
||||
github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
|
||||
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
|
||||
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
|
||||
github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k=
|
||||
github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
|
||||
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
|
||||
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.6 h1:hxM1gfDILk/l5ylers6BX/Eq1m/pnxe9NBwW6lVfecA=
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.6/go.mod h1:Y722kU5r/8mV7fYDifjug0r8FK8mZdw0K0GpJw/l8pU=
|
||||
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
|
||||
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
@@ -160,6 +196,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
|
||||
github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
@@ -184,10 +222,18 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
||||
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||
github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
|
||||
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0=
|
||||
github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
|
||||
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
|
||||
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -200,21 +246,35 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
|
||||
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
|
||||
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
|
||||
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 h1:G6Z6HvJuPjG6XfNGi/feOATzeJrfgTNJY+rGrHbA04E=
|
||||
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
|
||||
github.com/tidwall/buntdb v1.1.2 h1:noCrqQXL9EKMtcdwJcmuVKSEjqu1ua99RHHgbLTEHRo=
|
||||
github.com/tidwall/buntdb v1.1.2/go.mod h1:xAzi36Hir4FarpSHyfuZ6JzPJdjRZ8QlLZSntE2mqlI=
|
||||
github.com/tidwall/gjson v1.3.4/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb h1:5NSYaAdrnblKByzd7XByQEJVT8+9v0W/tIY0Oo4OwrE=
|
||||
github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb/go.mod h1:lKYYLFIr9OIgdgrtgkZ9zgRxRdvPYsExnYBsEAd8W5M=
|
||||
github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e h1:+NL1GDIUOKxVfbp2KoJQD9cTQ6dyP2co9q4yzmT9FZo=
|
||||
github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e/go.mod h1:/h+UnNGt0IhNNJLkGikcdcJqm66zGD/uJGMRxK/9+Ao=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 h1:Otn9S136ELckZ3KKDyCkxapfufrqDqwmGjcHfAyXRrE=
|
||||
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
@@ -229,8 +289,24 @@ github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLY
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4=
|
||||
github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
|
||||
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
|
||||
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY=
|
||||
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI=
|
||||
github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA=
|
||||
github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg=
|
||||
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M=
|
||||
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM=
|
||||
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
|
||||
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -247,6 +323,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
@@ -257,12 +335,12 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
|
||||
11
main.go
11
main.go
@@ -14,6 +14,7 @@ import (
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/src/oauth"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
@@ -203,5 +204,13 @@ func InitResources() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize OAuth2 server
|
||||
err = oauth.InitOAuthServer()
|
||||
if err != nil {
|
||||
common.SysLog("Warning: Failed to initialize OAuth2 server: " + err.Error())
|
||||
// OAuth2 失败不应该阻止系统启动
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -8,11 +8,14 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/src/oauth"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func validUserInfo(username string, role int) bool {
|
||||
@@ -177,6 +180,7 @@ func WssAuth(c *gin.Context) {
|
||||
|
||||
func TokenAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
rawAuth := c.Request.Header.Get("Authorization")
|
||||
// 先检测是否为ws
|
||||
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
|
||||
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
|
||||
@@ -235,6 +239,11 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// OAuth Bearer fallback
|
||||
if tryOAuthBearer(c, rawAuth) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -288,6 +297,74 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// tryOAuthBearer validates an OAuth JWT access token and sets minimal context for relay
|
||||
func tryOAuthBearer(c *gin.Context, rawAuth string) bool {
|
||||
if rawAuth == "" || !strings.HasPrefix(rawAuth, "Bearer ") {
|
||||
return false
|
||||
}
|
||||
tokenString := strings.TrimSpace(strings.TrimPrefix(rawAuth, "Bearer "))
|
||||
if tokenString == "" {
|
||||
return false
|
||||
}
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
// Parse & verify
|
||||
parsed, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if settings.JWTKeyID != "" && kid != settings.JWTKeyID {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
}
|
||||
pub := oauth.GetRSAPublicKey()
|
||||
if pub == nil {
|
||||
return nil, jwt.ErrTokenUnverifiable
|
||||
}
|
||||
return pub, nil
|
||||
})
|
||||
if err != nil || parsed == nil || !parsed.Valid {
|
||||
return false
|
||||
}
|
||||
claims, ok := parsed.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// issuer check when configured
|
||||
if iss, ok2 := claims["iss"].(string); !ok2 || (settings.Issuer != "" && iss != settings.Issuer) {
|
||||
return false
|
||||
}
|
||||
// revoke check
|
||||
if jti, ok2 := claims["jti"].(string); ok2 && jti != "" {
|
||||
if revoked, _ := model.IsTokenRevoked(jti); revoked {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// scope check: must contain api:read or api:write or admin
|
||||
scope, _ := claims["scope"].(string)
|
||||
scopePadded := " " + scope + " "
|
||||
if !(strings.Contains(scopePadded, " api:read ") || strings.Contains(scopePadded, " api:write ") || strings.Contains(scopePadded, " admin ")) {
|
||||
return false
|
||||
}
|
||||
// subject must be user id to support quota logic
|
||||
sub, _ := claims["sub"].(string)
|
||||
uid, err := strconv.Atoi(sub)
|
||||
if err != nil || uid <= 0 {
|
||||
return false
|
||||
}
|
||||
// load user cache & set context
|
||||
userCache, err := model.GetUserCache(uid)
|
||||
if err != nil || userCache == nil || userCache.Status != common.UserStatusEnabled {
|
||||
return false
|
||||
}
|
||||
c.Set("id", uid)
|
||||
c.Set("group", userCache.Group)
|
||||
c.Set("user_group", userCache.Group)
|
||||
// set UsingGroup
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userCache.Group)
|
||||
return true
|
||||
}
|
||||
|
||||
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token is nil")
|
||||
|
||||
@@ -166,9 +166,9 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
relayMode := relayconstant.RelayModeUnknown
|
||||
if c.Request.Method == http.MethodPost {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
relayMode = relayconstant.RelayModeVideoSubmit
|
||||
} else if c.Request.Method == http.MethodGet {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
|
||||
291
middleware/oauth_jwt.go
Normal file
291
middleware/oauth_jwt.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/src/oauth"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// OAuthJWTAuth OAuth2 JWT认证中间件
|
||||
func OAuthJWTAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查OAuth2是否启用
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Authorization header
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.Next() // 没有Authorization header,继续到下一个中间件
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为Bearer token
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
c.Next() // 不是Bearer token,继续到下一个中间件
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == "" {
|
||||
abortWithOAuthError(c, "invalid_token", "Missing token")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证JWT token
|
||||
claims, err := validateOAuthJWT(tokenString)
|
||||
if err != nil {
|
||||
abortWithOAuthError(c, "invalid_token", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token的有效性
|
||||
if err := validateOAuthClaims(claims); err != nil {
|
||||
abortWithOAuthError(c, "invalid_token", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 设置上下文信息
|
||||
setOAuthContext(c, claims)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// validateOAuthJWT 验证OAuth2 JWT令牌
|
||||
func validateOAuthJWT(tokenString string) (jwt.MapClaims, error) {
|
||||
// 解析JWT而不验证签名(先获取header中的kid)
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// 检查签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// 获取kid
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing kid in token header")
|
||||
}
|
||||
|
||||
// 根据kid获取公钥
|
||||
publicKey, err := getPublicKeyByKid(kid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
return publicKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// getPublicKeyByKid 根据kid获取公钥
|
||||
func getPublicKeyByKid(kid string) (*rsa.PublicKey, error) {
|
||||
// 这里需要从JWKS获取公钥
|
||||
// 在实际实现中,你可能需要从OAuth server获取JWKS
|
||||
// 这里先实现一个简单版本
|
||||
|
||||
// TODO: 实现JWKS缓存和刷新机制
|
||||
pub := oauth.GetPublicKeyByKid(kid)
|
||||
if pub == nil {
|
||||
return nil, fmt.Errorf("unknown kid: %s", kid)
|
||||
}
|
||||
return pub, nil
|
||||
}
|
||||
|
||||
// validateOAuthClaims 验证OAuth2 claims
|
||||
func validateOAuthClaims(claims jwt.MapClaims) error {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
|
||||
// 验证issuer(若配置了 Issuer 则强校验,否则仅要求存在)
|
||||
if iss, ok := claims["iss"].(string); ok {
|
||||
if settings.Issuer != "" && iss != settings.Issuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("missing issuer claim")
|
||||
}
|
||||
|
||||
// 验证audience
|
||||
// if aud, ok := claims["aud"].(string); ok {
|
||||
// // TODO: 验证audience
|
||||
// }
|
||||
|
||||
// 验证客户端ID
|
||||
if clientID, ok := claims["client_id"].(string); ok {
|
||||
// 验证客户端是否存在且有效
|
||||
client, err := model.GetOAuthClientByID(clientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid client")
|
||||
}
|
||||
if client.Status != common.UserStatusEnabled {
|
||||
return fmt.Errorf("client disabled")
|
||||
}
|
||||
|
||||
// 检查是否被撤销
|
||||
if jti, ok := claims["jti"].(string); ok && jti != "" {
|
||||
revoked, _ := model.IsTokenRevoked(jti)
|
||||
if revoked {
|
||||
return fmt.Errorf("token revoked")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("missing client_id claim")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setOAuthContext 设置OAuth上下文信息
|
||||
func setOAuthContext(c *gin.Context, claims jwt.MapClaims) {
|
||||
c.Set("oauth_claims", claims)
|
||||
c.Set("oauth_authenticated", true)
|
||||
|
||||
// 提取基本信息
|
||||
if clientID, ok := claims["client_id"].(string); ok {
|
||||
c.Set("oauth_client_id", clientID)
|
||||
}
|
||||
|
||||
if scope, ok := claims["scope"].(string); ok {
|
||||
c.Set("oauth_scope", scope)
|
||||
}
|
||||
|
||||
if sub, ok := claims["sub"].(string); ok {
|
||||
c.Set("oauth_subject", sub)
|
||||
}
|
||||
|
||||
// 对于client_credentials流程,subject就是client_id
|
||||
// 对于authorization_code流程,subject是用户ID
|
||||
if grantType, ok := claims["grant_type"].(string); ok {
|
||||
c.Set("oauth_grant_type", grantType)
|
||||
}
|
||||
}
|
||||
|
||||
// abortWithOAuthError 返回OAuth错误响应
|
||||
func abortWithOAuthError(c *gin.Context, errorCode, description string) {
|
||||
c.Header("WWW-Authenticate", fmt.Sprintf(`Bearer error="%s", error_description="%s"`, errorCode, description))
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": errorCode,
|
||||
"error_description": description,
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// RequireOAuthScope OAuth2 scope验证中间件
|
||||
func RequireOAuthScope(requiredScope string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查是否通过OAuth认证
|
||||
if !c.GetBool("oauth_authenticated") {
|
||||
abortWithOAuthError(c, "insufficient_scope", "OAuth2 authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取token的scope
|
||||
scope, exists := c.Get("oauth_scope")
|
||||
if !exists {
|
||||
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
|
||||
return
|
||||
}
|
||||
|
||||
scopeStr, ok := scope.(string)
|
||||
if !ok {
|
||||
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否包含所需的scope
|
||||
scopes := strings.Split(scopeStr, " ")
|
||||
for _, s := range scopes {
|
||||
if strings.TrimSpace(s) == requiredScope {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalOAuthAuth 可选的OAuth认证中间件(不会阻止请求)
|
||||
func OptionalOAuthAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试OAuth认证,但不会阻止请求
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if claims, err := validateOAuthJWT(tokenString); err == nil {
|
||||
if validateOAuthClaims(claims) == nil {
|
||||
setOAuthContext(c, claims)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireOAuthScopeIfPresent enforces scope only when OAuth is present; otherwise no-op
|
||||
func RequireOAuthScopeIfPresent(requiredScope string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !c.GetBool("oauth_authenticated") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
scope, exists := c.Get("oauth_scope")
|
||||
if !exists {
|
||||
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
|
||||
return
|
||||
}
|
||||
scopeStr, ok := scope.(string)
|
||||
if !ok {
|
||||
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
|
||||
return
|
||||
}
|
||||
scopes := strings.Split(scopeStr, " ")
|
||||
for _, s := range scopes {
|
||||
if strings.TrimSpace(s) == requiredScope {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
|
||||
}
|
||||
}
|
||||
|
||||
// GetOAuthClaims 获取OAuth claims
|
||||
func GetOAuthClaims(c *gin.Context) (jwt.MapClaims, bool) {
|
||||
claims, exists := c.Get("oauth_claims")
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
mapClaims, ok := claims.(jwt.MapClaims)
|
||||
return mapClaims, ok
|
||||
}
|
||||
|
||||
// IsOAuthAuthenticated 检查是否通过OAuth认证
|
||||
func IsOAuthAuthenticated(c *gin.Context) bool {
|
||||
return c.GetBool("oauth_authenticated")
|
||||
}
|
||||
@@ -42,7 +42,6 @@ type Channel struct {
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
@@ -51,6 +50,8 @@ type Channel struct {
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
|
||||
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings
|
||||
|
||||
// cache info
|
||||
Keys []string `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
@@ -265,6 +265,7 @@ func migrateDB() error {
|
||||
&Setup{},
|
||||
&TwoFA{},
|
||||
&TwoFABackupCode{},
|
||||
&OAuthClient{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
183
model/oauth_client.go
Normal file
183
model/oauth_client.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// OAuthClient OAuth2 客户端模型
|
||||
type OAuthClient struct {
|
||||
ID string `json:"id" gorm:"type:varchar(64);primaryKey"`
|
||||
Secret string `json:"secret" gorm:"type:varchar(128);not null"`
|
||||
Name string `json:"name" gorm:"type:varchar(255);not null"`
|
||||
Domain string `json:"domain" gorm:"type:varchar(255)"` // 允许的重定向域名
|
||||
RedirectURIs string `json:"redirect_uris" gorm:"type:text"` // JSON array of redirect URIs
|
||||
GrantTypes string `json:"grant_types" gorm:"type:varchar(255);default:'client_credentials'"`
|
||||
Scopes string `json:"scopes" gorm:"type:varchar(255);default:'api:read'"`
|
||||
RequirePKCE bool `json:"require_pkce" gorm:"default:true"`
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // 1: enabled, 2: disabled
|
||||
CreatedBy int `json:"created_by" gorm:"type:int;not null"` // 创建者用户ID
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
LastUsedTime int64 `json:"last_used_time" gorm:"bigint;default:0"`
|
||||
TokenCount int `json:"token_count" gorm:"type:int;default:0"` // 已签发的token数量
|
||||
Description string `json:"description" gorm:"type:text"`
|
||||
ClientType string `json:"client_type" gorm:"type:varchar(32);default:'confidential'"` // confidential, public
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
// GetRedirectURIs 获取重定向URI列表
|
||||
func (c *OAuthClient) GetRedirectURIs() []string {
|
||||
if c.RedirectURIs == "" {
|
||||
return []string{}
|
||||
}
|
||||
var uris []string
|
||||
err := json.Unmarshal([]byte(c.RedirectURIs), &uris)
|
||||
if err != nil {
|
||||
common.SysLog("failed to unmarshal redirect URIs: " + err.Error())
|
||||
return []string{}
|
||||
}
|
||||
return uris
|
||||
}
|
||||
|
||||
// SetRedirectURIs 设置重定向URI列表
|
||||
func (c *OAuthClient) SetRedirectURIs(uris []string) {
|
||||
data, err := json.Marshal(uris)
|
||||
if err != nil {
|
||||
common.SysLog("failed to marshal redirect URIs: " + err.Error())
|
||||
return
|
||||
}
|
||||
c.RedirectURIs = string(data)
|
||||
}
|
||||
|
||||
// GetGrantTypes 获取允许的授权类型列表
|
||||
func (c *OAuthClient) GetGrantTypes() []string {
|
||||
if c.GrantTypes == "" {
|
||||
return []string{"client_credentials"}
|
||||
}
|
||||
return strings.Split(c.GrantTypes, ",")
|
||||
}
|
||||
|
||||
// SetGrantTypes 设置允许的授权类型列表
|
||||
func (c *OAuthClient) SetGrantTypes(types []string) {
|
||||
c.GrantTypes = strings.Join(types, ",")
|
||||
}
|
||||
|
||||
// GetScopes 获取允许的scope列表
|
||||
func (c *OAuthClient) GetScopes() []string {
|
||||
if c.Scopes == "" {
|
||||
return []string{"api:read"}
|
||||
}
|
||||
return strings.Split(c.Scopes, ",")
|
||||
}
|
||||
|
||||
// SetScopes 设置允许的scope列表
|
||||
func (c *OAuthClient) SetScopes(scopes []string) {
|
||||
c.Scopes = strings.Join(scopes, ",")
|
||||
}
|
||||
|
||||
// ValidateRedirectURI 验证重定向URI是否有效
|
||||
func (c *OAuthClient) ValidateRedirectURI(uri string) bool {
|
||||
allowedURIs := c.GetRedirectURIs()
|
||||
for _, allowedURI := range allowedURIs {
|
||||
if allowedURI == uri {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateGrantType 验证授权类型是否被允许
|
||||
func (c *OAuthClient) ValidateGrantType(grantType string) bool {
|
||||
allowedTypes := c.GetGrantTypes()
|
||||
for _, allowedType := range allowedTypes {
|
||||
if allowedType == grantType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateScope 验证scope是否被允许
|
||||
func (c *OAuthClient) ValidateScope(scope string) bool {
|
||||
allowedScopes := c.GetScopes()
|
||||
requestedScopes := strings.Split(scope, " ")
|
||||
|
||||
for _, requestedScope := range requestedScopes {
|
||||
requestedScope = strings.TrimSpace(requestedScope)
|
||||
if requestedScope == "" {
|
||||
continue
|
||||
}
|
||||
found := false
|
||||
for _, allowedScope := range allowedScopes {
|
||||
if allowedScope == requestedScope {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// BeforeCreate GORM hook - 在创建前设置时间
|
||||
func (c *OAuthClient) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
c.CreatedTime = time.Now().Unix()
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateLastUsedTime 更新最后使用时间
|
||||
func (c *OAuthClient) UpdateLastUsedTime() error {
|
||||
c.LastUsedTime = time.Now().Unix()
|
||||
c.TokenCount++
|
||||
return DB.Model(c).Select("last_used_time", "token_count").Updates(c).Error
|
||||
}
|
||||
|
||||
// GetOAuthClientByID 根据ID获取OAuth客户端
|
||||
func GetOAuthClientByID(id string) (*OAuthClient, error) {
|
||||
var client OAuthClient
|
||||
err := DB.Where("id = ? AND status = ?", id, common.UserStatusEnabled).First(&client).Error
|
||||
return &client, err
|
||||
}
|
||||
|
||||
// GetAllOAuthClients 获取所有OAuth客户端
|
||||
func GetAllOAuthClients(startIdx int, num int) ([]*OAuthClient, error) {
|
||||
var clients []*OAuthClient
|
||||
err := DB.Order("created_time desc").Limit(num).Offset(startIdx).Find(&clients).Error
|
||||
return clients, err
|
||||
}
|
||||
|
||||
// SearchOAuthClients 搜索OAuth客户端
|
||||
func SearchOAuthClients(keyword string) ([]*OAuthClient, error) {
|
||||
var clients []*OAuthClient
|
||||
err := DB.Where("name LIKE ? OR id LIKE ? OR description LIKE ?",
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%").Find(&clients).Error
|
||||
return clients, err
|
||||
}
|
||||
|
||||
// CreateOAuthClient 创建OAuth客户端
|
||||
func CreateOAuthClient(client *OAuthClient) error {
|
||||
return DB.Create(client).Error
|
||||
}
|
||||
|
||||
// UpdateOAuthClient 更新OAuth客户端
|
||||
func UpdateOAuthClient(client *OAuthClient) error {
|
||||
return DB.Save(client).Error
|
||||
}
|
||||
|
||||
// DeleteOAuthClient 删除OAuth客户端
|
||||
func DeleteOAuthClient(id string) error {
|
||||
return DB.Where("id = ?", id).Delete(&OAuthClient{}).Error
|
||||
}
|
||||
|
||||
// CountOAuthClients 统计OAuth客户端数量
|
||||
func CountOAuthClients() (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&OAuthClient{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
57
model/oauth_revoked_token.go
Normal file
57
model/oauth_revoked_token.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var revokedMem sync.Map // jti -> exp(unix)
|
||||
|
||||
func RevokeToken(jti string, exp int64) error {
|
||||
if jti == "" {
|
||||
return nil
|
||||
}
|
||||
// Prefer Redis, else in-memory
|
||||
if common.RedisEnabled {
|
||||
ttl := time.Duration(0)
|
||||
if exp > 0 {
|
||||
ttl = time.Until(time.Unix(exp, 0))
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = time.Minute
|
||||
}
|
||||
key := fmt.Sprintf("oauth:revoked:%s", jti)
|
||||
return common.RedisSet(key, "1", ttl)
|
||||
}
|
||||
if exp <= 0 {
|
||||
exp = time.Now().Add(time.Minute).Unix()
|
||||
}
|
||||
revokedMem.Store(jti, exp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsTokenRevoked(jti string) (bool, error) {
|
||||
if jti == "" {
|
||||
return false, nil
|
||||
}
|
||||
if common.RedisEnabled {
|
||||
key := fmt.Sprintf("oauth:revoked:%s", jti)
|
||||
if _, err := common.RedisGet(key); err == nil {
|
||||
return true, nil
|
||||
} else {
|
||||
// Not found or error; treat as not revoked on error to avoid hard failures
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
// In-memory check
|
||||
if v, ok := revokedMem.Load(jti); ok {
|
||||
exp, _ := v.(int64)
|
||||
if exp == 0 || time.Now().Unix() <= exp {
|
||||
return true, nil
|
||||
}
|
||||
revokedMem.Delete(jti)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -66,16 +67,16 @@ func InitOptionMap() {
|
||||
common.OptionMap["SystemName"] = common.SystemName
|
||||
common.OptionMap["Logo"] = common.Logo
|
||||
common.OptionMap["ServerAddress"] = ""
|
||||
common.OptionMap["WorkerUrl"] = setting.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
|
||||
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled)
|
||||
common.OptionMap["WorkerUrl"] = system_setting.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = system_setting.WorkerValidKey
|
||||
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(system_setting.WorkerAllowHttpImageRequestEnabled)
|
||||
common.OptionMap["PayAddress"] = ""
|
||||
common.OptionMap["CustomCallbackAddress"] = ""
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
|
||||
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(operation_setting.Price, 'f', -1, 64)
|
||||
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(operation_setting.USDExchangeRate, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(operation_setting.MinTopUp)
|
||||
common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
|
||||
common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
|
||||
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
|
||||
@@ -85,7 +86,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
||||
common.OptionMap["PayMethods"] = operation_setting.PayMethods2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -271,7 +272,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPSSLEnabled":
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
case "WorkerAllowHttpImageRequestEnabled":
|
||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
system_setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
case "DefaultUseAutoGroup":
|
||||
setting.DefaultUseAutoGroup = boolValue
|
||||
case "ExposeRatioEnabled":
|
||||
@@ -293,29 +294,29 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPToken":
|
||||
common.SMTPToken = value
|
||||
case "ServerAddress":
|
||||
setting.ServerAddress = value
|
||||
system_setting.ServerAddress = value
|
||||
case "WorkerUrl":
|
||||
setting.WorkerUrl = value
|
||||
system_setting.WorkerUrl = value
|
||||
case "WorkerValidKey":
|
||||
setting.WorkerValidKey = value
|
||||
system_setting.WorkerValidKey = value
|
||||
case "PayAddress":
|
||||
setting.PayAddress = value
|
||||
operation_setting.PayAddress = value
|
||||
case "Chats":
|
||||
err = setting.UpdateChatsByJsonString(value)
|
||||
case "AutoGroups":
|
||||
err = setting.UpdateAutoGroupsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
setting.CustomCallbackAddress = value
|
||||
operation_setting.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
setting.EpayId = value
|
||||
operation_setting.EpayId = value
|
||||
case "EpayKey":
|
||||
setting.EpayKey = value
|
||||
operation_setting.EpayKey = value
|
||||
case "Price":
|
||||
setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||
operation_setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||
case "USDExchangeRate":
|
||||
setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
|
||||
operation_setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
setting.MinTopUp, _ = strconv.Atoi(value)
|
||||
operation_setting.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "StripeApiSecret":
|
||||
setting.StripeApiSecret = value
|
||||
case "StripeWebhookSecret":
|
||||
@@ -413,7 +414,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = setting.UpdatePayMethodsByJsonString(value)
|
||||
err = operation_setting.UpdatePayMethodsByJsonString(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -264,9 +264,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
|
||||
@@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
// 检查是否为Nova模型
|
||||
if isNovaModel(request.Model) {
|
||||
novaReq := convertToNovaRequest(request)
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", novaReq)
|
||||
c.Set("is_nova_model", true)
|
||||
return novaReq, nil
|
||||
}
|
||||
|
||||
// 原有的Claude模型处理逻辑
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
@@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
c.Set("request_model", claudeReq.Model)
|
||||
c.Set("converted_request", claudeReq)
|
||||
c.Set("is_nova_model", false)
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package aws
|
||||
|
||||
import "strings"
|
||||
|
||||
var awsModelIDMap = map[string]string{
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
@@ -14,6 +16,11 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
// Nova models
|
||||
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
|
||||
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -58,7 +65,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
}
|
||||
// Nova models - all support three major regions
|
||||
"amazon.nova-micro-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-lite-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-pro-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-premier-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
}}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
"us": "us",
|
||||
@@ -67,3 +94,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
}
|
||||
|
||||
var ChannelName = "aws"
|
||||
|
||||
// 判断是否为Nova模型
|
||||
func isNovaModel(modelId string) bool {
|
||||
return strings.HasPrefix(modelId, "nova-")
|
||||
}
|
||||
|
||||
@@ -34,3 +34,92 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
||||
Thinking: req.Thinking,
|
||||
}
|
||||
}
|
||||
|
||||
// NovaMessage Nova模型使用messages-v1格式
|
||||
type NovaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []NovaContent `json:"content"`
|
||||
}
|
||||
|
||||
type NovaContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type NovaRequest struct {
|
||||
SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
|
||||
Messages []NovaMessage `json:"messages"` // 对话消息列表
|
||||
InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
|
||||
}
|
||||
|
||||
type NovaInferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
|
||||
Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
|
||||
TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
|
||||
TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
|
||||
StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
|
||||
}
|
||||
|
||||
// 转换OpenAI请求为Nova格式
|
||||
func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
|
||||
novaMessages := make([]NovaMessage, len(req.Messages))
|
||||
for i, msg := range req.Messages {
|
||||
novaMessages[i] = NovaMessage{
|
||||
Role: msg.Role,
|
||||
Content: []NovaContent{{Text: msg.StringContent()}},
|
||||
}
|
||||
}
|
||||
|
||||
novaReq := &NovaRequest{
|
||||
SchemaVersion: "messages-v1",
|
||||
Messages: novaMessages,
|
||||
}
|
||||
|
||||
// 设置推理配置
|
||||
if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
|
||||
novaReq.InferenceConfig = &NovaInferenceConfig{}
|
||||
if req.MaxTokens != 0 {
|
||||
novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
|
||||
}
|
||||
if req.Temperature != nil && *req.Temperature != 0 {
|
||||
novaReq.InferenceConfig.Temperature = *req.Temperature
|
||||
}
|
||||
if req.TopP != 0 {
|
||||
novaReq.InferenceConfig.TopP = req.TopP
|
||||
}
|
||||
if req.TopK != 0 {
|
||||
novaReq.InferenceConfig.TopK = req.TopK
|
||||
}
|
||||
if req.Stop != nil {
|
||||
if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
|
||||
novaReq.InferenceConfig.StopSequences = stopSequences
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return novaReq
|
||||
}
|
||||
|
||||
// parseStopSequences 解析停止序列,支持字符串或字符串数组
|
||||
func parseStopSequences(stop any) []string {
|
||||
if stop == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := stop.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
return []string{v}
|
||||
}
|
||||
case []string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var sequences []string
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok && str != "" {
|
||||
sequences = append(sequences, str)
|
||||
}
|
||||
}
|
||||
return sequences
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -93,7 +94,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
}
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
// 检查是否为Nova模型
|
||||
isNova, _ := c.Get("is_nova_model")
|
||||
if isNova == true {
|
||||
// Nova模型也支持跨区域
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
return handleNovaRequest(c, awsCli, info, awsModelId)
|
||||
}
|
||||
|
||||
// 原有的Claude处理逻辑
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
@@ -209,3 +222,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
// Nova模型处理函数
|
||||
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
||||
novaReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
novaReq := novaReq_.(*NovaRequest)
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
// 解析Nova响应
|
||||
var novaResp struct {
|
||||
Output struct {
|
||||
Message struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"output"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"inputTokens"`
|
||||
OutputTokens int `json:"outputTokens"`
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
|
||||
return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
// 构造OpenAI格式响应
|
||||
response := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
Choices: []dto.OpenAITextResponseChoice{{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: novaResp.Output.Message.Content[0].Text,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}},
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: novaResp.Usage.InputTokens,
|
||||
CompletionTokens: novaResp.Usage.OutputTokens,
|
||||
TotalTokens: novaResp.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
return nil, &response.Usage
|
||||
}
|
||||
|
||||
@@ -46,32 +46,6 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
imageOutputCounts := 0
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||
imageOutputCounts++
|
||||
}
|
||||
}
|
||||
}
|
||||
if imageOutputCounts != 0 {
|
||||
usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290
|
||||
usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290
|
||||
c.Set("gemini_image_tokens", imageOutputCounts*1290)
|
||||
}
|
||||
}
|
||||
|
||||
// if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
// for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
// if detail.Modality == "IMAGE" {
|
||||
// usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
// usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
// c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
@@ -162,16 +136,6 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
if detail.Modality == "IMAGE" {
|
||||
usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
@@ -37,6 +36,7 @@ type requestPayload struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
Frames int `json:"frames,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
@@ -89,22 +89,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
req := relaycommon.TaskSubmitReq{}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
@@ -327,18 +312,23 @@ func hmacSHA256(key []byte, data []byte) []byte {
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
ReqKey: "jimeng_vgfm_i2v_l20",
|
||||
Prompt: req.Prompt,
|
||||
AspectRatio: "16:9", // Default aspect ratio
|
||||
Seed: -1, // Default to random
|
||||
ReqKey: req.Model,
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
|
||||
switch req.Duration {
|
||||
case 10:
|
||||
r.Frames = 241 // 24*10+1 = 241
|
||||
default:
|
||||
r.Frames = 121 // 24*5+1 = 121
|
||||
}
|
||||
|
||||
// Handle one-of image_urls or binary_data_base64
|
||||
if req.Image != "" {
|
||||
if strings.HasPrefix(req.Image, "http") {
|
||||
r.ImageUrls = []string{req.Image}
|
||||
if req.HasImage() {
|
||||
if strings.HasPrefix(req.Images[0], "http") {
|
||||
r.ImageUrls = req.Images
|
||||
} else {
|
||||
r.BinaryDataBase64 = []string{req.Image}
|
||||
r.BinaryDataBase64 = req.Images
|
||||
}
|
||||
}
|
||||
metadata := req.Metadata
|
||||
@@ -350,6 +340,22 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
// 即梦视频3.0 ReqKey转换
|
||||
// https://www.volcengine.com/docs/85621/1792707
|
||||
if strings.Contains(r.ReqKey, "jimeng_v30") {
|
||||
if len(r.ImageUrls) > 1 {
|
||||
// 多张图片:首尾帧生成
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
|
||||
} else if len(r.ImageUrls) == 1 {
|
||||
// 单张图片:图生视频
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
|
||||
} else {
|
||||
// 无图片:文生视频
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1)
|
||||
}
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
@@ -28,16 +27,6 @@ import (
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type TrajectoryPoint struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
@@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
var req SubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
// Use the standard validation method for TaskSubmitReq
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
@@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
@@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
|
||||
355
relay/channel/task/vertex/adaptor.go
Normal file
355
relay/channel/task/vertex/adaptor.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
vertexcore "one-api/relay/channel/vertex"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type requestPayload struct {
|
||||
Instances []map[string]any `json:"instances"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type submitResponse struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type operationVideo struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
Encoding string `json:"encoding"`
|
||||
}
|
||||
|
||||
type operationResponse struct {
|
||||
Name string `json:"name"`
|
||||
Done bool `json:"done"`
|
||||
Response struct {
|
||||
Type string `json:"@type"`
|
||||
RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
|
||||
Videos []operationVideo `json:"videos"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
Encoding string `json:"encoding"`
|
||||
Video string `json:"video"`
|
||||
} `json:"response"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Use the standard validation method for TaskSubmitReq
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
modelName := info.OriginModelName
|
||||
if modelName == "" {
|
||||
modelName = "veo-3.0-generate-001"
|
||||
}
|
||||
|
||||
region := vertexcore.GetModelRegion(info.ApiVersion, modelName)
|
||||
if strings.TrimSpace(region) == "" {
|
||||
region = "global"
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Vertex specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body := requestPayload{
|
||||
Instances: []map[string]any{{"prompt": req.Prompt}},
|
||||
Parameters: map[string]any{},
|
||||
}
|
||||
if req.Metadata != nil {
|
||||
if v, ok := req.Metadata["storageUri"]; ok {
|
||||
body.Parameters["storageUri"] = v
|
||||
}
|
||||
if v, ok := req.Metadata["sampleCount"]; ok {
|
||||
body.Parameters["sampleCount"] = v
|
||||
}
|
||||
}
|
||||
if _, ok := body.Parameters["sampleCount"]; !ok {
|
||||
body.Parameters["sampleCount"] = 1
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var s submitResponse
|
||||
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||
}
|
||||
localID := encodeLocalTaskID(s.Name)
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": localID})
|
||||
return localID, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} }
|
||||
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
upstreamName, err := decodeLocalTaskID(taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
region := extractRegionFromOperationName(upstreamName)
|
||||
if region == "" {
|
||||
region = "us-central1"
|
||||
}
|
||||
project := extractProjectFromOperationName(upstreamName)
|
||||
modelName := extractModelFromOperationName(upstreamName)
|
||||
if project == "" || modelName == "" {
|
||||
return nil, fmt.Errorf("cannot extract project/model from operation name")
|
||||
}
|
||||
var url string
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||
}
|
||||
payload := map[string]string{"operationName": upstreamName}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var op operationResponse
|
||||
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||
}
|
||||
ti := &relaycommon.TaskInfo{}
|
||||
if op.Error.Message != "" {
|
||||
ti.Status = model.TaskStatusFailure
|
||||
ti.Reason = op.Error.Message
|
||||
ti.Progress = "100%"
|
||||
return ti, nil
|
||||
}
|
||||
if !op.Done {
|
||||
ti.Status = model.TaskStatusInProgress
|
||||
ti.Progress = "50%"
|
||||
return ti, nil
|
||||
}
|
||||
ti.Status = model.TaskStatusSuccess
|
||||
ti.Progress = "100%"
|
||||
if len(op.Response.Videos) > 0 {
|
||||
v0 := op.Response.Videos[0]
|
||||
if v0.BytesBase64Encoded != "" {
|
||||
mime := strings.TrimSpace(v0.MimeType)
|
||||
if mime == "" {
|
||||
enc := strings.TrimSpace(v0.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
if strings.Contains(enc, "/") {
|
||||
mime = enc
|
||||
} else {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded
|
||||
return ti, nil
|
||||
}
|
||||
}
|
||||
if op.Response.BytesBase64Encoded != "" {
|
||||
enc := strings.TrimSpace(op.Response.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
mime := enc
|
||||
if !strings.Contains(enc, "/") {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded
|
||||
return ti, nil
|
||||
}
|
||||
if op.Response.Video != "" { // some variants use `video` as base64
|
||||
enc := strings.TrimSpace(op.Response.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
mime := enc
|
||||
if !strings.Contains(enc, "/") {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + op.Response.Video
|
||||
return ti, nil
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func encodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
func decodeLocalTaskID(local string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
|
||||
|
||||
func extractRegionFromOperationName(name string) string {
|
||||
m := regionRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
|
||||
|
||||
func extractModelFromOperationName(name string) string {
|
||||
m := modelRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
idx := strings.Index(name, "models/")
|
||||
if idx >= 0 {
|
||||
s := name[idx+len("models/"):]
|
||||
if p := strings.Index(s, "/operations/"); p > 0 {
|
||||
return s[:p]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`)
|
||||
|
||||
func extractProjectFromOperationName(name string) string {
|
||||
m := projectRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -23,16 +23,6 @@ import (
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
@@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var req SubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Image != "" {
|
||||
info.Action = constant.TaskActionGenerate
|
||||
} else {
|
||||
info.Action = constant.TaskActionTextGenerate
|
||||
}
|
||||
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
// Use the unified validation method for TaskSubmitReq with image-based action determination
|
||||
return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
@@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
@@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
var images []string
|
||||
if req.Image != "" {
|
||||
images = []string{req.Image}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
@@ -80,16 +81,64 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &Credentials{}
|
||||
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
|
||||
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
|
||||
a.AccountCredentials = *adc
|
||||
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
adc := &Credentials{}
|
||||
if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
a.AccountCredentials = *adc
|
||||
|
||||
if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
} else {
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||
modelName,
|
||||
suffix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
@@ -111,24 +160,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
suffix = "predict"
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return a.getRequestUrl(info, info.UpstreamModelName, suffix)
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
if info.IsStream {
|
||||
suffix = "streamRawPredict?alt=sse"
|
||||
@@ -139,41 +171,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
model = v
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return a.getRequestUrl(info, model, suffix)
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
return a.getRequestUrl(info, "", "")
|
||||
}
|
||||
return "", errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
if a.AccountCredentials.ProjectID != "" {
|
||||
req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,10 @@ func GetModelRegion(other string, localModelName string) string {
|
||||
if m[localModelName] != nil {
|
||||
return m[localModelName].(string)
|
||||
} else {
|
||||
return m["default"].(string)
|
||||
if v, ok := m["default"]; ok {
|
||||
return v.(string)
|
||||
}
|
||||
return "global"
|
||||
}
|
||||
}
|
||||
return other
|
||||
|
||||
@@ -6,14 +6,15 @@ import (
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
@@ -137,3 +138,45 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
|
||||
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
|
||||
func AcquireAccessToken(creds Credentials, proxy string) (string, error) {
|
||||
signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create signed JWT: %w", err)
|
||||
}
|
||||
return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy)
|
||||
}
|
||||
|
||||
func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) {
|
||||
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
data.Set("assertion", signedJWT)
|
||||
|
||||
var client *http.Client
|
||||
var err error
|
||||
if proxy != "" {
|
||||
client, err = service.NewProxyHttpClient(proxy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
client = service.GetHttpClient()
|
||||
}
|
||||
|
||||
resp, err := client.PostForm(authURL, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if accessToken, ok := result["access_token"].(string); ok {
|
||||
return accessToken, nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -481,11 +481,20 @@ type TaskSubmitReq struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (t TaskSubmitReq) GetPrompt() string {
|
||||
return t.Prompt
|
||||
}
|
||||
|
||||
func (t TaskSubmitReq) HasImage() bool {
|
||||
return len(t.Images) > 0
|
||||
}
|
||||
|
||||
type TaskInfo struct {
|
||||
Code int `json:"code"`
|
||||
TaskID string `json:"task_id"`
|
||||
|
||||
@@ -2,12 +2,23 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type HasPrompt interface {
|
||||
GetPrompt() string
|
||||
}
|
||||
|
||||
type HasImage interface {
|
||||
HasImage() bool
|
||||
}
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
@@ -30,3 +41,72 @@ func GetAPIVersion(c *gin.Context) string {
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
|
||||
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
|
||||
return &dto.TaskError{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
StatusCode: statusCode,
|
||||
LocalError: localError,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
|
||||
info.Action = action
|
||||
c.Set("task_request", requestObj)
|
||||
}
|
||||
|
||||
func validatePrompt(prompt string) *dto.TaskError {
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
|
||||
var req TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
|
||||
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
|
||||
// 兼容单图上传
|
||||
req.Images = []string{req.Image}
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
|
||||
hasPrompt, ok := requestObj.(HasPrompt)
|
||||
if !ok {
|
||||
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
|
||||
action := constant.TaskActionTextGenerate
|
||||
if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
|
||||
action = constant.TaskActionGenerate
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, requestObj)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
var req TaskSubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
|
||||
}
|
||||
|
||||
return ValidateTaskRequestWithImage(c, info, req)
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newApiErr := service.RelayErrorHandler(httpResp, false)
|
||||
newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return newApiErr
|
||||
@@ -195,6 +195,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||||
audioTokens := usage.PromptTokensDetails.AudioTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
@@ -204,6 +206,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
@@ -211,12 +214,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||||
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
|
||||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||||
dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
|
||||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||||
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||||
dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
|
||||
ratio := dModelRatio.Mul(dGroupRatio)
|
||||
@@ -284,6 +289,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
var dCachedCreationTokensWithRatio decimal.Decimal
|
||||
if !dCachedCreationTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
|
||||
}
|
||||
|
||||
// 减去 image tokens
|
||||
var imageTokensWithRatio decimal.Decimal
|
||||
@@ -302,7 +312,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
|
||||
}
|
||||
}
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||
Add(imageTokensWithRatio).
|
||||
Add(dCachedCreationTokensWithRatio)
|
||||
|
||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||
|
||||
@@ -314,22 +326,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
} else {
|
||||
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
}
|
||||
var dGeminiImageOutputQuota decimal.Decimal
|
||||
var imageOutputPrice float64
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
|
||||
imageOutputPrice = operation_setting.GetGeminiImageOutputPricePerMillionTokens(modelName)
|
||||
if imageOutputPrice > 0 {
|
||||
dImageOutputTokens := decimal.NewFromInt(int64(ctx.GetInt("gemini_image_tokens")))
|
||||
dGeminiImageOutputQuota = decimal.NewFromFloat(imageOutputPrice).Div(decimal.NewFromInt(1000000)).Mul(dImageOutputTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
// 添加 responses tools call 调用的配额
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
// 添加 audio input 独立计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
// 添加 Gemini image output 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dGeminiImageOutputQuota)
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
@@ -395,6 +396,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
other["image_ratio"] = imageRatio
|
||||
other["image_output"] = imageTokens
|
||||
}
|
||||
if cachedCreationTokens != 0 {
|
||||
other["cache_creation_tokens"] = cachedCreationTokens
|
||||
other["cache_creation_ratio"] = cachedCreationRatio
|
||||
}
|
||||
if !dWebSearchQuota.IsZero() {
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
|
||||
@@ -424,10 +429,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
if !dGeminiImageOutputQuota.IsZero() {
|
||||
other["image_output_token_count"] = ctx.GetInt("gemini_image_tokens")
|
||||
other["image_output_price"] = imageOutputPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
|
||||
@@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -152,7 +152,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
@@ -249,7 +249,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
@@ -120,7 +120,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
var logContent string
|
||||
|
||||
if len(request.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -131,7 +132,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/constant"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ali"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
taskjimeng "one-api/relay/channel/task/jimeng"
|
||||
"one-api/relay/channel/task/kling"
|
||||
"one-api/relay/channel/task/suno"
|
||||
taskvertex "one-api/relay/channel/task/vertex"
|
||||
taskVidu "one-api/relay/channel/task/vidu"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
@@ -37,6 +37,8 @@ import (
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/channel/zhipu_4v"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) channel.Adaptor {
|
||||
@@ -126,6 +128,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &kling.TaskAdaptor{}
|
||||
case constant.ChannelTypeJimeng:
|
||||
return &taskjimeng.TaskAdaptor{}
|
||||
case constant.ChannelTypeVertexAi:
|
||||
return &taskvertex.TaskAdaptor{}
|
||||
case constant.ChannelTypeVidu:
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -33,6 +35,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
adaptor := GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
@@ -197,6 +200,9 @@ func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||
if taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
if len(respBody) == 0 {
|
||||
respBody = []byte("{\"code\":\"success\",\"data\":null}")
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
@@ -276,10 +282,92 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
func() {
|
||||
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
if channelModel.Type != constant.ChannelTypeVertexAi {
|
||||
return
|
||||
}
|
||||
baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return
|
||||
}
|
||||
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": originTask.TaskID,
|
||||
"action": originTask.Action,
|
||||
})
|
||||
if err2 != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err2 := io.ReadAll(resp.Body)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
ti, err2 := adaptor.ParseTaskResult(body)
|
||||
if err2 == nil && ti != nil {
|
||||
if ti.Status != "" {
|
||||
originTask.Status = model.TaskStatus(ti.Status)
|
||||
}
|
||||
if ti.Progress != "" {
|
||||
originTask.Progress = ti.Progress
|
||||
}
|
||||
if ti.Url != "" {
|
||||
originTask.FailReason = ti.Url
|
||||
}
|
||||
_ = originTask.Update()
|
||||
var raw map[string]any
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
format := "mp4"
|
||||
if respObj, ok := raw["response"].(map[string]any); ok {
|
||||
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
|
||||
if v0, ok := vids[0].(map[string]any); ok {
|
||||
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
|
||||
if strings.Contains(mt, "mp4") {
|
||||
format = "mp4"
|
||||
} else {
|
||||
format = mt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
status := "processing"
|
||||
switch originTask.Status {
|
||||
case model.TaskStatusSuccess:
|
||||
status = "succeeded"
|
||||
case model.TaskStatusFailure:
|
||||
status = "failed"
|
||||
case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
||||
status = "queued"
|
||||
}
|
||||
out := map[string]any{
|
||||
"error": nil,
|
||||
"format": format,
|
||||
"metadata": nil,
|
||||
"status": status,
|
||||
"task_id": originTask.TaskID,
|
||||
"url": originTask.FailReason,
|
||||
}
|
||||
respBody, _ = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: out,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
if len(respBody) == 0 {
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -41,7 +41,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
adaptor.Init(info)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
|
||||
@@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
httpResp = resp.(*http.Response)
|
||||
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -31,6 +31,21 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
|
||||
// OAuth2 Server endpoints
|
||||
apiRouter.GET("/.well-known/jwks.json", controller.GetJWKS)
|
||||
apiRouter.GET("/.well-known/openid-configuration", controller.OAuthOIDCConfiguration)
|
||||
apiRouter.GET("/.well-known/oauth-authorization-server", controller.OAuthServerInfo)
|
||||
apiRouter.POST("/oauth/token", middleware.CriticalRateLimit(), controller.OAuthTokenEndpoint)
|
||||
apiRouter.GET("/oauth/authorize", controller.OAuthAuthorizeEndpoint)
|
||||
apiRouter.POST("/oauth/introspect", middleware.AdminAuth(), controller.OAuthIntrospect)
|
||||
apiRouter.POST("/oauth/revoke", middleware.CriticalRateLimit(), controller.OAuthRevoke)
|
||||
apiRouter.GET("/oauth/userinfo", middleware.OAuthJWTAuth(), controller.OAuthUserInfo)
|
||||
|
||||
// OAuth2 管理API (前端使用)
|
||||
apiRouter.GET("/oauth/jwks", controller.GetJWKS)
|
||||
apiRouter.GET("/oauth/server-info", controller.OAuthServerInfo)
|
||||
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
@@ -40,6 +55,17 @@ func SetApiRouter(router *gin.Engine) {
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
|
||||
// OAuth2 admin operations
|
||||
oauthAdmin := apiRouter.Group("/oauth")
|
||||
oauthAdmin.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth())
|
||||
{
|
||||
oauthAdmin.POST("/keys/rotate", controller.RotateOAuthSigningKey)
|
||||
oauthAdmin.GET("/keys", controller.ListOAuthSigningKeys)
|
||||
oauthAdmin.DELETE("/keys/:kid", controller.DeleteOAuthSigningKey)
|
||||
oauthAdmin.POST("/keys/generate_file", controller.GenerateOAuthSigningKeyFile)
|
||||
oauthAdmin.POST("/keys/import_pem", controller.ImportOAuthSigningKey)
|
||||
}
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
|
||||
@@ -60,6 +86,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.DELETE("/self", controller.DeleteSelf)
|
||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||
selfRoute.GET("/aff", controller.GetAffCode)
|
||||
selfRoute.GET("/topup/info", controller.GetTopUpInfo)
|
||||
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
|
||||
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
|
||||
selfRoute.POST("/amount", controller.RequestAmount)
|
||||
@@ -77,7 +104,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
adminRoute := userRoute.Group("/")
|
||||
adminRoute.Use(middleware.AdminAuth())
|
||||
adminRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
adminRoute.GET("/", controller.GetAllUsers)
|
||||
adminRoute.GET("/search", controller.SearchUsers)
|
||||
@@ -93,7 +120,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
}
|
||||
optionRoute := apiRouter.Group("/option")
|
||||
optionRoute.Use(middleware.RootAuth())
|
||||
optionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth())
|
||||
{
|
||||
optionRoute.GET("/", controller.GetOptions)
|
||||
optionRoute.PUT("/", controller.UpdateOption)
|
||||
@@ -107,7 +134,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
|
||||
}
|
||||
channelRoute := apiRouter.Group("/channel")
|
||||
channelRoute.Use(middleware.AdminAuth())
|
||||
channelRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
channelRoute.GET("/", controller.GetAllChannels)
|
||||
channelRoute.GET("/search", controller.SearchChannels)
|
||||
@@ -158,7 +185,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
redemptionRoute := apiRouter.Group("/redemption")
|
||||
redemptionRoute.Use(middleware.AdminAuth())
|
||||
redemptionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
redemptionRoute.GET("/", controller.GetAllRedemptions)
|
||||
redemptionRoute.GET("/search", controller.SearchRedemptions)
|
||||
@@ -186,13 +213,13 @@ func SetApiRouter(router *gin.Engine) {
|
||||
logRoute.GET("/token", controller.GetLogByKey)
|
||||
}
|
||||
groupRoute := apiRouter.Group("/group")
|
||||
groupRoute.Use(middleware.AdminAuth())
|
||||
groupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
groupRoute.GET("/", controller.GetGroups)
|
||||
}
|
||||
|
||||
prefillGroupRoute := apiRouter.Group("/prefill_group")
|
||||
prefillGroupRoute.Use(middleware.AdminAuth())
|
||||
prefillGroupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
prefillGroupRoute.GET("/", controller.GetPrefillGroups)
|
||||
prefillGroupRoute.POST("/", controller.CreatePrefillGroup)
|
||||
@@ -234,5 +261,17 @@ func SetApiRouter(router *gin.Engine) {
|
||||
modelsRoute.PUT("/", controller.UpdateModelMeta)
|
||||
modelsRoute.DELETE("/:id", controller.DeleteModelMeta)
|
||||
}
|
||||
|
||||
oauthClientsRoute := apiRouter.Group("/oauth_clients")
|
||||
oauthClientsRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
oauthClientsRoute.GET("/", controller.GetAllOAuthClients)
|
||||
oauthClientsRoute.GET("/search", controller.SearchOAuthClients)
|
||||
oauthClientsRoute.GET("/:id", controller.GetOAuthClient)
|
||||
oauthClientsRoute.POST("/", controller.CreateOAuthClient)
|
||||
oauthClientsRoute.PUT("/", controller.UpdateOAuthClient)
|
||||
oauthClientsRoute.DELETE("/:id", controller.DeleteOAuthClient)
|
||||
oauthClientsRoute.POST("/:id/regenerate_secret", controller.RegenerateOAuthClientSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -21,14 +21,14 @@ type WorkerRequest struct {
|
||||
|
||||
// DoWorkerRequest 通过Worker发送请求
|
||||
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
||||
if !setting.EnableWorker() {
|
||||
if !system_setting.EnableWorker() {
|
||||
return nil, fmt.Errorf("worker not enabled")
|
||||
}
|
||||
if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
|
||||
if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
|
||||
return nil, fmt.Errorf("only support https url")
|
||||
}
|
||||
|
||||
workerUrl := setting.WorkerUrl
|
||||
workerUrl := system_setting.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
@@ -43,11 +43,11 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
||||
}
|
||||
|
||||
func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
|
||||
if setting.EnableWorker() {
|
||||
if system_setting.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||
req := &WorkerRequest{
|
||||
URL: originUrl,
|
||||
Key: setting.WorkerValidKey,
|
||||
Key: system_setting.WorkerValidKey,
|
||||
}
|
||||
return DoWorkerRequest(req)
|
||||
} else {
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
)
|
||||
|
||||
func GetCallbackAddress() string {
|
||||
if setting.CustomCallbackAddress == "" {
|
||||
return setting.ServerAddress
|
||||
if operation_setting.CustomCallbackAddress == "" {
|
||||
return system_setting.ServerAddress
|
||||
}
|
||||
return setting.CustomCallbackAddress
|
||||
return operation_setting.CustomCallbackAddress
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
|
||||
return claudeErr
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||
func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
@@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
||||
} else {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||
}
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -13,13 +13,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
|
||||
if preConsumedQuota != 0 {
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
|
||||
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
if relayInfo.FinalPreConsumedQuota != 0 {
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
|
||||
gopool.Go(func() {
|
||||
relayInfoCopy := *relayInfo
|
||||
|
||||
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
||||
err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysLog("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
@@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
|
||||
|
||||
// PreConsumeQuota checks if the user has enough quota to pre-consume.
|
||||
// It returns the pre-consumed quota if successful, or an error if not.
|
||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
|
||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
|
||||
trustQuota := common.GetTrustQuota()
|
||||
@@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if preConsumedQuota > 0 {
|
||||
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
|
||||
}
|
||||
relayInfo.FinalPreConsumedQuota = preConsumedQuota
|
||||
return preConsumedQuota, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -534,7 +534,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
|
||||
}
|
||||
if quotaTooLow {
|
||||
prompt := "您的额度即将用尽"
|
||||
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
||||
topUpLink := fmt.Sprintf("%s/topup", system_setting.ServerAddress)
|
||||
|
||||
// 根据通知方式生成不同的内容格式
|
||||
var content string
|
||||
|
||||
@@ -336,7 +336,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
for i, file := range meta.Files {
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini && !strings.HasPrefix(model, "gemini-2.5-flash-image-preview") {
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 256
|
||||
} else {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -91,11 +91,11 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
if setting.EnableWorker() {
|
||||
if system_setting.EnableWorker() {
|
||||
// 使用worker发送请求
|
||||
workerReq := &WorkerRequest{
|
||||
URL: finalURL,
|
||||
Key: setting.WorkerValidKey,
|
||||
Key: system_setting.WorkerValidKey,
|
||||
Method: http.MethodGet,
|
||||
Headers: map[string]string{
|
||||
"User-Agent": "OneAPI-Bark-Notify/1.0",
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -56,11 +56,11 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
|
||||
var req *http.Request
|
||||
var resp *http.Response
|
||||
|
||||
if setting.EnableWorker() {
|
||||
if system_setting.EnableWorker() {
|
||||
// 构建worker请求数据
|
||||
workerReq := &WorkerRequest{
|
||||
URL: webhookURL,
|
||||
Key: setting.WorkerValidKey,
|
||||
Key: system_setting.WorkerValidKey,
|
||||
Method: http.MethodPost,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
|
||||
@@ -26,7 +26,6 @@ var defaultGeminiSettings = GeminiSettings{
|
||||
SupportedImagineModels: []string{
|
||||
"gemini-2.0-flash-exp-image-generation",
|
||||
"gemini-2.0-flash-exp",
|
||||
"gemini-2.5-flash-image-preview",
|
||||
},
|
||||
ThinkingAdapterEnabled: false,
|
||||
ThinkingAdapterBudgetTokensPercentage: 0.6,
|
||||
|
||||
23
setting/operation_setting/payment_setting.go
Normal file
23
setting/operation_setting/payment_setting.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package operation_setting
|
||||
|
||||
import "one-api/setting/config"
|
||||
|
||||
type PaymentSetting struct {
|
||||
AmountOptions []int `json:"amount_options"`
|
||||
AmountDiscount map[int]float64 `json:"amount_discount"` // 充值金额对应的折扣,例如 100 元 0.9 表示 100 元充值享受 9 折优惠
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var paymentSetting = PaymentSetting{
|
||||
AmountOptions: []int{10, 20, 50, 100, 200, 500},
|
||||
AmountDiscount: map[int]float64{},
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("payment_setting", &paymentSetting)
|
||||
}
|
||||
|
||||
func GetPaymentSetting() *PaymentSetting {
|
||||
return &paymentSetting
|
||||
}
|
||||
@@ -1,6 +1,13 @@
|
||||
package setting
|
||||
/**
|
||||
此文件为旧版支付设置文件,如需增加新的参数、变量等,请在 payment_setting.go 中添加
|
||||
This file is the old version of the payment settings file. If you need to add new parameters, variables, etc., please add them in payment_setting.go
|
||||
*/
|
||||
|
||||
import "encoding/json"
|
||||
package operation_setting
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
@@ -21,15 +28,21 @@ var PayMethods = []map[string]string{
|
||||
"color": "rgba(var(--semi-green-5), 1)",
|
||||
"type": "wxpay",
|
||||
},
|
||||
{
|
||||
"name": "自定义1",
|
||||
"color": "black",
|
||||
"type": "custom1",
|
||||
"min_topup": "50",
|
||||
},
|
||||
}
|
||||
|
||||
func UpdatePayMethodsByJsonString(jsonString string) error {
|
||||
PayMethods = make([]map[string]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &PayMethods)
|
||||
return common.Unmarshal([]byte(jsonString), &PayMethods)
|
||||
}
|
||||
|
||||
func PayMethods2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(PayMethods)
|
||||
jsonBytes, err := common.Marshal(PayMethods)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
@@ -24,10 +24,6 @@ const (
|
||||
ClaudeWebSearchPrice = 10.00
|
||||
)
|
||||
|
||||
const (
|
||||
Gemini25FlashImagePreviewImageOutputPrice = 30.00
|
||||
)
|
||||
|
||||
func GetClaudeWebSearchPricePerThousand() float64 {
|
||||
return ClaudeWebSearchPrice
|
||||
}
|
||||
@@ -69,10 +65,3 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func GetGeminiImageOutputPricePerMillionTokens(modelName string) float64 {
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
|
||||
return Gemini25FlashImagePreviewImageOutputPrice
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -178,7 +178,6 @@ var defaultModelRatio = map[string]float64{
|
||||
"gemini-2.5-flash-lite-preview-thinking-*": 0.05,
|
||||
"gemini-2.5-flash-lite-preview-06-17": 0.05,
|
||||
"gemini-2.5-flash": 0.15,
|
||||
"gemini-2.5-flash-image-preview": 0.15, // $0.30(text/image) / 1M tokens
|
||||
"text-embedding-004": 0.001,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
@@ -294,11 +293,10 @@ var (
|
||||
)
|
||||
|
||||
var defaultCompletionRatio = map[string]float64{
|
||||
"gpt-4-gizmo-*": 2,
|
||||
"gpt-4o-gizmo-*": 3,
|
||||
"gpt-4-all": 2,
|
||||
"gpt-image-1": 8,
|
||||
"gemini-2.5-flash-image-preview": 8.3333333333,
|
||||
"gpt-4-gizmo-*": 2,
|
||||
"gpt-4o-gizmo-*": 3,
|
||||
"gpt-4-all": 2,
|
||||
"gpt-image-1": 8,
|
||||
}
|
||||
|
||||
// InitRatioSettings initializes all model related settings maps
|
||||
|
||||
74
setting/system_setting/oauth2.go
Normal file
74
setting/system_setting/oauth2.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package system_setting
|
||||
|
||||
import "one-api/setting/config"
|
||||
|
||||
type OAuth2Settings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Issuer string `json:"issuer"`
|
||||
AccessTokenTTL int `json:"access_token_ttl"` // in minutes
|
||||
RefreshTokenTTL int `json:"refresh_token_ttl"` // in minutes
|
||||
AllowedGrantTypes []string `json:"allowed_grant_types"` // client_credentials, authorization_code, refresh_token
|
||||
RequirePKCE bool `json:"require_pkce"` // force PKCE for authorization code flow
|
||||
JWTSigningAlgorithm string `json:"jwt_signing_algorithm"`
|
||||
JWTKeyID string `json:"jwt_key_id"`
|
||||
JWTPrivateKeyFile string `json:"jwt_private_key_file"`
|
||||
AutoCreateUser bool `json:"auto_create_user"` // auto create user on first OAuth2 login
|
||||
DefaultUserRole int `json:"default_user_role"` // default role for auto-created users
|
||||
DefaultUserGroup string `json:"default_user_group"` // default group for auto-created users
|
||||
ScopeMappings map[string][]string `json:"scope_mappings"` // scope to permissions mapping
|
||||
MaxJWKSKeys int `json:"max_jwks_keys"` // maximum number of JWKS signing keys to retain
|
||||
DefaultPrivateKeyPath string `json:"default_private_key_path"` // suggested private key file path
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var defaultOAuth2Settings = OAuth2Settings{
|
||||
Enabled: false,
|
||||
AccessTokenTTL: 10, // 10 minutes
|
||||
RefreshTokenTTL: 720, // 12 hours
|
||||
AllowedGrantTypes: []string{"client_credentials", "authorization_code", "refresh_token"},
|
||||
RequirePKCE: true,
|
||||
JWTSigningAlgorithm: "RS256",
|
||||
JWTKeyID: "oauth2-key-1",
|
||||
AutoCreateUser: false,
|
||||
DefaultUserRole: 1, // common user
|
||||
DefaultUserGroup: "default",
|
||||
ScopeMappings: map[string][]string{
|
||||
"api:read": {"read"},
|
||||
"api:write": {"write"},
|
||||
"admin": {"admin"},
|
||||
},
|
||||
MaxJWKSKeys: 3,
|
||||
DefaultPrivateKeyPath: "/etc/new-api/oauth2-private.pem",
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("oauth2", &defaultOAuth2Settings)
|
||||
}
|
||||
|
||||
func GetOAuth2Settings() *OAuth2Settings {
|
||||
return &defaultOAuth2Settings
|
||||
}
|
||||
|
||||
// UpdateOAuth2Settings 更新OAuth2配置
|
||||
func UpdateOAuth2Settings(settings OAuth2Settings) {
|
||||
defaultOAuth2Settings = settings
|
||||
}
|
||||
|
||||
// ValidateGrantType 验证授权类型是否被允许
|
||||
func (s *OAuth2Settings) ValidateGrantType(grantType string) bool {
|
||||
for _, allowedType := range s.AllowedGrantTypes {
|
||||
if allowedType == grantType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetScopePermissions 获取scope对应的权限
|
||||
func (s *OAuth2Settings) GetScopePermissions(scope string) []string {
|
||||
if perms, exists := s.ScopeMappings[scope]; exists {
|
||||
return perms
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package setting
|
||||
package system_setting
|
||||
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var WorkerUrl = ""
|
||||
1069
src/oauth/server.go
Normal file
1069
src/oauth/server.go
Normal file
File diff suppressed because it is too large
Load Diff
82
src/oauth/store.go
Normal file
82
src/oauth/store.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KVStore is a minimal TTL key-value abstraction used by OAuth flows.
|
||||
type KVStore interface {
|
||||
Set(key, value string, ttl time.Duration) error
|
||||
Get(key string) (string, bool)
|
||||
Del(key string) error
|
||||
}
|
||||
|
||||
type redisStore struct{}
|
||||
|
||||
func (r *redisStore) Set(key, value string, ttl time.Duration) error {
|
||||
return common.RedisSet(key, value, ttl)
|
||||
}
|
||||
func (r *redisStore) Get(key string) (string, bool) {
|
||||
v, err := common.RedisGet(key)
|
||||
if err != nil || v == "" {
|
||||
return "", false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
func (r *redisStore) Del(key string) error {
|
||||
return common.RedisDel(key)
|
||||
}
|
||||
|
||||
type memEntry struct {
|
||||
val string
|
||||
exp int64 // unix seconds, 0 means no expiry
|
||||
}
|
||||
|
||||
type memoryStore struct {
|
||||
m sync.Map // key -> memEntry
|
||||
}
|
||||
|
||||
func (m *memoryStore) Set(key, value string, ttl time.Duration) error {
|
||||
var exp int64
|
||||
if ttl > 0 {
|
||||
exp = time.Now().Add(ttl).Unix()
|
||||
}
|
||||
m.m.Store(key, memEntry{val: value, exp: exp})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memoryStore) Get(key string) (string, bool) {
|
||||
v, ok := m.m.Load(key)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
e := v.(memEntry)
|
||||
if e.exp > 0 && time.Now().Unix() > e.exp {
|
||||
m.m.Delete(key)
|
||||
return "", false
|
||||
}
|
||||
return e.val, true
|
||||
}
|
||||
|
||||
func (m *memoryStore) Del(key string) error {
|
||||
m.m.Delete(key)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
memStore = &memoryStore{}
|
||||
rdsStore = &redisStore{}
|
||||
)
|
||||
|
||||
func getStore() KVStore {
|
||||
if common.RedisEnabled {
|
||||
return rdsStore
|
||||
}
|
||||
return memStore
|
||||
}
|
||||
|
||||
func storeSet(key, val string, ttl time.Duration) error { return getStore().Set(key, val, ttl) }
|
||||
func storeGet(key string) (string, bool) { return getStore().Get(key) }
|
||||
func storeDel(key string) error { return getStore().Del(key) }
|
||||
59
src/oauth/util.go
Normal file
59
src/oauth/util.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// getFormOrBasicAuth extracts client_id/client_secret from Basic Auth first, then form
|
||||
func getFormOrBasicAuth(c *gin.Context) (clientID, clientSecret string) {
|
||||
id, secret, ok := c.Request.BasicAuth()
|
||||
if ok {
|
||||
return strings.TrimSpace(id), strings.TrimSpace(secret)
|
||||
}
|
||||
return strings.TrimSpace(c.PostForm("client_id")), strings.TrimSpace(c.PostForm("client_secret"))
|
||||
}
|
||||
|
||||
// genCode generates URL-safe random string based on nBytes of entropy
|
||||
func genCode(nBytes int) (string, error) {
|
||||
b := make([]byte, nBytes)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// s256Base64URL computes base64url-encoded SHA256 digest
|
||||
func s256Base64URL(verifier string) string {
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// writeNoStore sets no-store cache headers for OAuth responses
|
||||
func writeNoStore(c *gin.Context) {
|
||||
c.Header("Cache-Control", "no-store")
|
||||
c.Header("Pragma", "no-cache")
|
||||
}
|
||||
|
||||
// writeOAuthRedirectError builds an error redirect to redirect_uri as RFC6749
|
||||
func writeOAuthRedirectError(c *gin.Context, redirectURI, errCode, description, state string) {
|
||||
writeNoStore(c)
|
||||
q := "error=" + url.QueryEscape(errCode)
|
||||
if description != "" {
|
||||
q += "&error_description=" + url.QueryEscape(description)
|
||||
}
|
||||
if state != "" {
|
||||
q += "&state=" + url.QueryEscape(state)
|
||||
}
|
||||
sep := "?"
|
||||
if strings.Contains(redirectURI, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
c.Redirect(http.StatusFound, redirectURI+sep+q)
|
||||
}
|
||||
@@ -185,6 +185,14 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
|
||||
type NewAPIErrorOptions func(*NewAPIError)
|
||||
|
||||
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
var newErr *NewAPIError
|
||||
// 保留深层传递的 new err
|
||||
if errors.As(err, &newErr) {
|
||||
for _, op := range ops {
|
||||
op(newErr)
|
||||
}
|
||||
return newErr
|
||||
}
|
||||
e := &NewAPIError{
|
||||
Err: err,
|
||||
RelayError: nil,
|
||||
@@ -199,8 +207,21 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI
|
||||
}
|
||||
|
||||
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
if errorCode == ErrorCodeDoRequestFailed {
|
||||
err = errors.New("upstream error: do request failed")
|
||||
var newErr *NewAPIError
|
||||
// 保留深层传递的 new err
|
||||
if errors.As(err, &newErr) {
|
||||
if newErr.RelayError == nil {
|
||||
openaiError := OpenAIError{
|
||||
Message: newErr.Error(),
|
||||
Type: string(errorCode),
|
||||
Code: errorCode,
|
||||
}
|
||||
newErr.RelayError = openaiError
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(newErr)
|
||||
}
|
||||
return newErr
|
||||
}
|
||||
openaiError := OpenAIError{
|
||||
Message: err.Error(),
|
||||
@@ -305,6 +326,15 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions {
|
||||
return func(e *NewAPIError) {
|
||||
if common.DebugEnabled {
|
||||
fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err)
|
||||
}
|
||||
e.Err = errors.New(replaceStr)
|
||||
}
|
||||
}
|
||||
|
||||
func IsRecordErrorLog(e *NewAPIError) bool {
|
||||
if e == nil {
|
||||
return false
|
||||
|
||||
662
web/public/oauth-demo.html
Normal file
662
web/public/oauth-demo.html
Normal file
@@ -0,0 +1,662 @@
|
||||
<!-- This file is a copy of examples/oauth-demo.html for direct serving under /oauth-demo.html -->
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>OAuth2/OIDC 授权码 + PKCE 前端演示</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg: #0b0c10;
|
||||
--panel: #111317;
|
||||
--muted: #aab2bf;
|
||||
--accent: #3b82f6;
|
||||
--ok: #16a34a;
|
||||
--warn: #f59e0b;
|
||||
--err: #ef4444;
|
||||
--border: #1f2430;
|
||||
}
|
||||
body {
|
||||
margin: 0;
|
||||
font-family:
|
||||
ui-sans-serif,
|
||||
system-ui,
|
||||
-apple-system,
|
||||
Segoe UI,
|
||||
Roboto,
|
||||
Helvetica,
|
||||
Arial;
|
||||
background: var(--bg);
|
||||
color: #e5e7eb;
|
||||
}
|
||||
.wrap {
|
||||
max-width: 980px;
|
||||
margin: 32px auto;
|
||||
padding: 0 16px;
|
||||
}
|
||||
h1 {
|
||||
font-size: 22px;
|
||||
margin: 0 0 16px;
|
||||
}
|
||||
.card {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
margin: 12px 0;
|
||||
}
|
||||
.row {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.col {
|
||||
flex: 1 1 280px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
label {
|
||||
font-size: 12px;
|
||||
color: var(--muted);
|
||||
margin-bottom: 6px;
|
||||
}
|
||||
input,
|
||||
textarea,
|
||||
select {
|
||||
background: #0f1115;
|
||||
color: #e5e7eb;
|
||||
border: 1px solid var(--border);
|
||||
padding: 10px 12px;
|
||||
border-radius: 8px;
|
||||
outline: none;
|
||||
}
|
||||
textarea {
|
||||
min-height: 100px;
|
||||
resize: vertical;
|
||||
}
|
||||
.btns {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
flex-wrap: wrap;
|
||||
margin-top: 8px;
|
||||
}
|
||||
button {
|
||||
background: #1a1f2b;
|
||||
color: #e5e7eb;
|
||||
border: 1px solid var(--border);
|
||||
padding: 8px 12px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
}
|
||||
button.primary {
|
||||
background: var(--accent);
|
||||
border-color: var(--accent);
|
||||
color: white;
|
||||
}
|
||||
button.ok {
|
||||
background: var(--ok);
|
||||
border-color: var(--ok);
|
||||
color: white;
|
||||
}
|
||||
button.warn {
|
||||
background: var(--warn);
|
||||
border-color: var(--warn);
|
||||
color: black;
|
||||
}
|
||||
button.ghost {
|
||||
background: transparent;
|
||||
}
|
||||
.muted {
|
||||
color: var(--muted);
|
||||
font-size: 12px;
|
||||
}
|
||||
.mono {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas,
|
||||
'Liberation Mono', 'Courier New', monospace;
|
||||
}
|
||||
.grid2 {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 12px;
|
||||
}
|
||||
@media (max-width: 880px) {
|
||||
.grid2 {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
.ok {
|
||||
color: #10b981;
|
||||
}
|
||||
.err {
|
||||
color: #ef4444;
|
||||
}
|
||||
.sep {
|
||||
height: 1px;
|
||||
background: var(--border);
|
||||
margin: 12px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrap">
|
||||
<h1>OAuth2/OIDC 授权码 + PKCE 前端演示</h1>
|
||||
<div class="card">
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label
|
||||
>Issuer(可选,用于自动发现
|
||||
/.well-known/openid-configuration)</label
|
||||
>
|
||||
<input id="issuer" placeholder="https://your-domain" />
|
||||
<div class="btns">
|
||||
<button class="" id="btnDiscover">自动发现端点</button>
|
||||
</div>
|
||||
<div class="muted">提示:若未配置 Issuer,可直接填写下方端点。</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Response Type</label>
|
||||
<select id="response_type">
|
||||
<option value="code" selected>code</option>
|
||||
<option value="token">token</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Authorization Endpoint</label
|
||||
><input
|
||||
id="authorization_endpoint"
|
||||
placeholder="https://domain/api/oauth/authorize"
|
||||
/>
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Token Endpoint</label
|
||||
><input
|
||||
id="token_endpoint"
|
||||
placeholder="https://domain/api/oauth/token"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>UserInfo Endpoint(可选)</label
|
||||
><input
|
||||
id="userinfo_endpoint"
|
||||
placeholder="https://domain/api/oauth/userinfo"
|
||||
/>
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Client ID</label
|
||||
><input id="client_id" placeholder="your-public-client-id" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Client Secret(可选,机密客户端)</label
|
||||
><input id="client_secret" placeholder="留空表示公开客户端" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Redirect URI(当前页地址或你的回调)</label
|
||||
><input id="redirect_uri" />
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Scope</label
|
||||
><input id="scope" value="openid profile email" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>State</label><input id="state" /></div>
|
||||
<div class="col"><label>Nonce</label><input id="nonce" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Code Verifier(自动生成,不会上送)</label
|
||||
><input id="code_verifier" class="mono" readonly />
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Code Challenge(S256)</label
|
||||
><input id="code_challenge" class="mono" readonly />
|
||||
</div>
|
||||
</div>
|
||||
<div class="btns">
|
||||
<button id="btnGenPkce">生成 PKCE</button>
|
||||
<button id="btnRandomState">随机 State</button>
|
||||
<button id="btnRandomNonce">随机 Nonce</button>
|
||||
<button id="btnMakeAuthURL">生成授权链接</button>
|
||||
<button id="btnAuthorize" class="primary">跳转授权</button>
|
||||
</div>
|
||||
<div class="row" style="margin-top: 8px">
|
||||
<div class="col">
|
||||
<label>授权链接(只生成不跳转)</label>
|
||||
<textarea
|
||||
id="authorize_url"
|
||||
class="mono"
|
||||
placeholder="(空)"
|
||||
></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnCopyAuthURL">复制链接</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="sep"></div>
|
||||
<div class="muted">
|
||||
说明:
|
||||
<ul>
|
||||
<li>
|
||||
本页为纯前端演示,适用于公开客户端(不需要 client_secret)。
|
||||
</li>
|
||||
<li>
|
||||
如跨域调用 Token/UserInfo,需要服务端正确设置 CORS;建议将此 demo
|
||||
部署到同源域名下。
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="sep"></div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label
|
||||
>粘贴 OIDC Discovery
|
||||
JSON(/.well-known/openid-configuration)</label
|
||||
>
|
||||
<textarea
|
||||
id="conf_json"
|
||||
class="mono"
|
||||
placeholder='{"issuer":"https://...","authorization_endpoint":"...","token_endpoint":"...","userinfo_endpoint":"..."}'
|
||||
></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnParseConf">解析并填充端点</button>
|
||||
<button id="btnGenConf">用当前端点生成 JSON</button>
|
||||
</div>
|
||||
<div class="muted">
|
||||
可将服务端返回的 OIDC Discovery JSON
|
||||
粘贴到此处,点击“解析并填充端点”。
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>授权结果</label>
|
||||
<div id="authResult" class="muted">等待授权...</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid2" style="margin-top: 12px">
|
||||
<div>
|
||||
<label>Access Token</label>
|
||||
<textarea
|
||||
id="access_token"
|
||||
class="mono"
|
||||
placeholder="(空)"
|
||||
></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnCopyAT">复制</button
|
||||
><button id="btnCallUserInfo" class="ok">调用 UserInfo</button>
|
||||
</div>
|
||||
<div id="userinfoOut" class="muted" style="margin-top: 6px"></div>
|
||||
</div>
|
||||
<div>
|
||||
<label>ID Token(JWT)</label>
|
||||
<textarea id="id_token" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnDecodeJWT">解码显示 Claims</button>
|
||||
</div>
|
||||
<pre
|
||||
id="jwtClaims"
|
||||
class="mono"
|
||||
style="
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
margin-top: 6px;
|
||||
"
|
||||
></pre>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid2" style="margin-top: 12px">
|
||||
<div>
|
||||
<label>Refresh Token</label>
|
||||
<textarea
|
||||
id="refresh_token"
|
||||
class="mono"
|
||||
placeholder="(空)"
|
||||
></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnRefreshToken">使用 Refresh Token 刷新</button>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label>原始 Token 响应</label>
|
||||
<textarea id="token_raw" class="mono" placeholder="(空)"></textarea>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
const $ = (id) => document.getElementById(id);
|
||||
const toB64Url = (buf) =>
|
||||
btoa(String.fromCharCode(...new Uint8Array(buf)))
|
||||
.replace(/\+/g, '-')
|
||||
.replace(/\//g, '_')
|
||||
.replace(/=+$/, '');
|
||||
async function sha256B64Url(str) {
|
||||
const data = new TextEncoder().encode(str);
|
||||
const digest = await crypto.subtle.digest('SHA-256', data);
|
||||
return toB64Url(digest);
|
||||
}
|
||||
function randStr(len = 64) {
|
||||
const chars =
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';
|
||||
const arr = new Uint8Array(len);
|
||||
crypto.getRandomValues(arr);
|
||||
return Array.from(arr, (v) => chars[v % chars.length]).join('');
|
||||
}
|
||||
function setAuthInfo(msg, ok = true) {
|
||||
const el = $('authResult');
|
||||
el.textContent = msg;
|
||||
el.className = ok ? 'ok' : 'err';
|
||||
}
|
||||
function qs(name) {
|
||||
const u = new URL(location.href);
|
||||
return u.searchParams.get(name);
|
||||
}
|
||||
function persist(k, v) {
|
||||
sessionStorage.setItem('demo_' + k, v);
|
||||
}
|
||||
function load(k) {
|
||||
return sessionStorage.getItem('demo_' + k) || '';
|
||||
}
|
||||
(function init() {
|
||||
$('redirect_uri').value =
|
||||
window.location.origin + window.location.pathname;
|
||||
const iss = load('issuer');
|
||||
if (iss) $('issuer').value = iss;
|
||||
const cid = load('client_id');
|
||||
if (cid) $('client_id').value = cid;
|
||||
const scp = load('scope');
|
||||
if (scp) $('scope').value = scp;
|
||||
})();
|
||||
$('btnDiscover').onclick = async () => {
|
||||
const iss = $('issuer').value.trim();
|
||||
if (!iss) {
|
||||
alert('请填写 Issuer');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
persist('issuer', iss);
|
||||
const res = await fetch(
|
||||
iss.replace(/\/$/, '') + '/api/.well-known/openid-configuration',
|
||||
);
|
||||
const d = await res.json();
|
||||
$('authorization_endpoint').value = d.authorization_endpoint || '';
|
||||
$('token_endpoint').value = d.token_endpoint || '';
|
||||
$('userinfo_endpoint').value = d.userinfo_endpoint || '';
|
||||
if (d.issuer) {
|
||||
$('issuer').value = d.issuer;
|
||||
persist('issuer', d.issuer);
|
||||
}
|
||||
$('conf_json').value = JSON.stringify(d, null, 2);
|
||||
setAuthInfo('已从发现文档加载端点', true);
|
||||
} catch (e) {
|
||||
setAuthInfo('自动发现失败:' + e, false);
|
||||
}
|
||||
};
|
||||
$('btnGenPkce').onclick = async () => {
|
||||
const v = randStr(64);
|
||||
const c = await sha256B64Url(v);
|
||||
$('code_verifier').value = v;
|
||||
$('code_challenge').value = c;
|
||||
persist('code_verifier', v);
|
||||
persist('code_challenge', c);
|
||||
setAuthInfo('已生成 PKCE 参数', true);
|
||||
};
|
||||
$('btnRandomState').onclick = () => {
|
||||
$('state').value = randStr(16);
|
||||
persist('state', $('state').value);
|
||||
};
|
||||
$('btnRandomNonce').onclick = () => {
|
||||
$('nonce').value = randStr(16);
|
||||
persist('nonce', $('nonce').value);
|
||||
};
|
||||
function buildAuthorizeURLFromFields() {
|
||||
const auth = $('authorization_endpoint').value.trim();
|
||||
const token = $('token_endpoint').value.trim();
|
||||
const cid = $('client_id').value.trim();
|
||||
const red = $('redirect_uri').value.trim();
|
||||
const scp = $('scope').value.trim() || 'openid profile email';
|
||||
const rt = $('response_type').value;
|
||||
const st = $('state').value.trim() || randStr(16);
|
||||
const no = $('nonce').value.trim() || randStr(16);
|
||||
const cc = $('code_challenge').value.trim();
|
||||
const cv = $('code_verifier').value.trim();
|
||||
if (!auth || !cid || !red) {
|
||||
throw new Error('请先完善端点/ClientID/RedirectURI');
|
||||
}
|
||||
if (rt === 'code' && (!cc || !cv)) {
|
||||
throw new Error('请先生成 PKCE');
|
||||
}
|
||||
persist('authorization_endpoint', auth);
|
||||
persist('token_endpoint', token);
|
||||
persist('client_id', cid);
|
||||
persist('redirect_uri', red);
|
||||
persist('scope', scp);
|
||||
persist('state', st);
|
||||
persist('nonce', no);
|
||||
persist('code_verifier', cv);
|
||||
const u = new URL(auth);
|
||||
u.searchParams.set('response_type', rt);
|
||||
u.searchParams.set('client_id', cid);
|
||||
u.searchParams.set('redirect_uri', red);
|
||||
u.searchParams.set('scope', scp);
|
||||
u.searchParams.set('state', st);
|
||||
if (no) u.searchParams.set('nonce', no);
|
||||
if (rt === 'code') {
|
||||
u.searchParams.set('code_challenge', cc);
|
||||
u.searchParams.set('code_challenge_method', 'S256');
|
||||
}
|
||||
return u.toString();
|
||||
}
|
||||
$('btnMakeAuthURL').onclick = () => {
|
||||
try {
|
||||
const url = buildAuthorizeURLFromFields();
|
||||
$('authorize_url').value = url;
|
||||
setAuthInfo('已生成授权链接', true);
|
||||
} catch (e) {
|
||||
setAuthInfo(e.message, false);
|
||||
}
|
||||
};
|
||||
$('btnAuthorize').onclick = () => {
|
||||
try {
|
||||
const url = buildAuthorizeURLFromFields();
|
||||
location.href = url;
|
||||
} catch (e) {
|
||||
setAuthInfo(e.message, false);
|
||||
}
|
||||
};
|
||||
$('btnCopyAuthURL').onclick = async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText($('authorize_url').value);
|
||||
} catch {}
|
||||
};
|
||||
async function postForm(url, data, basic) {
|
||||
const body = Object.entries(data)
|
||||
.map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(v)}`)
|
||||
.join('&');
|
||||
const headers = { 'Content-Type': 'application/x-www-form-urlencoded' };
|
||||
if (basic && basic.id && basic.secret) {
|
||||
headers['Authorization'] =
|
||||
'Basic ' + btoa(`${basic.id}:${basic.secret}`);
|
||||
}
|
||||
const res = await fetch(url, { method: 'POST', headers, body });
|
||||
if (!res.ok) {
|
||||
const t = await res.text();
|
||||
throw new Error(`HTTP ${res.status} ${t}`);
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
async function handleCallback() {
|
||||
const frag =
|
||||
location.hash && location.hash.startsWith('#')
|
||||
? new URLSearchParams(location.hash.slice(1))
|
||||
: null;
|
||||
const at = frag ? frag.get('access_token') : null;
|
||||
const err = qs('error') || (frag ? frag.get('error') : null);
|
||||
const state = qs('state') || (frag ? frag.get('state') : null);
|
||||
if (err) {
|
||||
setAuthInfo('授权失败:' + err, false);
|
||||
return;
|
||||
}
|
||||
if (at) {
|
||||
$('access_token').value = at || '';
|
||||
$('token_raw').value = JSON.stringify(
|
||||
{
|
||||
access_token: at,
|
||||
token_type: frag.get('token_type'),
|
||||
expires_in: frag.get('expires_in'),
|
||||
scope: frag.get('scope'),
|
||||
state,
|
||||
},
|
||||
null,
|
||||
2,
|
||||
);
|
||||
setAuthInfo('隐式模式已获取 Access Token', true);
|
||||
return;
|
||||
}
|
||||
const code = qs('code');
|
||||
if (!code) {
|
||||
setAuthInfo('等待授权...', true);
|
||||
return;
|
||||
}
|
||||
if (state && load('state') && state !== load('state')) {
|
||||
setAuthInfo('state 不匹配,已拒绝', false);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const tokenEp = load('token_endpoint');
|
||||
const cid = load('client_id');
|
||||
const csec = $('client_secret').value.trim();
|
||||
const basic = csec ? { id: cid, secret: csec } : null;
|
||||
const data = await postForm(
|
||||
tokenEp,
|
||||
{
|
||||
grant_type: 'authorization_code',
|
||||
code,
|
||||
client_id: cid,
|
||||
redirect_uri: load('redirect_uri'),
|
||||
code_verifier: load('code_verifier'),
|
||||
},
|
||||
basic,
|
||||
);
|
||||
$('access_token').value = data.access_token || '';
|
||||
$('id_token').value = data.id_token || '';
|
||||
$('refresh_token').value = data.refresh_token || '';
|
||||
$('token_raw').value = JSON.stringify(data, null, 2);
|
||||
setAuthInfo('授权成功,已获取令牌', true);
|
||||
} catch (e) {
|
||||
setAuthInfo('交换令牌失败:' + e.message, false);
|
||||
}
|
||||
}
|
||||
handleCallback();
|
||||
$('btnCopyAT').onclick = async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText($('access_token').value);
|
||||
} catch {}
|
||||
};
|
||||
$('btnDecodeJWT').onclick = () => {
|
||||
const t = $('id_token').value.trim();
|
||||
if (!t) {
|
||||
$('jwtClaims').textContent = '(空)';
|
||||
return;
|
||||
}
|
||||
const parts = t.split('.');
|
||||
if (parts.length < 2) {
|
||||
$('jwtClaims').textContent = '格式错误';
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const json = JSON.parse(
|
||||
atob(parts[1].replace(/-/g, '+').replace(/_/g, '/')),
|
||||
);
|
||||
$('jwtClaims').textContent = JSON.stringify(json, null, 2);
|
||||
} catch (e) {
|
||||
$('jwtClaims').textContent = '解码失败:' + e;
|
||||
}
|
||||
};
|
||||
$('btnCallUserInfo').onclick = async () => {
|
||||
const at = $('access_token').value.trim();
|
||||
const ep = $('userinfo_endpoint').value.trim();
|
||||
if (!at || !ep) {
|
||||
alert('请填写UserInfo端点并获取AccessToken');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const res = await fetch(ep, {
|
||||
headers: { Authorization: 'Bearer ' + at },
|
||||
});
|
||||
const data = await res.json();
|
||||
$('userinfoOut').textContent = JSON.stringify(data, null, 2);
|
||||
} catch (e) {
|
||||
$('userinfoOut').textContent = '调用失败:' + e;
|
||||
}
|
||||
};
|
||||
$('btnRefreshToken').onclick = async () => {
|
||||
const rt = $('refresh_token').value.trim();
|
||||
if (!rt) {
|
||||
alert('没有刷新令牌');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const tokenEp = load('token_endpoint');
|
||||
const cid = load('client_id');
|
||||
const csec = $('client_secret').value.trim();
|
||||
const basic = csec ? { id: cid, secret: csec } : null;
|
||||
const data = await postForm(
|
||||
tokenEp,
|
||||
{ grant_type: 'refresh_token', refresh_token: rt, client_id: cid },
|
||||
basic,
|
||||
);
|
||||
$('access_token').value = data.access_token || '';
|
||||
$('id_token').value = data.id_token || '';
|
||||
$('refresh_token').value = data.refresh_token || '';
|
||||
$('token_raw').value = JSON.stringify(data, null, 2);
|
||||
setAuthInfo('刷新成功', true);
|
||||
} catch (e) {
|
||||
setAuthInfo('刷新失败:' + e.message, false);
|
||||
}
|
||||
};
|
||||
$('btnParseConf').onclick = () => {
|
||||
const txt = $('conf_json').value.trim();
|
||||
if (!txt) {
|
||||
alert('请先粘贴 JSON');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const d = JSON.parse(txt);
|
||||
if (d.issuer) {
|
||||
$('issuer').value = d.issuer;
|
||||
persist('issuer', d.issuer);
|
||||
}
|
||||
if (d.authorization_endpoint)
|
||||
$('authorization_endpoint').value = d.authorization_endpoint;
|
||||
if (d.token_endpoint) $('token_endpoint').value = d.token_endpoint;
|
||||
if (d.userinfo_endpoint)
|
||||
$('userinfo_endpoint').value = d.userinfo_endpoint;
|
||||
setAuthInfo('已解析配置并填充端点', true);
|
||||
} catch (e) {
|
||||
setAuthInfo('解析失败:' + e, false);
|
||||
}
|
||||
};
|
||||
$('btnGenConf').onclick = () => {
|
||||
const d = {
|
||||
issuer: $('issuer').value.trim() || undefined,
|
||||
authorization_endpoint:
|
||||
$('authorization_endpoint').value.trim() || undefined,
|
||||
token_endpoint: $('token_endpoint').value.trim() || undefined,
|
||||
userinfo_endpoint: $('userinfo_endpoint').value.trim() || undefined,
|
||||
};
|
||||
$('conf_json').value = JSON.stringify(d, null, 2);
|
||||
};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -44,6 +44,7 @@ import Task from './pages/Task';
|
||||
import ModelPage from './pages/Model';
|
||||
import Playground from './pages/Playground';
|
||||
import OAuth2Callback from './components/auth/OAuth2Callback';
|
||||
import OAuthConsent from './pages/OAuth';
|
||||
import PersonalSetting from './components/settings/PersonalSetting';
|
||||
import Setup from './pages/Setup';
|
||||
import SetupCheck from './components/layout/SetupCheck';
|
||||
@@ -198,6 +199,14 @@ function App() {
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/consent'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<OAuthConsent />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/linuxdo'
|
||||
element={
|
||||
|
||||
@@ -176,7 +176,11 @@ const LoginForm = () => {
|
||||
centered: true,
|
||||
});
|
||||
}
|
||||
navigate('/console');
|
||||
// 优先跳回 next(仅允许相对路径)
|
||||
const sp = new URLSearchParams(window.location.search);
|
||||
const next = sp.get('next');
|
||||
const isSafeInternalPath = next && next.startsWith('/') && !next.startsWith('//');
|
||||
navigate(isSafeInternalPath ? next : '/console');
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
@@ -286,7 +290,10 @@ const LoginForm = () => {
|
||||
setUserData(data);
|
||||
updateAPI();
|
||||
showSuccess('登录成功!');
|
||||
navigate('/console');
|
||||
const sp = new URLSearchParams(window.location.search);
|
||||
const next = sp.get('next');
|
||||
const isSafeInternalPath = next && next.startsWith('/') && !next.startsWith('//');
|
||||
navigate(isSafeInternalPath ? next : '/console');
|
||||
};
|
||||
|
||||
// 返回登录页面
|
||||
|
||||
@@ -135,7 +135,9 @@ const TwoFactorAuthModal = ({
|
||||
autoFocus
|
||||
/>
|
||||
<Typography.Text type='tertiary' size='small' className='mt-2 block'>
|
||||
{t('支持6位TOTP验证码或8位备用码')}
|
||||
{t(
|
||||
'支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。',
|
||||
)}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -21,7 +21,7 @@ import React, { useState, useMemo, useCallback } from 'react';
|
||||
import { Button, Tooltip, Toast } from '@douyinfe/semi-ui';
|
||||
import { Copy, ChevronDown, ChevronUp } from 'lucide-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { copy } from '../../helpers';
|
||||
import { copy } from '../../../helpers';
|
||||
|
||||
const PERFORMANCE_CONFIG = {
|
||||
MAX_DISPLAY_LENGTH: 50000, // 最大显示字符数
|
||||
135
web/src/components/common/ui/ResponsiveModal.jsx
Normal file
135
web/src/components/common/ui/ResponsiveModal.jsx
Normal file
@@ -0,0 +1,135 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Modal, Typography } from '@douyinfe/semi-ui';
|
||||
import PropTypes from 'prop-types';
|
||||
import { useIsMobile } from '../../../hooks/common/useIsMobile';
|
||||
|
||||
const { Title } = Typography;
|
||||
|
||||
/**
|
||||
* ResponsiveModal 响应式模态框组件
|
||||
*
|
||||
* 特性:
|
||||
* - 响应式布局:移动端和桌面端不同的宽度和布局
|
||||
* - 自定义头部:标题左对齐,操作按钮右对齐,移动端自动换行
|
||||
* - Tailwind CSS 样式支持
|
||||
* - 保持原 Modal 组件的所有功能
|
||||
*/
|
||||
const ResponsiveModal = ({
|
||||
visible,
|
||||
onCancel,
|
||||
title,
|
||||
headerActions = [],
|
||||
children,
|
||||
width = { mobile: '95%', desktop: 600 },
|
||||
className = '',
|
||||
footer = null,
|
||||
titleProps = {},
|
||||
headerClassName = '',
|
||||
actionsClassName = '',
|
||||
...props
|
||||
}) => {
|
||||
const isMobile = useIsMobile();
|
||||
|
||||
// 自定义 Header 组件
|
||||
const CustomHeader = () => {
|
||||
if (!title && (!headerActions || headerActions.length === 0)) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`flex w-full gap-3 justify-between ${
|
||||
isMobile ? 'flex-col items-start' : 'flex-row items-center'
|
||||
} ${headerClassName}`}
|
||||
>
|
||||
{title && (
|
||||
<Title heading={5} className='m-0 min-w-fit' {...titleProps}>
|
||||
{title}
|
||||
</Title>
|
||||
)}
|
||||
{headerActions && headerActions.length > 0 && (
|
||||
<div
|
||||
className={`flex flex-wrap gap-2 items-center ${
|
||||
isMobile ? 'w-full justify-start' : 'w-auto justify-end'
|
||||
} ${actionsClassName}`}
|
||||
>
|
||||
{headerActions.map((action, index) => (
|
||||
<React.Fragment key={index}>{action}</React.Fragment>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 计算模态框宽度
|
||||
const getModalWidth = () => {
|
||||
if (typeof width === 'object') {
|
||||
return isMobile ? width.mobile : width.desktop;
|
||||
}
|
||||
return width;
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
visible={visible}
|
||||
title={<CustomHeader />}
|
||||
onCancel={onCancel}
|
||||
footer={footer}
|
||||
width={getModalWidth()}
|
||||
className={`!top-12 ${className}`}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
ResponsiveModal.propTypes = {
|
||||
// Modal 基础属性
|
||||
visible: PropTypes.bool.isRequired,
|
||||
onCancel: PropTypes.func.isRequired,
|
||||
children: PropTypes.node,
|
||||
|
||||
// 自定义头部
|
||||
title: PropTypes.oneOfType([PropTypes.string, PropTypes.node]),
|
||||
headerActions: PropTypes.arrayOf(PropTypes.node),
|
||||
|
||||
// 样式和布局
|
||||
width: PropTypes.oneOfType([
|
||||
PropTypes.number,
|
||||
PropTypes.string,
|
||||
PropTypes.shape({
|
||||
mobile: PropTypes.oneOfType([PropTypes.number, PropTypes.string]),
|
||||
desktop: PropTypes.oneOfType([PropTypes.number, PropTypes.string]),
|
||||
}),
|
||||
]),
|
||||
className: PropTypes.string,
|
||||
footer: PropTypes.node,
|
||||
|
||||
// 标题自定义属性
|
||||
titleProps: PropTypes.object,
|
||||
|
||||
// 自定义 CSS 类
|
||||
headerClassName: PropTypes.string,
|
||||
actionsClassName: PropTypes.string,
|
||||
};
|
||||
|
||||
export default ResponsiveModal;
|
||||
@@ -28,7 +28,7 @@ import {
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { Code, Zap, Clock, X, Eye, Send } from 'lucide-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import CodeViewer from './CodeViewer';
|
||||
import CodeViewer from '../common/ui/CodeViewer';
|
||||
|
||||
const DebugPanel = ({
|
||||
debugData,
|
||||
|
||||
72
web/src/components/settings/OAuth2Setting.jsx
Normal file
72
web/src/components/settings/OAuth2Setting.jsx
Normal file
@@ -0,0 +1,72 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Spin } from '@douyinfe/semi-ui';
|
||||
import { API, showError } from '../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import OAuth2ServerSettings from './oauth2/OAuth2ServerSettings';
|
||||
import OAuth2ClientSettings from './oauth2/OAuth2ClientSettings';
|
||||
|
||||
const OAuth2Setting = () => {
|
||||
const { t } = useTranslation();
|
||||
const [options, setOptions] = useState({});
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const getOptions = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
const map = {};
|
||||
for (const item of data) {
|
||||
map[item.key] = item.value;
|
||||
}
|
||||
setOptions(map);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('获取OAuth2设置失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const refresh = () => {
|
||||
getOptions();
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
getOptions();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Spin spinning={loading} size='large'>
|
||||
{/* 服务器配置 */}
|
||||
<OAuth2ServerSettings options={options} refresh={refresh} />
|
||||
|
||||
{/* 客户端管理 */}
|
||||
<OAuth2ClientSettings />
|
||||
</Spin>
|
||||
);
|
||||
};
|
||||
|
||||
export default OAuth2Setting;
|
||||
@@ -37,6 +37,8 @@ const PaymentSetting = () => {
|
||||
TopupGroupRatio: '',
|
||||
CustomCallbackAddress: '',
|
||||
PayMethods: '',
|
||||
AmountOptions: '',
|
||||
AmountDiscount: '',
|
||||
|
||||
StripeApiSecret: '',
|
||||
StripeWebhookSecret: '',
|
||||
@@ -66,6 +68,30 @@ const PaymentSetting = () => {
|
||||
newInputs[item.key] = item.value;
|
||||
}
|
||||
break;
|
||||
case 'payment_setting.amount_options':
|
||||
try {
|
||||
newInputs['AmountOptions'] = JSON.stringify(
|
||||
JSON.parse(item.value),
|
||||
null,
|
||||
2,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('解析AmountOptions出错:', error);
|
||||
newInputs['AmountOptions'] = item.value;
|
||||
}
|
||||
break;
|
||||
case 'payment_setting.amount_discount':
|
||||
try {
|
||||
newInputs['AmountDiscount'] = JSON.stringify(
|
||||
JSON.parse(item.value),
|
||||
null,
|
||||
2,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('解析AmountDiscount出错:', error);
|
||||
newInputs['AmountDiscount'] = item.value;
|
||||
}
|
||||
break;
|
||||
case 'Price':
|
||||
case 'MinTopUp':
|
||||
case 'StripeUnitPrice':
|
||||
|
||||
400
web/src/components/settings/oauth2/OAuth2ClientSettings.jsx
Normal file
400
web/src/components/settings/oauth2/OAuth2ClientSettings.jsx
Normal file
@@ -0,0 +1,400 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
Card,
|
||||
Table,
|
||||
Button,
|
||||
Space,
|
||||
Tag,
|
||||
Typography,
|
||||
Input,
|
||||
Popconfirm,
|
||||
Empty,
|
||||
Tooltip,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconSearch } from '@douyinfe/semi-icons';
|
||||
import { User } from 'lucide-react';
|
||||
import {
|
||||
IllustrationNoResult,
|
||||
IllustrationNoResultDark,
|
||||
} from '@douyinfe/semi-illustrations';
|
||||
import { API, showError, showSuccess } from '../../../helpers';
|
||||
import OAuth2ClientModal from './modals/OAuth2ClientModal';
|
||||
import SecretDisplayModal from './modals/SecretDisplayModal';
|
||||
import ServerInfoModal from './modals/ServerInfoModal';
|
||||
import JWKSInfoModal from './modals/JWKSInfoModal';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
export default function OAuth2ClientSettings() {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [clients, setClients] = useState([]);
|
||||
const [filteredClients, setFilteredClients] = useState([]);
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [showModal, setShowModal] = useState(false);
|
||||
const [editingClient, setEditingClient] = useState(null);
|
||||
const [showSecretModal, setShowSecretModal] = useState(false);
|
||||
const [currentSecret, setCurrentSecret] = useState('');
|
||||
const [showServerInfoModal, setShowServerInfoModal] = useState(false);
|
||||
const [showJWKSModal, setShowJWKSModal] = useState(false);
|
||||
|
||||
// 加载客户端列表
|
||||
const loadClients = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/oauth_clients/');
|
||||
if (res.data.success) {
|
||||
setClients(res.data.data || []);
|
||||
setFilteredClients(res.data.data || []);
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('加载OAuth2客户端失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 搜索过滤
|
||||
const handleSearch = (value) => {
|
||||
setSearchKeyword(value);
|
||||
if (!value) {
|
||||
setFilteredClients(clients);
|
||||
} else {
|
||||
const filtered = clients.filter(
|
||||
(client) =>
|
||||
client.name?.toLowerCase().includes(value.toLowerCase()) ||
|
||||
client.id?.toLowerCase().includes(value.toLowerCase()) ||
|
||||
client.description?.toLowerCase().includes(value.toLowerCase()),
|
||||
);
|
||||
setFilteredClients(filtered);
|
||||
}
|
||||
};
|
||||
|
||||
// 删除客户端
|
||||
const handleDelete = async (client) => {
|
||||
try {
|
||||
const res = await API.delete(`/api/oauth_clients/${client.id}`);
|
||||
if (res.data.success) {
|
||||
showSuccess(t('删除成功'));
|
||||
loadClients();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('删除失败'));
|
||||
}
|
||||
};
|
||||
|
||||
// 重新生成密钥
|
||||
const handleRegenerateSecret = async (client) => {
|
||||
try {
|
||||
const res = await API.post(
|
||||
`/api/oauth_clients/${client.id}/regenerate_secret`,
|
||||
);
|
||||
if (res.data.success) {
|
||||
setCurrentSecret(res.data.client_secret);
|
||||
setShowSecretModal(true);
|
||||
loadClients();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('重新生成密钥失败'));
|
||||
}
|
||||
};
|
||||
|
||||
// 查看服务器信息
|
||||
const showServerInfo = () => {
|
||||
setShowServerInfoModal(true);
|
||||
};
|
||||
|
||||
// 查看JWKS
|
||||
const showJWKS = () => {
|
||||
setShowJWKSModal(true);
|
||||
};
|
||||
|
||||
// 表格列定义
|
||||
const columns = [
|
||||
{
|
||||
title: t('客户端名称'),
|
||||
dataIndex: 'name',
|
||||
render: (name, record) => (
|
||||
<div className='flex items-center cursor-help'>
|
||||
<User size={16} className='mr-1.5 text-gray-500' />
|
||||
<Tooltip content={record.description || t('暂无描述')} position='top'>
|
||||
<Text strong>{name}</Text>
|
||||
</Tooltip>
|
||||
</div>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('客户端ID'),
|
||||
dataIndex: 'id',
|
||||
render: (id) => (
|
||||
<Text type='tertiary' size='small' code copyable>
|
||||
{id}
|
||||
</Text>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('状态'),
|
||||
dataIndex: 'status',
|
||||
render: (status) => (
|
||||
<Tag color={status === 1 ? 'green' : 'red'} shape='circle'>
|
||||
{status === 1 ? t('启用') : t('禁用')}
|
||||
</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('类型'),
|
||||
dataIndex: 'client_type',
|
||||
render: (text) => (
|
||||
<Tag color='white' shape='circle'>
|
||||
{text === 'confidential' ? t('机密客户端') : t('公开客户端')}
|
||||
</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('授权类型'),
|
||||
dataIndex: 'grant_types',
|
||||
render: (grantTypes) => {
|
||||
const types =
|
||||
typeof grantTypes === 'string'
|
||||
? grantTypes.split(',')
|
||||
: grantTypes || [];
|
||||
const typeMap = {
|
||||
client_credentials: t('客户端凭证'),
|
||||
authorization_code: t('授权码'),
|
||||
refresh_token: t('刷新令牌'),
|
||||
};
|
||||
return (
|
||||
<div className='flex flex-wrap gap-1'>
|
||||
{types.slice(0, 2).map((type) => (
|
||||
<Tag color='white' key={type} size='small' shape='circle'>
|
||||
{typeMap[type] || type}
|
||||
</Tag>
|
||||
))}
|
||||
{types.length > 2 && (
|
||||
<Tooltip
|
||||
content={types
|
||||
.slice(2)
|
||||
.map((t) => typeMap[t] || t)
|
||||
.join(', ')}
|
||||
>
|
||||
<Tag color='white' size='small' shape='circle'>
|
||||
+{types.length - 2}
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('创建时间'),
|
||||
dataIndex: 'created_time',
|
||||
render: (time) => new Date(time * 1000).toLocaleString(),
|
||||
},
|
||||
{
|
||||
title: t('操作'),
|
||||
render: (_, record) => (
|
||||
<Space size={4} wrap>
|
||||
<Button
|
||||
type='primary'
|
||||
size='small'
|
||||
onClick={() => {
|
||||
setEditingClient(record);
|
||||
setShowModal(true);
|
||||
}}
|
||||
>
|
||||
{t('编辑')}
|
||||
</Button>
|
||||
{record.client_type === 'confidential' && (
|
||||
<Popconfirm
|
||||
title={t('确认重新生成客户端密钥?')}
|
||||
content={t('操作不可撤销,旧密钥将立即失效。')}
|
||||
onConfirm={() => handleRegenerateSecret(record)}
|
||||
okText={t('确认')}
|
||||
cancelText={t('取消')}
|
||||
position='bottomLeft'
|
||||
>
|
||||
<Button type='secondary' size='small'>
|
||||
{t('重新生成密钥')}
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
)}
|
||||
<Popconfirm
|
||||
title={t('请再次确认删除该客户端')}
|
||||
content={t('删除后无法恢复,相关 API 调用将立即失效。')}
|
||||
onConfirm={() => handleDelete(record)}
|
||||
okText={t('确定删除')}
|
||||
cancelText={t('取消')}
|
||||
position='bottomLeft'
|
||||
>
|
||||
<Button type='danger' size='small'>
|
||||
{t('删除')}
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
),
|
||||
fixed: 'right',
|
||||
},
|
||||
];
|
||||
|
||||
useEffect(() => {
|
||||
loadClients();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Card
|
||||
className='!rounded-2xl shadow-sm border-0'
|
||||
style={{ marginTop: 10 }}
|
||||
title={
|
||||
<div
|
||||
className='flex flex-col sm:flex-row sm:items-center sm:justify-between w-full gap-3 sm:gap-0'
|
||||
style={{ paddingRight: '8px' }}
|
||||
>
|
||||
<div className='flex items-center'>
|
||||
<User size={18} className='mr-2' />
|
||||
<Text strong>{t('OAuth2 客户端管理')}</Text>
|
||||
<Tag color='white' shape='circle' size='small' className='ml-2'>
|
||||
{filteredClients.length} {t('个客户端')}
|
||||
</Tag>
|
||||
</div>
|
||||
<div className='flex items-center gap-2 sm:flex-shrink-0 flex-wrap'>
|
||||
<Input
|
||||
prefix={<IconSearch />}
|
||||
placeholder={t('搜索客户端名称、ID或描述')}
|
||||
value={searchKeyword}
|
||||
onChange={handleSearch}
|
||||
showClear
|
||||
size='small'
|
||||
style={{ width: 300 }}
|
||||
/>
|
||||
<Button type='tertiary' onClick={loadClients} size='small'>
|
||||
{t('刷新')}
|
||||
</Button>
|
||||
<Button type='secondary' onClick={showServerInfo} size='small'>
|
||||
{t('服务器信息')}
|
||||
</Button>
|
||||
<Button type='secondary' onClick={showJWKS} size='small'>
|
||||
{t('查看JWKS')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
onClick={() => {
|
||||
setEditingClient(null);
|
||||
setShowModal(true);
|
||||
}}
|
||||
size='small'
|
||||
>
|
||||
{t('创建客户端')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div className='mb-4'>
|
||||
<Text type='tertiary'>
|
||||
{t(
|
||||
'管理OAuth2客户端应用程序,每个客户端代表一个可以访问API的应用程序。机密客户端用于服务器端应用,公开客户端用于移动应用或单页应用。',
|
||||
)}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{/* 客户端表格 */}
|
||||
<Table
|
||||
columns={columns}
|
||||
dataSource={filteredClients}
|
||||
rowKey='id'
|
||||
loading={loading}
|
||||
scroll={{ x: 'max-content' }}
|
||||
pagination={{
|
||||
showSizeChanger: true,
|
||||
showQuickJumper: true,
|
||||
showTotal: true,
|
||||
pageSize: 10,
|
||||
}}
|
||||
empty={
|
||||
<Empty
|
||||
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
|
||||
darkModeImage={
|
||||
<IllustrationNoResultDark style={{ width: 150, height: 150 }} />
|
||||
}
|
||||
title={t('暂无OAuth2客户端')}
|
||||
description={t(
|
||||
'还没有创建任何客户端,点击下方按钮创建第一个客户端',
|
||||
)}
|
||||
style={{ padding: 30 }}
|
||||
>
|
||||
<Button
|
||||
type='primary'
|
||||
onClick={() => {
|
||||
setEditingClient(null);
|
||||
setShowModal(true);
|
||||
}}
|
||||
>
|
||||
{t('创建第一个客户端')}
|
||||
</Button>
|
||||
</Empty>
|
||||
}
|
||||
/>
|
||||
|
||||
{/* OAuth2 客户端模态框 */}
|
||||
<OAuth2ClientModal
|
||||
visible={showModal}
|
||||
client={editingClient}
|
||||
onCancel={() => {
|
||||
setShowModal(false);
|
||||
setEditingClient(null);
|
||||
}}
|
||||
onSuccess={() => {
|
||||
setShowModal(false);
|
||||
setEditingClient(null);
|
||||
loadClients();
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* 密钥显示模态框 */}
|
||||
<SecretDisplayModal
|
||||
visible={showSecretModal}
|
||||
onClose={() => setShowSecretModal(false)}
|
||||
secret={currentSecret}
|
||||
/>
|
||||
|
||||
{/* 服务器信息模态框 */}
|
||||
<ServerInfoModal
|
||||
visible={showServerInfoModal}
|
||||
onClose={() => setShowServerInfoModal(false)}
|
||||
/>
|
||||
|
||||
{/* JWKS信息模态框 */}
|
||||
<JWKSInfoModal
|
||||
visible={showJWKSModal}
|
||||
onClose={() => setShowJWKSModal(false)}
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
473
web/src/components/settings/oauth2/OAuth2ServerSettings.jsx
Normal file
473
web/src/components/settings/oauth2/OAuth2ServerSettings.jsx
Normal file
@@ -0,0 +1,473 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import {
|
||||
Banner,
|
||||
Button,
|
||||
Col,
|
||||
Form,
|
||||
Row,
|
||||
Card,
|
||||
Typography,
|
||||
Badge,
|
||||
Divider,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { Server } from 'lucide-react';
|
||||
import JWKSManagerModal from './modals/JWKSManagerModal';
|
||||
import {
|
||||
compareObjects,
|
||||
API,
|
||||
showError,
|
||||
showSuccess,
|
||||
showWarning,
|
||||
} from '../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
export default function OAuth2ServerSettings(props) {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [inputs, setInputs] = useState({
|
||||
'oauth2.enabled': false,
|
||||
'oauth2.issuer': '',
|
||||
'oauth2.access_token_ttl': 10,
|
||||
'oauth2.refresh_token_ttl': 720,
|
||||
'oauth2.jwt_signing_algorithm': 'RS256',
|
||||
'oauth2.jwt_key_id': 'oauth2-key-1',
|
||||
'oauth2.allowed_grant_types': [
|
||||
'client_credentials',
|
||||
'authorization_code',
|
||||
'refresh_token',
|
||||
],
|
||||
'oauth2.require_pkce': true,
|
||||
'oauth2.max_jwks_keys': 3,
|
||||
});
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
const [keysReady, setKeysReady] = useState(true);
|
||||
const [keysLoading, setKeysLoading] = useState(false);
|
||||
const [serverInfo, setServerInfo] = useState(null);
|
||||
const enabledRef = useRef(inputs['oauth2.enabled']);
|
||||
|
||||
// 模态框状态
|
||||
const [jwksVisible, setJwksVisible] = useState(false);
|
||||
|
||||
function handleFieldChange(fieldName) {
|
||||
return (value) => {
|
||||
setInputs((inputs) => ({ ...inputs, [fieldName]: value }));
|
||||
};
|
||||
}
|
||||
|
||||
function onSubmit() {
|
||||
const updateArray = compareObjects(inputs, inputsRow);
|
||||
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
|
||||
const requestQueue = updateArray.map((item) => {
|
||||
let value = '';
|
||||
if (typeof inputs[item.key] === 'boolean') {
|
||||
value = String(inputs[item.key]);
|
||||
} else if (Array.isArray(inputs[item.key])) {
|
||||
value = JSON.stringify(inputs[item.key]);
|
||||
} else {
|
||||
value = inputs[item.key];
|
||||
}
|
||||
return API.put('/api/option/', {
|
||||
key: item.key,
|
||||
value,
|
||||
});
|
||||
});
|
||||
setLoading(true);
|
||||
Promise.all(requestQueue)
|
||||
.then((res) => {
|
||||
if (requestQueue.length === 1) {
|
||||
if (res.includes(undefined)) return;
|
||||
} else if (requestQueue.length > 1) {
|
||||
if (res.includes(undefined))
|
||||
return showError(t('部分保存失败,请重试'));
|
||||
}
|
||||
showSuccess(t('保存成功'));
|
||||
if (props && props.refresh) {
|
||||
props.refresh();
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
showError(t('保存失败,请重试'));
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false);
|
||||
});
|
||||
}
|
||||
|
||||
// 测试OAuth2连接(默认静默,仅用户点击时弹提示)
|
||||
const testOAuth2 = async (silent = true) => {
|
||||
// 未启用时不触发测试,避免 404
|
||||
if (!enabledRef.current) return;
|
||||
try {
|
||||
const res = await API.get('/api/oauth/server-info', {
|
||||
skipErrorHandler: true,
|
||||
});
|
||||
if (!enabledRef.current) return;
|
||||
if (
|
||||
res.status === 200 &&
|
||||
(res.data.issuer || res.data.authorization_endpoint)
|
||||
) {
|
||||
if (!silent) showSuccess('OAuth2服务器运行正常');
|
||||
setServerInfo(res.data);
|
||||
} else {
|
||||
if (!enabledRef.current) return;
|
||||
if (!silent) showError('OAuth2服务器测试失败');
|
||||
}
|
||||
} catch (error) {
|
||||
if (!enabledRef.current) return;
|
||||
if (!silent) showError('OAuth2服务器连接测试失败');
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (props && props.options) {
|
||||
const currentInputs = {};
|
||||
for (let key in props.options) {
|
||||
if (Object.keys(inputs).includes(key)) {
|
||||
if (key === 'oauth2.allowed_grant_types') {
|
||||
try {
|
||||
currentInputs[key] = JSON.parse(
|
||||
props.options[key] ||
|
||||
'["client_credentials","authorization_code","refresh_token"]',
|
||||
);
|
||||
} catch {
|
||||
currentInputs[key] = [
|
||||
'client_credentials',
|
||||
'authorization_code',
|
||||
'refresh_token',
|
||||
];
|
||||
}
|
||||
} else if (typeof inputs[key] === 'boolean') {
|
||||
currentInputs[key] = props.options[key] === 'true';
|
||||
} else if (typeof inputs[key] === 'number') {
|
||||
currentInputs[key] = parseInt(props.options[key]) || inputs[key];
|
||||
} else {
|
||||
currentInputs[key] = props.options[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
setInputs({ ...inputs, ...currentInputs });
|
||||
setInputsRow(structuredClone({ ...inputs, ...currentInputs }));
|
||||
if (refForm.current) {
|
||||
refForm.current.setValues({ ...inputs, ...currentInputs });
|
||||
}
|
||||
}
|
||||
}, [props]);
|
||||
|
||||
useEffect(() => {
|
||||
enabledRef.current = inputs['oauth2.enabled'];
|
||||
}, [inputs['oauth2.enabled']]);
|
||||
|
||||
useEffect(() => {
|
||||
const loadKeys = async () => {
|
||||
try {
|
||||
setKeysLoading(true);
|
||||
const res = await API.get('/api/oauth/keys', {
|
||||
skipErrorHandler: true,
|
||||
});
|
||||
const list = res?.data?.data || [];
|
||||
setKeysReady(list.length > 0);
|
||||
} catch {
|
||||
setKeysReady(false);
|
||||
} finally {
|
||||
setKeysLoading(false);
|
||||
}
|
||||
};
|
||||
if (inputs['oauth2.enabled']) {
|
||||
loadKeys();
|
||||
testOAuth2(true);
|
||||
} else {
|
||||
// 禁用时清理状态,避免残留状态与不必要的请求
|
||||
setKeysReady(true);
|
||||
setServerInfo(null);
|
||||
setKeysLoading(false);
|
||||
}
|
||||
}, [inputs['oauth2.enabled']]);
|
||||
|
||||
const isEnabled = inputs['oauth2.enabled'];
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* OAuth2 服务端管理 */}
|
||||
<Card
|
||||
className='!rounded-2xl shadow-sm border-0'
|
||||
style={{ marginTop: 10 }}
|
||||
title={
|
||||
<div
|
||||
className='flex flex-col sm:flex-row sm:items-center sm:justify-between w-full gap-3 sm:gap-0'
|
||||
style={{ paddingRight: '8px' }}
|
||||
>
|
||||
<div className='flex items-center'>
|
||||
<Server size={18} className='mr-2' />
|
||||
<Text strong>{t('OAuth2 服务端管理')}</Text>
|
||||
{isEnabled ? (
|
||||
serverInfo ? (
|
||||
<Badge
|
||||
count={t('运行正常')}
|
||||
type='success'
|
||||
style={{ marginLeft: 8 }}
|
||||
/>
|
||||
) : (
|
||||
<Badge
|
||||
count={t('配置中')}
|
||||
type='warning'
|
||||
style={{ marginLeft: 8 }}
|
||||
/>
|
||||
)
|
||||
) : (
|
||||
<Badge
|
||||
count={t('未启用')}
|
||||
type='tertiary'
|
||||
style={{ marginLeft: 8 }}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className='flex items-center gap-2 sm:flex-shrink-0'>
|
||||
{isEnabled && (
|
||||
<Button
|
||||
type='secondary'
|
||||
onClick={() => setJwksVisible(true)}
|
||||
size='small'
|
||||
>
|
||||
{t('密钥管理')}
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
type='primary'
|
||||
onClick={onSubmit}
|
||||
loading={loading}
|
||||
size='small'
|
||||
>
|
||||
{t('保存配置')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<Form
|
||||
initValues={inputs}
|
||||
getFormApi={(formAPI) => (refForm.current = formAPI)}
|
||||
>
|
||||
{!keysReady && isEnabled && (
|
||||
<Banner
|
||||
type='warning'
|
||||
className='!rounded-lg'
|
||||
closeIcon={null}
|
||||
description={t(
|
||||
'尚未准备签名密钥,建议立即初始化或轮换以发布 JWKS。签名密钥用于 JWT 令牌的安全签发。',
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Row gutter={[16, 24]}>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Switch
|
||||
field='oauth2.enabled'
|
||||
label={t('启用 OAuth2 & SSO')}
|
||||
value={inputs['oauth2.enabled']}
|
||||
onChange={handleFieldChange('oauth2.enabled')}
|
||||
extraText={t('开启后将允许以 OAuth2/OIDC 标准进行授权与登录')}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Input
|
||||
field='oauth2.issuer'
|
||||
label={t('发行人 (Issuer)')}
|
||||
placeholder={window.location.origin}
|
||||
value={inputs['oauth2.issuer']}
|
||||
onChange={handleFieldChange('oauth2.issuer')}
|
||||
extraText={t('为空则按请求自动推断(含 X-Forwarded-Proto)')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{/* 令牌配置 */}
|
||||
<Divider margin='24px'>{t('令牌配置')}</Divider>
|
||||
|
||||
<Row gutter={[16, 24]}>
|
||||
<Col xs={24} sm={12} lg={8}>
|
||||
<Form.InputNumber
|
||||
field='oauth2.access_token_ttl'
|
||||
label={t('访问令牌有效期')}
|
||||
suffix={t('分钟')}
|
||||
min={1}
|
||||
max={1440}
|
||||
value={inputs['oauth2.access_token_ttl']}
|
||||
onChange={handleFieldChange('oauth2.access_token_ttl')}
|
||||
extraText={t('访问令牌的有效时间,建议较短(10-60分钟)')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} sm={12} lg={8}>
|
||||
<Form.InputNumber
|
||||
field='oauth2.refresh_token_ttl'
|
||||
label={t('刷新令牌有效期')}
|
||||
suffix={t('小时')}
|
||||
min={1}
|
||||
max={8760}
|
||||
value={inputs['oauth2.refresh_token_ttl']}
|
||||
onChange={handleFieldChange('oauth2.refresh_token_ttl')}
|
||||
extraText={t('刷新令牌的有效时间,建议较长(12-720小时)')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} sm={12} lg={8}>
|
||||
<Form.InputNumber
|
||||
field='oauth2.max_jwks_keys'
|
||||
label={t('JWKS历史保留上限')}
|
||||
min={1}
|
||||
max={10}
|
||||
value={inputs['oauth2.max_jwks_keys']}
|
||||
onChange={handleFieldChange('oauth2.max_jwks_keys')}
|
||||
extraText={t('轮换后最多保留的历史签名密钥数量')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={[16, 24]} style={{ marginTop: 16 }}>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Select
|
||||
field='oauth2.jwt_signing_algorithm'
|
||||
label={t('JWT签名算法')}
|
||||
value={inputs['oauth2.jwt_signing_algorithm']}
|
||||
onChange={handleFieldChange('oauth2.jwt_signing_algorithm')}
|
||||
extraText={t('JWT令牌的签名算法,推荐使用RS256')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
>
|
||||
<Form.Select.Option value='RS256'>
|
||||
RS256 (RSA with SHA-256)
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='HS256'>
|
||||
HS256 (HMAC with SHA-256)
|
||||
</Form.Select.Option>
|
||||
</Form.Select>
|
||||
</Col>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Input
|
||||
field='oauth2.jwt_key_id'
|
||||
label={t('JWT密钥ID')}
|
||||
placeholder='oauth2-key-1'
|
||||
value={inputs['oauth2.jwt_key_id']}
|
||||
onChange={handleFieldChange('oauth2.jwt_key_id')}
|
||||
extraText={t('用于标识JWT签名密钥,支持密钥轮换')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{/* 授权配置 */}
|
||||
<Divider margin='24px'>{t('授权配置')}</Divider>
|
||||
|
||||
<Row gutter={[16, 24]}>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Select
|
||||
field='oauth2.allowed_grant_types'
|
||||
label={t('允许的授权类型')}
|
||||
multiple
|
||||
value={inputs['oauth2.allowed_grant_types']}
|
||||
onChange={handleFieldChange('oauth2.allowed_grant_types')}
|
||||
extraText={t('选择允许的OAuth2授权流程')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
>
|
||||
<Form.Select.Option value='client_credentials'>
|
||||
{t('Client Credentials(客户端凭证)')}
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='authorization_code'>
|
||||
{t('Authorization Code(授权码)')}
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='refresh_token'>
|
||||
{t('Refresh Token(刷新令牌)')}
|
||||
</Form.Select.Option>
|
||||
</Form.Select>
|
||||
</Col>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Switch
|
||||
field='oauth2.require_pkce'
|
||||
label={t('强制PKCE验证')}
|
||||
value={inputs['oauth2.require_pkce']}
|
||||
onChange={handleFieldChange('oauth2.require_pkce')}
|
||||
extraText={t('为授权码流程强制启用PKCE,提高安全性')}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<div style={{ marginTop: 16 }}>
|
||||
<Text type='tertiary' size='small'>
|
||||
<div className='space-y-1'>
|
||||
<div>• {t('OAuth2 服务器提供标准的 API 认证与授权')}</div>
|
||||
<div>
|
||||
•{' '}
|
||||
{t(
|
||||
'支持 Client Credentials、Authorization Code + PKCE 等标准流程',
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
•{' '}
|
||||
{t(
|
||||
'配置保存后多数项即时生效;签名密钥轮换与 JWKS 发布为即时操作',
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
• {t('生产环境务必启用 HTTPS,并妥善管理 JWT 签名密钥')}
|
||||
</div>
|
||||
</div>
|
||||
</Text>
|
||||
</div>
|
||||
</Form>
|
||||
</Card>
|
||||
|
||||
{/* 模态框 */}
|
||||
<JWKSManagerModal
|
||||
visible={jwksVisible}
|
||||
onClose={() => setJwksVisible(false)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Modal, Banner, Typography } from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const ClientInfoModal = ({ visible, onClose, clientId, clientSecret }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('客户端创建成功')}
|
||||
visible={visible}
|
||||
onCancel={onClose}
|
||||
onOk={onClose}
|
||||
cancelText=''
|
||||
okText={t('我已复制保存')}
|
||||
width={650}
|
||||
bodyStyle={{ padding: '20px 24px' }}
|
||||
>
|
||||
<Banner
|
||||
type='success'
|
||||
closeIcon={null}
|
||||
description={t(
|
||||
'客户端信息如下,请立即复制保存。关闭此窗口后将无法再次查看密钥。',
|
||||
)}
|
||||
className='mb-5 !rounded-lg'
|
||||
/>
|
||||
|
||||
<div className='space-y-4'>
|
||||
<div className='flex justify-center items-center'>
|
||||
<div className='text-center'>
|
||||
<Text strong className='block mb-2'>
|
||||
{t('客户端ID')}
|
||||
</Text>
|
||||
<Text code copyable>
|
||||
{clientId}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{clientSecret && (
|
||||
<div className='flex justify-center items-center'>
|
||||
<div className='text-center'>
|
||||
<Text strong className='block mb-2'>
|
||||
{t('客户端密钥(仅此一次显示)')}
|
||||
</Text>
|
||||
<Text code copyable>
|
||||
{clientSecret}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default ClientInfoModal;
|
||||
70
web/src/components/settings/oauth2/modals/JWKSInfoModal.jsx
Normal file
70
web/src/components/settings/oauth2/modals/JWKSInfoModal.jsx
Normal file
@@ -0,0 +1,70 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Modal } from '@douyinfe/semi-ui';
|
||||
import { API, showError } from '../../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import CodeViewer from '../../../common/ui/CodeViewer';
|
||||
|
||||
const JWKSInfoModal = ({ visible, onClose }) => {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [jwksInfo, setJwksInfo] = useState(null);
|
||||
|
||||
const loadJWKSInfo = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/oauth/jwks');
|
||||
setJwksInfo(res.data);
|
||||
} catch (error) {
|
||||
showError(t('获取JWKS失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
loadJWKSInfo();
|
||||
}
|
||||
}, [visible]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('JWKS 信息')}
|
||||
visible={visible}
|
||||
onCancel={onClose}
|
||||
onOk={onClose}
|
||||
cancelText=''
|
||||
okText={t('关闭')}
|
||||
width={650}
|
||||
bodyStyle={{ padding: '20px 24px' }}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<CodeViewer
|
||||
content={jwksInfo ? JSON.stringify(jwksInfo, null, 2) : t('加载中...')}
|
||||
title={t('JWKS 密钥集')}
|
||||
language='json'
|
||||
/>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default JWKSInfoModal;
|
||||
399
web/src/components/settings/oauth2/modals/JWKSManagerModal.jsx
Normal file
399
web/src/components/settings/oauth2/modals/JWKSManagerModal.jsx
Normal file
@@ -0,0 +1,399 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
Table,
|
||||
Button,
|
||||
Space,
|
||||
Tag,
|
||||
Typography,
|
||||
Popconfirm,
|
||||
Toast,
|
||||
Form,
|
||||
Card,
|
||||
Tabs,
|
||||
TabPane,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { API, showError, showSuccess } from '../../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ResponsiveModal from '../../../common/ui/ResponsiveModal';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
// 操作模式枚举
|
||||
const OPERATION_MODES = {
|
||||
VIEW: 'view',
|
||||
IMPORT: 'import',
|
||||
GENERATE: 'generate',
|
||||
};
|
||||
|
||||
export default function JWKSManagerModal({ visible, onClose }) {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [keys, setKeys] = useState([]);
|
||||
const [activeTab, setActiveTab] = useState(OPERATION_MODES.VIEW);
|
||||
|
||||
const load = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/oauth/keys');
|
||||
if (res?.data?.success) setKeys(res.data.data || []);
|
||||
else showError(res?.data?.message || t('获取密钥列表失败'));
|
||||
} catch {
|
||||
showError(t('获取密钥列表失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const rotate = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.post('/api/oauth/keys/rotate', {});
|
||||
if (res?.data?.success) {
|
||||
showSuccess(t('签名密钥已轮换:{{kid}}', { kid: res.data.kid }));
|
||||
await load();
|
||||
} else showError(res?.data?.message || t('密钥轮换失败'));
|
||||
} catch {
|
||||
showError(t('密钥轮换失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const del = async (kid) => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.delete(`/api/oauth/keys/${kid}`);
|
||||
if (res?.data?.success) {
|
||||
Toast.success(t('已删除:{{kid}}', { kid }));
|
||||
await load();
|
||||
} else showError(res?.data?.message || t('删除失败'));
|
||||
} catch {
|
||||
showError(t('删除失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Import PEM state
|
||||
const [pem, setPem] = useState('');
|
||||
const [customKid, setCustomKid] = useState('');
|
||||
|
||||
// Generate PEM file state
|
||||
const [genPath, setGenPath] = useState('/etc/new-api/oauth2-private.pem');
|
||||
const [genKid, setGenKid] = useState('');
|
||||
|
||||
// 重置表单数据
|
||||
const resetForms = () => {
|
||||
setPem('');
|
||||
setCustomKid('');
|
||||
setGenKid('');
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
load();
|
||||
// 重置到主视图
|
||||
setActiveTab(OPERATION_MODES.VIEW);
|
||||
} else {
|
||||
// 模态框关闭时重置表单数据
|
||||
resetForms();
|
||||
}
|
||||
}, [visible]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!visible) return;
|
||||
(async () => {
|
||||
try {
|
||||
const res = await API.get('/api/oauth/server-info');
|
||||
const p = res?.data?.default_private_key_path;
|
||||
if (p) setGenPath(p);
|
||||
} catch {}
|
||||
})();
|
||||
}, [visible]);
|
||||
|
||||
// 导入 PEM 私钥
|
||||
const importPem = async () => {
|
||||
if (!pem.trim()) return Toast.warning(t('请粘贴 PEM 私钥'));
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.post('/api/oauth/keys/import_pem', {
|
||||
pem,
|
||||
kid: customKid.trim(),
|
||||
});
|
||||
if (res?.data?.success) {
|
||||
Toast.success(
|
||||
t('已导入私钥并切换到 kid={{kid}}', { kid: res.data.kid }),
|
||||
);
|
||||
resetForms();
|
||||
setActiveTab(OPERATION_MODES.VIEW);
|
||||
await load();
|
||||
} else {
|
||||
Toast.error(res?.data?.message || t('导入失败'));
|
||||
}
|
||||
} catch {
|
||||
Toast.error(t('导入失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 生成 PEM 文件
|
||||
const generatePemFile = async () => {
|
||||
if (!genPath.trim()) return Toast.warning(t('请填写保存路径'));
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.post('/api/oauth/keys/generate_file', {
|
||||
path: genPath.trim(),
|
||||
kid: genKid.trim(),
|
||||
});
|
||||
if (res?.data?.success) {
|
||||
Toast.success(t('已生成并生效:{{path}}', { path: res.data.path }));
|
||||
resetForms();
|
||||
setActiveTab(OPERATION_MODES.VIEW);
|
||||
await load();
|
||||
} else {
|
||||
Toast.error(res?.data?.message || t('生成失败'));
|
||||
}
|
||||
} catch {
|
||||
Toast.error(t('生成失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const columns = [
|
||||
{
|
||||
title: 'KID',
|
||||
dataIndex: 'kid',
|
||||
render: (kid) => (
|
||||
<Text code copyable>
|
||||
{kid}
|
||||
</Text>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('创建时间'),
|
||||
dataIndex: 'created_at',
|
||||
render: (ts) => (ts ? new Date(ts * 1000).toLocaleString() : '-'),
|
||||
},
|
||||
{
|
||||
title: t('状态'),
|
||||
dataIndex: 'current',
|
||||
render: (cur) =>
|
||||
cur ? (
|
||||
<Tag color='green' shape='circle'>
|
||||
{t('当前')}
|
||||
</Tag>
|
||||
) : (
|
||||
<Tag shape='circle'>{t('历史')}</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('操作'),
|
||||
render: (_, r) => (
|
||||
<Space>
|
||||
{!r.current && (
|
||||
<Popconfirm
|
||||
title={t('确定删除密钥 {{kid}} ?', { kid: r.kid })}
|
||||
content={t(
|
||||
'删除后使用该 kid 签发的旧令牌仍可被验证(外部 JWKS 缓存可能仍保留)',
|
||||
)}
|
||||
okText={t('删除')}
|
||||
onConfirm={() => del(r.kid)}
|
||||
>
|
||||
<Button size='small' type='danger'>
|
||||
{t('删除')}
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
)}
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
// 头部操作按钮 - 根据当前标签页动态生成
|
||||
const getHeaderActions = () => {
|
||||
if (activeTab === OPERATION_MODES.VIEW) {
|
||||
const hasKeys = Array.isArray(keys) && keys.length > 0;
|
||||
return [
|
||||
<Button key='refresh' onClick={load} loading={loading} size='small'>
|
||||
{t('刷新')}
|
||||
</Button>,
|
||||
<Button
|
||||
key='rotate'
|
||||
type='primary'
|
||||
onClick={rotate}
|
||||
loading={loading}
|
||||
size='small'
|
||||
>
|
||||
{hasKeys ? t('轮换密钥') : t('初始化密钥')}
|
||||
</Button>,
|
||||
];
|
||||
}
|
||||
|
||||
if (activeTab === OPERATION_MODES.IMPORT) {
|
||||
return [
|
||||
<Button
|
||||
key='import'
|
||||
type='primary'
|
||||
onClick={importPem}
|
||||
loading={loading}
|
||||
size='small'
|
||||
>
|
||||
{t('导入并生效')}
|
||||
</Button>,
|
||||
];
|
||||
}
|
||||
|
||||
if (activeTab === OPERATION_MODES.GENERATE) {
|
||||
return [
|
||||
<Button
|
||||
key='generate'
|
||||
type='primary'
|
||||
onClick={generatePemFile}
|
||||
loading={loading}
|
||||
size='small'
|
||||
>
|
||||
{t('生成并生效')}
|
||||
</Button>,
|
||||
];
|
||||
}
|
||||
|
||||
return [];
|
||||
};
|
||||
|
||||
// 渲染密钥列表视图
|
||||
const renderKeysView = () => (
|
||||
<Card
|
||||
className='!rounded-lg'
|
||||
title={
|
||||
<Text className='text-blue-700 dark:text-blue-300'>
|
||||
{t(
|
||||
'提示:当前密钥用于签发 JWT 令牌。建议定期轮换密钥以提升安全性。只有历史密钥可以删除。',
|
||||
)}
|
||||
</Text>
|
||||
}
|
||||
>
|
||||
<Table
|
||||
dataSource={keys}
|
||||
columns={columns}
|
||||
rowKey='kid'
|
||||
loading={loading}
|
||||
pagination={false}
|
||||
empty={<Text type='tertiary'>{t('暂无密钥')}</Text>}
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
|
||||
// 渲染导入 PEM 私钥视图
|
||||
const renderImportView = () => (
|
||||
<Card
|
||||
className='!rounded-lg'
|
||||
title={
|
||||
<Text className='text-yellow-700 dark:text-yellow-300'>
|
||||
{t(
|
||||
'建议:优先使用内存签名密钥与 JWKS 轮换;仅在有合规要求时导入外部私钥。请确保私钥来源可信。',
|
||||
)}
|
||||
</Text>
|
||||
}
|
||||
>
|
||||
<Form labelPosition='left' labelWidth={120}>
|
||||
<Form.Input
|
||||
field='kid'
|
||||
label={t('自定义 KID')}
|
||||
placeholder={t('可留空自动生成')}
|
||||
value={customKid}
|
||||
onChange={setCustomKid}
|
||||
/>
|
||||
<Form.TextArea
|
||||
field='pem'
|
||||
label={t('PEM 私钥')}
|
||||
value={pem}
|
||||
onChange={setPem}
|
||||
rows={8}
|
||||
placeholder={
|
||||
'-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----'
|
||||
}
|
||||
/>
|
||||
</Form>
|
||||
</Card>
|
||||
);
|
||||
|
||||
// 渲染生成 PEM 文件视图
|
||||
const renderGenerateView = () => (
|
||||
<Card
|
||||
className='!rounded-lg'
|
||||
title={
|
||||
<Text className='text-orange-700 dark:text-orange-300'>
|
||||
{t(
|
||||
'建议:仅在合规要求下使用文件私钥。请确保目录权限安全(建议 0600),并妥善备份。',
|
||||
)}
|
||||
</Text>
|
||||
}
|
||||
>
|
||||
<Form labelPosition='left' labelWidth={120}>
|
||||
<Form.Input
|
||||
field='path'
|
||||
label={t('保存路径')}
|
||||
value={genPath}
|
||||
onChange={setGenPath}
|
||||
placeholder='/secure/path/oauth2-private.pem'
|
||||
/>
|
||||
<Form.Input
|
||||
field='genKid'
|
||||
label={t('自定义 KID')}
|
||||
value={genKid}
|
||||
onChange={setGenKid}
|
||||
placeholder={t('可留空自动生成')}
|
||||
/>
|
||||
</Form>
|
||||
</Card>
|
||||
);
|
||||
|
||||
return (
|
||||
<ResponsiveModal
|
||||
visible={visible}
|
||||
title={t('JWKS 管理')}
|
||||
headerActions={getHeaderActions()}
|
||||
onCancel={onClose}
|
||||
footer={null}
|
||||
width={{ mobile: '95%', desktop: 800 }}
|
||||
>
|
||||
<Tabs
|
||||
activeKey={activeTab}
|
||||
onChange={setActiveTab}
|
||||
type='card'
|
||||
size='medium'
|
||||
className='!-mt-2'
|
||||
>
|
||||
<TabPane tab={t('密钥列表')} itemKey={OPERATION_MODES.VIEW}>
|
||||
{renderKeysView()}
|
||||
</TabPane>
|
||||
<TabPane tab={t('导入 PEM 私钥')} itemKey={OPERATION_MODES.IMPORT}>
|
||||
{renderImportView()}
|
||||
</TabPane>
|
||||
<TabPane tab={t('生成 PEM 文件')} itemKey={OPERATION_MODES.GENERATE}>
|
||||
{renderGenerateView()}
|
||||
</TabPane>
|
||||
</Tabs>
|
||||
</ResponsiveModal>
|
||||
);
|
||||
}
|
||||
730
web/src/components/settings/oauth2/modals/OAuth2ClientModal.jsx
Normal file
730
web/src/components/settings/oauth2/modals/OAuth2ClientModal.jsx
Normal file
@@ -0,0 +1,730 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import {
|
||||
SideSheet,
|
||||
Form,
|
||||
Input,
|
||||
Select,
|
||||
Space,
|
||||
Typography,
|
||||
Button,
|
||||
Card,
|
||||
Avatar,
|
||||
Tag,
|
||||
Spin,
|
||||
Radio,
|
||||
Divider,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IconKey,
|
||||
IconLink,
|
||||
IconSave,
|
||||
IconClose,
|
||||
IconPlus,
|
||||
IconDelete,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { API, showError, showSuccess } from '../../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsMobile } from '../../../../hooks/common/useIsMobile';
|
||||
import ClientInfoModal from './ClientInfoModal';
|
||||
|
||||
const { Text, Title } = Typography;
|
||||
const { Option } = Select;
|
||||
|
||||
const AUTH_CODE = 'authorization_code';
|
||||
const CLIENT_CREDENTIALS = 'client_credentials';
|
||||
|
||||
// 子组件:重定向URI编辑卡片
|
||||
function RedirectUriCard({
|
||||
t,
|
||||
isAuthCodeSelected,
|
||||
redirectUris,
|
||||
onAdd,
|
||||
onUpdate,
|
||||
onRemove,
|
||||
onFillTemplate,
|
||||
}) {
|
||||
return (
|
||||
<Card
|
||||
header={
|
||||
<div className='flex justify-between items-center'>
|
||||
<div className='flex items-center'>
|
||||
<Avatar size='small' color='purple' className='mr-2 shadow-md'>
|
||||
<IconLink size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className='text-lg font-medium'>{t('重定向URI配置')}</Text>
|
||||
<div className='text-xs text-gray-600'>
|
||||
{t('用于授权码流程的重定向地址')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
type='tertiary'
|
||||
onClick={onFillTemplate}
|
||||
size='small'
|
||||
disabled={!isAuthCodeSelected}
|
||||
>
|
||||
{t('填入示例模板')}
|
||||
</Button>
|
||||
</div>
|
||||
}
|
||||
headerStyle={{ padding: '12px 16px' }}
|
||||
bodyStyle={{ padding: '16px' }}
|
||||
className='!rounded-2xl shadow-sm border-0'
|
||||
>
|
||||
<div className='space-y-1'>
|
||||
{redirectUris.length === 0 && (
|
||||
<div className='text-center py-4 px-4'>
|
||||
<Text type='tertiary' className='text-gray-500 text-sm'>
|
||||
{t('暂无重定向URI,点击下方按钮添加')}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{redirectUris.map((uri, index) => (
|
||||
<div
|
||||
key={index}
|
||||
style={{
|
||||
marginBottom: 8,
|
||||
display: 'flex',
|
||||
gap: 8,
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
<Input
|
||||
placeholder={t('例如:https://your-app.com/callback')}
|
||||
value={uri}
|
||||
onChange={(value) => onUpdate(index, value)}
|
||||
style={{ flex: 1 }}
|
||||
disabled={!isAuthCodeSelected}
|
||||
/>
|
||||
<Button
|
||||
icon={<IconDelete />}
|
||||
type='danger'
|
||||
theme='borderless'
|
||||
onClick={() => onRemove(index)}
|
||||
disabled={!isAuthCodeSelected}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<div className='py-2 flex justify-center gap-2'>
|
||||
<Button
|
||||
icon={<IconPlus />}
|
||||
type='primary'
|
||||
theme='outline'
|
||||
onClick={onAdd}
|
||||
disabled={!isAuthCodeSelected}
|
||||
>
|
||||
{t('添加重定向URI')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Divider margin='12px' align='center'>
|
||||
<Text type='tertiary' size='small'>
|
||||
{isAuthCodeSelected
|
||||
? t(
|
||||
'用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP,仅限localhost/127.0.0.1)',
|
||||
)
|
||||
: t('仅在选择“授权码”授权类型时需要配置重定向URI')}
|
||||
</Text>
|
||||
</Divider>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
const OAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
|
||||
const { t } = useTranslation();
|
||||
const isMobile = useIsMobile();
|
||||
const formApiRef = useRef(null);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [redirectUris, setRedirectUris] = useState([]);
|
||||
const [clientType, setClientType] = useState('confidential');
|
||||
const [grantTypes, setGrantTypes] = useState([]);
|
||||
const [allowedGrantTypes, setAllowedGrantTypes] = useState([
|
||||
CLIENT_CREDENTIALS,
|
||||
AUTH_CODE,
|
||||
'refresh_token',
|
||||
]);
|
||||
|
||||
// ClientInfoModal 状态
|
||||
const [showClientInfo, setShowClientInfo] = useState(false);
|
||||
const [clientInfo, setClientInfo] = useState({
|
||||
clientId: '',
|
||||
clientSecret: '',
|
||||
});
|
||||
|
||||
const isEdit = client?.id !== undefined;
|
||||
const [mode, setMode] = useState('create'); // 'create' | 'edit'
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
setMode(isEdit ? 'edit' : 'create');
|
||||
}
|
||||
}, [visible, isEdit]);
|
||||
|
||||
const getInitValues = () => ({
|
||||
name: '',
|
||||
description: '',
|
||||
client_type: 'confidential',
|
||||
grant_types: [],
|
||||
scopes: [],
|
||||
require_pkce: true,
|
||||
status: 1,
|
||||
});
|
||||
|
||||
// 加载后端允许的授权类型
|
||||
useEffect(() => {
|
||||
let mounted = true;
|
||||
(async () => {
|
||||
try {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, data } = res.data || {};
|
||||
if (!success || !Array.isArray(data)) return;
|
||||
const found = data.find((i) => i.key === 'oauth2.allowed_grant_types');
|
||||
if (!found) return;
|
||||
let parsed = [];
|
||||
try {
|
||||
parsed = JSON.parse(found.value || '[]');
|
||||
} catch (_) {}
|
||||
if (mounted && Array.isArray(parsed) && parsed.length) {
|
||||
setAllowedGrantTypes(parsed);
|
||||
}
|
||||
} catch (_) {
|
||||
// 忽略错误,使用默认allowedGrantTypes
|
||||
}
|
||||
})();
|
||||
return () => {
|
||||
mounted = false;
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
setGrantTypes((prev) => {
|
||||
const normalizedPrev = Array.isArray(prev) ? prev : [];
|
||||
// 移除不被允许或与客户端类型冲突的类型
|
||||
let next = normalizedPrev.filter((g) => allowedGrantTypes.includes(g));
|
||||
if (clientType === 'public') {
|
||||
next = next.filter((g) => g !== CLIENT_CREDENTIALS);
|
||||
}
|
||||
return next.length ? next : [];
|
||||
});
|
||||
}, [clientType, allowedGrantTypes]);
|
||||
|
||||
// 初始化表单数据(编辑模式)
|
||||
useEffect(() => {
|
||||
if (client && visible && isEdit) {
|
||||
setLoading(true);
|
||||
// 解析授权类型
|
||||
let parsedGrantTypes = [];
|
||||
if (typeof client.grant_types === 'string') {
|
||||
parsedGrantTypes = client.grant_types.split(',');
|
||||
} else if (Array.isArray(client.grant_types)) {
|
||||
parsedGrantTypes = client.grant_types;
|
||||
}
|
||||
|
||||
// 解析Scope
|
||||
let parsedScopes = [];
|
||||
if (typeof client.scopes === 'string') {
|
||||
parsedScopes = client.scopes.split(',');
|
||||
} else if (Array.isArray(client.scopes)) {
|
||||
parsedScopes = client.scopes;
|
||||
}
|
||||
if (!parsedScopes || parsedScopes.length === 0) {
|
||||
parsedScopes = ['openid', 'profile', 'email', 'api:read'];
|
||||
}
|
||||
|
||||
// 解析重定向URI
|
||||
let parsedRedirectUris = [];
|
||||
if (client.redirect_uris) {
|
||||
try {
|
||||
const parsed =
|
||||
typeof client.redirect_uris === 'string'
|
||||
? JSON.parse(client.redirect_uris)
|
||||
: client.redirect_uris;
|
||||
if (Array.isArray(parsed) && parsed.length > 0) {
|
||||
parsedRedirectUris = parsed;
|
||||
}
|
||||
} catch (e) {}
|
||||
}
|
||||
|
||||
// 过滤不被允许或不兼容的授权类型
|
||||
const filteredGrantTypes = (parsedGrantTypes || []).filter((g) =>
|
||||
allowedGrantTypes.includes(g),
|
||||
);
|
||||
const finalGrantTypes =
|
||||
client.client_type === 'public'
|
||||
? filteredGrantTypes.filter((g) => g !== CLIENT_CREDENTIALS)
|
||||
: filteredGrantTypes;
|
||||
|
||||
setClientType(client.client_type);
|
||||
setGrantTypes(finalGrantTypes);
|
||||
// 不自动新增空白URI,保持与创建模式一致的手动添加体验
|
||||
setRedirectUris(parsedRedirectUris);
|
||||
|
||||
// 设置表单值
|
||||
const formValues = {
|
||||
id: client.id,
|
||||
name: client.name,
|
||||
description: client.description,
|
||||
client_type: client.client_type,
|
||||
grant_types: finalGrantTypes,
|
||||
scopes: parsedScopes,
|
||||
require_pkce: !!client.require_pkce,
|
||||
status: client.status,
|
||||
};
|
||||
|
||||
setTimeout(() => {
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValues(formValues);
|
||||
}
|
||||
setLoading(false);
|
||||
}, 100);
|
||||
} else if (visible && !isEdit) {
|
||||
// 创建模式,重置状态
|
||||
setClientType('confidential');
|
||||
setGrantTypes([]);
|
||||
setRedirectUris([]);
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValues(getInitValues());
|
||||
}
|
||||
}
|
||||
}, [client, visible, isEdit, allowedGrantTypes]);
|
||||
|
||||
const isAuthCodeSelected = grantTypes.includes(AUTH_CODE);
|
||||
const isGrantTypeDisabled = (value) => {
|
||||
if (!allowedGrantTypes.includes(value)) return true;
|
||||
if (clientType === 'public' && value === CLIENT_CREDENTIALS) return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// URL校验:允许 https;http 仅限本地开发域名
|
||||
const isValidRedirectUri = (uri) => {
|
||||
if (!uri || !uri.trim()) return false;
|
||||
try {
|
||||
const u = new URL(uri.trim());
|
||||
if (u.protocol === 'https:') return true;
|
||||
if (u.protocol === 'http:') {
|
||||
const host = u.hostname;
|
||||
return (
|
||||
host === 'localhost' ||
|
||||
host === '127.0.0.1' ||
|
||||
host.endsWith('.local')
|
||||
);
|
||||
}
|
||||
return false;
|
||||
} catch (_) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// 处理提交
|
||||
const handleSubmit = async (values) => {
|
||||
setLoading(true);
|
||||
try {
|
||||
// 过滤空的重定向URI
|
||||
const validRedirectUris = redirectUris
|
||||
.map((u) => (u || '').trim())
|
||||
.filter((u) => u.length > 0);
|
||||
|
||||
// 业务校验
|
||||
if (!grantTypes.length) {
|
||||
showError(t('请至少选择一种授权类型'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// 校验是否包含不被允许的授权类型
|
||||
const invalids = grantTypes.filter((g) => !allowedGrantTypes.includes(g));
|
||||
if (invalids.length) {
|
||||
showError(
|
||||
t('不被允许的授权类型: {{types}}', { types: invalids.join(', ') }),
|
||||
);
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (clientType === 'public' && grantTypes.includes(CLIENT_CREDENTIALS)) {
|
||||
showError(t('公开客户端不允许使用client_credentials授权类型'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (grantTypes.includes(AUTH_CODE)) {
|
||||
if (!validRedirectUris.length) {
|
||||
showError(t('选择授权码授权类型时,必须填写至少一个重定向URI'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
const allValid = validRedirectUris.every(isValidRedirectUri);
|
||||
if (!allValid) {
|
||||
showError(t('重定向URI格式不合法:仅支持https,或本地开发使用http'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 避免把 Radio 组件对象形式的 client_type 直接传给后端
|
||||
const { client_type: _formClientType, ...restValues } = values || {};
|
||||
const payload = {
|
||||
...restValues,
|
||||
client_type: clientType,
|
||||
grant_types: grantTypes,
|
||||
redirect_uris: validRedirectUris,
|
||||
};
|
||||
|
||||
let res;
|
||||
if (isEdit) {
|
||||
res = await API.put('/api/oauth_clients/', payload);
|
||||
} else {
|
||||
res = await API.post('/api/oauth_clients/', payload);
|
||||
}
|
||||
|
||||
const { success, message, client_id, client_secret } = res.data;
|
||||
|
||||
if (success) {
|
||||
if (isEdit) {
|
||||
showSuccess(t('OAuth2客户端更新成功'));
|
||||
resetForm();
|
||||
onSuccess();
|
||||
} else {
|
||||
showSuccess(t('OAuth2客户端创建成功'));
|
||||
// 显示客户端信息
|
||||
setClientInfo({
|
||||
clientId: client_id,
|
||||
clientSecret: client_secret,
|
||||
});
|
||||
setShowClientInfo(true);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(isEdit ? t('更新OAuth2客户端失败') : t('创建OAuth2客户端失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 重置表单
|
||||
const resetForm = () => {
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.reset();
|
||||
}
|
||||
setClientType('confidential');
|
||||
setGrantTypes([]);
|
||||
setRedirectUris([]);
|
||||
};
|
||||
|
||||
// 处理ClientInfoModal关闭
|
||||
const handleClientInfoClose = () => {
|
||||
setShowClientInfo(false);
|
||||
setClientInfo({ clientId: '', clientSecret: '' });
|
||||
resetForm();
|
||||
onSuccess();
|
||||
};
|
||||
|
||||
// 处理取消
|
||||
const handleCancel = () => {
|
||||
resetForm();
|
||||
onCancel();
|
||||
};
|
||||
|
||||
// 添加重定向URI
|
||||
const addRedirectUri = () => {
|
||||
setRedirectUris([...redirectUris, '']);
|
||||
};
|
||||
|
||||
// 删除重定向URI
|
||||
const removeRedirectUri = (index) => {
|
||||
setRedirectUris(redirectUris.filter((_, i) => i !== index));
|
||||
};
|
||||
|
||||
// 更新重定向URI
|
||||
const updateRedirectUri = (index, value) => {
|
||||
const newUris = [...redirectUris];
|
||||
newUris[index] = value;
|
||||
setRedirectUris(newUris);
|
||||
};
|
||||
|
||||
// 填入示例重定向URI模板
|
||||
const fillRedirectUriTemplate = () => {
|
||||
const template = [
|
||||
'https://your-app.com/auth/callback',
|
||||
'https://localhost:3000/callback',
|
||||
];
|
||||
setRedirectUris(template);
|
||||
};
|
||||
|
||||
// 授权类型变化处理(清理非法项,只设置一次)
|
||||
const handleGrantTypesChange = (values) => {
|
||||
const allowed = Array.isArray(values)
|
||||
? values.filter((v) => allowedGrantTypes.includes(v))
|
||||
: [];
|
||||
const sanitized =
|
||||
clientType === 'public'
|
||||
? allowed.filter((v) => v !== CLIENT_CREDENTIALS)
|
||||
: allowed;
|
||||
setGrantTypes(sanitized);
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('grant_types', sanitized);
|
||||
}
|
||||
};
|
||||
|
||||
// 客户端类型变化处理(兼容 RadioGroup 事件对象与直接值)
|
||||
const handleClientTypeChange = (next) => {
|
||||
const value = next && next.target ? next.target.value : next;
|
||||
setClientType(value);
|
||||
// 公开客户端自动移除 client_credentials,并同步表单字段
|
||||
const current = Array.isArray(grantTypes) ? grantTypes : [];
|
||||
const sanitized =
|
||||
value === 'public'
|
||||
? current.filter((g) => g !== CLIENT_CREDENTIALS)
|
||||
: current;
|
||||
if (sanitized !== current) {
|
||||
setGrantTypes(sanitized);
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('grant_types', sanitized);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<SideSheet
|
||||
placement={mode === 'edit' ? 'right' : 'left'}
|
||||
title={
|
||||
<Space>
|
||||
{mode === 'edit' ? (
|
||||
<Tag color='blue' shape='circle'>
|
||||
{t('编辑')}
|
||||
</Tag>
|
||||
) : (
|
||||
<Tag color='green' shape='circle'>
|
||||
{t('创建')}
|
||||
</Tag>
|
||||
)}
|
||||
<Title heading={4} className='m-0'>
|
||||
{mode === 'edit' ? t('编辑OAuth2客户端') : t('创建OAuth2客户端')}
|
||||
</Title>
|
||||
</Space>
|
||||
}
|
||||
bodyStyle={{ padding: '0' }}
|
||||
visible={visible}
|
||||
width={isMobile ? '100%' : 700}
|
||||
footer={
|
||||
<div className='flex justify-end bg-white'>
|
||||
<Space>
|
||||
<Button
|
||||
theme='solid'
|
||||
className='!rounded-lg'
|
||||
onClick={() => formApiRef.current?.submitForm()}
|
||||
icon={<IconSave />}
|
||||
loading={loading}
|
||||
>
|
||||
{isEdit ? t('保存') : t('创建')}
|
||||
</Button>
|
||||
<Button
|
||||
theme='light'
|
||||
className='!rounded-lg'
|
||||
type='primary'
|
||||
onClick={handleCancel}
|
||||
icon={<IconClose />}
|
||||
>
|
||||
{t('取消')}
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
}
|
||||
closeIcon={null}
|
||||
onCancel={handleCancel}
|
||||
>
|
||||
<Spin spinning={loading}>
|
||||
<Form
|
||||
key={isEdit ? `edit-${client?.id}` : 'create'}
|
||||
initValues={getInitValues()}
|
||||
getFormApi={(api) => (formApiRef.current = api)}
|
||||
onSubmit={handleSubmit}
|
||||
>
|
||||
{() => (
|
||||
<div className='p-2'>
|
||||
{/* 表单内容 */}
|
||||
{/* 基本信息 */}
|
||||
<Card className='!rounded-2xl shadow-sm border-0'>
|
||||
<div className='flex items-center mb-4'>
|
||||
<Avatar size='small' color='blue' className='mr-2 shadow-md'>
|
||||
<IconKey size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className='text-lg font-medium'>{t('基本信息')}</Text>
|
||||
<div className='text-xs text-gray-600'>
|
||||
{t('设置客户端的基本信息')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{isEdit && (
|
||||
<>
|
||||
<Form.Select
|
||||
field='status'
|
||||
label={t('状态')}
|
||||
rules={[{ required: true, message: t('请选择状态') }]}
|
||||
required
|
||||
>
|
||||
<Option value={1}>{t('启用')}</Option>
|
||||
<Option value={2}>{t('禁用')}</Option>
|
||||
</Form.Select>
|
||||
<Form.Input field='id' label={t('客户端ID')} disabled />
|
||||
</>
|
||||
)}
|
||||
<Form.Input
|
||||
field='name'
|
||||
label={t('客户端名称')}
|
||||
placeholder={t('输入客户端名称')}
|
||||
rules={[{ required: true, message: t('请输入客户端名称') }]}
|
||||
required
|
||||
showClear
|
||||
/>
|
||||
<Form.TextArea
|
||||
field='description'
|
||||
label={t('描述')}
|
||||
placeholder={t('输入客户端描述')}
|
||||
rows={3}
|
||||
showClear
|
||||
/>
|
||||
<Form.RadioGroup
|
||||
label={t('客户端类型')}
|
||||
field='client_type'
|
||||
value={clientType}
|
||||
onChange={handleClientTypeChange}
|
||||
type='card'
|
||||
aria-label={t('选择客户端类型')}
|
||||
disabled={isEdit}
|
||||
rules={[{ required: true, message: t('请选择客户端类型') }]}
|
||||
required
|
||||
>
|
||||
<Radio
|
||||
value='confidential'
|
||||
extra={t('服务器端应用,安全地存储客户端密钥')}
|
||||
style={{ width: isMobile ? '100%' : 'auto' }}
|
||||
>
|
||||
{t('机密客户端(Confidential)')}
|
||||
</Radio>
|
||||
<Radio
|
||||
value='public'
|
||||
extra={t('移动应用或单页应用,无法安全存储密钥')}
|
||||
style={{ width: isMobile ? '100%' : 'auto' }}
|
||||
>
|
||||
{t('公开客户端(Public)')}
|
||||
</Radio>
|
||||
</Form.RadioGroup>
|
||||
<Form.Select
|
||||
field='grant_types'
|
||||
label={t('允许的授权类型')}
|
||||
multiple
|
||||
value={grantTypes}
|
||||
onChange={handleGrantTypesChange}
|
||||
rules={[
|
||||
{ required: true, message: t('请选择至少一种授权类型') },
|
||||
]}
|
||||
required
|
||||
placeholder={t('请选择授权类型(可多选)')}
|
||||
>
|
||||
{clientType !== 'public' && (
|
||||
<Option
|
||||
value={CLIENT_CREDENTIALS}
|
||||
disabled={isGrantTypeDisabled(CLIENT_CREDENTIALS)}
|
||||
>
|
||||
{t('Client Credentials(客户端凭证)')}
|
||||
</Option>
|
||||
)}
|
||||
<Option
|
||||
value={AUTH_CODE}
|
||||
disabled={isGrantTypeDisabled(AUTH_CODE)}
|
||||
>
|
||||
{t('Authorization Code(授权码)')}
|
||||
</Option>
|
||||
<Option
|
||||
value='refresh_token'
|
||||
disabled={isGrantTypeDisabled('refresh_token')}
|
||||
>
|
||||
{t('Refresh Token(刷新令牌)')}
|
||||
</Option>
|
||||
</Form.Select>
|
||||
<Form.Select
|
||||
field='scopes'
|
||||
label={t('允许的权限范围(Scope)')}
|
||||
multiple
|
||||
rules={[
|
||||
{ required: true, message: t('请选择至少一个权限范围') },
|
||||
]}
|
||||
required
|
||||
placeholder={t('请选择权限范围(可多选)')}
|
||||
>
|
||||
<Option value='openid'>{t('openid(OIDC 基础身份)')}</Option>
|
||||
<Option value='profile'>
|
||||
{t('profile(用户名/昵称等)')}
|
||||
</Option>
|
||||
<Option value='email'>{t('email(邮箱信息)')}</Option>
|
||||
<Option value='api:read'>
|
||||
{`api:read (${t('读取API')})`}
|
||||
</Option>
|
||||
<Option value='api:write'>
|
||||
{`api:write (${t('写入API')})`}
|
||||
</Option>
|
||||
<Option value='admin'>{t('admin(管理员权限)')}</Option>
|
||||
</Form.Select>
|
||||
<Form.Switch
|
||||
field='require_pkce'
|
||||
label={t('强制PKCE验证')}
|
||||
size='large'
|
||||
extraText={t(
|
||||
'PKCE(Proof Key for Code Exchange)可提高授权码流程的安全性。',
|
||||
)}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* 重定向URI */}
|
||||
<RedirectUriCard
|
||||
t={t}
|
||||
isAuthCodeSelected={isAuthCodeSelected}
|
||||
redirectUris={redirectUris}
|
||||
onAdd={addRedirectUri}
|
||||
onUpdate={updateRedirectUri}
|
||||
onRemove={removeRedirectUri}
|
||||
onFillTemplate={fillRedirectUriTemplate}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</Form>
|
||||
</Spin>
|
||||
|
||||
{/* 客户端信息展示模态框 */}
|
||||
<ClientInfoModal
|
||||
visible={showClientInfo}
|
||||
onClose={handleClientInfoClose}
|
||||
clientId={clientInfo.clientId}
|
||||
clientSecret={clientInfo.clientSecret}
|
||||
/>
|
||||
</SideSheet>
|
||||
);
|
||||
};
|
||||
|
||||
export default OAuth2ClientModal;
|
||||
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Modal, Banner, Typography } from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const SecretDisplayModal = ({ visible, onClose, secret }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('客户端密钥已重新生成')}
|
||||
visible={visible}
|
||||
onCancel={onClose}
|
||||
onOk={onClose}
|
||||
cancelText=''
|
||||
okText={t('我已复制保存')}
|
||||
width={650}
|
||||
bodyStyle={{ padding: '20px 24px' }}
|
||||
>
|
||||
<Banner
|
||||
type='success'
|
||||
closeIcon={null}
|
||||
description={t(
|
||||
'新的客户端密钥如下,请立即复制保存。关闭此窗口后将无法再次查看。',
|
||||
)}
|
||||
className='mb-5 !rounded-lg'
|
||||
/>
|
||||
<div className='flex justify-center items-center'>
|
||||
<Text code copyable>
|
||||
{secret}
|
||||
</Text>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default SecretDisplayModal;
|
||||
@@ -0,0 +1,72 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Modal } from '@douyinfe/semi-ui';
|
||||
import { API, showError } from '../../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import CodeViewer from '../../../common/ui/CodeViewer';
|
||||
|
||||
const ServerInfoModal = ({ visible, onClose }) => {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [serverInfo, setServerInfo] = useState(null);
|
||||
|
||||
const loadServerInfo = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/oauth/server-info');
|
||||
setServerInfo(res.data);
|
||||
} catch (error) {
|
||||
showError(t('获取服务器信息失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
loadServerInfo();
|
||||
}
|
||||
}, [visible]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('OAuth2 服务器信息')}
|
||||
visible={visible}
|
||||
onCancel={onClose}
|
||||
onOk={onClose}
|
||||
cancelText=''
|
||||
okText={t('关闭')}
|
||||
width={650}
|
||||
bodyStyle={{ padding: '20px 24px' }}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<CodeViewer
|
||||
content={
|
||||
serverInfo ? JSON.stringify(serverInfo, null, 2) : t('加载中...')
|
||||
}
|
||||
title={t('OAuth2 服务器配置')}
|
||||
language='json'
|
||||
/>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default ServerInfoModal;
|
||||
@@ -40,7 +40,7 @@ import {
|
||||
showSuccess,
|
||||
showError,
|
||||
} from '../../../../helpers';
|
||||
import CodeViewer from '../../../playground/CodeViewer';
|
||||
import CodeViewer from '../../../common/ui/CodeViewer';
|
||||
import { StatusContext } from '../../../../context/Status';
|
||||
import { UserContext } from '../../../../context/User';
|
||||
import { useUserPermissions } from '../../../../hooks/common/useUserPermissions';
|
||||
|
||||
@@ -142,6 +142,8 @@ const EditChannelModal = (props) => {
|
||||
system_prompt: '',
|
||||
system_prompt_override: false,
|
||||
settings: '',
|
||||
// 仅 Vertex: 密钥格式(存入 settings.vertex_key_type)
|
||||
vertex_key_type: 'json',
|
||||
};
|
||||
const [batch, setBatch] = useState(false);
|
||||
const [multiToSingle, setMultiToSingle] = useState(false);
|
||||
@@ -409,11 +411,17 @@ const EditChannelModal = (props) => {
|
||||
const parsedSettings = JSON.parse(data.settings);
|
||||
data.azure_responses_version =
|
||||
parsedSettings.azure_responses_version || '';
|
||||
// 读取 Vertex 密钥格式
|
||||
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
|
||||
} catch (error) {
|
||||
console.error('解析其他设置失败:', error);
|
||||
data.azure_responses_version = '';
|
||||
data.region = '';
|
||||
data.vertex_key_type = 'json';
|
||||
}
|
||||
} else {
|
||||
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
|
||||
data.vertex_key_type = 'json';
|
||||
}
|
||||
|
||||
setInputs(data);
|
||||
@@ -745,59 +753,58 @@ const EditChannelModal = (props) => {
|
||||
let localInputs = { ...formValues };
|
||||
|
||||
if (localInputs.type === 41) {
|
||||
if (useManualInput) {
|
||||
// 手动输入模式
|
||||
if (localInputs.key && localInputs.key.trim() !== '') {
|
||||
try {
|
||||
// 验证 JSON 格式
|
||||
const parsedKey = JSON.parse(localInputs.key);
|
||||
// 确保是有效的密钥格式
|
||||
localInputs.key = JSON.stringify(parsedKey);
|
||||
} catch (err) {
|
||||
showError(t('密钥格式无效,请输入有效的 JSON 格式密钥'));
|
||||
return;
|
||||
}
|
||||
} else if (!isEdit) {
|
||||
const keyType = localInputs.vertex_key_type || 'json';
|
||||
if (keyType === 'api_key') {
|
||||
// 直接作为普通字符串密钥处理
|
||||
if (!isEdit && (!localInputs.key || localInputs.key.trim() === '')) {
|
||||
showInfo(t('请输入密钥!'));
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// 文件上传模式
|
||||
let keys = vertexKeys;
|
||||
|
||||
// 若当前未选择文件,尝试从已上传文件列表解析(异步读取)
|
||||
if (keys.length === 0 && vertexFileList.length > 0) {
|
||||
try {
|
||||
const parsed = await Promise.all(
|
||||
vertexFileList.map(async (item) => {
|
||||
const fileObj = item.fileInstance;
|
||||
if (!fileObj) return null;
|
||||
const txt = await fileObj.text();
|
||||
return JSON.parse(txt);
|
||||
}),
|
||||
);
|
||||
keys = parsed.filter(Boolean);
|
||||
} catch (err) {
|
||||
showError(t('解析密钥文件失败: {{msg}}', { msg: err.message }));
|
||||
// JSON 服务账号密钥
|
||||
if (useManualInput) {
|
||||
if (localInputs.key && localInputs.key.trim() !== '') {
|
||||
try {
|
||||
const parsedKey = JSON.parse(localInputs.key);
|
||||
localInputs.key = JSON.stringify(parsedKey);
|
||||
} catch (err) {
|
||||
showError(t('密钥格式无效,请输入有效的 JSON 格式密钥'));
|
||||
return;
|
||||
}
|
||||
} else if (!isEdit) {
|
||||
showInfo(t('请输入密钥!'));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 创建模式必须上传密钥;编辑模式可选
|
||||
if (keys.length === 0) {
|
||||
if (!isEdit) {
|
||||
showInfo(t('请上传密钥文件!'));
|
||||
return;
|
||||
} else {
|
||||
// 编辑模式且未上传新密钥,不修改 key
|
||||
delete localInputs.key;
|
||||
}
|
||||
} else {
|
||||
// 有新密钥,则覆盖
|
||||
if (batch) {
|
||||
localInputs.key = JSON.stringify(keys);
|
||||
// 文件上传模式
|
||||
let keys = vertexKeys;
|
||||
if (keys.length === 0 && vertexFileList.length > 0) {
|
||||
try {
|
||||
const parsed = await Promise.all(
|
||||
vertexFileList.map(async (item) => {
|
||||
const fileObj = item.fileInstance;
|
||||
if (!fileObj) return null;
|
||||
const txt = await fileObj.text();
|
||||
return JSON.parse(txt);
|
||||
}),
|
||||
);
|
||||
keys = parsed.filter(Boolean);
|
||||
} catch (err) {
|
||||
showError(t('解析密钥文件失败: {{msg}}', { msg: err.message }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (keys.length === 0) {
|
||||
if (!isEdit) {
|
||||
showInfo(t('请上传密钥文件!'));
|
||||
return;
|
||||
} else {
|
||||
delete localInputs.key;
|
||||
}
|
||||
} else {
|
||||
localInputs.key = JSON.stringify(keys[0]);
|
||||
localInputs.key = batch
|
||||
? JSON.stringify(keys)
|
||||
: JSON.stringify(keys[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -853,6 +860,8 @@ const EditChannelModal = (props) => {
|
||||
delete localInputs.pass_through_body_enabled;
|
||||
delete localInputs.system_prompt;
|
||||
delete localInputs.system_prompt_override;
|
||||
// 顶层的 vertex_key_type 不应发送给后端
|
||||
delete localInputs.vertex_key_type;
|
||||
|
||||
let res;
|
||||
localInputs.auto_ban = localInputs.auto_ban ? 1 : 0;
|
||||
@@ -1178,8 +1187,44 @@ const EditChannelModal = (props) => {
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
|
||||
{inputs.type === 41 && (
|
||||
<Form.Select
|
||||
field='vertex_key_type'
|
||||
label={t('密钥格式')}
|
||||
placeholder={t('请选择密钥格式')}
|
||||
optionList={[
|
||||
{ label: 'JSON', value: 'json' },
|
||||
{ label: 'API Key', value: 'api_key' },
|
||||
]}
|
||||
style={{ width: '100%' }}
|
||||
value={inputs.vertex_key_type || 'json'}
|
||||
onChange={(value) => {
|
||||
// 更新设置中的 vertex_key_type
|
||||
handleChannelOtherSettingsChange(
|
||||
'vertex_key_type',
|
||||
value,
|
||||
);
|
||||
// 切换为 api_key 时,关闭批量与手动/文件切换,并清理已选文件
|
||||
if (value === 'api_key') {
|
||||
setBatch(false);
|
||||
setUseManualInput(false);
|
||||
setVertexKeys([]);
|
||||
setVertexFileList([]);
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('vertex_files', []);
|
||||
}
|
||||
}
|
||||
}}
|
||||
extraText={
|
||||
inputs.vertex_key_type === 'api_key'
|
||||
? t('API Key 模式下不支持批量创建')
|
||||
: t('JSON 模式支持手动输入或上传服务账号 JSON')
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{batch ? (
|
||||
inputs.type === 41 ? (
|
||||
inputs.type === 41 &&
|
||||
(inputs.vertex_key_type || 'json') === 'json' ? (
|
||||
<Form.Upload
|
||||
field='vertex_files'
|
||||
label={t('密钥文件 (.json)')}
|
||||
@@ -1243,7 +1288,8 @@ const EditChannelModal = (props) => {
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
{inputs.type === 41 ? (
|
||||
{inputs.type === 41 &&
|
||||
(inputs.vertex_key_type || 'json') === 'json' ? (
|
||||
<>
|
||||
{!batch && (
|
||||
<div className='flex items-center justify-between mb-3'>
|
||||
|
||||
@@ -21,6 +21,7 @@ import React, { useRef } from 'react';
|
||||
import {
|
||||
Avatar,
|
||||
Typography,
|
||||
Tag,
|
||||
Card,
|
||||
Button,
|
||||
Banner,
|
||||
@@ -30,6 +31,7 @@ import {
|
||||
Row,
|
||||
Col,
|
||||
Spin,
|
||||
Tooltip,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { SiAlipay, SiWechat, SiStripe } from 'react-icons/si';
|
||||
import { CreditCard, Coins, Wallet, BarChart2, TrendingUp } from 'lucide-react';
|
||||
@@ -68,6 +70,7 @@ const RechargeCard = ({
|
||||
userState,
|
||||
renderQuota,
|
||||
statusLoading,
|
||||
topupInfo,
|
||||
}) => {
|
||||
const onlineFormApiRef = useRef(null);
|
||||
const redeemFormApiRef = useRef(null);
|
||||
@@ -261,44 +264,74 @@ const RechargeCard = ({
|
||||
</Col>
|
||||
<Col xs={24} sm={24} md={24} lg={14} xl={14}>
|
||||
<Form.Slot label={t('选择支付方式')}>
|
||||
<Space wrap>
|
||||
{payMethods.map((payMethod) => (
|
||||
<Button
|
||||
key={payMethod.type}
|
||||
theme='outline'
|
||||
type='tertiary'
|
||||
onClick={() => preTopUp(payMethod.type)}
|
||||
disabled={
|
||||
(!enableOnlineTopUp &&
|
||||
payMethod.type !== 'stripe') ||
|
||||
(!enableStripeTopUp &&
|
||||
payMethod.type === 'stripe')
|
||||
}
|
||||
loading={
|
||||
paymentLoading && payWay === payMethod.type
|
||||
}
|
||||
icon={
|
||||
payMethod.type === 'alipay' ? (
|
||||
<SiAlipay size={18} color='#1677FF' />
|
||||
) : payMethod.type === 'wxpay' ? (
|
||||
<SiWechat size={18} color='#07C160' />
|
||||
) : payMethod.type === 'stripe' ? (
|
||||
<SiStripe size={18} color='#635BFF' />
|
||||
) : (
|
||||
<CreditCard
|
||||
size={18}
|
||||
color={
|
||||
payMethod.color ||
|
||||
'var(--semi-color-text-2)'
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
>
|
||||
{payMethod.name}
|
||||
</Button>
|
||||
))}
|
||||
</Space>
|
||||
{payMethods && payMethods.length > 0 ? (
|
||||
<Space wrap>
|
||||
{payMethods.map((payMethod) => {
|
||||
const minTopupVal =
|
||||
Number(payMethod.min_topup) || 0;
|
||||
const isStripe = payMethod.type === 'stripe';
|
||||
const disabled =
|
||||
(!enableOnlineTopUp && !isStripe) ||
|
||||
(!enableStripeTopUp && isStripe) ||
|
||||
minTopupVal > Number(topUpCount || 0);
|
||||
|
||||
const buttonEl = (
|
||||
<Button
|
||||
key={payMethod.type}
|
||||
theme='outline'
|
||||
type='tertiary'
|
||||
onClick={() => preTopUp(payMethod.type)}
|
||||
disabled={disabled}
|
||||
loading={
|
||||
paymentLoading && payWay === payMethod.type
|
||||
}
|
||||
icon={
|
||||
payMethod.type === 'alipay' ? (
|
||||
<SiAlipay size={18} color='#1677FF' />
|
||||
) : payMethod.type === 'wxpay' ? (
|
||||
<SiWechat size={18} color='#07C160' />
|
||||
) : payMethod.type === 'stripe' ? (
|
||||
<SiStripe size={18} color='#635BFF' />
|
||||
) : (
|
||||
<CreditCard
|
||||
size={18}
|
||||
color={
|
||||
payMethod.color ||
|
||||
'var(--semi-color-text-2)'
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
className='!rounded-lg !px-4 !py-2'
|
||||
>
|
||||
{payMethod.name}
|
||||
</Button>
|
||||
);
|
||||
|
||||
return disabled &&
|
||||
minTopupVal > Number(topUpCount || 0) ? (
|
||||
<Tooltip
|
||||
content={
|
||||
t('此支付方式最低充值金额为') +
|
||||
' ' +
|
||||
minTopupVal
|
||||
}
|
||||
key={payMethod.type}
|
||||
>
|
||||
{buttonEl}
|
||||
</Tooltip>
|
||||
) : (
|
||||
<React.Fragment key={payMethod.type}>
|
||||
{buttonEl}
|
||||
</React.Fragment>
|
||||
);
|
||||
})}
|
||||
</Space>
|
||||
) : (
|
||||
<div className='text-gray-500 text-sm p-3 bg-gray-50 rounded-lg border border-dashed border-gray-300'>
|
||||
{t('暂无可用的支付方式,请联系管理员配置')}
|
||||
</div>
|
||||
)}
|
||||
</Form.Slot>
|
||||
</Col>
|
||||
</Row>
|
||||
@@ -306,41 +339,75 @@ const RechargeCard = ({
|
||||
|
||||
{(enableOnlineTopUp || enableStripeTopUp) && (
|
||||
<Form.Slot label={t('选择充值额度')}>
|
||||
<Space wrap>
|
||||
{presetAmounts.map((preset, index) => (
|
||||
<Button
|
||||
key={index}
|
||||
theme={
|
||||
selectedPreset === preset.value
|
||||
? 'solid'
|
||||
: 'outline'
|
||||
}
|
||||
type={
|
||||
selectedPreset === preset.value
|
||||
? 'primary'
|
||||
: 'tertiary'
|
||||
}
|
||||
onClick={() => {
|
||||
selectPresetAmount(preset);
|
||||
onlineFormApiRef.current?.setValue(
|
||||
'topUpCount',
|
||||
preset.value,
|
||||
);
|
||||
}}
|
||||
className='!rounded-lg !py-2 !px-3'
|
||||
>
|
||||
<div className='flex items-center gap-2'>
|
||||
<Coins size={14} className='opacity-80' />
|
||||
<span className='font-medium'>
|
||||
{formatLargeNumber(preset.value)}
|
||||
</span>
|
||||
<span className='text-xs text-gray-500'>
|
||||
¥{(preset.value * priceRatio).toFixed(2)}
|
||||
</span>
|
||||
</div>
|
||||
</Button>
|
||||
))}
|
||||
</Space>
|
||||
<div className='grid grid-cols-2 sm:grid-cols-3 md:grid-cols-4 gap-2'>
|
||||
{presetAmounts.map((preset, index) => {
|
||||
const discount =
|
||||
preset.discount ||
|
||||
topupInfo?.discount?.[preset.value] ||
|
||||
1.0;
|
||||
const originalPrice = preset.value * priceRatio;
|
||||
const discountedPrice = originalPrice * discount;
|
||||
const hasDiscount = discount < 1.0;
|
||||
const actualPay = discountedPrice;
|
||||
const save = originalPrice - discountedPrice;
|
||||
|
||||
return (
|
||||
<Card
|
||||
key={index}
|
||||
style={{
|
||||
cursor: 'pointer',
|
||||
border:
|
||||
selectedPreset === preset.value
|
||||
? '2px solid var(--semi-color-primary)'
|
||||
: '1px solid var(--semi-color-border)',
|
||||
height: '100%',
|
||||
width: '100%',
|
||||
}}
|
||||
bodyStyle={{ padding: '12px' }}
|
||||
onClick={() => {
|
||||
selectPresetAmount(preset);
|
||||
onlineFormApiRef.current?.setValue(
|
||||
'topUpCount',
|
||||
preset.value,
|
||||
);
|
||||
}}
|
||||
>
|
||||
<div style={{ textAlign: 'center' }}>
|
||||
<Typography.Title
|
||||
heading={6}
|
||||
style={{ margin: '0 0 8px 0' }}
|
||||
>
|
||||
<Coins size={18} />
|
||||
{formatLargeNumber(preset.value)}
|
||||
{hasDiscount && (
|
||||
<Tag style={{ marginLeft: 4 }} color='green'>
|
||||
{t('折').includes('off')
|
||||
? (
|
||||
(1 - parseFloat(discount)) *
|
||||
100
|
||||
).toFixed(1)
|
||||
: (discount * 10).toFixed(1)}
|
||||
{t('折')}
|
||||
</Tag>
|
||||
)}
|
||||
</Typography.Title>
|
||||
<div
|
||||
style={{
|
||||
color: 'var(--semi-color-text-2)',
|
||||
fontSize: '12px',
|
||||
margin: '4px 0',
|
||||
}}
|
||||
>
|
||||
{t('实付')} {actualPay.toFixed(2)},
|
||||
{hasDiscount
|
||||
? `${t('节省')} ${save.toFixed(2)}`
|
||||
: `${t('节省')} 0.00`}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</Form.Slot>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -81,6 +81,12 @@ const TopUp = () => {
|
||||
const [presetAmounts, setPresetAmounts] = useState([]);
|
||||
const [selectedPreset, setSelectedPreset] = useState(null);
|
||||
|
||||
// 充值配置信息
|
||||
const [topupInfo, setTopupInfo] = useState({
|
||||
amount_options: [],
|
||||
discount: {},
|
||||
});
|
||||
|
||||
const topUp = async () => {
|
||||
if (redemptionCode === '') {
|
||||
showInfo(t('请输入兑换码!'));
|
||||
@@ -248,6 +254,108 @@ const TopUp = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 获取充值配置信息
|
||||
const getTopupInfo = async () => {
|
||||
try {
|
||||
const res = await API.get('/api/user/topup/info');
|
||||
const { message, data, success } = res.data;
|
||||
if (success) {
|
||||
setTopupInfo({
|
||||
amount_options: data.amount_options || [],
|
||||
discount: data.discount || {},
|
||||
});
|
||||
|
||||
// 处理支付方式
|
||||
let payMethods = data.pay_methods || [];
|
||||
try {
|
||||
if (typeof payMethods === 'string') {
|
||||
payMethods = JSON.parse(payMethods);
|
||||
}
|
||||
if (payMethods && payMethods.length > 0) {
|
||||
// 检查name和type是否为空
|
||||
payMethods = payMethods.filter((method) => {
|
||||
return method.name && method.type;
|
||||
});
|
||||
// 如果没有color,则设置默认颜色
|
||||
payMethods = payMethods.map((method) => {
|
||||
// 规范化最小充值数
|
||||
const normalizedMinTopup = Number(method.min_topup);
|
||||
method.min_topup = Number.isFinite(normalizedMinTopup)
|
||||
? normalizedMinTopup
|
||||
: 0;
|
||||
|
||||
// Stripe 的最小充值从后端字段回填
|
||||
if (
|
||||
method.type === 'stripe' &&
|
||||
(!method.min_topup || method.min_topup <= 0)
|
||||
) {
|
||||
const stripeMin = Number(data.stripe_min_topup);
|
||||
if (Number.isFinite(stripeMin)) {
|
||||
method.min_topup = stripeMin;
|
||||
}
|
||||
}
|
||||
|
||||
if (!method.color) {
|
||||
if (method.type === 'alipay') {
|
||||
method.color = 'rgba(var(--semi-blue-5), 1)';
|
||||
} else if (method.type === 'wxpay') {
|
||||
method.color = 'rgba(var(--semi-green-5), 1)';
|
||||
} else if (method.type === 'stripe') {
|
||||
method.color = 'rgba(var(--semi-purple-5), 1)';
|
||||
} else {
|
||||
method.color = 'rgba(var(--semi-primary-5), 1)';
|
||||
}
|
||||
}
|
||||
return method;
|
||||
});
|
||||
} else {
|
||||
payMethods = [];
|
||||
}
|
||||
|
||||
// 如果启用了 Stripe 支付,添加到支付方法列表
|
||||
// 这个逻辑现在由后端处理,如果 Stripe 启用,后端会在 pay_methods 中包含它
|
||||
|
||||
setPayMethods(payMethods);
|
||||
const enableStripeTopUp = data.enable_stripe_topup || false;
|
||||
const enableOnlineTopUp = data.enable_online_topup || false;
|
||||
const minTopUpValue = enableOnlineTopUp
|
||||
? data.min_topup
|
||||
: enableStripeTopUp
|
||||
? data.stripe_min_topup
|
||||
: 1;
|
||||
setEnableOnlineTopUp(enableOnlineTopUp);
|
||||
setEnableStripeTopUp(enableStripeTopUp);
|
||||
setMinTopUp(minTopUpValue);
|
||||
setTopUpCount(minTopUpValue);
|
||||
|
||||
// 如果没有自定义充值数量选项,根据最小充值金额生成预设充值额度选项
|
||||
if (topupInfo.amount_options.length === 0) {
|
||||
setPresetAmounts(generatePresetAmounts(minTopUpValue));
|
||||
}
|
||||
|
||||
// 初始化显示实付金额
|
||||
getAmount(minTopUpValue);
|
||||
} catch (e) {
|
||||
console.log('解析支付方式失败:', e);
|
||||
setPayMethods([]);
|
||||
}
|
||||
|
||||
// 如果有自定义充值数量选项,使用它们替换默认的预设选项
|
||||
if (data.amount_options && data.amount_options.length > 0) {
|
||||
const customPresets = data.amount_options.map((amount) => ({
|
||||
value: amount,
|
||||
discount: data.discount[amount] || 1.0,
|
||||
}));
|
||||
setPresetAmounts(customPresets);
|
||||
}
|
||||
} else {
|
||||
console.error('获取充值配置失败:', data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取充值配置异常:', error);
|
||||
}
|
||||
};
|
||||
|
||||
// 获取邀请链接
|
||||
const getAffLink = async () => {
|
||||
const res = await API.get('/api/user/aff');
|
||||
@@ -290,52 +398,7 @@ const TopUp = () => {
|
||||
getUserQuota().then();
|
||||
}
|
||||
setTransferAmount(getQuotaPerUnit());
|
||||
|
||||
let payMethods = localStorage.getItem('pay_methods');
|
||||
try {
|
||||
payMethods = JSON.parse(payMethods);
|
||||
if (payMethods && payMethods.length > 0) {
|
||||
// 检查name和type是否为空
|
||||
payMethods = payMethods.filter((method) => {
|
||||
return method.name && method.type;
|
||||
});
|
||||
// 如果没有color,则设置默认颜色
|
||||
payMethods = payMethods.map((method) => {
|
||||
if (!method.color) {
|
||||
if (method.type === 'alipay') {
|
||||
method.color = 'rgba(var(--semi-blue-5), 1)';
|
||||
} else if (method.type === 'wxpay') {
|
||||
method.color = 'rgba(var(--semi-green-5), 1)';
|
||||
} else if (method.type === 'stripe') {
|
||||
method.color = 'rgba(var(--semi-purple-5), 1)';
|
||||
} else {
|
||||
method.color = 'rgba(var(--semi-primary-5), 1)';
|
||||
}
|
||||
}
|
||||
return method;
|
||||
});
|
||||
} else {
|
||||
payMethods = [];
|
||||
}
|
||||
|
||||
// 如果启用了 Stripe 支付,添加到支付方法列表
|
||||
if (statusState?.status?.enable_stripe_topup) {
|
||||
const hasStripe = payMethods.some((method) => method.type === 'stripe');
|
||||
if (!hasStripe) {
|
||||
payMethods.push({
|
||||
name: 'Stripe',
|
||||
type: 'stripe',
|
||||
color: 'rgba(var(--semi-purple-5), 1)',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
setPayMethods(payMethods);
|
||||
} catch (e) {
|
||||
console.log(e);
|
||||
showError(t('支付方式配置错误, 请联系管理员'));
|
||||
}
|
||||
}, [statusState?.status?.enable_stripe_topup]);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (affFetchedRef.current) return;
|
||||
@@ -343,20 +406,18 @@ const TopUp = () => {
|
||||
getAffLink().then();
|
||||
}, []);
|
||||
|
||||
// 在 statusState 可用时获取充值信息
|
||||
useEffect(() => {
|
||||
getTopupInfo().then();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (statusState?.status) {
|
||||
const minTopUpValue = statusState.status.min_topup || 1;
|
||||
setMinTopUp(minTopUpValue);
|
||||
setTopUpCount(minTopUpValue);
|
||||
// const minTopUpValue = statusState.status.min_topup || 1;
|
||||
// setMinTopUp(minTopUpValue);
|
||||
// setTopUpCount(minTopUpValue);
|
||||
setTopUpLink(statusState.status.top_up_link || '');
|
||||
setEnableOnlineTopUp(statusState.status.enable_online_topup || false);
|
||||
setPriceRatio(statusState.status.price || 1);
|
||||
setEnableStripeTopUp(statusState.status.enable_stripe_topup || false);
|
||||
|
||||
// 根据最小充值金额生成预设充值额度选项
|
||||
setPresetAmounts(generatePresetAmounts(minTopUpValue));
|
||||
// 初始化显示实付金额
|
||||
getAmount(minTopUpValue);
|
||||
|
||||
setStatusLoading(false);
|
||||
}
|
||||
@@ -431,7 +492,11 @@ const TopUp = () => {
|
||||
const selectPresetAmount = (preset) => {
|
||||
setTopUpCount(preset.value);
|
||||
setSelectedPreset(preset.value);
|
||||
setAmount(preset.value * priceRatio);
|
||||
|
||||
// 计算实际支付金额,考虑折扣
|
||||
const discount = preset.discount || topupInfo.discount[preset.value] || 1.0;
|
||||
const discountedAmount = preset.value * priceRatio * discount;
|
||||
setAmount(discountedAmount);
|
||||
};
|
||||
|
||||
// 格式化大数字显示
|
||||
@@ -475,6 +540,8 @@ const TopUp = () => {
|
||||
renderAmount={renderAmount}
|
||||
payWay={payWay}
|
||||
payMethods={payMethods}
|
||||
amountNumber={amount}
|
||||
discountRate={topupInfo?.discount?.[topUpCount] || 1.0}
|
||||
/>
|
||||
|
||||
{/* 用户信息头部 */}
|
||||
@@ -512,6 +579,7 @@ const TopUp = () => {
|
||||
userState={userState}
|
||||
renderQuota={renderQuota}
|
||||
statusLoading={statusLoading}
|
||||
topupInfo={topupInfo}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -36,7 +36,14 @@ const PaymentConfirmModal = ({
|
||||
renderAmount,
|
||||
payWay,
|
||||
payMethods,
|
||||
// 新增:用于显示折扣明细
|
||||
amountNumber,
|
||||
discountRate,
|
||||
}) => {
|
||||
const hasDiscount =
|
||||
discountRate && discountRate > 0 && discountRate < 1 && amountNumber > 0;
|
||||
const originalAmount = hasDiscount ? amountNumber / discountRate : 0;
|
||||
const discountAmount = hasDiscount ? originalAmount - amountNumber : 0;
|
||||
return (
|
||||
<Modal
|
||||
title={
|
||||
@@ -71,11 +78,38 @@ const PaymentConfirmModal = ({
|
||||
{amountLoading ? (
|
||||
<Skeleton.Title style={{ width: '60px', height: '16px' }} />
|
||||
) : (
|
||||
<Text strong className='font-bold' style={{ color: 'red' }}>
|
||||
{renderAmount()}
|
||||
</Text>
|
||||
<div className='flex items-baseline space-x-2'>
|
||||
<Text strong className='font-bold' style={{ color: 'red' }}>
|
||||
{renderAmount()}
|
||||
</Text>
|
||||
{hasDiscount && (
|
||||
<Text size='small' className='text-rose-500'>
|
||||
{Math.round(discountRate * 100)}%
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{hasDiscount && !amountLoading && (
|
||||
<>
|
||||
<div className='flex justify-between items-center'>
|
||||
<Text className='text-slate-500 dark:text-slate-400'>
|
||||
{t('原价')}:
|
||||
</Text>
|
||||
<Text delete className='text-slate-500 dark:text-slate-400'>
|
||||
{`${originalAmount.toFixed(2)} ${t('元')}`}
|
||||
</Text>
|
||||
</div>
|
||||
<div className='flex justify-between items-center'>
|
||||
<Text className='text-slate-500 dark:text-slate-400'>
|
||||
{t('优惠')}:
|
||||
</Text>
|
||||
<Text className='text-emerald-600 dark:text-emerald-400'>
|
||||
{`- ${discountAmount.toFixed(2)} ${t('元')}`}
|
||||
</Text>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<div className='flex justify-between items-center'>
|
||||
<Text strong className='text-slate-700 dark:text-slate-200'>
|
||||
{t('支付方式')}:
|
||||
|
||||
@@ -36,7 +36,11 @@ export const AuthRedirect = ({ children }) => {
|
||||
const user = localStorage.getItem('user');
|
||||
|
||||
if (user) {
|
||||
return <Navigate to='/console' replace />;
|
||||
// 优先使用登录页上的 next 参数(仅允许站内相对路径)
|
||||
const sp = new URLSearchParams(window.location.search);
|
||||
const next = sp.get('next');
|
||||
const isSafeInternalPath = next && next.startsWith('/') && !next.startsWith('//');
|
||||
return <Navigate to={isSafeInternalPath ? next : '/console'} replace />;
|
||||
}
|
||||
|
||||
return children;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user