diff --git a/README.fr.md b/README.fr.md index 77fd0cd1c..6b4d0ceba 100644 --- a/README.fr.md +++ b/README.fr.md @@ -30,8 +30,8 @@
-
-
+
+
diff --git a/README.ja.md b/README.ja.md
index 2cb00affb..2b35bdfe9 100644
--- a/README.ja.md
+++ b/README.ja.md
@@ -30,8 +30,8 @@
-
-
+
+
diff --git a/README.md b/README.md
index 5f64a0d0b..8f23d5dcd 100644
--- a/README.md
+++ b/README.md
@@ -30,8 +30,8 @@
-
-
+
+
diff --git a/README.zh_CN.md b/README.zh_CN.md
index 55265d9a8..fd3204950 100644
--- a/README.zh_CN.md
+++ b/README.zh_CN.md
@@ -30,8 +30,8 @@
-
-
+
+
diff --git a/README.zh_TW.md b/README.zh_TW.md
index 2fa93157e..9264bc722 100644
--- a/README.zh_TW.md
+++ b/README.zh_TW.md
@@ -30,8 +30,8 @@
-
-
+
+
diff --git a/common/gin.go b/common/gin.go
index 48971c130..5cad6e5c9 100644
--- a/common/gin.go
+++ b/common/gin.go
@@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
return nil, err
}
- contentType := c.Request.Header.Get("Content-Type")
+ // Use the original Content-Type saved on first call to avoid boundary
+ // mismatch when callers overwrite c.Request.Header after multipart rebuild.
+ var contentType string
+ if saved, ok := c.Get("_original_multipart_ct"); ok {
+ contentType = saved.(string)
+ } else {
+ contentType = c.Request.Header.Get("Content-Type")
+ c.Set("_original_multipart_ct", contentType)
+ }
boundary, err := parseBoundary(contentType)
if err != nil {
return nil, err
@@ -295,7 +303,13 @@ func parseFormData(data []byte, v any) error {
}
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
- contentType := c.Request.Header.Get("Content-Type")
+ var contentType string
+ if saved, ok := c.Get("_original_multipart_ct"); ok {
+ contentType = saved.(string)
+ } else {
+ contentType = c.Request.Header.Get("Content-Type")
+ c.Set("_original_multipart_ct", contentType)
+ }
boundary, err := parseBoundary(contentType)
if err != nil {
if errors.Is(err, errBoundaryNotFound) {
diff --git a/common/init.go b/common/init.go
index 6d2c3572b..e4ddbb453 100644
--- a/common/init.go
+++ b/common/init.go
@@ -145,6 +145,8 @@ func initConstantEnv() {
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
// 任务轮询时查询的最大数量
constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
+ // 异步任务超时时间(分钟),超过此时间未完成的任务将被标记为失败并退款。0 表示禁用。
+ constant.TaskTimeoutMinutes = GetEnvOrDefault("TASK_TIMEOUT_MINUTES", 1440)
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
if soraPatchStr != "" {
diff --git a/constant/env.go b/constant/env.go
index 957f68669..d5aff1b0b 100644
--- a/constant/env.go
+++ b/constant/env.go
@@ -16,6 +16,7 @@ var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
var ErrorLogEnabled bool
var TaskQueryLimit int
+var TaskTimeoutMinutes int
// temporary variable for sora patch, will be removed in future
var TaskPricePatches []string
diff --git a/controller/codex_oauth.go b/controller/codex_oauth.go
index 3071413c6..de9743ab7 100644
--- a/controller/codex_oauth.go
+++ b/controller/codex_oauth.go
@@ -145,6 +145,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
return
}
+ channelProxy := ""
if channelID > 0 {
ch, err := model.GetChannelById(channelID, false)
if err != nil {
@@ -159,6 +160,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
return
}
+ channelProxy = ch.GetSetting().Proxy
}
session := sessions.Default(c)
@@ -176,7 +178,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
defer cancel()
- tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
+ tokenRes, err := service.ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, channelProxy)
if err != nil {
common.SysError("failed to exchange codex authorization code: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"})
diff --git a/controller/codex_usage.go b/controller/codex_usage.go
index 62b7a754f..52fdbdf6f 100644
--- a/controller/codex_usage.go
+++ b/controller/codex_usage.go
@@ -2,7 +2,6 @@ package controller
import (
"context"
- "encoding/json"
"fmt"
"net/http"
"strconv"
@@ -80,7 +79,7 @@ func GetCodexChannelUsage(c *gin.Context) {
refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
defer refreshCancel()
- res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
+ res, refreshErr := service.RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
if refreshErr == nil {
oauthKey.AccessToken = res.AccessToken
oauthKey.RefreshToken = res.RefreshToken
@@ -109,7 +108,7 @@ func GetCodexChannelUsage(c *gin.Context) {
}
var payload any
- if json.Unmarshal(body, &payload) != nil {
+ if common.Unmarshal(body, &payload) != nil {
payload = string(body)
}
diff --git a/controller/custom_oauth.go b/controller/custom_oauth.go
index e2245f880..c21ec7910 100644
--- a/controller/custom_oauth.go
+++ b/controller/custom_oauth.go
@@ -1,8 +1,13 @@
package controller
import (
+ "context"
+ "io"
"net/http"
+ "net/url"
"strconv"
+ "strings"
+ "time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
@@ -16,6 +21,7 @@ type CustomOAuthProviderResponse struct {
Id int `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
+ Icon string `json:"icon"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
@@ -28,6 +34,16 @@ type CustomOAuthProviderResponse struct {
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
+ AccessPolicy string `json:"access_policy"`
+ AccessDeniedMessage string `json:"access_denied_message"`
+}
+
+type UserOAuthBindingResponse struct {
+ ProviderId int `json:"provider_id"`
+ ProviderName string `json:"provider_name"`
+ ProviderSlug string `json:"provider_slug"`
+ ProviderIcon string `json:"provider_icon"`
+ ProviderUserId string `json:"provider_user_id"`
}
func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
@@ -35,6 +51,7 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro
Id: p.Id,
Name: p.Name,
Slug: p.Slug,
+ Icon: p.Icon,
Enabled: p.Enabled,
ClientId: p.ClientId,
AuthorizationEndpoint: p.AuthorizationEndpoint,
@@ -47,6 +64,8 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro
EmailField: p.EmailField,
WellKnown: p.WellKnown,
AuthStyle: p.AuthStyle,
+ AccessPolicy: p.AccessPolicy,
+ AccessDeniedMessage: p.AccessDeniedMessage,
}
}
@@ -96,6 +115,7 @@ func GetCustomOAuthProvider(c *gin.Context) {
type CreateCustomOAuthProviderRequest struct {
Name string `json:"name" binding:"required"`
Slug string `json:"slug" binding:"required"`
+ Icon string `json:"icon"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id" binding:"required"`
ClientSecret string `json:"client_secret" binding:"required"`
@@ -109,6 +129,85 @@ type CreateCustomOAuthProviderRequest struct {
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
+ AccessPolicy string `json:"access_policy"`
+ AccessDeniedMessage string `json:"access_denied_message"`
+}
+
+type FetchCustomOAuthDiscoveryRequest struct {
+ WellKnownURL string `json:"well_known_url"`
+ IssuerURL string `json:"issuer_url"`
+}
+
+// FetchCustomOAuthDiscovery fetches OIDC discovery document via backend (root-only route)
+func FetchCustomOAuthDiscovery(c *gin.Context) {
+ var req FetchCustomOAuthDiscoveryRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
+ return
+ }
+
+ wellKnownURL := strings.TrimSpace(req.WellKnownURL)
+ issuerURL := strings.TrimSpace(req.IssuerURL)
+
+ if wellKnownURL == "" && issuerURL == "" {
+ common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL")
+ return
+ }
+
+ targetURL := wellKnownURL
+ if targetURL == "" {
+ targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration"
+ }
+ targetURL = strings.TrimSpace(targetURL)
+
+ parsedURL, err := url.Parse(targetURL)
+ if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
+ common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https")
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second)
+ defer cancel()
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
+ if err != nil {
+ common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error())
+ return
+ }
+ httpReq.Header.Set("Accept", "application/json")
+
+ client := &http.Client{Timeout: 20 * time.Second}
+ resp, err := client.Do(httpReq)
+ if err != nil {
+ common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error())
+ return
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
+ message := strings.TrimSpace(string(body))
+ if message == "" {
+ message = resp.Status
+ }
+ common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message)
+ return
+ }
+
+ var discovery map[string]any
+ if err = common.DecodeJson(resp.Body, &discovery); err != nil {
+ common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error())
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "well_known_url": targetURL,
+ "discovery": discovery,
+ },
+ })
}
// CreateCustomOAuthProvider creates a new custom OAuth provider
@@ -134,6 +233,7 @@ func CreateCustomOAuthProvider(c *gin.Context) {
provider := &model.CustomOAuthProvider{
Name: req.Name,
Slug: req.Slug,
+ Icon: req.Icon,
Enabled: req.Enabled,
ClientId: req.ClientId,
ClientSecret: req.ClientSecret,
@@ -147,6 +247,8 @@ func CreateCustomOAuthProvider(c *gin.Context) {
EmailField: req.EmailField,
WellKnown: req.WellKnown,
AuthStyle: req.AuthStyle,
+ AccessPolicy: req.AccessPolicy,
+ AccessDeniedMessage: req.AccessDeniedMessage,
}
if err := model.CreateCustomOAuthProvider(provider); err != nil {
@@ -168,9 +270,10 @@ func CreateCustomOAuthProvider(c *gin.Context) {
type UpdateCustomOAuthProviderRequest struct {
Name string `json:"name"`
Slug string `json:"slug"`
- Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
+ Icon *string `json:"icon"` // Optional: if nil, keep existing
+ Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
ClientId string `json:"client_id"`
- ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
+ ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
@@ -181,6 +284,8 @@ type UpdateCustomOAuthProviderRequest struct {
EmailField string `json:"email_field"`
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
+ AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing
+ AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
@@ -227,6 +332,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
if req.Slug != "" {
provider.Slug = req.Slug
}
+ if req.Icon != nil {
+ provider.Icon = *req.Icon
+ }
if req.Enabled != nil {
provider.Enabled = *req.Enabled
}
@@ -266,6 +374,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
if req.AuthStyle != nil {
provider.AuthStyle = *req.AuthStyle
}
+ if req.AccessPolicy != nil {
+ provider.AccessPolicy = *req.AccessPolicy
+ }
+ if req.AccessDeniedMessage != nil {
+ provider.AccessDeniedMessage = *req.AccessDeniedMessage
+ }
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
common.ApiError(c, err)
@@ -327,6 +441,30 @@ func DeleteCustomOAuthProvider(c *gin.Context) {
})
}
+func buildUserOAuthBindingsResponse(userId int) ([]UserOAuthBindingResponse, error) {
+ bindings, err := model.GetUserOAuthBindingsByUserId(userId)
+ if err != nil {
+ return nil, err
+ }
+
+ response := make([]UserOAuthBindingResponse, 0, len(bindings))
+ for _, binding := range bindings {
+ provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
+ if err != nil {
+ continue
+ }
+ response = append(response, UserOAuthBindingResponse{
+ ProviderId: binding.ProviderId,
+ ProviderName: provider.Name,
+ ProviderSlug: provider.Slug,
+ ProviderIcon: provider.Icon,
+ ProviderUserId: binding.ProviderUserId,
+ })
+ }
+
+ return response, nil
+}
+
// GetUserOAuthBindings returns all OAuth bindings for the current user
func GetUserOAuthBindings(c *gin.Context) {
userId := c.GetInt("id")
@@ -335,32 +473,43 @@ func GetUserOAuthBindings(c *gin.Context) {
return
}
- bindings, err := model.GetUserOAuthBindingsByUserId(userId)
+ response, err := buildUserOAuthBindingsResponse(userId)
if err != nil {
common.ApiError(c, err)
return
}
- // Build response with provider info
- type BindingResponse struct {
- ProviderId int `json:"provider_id"`
- ProviderName string `json:"provider_name"`
- ProviderSlug string `json:"provider_slug"`
- ProviderUserId string `json:"provider_user_id"`
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": response,
+ })
+}
+
+func GetUserOAuthBindingsByAdmin(c *gin.Context) {
+ userIdStr := c.Param("id")
+ userId, err := strconv.Atoi(userIdStr)
+ if err != nil {
+ common.ApiErrorMsg(c, "invalid user id")
+ return
}
- response := make([]BindingResponse, 0)
- for _, binding := range bindings {
- provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
- if err != nil {
- continue // Skip if provider not found
- }
- response = append(response, BindingResponse{
- ProviderId: binding.ProviderId,
- ProviderName: provider.Name,
- ProviderSlug: provider.Slug,
- ProviderUserId: binding.ProviderUserId,
- })
+ targetUser, err := model.GetUserById(userId, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ myRole := c.GetInt("role")
+ if myRole <= targetUser.Role && myRole != common.RoleRootUser {
+ common.ApiErrorMsg(c, "no permission")
+ return
+ }
+
+ response, err := buildUserOAuthBindingsResponse(userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
}
c.JSON(http.StatusOK, gin.H{
@@ -395,3 +544,41 @@ func UnbindCustomOAuth(c *gin.Context) {
"message": "解绑成功",
})
}
+
+func UnbindCustomOAuthByAdmin(c *gin.Context) {
+ userIdStr := c.Param("id")
+ userId, err := strconv.Atoi(userIdStr)
+ if err != nil {
+ common.ApiErrorMsg(c, "invalid user id")
+ return
+ }
+
+ targetUser, err := model.GetUserById(userId, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ myRole := c.GetInt("role")
+ if myRole <= targetUser.Role && myRole != common.RoleRootUser {
+ common.ApiErrorMsg(c, "no permission")
+ return
+ }
+
+ providerIdStr := c.Param("provider_id")
+ providerId, err := strconv.Atoi(providerIdStr)
+ if err != nil {
+ common.ApiErrorMsg(c, "invalid provider id")
+ return
+ }
+
+ if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "success",
+ })
+}
diff --git a/controller/midjourney.go b/controller/midjourney.go
index c480c12bb..69aa5ccd4 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -105,13 +105,13 @@ func UpdateMidjourneyTaskBulk() {
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Get Mjp Task parse body error: %v", err))
continue
}
var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+ logger.LogError(ctx, fmt.Sprintf("Get Mjp Task parse body error2: %v, body: %s", err, string(responseBody)))
continue
}
resp.Body.Close()
@@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() {
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
+ preStatus := task.Status
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
@@ -172,18 +173,26 @@ func UpdateMidjourneyTaskBulk() {
shouldReturnQuota = true
}
}
- err = task.Update()
+ won, err := task.UpdateWithStatus(preStatus)
if err != nil {
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
- } else {
- if shouldReturnQuota {
- err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
- if err != nil {
- logger.LogError(ctx, "fail to increase user quota: "+err.Error())
- }
- logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+ } else if won && shouldReturnQuota {
+ err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
+ if err != nil {
+ logger.LogError(ctx, "fail to increase user quota: "+err.Error())
}
+ model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
+ UserId: task.UserId,
+ LogType: model.LogTypeRefund,
+ Content: "",
+ ChannelId: task.ChannelId,
+ ModelName: service.CovertMjpActionToModelName(task.Action),
+ Quota: task.Quota,
+ Other: map[string]interface{}{
+ "task_id": task.MjId,
+ "reason": "构图失败",
+ },
+ })
}
}
}
diff --git a/controller/misc.go b/controller/misc.go
index a16e2d554..b24a74adf 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -134,8 +134,10 @@ func GetStatus(c *gin.Context) {
customProviders := oauth.GetEnabledCustomProviders()
if len(customProviders) > 0 {
type CustomOAuthInfo struct {
+ Id int `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
+ Icon string `json:"icon"`
ClientId string `json:"client_id"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
Scopes string `json:"scopes"`
@@ -144,8 +146,10 @@ func GetStatus(c *gin.Context) {
for _, p := range customProviders {
config := p.GetConfig()
providersInfo = append(providersInfo, CustomOAuthInfo{
+ Id: config.Id,
Name: config.Name,
Slug: config.Slug,
+ Icon: config.Icon,
ClientId: config.ClientId,
AuthorizationEndpoint: config.AuthorizationEndpoint,
Scopes: config.Scopes,
diff --git a/controller/oauth.go b/controller/oauth.go
index 65e18f9da..818a28f84 100644
--- a/controller/oauth.go
+++ b/controller/oauth.go
@@ -237,6 +237,16 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
// Set up new user
user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
+
+ if oauthUser.Username != "" {
+ if exists, err := model.CheckUserExistOrDeleted(oauthUser.Username, ""); err == nil && !exists {
+ // 防止索引退化
+ if len(oauthUser.Username) <= model.UserNameMaxLength {
+ user.Username = oauthUser.Username
+ }
+ }
+ }
+
if oauthUser.DisplayName != "" {
user.DisplayName = oauthUser.DisplayName
} else if oauthUser.Username != "" {
@@ -295,12 +305,12 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
// Set the provider user ID on the user model and update
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
if err := tx.Model(user).Updates(map[string]interface{}{
- "github_id": user.GitHubId,
- "discord_id": user.DiscordId,
- "oidc_id": user.OidcId,
- "linux_do_id": user.LinuxDOId,
- "wechat_id": user.WeChatId,
- "telegram_id": user.TelegramId,
+ "github_id": user.GitHubId,
+ "discord_id": user.DiscordId,
+ "oidc_id": user.OidcId,
+ "linux_do_id": user.LinuxDOId,
+ "wechat_id": user.WeChatId,
+ "telegram_id": user.TelegramId,
}).Error; err != nil {
return err
}
@@ -340,6 +350,8 @@ func handleOAuthError(c *gin.Context, err error) {
} else {
common.ApiErrorI18n(c, e.MsgKey)
}
+ case *oauth.AccessDeniedError:
+ common.ApiErrorMsg(c, e.Message)
case *oauth.TrustLevelError:
common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
default:
diff --git a/controller/relay.go b/controller/relay.go
index e3e92bc51..1788b25b7 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -455,72 +455,147 @@ func RelayNotFound(c *gin.Context) {
})
}
-func RelayTask(c *gin.Context) {
- retryTimes := common.RetryTimes
- channelId := c.GetInt("channel_id")
- c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
+func RelayTaskFetch(c *gin.Context) {
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
if err != nil {
+ c.JSON(http.StatusInternalServerError, &dto.TaskError{
+ Code: "gen_relay_info_failed",
+ Message: err.Error(),
+ StatusCode: http.StatusInternalServerError,
+ })
return
}
- taskErr := taskRelayHandler(c, relayInfo)
- if taskErr == nil {
- retryTimes = 0
+ if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil {
+ respondTaskError(c, taskErr)
}
+}
+
+func RelayTask(c *gin.Context) {
+ relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, &dto.TaskError{
+ Code: "gen_relay_info_failed",
+ Message: err.Error(),
+ StatusCode: http.StatusInternalServerError,
+ })
+ return
+ }
+
+ if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil {
+ respondTaskError(c, taskErr)
+ return
+ }
+
+ var result *relay.TaskSubmitResult
+ var taskErr *dto.TaskError
+ defer func() {
+ if taskErr != nil && relayInfo.Billing != nil {
+ relayInfo.Billing.Refund(c)
+ }
+ }()
+
retryParam := &service.RetryParam{
Ctx: c,
TokenGroup: relayInfo.TokenGroup,
ModelName: relayInfo.OriginModelName,
Retry: common.GetPointer(0),
}
- for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() {
- channel, newAPIError := getChannel(c, relayInfo, retryParam)
- if newAPIError != nil {
- logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
- taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
- break
- }
- channelId = channel.Id
- useChannel := c.GetStringSlice("use_channel")
- useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
- c.Set("use_channel", useChannel)
- logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
- //middleware.SetupContextForSelectedChannel(c, channel, originalModel)
- bodyStorage, err := common.GetBodyStorage(c)
- if err != nil {
- if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
- taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
+ for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
+ var channel *model.Channel
+
+ if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil {
+ channel = lockedCh
+ if retryParam.GetRetry() > 0 {
+ if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil {
+ taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError)
+ break
+ }
+ }
+ } else {
+ var channelErr *types.NewAPIError
+ channel, channelErr = getChannel(c, relayInfo, retryParam)
+ if channelErr != nil {
+ logger.LogError(c, channelErr.Error())
+ taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError)
+ break
+ }
+ }
+
+ addUsedChannel(c, channel.Id)
+ bodyStorage, bodyErr := common.GetBodyStorage(c)
+ if bodyErr != nil {
+ if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
+ taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
} else {
- taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest)
+ taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusBadRequest)
}
break
}
c.Request.Body = io.NopCloser(bodyStorage)
- taskErr = taskRelayHandler(c, relayInfo)
+
+ result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
+ if taskErr == nil {
+ break
+ }
+
+ if !taskErr.LocalError {
+ processChannelError(c,
+ *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey,
+ common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()),
+ types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
+ }
+
+ if !shouldRetryTaskRelay(c, channel.Id, taskErr, common.RetryTimes-retryParam.GetRetry()) {
+ break
+ }
}
+
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
logger.LogInfo(c, retryLogStr)
}
- if taskErr != nil {
- if taskErr.StatusCode == http.StatusTooManyRequests {
- taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
+
+ // ── 成功:结算 + 日志 + 插入任务 ──
+ if taskErr == nil {
+ if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
+ common.SysError("settle task billing error: " + settleErr.Error())
}
- c.JSON(taskErr.StatusCode, taskErr)
+ service.LogTaskConsumption(c, relayInfo)
+
+ task := model.InitTask(result.Platform, relayInfo)
+ task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
+ task.PrivateData.BillingSource = relayInfo.BillingSource
+ task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
+ task.PrivateData.TokenId = relayInfo.TokenId
+ task.PrivateData.BillingContext = &model.TaskBillingContext{
+ ModelPrice: relayInfo.PriceData.ModelPrice,
+ GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
+ ModelRatio: relayInfo.PriceData.ModelRatio,
+ OtherRatios: relayInfo.PriceData.OtherRatios,
+ OriginModelName: relayInfo.OriginModelName,
+ PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
+ }
+ task.Quota = result.Quota
+ task.Data = result.TaskData
+ task.Action = relayInfo.Action
+ if insertErr := task.Insert(); insertErr != nil {
+ common.SysError("insert task error: " + insertErr.Error())
+ }
+ }
+
+ if taskErr != nil {
+ respondTaskError(c, taskErr)
}
}
-func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
- var err *dto.TaskError
- switch relayInfo.RelayMode {
- case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
- err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
- default:
- err = relay.RelayTaskSubmit(c, relayInfo)
+// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写)
+func respondTaskError(c *gin.Context, taskErr *dto.TaskError) {
+ if taskErr.StatusCode == http.StatusTooManyRequests {
+ taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
}
- return err
+ c.JSON(taskErr.StatusCode, taskErr)
}
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
@@ -544,7 +619,7 @@ func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError,
}
if taskErr.StatusCode/100 == 5 {
// 超时不重试
- if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
+ if operation_setting.IsAlwaysSkipRetryStatusCode(taskErr.StatusCode) {
return false
}
return true
diff --git a/controller/task.go b/controller/task.go
index 244f9161c..eac7db153 100644
--- a/controller/task.go
+++ b/controller/task.go
@@ -1,231 +1,22 @@
package controller
import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "sort"
"strconv"
- "time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay"
+ "github.com/QuantumNous/new-api/service"
+ "github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
- "github.com/samber/lo"
)
+// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层
func UpdateTaskBulk() {
- //revocer
- //imageModel := "midjourney"
- for {
- time.Sleep(time.Duration(15) * time.Second)
- common.SysLog("任务进度轮询开始")
- ctx := context.TODO()
- allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
- platformTask := make(map[constant.TaskPlatform][]*model.Task)
- for _, t := range allTasks {
- platformTask[t.Platform] = append(platformTask[t.Platform], t)
- }
- for platform, tasks := range platformTask {
- if len(tasks) == 0 {
- continue
- }
- taskChannelM := make(map[int][]string)
- taskM := make(map[string]*model.Task)
- nullTaskIds := make([]int64, 0)
- for _, task := range tasks {
- if task.TaskID == "" {
- // 统计失败的未完成任务
- nullTaskIds = append(nullTaskIds, task.ID)
- continue
- }
- taskM[task.TaskID] = task
- taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
- }
- if len(nullTaskIds) > 0 {
- err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
- } else {
- logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
- }
- }
- if len(taskChannelM) == 0 {
- continue
- }
-
- UpdateTaskByPlatform(platform, taskChannelM, taskM)
- }
- common.SysLog("任务进度轮询完成")
- }
-}
-
-func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
- switch platform {
- case constant.TaskPlatformMidjourney:
- //_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
- case constant.TaskPlatformSuno:
- _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
- default:
- if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
- common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
- }
- }
-}
-
-func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
- for channelId, taskIds := range taskChannelM {
- err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
- }
- }
- return nil
-}
-
-func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
- if len(taskIds) == 0 {
- return nil
- }
- channel, err := model.CacheGetChannel(channelId)
- if err != nil {
- common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
- err = model.TaskBulkUpdate(taskIds, map[string]any{
- "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
- }
- return err
- }
- adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
- if adaptor == nil {
- return errors.New("adaptor not found")
- }
- proxy := channel.GetSetting().Proxy
- resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
- "ids": taskIds,
- }, proxy)
- if err != nil {
- common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
- return err
- }
- if resp.StatusCode != http.StatusOK {
- logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
- return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
- }
- defer resp.Body.Close()
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
- return err
- }
- var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
- err = json.Unmarshal(responseBody, &responseItems)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
- return err
- }
- if !responseItems.IsSuccess() {
- common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
- return err
- }
-
- for _, responseItem := range responseItems.Data {
- task := taskM[responseItem.TaskID]
- if !checkTaskNeedUpdate(task, responseItem) {
- continue
- }
-
- task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
- task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
- task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
- task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
- task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
- if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
- logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
- task.Progress = "100%"
- //err = model.CacheUpdateUserQuota(task.UserId) ?
- if err != nil {
- logger.LogError(ctx, "error update user quota cache: "+err.Error())
- } else {
- quota := task.Quota
- if quota != 0 {
- err = model.IncreaseUserQuota(task.UserId, quota, false)
- if err != nil {
- logger.LogError(ctx, "fail to increase user quota: "+err.Error())
- }
- logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- }
- }
- if responseItem.Status == model.TaskStatusSuccess {
- task.Progress = "100%"
- }
- task.Data = responseItem.Data
-
- err = task.Update()
- if err != nil {
- common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
- }
- }
- return nil
-}
-
-func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
-
- if oldTask.SubmitTime != newTask.SubmitTime {
- return true
- }
- if oldTask.StartTime != newTask.StartTime {
- return true
- }
- if oldTask.FinishTime != newTask.FinishTime {
- return true
- }
- if string(oldTask.Status) != newTask.Status {
- return true
- }
- if oldTask.FailReason != newTask.FailReason {
- return true
- }
- if oldTask.FinishTime != newTask.FinishTime {
- return true
- }
-
- if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
- return true
- }
-
- oldData, _ := json.Marshal(oldTask.Data)
- newData, _ := json.Marshal(newTask.Data)
-
- sort.Slice(oldData, func(i, j int) bool {
- return oldData[i] < oldData[j]
- })
- sort.Slice(newData, func(i, j int) bool {
- return newData[i] < newData[j]
- })
-
- if string(oldData) != string(newData) {
- return true
- }
- return false
+ service.TaskPollingLoop()
}
func GetAllTask(c *gin.Context) {
@@ -247,7 +38,7 @@ func GetAllTask(c *gin.Context) {
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
total := model.TaskCountAllTasks(queryParams)
pageInfo.SetTotal(int(total))
- pageInfo.SetItems(items)
+ pageInfo.SetItems(tasksToDto(items, true))
common.ApiSuccess(c, pageInfo)
}
@@ -271,6 +62,33 @@ func GetUserTask(c *gin.Context) {
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
total := model.TaskCountAllUserTask(userId, queryParams)
pageInfo.SetTotal(int(total))
- pageInfo.SetItems(items)
+ pageInfo.SetItems(tasksToDto(items, false))
common.ApiSuccess(c, pageInfo)
}
+
+func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto {
+ var userIdMap map[int]*model.UserBase
+ if fillUser {
+ userIdMap = make(map[int]*model.UserBase)
+ userIds := types.NewSet[int]()
+ for _, task := range tasks {
+ userIds.Add(task.UserId)
+ }
+ for _, userId := range userIds.Items() {
+ cacheUser, err := model.GetUserCache(userId)
+ if err == nil {
+ userIdMap[userId] = cacheUser
+ }
+ }
+ }
+ result := make([]*dto.TaskDto, len(tasks))
+ for i, task := range tasks {
+ if fillUser {
+ if user, ok := userIdMap[task.UserId]; ok {
+ task.Username = user.Username
+ }
+ }
+ result[i] = relay.TaskModel2Dto(task)
+ }
+ return result
+}
diff --git a/controller/task_video.go b/controller/task_video.go
deleted file mode 100644
index d7c19e620..000000000
--- a/controller/task_video.go
+++ /dev/null
@@ -1,313 +0,0 @@
-package controller
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "time"
-
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/relay"
- "github.com/QuantumNous/new-api/relay/channel"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/setting/ratio_setting"
-)
-
-func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
- for channelId, taskIds := range taskChannelM {
- if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
- logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
- }
- }
- return nil
-}
-
-func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
- if len(taskIds) == 0 {
- return nil
- }
- cacheGetChannel, err := model.CacheGetChannel(channelId)
- if err != nil {
- errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
- "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
- "status": "FAILURE",
- "progress": "100%",
- })
- if errUpdate != nil {
- common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
- }
- return fmt.Errorf("CacheGetChannel failed: %w", err)
- }
- adaptor := relay.GetTaskAdaptor(platform)
- if adaptor == nil {
- return fmt.Errorf("video adaptor not found")
- }
- info := &relaycommon.RelayInfo{}
- info.ChannelMeta = &relaycommon.ChannelMeta{
- ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
- }
- info.ApiKey = cacheGetChannel.Key
- adaptor.Init(info)
- for _, taskId := range taskIds {
- if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
- logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
- }
- }
- return nil
-}
-
-func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
- baseURL := constant.ChannelBaseURLs[channel.Type]
- if channel.GetBaseURL() != "" {
- baseURL = channel.GetBaseURL()
- }
- proxy := channel.GetSetting().Proxy
-
- task := taskM[taskId]
- if task == nil {
- logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
- return fmt.Errorf("task %s not found", taskId)
- }
- key := channel.Key
-
- privateData := task.PrivateData
- if privateData.Key != "" {
- key = privateData.Key
- }
- resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
- "task_id": taskId,
- "action": task.Action,
- }, proxy)
- if err != nil {
- return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
- }
- //if resp.StatusCode != http.StatusOK {
- //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
- //}
- defer resp.Body.Close()
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
- }
-
- logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
-
- taskResult := &relaycommon.TaskInfo{}
- // try parse as New API response format
- var responseItems dto.TaskResponse[model.Task]
- if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
- logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
- t := responseItems.Data
- taskResult.TaskID = t.TaskID
- taskResult.Status = string(t.Status)
- taskResult.Url = t.FailReason
- taskResult.Progress = t.Progress
- taskResult.Reason = t.FailReason
- task.Data = t.Data
- } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
- return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
- } else {
- task.Data = redactVideoResponseBody(responseBody)
- }
-
- logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
-
- now := time.Now().Unix()
- if taskResult.Status == "" {
- //return fmt.Errorf("task %s status is empty", taskId)
- taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
- }
-
- // 记录原本的状态,防止重复退款
- shouldRefund := false
- quota := task.Quota
- preStatus := task.Status
-
- task.Status = model.TaskStatus(taskResult.Status)
- switch taskResult.Status {
- case model.TaskStatusSubmitted:
- task.Progress = "10%"
- case model.TaskStatusQueued:
- task.Progress = "20%"
- case model.TaskStatusInProgress:
- task.Progress = "30%"
- if task.StartTime == 0 {
- task.StartTime = now
- }
- case model.TaskStatusSuccess:
- task.Progress = "100%"
- if task.FinishTime == 0 {
- task.FinishTime = now
- }
- if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
- task.FailReason = taskResult.Url
- }
-
- // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
- if taskResult.TotalTokens > 0 {
- // 获取模型名称
- var taskData map[string]interface{}
- if err := json.Unmarshal(task.Data, &taskData); err == nil {
- if modelName, ok := taskData["model"].(string); ok && modelName != "" {
- // 获取模型价格和倍率
- modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
- // 只有配置了倍率(非固定价格)时才按 token 重新计费
- if hasRatioSetting && modelRatio > 0 {
- // 获取用户和组的倍率信息
- group := task.Group
- if group == "" {
- user, err := model.GetUserById(task.UserId, false)
- if err == nil {
- group = user.Group
- }
- }
- if group != "" {
- groupRatio := ratio_setting.GetGroupRatio(group)
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
-
- var finalGroupRatio float64
- if hasUserGroupRatio {
- finalGroupRatio = userGroupRatio
- } else {
- finalGroupRatio = groupRatio
- }
-
- // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
- actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
-
- // 计算差额
- preConsumedQuota := task.Quota
- quotaDelta := actualQuota - preConsumedQuota
-
- if quotaDelta > 0 {
- // 需要补扣费
- logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
- task.TaskID,
- logger.LogQuota(quotaDelta),
- logger.LogQuota(actualQuota),
- logger.LogQuota(preConsumedQuota),
- taskResult.TotalTokens,
- ))
- if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
- logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
- } else {
- model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
- model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
- task.Quota = actualQuota // 更新任务记录的实际扣费额度
-
- // 记录消费日志
- logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
- modelRatio, finalGroupRatio, taskResult.TotalTokens,
- logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- } else if quotaDelta < 0 {
- // 需要退还多扣的费用
- refundQuota := -quotaDelta
- logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
- task.TaskID,
- logger.LogQuota(refundQuota),
- logger.LogQuota(actualQuota),
- logger.LogQuota(preConsumedQuota),
- taskResult.TotalTokens,
- ))
- if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
- logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
- } else {
- task.Quota = actualQuota // 更新任务记录的实际扣费额度
-
- // 记录退款日志
- logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
- modelRatio, finalGroupRatio, taskResult.TotalTokens,
- logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- } else {
- // quotaDelta == 0, 预扣费刚好准确
- logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
- task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
- }
- }
- }
- }
- }
- }
- case model.TaskStatusFailure:
- logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
- task.Status = model.TaskStatusFailure
- task.Progress = "100%"
- if task.FinishTime == 0 {
- task.FinishTime = now
- }
- task.FailReason = taskResult.Reason
- logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
- taskResult.Progress = "100%"
- if quota != 0 {
- if preStatus != model.TaskStatusFailure {
- shouldRefund = true
- } else {
- logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
- }
- }
- default:
- return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
- }
- if taskResult.Progress != "" {
- task.Progress = taskResult.Progress
- }
- if err := task.Update(); err != nil {
- common.SysLog("UpdateVideoTask task error: " + err.Error())
- shouldRefund = false
- }
-
- if shouldRefund {
- // 任务失败且之前状态不是失败才退还额度,防止重复退还
- if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
- logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
- }
- logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
-
- return nil
-}
-
-func redactVideoResponseBody(body []byte) []byte {
- var m map[string]any
- if err := json.Unmarshal(body, &m); err != nil {
- return body
- }
- resp, _ := m["response"].(map[string]any)
- if resp != nil {
- delete(resp, "bytesBase64Encoded")
- if v, ok := resp["video"].(string); ok {
- resp["video"] = truncateBase64(v)
- }
- if vs, ok := resp["videos"].([]any); ok {
- for i := range vs {
- if vm, ok := vs[i].(map[string]any); ok {
- delete(vm, "bytesBase64Encoded")
- }
- }
- }
- }
- b, err := json.Marshal(m)
- if err != nil {
- return body
- }
- return b
-}
-
-func truncateBase64(s string) string {
- const maxKeep = 256
- if len(s) <= maxKeep {
- return s
- }
- return s[:maxKeep] + "..."
-}
diff --git a/controller/user.go b/controller/user.go
index db078071e..b58eab88f 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -582,6 +582,44 @@ func UpdateUser(c *gin.Context) {
return
}
+func AdminClearUserBinding(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiErrorI18n(c, i18n.MsgInvalidParams)
+ return
+ }
+
+ bindingType := strings.ToLower(strings.TrimSpace(c.Param("binding_type")))
+ if bindingType == "" {
+ common.ApiErrorI18n(c, i18n.MsgInvalidParams)
+ return
+ }
+
+ user, err := model.GetUserById(id, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ myRole := c.GetInt("role")
+ if myRole <= user.Role && myRole != common.RoleRootUser {
+ common.ApiErrorI18n(c, i18n.MsgUserNoPermissionSameLevel)
+ return
+ }
+
+ if err := user.ClearBinding(bindingType); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ model.RecordLog(user.Id, model.LogTypeManage, fmt.Sprintf("admin cleared %s binding for user %s", bindingType, user.Username))
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "success",
+ })
+}
+
func UpdateSelf(c *gin.Context) {
var requestData map[string]interface{}
err := json.NewDecoder(c.Request.Body).Decode(&requestData)
diff --git a/controller/video_proxy.go b/controller/video_proxy.go
index f102baae4..f1dd2bc92 100644
--- a/controller/video_proxy.go
+++ b/controller/video_proxy.go
@@ -16,59 +16,44 @@ import (
"github.com/gin-gonic/gin"
)
+// videoProxyError returns a standardized OpenAI-style error response.
+func videoProxyError(c *gin.Context, status int, errType, message string) {
+ c.JSON(status, gin.H{
+ "error": gin.H{
+ "message": message,
+ "type": errType,
+ },
+ })
+}
+
func VideoProxy(c *gin.Context) {
taskID := c.Param("task_id")
if taskID == "" {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": gin.H{
- "message": "task_id is required",
- "type": "invalid_request_error",
- },
- })
+ videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required")
return
}
task, exists, err := model.GetByOnlyTaskId(taskID)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": "Failed to query task",
- "type": "server_error",
- },
- })
+ videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task")
return
}
if !exists || task == nil {
- logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
- c.JSON(http.StatusNotFound, gin.H{
- "error": gin.H{
- "message": "Task not found",
- "type": "invalid_request_error",
- },
- })
+ videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found")
return
}
if task.Status != model.TaskStatusSuccess {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": gin.H{
- "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
- "type": "invalid_request_error",
- },
- })
+ videoProxyError(c, http.StatusBadRequest, "invalid_request_error",
+ fmt.Sprintf("Task is not completed yet, current status: %s", task.Status))
return
}
channel, err := model.CacheGetChannel(task.ChannelId)
if err != nil {
- logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": "Failed to retrieve channel information",
- "type": "server_error",
- },
- })
+ logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error()))
+ videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information")
return
}
baseURL := channel.GetBaseURL()
@@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) {
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": "Failed to create proxy client",
- "type": "server_error",
- },
- })
+ videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client")
return
}
@@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": "Failed to create proxy request",
- "type": "server_error",
- },
- })
+ videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
return
}
@@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) {
apiKey := task.PrivateData.Key
if apiKey == "" {
logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": "API key not stored for task",
- "type": "server_error",
- },
- })
+ videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task")
return
}
-
videoURL, err = getGeminiVideoURL(channel, task, apiKey)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
- c.JSON(http.StatusBadGateway, gin.H{
- "error": gin.H{
- "message": "Failed to resolve Gemini video URL",
- "type": "server_error",
- },
- })
+ videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL")
return
}
req.Header.Set("x-goog-api-key", apiKey)
case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
- videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
+ videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
req.Header.Set("Authorization", "Bearer "+channel.Key)
default:
- // Video URL is directly in task.FailReason
- videoURL = task.FailReason
+ // Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data)
+ videoURL = task.GetResultURL()
}
req.URL, err = url.Parse(videoURL)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": "Failed to create proxy request",
- "type": "server_error",
- },
- })
+ videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
return
}
resp, err := client.Do(req)
if err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
- c.JSON(http.StatusBadGateway, gin.H{
- "error": gin.H{
- "message": "Failed to fetch video content",
- "type": "server_error",
- },
- })
+ videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
- c.JSON(http.StatusBadGateway, gin.H{
- "error": gin.H{
- "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
- "type": "server_error",
- },
- })
+ videoProxyError(c, http.StatusBadGateway, "server_error",
+ fmt.Sprintf("Upstream service returned status %d", resp.StatusCode))
return
}
@@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) {
}
}
- c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
+ c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
+ if _, err = io.Copy(c.Writer, resp.Body); err != nil {
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
}
}
diff --git a/controller/video_proxy_gemini.go b/controller/video_proxy_gemini.go
index 053ac6515..a63a2a5c4 100644
--- a/controller/video_proxy_gemini.go
+++ b/controller/video_proxy_gemini.go
@@ -1,12 +1,12 @@
package controller
import (
- "encoding/json"
"fmt"
"io"
"strconv"
"strings"
+ "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay"
@@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
proxy := channel.GetSetting().Proxy
resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
- "task_id": task.TaskID,
+ "task_id": task.GetUpstreamTaskID(),
"action": task.Action,
}, proxy)
if err != nil {
@@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
return ""
}
var payload map[string]any
- if err := json.Unmarshal(task.Data, &payload); err != nil {
+ if err := common.Unmarshal(task.Data, &payload); err != nil {
return ""
}
return extractGeminiVideoURLFromMap(payload)
@@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
func extractGeminiVideoURLFromPayload(body []byte) string {
var payload map[string]any
- if err := json.Unmarshal(body, &payload); err != nil {
+ if err := common.Unmarshal(body, &payload); err != nil {
return ""
}
return extractGeminiVideoURLFromMap(payload)
diff --git a/dto/channel_settings.go b/dto/channel_settings.go
index 74bceb281..72fdf460c 100644
--- a/dto/channel_settings.go
+++ b/dto/channel_settings.go
@@ -24,14 +24,16 @@ const (
)
type ChannelOtherSettings struct {
- AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
- VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
- OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
- ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true
- AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
- DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
- AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
- AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"`
+ AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
+ VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
+ OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
+ ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true
+ AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
+ AllowInferenceGeo bool `json:"allow_inference_geo,omitempty"` // 是否允许 inference_geo 透传(仅 Claude,默认过滤以满足数据驻留合规)
+ DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
+ AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
+ AllowIncludeObfuscation bool `json:"allow_include_obfuscation,omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护)
+ AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"`
}
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
diff --git a/dto/claude.go b/dto/claude.go
index 8b6b495f6..32e31710b 100644
--- a/dto/claude.go
+++ b/dto/claude.go
@@ -190,10 +190,13 @@ type ClaudeToolChoice struct {
}
type ClaudeRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt,omitempty"`
- System any `json:"system,omitempty"`
- Messages []ClaudeMessage `json:"messages,omitempty"`
+ Model string `json:"model"`
+ Prompt string `json:"prompt,omitempty"`
+ System any `json:"system,omitempty"`
+ Messages []ClaudeMessage `json:"messages,omitempty"`
+ // InferenceGeo controls Claude data residency region.
+ // This field is filtered by default and can be enabled via channel setting allow_inference_geo.
+ InferenceGeo string `json:"inference_geo,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
@@ -210,7 +213,8 @@ type ClaudeRequest struct {
Thinking *Thinking `json:"thinking,omitempty"`
McpServers json.RawMessage `json:"mcp_servers,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"`
- // 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤
+ // ServiceTier specifies upstream service level and may affect billing.
+ // This field is filtered by default and can be enabled via channel setting allow_service_tier.
ServiceTier string `json:"service_tier,omitempty"`
}
diff --git a/dto/gemini.go b/dto/gemini.go
index 0fd74c639..b97f19ec6 100644
--- a/dto/gemini.go
+++ b/dto/gemini.go
@@ -324,25 +324,26 @@ type GeminiChatTool struct {
}
type GeminiChatGenerationConfig struct {
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK float64 `json:"topK,omitempty"`
- MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- StopSequences []string `json:"stopSequences,omitempty"`
- ResponseMimeType string `json:"responseMimeType,omitempty"`
- ResponseSchema any `json:"responseSchema,omitempty"`
- ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
- PresencePenalty *float32 `json:"presencePenalty,omitempty"`
- FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
- ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
- Logprobs *int32 `json:"logprobs,omitempty"`
- MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
- Seed int64 `json:"seed,omitempty"`
- ResponseModalities []string `json:"responseModalities,omitempty"`
- ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
- SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
- ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK float64 `json:"topK,omitempty"`
+ MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ StopSequences []string `json:"stopSequences,omitempty"`
+ ResponseMimeType string `json:"responseMimeType,omitempty"`
+ ResponseSchema any `json:"responseSchema,omitempty"`
+ ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
+ PresencePenalty *float32 `json:"presencePenalty,omitempty"`
+ FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
+ ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
+ Logprobs *int32 `json:"logprobs,omitempty"`
+ EnableEnhancedCivicAnswers *bool `json:"enableEnhancedCivicAnswers,omitempty"`
+ MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
+ Seed int64 `json:"seed,omitempty"`
+ ResponseModalities []string `json:"responseModalities,omitempty"`
+ ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
+ SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
+ ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config
}
// UnmarshalJSON allows GeminiChatGenerationConfig to accept both snake_case and camelCase fields.
@@ -350,22 +351,23 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
type Alias GeminiChatGenerationConfig
var aux struct {
Alias
- TopPSnake float64 `json:"top_p,omitempty"`
- TopKSnake float64 `json:"top_k,omitempty"`
- MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"`
- CandidateCountSnake int `json:"candidate_count,omitempty"`
- StopSequencesSnake []string `json:"stop_sequences,omitempty"`
- ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"`
- ResponseSchemaSnake any `json:"response_schema,omitempty"`
- ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"`
- PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"`
- FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"`
- ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"`
- MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"`
- ResponseModalitiesSnake []string `json:"response_modalities,omitempty"`
- ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"`
- SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"`
- ImageConfigSnake json.RawMessage `json:"image_config,omitempty"`
+ TopPSnake float64 `json:"top_p,omitempty"`
+ TopKSnake float64 `json:"top_k,omitempty"`
+ MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"`
+ CandidateCountSnake int `json:"candidate_count,omitempty"`
+ StopSequencesSnake []string `json:"stop_sequences,omitempty"`
+ ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"`
+ ResponseSchemaSnake any `json:"response_schema,omitempty"`
+ ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"`
+ PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"`
+ FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"`
+ ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"`
+ EnableEnhancedCivicAnswersSnake *bool `json:"enable_enhanced_civic_answers,omitempty"`
+ MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"`
+ ResponseModalitiesSnake []string `json:"response_modalities,omitempty"`
+ ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"`
+ SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"`
+ ImageConfigSnake json.RawMessage `json:"image_config,omitempty"`
}
if err := common.Unmarshal(data, &aux); err != nil {
@@ -408,6 +410,9 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
if aux.ResponseLogprobsSnake {
c.ResponseLogprobs = aux.ResponseLogprobsSnake
}
+ if aux.EnableEnhancedCivicAnswersSnake != nil {
+ c.EnableEnhancedCivicAnswers = aux.EnableEnhancedCivicAnswersSnake
+ }
if aux.MediaResolutionSnake != "" {
c.MediaResolution = aux.MediaResolutionSnake
}
@@ -453,12 +458,14 @@ type GeminiChatResponse struct {
}
type GeminiUsageMetadata struct {
- PromptTokenCount int `json:"promptTokenCount"`
- CandidatesTokenCount int `json:"candidatesTokenCount"`
- TotalTokenCount int `json:"totalTokenCount"`
- ThoughtsTokenCount int `json:"thoughtsTokenCount"`
- CachedContentTokenCount int `json:"cachedContentTokenCount"`
- PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+ PromptTokenCount int `json:"promptTokenCount"`
+ ToolUsePromptTokenCount int `json:"toolUsePromptTokenCount"`
+ CandidatesTokenCount int `json:"candidatesTokenCount"`
+ TotalTokenCount int `json:"totalTokenCount"`
+ ThoughtsTokenCount int `json:"thoughtsTokenCount"`
+ CachedContentTokenCount int `json:"cachedContentTokenCount"`
+ PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+ ToolUsePromptTokensDetails []GeminiPromptTokensDetails `json:"toolUsePromptTokensDetails"`
}
type GeminiPromptTokensDetails struct {
diff --git a/dto/openai_request.go b/dto/openai_request.go
index 9113a086e..c0a69a376 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -54,18 +54,22 @@ type GeneralOpenAIRequest struct {
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
+ FunctionCall json.RawMessage `json:"function_call,omitempty"`
User string `json:"user,omitempty"`
- LogProbs bool `json:"logprobs,omitempty"`
- TopLogProbs int `json:"top_logprobs,omitempty"`
- Dimensions int `json:"dimensions,omitempty"`
- Modalities json.RawMessage `json:"modalities,omitempty"`
- Audio json.RawMessage `json:"audio,omitempty"`
+ // ServiceTier specifies upstream service level and may affect billing.
+ // This field is filtered by default and can be enabled via channel setting allow_service_tier.
+ ServiceTier string `json:"service_tier,omitempty"`
+ LogProbs bool `json:"logprobs,omitempty"`
+ TopLogProbs int `json:"top_logprobs,omitempty"`
+ Dimensions int `json:"dimensions,omitempty"`
+ Modalities json.RawMessage `json:"modalities,omitempty"`
+ Audio json.RawMessage `json:"audio,omitempty"`
// 安全标识符,用于帮助 OpenAI 检测可能违反使用政策的应用程序用户
- // 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤以保护用户隐私
+ // 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤,可通过 allow_safety_identifier 开启
SafetyIdentifier string `json:"safety_identifier,omitempty"`
// Whether or not to store the output of this chat completion request for use in our model distillation or evals products.
// 是否存储此次请求数据供 OpenAI 用于评估和优化产品
- // 注意:默认过滤此字段以保护用户隐私,但过滤后可能导致 Codex 无法正常使用
+ // 注意:默认允许透传,可通过 disable_store 禁用;禁用后可能导致 Codex 无法正常使用
Store json.RawMessage `json:"store,omitempty"`
// Used by OpenAI to cache responses for similar requests to optimize your cache hit rates. Replaces the user field
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
@@ -261,6 +265,9 @@ type FunctionRequest struct {
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
+ // IncludeObfuscation is only for /v1/responses stream payload.
+ // This field is filtered by default and can be enabled via channel setting allow_include_obfuscation.
+ IncludeObfuscation bool `json:"include_obfuscation,omitempty"`
}
func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
@@ -799,30 +806,42 @@ type WebSearchOptions struct {
// https://platform.openai.com/docs/api-reference/responses/create
type OpenAIResponsesRequest struct {
- Model string `json:"model"`
- Input json.RawMessage `json:"input,omitempty"`
- Include json.RawMessage `json:"include,omitempty"`
+ Model string `json:"model"`
+ Input json.RawMessage `json:"input,omitempty"`
+ Include json.RawMessage `json:"include,omitempty"`
+ // 在后台运行推理,暂时还不支持依赖的接口
+ // Background json.RawMessage `json:"background,omitempty"`
+ Conversation json.RawMessage `json:"conversation,omitempty"`
+ ContextManagement json.RawMessage `json:"context_management,omitempty"`
Instructions json.RawMessage `json:"instructions,omitempty"`
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
+ TopLogProbs *int `json:"top_logprobs,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"`
ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
- // 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤
- ServiceTier string `json:"service_tier,omitempty"`
+ // ServiceTier specifies upstream service level and may affect billing.
+ // This field is filtered by default and can be enabled via channel setting allow_service_tier.
+ ServiceTier string `json:"service_tier,omitempty"`
+ // Store controls whether upstream may store request/response data.
+ // This field is allowed by default and can be disabled via channel setting disable_store.
Store json.RawMessage `json:"store,omitempty"`
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- Text json.RawMessage `json:"text,omitempty"`
- ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
- Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
- TopP *float64 `json:"top_p,omitempty"`
- Truncation string `json:"truncation,omitempty"`
- User string `json:"user,omitempty"`
- MaxToolCalls uint `json:"max_tool_calls,omitempty"`
- Prompt json.RawMessage `json:"prompt,omitempty"`
+ // SafetyIdentifier carries client identity for policy abuse detection.
+ // This field is filtered by default and can be enabled via channel setting allow_safety_identifier.
+ SafetyIdentifier string `json:"safety_identifier,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ StreamOptions *StreamOptions `json:"stream_options,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ Text json.RawMessage `json:"text,omitempty"`
+ ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
+ Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
+ TopP *float64 `json:"top_p,omitempty"`
+ Truncation string `json:"truncation,omitempty"`
+ User string `json:"user,omitempty"`
+ MaxToolCalls uint `json:"max_tool_calls,omitempty"`
+ Prompt json.RawMessage `json:"prompt,omitempty"`
// qwen
EnableThinking json.RawMessage `json:"enable_thinking,omitempty"`
// perplexity
diff --git a/dto/suno.go b/dto/suno.go
index a6bb3ebae..90e11b810 100644
--- a/dto/suno.go
+++ b/dto/suno.go
@@ -4,10 +4,6 @@ import (
"encoding/json"
)
-type TaskData interface {
- SunoDataResponse | []SunoDataResponse | string | any
-}
-
type SunoSubmitReq struct {
GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"`
Prompt string `json:"prompt,omitempty"`
@@ -20,10 +16,6 @@ type SunoSubmitReq struct {
MakeInstrumental bool `json:"make_instrumental"`
}
-type FetchReq struct {
- IDs []string `json:"ids"`
-}
-
type SunoDataResponse struct {
TaskID string `json:"task_id" gorm:"type:varchar(50);index"`
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
@@ -66,30 +58,6 @@ type SunoLyrics struct {
Text string `json:"text"`
}
-const TaskSuccessCode = "success"
-
-type TaskResponse[T TaskData] struct {
- Code string `json:"code"`
- Message string `json:"message"`
- Data T `json:"data"`
-}
-
-func (t *TaskResponse[T]) IsSuccess() bool {
- return t.Code == TaskSuccessCode
-}
-
-type TaskDto struct {
- TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
- Action string `json:"action"` // 任务类型, song, lyrics, description-mode
- Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed
- FailReason string `json:"fail_reason"`
- SubmitTime int64 `json:"submit_time"`
- StartTime int64 `json:"start_time"`
- FinishTime int64 `json:"finish_time"`
- Progress string `json:"progress"`
- Data json.RawMessage `json:"data"`
-}
-
type SunoGoAPISubmitReq struct {
CustomMode bool `json:"custom_mode"`
diff --git a/dto/task.go b/dto/task.go
index afc186b41..4a9a8e2e6 100644
--- a/dto/task.go
+++ b/dto/task.go
@@ -1,5 +1,9 @@
package dto
+import (
+ "encoding/json"
+)
+
type TaskError struct {
Code string `json:"code"`
Message string `json:"message"`
@@ -8,3 +12,46 @@ type TaskError struct {
LocalError bool `json:"-"`
Error error `json:"-"`
}
+
+type TaskData interface {
+ SunoDataResponse | []SunoDataResponse | string | any
+}
+
+const TaskSuccessCode = "success"
+
+type TaskResponse[T TaskData] struct {
+ Code string `json:"code"`
+ Message string `json:"message"`
+ Data T `json:"data"`
+}
+
+func (t *TaskResponse[T]) IsSuccess() bool {
+ return t.Code == TaskSuccessCode
+}
+
+type TaskDto struct {
+ ID int64 `json:"id"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+ TaskID string `json:"task_id"`
+ Platform string `json:"platform"`
+ UserId int `json:"user_id"`
+ Group string `json:"group"`
+ ChannelId int `json:"channel_id"`
+ Quota int `json:"quota"`
+ Action string `json:"action"`
+ Status string `json:"status"`
+ FailReason string `json:"fail_reason"`
+ ResultURL string `json:"result_url,omitempty"` // 任务结果 URL(视频地址等)
+ SubmitTime int64 `json:"submit_time"`
+ StartTime int64 `json:"start_time"`
+ FinishTime int64 `json:"finish_time"`
+ Progress string `json:"progress"`
+ Properties any `json:"properties"`
+ Username string `json:"username,omitempty"`
+ Data json.RawMessage `json:"data"`
+}
+
+type FetchReq struct {
+ IDs []string `json:"ids"`
+}
diff --git a/logger/logger.go b/logger/logger.go
index 61b1d49d8..90cf5006e 100644
--- a/logger/logger.go
+++ b/logger/logger.go
@@ -2,7 +2,6 @@ package logger
import (
"context"
- "encoding/json"
"fmt"
"io"
"log"
@@ -151,7 +150,7 @@ func FormatQuota(quota int) string {
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
- jsonStr, err := json.Marshal(obj)
+ jsonStr, err := common.Marshal(obj)
if err != nil {
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
return
diff --git a/main.go b/main.go
index 852e1a0a8..476a2ed24 100644
--- a/main.go
+++ b/main.go
@@ -19,6 +19,7 @@ import (
"github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/oauth"
+ "github.com/QuantumNous/new-api/relay"
"github.com/QuantumNous/new-api/router"
"github.com/QuantumNous/new-api/service"
_ "github.com/QuantumNous/new-api/setting/performance_setting"
@@ -111,6 +112,15 @@ func main() {
// Subscription quota reset task (daily/weekly/monthly/custom)
service.StartSubscriptionQuotaResetTask()
+ // Wire task polling adaptor factory (breaks service -> relay import cycle)
+ service.GetTaskAdaptorFunc = func(platform constant.TaskPlatform) service.TaskPollingAdaptor {
+ a := relay.GetTaskAdaptor(platform)
+ if a == nil {
+ return nil
+ }
+ return a
+ }
+
if common.IsMasterNode && constant.UpdateTask {
gopool.Go(func() {
controller.UpdateMidjourneyTaskBulk()
diff --git a/middleware/auth.go b/middleware/auth.go
index cf1843510..342e7f498 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -170,6 +170,24 @@ func WssAuth(c *gin.Context) {
}
+// TokenOrUserAuth allows either session-based user auth or API token auth.
+// Used for endpoints that need to be accessible from both the dashboard and API clients.
+func TokenOrUserAuth() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ // Try session auth first (dashboard users)
+ session := sessions.Default(c)
+ if id := session.Get("id"); id != nil {
+ if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled {
+ c.Set("id", id)
+ c.Next()
+ return
+ }
+ }
+ // Fall back to token auth (API clients)
+ TokenAuth()(c)
+ }
+}
+
// TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
// 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
// 即使令牌已过期、已耗尽或已禁用,也允许访问。
diff --git a/middleware/logger.go b/middleware/logger.go
index b4ed8c89d..151008d9f 100644
--- a/middleware/logger.go
+++ b/middleware/logger.go
@@ -7,14 +7,28 @@ import (
"github.com/gin-gonic/gin"
)
+const RouteTagKey = "route_tag"
+
+func RouteTag(tag string) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ c.Set(RouteTagKey, tag)
+ c.Next()
+ }
+}
+
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
- requestID = param.Keys[common.RequestIdKey].(string)
+ requestID, _ = param.Keys[common.RequestIdKey].(string)
}
- return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
+ tag, _ := param.Keys[RouteTagKey].(string)
+ if tag == "" {
+ tag = "web"
+ }
+ return fmt.Sprintf("[GIN] %s | %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
+ tag,
requestID,
param.StatusCode,
param.Latency,
diff --git a/model/custom_oauth_provider.go b/model/custom_oauth_provider.go
index 43c69833a..12b4d1111 100644
--- a/model/custom_oauth_provider.go
+++ b/model/custom_oauth_provider.go
@@ -2,32 +2,65 @@ package model
import (
"errors"
+ "fmt"
"strings"
"time"
+
+ "github.com/QuantumNous/new-api/common"
)
+type accessPolicyPayload struct {
+ Logic string `json:"logic"`
+ Conditions []accessConditionItem `json:"conditions"`
+ Groups []accessPolicyPayload `json:"groups"`
+}
+
+type accessConditionItem struct {
+ Field string `json:"field"`
+ Op string `json:"op"`
+ Value any `json:"value"`
+}
+
+var supportedAccessPolicyOps = map[string]struct{}{
+ "eq": {},
+ "ne": {},
+ "gt": {},
+ "gte": {},
+ "lt": {},
+ "lte": {},
+ "in": {},
+ "not_in": {},
+ "contains": {},
+ "not_contains": {},
+ "exists": {},
+ "not_exists": {},
+}
+
// CustomOAuthProvider stores configuration for custom OAuth providers
type CustomOAuthProvider struct {
- Id int `json:"id" gorm:"primaryKey"`
- Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
- Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
- Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
- ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
- ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
- AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
- TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
- UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
- Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
+ Id int `json:"id" gorm:"primaryKey"`
+ Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
+ Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
+ Icon string `json:"icon" gorm:"type:varchar(128);default:''"` // Icon name from @lobehub/icons
+ Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
+ ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
+ ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
+ AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
+ TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
+ UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
+ Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
// Field mapping configuration (supports JSONPath via gjson)
- UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
- UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
- DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
- EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
+ UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
+ UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
+ DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
+ EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
// Advanced options
- WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
- AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
+ WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
+ AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
+ AccessPolicy string `json:"access_policy" gorm:"type:text"` // JSON policy for access control based on user info
+ AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` // Custom error message template when access is denied
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
@@ -158,6 +191,57 @@ func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if provider.Scopes == "" {
provider.Scopes = "openid profile email"
}
+ if strings.TrimSpace(provider.AccessPolicy) != "" {
+ var policy accessPolicyPayload
+ if err := common.UnmarshalJsonStr(provider.AccessPolicy, &policy); err != nil {
+ return errors.New("access_policy must be valid JSON")
+ }
+ if err := validateAccessPolicyPayload(&policy); err != nil {
+ return fmt.Errorf("access_policy is invalid: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func validateAccessPolicyPayload(policy *accessPolicyPayload) error {
+ if policy == nil {
+ return errors.New("policy is nil")
+ }
+
+ logic := strings.ToLower(strings.TrimSpace(policy.Logic))
+ if logic == "" {
+ logic = "and"
+ }
+ if logic != "and" && logic != "or" {
+ return fmt.Errorf("unsupported logic: %s", logic)
+ }
+
+ if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
+ return errors.New("policy requires at least one condition or group")
+ }
+
+ for index, condition := range policy.Conditions {
+ field := strings.TrimSpace(condition.Field)
+ if field == "" {
+ return fmt.Errorf("condition[%d].field is required", index)
+ }
+ op := strings.ToLower(strings.TrimSpace(condition.Op))
+ if _, ok := supportedAccessPolicyOps[op]; !ok {
+ return fmt.Errorf("condition[%d].op is unsupported: %s", index, op)
+ }
+ if op == "in" || op == "not_in" {
+ if _, ok := condition.Value.([]any); !ok {
+ return fmt.Errorf("condition[%d].value must be an array for op %s", index, op)
+ }
+ }
+ }
+
+ for index := range policy.Groups {
+ if err := validateAccessPolicyPayload(&policy.Groups[index]); err != nil {
+ return fmt.Errorf("group[%d]: %w", index, err)
+ }
+ }
return nil
}
diff --git a/model/log.go b/model/log.go
index d7cd97a42..2d4782fa5 100644
--- a/model/log.go
+++ b/model/log.go
@@ -199,6 +199,49 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
}
}
+type RecordTaskBillingLogParams struct {
+ UserId int
+ LogType int
+ Content string
+ ChannelId int
+ ModelName string
+ Quota int
+ TokenId int
+ Group string
+ Other map[string]interface{}
+}
+
+func RecordTaskBillingLog(params RecordTaskBillingLogParams) {
+ if params.LogType == LogTypeConsume && !common.LogConsumeEnabled {
+ return
+ }
+ username, _ := GetUsernameById(params.UserId, false)
+ tokenName := ""
+ if params.TokenId > 0 {
+ if token, err := GetTokenById(params.TokenId); err == nil {
+ tokenName = token.Name
+ }
+ }
+ log := &Log{
+ UserId: params.UserId,
+ Username: username,
+ CreatedAt: common.GetTimestamp(),
+ Type: params.LogType,
+ Content: params.Content,
+ TokenName: tokenName,
+ ModelName: params.ModelName,
+ Quota: params.Quota,
+ ChannelId: params.ChannelId,
+ TokenId: params.TokenId,
+ Group: params.Group,
+ Other: common.MapToJsonStr(params.Other),
+ }
+ err := LOG_DB.Create(log).Error
+ if err != nil {
+ common.SysLog("failed to record task billing log: " + err.Error())
+ }
+}
+
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
@@ -252,8 +295,24 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
Id int `gorm:"column:id"`
Name string `gorm:"column:name"`
}
- if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
- return logs, total, err
+ if common.MemoryCacheEnabled {
+ // Cache get channel
+ for _, channelId := range channelIds.Items() {
+ if cacheChannel, err := CacheGetChannel(channelId); err == nil {
+ channels = append(channels, struct {
+ Id int `gorm:"column:id"`
+ Name string `gorm:"column:name"`
+ }{
+ Id: channelId,
+ Name: cacheChannel.Name,
+ })
+ }
+ }
+ } else {
+ // Bulk query channels from DB
+ if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
+ return logs, total, err
+ }
}
channelMap := make(map[int]string, len(channels))
for _, channel := range channels {
diff --git a/model/midjourney.go b/model/midjourney.go
index c6ef5de5b..e1a8d772b 100644
--- a/model/midjourney.go
+++ b/model/midjourney.go
@@ -157,6 +157,19 @@ func (midjourney *Midjourney) Update() error {
return err
}
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Returns (true, nil) if this caller won the update, (false, nil) if
+// another process already moved the task out of fromStatus.
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback.
+func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) {
+ result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney)
+ if result.Error != nil {
+ return false, result.Error
+ }
+ return result.RowsAffected > 0, nil
+}
+
func MjBulkUpdate(mjIds []string, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("mj_id in (?)", mjIds).
diff --git a/model/task.go b/model/task.go
index 82c2e978a..984445083 100644
--- a/model/task.go
+++ b/model/task.go
@@ -1,10 +1,12 @@
package model
import (
+ "bytes"
"database/sql/driver"
"encoding/json"
"time"
+ "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
commonRelay "github.com/QuantumNous/new-api/relay/common"
@@ -64,13 +66,12 @@ type Task struct {
}
func (t *Task) SetData(data any) {
- b, _ := json.Marshal(data)
+ b, _ := common.Marshal(data)
t.Data = json.RawMessage(b)
}
func (t *Task) GetData(v any) error {
- err := json.Unmarshal(t.Data, &v)
- return err
+ return common.Unmarshal(t.Data, &v)
}
type Properties struct {
@@ -85,18 +86,59 @@ func (m *Properties) Scan(val interface{}) error {
*m = Properties{}
return nil
}
- return json.Unmarshal(bytesValue, m)
+ return common.Unmarshal(bytesValue, m)
}
func (m Properties) Value() (driver.Value, error) {
if m == (Properties{}) {
return nil, nil
}
- return json.Marshal(m)
+ return common.Marshal(m)
}
type TaskPrivateData struct {
- Key string `json:"key,omitempty"`
+ Key string `json:"key,omitempty"`
+ UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID
+ ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等)
+ // 计费上下文:用于异步退款/差额结算(轮询阶段读取)
+ BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription"
+ SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款
+ TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款
+ BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算)
+}
+
+// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
+type TaskBillingContext struct {
+ ModelPrice float64 `json:"model_price,omitempty"` // 模型单价
+ GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率
+ ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率
+ OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
+ OriginModelName string `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName
+ PerCallBilling bool `json:"per_call_billing,omitempty"` // 按次计费:跳过轮询阶段的差额结算
+}
+
+// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)
+// 旧数据没有 UpstreamTaskID 时,TaskID 本身就是上游 ID
+func (t *Task) GetUpstreamTaskID() string {
+ if t.PrivateData.UpstreamTaskID != "" {
+ return t.PrivateData.UpstreamTaskID
+ }
+ return t.TaskID
+}
+
+// GetResultURL 获取任务结果 URL(视频地址等)
+// 新数据存在 PrivateData.ResultURL 中;旧数据回退到 FailReason(历史兼容)
+func (t *Task) GetResultURL() string {
+ if t.PrivateData.ResultURL != "" {
+ return t.PrivateData.ResultURL
+ }
+ return t.FailReason
+}
+
+// GenerateTaskID 生成对外暴露的 task_xxxx 格式 ID
+func GenerateTaskID() string {
+ key, _ := common.GenerateRandomCharsKey(32)
+ return "task_" + key
}
func (p *TaskPrivateData) Scan(val interface{}) error {
@@ -104,14 +146,14 @@ func (p *TaskPrivateData) Scan(val interface{}) error {
if len(bytesValue) == 0 {
return nil
}
- return json.Unmarshal(bytesValue, p)
+ return common.Unmarshal(bytesValue, p)
}
func (p TaskPrivateData) Value() (driver.Value, error) {
if (p == TaskPrivateData{}) {
return nil, nil
}
- return json.Marshal(p)
+ return common.Marshal(p)
}
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
@@ -142,7 +184,16 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo)
}
}
+ // 使用预生成的公开 ID(如果有),否则新生成
+ taskID := ""
+ if relayInfo.TaskRelayInfo != nil && relayInfo.TaskRelayInfo.PublicTaskID != "" {
+ taskID = relayInfo.TaskRelayInfo.PublicTaskID
+ } else {
+ taskID = GenerateTaskID()
+ }
+
t := &Task{
+ TaskID: taskID,
UserId: relayInfo.UserId,
Group: relayInfo.UsingGroup,
SubmitTime: time.Now().Unix(),
@@ -237,6 +288,20 @@ func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*
return tasks
}
+func GetTimedOutUnfinishedTasks(cutoffUnix int64, limit int) []*Task {
+ var tasks []*Task
+ err := DB.Where("progress != ?", "100%").
+ Where("status NOT IN ?", []string{TaskStatusFailure, TaskStatusSuccess}).
+ Where("submit_time < ?", cutoffUnix).
+ Order("submit_time").
+ Limit(limit).
+ Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+ return tasks
+}
+
func GetAllUnFinishSyncTasks(limit int) []*Task {
var tasks []*Task
var err error
@@ -291,40 +356,70 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
return task, nil
}
-func TaskUpdateProgress(id int64, progress string) error {
- return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
-}
-
func (Task *Task) Insert() error {
var err error
err = DB.Create(Task).Error
return err
}
+type taskSnapshot struct {
+ Status TaskStatus
+ Progress string
+ StartTime int64
+ FinishTime int64
+ FailReason string
+ ResultURL string
+ Data json.RawMessage
+}
+
+func (s taskSnapshot) Equal(other taskSnapshot) bool {
+ return s.Status == other.Status &&
+ s.Progress == other.Progress &&
+ s.StartTime == other.StartTime &&
+ s.FinishTime == other.FinishTime &&
+ s.FailReason == other.FailReason &&
+ s.ResultURL == other.ResultURL &&
+ bytes.Equal(s.Data, other.Data)
+}
+
+func (t *Task) Snapshot() taskSnapshot {
+ return taskSnapshot{
+ Status: t.Status,
+ Progress: t.Progress,
+ StartTime: t.StartTime,
+ FinishTime: t.FinishTime,
+ FailReason: t.FailReason,
+ ResultURL: t.PrivateData.ResultURL,
+ Data: t.Data,
+ }
+}
+
func (Task *Task) Update() error {
var err error
err = DB.Save(Task).Error
return err
}
-func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
- if len(TaskIds) == 0 {
- return nil
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Returns (true, nil) if this caller won the update, (false, nil) if
+// another process already moved the task out of fromStatus.
+//
+// Uses Model().Select("*").Updates() instead of Save() because GORM's Save
+// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches
+// zero rows, which silently bypasses the CAS guard.
+func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
+ result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t)
+ if result.Error != nil {
+ return false, result.Error
}
- return DB.Model(&Task{}).
- Where("task_id in (?)", TaskIds).
- Updates(params).Error
-}
-
-func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
- if len(taskIDs) == 0 {
- return nil
- }
- return DB.Model(&Task{}).
- Where("id in (?)", taskIDs).
- Updates(params).Error
+ return result.RowsAffected > 0, nil
}
+// TaskBulkUpdateByID performs an unconditional bulk UPDATE by primary key IDs.
+// WARNING: This function has NO CAS (Compare-And-Swap) guard — it will overwrite
+// any concurrent status changes. DO NOT use in billing/quota lifecycle flows
+// (e.g., timeout, success, failure transitions that trigger refunds or settlements).
+// For status transitions that involve billing, use Task.UpdateWithStatus() instead.
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
if len(ids) == 0 {
return nil
@@ -339,37 +434,6 @@ type TaskQuotaUsage struct {
Count float64 `json:"count"`
}
-func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
- query := DB.Model(Task{})
- // 添加过滤条件
- if queryParams.ChannelID != "" {
- query = query.Where("channel_id = ?", queryParams.ChannelID)
- }
- if queryParams.UserID != "" {
- query = query.Where("user_id = ?", queryParams.UserID)
- }
- if len(queryParams.UserIDs) != 0 {
- query = query.Where("user_id in (?)", queryParams.UserIDs)
- }
- if queryParams.TaskID != "" {
- query = query.Where("task_id = ?", queryParams.TaskID)
- }
- if queryParams.Action != "" {
- query = query.Where("action = ?", queryParams.Action)
- }
- if queryParams.Status != "" {
- query = query.Where("status = ?", queryParams.Status)
- }
- if queryParams.StartTimestamp != 0 {
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != 0 {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
- err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
- return stat, err
-}
-
// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
var total int64
@@ -438,6 +502,6 @@ func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo {
openAIVideo.SetProgressStr(t.Progress)
openAIVideo.CreatedAt = t.CreatedAt
openAIVideo.CompletedAt = t.UpdatedAt
- openAIVideo.SetMetadata("url", t.FailReason)
+ openAIVideo.SetMetadata("url", t.GetResultURL())
return openAIVideo
}
diff --git a/model/task_cas_test.go b/model/task_cas_test.go
new file mode 100644
index 000000000..3449c6d26
--- /dev/null
+++ b/model/task_cas_test.go
@@ -0,0 +1,217 @@
+package model
+
+import (
+ "encoding/json"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/glebarez/sqlite"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gorm.io/gorm"
+)
+
+func TestMain(m *testing.M) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ panic("failed to open test db: " + err.Error())
+ }
+ DB = db
+ LOG_DB = db
+
+ common.UsingSQLite = true
+ common.RedisEnabled = false
+ common.BatchUpdateEnabled = false
+ common.LogConsumeEnabled = true
+
+ sqlDB, err := db.DB()
+ if err != nil {
+ panic("failed to get sql.DB: " + err.Error())
+ }
+ sqlDB.SetMaxOpenConns(1)
+
+ if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil {
+ panic("failed to migrate: " + err.Error())
+ }
+
+ os.Exit(m.Run())
+}
+
+func truncateTables(t *testing.T) {
+ t.Helper()
+ t.Cleanup(func() {
+ DB.Exec("DELETE FROM tasks")
+ DB.Exec("DELETE FROM users")
+ DB.Exec("DELETE FROM tokens")
+ DB.Exec("DELETE FROM logs")
+ DB.Exec("DELETE FROM channels")
+ })
+}
+
+func insertTask(t *testing.T, task *Task) {
+ t.Helper()
+ task.CreatedAt = time.Now().Unix()
+ task.UpdatedAt = time.Now().Unix()
+ require.NoError(t, DB.Create(task).Error)
+}
+
+// ---------------------------------------------------------------------------
+// Snapshot / Equal — pure logic tests (no DB)
+// ---------------------------------------------------------------------------
+
+func TestSnapshotEqual_Same(t *testing.T) {
+ s := taskSnapshot{
+ Status: TaskStatusInProgress,
+ Progress: "50%",
+ StartTime: 1000,
+ FinishTime: 0,
+ FailReason: "",
+ ResultURL: "",
+ Data: json.RawMessage(`{"key":"value"}`),
+ }
+ assert.True(t, s.Equal(s))
+}
+
+func TestSnapshotEqual_DifferentStatus(t *testing.T) {
+ a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
+ b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
+ assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_DifferentProgress(t *testing.T) {
+ a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
+ b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
+ assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_DifferentData(t *testing.T) {
+ a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
+ b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
+ assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
+ a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
+ b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
+ // bytes.Equal(nil, []byte{}) == true
+ assert.True(t, a.Equal(b))
+}
+
+func TestSnapshot_Roundtrip(t *testing.T) {
+ task := &Task{
+ Status: TaskStatusInProgress,
+ Progress: "42%",
+ StartTime: 1234,
+ FinishTime: 5678,
+ FailReason: "timeout",
+ PrivateData: TaskPrivateData{
+ ResultURL: "https://example.com/result.mp4",
+ },
+ Data: json.RawMessage(`{"model":"test-model"}`),
+ }
+ snap := task.Snapshot()
+ assert.Equal(t, task.Status, snap.Status)
+ assert.Equal(t, task.Progress, snap.Progress)
+ assert.Equal(t, task.StartTime, snap.StartTime)
+ assert.Equal(t, task.FinishTime, snap.FinishTime)
+ assert.Equal(t, task.FailReason, snap.FailReason)
+ assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
+ assert.JSONEq(t, string(task.Data), string(snap.Data))
+}
+
+// ---------------------------------------------------------------------------
+// UpdateWithStatus CAS — DB integration tests
+// ---------------------------------------------------------------------------
+
+func TestUpdateWithStatus_Win(t *testing.T) {
+ truncateTables(t)
+
+ task := &Task{
+ TaskID: "task_cas_win",
+ Status: TaskStatusInProgress,
+ Progress: "50%",
+ Data: json.RawMessage(`{}`),
+ }
+ insertTask(t, task)
+
+ task.Status = TaskStatusSuccess
+ task.Progress = "100%"
+ won, err := task.UpdateWithStatus(TaskStatusInProgress)
+ require.NoError(t, err)
+ assert.True(t, won)
+
+ var reloaded Task
+ require.NoError(t, DB.First(&reloaded, task.ID).Error)
+ assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
+ assert.Equal(t, "100%", reloaded.Progress)
+}
+
+func TestUpdateWithStatus_Lose(t *testing.T) {
+ truncateTables(t)
+
+ task := &Task{
+ TaskID: "task_cas_lose",
+ Status: TaskStatusFailure,
+ Data: json.RawMessage(`{}`),
+ }
+ insertTask(t, task)
+
+ task.Status = TaskStatusSuccess
+ won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
+ require.NoError(t, err)
+ assert.False(t, won)
+
+ var reloaded Task
+ require.NoError(t, DB.First(&reloaded, task.ID).Error)
+ assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
+}
+
+func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
+ truncateTables(t)
+
+ task := &Task{
+ TaskID: "task_cas_race",
+ Status: TaskStatusInProgress,
+ Quota: 1000,
+ Data: json.RawMessage(`{}`),
+ }
+ insertTask(t, task)
+
+ const goroutines = 5
+ wins := make([]bool, goroutines)
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func(idx int) {
+ defer wg.Done()
+ t := &Task{}
+ *t = Task{
+ ID: task.ID,
+ TaskID: task.TaskID,
+ Status: TaskStatusSuccess,
+ Progress: "100%",
+ Quota: task.Quota,
+ Data: json.RawMessage(`{}`),
+ }
+ t.CreatedAt = task.CreatedAt
+ t.UpdatedAt = time.Now().Unix()
+ won, err := t.UpdateWithStatus(TaskStatusInProgress)
+ if err == nil {
+ wins[idx] = won
+ }
+ }(i)
+ }
+ wg.Wait()
+
+ winCount := 0
+ for _, w := range wins {
+ if w {
+ winCount++
+ }
+ }
+ assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
+}
diff --git a/model/token.go b/model/token.go
index 9e05b63ca..773b2d792 100644
--- a/model/token.go
+++ b/model/token.go
@@ -360,7 +360,7 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete()
}
-func IncreaseTokenQuota(id int, key string, quota int) (err error) {
+func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -373,10 +373,10 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
})
}
if common.BatchUpdateEnabled {
- addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
+ addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota)
return nil
}
- return increaseTokenQuota(id, quota)
+ return increaseTokenQuota(tokenId, quota)
}
func increaseTokenQuota(id int, quota int) (err error) {
diff --git a/model/user.go b/model/user.go
index e0c9c686f..1210b5435 100644
--- a/model/user.go
+++ b/model/user.go
@@ -1,6 +1,7 @@
package model
import (
+ "database/sql"
"encoding/json"
"errors"
"fmt"
@@ -15,6 +16,8 @@ import (
"gorm.io/gorm"
)
+const UserNameMaxLength = 20
+
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct {
@@ -536,6 +539,37 @@ func (user *User) Edit(updatePassword bool) error {
return updateUserCache(*user)
}
+func (user *User) ClearBinding(bindingType string) error {
+ if user.Id == 0 {
+ return errors.New("user id is empty")
+ }
+
+ bindingColumnMap := map[string]string{
+ "email": "email",
+ "github": "github_id",
+ "discord": "discord_id",
+ "oidc": "oidc_id",
+ "wechat": "wechat_id",
+ "telegram": "telegram_id",
+ "linuxdo": "linux_do_id",
+ }
+
+ column, ok := bindingColumnMap[bindingType]
+ if !ok {
+ return errors.New("invalid binding type")
+ }
+
+ if err := DB.Model(&User{}).Where("id = ?", user.Id).Update(column, "").Error; err != nil {
+ return err
+ }
+
+ if err := DB.Where("id = ?", user.Id).First(user).Error; err != nil {
+ return err
+ }
+
+ return updateUserCache(*user)
+}
+
func (user *User) Delete() error {
if user.Id == 0 {
return errors.New("id 为空!")
@@ -820,10 +854,17 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
// Don't return error - fall through to DB
}
fromDB = true
- err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
+ // can be nil setting
+ var safeSetting sql.NullString
+ err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&safeSetting).Error
if err != nil {
return settingMap, err
}
+ if safeSetting.Valid {
+ setting = safeSetting.String
+ } else {
+ setting = ""
+ }
userBase := &UserBase{
Setting: setting,
}
diff --git a/oauth/generic.go b/oauth/generic.go
index c7aa87931..bc18054d5 100644
--- a/oauth/generic.go
+++ b/oauth/generic.go
@@ -3,19 +3,24 @@ package oauth
import (
"context"
"encoding/base64"
- "encoding/json"
+ stdjson "encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
"net/url"
+ "regexp"
+ "strconv"
"strings"
"time"
+ "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
+ "github.com/samber/lo"
"github.com/tidwall/gjson"
)
@@ -31,6 +36,40 @@ type GenericOAuthProvider struct {
config *model.CustomOAuthProvider
}
+type accessPolicy struct {
+ Logic string `json:"logic"`
+ Conditions []accessCondition `json:"conditions"`
+ Groups []accessPolicy `json:"groups"`
+}
+
+type accessCondition struct {
+ Field string `json:"field"`
+ Op string `json:"op"`
+ Value any `json:"value"`
+}
+
+type accessPolicyFailure struct {
+ Field string
+ Op string
+ Expected any
+ Current any
+}
+
+var supportedAccessPolicyOps = []string{
+ "eq",
+ "ne",
+ "gt",
+ "gte",
+ "lt",
+ "lte",
+ "in",
+ "not_in",
+ "contains",
+ "not_contains",
+ "exists",
+ "not_exists",
+}
+
// NewGenericOAuthProvider creates a new generic OAuth provider from config
func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
return &GenericOAuthProvider{config: config}
@@ -125,7 +164,7 @@ func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c
ErrorDesc string `json:"error_description"`
}
- if err := json.Unmarshal(body, &tokenResponse); err != nil {
+ if err := common.Unmarshal(body, &tokenResponse); err != nil {
// Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
parsedValues, parseErr := url.ParseQuery(bodyStr)
if parseErr != nil {
@@ -227,11 +266,30 @@ func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToke
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
p.config.Slug, userId, username, displayName, email)
+ policyRaw := strings.TrimSpace(p.config.AccessPolicy)
+ if policyRaw != "" {
+ policy, err := parseAccessPolicy(policyRaw)
+ if err != nil {
+ logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] invalid access policy: %s", p.config.Slug, err.Error()))
+ return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, nil, "invalid access policy configuration")
+ }
+ allowed, failure := evaluateAccessPolicy(bodyStr, policy)
+ if !allowed {
+ message := renderAccessDeniedMessage(p.config.AccessDeniedMessage, p.config.Name, bodyStr, failure)
+ logger.LogWarn(ctx, fmt.Sprintf("[OAuth-Generic-%s] access denied by policy: field=%s op=%s expected=%v current=%v",
+ p.config.Slug, failure.Field, failure.Op, failure.Expected, failure.Current))
+ return nil, &AccessDeniedError{Message: message}
+ }
+ }
+
return &OAuthUser{
ProviderUserID: userId,
Username: username,
DisplayName: displayName,
Email: email,
+ Extra: map[string]any{
+ "provider": p.config.Slug,
+ },
}, nil
}
@@ -266,3 +324,345 @@ func (p *GenericOAuthProvider) GetProviderId() int {
func (p *GenericOAuthProvider) IsGenericProvider() bool {
return true
}
+
+func parseAccessPolicy(raw string) (*accessPolicy, error) {
+ var policy accessPolicy
+ if err := common.UnmarshalJsonStr(raw, &policy); err != nil {
+ return nil, err
+ }
+ if err := validateAccessPolicy(&policy); err != nil {
+ return nil, err
+ }
+ return &policy, nil
+}
+
+func validateAccessPolicy(policy *accessPolicy) error {
+ if policy == nil {
+ return errors.New("policy is nil")
+ }
+
+ logic := strings.ToLower(strings.TrimSpace(policy.Logic))
+ if logic == "" {
+ logic = "and"
+ }
+ if !lo.Contains([]string{"and", "or"}, logic) {
+ return fmt.Errorf("unsupported policy logic: %s", logic)
+ }
+ policy.Logic = logic
+
+ if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
+ return errors.New("policy requires at least one condition or group")
+ }
+
+ for index := range policy.Conditions {
+ if err := validateAccessCondition(&policy.Conditions[index], index); err != nil {
+ return err
+ }
+ }
+
+ for index := range policy.Groups {
+ if err := validateAccessPolicy(&policy.Groups[index]); err != nil {
+ return fmt.Errorf("invalid policy group[%d]: %w", index, err)
+ }
+ }
+
+ return nil
+}
+
+func validateAccessCondition(condition *accessCondition, index int) error {
+ if condition == nil {
+ return fmt.Errorf("condition[%d] is nil", index)
+ }
+
+ condition.Field = strings.TrimSpace(condition.Field)
+ if condition.Field == "" {
+ return fmt.Errorf("condition[%d].field is required", index)
+ }
+
+ condition.Op = normalizePolicyOp(condition.Op)
+ if !lo.Contains(supportedAccessPolicyOps, condition.Op) {
+ return fmt.Errorf("condition[%d].op is unsupported: %s", index, condition.Op)
+ }
+
+ if lo.Contains([]string{"in", "not_in"}, condition.Op) {
+ if _, ok := condition.Value.([]any); !ok {
+ return fmt.Errorf("condition[%d].value must be an array for op %s", index, condition.Op)
+ }
+ }
+
+ return nil
+}
+
+func evaluateAccessPolicy(body string, policy *accessPolicy) (bool, *accessPolicyFailure) {
+ if policy == nil {
+ return true, nil
+ }
+
+ logic := strings.ToLower(strings.TrimSpace(policy.Logic))
+ if logic == "" {
+ logic = "and"
+ }
+
+ hasAny := len(policy.Conditions) > 0 || len(policy.Groups) > 0
+ if !hasAny {
+ return true, nil
+ }
+
+ if logic == "or" {
+ var firstFailure *accessPolicyFailure
+ for _, cond := range policy.Conditions {
+ ok, failure := evaluateAccessCondition(body, cond)
+ if ok {
+ return true, nil
+ }
+ if firstFailure == nil {
+ firstFailure = failure
+ }
+ }
+ for _, group := range policy.Groups {
+ ok, failure := evaluateAccessPolicy(body, &group)
+ if ok {
+ return true, nil
+ }
+ if firstFailure == nil {
+ firstFailure = failure
+ }
+ }
+ return false, firstFailure
+ }
+
+ for _, cond := range policy.Conditions {
+ ok, failure := evaluateAccessCondition(body, cond)
+ if !ok {
+ return false, failure
+ }
+ }
+ for _, group := range policy.Groups {
+ ok, failure := evaluateAccessPolicy(body, &group)
+ if !ok {
+ return false, failure
+ }
+ }
+ return true, nil
+}
+
+func evaluateAccessCondition(body string, cond accessCondition) (bool, *accessPolicyFailure) {
+ path := cond.Field
+ op := cond.Op
+ result := gjson.Get(body, path)
+ current := gjsonResultToValue(result)
+ failure := &accessPolicyFailure{
+ Field: path,
+ Op: op,
+ Expected: cond.Value,
+ Current: current,
+ }
+
+ switch op {
+ case "exists":
+ return result.Exists(), failure
+ case "not_exists":
+ return !result.Exists(), failure
+ case "eq":
+ return compareAny(current, cond.Value) == 0, failure
+ case "ne":
+ return compareAny(current, cond.Value) != 0, failure
+ case "gt":
+ return compareAny(current, cond.Value) > 0, failure
+ case "gte":
+ return compareAny(current, cond.Value) >= 0, failure
+ case "lt":
+ return compareAny(current, cond.Value) < 0, failure
+ case "lte":
+ return compareAny(current, cond.Value) <= 0, failure
+ case "in":
+ return valueInSlice(current, cond.Value), failure
+ case "not_in":
+ return !valueInSlice(current, cond.Value), failure
+ case "contains":
+ return containsValue(current, cond.Value), failure
+ case "not_contains":
+ return !containsValue(current, cond.Value), failure
+ default:
+ return false, failure
+ }
+}
+
+func normalizePolicyOp(op string) string {
+ return strings.ToLower(strings.TrimSpace(op))
+}
+
+func gjsonResultToValue(result gjson.Result) any {
+ if !result.Exists() {
+ return nil
+ }
+ if result.IsArray() {
+ arr := result.Array()
+ values := make([]any, 0, len(arr))
+ for _, item := range arr {
+ values = append(values, gjsonResultToValue(item))
+ }
+ return values
+ }
+ switch result.Type {
+ case gjson.Null:
+ return nil
+ case gjson.True:
+ return true
+ case gjson.False:
+ return false
+ case gjson.Number:
+ return result.Num
+ case gjson.String:
+ return result.String()
+ case gjson.JSON:
+ var data any
+ if err := common.UnmarshalJsonStr(result.Raw, &data); err == nil {
+ return data
+ }
+ return result.Raw
+ default:
+ return result.Value()
+ }
+}
+
+func compareAny(left any, right any) int {
+ if lf, ok := toFloat(left); ok {
+ if rf, ok2 := toFloat(right); ok2 {
+ switch {
+ case lf < rf:
+ return -1
+ case lf > rf:
+ return 1
+ default:
+ return 0
+ }
+ }
+ }
+
+ ls := strings.TrimSpace(fmt.Sprint(left))
+ rs := strings.TrimSpace(fmt.Sprint(right))
+ switch {
+ case ls < rs:
+ return -1
+ case ls > rs:
+ return 1
+ default:
+ return 0
+ }
+}
+
+func toFloat(v any) (float64, bool) {
+ switch value := v.(type) {
+ case float64:
+ return value, true
+ case float32:
+ return float64(value), true
+ case int:
+ return float64(value), true
+ case int8:
+ return float64(value), true
+ case int16:
+ return float64(value), true
+ case int32:
+ return float64(value), true
+ case int64:
+ return float64(value), true
+ case uint:
+ return float64(value), true
+ case uint8:
+ return float64(value), true
+ case uint16:
+ return float64(value), true
+ case uint32:
+ return float64(value), true
+ case uint64:
+ return float64(value), true
+ case stdjson.Number:
+ n, err := value.Float64()
+ if err == nil {
+ return n, true
+ }
+ case string:
+ n, err := strconv.ParseFloat(strings.TrimSpace(value), 64)
+ if err == nil {
+ return n, true
+ }
+ }
+ return 0, false
+}
+
+func valueInSlice(current any, expected any) bool {
+ list, ok := expected.([]any)
+ if !ok {
+ return false
+ }
+ return lo.ContainsBy(list, func(item any) bool {
+ return compareAny(current, item) == 0
+ })
+}
+
+func containsValue(current any, expected any) bool {
+ switch value := current.(type) {
+ case string:
+ target := strings.TrimSpace(fmt.Sprint(expected))
+ return strings.Contains(value, target)
+ case []any:
+ return lo.ContainsBy(value, func(item any) bool {
+ return compareAny(item, expected) == 0
+ })
+ }
+ return false
+}
+
+func renderAccessDeniedMessage(template string, providerName string, body string, failure *accessPolicyFailure) string {
+ defaultMessage := "Access denied: your account does not meet this provider's access requirements."
+ message := strings.TrimSpace(template)
+ if message == "" {
+ return defaultMessage
+ }
+
+ if failure == nil {
+ failure = &accessPolicyFailure{}
+ }
+
+ replacements := map[string]string{
+ "{{provider}}": providerName,
+ "{{field}}": failure.Field,
+ "{{op}}": failure.Op,
+ "{{required}}": fmt.Sprint(failure.Expected),
+ "{{current}}": fmt.Sprint(failure.Current),
+ }
+
+ for key, value := range replacements {
+ message = strings.ReplaceAll(message, key, value)
+ }
+
+ currentPattern := regexp.MustCompile(`\{\{current\.([^}]+)\}\}`)
+ message = currentPattern.ReplaceAllStringFunc(message, func(token string) string {
+ match := currentPattern.FindStringSubmatch(token)
+ if len(match) != 2 {
+ return ""
+ }
+ path := strings.TrimSpace(match[1])
+ if path == "" {
+ return ""
+ }
+ return strings.TrimSpace(gjson.Get(body, path).String())
+ })
+
+ requiredPattern := regexp.MustCompile(`\{\{required\.([^}]+)\}\}`)
+ message = requiredPattern.ReplaceAllStringFunc(message, func(token string) string {
+ match := requiredPattern.FindStringSubmatch(token)
+ if len(match) != 2 {
+ return ""
+ }
+ path := strings.TrimSpace(match[1])
+ if failure.Field == path {
+ return fmt.Sprint(failure.Expected)
+ }
+ return ""
+ })
+
+ return strings.TrimSpace(message)
+}
diff --git a/oauth/types.go b/oauth/types.go
index 1b0e3646a..383e6f351 100644
--- a/oauth/types.go
+++ b/oauth/types.go
@@ -57,3 +57,12 @@ func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string)
RawError: rawError,
}
}
+
+// AccessDeniedError is a direct user-facing access denial message.
+type AccessDeniedError struct {
+ Message string
+}
+
+func (e *AccessDeniedError) Error() string {
+ return e.Message
+}
diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go
index ff7606e2e..d2f7c6bb6 100644
--- a/relay/channel/adapter.go
+++ b/relay/channel/adapter.go
@@ -36,6 +36,32 @@ type TaskAdaptor interface {
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
+ // ── Billing ──────────────────────────────────────────────────────
+
+ // EstimateBilling returns OtherRatios for pre-charge based on user request.
+ // Called after ValidateRequestAndSetAction, before price calculation.
+ // Adaptors should extract duration, resolution, etc. from the parsed request
+ // and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}).
+ // Return nil to use the base model price without extra ratios.
+ EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64
+
+ // AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream
+ // submit response. Called after a successful DoResponse.
+ // If the upstream returned actual parameters that differ from the estimate
+ // (e.g. actual seconds), return updated ratios so the caller can recalculate
+ // the quota and settle the delta with the pre-charge.
+ // Return nil if no adjustment is needed.
+ AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64
+
+ // AdjustBillingOnComplete returns the actual quota when a task reaches a
+ // terminal state (success/failure) during polling.
+ // Called by the polling loop after ParseTaskResult.
+ // Return a positive value to trigger delta settlement (supplement / refund).
+ // Return 0 to keep the pre-charged amount unchanged.
+ AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
+
+ // ── Request / Response ───────────────────────────────────────────
+
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
@@ -46,9 +72,9 @@ type TaskAdaptor interface {
GetModelList() []string
GetChannelName() string
- // FetchTask
- FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
+ // ── Polling ──────────────────────────────────────────────────────
+ FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}
diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go
index dcdff584b..49773e1e6 100644
--- a/relay/channel/api_request.go
+++ b/relay/channel/api_request.go
@@ -61,8 +61,9 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{
"cookie": {},
// Additional headers that should not be forwarded by name-matching passthrough rules.
- "host": {},
- "content-length": {},
+ "host": {},
+ "content-length": {},
+ "accept-encoding": {},
// Do not passthrough credentials by wildcard/regex.
"authorization": {},
diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go
index 31e15340a..791379b90 100644
--- a/relay/channel/api_request_test.go
+++ b/relay/channel/api_request_test.go
@@ -110,3 +110,30 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
_, ok := headers["X-Legacy"]
require.False(t, ok)
}
+
+func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+ ctx.Request.Header.Set("X-Trace-Id", "trace-123")
+ ctx.Request.Header.Set("Accept-Encoding", "gzip")
+
+ info := &relaycommon.RelayInfo{
+ IsChannelTest: false,
+ ChannelMeta: &relaycommon.ChannelMeta{
+ HeadersOverride: map[string]any{
+ "*": "",
+ },
+ },
+ }
+
+ headers, err := processHeaderOverride(info, ctx)
+ require.NoError(t, err)
+ require.Equal(t, "trace-123", headers["X-Trace-Id"])
+
+ _, hasAcceptEncoding := headers["Accept-Encoding"]
+ require.False(t, hasAcceptEncoding)
+}
diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go
index 39485b16f..1a434a432 100644
--- a/relay/channel/gemini/relay-gemini-native.go
+++ b/relay/channel/gemini/relay-gemini-native.go
@@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
}
// 计算使用量(基于 UsageMetadata)
- usage := dto.Usage{
- PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
- CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
- TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
- }
-
- usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
- usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
-
- for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
- if detail.Modality == "AUDIO" {
- usage.PromptTokensDetails.AudioTokens = detail.TokenCount
- } else if detail.Modality == "TEXT" {
- usage.PromptTokensDetails.TextTokens = detail.TokenCount
- }
- }
+ usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
service.IOCopyBytesGracefully(c, resp, responseBody)
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index b10ec06c7..b81a148a3 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
}
}
+func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
+ promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
+ if promptTokens <= 0 && fallbackPromptTokens > 0 {
+ promptTokens = fallbackPromptTokens
+ }
+
+ usage := dto.Usage{
+ PromptTokens: promptTokens,
+ CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
+ TotalTokens: metadata.TotalTokenCount,
+ }
+ usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
+ usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
+
+ for _, detail := range metadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens += detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens += detail.TokenCount
+ }
+ }
+ for _, detail := range metadata.ToolUsePromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens += detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens += detail.TokenCount
+ }
+ }
+
+ if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
+ usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+ }
+
+ if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
+ usage.PromptTokensDetails.TextTokens = usage.PromptTokens
+ }
+
+ return usage
+}
+
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: helper.GetResponseID(c),
@@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
- usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
- usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
- usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
- usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
- usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
- for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
- if detail.Modality == "AUDIO" {
- usage.PromptTokensDetails.AudioTokens = detail.TokenCount
- } else if detail.Modality == "TEXT" {
- usage.PromptTokensDetails.TextTokens = detail.TokenCount
- }
- }
+ mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
+ *usage = mappedUsage
}
return callback(data, &geminiResponse)
@@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
}
- usage.PromptTokensDetails.TextTokens = usage.PromptTokens
- if usage.TotalTokens > 0 {
- usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
- }
-
if usage.CompletionTokens <= 0 {
if info.ReceivedResponseCount > 0 {
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
@@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 {
- usage := dto.Usage{
- PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
- }
- usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
- usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
- for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
- if detail.Modality == "AUDIO" {
- usage.PromptTokensDetails.AudioTokens = detail.TokenCount
- } else if detail.Modality == "TEXT" {
- usage.PromptTokensDetails.TextTokens = detail.TokenCount
- }
- }
- if usage.PromptTokens <= 0 {
- usage.PromptTokens = info.GetEstimatePromptTokens()
- }
+ usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
var newAPIError *types.NewAPIError
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
@@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
- usage := dto.Usage{
- PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
- CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
- TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
- }
-
- usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
- usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
- usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
-
- for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
- if detail.Modality == "AUDIO" {
- usage.PromptTokensDetails.AudioTokens = detail.TokenCount
- } else if detail.Modality == "TEXT" {
- usage.PromptTokensDetails.TextTokens = detail.TokenCount
- }
- }
+ usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
fullTextResponse.Usage = usage
diff --git a/relay/channel/gemini/relay_gemini_usage_test.go b/relay/channel/gemini/relay_gemini_usage_test.go
new file mode 100644
index 000000000..c8f9f8343
--- /dev/null
+++ b/relay/channel/gemini/relay_gemini_usage_test.go
@@ -0,0 +1,333 @@
+package gemini
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/constant"
+ "github.com/QuantumNous/new-api/dto"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
+ "github.com/QuantumNous/new-api/types"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ info := &relaycommon.RelayInfo{
+ RelayFormat: types.RelayFormatGemini,
+ OriginModelName: "gemini-3-flash-preview",
+ ChannelMeta: &relaycommon.ChannelMeta{
+ UpstreamModelName: "gemini-3-flash-preview",
+ },
+ }
+
+ payload := dto.GeminiChatResponse{
+ Candidates: []dto.GeminiChatCandidate{
+ {
+ Content: dto.GeminiChatContent{
+ Role: "model",
+ Parts: []dto.GeminiPart{
+ {Text: "ok"},
+ },
+ },
+ },
+ },
+ UsageMetadata: dto.GeminiUsageMetadata{
+ PromptTokenCount: 151,
+ ToolUsePromptTokenCount: 18329,
+ CandidatesTokenCount: 1089,
+ ThoughtsTokenCount: 1120,
+ TotalTokenCount: 20689,
+ },
+ }
+
+ body, err := common.Marshal(payload)
+ require.NoError(t, err)
+
+ resp := &http.Response{
+ Body: io.NopCloser(bytes.NewReader(body)),
+ }
+
+ usage, newAPIError := GeminiChatHandler(c, info, resp)
+ require.Nil(t, newAPIError)
+ require.NotNil(t, usage)
+ require.Equal(t, 18480, usage.PromptTokens)
+ require.Equal(t, 2209, usage.CompletionTokens)
+ require.Equal(t, 20689, usage.TotalTokens)
+ require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
+}
+
+func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ oldStreamingTimeout := constant.StreamingTimeout
+ constant.StreamingTimeout = 300
+ t.Cleanup(func() {
+ constant.StreamingTimeout = oldStreamingTimeout
+ })
+
+ info := &relaycommon.RelayInfo{
+ OriginModelName: "gemini-3-flash-preview",
+ ChannelMeta: &relaycommon.ChannelMeta{
+ UpstreamModelName: "gemini-3-flash-preview",
+ },
+ }
+
+ chunk := dto.GeminiChatResponse{
+ Candidates: []dto.GeminiChatCandidate{
+ {
+ Content: dto.GeminiChatContent{
+ Role: "model",
+ Parts: []dto.GeminiPart{
+ {Text: "partial"},
+ },
+ },
+ },
+ },
+ UsageMetadata: dto.GeminiUsageMetadata{
+ PromptTokenCount: 151,
+ ToolUsePromptTokenCount: 18329,
+ CandidatesTokenCount: 1089,
+ ThoughtsTokenCount: 1120,
+ TotalTokenCount: 20689,
+ },
+ }
+
+ chunkData, err := common.Marshal(chunk)
+ require.NoError(t, err)
+
+ streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
+ resp := &http.Response{
+ Body: io.NopCloser(bytes.NewReader(streamBody)),
+ }
+
+ usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
+ return true
+ })
+ require.Nil(t, newAPIError)
+ require.NotNil(t, usage)
+ require.Equal(t, 18480, usage.PromptTokens)
+ require.Equal(t, 2209, usage.CompletionTokens)
+ require.Equal(t, 20689, usage.TotalTokens)
+ require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
+}
+
+func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
+
+ info := &relaycommon.RelayInfo{
+ OriginModelName: "gemini-3-flash-preview",
+ ChannelMeta: &relaycommon.ChannelMeta{
+ UpstreamModelName: "gemini-3-flash-preview",
+ },
+ }
+
+ payload := dto.GeminiChatResponse{
+ Candidates: []dto.GeminiChatCandidate{
+ {
+ Content: dto.GeminiChatContent{
+ Role: "model",
+ Parts: []dto.GeminiPart{
+ {Text: "ok"},
+ },
+ },
+ },
+ },
+ UsageMetadata: dto.GeminiUsageMetadata{
+ PromptTokenCount: 151,
+ ToolUsePromptTokenCount: 18329,
+ CandidatesTokenCount: 1089,
+ ThoughtsTokenCount: 1120,
+ TotalTokenCount: 20689,
+ },
+ }
+
+ body, err := common.Marshal(payload)
+ require.NoError(t, err)
+
+ resp := &http.Response{
+ Body: io.NopCloser(bytes.NewReader(body)),
+ }
+
+ usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
+ require.Nil(t, newAPIError)
+ require.NotNil(t, usage)
+ require.Equal(t, 18480, usage.PromptTokens)
+ require.Equal(t, 2209, usage.CompletionTokens)
+ require.Equal(t, 20689, usage.TotalTokens)
+ require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
+}
+
+func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ info := &relaycommon.RelayInfo{
+ RelayFormat: types.RelayFormatGemini,
+ OriginModelName: "gemini-3-flash-preview",
+ ChannelMeta: &relaycommon.ChannelMeta{
+ UpstreamModelName: "gemini-3-flash-preview",
+ },
+ }
+ info.SetEstimatePromptTokens(20)
+
+ payload := dto.GeminiChatResponse{
+ Candidates: []dto.GeminiChatCandidate{
+ {
+ Content: dto.GeminiChatContent{
+ Role: "model",
+ Parts: []dto.GeminiPart{
+ {Text: "ok"},
+ },
+ },
+ },
+ },
+ UsageMetadata: dto.GeminiUsageMetadata{
+ PromptTokenCount: 0,
+ ToolUsePromptTokenCount: 0,
+ CandidatesTokenCount: 90,
+ ThoughtsTokenCount: 10,
+ TotalTokenCount: 110,
+ },
+ }
+
+ body, err := common.Marshal(payload)
+ require.NoError(t, err)
+
+ resp := &http.Response{
+ Body: io.NopCloser(bytes.NewReader(body)),
+ }
+
+ usage, newAPIError := GeminiChatHandler(c, info, resp)
+ require.Nil(t, newAPIError)
+ require.NotNil(t, usage)
+ require.Equal(t, 20, usage.PromptTokens)
+ require.Equal(t, 100, usage.CompletionTokens)
+ require.Equal(t, 110, usage.TotalTokens)
+}
+
+func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ oldStreamingTimeout := constant.StreamingTimeout
+ constant.StreamingTimeout = 300
+ t.Cleanup(func() {
+ constant.StreamingTimeout = oldStreamingTimeout
+ })
+
+ info := &relaycommon.RelayInfo{
+ OriginModelName: "gemini-3-flash-preview",
+ ChannelMeta: &relaycommon.ChannelMeta{
+ UpstreamModelName: "gemini-3-flash-preview",
+ },
+ }
+ info.SetEstimatePromptTokens(20)
+
+ chunk := dto.GeminiChatResponse{
+ Candidates: []dto.GeminiChatCandidate{
+ {
+ Content: dto.GeminiChatContent{
+ Role: "model",
+ Parts: []dto.GeminiPart{
+ {Text: "partial"},
+ },
+ },
+ },
+ },
+ UsageMetadata: dto.GeminiUsageMetadata{
+ PromptTokenCount: 0,
+ ToolUsePromptTokenCount: 0,
+ CandidatesTokenCount: 90,
+ ThoughtsTokenCount: 10,
+ TotalTokenCount: 110,
+ },
+ }
+
+ chunkData, err := common.Marshal(chunk)
+ require.NoError(t, err)
+
+ streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
+ resp := &http.Response{
+ Body: io.NopCloser(bytes.NewReader(streamBody)),
+ }
+
+ usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
+ return true
+ })
+ require.Nil(t, newAPIError)
+ require.NotNil(t, usage)
+ require.Equal(t, 20, usage.PromptTokens)
+ require.Equal(t, 100, usage.CompletionTokens)
+ require.Equal(t, 110, usage.TotalTokens)
+}
+
+func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
+
+ info := &relaycommon.RelayInfo{
+ OriginModelName: "gemini-3-flash-preview",
+ ChannelMeta: &relaycommon.ChannelMeta{
+ UpstreamModelName: "gemini-3-flash-preview",
+ },
+ }
+ info.SetEstimatePromptTokens(20)
+
+ payload := dto.GeminiChatResponse{
+ Candidates: []dto.GeminiChatCandidate{
+ {
+ Content: dto.GeminiChatContent{
+ Role: "model",
+ Parts: []dto.GeminiPart{
+ {Text: "ok"},
+ },
+ },
+ },
+ },
+ UsageMetadata: dto.GeminiUsageMetadata{
+ PromptTokenCount: 0,
+ ToolUsePromptTokenCount: 0,
+ CandidatesTokenCount: 90,
+ ThoughtsTokenCount: 10,
+ TotalTokenCount: 110,
+ },
+ }
+
+ body, err := common.Marshal(payload)
+ require.NoError(t, err)
+
+ resp := &http.Response{
+ Body: io.NopCloser(bytes.NewReader(body)),
+ }
+
+ usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
+ require.Nil(t, newAPIError)
+ require.NotNil(t, usage)
+ require.Equal(t, 20, usage.PromptTokens)
+ require.Equal(t, 100, usage.CompletionTokens)
+ require.Equal(t, 110, usage.TotalTokens)
+}
diff --git a/relay/channel/minimax/adaptor.go b/relay/channel/minimax/adaptor.go
index 8235abc05..d244e695a 100644
--- a/relay/channel/minimax/adaptor.go
+++ b/relay/channel/minimax/adaptor.go
@@ -10,6 +10,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
+ "github.com/QuantumNous/new-api/relay/channel/claude"
"github.com/QuantumNous/new-api/relay/channel/openai"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
@@ -26,7 +27,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- return nil, errors.New("not implemented")
+ adaptor := claude.Adaptor{}
+ return adaptor.ConvertClaudeRequest(c, info, req)
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -119,8 +121,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return handleTTSResponse(c, resp, info)
}
- adaptor := openai.Adaptor{}
- return adaptor.DoResponse(c, resp, info)
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ adaptor := claude.Adaptor{}
+ return adaptor.DoResponse(c, resp, info)
+ default:
+ adaptor := openai.Adaptor{}
+ return adaptor.DoResponse(c, resp, info)
+ }
}
func (a *Adaptor) GetModelList() []string {
diff --git a/relay/channel/minimax/relay-minimax.go b/relay/channel/minimax/relay-minimax.go
index b314e69d7..c249de6a4 100644
--- a/relay/channel/minimax/relay-minimax.go
+++ b/relay/channel/minimax/relay-minimax.go
@@ -6,6 +6,7 @@ import (
channelconstant "github.com/QuantumNous/new-api/constant"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
+ "github.com/QuantumNous/new-api/types"
)
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -13,13 +14,17 @@ func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax]
}
-
- switch info.RelayMode {
- case constant.RelayModeChatCompletions:
- return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
- case constant.RelayModeAudioSpeech:
- return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
default:
- return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
+ switch info.RelayMode {
+ case constant.RelayModeChatCompletions:
+ return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
+ case constant.RelayModeAudioSpeech:
+ return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
+ default:
+ return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
+ }
}
}
diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go
index d55452c08..f698fc9f6 100644
--- a/relay/channel/task/ali/adaptor.go
+++ b/relay/channel/task/ali/adaptor.go
@@ -13,6 +13,7 @@ import (
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
+ "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/samber/lo"
@@ -108,10 +109,10 @@ type AliMetadata struct {
// ============================
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
- aliReq *AliVideoRequest
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- // 阿里通义万相支持 JSON 格式,不使用 multipart
- var taskReq relaycommon.TaskSubmitReq
- if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
- return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
- }
- aliReq, err := a.convertToAliRequest(info, taskReq)
- if err != nil {
- return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
- }
- a.aliReq = aliReq
- logger.LogJson(c, "ali video request body", aliReq)
+ // ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
return relaycommon.ValidateMultipartDirect(c, info)
}
@@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
- bodyBytes, err := common.Marshal(a.aliReq)
+ taskReq, err := relaycommon.GetTaskRequest(c)
+ if err != nil {
+ return nil, errors.Wrap(err, "get_task_request_failed")
+ }
+
+ aliReq, err := a.convertToAliRequest(info, taskReq)
+ if err != nil {
+ return nil, errors.Wrap(err, "convert_to_ali_request_failed")
+ }
+ logger.LogJson(c, "ali video request body", aliReq)
+
+ bodyBytes, err := common.Marshal(aliReq)
if err != nil {
return nil, errors.Wrap(err, "marshal_ali_request_failed")
}
-
return bytes.NewReader(bodyBytes), nil
}
@@ -252,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
}
func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
+ upstreamModel := req.Model
+ if info.IsModelMapped {
+ upstreamModel = info.UpstreamModelName
+ }
aliReq := &AliVideoRequest{
- Model: req.Model,
+ Model: upstreamModel,
Input: AliVideoInput{
Prompt: req.Prompt,
ImgURL: req.InputReference,
@@ -331,23 +336,37 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
}
}
- if aliReq.Model != req.Model {
+ if aliReq.Model != upstreamModel {
return nil, errors.New("can't change model with metadata")
}
- info.PriceData.OtherRatios = map[string]float64{
+ return aliReq, nil
+}
+
+// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。
+// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
+ taskReq, err := relaycommon.GetTaskRequest(c)
+ if err != nil {
+ return nil
+ }
+
+ aliReq, err := a.convertToAliRequest(info, taskReq)
+ if err != nil {
+ return nil
+ }
+
+ otherRatios := map[string]float64{
"seconds": float64(aliReq.Parameters.Duration),
}
-
ratios, err := ProcessAliOtherRatios(aliReq)
if err != nil {
- return nil, err
+ return otherRatios
}
- for s, f := range ratios {
- info.PriceData.OtherRatios[s] = f
+ for k, v := range ratios {
+ otherRatios[k] = v
}
-
- return aliReq, nil
+ return otherRatios
}
// DoRequest delegates to common helper
@@ -384,7 +403,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
// 转换为 OpenAI 格式响应
openAIResp := dto.NewOpenAIVideo()
- openAIResp.ID = aliResp.Output.TaskID
+ openAIResp.ID = info.PublicTaskID
+ openAIResp.TaskID = info.PublicTaskID
openAIResp.Model = c.GetString("model")
if openAIResp.Model == "" && info != nil {
openAIResp.Model = info.OriginModelName
diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go
index 6ebecb3c0..8f1d748ce 100644
--- a/relay/channel/task/doubao/adaptor.go
+++ b/relay/channel/task/doubao/adaptor.go
@@ -2,7 +2,6 @@ package doubao
import (
"bytes"
- "encoding/json"
"fmt"
"io"
"net/http"
@@ -14,6 +13,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
+ taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
@@ -89,6 +89,7 @@ type responseTask struct {
// ============================
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -130,8 +131,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
- info.UpstreamModelName = body.Model
- data, err := json.Marshal(body)
+ if info.IsModelMapped {
+ body.Model = info.UpstreamModelName
+ } else {
+ info.UpstreamModelName = body.Model
+ }
+ data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -154,7 +159,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
// Parse Doubao response
var dResp responsePayload
- if err := json.Unmarshal(responseBody, &dResp); err != nil {
+ if err := common.Unmarshal(responseBody, &dResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
@@ -165,8 +170,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
ov := dto.NewOpenAIVideo()
- ov.ID = dResp.ID
- ov.TaskID = dResp.ID
+ ov.ID = info.PublicTaskID
+ ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
@@ -234,12 +239,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
metadata := req.Metadata
- medaBytes, err := json.Marshal(metadata)
- if err != nil {
- return nil, errors.Wrap(err, "metadata marshal metadata failed")
- }
- err = json.Unmarshal(medaBytes, &r)
- if err != nil {
+ if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
@@ -248,7 +248,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := responseTask{}
- if err := json.Unmarshal(respBody, &resTask); err != nil {
+ if err := common.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
@@ -286,7 +286,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var dResp responseTask
- if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
+ if err := common.Unmarshal(originTask.Data, &dResp); err != nil {
return nil, errors.Wrap(err, "unmarshal doubao task data failed")
}
@@ -307,6 +307,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
}
}
- jsonData, _ := common.Marshal(openAIVideo)
- return jsonData, nil
+ return common.Marshal(openAIVideo)
}
diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go
index 16c6919b7..5644cd5dc 100644
--- a/relay/channel/task/gemini/adaptor.go
+++ b/relay/channel/task/gemini/adaptor.go
@@ -2,8 +2,6 @@ package gemini
import (
"bytes"
- "encoding/base64"
- "encoding/json"
"fmt"
"io"
"net/http"
@@ -16,10 +14,10 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
+ taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
- "github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
)
@@ -87,6 +85,7 @@ type operationResponse struct {
// ============================
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -106,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
- modelName := info.OriginModelName
+ modelName := info.UpstreamModelName
version := model_setting.GetGeminiVersionSetting(modelName)
return fmt.Sprintf(
@@ -145,16 +144,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
metadata := req.Metadata
- medaBytes, err := json.Marshal(metadata)
- if err != nil {
- return nil, errors.Wrap(err, "metadata marshal metadata failed")
- }
- err = json.Unmarshal(medaBytes, &body.Parameters)
- if err != nil {
+ if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
- data, err := json.Marshal(body)
+ data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -175,16 +169,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
_ = resp.Body.Close()
var s submitResponse
- if err := json.Unmarshal(responseBody, &s); err != nil {
+ if err := common.Unmarshal(responseBody, &s); err != nil {
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
}
if strings.TrimSpace(s.Name) == "" {
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
}
- taskID = encodeLocalTaskID(s.Name)
+ taskID = taskcommon.EncodeLocalTaskID(s.Name)
ov := dto.NewOpenAIVideo()
- ov.ID = taskID
- ov.TaskID = taskID
+ ov.ID = info.PublicTaskID
+ ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
@@ -206,7 +200,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
return nil, fmt.Errorf("invalid task_id")
}
- upstreamName, err := decodeLocalTaskID(taskID)
+ upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
if err != nil {
return nil, fmt.Errorf("decode task_id failed: %w", err)
}
@@ -232,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
var op operationResponse
- if err := json.Unmarshal(respBody, &op); err != nil {
+ if err := common.Unmarshal(respBody, &op); err != nil {
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
}
@@ -254,9 +248,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
ti.Status = model.TaskStatusSuccess
ti.Progress = "100%"
- taskID := encodeLocalTaskID(op.Name)
- ti.TaskID = taskID
- ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
+ ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name)
+ // Url intentionally left empty — the caller constructs the proxy URL using the public task ID
// Extract URL from generateVideoResponse if available
if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
@@ -269,7 +262,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
- upstreamName, err := decodeLocalTaskID(task.TaskID)
+ // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
+ // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
+ upstreamTaskID := task.GetUpstreamTaskID()
+ upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
if err != nil {
upstreamName = ""
}
@@ -297,18 +293,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
// helpers
// ============================
-func encodeLocalTaskID(name string) string {
- return base64.RawURLEncoding.EncodeToString([]byte(name))
-}
-
-func decodeLocalTaskID(local string) (string, error) {
- b, err := base64.RawURLEncoding.DecodeString(local)
- if err != nil {
- return "", err
- }
- return string(b), nil
-}
-
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
func extractModelFromOperationName(name string) string {
diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go
index c77905bfb..28b3a97f1 100644
--- a/relay/channel/task/hailuo/adaptor.go
+++ b/relay/channel/task/hailuo/adaptor.go
@@ -2,7 +2,6 @@ package hailuo
import (
"bytes"
- "encoding/json"
"fmt"
"io"
"net/http"
@@ -18,12 +17,14 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
+ taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
)
// https://platform.minimaxi.com/docs/api-reference/video-generation-intro
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -60,12 +61,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
return nil, fmt.Errorf("invalid request type in context")
}
- body, err := a.convertToRequestPayload(&req)
+ body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
- data, err := json.Marshal(body)
+ data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -86,7 +87,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
_ = resp.Body.Close()
var hResp VideoResponse
- if err := json.Unmarshal(responseBody, &hResp); err != nil {
+ if err := common.Unmarshal(responseBody, &hResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
@@ -101,8 +102,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
ov := dto.NewOpenAIVideo()
- ov.ID = hResp.TaskID
- ov.TaskID = hResp.TaskID
+ ov.ID = info.PublicTaskID
+ ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
@@ -141,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
- modelConfig := GetModelConfig(req.Model)
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
+ modelConfig := GetModelConfig(info.UpstreamModelName)
duration := DefaultDuration
if req.Duration > 0 {
duration = req.Duration
@@ -153,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
videoRequest := &VideoRequest{
- Model: req.Model,
+ Model: info.UpstreamModelName,
Prompt: req.Prompt,
Duration: &duration,
Resolution: resolution,
@@ -182,7 +183,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := QueryTaskResponse{}
- if err := json.Unmarshal(respBody, &resTask); err != nil {
+ if err := common.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
@@ -224,7 +225,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var hailuoResp QueryTaskResponse
- if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil {
+ if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil {
return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
}
@@ -271,7 +272,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string {
}
var retrieveResp RetrieveFileResponse
- if err := json.Unmarshal(responseBody, &retrieveResp); err != nil {
+ if err := common.Unmarshal(responseBody, &retrieveResp); err != nil {
return ""
}
diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go
index 1522a967f..e6211b1e4 100644
--- a/relay/channel/task/jimeng/adaptor.go
+++ b/relay/channel/task/jimeng/adaptor.go
@@ -6,7 +6,6 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
- "encoding/json"
"fmt"
"io"
"net/http"
@@ -25,6 +24,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
+ taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
)
@@ -77,6 +77,7 @@ const (
// ============================
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
accessKey string
secretKey string
@@ -164,11 +165,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
}
- body, err := a.convertToRequestPayload(&req)
+ body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
- data, err := json.Marshal(body)
+ data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -191,7 +192,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
// Parse Jimeng response
var jResp responsePayload
- if err := json.Unmarshal(responseBody, &jResp); err != nil {
+ if err := common.Unmarshal(responseBody, &jResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
@@ -202,8 +203,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
ov := dto.NewOpenAIVideo()
- ov.ID = jResp.Data.TaskID
- ov.TaskID = jResp.Data.TaskID
+ ov.ID = info.PublicTaskID
+ ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
@@ -225,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
"task_id": taskID,
}
- payloadBytes, err := json.Marshal(payload)
+ payloadBytes, err := common.Marshal(payload)
if err != nil {
return nil, errors.Wrap(err, "marshal fetch task payload failed")
}
@@ -377,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
return h.Sum(nil)
}
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
- ReqKey: req.Model,
+ ReqKey: info.UpstreamModelName,
Prompt: req.Prompt,
}
@@ -398,13 +399,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
r.BinaryDataBase64 = req.Images
}
}
- metadata := req.Metadata
- medaBytes, err := json.Marshal(metadata)
- if err != nil {
- return nil, errors.Wrap(err, "metadata marshal metadata failed")
- }
- err = json.Unmarshal(medaBytes, &r)
- if err != nil {
+ if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
@@ -432,7 +427,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := responseTask{}
- if err := json.Unmarshal(respBody, &resTask); err != nil {
+ if err := common.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
taskResult := relaycommon.TaskInfo{}
@@ -458,7 +453,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var jimengResp responseTask
- if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
+ if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil {
return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
}
@@ -477,8 +472,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
}
}
- jsonData, _ := common.Marshal(openAIVideo)
- return jsonData, nil
+ return common.Marshal(openAIVideo)
}
func isNewAPIRelay(apiKey string) bool {
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
index 5fb853481..cdbb56878 100644
--- a/relay/channel/task/kling/adaptor.go
+++ b/relay/channel/task/kling/adaptor.go
@@ -2,7 +2,6 @@ package kling
import (
"bytes"
- "encoding/json"
"fmt"
"io"
"net/http"
@@ -21,6 +20,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
+ taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
)
@@ -97,6 +97,7 @@ type responsePayload struct {
// ============================
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -149,14 +150,14 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
req := v.(relaycommon.TaskSubmitReq)
- body, err := a.convertToRequestPayload(&req)
+ body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, err
}
if body.Image == "" && body.ImageTail == "" {
c.Set("action", constant.TaskActionTextGenerate)
}
- data, err := json.Marshal(body)
+ data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -180,7 +181,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
var kResp responsePayload
- err = json.Unmarshal(responseBody, &kResp)
+ err = common.Unmarshal(responseBody, &kResp)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
return
@@ -190,8 +191,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
ov := dto.NewOpenAIVideo()
- ov.ID = kResp.Data.TaskId
- ov.TaskID = kResp.Data.TaskId
+ ov.ID = info.PublicTaskID
+ ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
@@ -247,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
- Mode: defaultString(req.Mode, "std"),
- Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
+ Mode: taskcommon.DefaultString(req.Mode, "std"),
+ Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
- ModelName: req.Model,
- Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
+ ModelName: info.UpstreamModelName,
+ Model: info.UpstreamModelName,
CfgScale: 0.5,
StaticMask: "",
DynamicMasks: []DynamicMask{},
@@ -265,14 +266,9 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
if r.ModelName == "" {
r.ModelName = "kling-v1"
+ r.Model = "kling-v1"
}
- metadata := req.Metadata
- medaBytes, err := json.Marshal(metadata)
- if err != nil {
- return nil, errors.Wrap(err, "metadata marshal metadata failed")
- }
- err = json.Unmarshal(medaBytes, &r)
- if err != nil {
+ if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
@@ -291,20 +287,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string {
}
}
-func defaultString(s, def string) string {
- if strings.TrimSpace(s) == "" {
- return def
- }
- return s
-}
-
-func defaultInt(v int, def int) int {
- if v == 0 {
- return def
- }
- return v
-}
-
// ============================
// JWT helpers
// ============================
@@ -340,7 +322,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
taskInfo := &relaycommon.TaskInfo{}
resPayload := responsePayload{}
- err := json.Unmarshal(respBody, &resPayload)
+ err := common.Unmarshal(respBody, &resPayload)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
@@ -374,7 +356,7 @@ func isNewAPIRelay(apiKey string) bool {
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var klingResp responsePayload
- if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
+ if err := common.Unmarshal(originTask.Data, &klingResp); err != nil {
return nil, errors.Wrap(err, "unmarshal kling task data failed")
}
@@ -401,6 +383,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
Code: fmt.Sprintf("%d", klingResp.Code),
}
}
- jsonData, _ := common.Marshal(openAIVideo)
- return jsonData, nil
+ return common.Marshal(openAIVideo)
}
diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go
index c149f9663..e9029aa20 100644
--- a/relay/channel/task/sora/adaptor.go
+++ b/relay/channel/task/sora/adaptor.go
@@ -1,9 +1,13 @@
package sora
import (
+ "bytes"
"fmt"
"io"
+ "mime/multipart"
"net/http"
+ "net/textproto"
+ "strconv"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -11,12 +15,13 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
+ taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
- "github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
+ "github.com/tidwall/sjson"
)
// ============================
@@ -57,6 +62,7 @@ type responseTask struct {
// ============================
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -69,15 +75,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func validateRemixRequest(c *gin.Context) *dto.TaskError {
- var req struct {
- Prompt string `json:"prompt"`
- }
+ var req relaycommon.TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
}
if strings.TrimSpace(req.Prompt) == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
}
+ // 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
+ c.Set("task_request", req)
return nil
}
@@ -88,6 +94,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
return relaycommon.ValidateMultipartDirect(c, info)
}
+// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
+ // remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
+ if info.Action == constant.TaskActionRemix {
+ return nil
+ }
+
+ req, err := relaycommon.GetTaskRequest(c)
+ if err != nil {
+ return nil
+ }
+
+ seconds, _ := strconv.Atoi(req.Seconds)
+ if seconds == 0 {
+ seconds = req.Duration
+ }
+ if seconds <= 0 {
+ seconds = 4
+ }
+
+ size := req.Size
+ if size == "" {
+ size = "720x1280"
+ }
+
+ ratios := map[string]float64{
+ "seconds": float64(seconds),
+ "size": 1,
+ }
+ if size == "1792x1024" || size == "1024x1792" {
+ ratios["size"] = 1.666667
+ }
+ return ratios
+}
+
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.Action == constant.TaskActionRemix {
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
@@ -107,6 +148,74 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, errors.Wrap(err, "get_request_body_failed")
}
+ cachedBody, err := storage.Bytes()
+ if err != nil {
+ return nil, errors.Wrap(err, "read_body_bytes_failed")
+ }
+ contentType := c.GetHeader("Content-Type")
+
+ if strings.HasPrefix(contentType, "application/json") {
+ var bodyMap map[string]interface{}
+ if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
+ bodyMap["model"] = info.UpstreamModelName
+ if newBody, err := common.Marshal(bodyMap); err == nil {
+ return bytes.NewReader(newBody), nil
+ }
+ }
+ return bytes.NewReader(cachedBody), nil
+ }
+
+ if strings.Contains(contentType, "multipart/form-data") {
+ formData, err := common.ParseMultipartFormReusable(c)
+ if err != nil {
+ return bytes.NewReader(cachedBody), nil
+ }
+ var buf bytes.Buffer
+ writer := multipart.NewWriter(&buf)
+ writer.WriteField("model", info.UpstreamModelName)
+ for key, values := range formData.Value {
+ if key == "model" {
+ continue
+ }
+ for _, v := range values {
+ writer.WriteField(key, v)
+ }
+ }
+ for fieldName, fileHeaders := range formData.File {
+ for _, fh := range fileHeaders {
+ f, err := fh.Open()
+ if err != nil {
+ continue
+ }
+ ct := fh.Header.Get("Content-Type")
+ if ct == "" || ct == "application/octet-stream" {
+ buf512 := make([]byte, 512)
+ n, _ := io.ReadFull(f, buf512)
+ ct = http.DetectContentType(buf512[:n])
+ // Re-open after sniffing so the full content is copied below
+ f.Close()
+ f, err = fh.Open()
+ if err != nil {
+ continue
+ }
+ }
+ h := make(textproto.MIMEHeader)
+ h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fh.Filename))
+ h.Set("Content-Type", ct)
+ part, err := writer.CreatePart(h)
+ if err != nil {
+ f.Close()
+ continue
+ }
+ io.Copy(part, f)
+ f.Close()
+ }
+ }
+ writer.Close()
+ c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+ return &buf, nil
+ }
+
return common.ReaderOnly(storage), nil
}
@@ -116,7 +225,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req
}
// DoResponse handles upstream response, returns taskID etc.
-func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@@ -131,17 +240,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
return
}
- if dResp.ID == "" {
- if dResp.TaskID == "" {
- taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
- return
- }
- dResp.ID = dResp.TaskID
- dResp.TaskID = ""
+ upstreamID := dResp.ID
+ if upstreamID == "" {
+ upstreamID = dResp.TaskID
+ }
+ if upstreamID == "" {
+ taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
+ return
}
+ // 使用公开 task_xxxx ID 返回给客户端
+ dResp.ID = info.PublicTaskID
+ dResp.TaskID = info.PublicTaskID
c.JSON(http.StatusOK, dResp)
- return dResp.ID, responseBody, nil
+ return upstreamID, responseBody, nil
}
// FetchTask fetch task status
@@ -192,7 +304,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
taskResult.Status = model.TaskStatusInProgress
case "completed":
taskResult.Status = model.TaskStatusSuccess
- taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID)
+ // Url intentionally left empty — the caller constructs the proxy URL using the public task ID
case "failed", "cancelled":
taskResult.Status = model.TaskStatusFailure
if resTask.Error != nil {
@@ -210,5 +322,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
- return task.Data, nil
+ data := task.Data
+ var err error
+ if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil {
+ return nil, errors.Wrap(err, "set id failed")
+ }
+ return data, nil
}
diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go
index 8ea9a1c7f..35b5e423b 100644
--- a/relay/channel/task/suno/adaptor.go
+++ b/relay/channel/task/suno/adaptor.go
@@ -2,18 +2,16 @@ package suno
import (
"bytes"
- "context"
- "encoding/json"
"fmt"
"io"
"net/http"
"strings"
- "time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
+ taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
@@ -21,11 +19,16 @@ import (
)
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
}
+// ParseTaskResult is not used for Suno tasks.
+// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that
+// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API.
+// This differs from the per-task polling used by video adaptors.
func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
- return nil, fmt.Errorf("not implement") // todo implement this method if needed
+ return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable")
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -47,13 +50,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
return
}
- if sunoRequest.ContinueClipId != "" {
- if sunoRequest.TaskID == "" {
- taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
- return
- }
- info.OriginTaskID = sunoRequest.TaskID
- }
+ //if sunoRequest.ContinueClipId != "" {
+ // if sunoRequest.TaskID == "" {
+ // taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
+ // return
+ // }
+ // info.OriginTaskID = sunoRequest.TaskID
+ //}
info.Action = action
c.Set("task_request", sunoRequest)
@@ -76,12 +79,9 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
sunoRequest, ok := c.Get("task_request")
if !ok {
- err := common.UnmarshalBodyReusable(c, &sunoRequest)
- if err != nil {
- return nil, err
- }
+ return nil, fmt.Errorf("task_request not found in context")
}
- data, err := json.Marshal(sunoRequest)
+ data, err := common.Marshal(sunoRequest)
if err != nil {
return nil, err
}
@@ -99,7 +99,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
var sunoResponse dto.TaskResponse[string]
- err = json.Unmarshal(responseBody, &sunoResponse)
+ err = common.Unmarshal(responseBody, &sunoResponse)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
return
@@ -109,17 +109,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
-
- _, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
- return
+ // 使用公开 task_xxxx ID 替换上游 ID 返回给客户端
+ publicResponse := dto.TaskResponse[string]{
+ Code: sunoResponse.Code,
+ Message: sunoResponse.Message,
+ Data: info.PublicTaskID,
}
+ c.JSON(http.StatusOK, publicResponse)
return sunoResponse.Data, nil, nil
}
@@ -134,7 +130,7 @@ func (a *TaskAdaptor) GetChannelName() string {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
- byteBody, err := json.Marshal(body)
+ byteBody, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -144,13 +140,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
common.SysLog(fmt.Sprintf("Get Task error: %v", err))
return nil, err
}
- defer req.Body.Close()
- // 设置超时时间
- timeout := time.Second * 15
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key)
client, err := service.GetHttpClientWithProxy(proxy)
diff --git a/relay/channel/task/taskcommon/helpers.go b/relay/channel/task/taskcommon/helpers.go
new file mode 100644
index 000000000..27d6612d4
--- /dev/null
+++ b/relay/channel/task/taskcommon/helpers.go
@@ -0,0 +1,95 @@
+package taskcommon
+
+import (
+ "encoding/base64"
+ "fmt"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/model"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
+ "github.com/QuantumNous/new-api/setting/system_setting"
+ "github.com/gin-gonic/gin"
+)
+
+// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
+// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target).
+func UnmarshalMetadata(metadata map[string]any, target any) error {
+ if metadata == nil {
+ return nil
+ }
+ metaBytes, err := common.Marshal(metadata)
+ if err != nil {
+ return fmt.Errorf("marshal metadata failed: %w", err)
+ }
+ if err := common.Unmarshal(metaBytes, target); err != nil {
+ return fmt.Errorf("unmarshal metadata failed: %w", err)
+ }
+ return nil
+}
+
+// DefaultString returns val if non-empty, otherwise fallback.
+func DefaultString(val, fallback string) string {
+ if val == "" {
+ return fallback
+ }
+ return val
+}
+
+// DefaultInt returns val if non-zero, otherwise fallback.
+func DefaultInt(val, fallback int) int {
+ if val == 0 {
+ return fallback
+ }
+ return val
+}
+
+// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string.
+// Used by Gemini/Vertex to store upstream names as task IDs.
+func EncodeLocalTaskID(name string) string {
+ return base64.RawURLEncoding.EncodeToString([]byte(name))
+}
+
+// DecodeLocalTaskID decodes a base64-encoded upstream operation name.
+func DecodeLocalTaskID(id string) (string, error) {
+ b, err := base64.RawURLEncoding.DecodeString(id)
+ if err != nil {
+ return "", err
+ }
+ return string(b), nil
+}
+
+// BuildProxyURL constructs the video proxy URL using the public task ID.
+// e.g., "https://your-server.com/v1/videos/task_xxxx/content"
+func BuildProxyURL(taskID string) string {
+ return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
+}
+
+// Status-to-progress mapping constants for polling updates.
+const (
+ ProgressSubmitted = "10%"
+ ProgressQueued = "20%"
+ ProgressInProgress = "30%"
+ ProgressComplete = "100%"
+)
+
+// ---------------------------------------------------------------------------
+// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
+// Adaptors that do not need custom billing can embed this struct directly.
+// ---------------------------------------------------------------------------
+
+type BaseBilling struct{}
+
+// EstimateBilling returns nil (no extra ratios; use base model price).
+func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
+ return nil
+}
+
+// AdjustBillingOnSubmit returns nil (no submit-time adjustment).
+func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
+ return nil
+}
+
+// AdjustBillingOnComplete returns 0 (keep pre-charged amount).
+func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
+ return 0
+}
diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go
index 8ec77266e..700e60976 100644
--- a/relay/channel/task/vertex/adaptor.go
+++ b/relay/channel/task/vertex/adaptor.go
@@ -2,13 +2,12 @@ package vertex
import (
"bytes"
- "encoding/base64"
- "encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
+ "time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
@@ -17,6 +16,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
+ taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
@@ -62,6 +62,7 @@ type operationResponse struct {
// ============================
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -82,10 +83,10 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
adc := &vertexcore.Credentials{}
- if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+ if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials: %w", err)
}
- modelName := info.OriginModelName
+ modelName := info.UpstreamModelName
if modelName == "" {
modelName = "veo-3.0-generate-001"
}
@@ -116,7 +117,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
req.Header.Set("Accept", "application/json")
adc := &vertexcore.Credentials{}
- if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+ if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
return fmt.Errorf("failed to decode credentials: %w", err)
}
@@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
return nil
}
+// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
+ sampleCount := 1
+ v, ok := c.Get("task_request")
+ if ok {
+ req := v.(relaycommon.TaskSubmitReq)
+ if req.Metadata != nil {
+ if sc, exists := req.Metadata["sampleCount"]; exists {
+ if i, ok := sc.(int); ok && i > 0 {
+ sampleCount = i
+ }
+ if f, ok := sc.(float64); ok && int(f) > 0 {
+ sampleCount = int(f)
+ }
+ }
+ }
+ }
+ return map[string]float64{
+ "sampleCount": float64(sampleCount),
+ }
+}
+
// BuildRequestBody converts request into Vertex specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
v, ok := c.Get("task_request")
@@ -166,25 +189,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
return nil, fmt.Errorf("sampleCount must be greater than 0")
}
- // if req.Duration > 0 {
- // body.Parameters["durationSeconds"] = req.Duration
- // } else if req.Seconds != "" {
- // seconds, err := strconv.Atoi(req.Seconds)
- // if err != nil {
- // return nil, errors.Wrap(err, "convert seconds to int failed")
- // }
- // body.Parameters["durationSeconds"] = seconds
- // }
-
- info.PriceData.OtherRatios = map[string]float64{
- "sampleCount": float64(body.Parameters["sampleCount"].(int)),
- }
-
- // if v, ok := body.Parameters["durationSeconds"]; ok {
- // info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
- // }
-
- data, err := json.Marshal(body)
+ data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -205,14 +210,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
_ = resp.Body.Close()
var s submitResponse
- if err := json.Unmarshal(responseBody, &s); err != nil {
+ if err := common.Unmarshal(responseBody, &s); err != nil {
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
}
if strings.TrimSpace(s.Name) == "" {
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
}
- localID := encodeLocalTaskID(s.Name)
- c.JSON(http.StatusOK, gin.H{"task_id": localID})
+ localID := taskcommon.EncodeLocalTaskID(s.Name)
+ ov := dto.NewOpenAIVideo()
+ ov.ID = info.PublicTaskID
+ ov.TaskID = info.PublicTaskID
+ ov.CreatedAt = time.Now().Unix()
+ ov.Model = info.OriginModelName
+ c.JSON(http.StatusOK, ov)
return localID, responseBody, nil
}
@@ -225,7 +235,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
- upstreamName, err := decodeLocalTaskID(taskID)
+ upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
if err != nil {
return nil, fmt.Errorf("decode task_id failed: %w", err)
}
@@ -245,12 +255,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
}
payload := map[string]string{"operationName": upstreamName}
- data, err := json.Marshal(payload)
+ data, err := common.Marshal(payload)
if err != nil {
return nil, err
}
adc := &vertexcore.Credentials{}
- if err := json.Unmarshal([]byte(key), adc); err != nil {
+ if err := common.Unmarshal([]byte(key), adc); err != nil {
return nil, fmt.Errorf("failed to decode credentials: %w", err)
}
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
@@ -274,7 +284,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
var op operationResponse
- if err := json.Unmarshal(respBody, &op); err != nil {
+ if err := common.Unmarshal(respBody, &op); err != nil {
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
}
ti := &relaycommon.TaskInfo{}
@@ -338,7 +348,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
- upstreamName, err := decodeLocalTaskID(task.TaskID)
+ // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
+ // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
+ upstreamTaskID := task.GetUpstreamTaskID()
+ upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
if err != nil {
upstreamName = ""
}
@@ -353,8 +366,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
v.SetProgressStr(task.Progress)
v.CreatedAt = task.CreatedAt
v.CompletedAt = task.UpdatedAt
- if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 {
- v.SetMetadata("url", task.FailReason)
+ if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 {
+ v.SetMetadata("url", resultURL)
}
return common.Marshal(v)
@@ -364,18 +377,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
// helpers
// ============================
-func encodeLocalTaskID(name string) string {
- return base64.RawURLEncoding.EncodeToString([]byte(name))
-}
-
-func decodeLocalTaskID(local string) (string, error) {
- b, err := base64.RawURLEncoding.DecodeString(local)
- if err != nil {
- return "", err
- }
- return string(b), nil
-}
-
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
func extractRegionFromOperationName(name string) string {
diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go
index 3657161c0..6ae1c181b 100644
--- a/relay/channel/task/vidu/adaptor.go
+++ b/relay/channel/task/vidu/adaptor.go
@@ -2,7 +2,6 @@ package vidu
import (
"bytes"
- "encoding/json"
"fmt"
"io"
"net/http"
@@ -16,6 +15,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
+ taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
@@ -73,6 +73,7 @@ type creation struct {
// ============================
type TaskAdaptor struct {
+ taskcommon.BaseBilling
ChannelType int
baseURL string
}
@@ -115,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
req := v.(relaycommon.TaskSubmitReq)
- body, err := a.convertToRequestPayload(&req)
+ body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, err
}
@@ -127,7 +128,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
}
- data, err := json.Marshal(body)
+ data, err := common.Marshal(body)
if err != nil {
return nil, err
}
@@ -168,7 +169,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
var vResp responsePayload
- err = json.Unmarshal(responseBody, &vResp)
+ err = common.Unmarshal(responseBody, &vResp)
if err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
return
@@ -180,8 +181,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
ov := dto.NewOpenAIVideo()
- ov.ID = vResp.TaskId
- ov.TaskID = vResp.TaskId
+ ov.ID = info.PublicTaskID
+ ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
@@ -223,47 +224,27 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
- Model: defaultString(req.Model, "viduq1"),
+ Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
Images: req.Images,
Prompt: req.Prompt,
- Duration: defaultInt(req.Duration, 5),
- Resolution: defaultString(req.Size, "1080p"),
+ Duration: taskcommon.DefaultInt(req.Duration, 5),
+ Resolution: taskcommon.DefaultString(req.Size, "1080p"),
MovementAmplitude: "auto",
Bgm: false,
}
- metadata := req.Metadata
- medaBytes, err := json.Marshal(metadata)
- if err != nil {
- return nil, errors.Wrap(err, "metadata marshal metadata failed")
- }
- err = json.Unmarshal(medaBytes, &r)
- if err != nil {
+ if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
}
-func defaultString(value, defaultValue string) string {
- if value == "" {
- return defaultValue
- }
- return value
-}
-
-func defaultInt(value, defaultValue int) int {
- if value == 0 {
- return defaultValue
- }
- return value
-}
-
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
taskInfo := &relaycommon.TaskInfo{}
var taskResp taskResultResponse
- err := json.Unmarshal(respBody, &taskResp)
+ err := common.Unmarshal(respBody, &taskResp)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
@@ -293,7 +274,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var viduResp taskResultResponse
- if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
+ if err := common.Unmarshal(originTask.Data, &viduResp); err != nil {
return nil, errors.Wrap(err, "unmarshal vidu task data failed")
}
@@ -315,6 +296,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
}
}
- jsonData, _ := common.Marshal(openAIVideo)
- return jsonData, nil
+ return common.Marshal(openAIVideo)
}
diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go
index 580cba5f4..8f69b9375 100644
--- a/relay/chat_completions_via_responses.go
+++ b/relay/chat_completions_via_responses.go
@@ -75,7 +75,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
- chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings)
+ chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -119,7 +119,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
- jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+ jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
diff --git a/relay/claude_handler.go b/relay/claude_handler.go
index 2dfa09df5..1722cd9b2 100644
--- a/relay/claude_handler.go
+++ b/relay/claude_handler.go
@@ -146,7 +146,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
// remove disabled fields for Claude API
- jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+ jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
diff --git a/relay/common/override_test.go b/relay/common/override_test.go
index a37eb78f9..0fc24467d 100644
--- a/relay/common/override_test.go
+++ b/relay/common/override_test.go
@@ -6,6 +6,9 @@ import (
"testing"
"github.com/QuantumNous/new-api/types"
+
+ "github.com/QuantumNous/new-api/dto"
+ "github.com/QuantumNous/new-api/setting/model_setting"
)
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
@@ -1311,6 +1314,76 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
}
}
+func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
+ input := `{
+ "service_tier":"flex",
+ "safety_identifier":"user-123",
+ "store":true,
+ "stream_options":{"include_obfuscation":false}
+ }`
+ settings := dto.ChannelOtherSettings{}
+
+ out, err := RemoveDisabledFields([]byte(input), settings, true)
+ if err != nil {
+ t.Fatalf("RemoveDisabledFields returned error: %v", err)
+ }
+ assertJSONEqual(t, input, string(out))
+}
+
+func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) {
+ original := model_setting.GetGlobalSettings().PassThroughRequestEnabled
+ model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
+ t.Cleanup(func() {
+ model_setting.GetGlobalSettings().PassThroughRequestEnabled = original
+ })
+
+ input := `{
+ "service_tier":"flex",
+ "safety_identifier":"user-123",
+ "stream_options":{"include_obfuscation":false}
+ }`
+ settings := dto.ChannelOtherSettings{}
+
+ out, err := RemoveDisabledFields([]byte(input), settings, false)
+ if err != nil {
+ t.Fatalf("RemoveDisabledFields returned error: %v", err)
+ }
+ assertJSONEqual(t, input, string(out))
+}
+
+func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
+ input := `{
+ "service_tier":"flex",
+ "inference_geo":"eu",
+ "safety_identifier":"user-123",
+ "store":true,
+ "stream_options":{"include_obfuscation":false}
+ }`
+ settings := dto.ChannelOtherSettings{}
+
+ out, err := RemoveDisabledFields([]byte(input), settings, false)
+ if err != nil {
+ t.Fatalf("RemoveDisabledFields returned error: %v", err)
+ }
+ assertJSONEqual(t, `{"store":true}`, string(out))
+}
+
+func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
+ input := `{
+ "inference_geo":"eu",
+ "store":true
+ }`
+ settings := dto.ChannelOtherSettings{
+ AllowInferenceGeo: true,
+ }
+
+ out, err := RemoveDisabledFields([]byte(input), settings, false)
+ if err != nil {
+ t.Fatalf("RemoveDisabledFields returned error: %v", err)
+ }
+ assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
+}
+
func assertJSONEqual(t *testing.T, want, got string) {
t.Helper()
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index e5a0a06f5..8b0789c0d 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -119,8 +119,12 @@ type RelayInfo struct {
SendResponseCount int
ReceivedResponseCount int
FinalPreConsumedQuota int // 最终预消耗的配额
+ // ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
+ // 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
+ // 必须在提交前锁定全额。
+ ForcePreConsume bool
// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
- // 免费模型和按次计费(MJ/Task)时为 nil。
+ // 免费模型时为 nil。
Billing BillingSettler
// BillingSource indicates whether this request is billed from wallet quota or subscription.
// "" or "wallet" => wallet; "subscription" => subscription
@@ -153,7 +157,8 @@ type RelayInfo struct {
// RequestConversionChain records request format conversions in order, e.g.
// ["openai", "openai_responses"] or ["openai", "claude"].
RequestConversionChain []types.RelayFormat
- // 最终请求到上游的格式 TODO: 当前仅设置了Claude
+ // 最终请求到上游的格式。可由 adaptor 显式设置;
+ // 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
FinalRequestRelayFormat types.RelayFormat
ThinkingContentInfo
@@ -552,8 +557,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
case types.RelayFormatTask:
info = genBaseRelayInfo(c, nil)
+ info.TaskRelayInfo = &TaskRelayInfo{}
case types.RelayFormatMjProxy:
info = genBaseRelayInfo(c, nil)
+ info.TaskRelayInfo = &TaskRelayInfo{}
default:
err = errors.New("invalid relay format")
}
@@ -600,6 +607,19 @@ func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
info.RequestConversionChain = append(info.RequestConversionChain, format)
}
+func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
+ if info == nil {
+ return ""
+ }
+ if info.FinalRequestRelayFormat != "" {
+ return info.FinalRequestRelayFormat
+ }
+ if n := len(info.RequestConversionChain); n > 0 {
+ return info.RequestConversionChain[n-1]
+ }
+ return info.RelayFormat
+}
+
func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
info := genBaseRelayInfo(c, request)
if info.RelayMode == relayconstant.RelayModeUnknown {
@@ -635,8 +655,16 @@ func (info *RelayInfo) HasSendResponse() bool {
type TaskRelayInfo struct {
Action string
OriginTaskID string
+ // PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
+ // 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
+ PublicTaskID string
ConsumeQuota bool
+
+ // LockedChannel holds the full channel object when the request is bound to
+ // a specific channel (e.g., remix on origin task's channel). Stored as any
+ // to avoid an import cycle with model; callers type-assert to *model.Channel.
+ LockedChannel any
}
type TaskSubmitReq struct {
@@ -694,11 +722,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
metadata := t.Metadata
if metadata != nil {
- metadataBytes, err := json.Marshal(metadata)
+ metadataBytes, err := common.Marshal(metadata)
if err != nil {
return fmt.Errorf("marshal metadata failed: %w", err)
}
- err = json.Unmarshal(metadataBytes, v)
+ err = common.Unmarshal(metadataBytes, v)
if err != nil {
return fmt.Errorf("unmarshal metadata to target failed: %w", err)
}
@@ -727,9 +755,15 @@ func FailTaskInfo(reason string) *TaskInfo {
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
// service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
+// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
// safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
-func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) {
+// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
+func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
+ return jsonData, nil
+ }
+
var data map[string]interface{}
if err := common.Unmarshal(jsonData, &data); err != nil {
common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
@@ -743,6 +777,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
}
}
+ // 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域)
+ if !channelOtherSettings.AllowInferenceGeo {
+ if _, exists := data["inference_geo"]; exists {
+ delete(data, "inference_geo")
+ }
+ }
+
// 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
if channelOtherSettings.DisableStore {
if _, exists := data["store"]; exists {
@@ -757,6 +798,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
}
}
+ // 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护)
+ if !channelOtherSettings.AllowIncludeObfuscation {
+ if streamOptionsAny, exists := data["stream_options"]; exists {
+ if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
+ if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
+ delete(streamOptions, "include_obfuscation")
+ }
+ if len(streamOptions) == 0 {
+ delete(data, "stream_options")
+ } else {
+ data["stream_options"] = streamOptions
+ }
+ }
+ }
+ }
+
jsonDataAfter, err := common.Marshal(data)
if err != nil {
common.SysError("RemoveDisabledFields Marshal error :" + err.Error())
diff --git a/relay/common/relay_info_test.go b/relay/common/relay_info_test.go
new file mode 100644
index 000000000..e53ec804c
--- /dev/null
+++ b/relay/common/relay_info_test.go
@@ -0,0 +1,40 @@
+package common
+
+import (
+ "testing"
+
+ "github.com/QuantumNous/new-api/types"
+ "github.com/stretchr/testify/require"
+)
+
+func TestRelayInfoGetFinalRequestRelayFormatPrefersExplicitFinal(t *testing.T) {
+ info := &RelayInfo{
+ RelayFormat: types.RelayFormatOpenAI,
+ RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
+ FinalRequestRelayFormat: types.RelayFormatOpenAIResponses,
+ }
+
+ require.Equal(t, types.RelayFormat(types.RelayFormatOpenAIResponses), info.GetFinalRequestRelayFormat())
+}
+
+func TestRelayInfoGetFinalRequestRelayFormatFallsBackToConversionChain(t *testing.T) {
+ info := &RelayInfo{
+ RelayFormat: types.RelayFormatOpenAI,
+ RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
+ }
+
+ require.Equal(t, types.RelayFormat(types.RelayFormatClaude), info.GetFinalRequestRelayFormat())
+}
+
+func TestRelayInfoGetFinalRequestRelayFormatFallsBackToRelayFormat(t *testing.T) {
+ info := &RelayInfo{
+ RelayFormat: types.RelayFormatGemini,
+ }
+
+ require.Equal(t, types.RelayFormat(types.RelayFormatGemini), info.GetFinalRequestRelayFormat())
+}
+
+func TestRelayInfoGetFinalRequestRelayFormatNilReceiver(t *testing.T) {
+ var info *RelayInfo
+ require.Equal(t, types.RelayFormat(""), info.GetFinalRequestRelayFormat())
+}
diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go
index b662f9053..3cbb18c22 100644
--- a/relay/common/relay_utils.go
+++ b/relay/common/relay_utils.go
@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
}
- info.PriceData.OtherRatios = map[string]float64{
- "seconds": float64(seconds),
- "size": 1,
- }
- if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
- info.PriceData.OtherRatios["size"] = 1.666667
- }
+ // OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
}
- info.Action = action
+ storeTaskRequest(c, info, action, req)
return nil
}
diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go
index 7f4b99488..9a25237c7 100644
--- a/relay/compatible_handler.go
+++ b/relay/compatible_handler.go
@@ -165,7 +165,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
}
// remove disabled fields for OpenAI API
- jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+ jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
@@ -232,7 +232,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
}
if originUsage != nil {
- service.ObserveChannelAffinityUsageCacheFromContext(ctx, usage)
+ service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
}
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
@@ -336,7 +336,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
var audioInputQuota decimal.Decimal
var audioInputPrice float64
- isClaudeUsageSemantic := relayInfo.FinalRequestRelayFormat == types.RelayFormatClaude
+ isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude
if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens
diff --git a/relay/helper/price.go b/relay/helper/price.go
index c310220fe..1cb04166f 100644
--- a/relay/helper/price.go
+++ b/relay/helper/price.go
@@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
}
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
-func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
+func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData {
groupRatioInfo := HandleGroupRatio(c, info)
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
@@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.
}
}
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
- priceData := types.PerCallPriceData{
+
+ // 免费模型检测(与 ModelPriceHelper 对齐)
+ freeModel := false
+ if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
+ if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 {
+ quota = 0
+ freeModel = true
+ }
+ }
+
+ priceData := types.PriceData{
+ FreeModel: freeModel,
ModelPrice: modelPrice,
Quota: quota,
GroupRatioInfo: groupRatioInfo,
diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go
index 4f3ab2363..ae70f53c0 100644
--- a/relay/helper/stream_scanner.go
+++ b/relay/helper/stream_scanner.go
@@ -176,10 +176,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
})
}
+ dataChan := make(chan string, 10)
+
+ wg.Add(1)
+ gopool.Go(func() {
+ defer func() {
+ wg.Done()
+ if r := recover(); r != nil {
+ logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
+ }
+ common.SafeSendBool(stopChan, true)
+ }()
+ for data := range dataChan {
+ writeMutex.Lock()
+ success := dataHandler(data)
+ writeMutex.Unlock()
+ if !success {
+ return
+ }
+ }
+ })
+
// Scanner goroutine with improved error handling
wg.Add(1)
common.RelayCtxGo(ctx, func() {
defer func() {
+ close(dataChan)
wg.Done()
if r := recover(); r != nil {
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
@@ -215,27 +237,16 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
continue
}
data = data[5:]
- data = strings.TrimLeft(data, " ")
- data = strings.TrimSuffix(data, "\r")
+ data = strings.TrimSpace(data)
+ if data == "" {
+ continue
+ }
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
info.ReceivedResponseCount++
- // 使用超时机制防止写操作阻塞
- done := make(chan bool, 1)
- gopool.Go(func() {
- writeMutex.Lock()
- defer writeMutex.Unlock()
- done <- dataHandler(data)
- })
select {
- case success := <-done:
- if !success {
- return
- }
- case <-time.After(10 * time.Second):
- logger.LogError(c, "data handler timeout")
- return
+ case dataChan <- data:
case <-ctx.Done():
return
case <-stopChan:
diff --git a/relay/helper/stream_scanner_test.go b/relay/helper/stream_scanner_test.go
new file mode 100644
index 000000000..6890d82a5
--- /dev/null
+++ b/relay/helper/stream_scanner_test.go
@@ -0,0 +1,521 @@
+package helper
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/QuantumNous/new-api/constant"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
+ "github.com/QuantumNous/new-api/setting/operation_setting"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func init() {
+ gin.SetMode(gin.TestMode)
+}
+
+func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) {
+ t.Helper()
+
+ oldTimeout := constant.StreamingTimeout
+ constant.StreamingTimeout = 30
+ t.Cleanup(func() {
+ constant.StreamingTimeout = oldTimeout
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ resp := &http.Response{
+ Body: io.NopCloser(body),
+ }
+
+ info := &relaycommon.RelayInfo{
+ ChannelMeta: &relaycommon.ChannelMeta{},
+ }
+
+ return c, resp, info
+}
+
+func buildSSEBody(n int) string {
+ var b strings.Builder
+ for i := 0; i < n; i++ {
+ fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i)
+ }
+ b.WriteString("data: [DONE]\n")
+ return b.String()
+}
+
+// slowReader wraps a reader and injects a delay before each Read call,
+// simulating a slow upstream that trickles data.
+type slowReader struct {
+ r io.Reader
+ delay time.Duration
+}
+
+func (s *slowReader) Read(p []byte) (int, error) {
+ time.Sleep(s.delay)
+ return s.r.Read(p)
+}
+
+// ---------- Basic correctness ----------
+
+func TestStreamScannerHandler_NilInputs(t *testing.T) {
+ t.Parallel()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
+
+ StreamScannerHandler(c, nil, info, func(data string) bool { return true })
+ StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
+}
+
+func TestStreamScannerHandler_EmptyBody(t *testing.T) {
+ t.Parallel()
+
+ c, resp, info := setupStreamTest(t, strings.NewReader(""))
+
+ var called atomic.Bool
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ called.Store(true)
+ return true
+ })
+
+ assert.False(t, called.Load(), "handler should not be called for empty body")
+}
+
+func TestStreamScannerHandler_1000Chunks(t *testing.T) {
+ t.Parallel()
+
+ const numChunks = 1000
+ body := buildSSEBody(numChunks)
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+ var count atomic.Int64
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ count.Add(1)
+ return true
+ })
+
+ assert.Equal(t, int64(numChunks), count.Load())
+ assert.Equal(t, numChunks, info.ReceivedResponseCount)
+}
+
+func TestStreamScannerHandler_10000Chunks(t *testing.T) {
+ t.Parallel()
+
+ const numChunks = 10000
+ body := buildSSEBody(numChunks)
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+ var count atomic.Int64
+ start := time.Now()
+
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ count.Add(1)
+ return true
+ })
+
+ elapsed := time.Since(start)
+ assert.Equal(t, int64(numChunks), count.Load())
+ assert.Equal(t, numChunks, info.ReceivedResponseCount)
+ t.Logf("10000 chunks processed in %v", elapsed)
+}
+
+func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
+ t.Parallel()
+
+ const numChunks = 500
+ body := buildSSEBody(numChunks)
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+ var mu sync.Mutex
+ received := make([]string, 0, numChunks)
+
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ mu.Lock()
+ received = append(received, data)
+ mu.Unlock()
+ return true
+ })
+
+ require.Equal(t, numChunks, len(received))
+ for i := 0; i < numChunks; i++ {
+ expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i)
+ assert.Equal(t, expected, received[i], "chunk %d out of order", i)
+ }
+}
+
+func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
+ t.Parallel()
+
+ body := buildSSEBody(50) + "data: should_not_appear\n"
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+ var count atomic.Int64
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ count.Add(1)
+ return true
+ })
+
+ assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
+}
+
+func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
+ t.Parallel()
+
+ const numChunks = 200
+ body := buildSSEBody(numChunks)
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+ const failAt = 50
+ var count atomic.Int64
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ n := count.Add(1)
+ return n < failAt
+ })
+
+ // The worker stops at failAt; the scanner may have read ahead,
+ // but the handler should not be called beyond failAt.
+ assert.Equal(t, int64(failAt), count.Load())
+}
+
+func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
+ t.Parallel()
+
+ var b strings.Builder
+ b.WriteString(": comment line\n")
+ b.WriteString("event: message\n")
+ b.WriteString("id: 12345\n")
+ b.WriteString("retry: 5000\n")
+ for i := 0; i < 100; i++ {
+ fmt.Fprintf(&b, "data: payload_%d\n", i)
+ b.WriteString(": interleaved comment\n")
+ }
+ b.WriteString("data: [DONE]\n")
+
+ c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
+
+ var count atomic.Int64
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ count.Add(1)
+ return true
+ })
+
+ assert.Equal(t, int64(100), count.Load())
+}
+
+func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
+ t.Parallel()
+
+ body := "data: {\"trimmed\":true} \ndata: [DONE]\n"
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+ var got string
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ got = data
+ return true
+ })
+
+ assert.Equal(t, "{\"trimmed\":true}", got)
+}
+
+// ---------- Decoupling: scanner not blocked by slow handler ----------
+
+func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
+ t.Parallel()
+
+ // Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
+ // If the scanner were synchronously coupled to the handler, total time would be
+ // ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
+ // With decoupling, total time should be closer to
+ // ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
+ // because the scanner reads ahead into the buffer while the handler processes.
+ const numChunks = 50
+ const upstreamDelay = 10 * time.Millisecond
+ const handlerDelay = 20 * time.Millisecond
+
+ pr, pw := io.Pipe()
+ go func() {
+ defer pw.Close()
+ for i := 0; i < numChunks; i++ {
+ fmt.Fprintf(pw, "data: {\"id\":%d}\n", i)
+ time.Sleep(upstreamDelay)
+ }
+ fmt.Fprint(pw, "data: [DONE]\n")
+ }()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ oldTimeout := constant.StreamingTimeout
+ constant.StreamingTimeout = 30
+ t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
+
+ resp := &http.Response{Body: pr}
+ info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
+
+ var count atomic.Int64
+ start := time.Now()
+ done := make(chan struct{})
+ go func() {
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ time.Sleep(handlerDelay)
+ count.Add(1)
+ return true
+ })
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(15 * time.Second):
+ t.Fatal("StreamScannerHandler did not complete in time")
+ }
+
+ elapsed := time.Since(start)
+ assert.Equal(t, int64(numChunks), count.Load())
+
+ coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
+ t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
+
+ // If decoupled, elapsed should be well under the coupled estimate.
+ assert.Less(t, elapsed, coupledTime*85/100,
+ "decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
+}
+
+func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
+ t.Parallel()
+
+ const numChunks = 50
+ body := buildSSEBody(numChunks)
+ reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond}
+ c, resp, info := setupStreamTest(t, reader)
+
+ var count atomic.Int64
+ start := time.Now()
+
+ done := make(chan struct{})
+ go func() {
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ count.Add(1)
+ return true
+ })
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(15 * time.Second):
+ t.Fatal("timed out with slow upstream")
+ }
+
+ elapsed := time.Since(start)
+ assert.Equal(t, int64(numChunks), count.Load())
+ t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed)
+}
+
+// ---------- Ping tests ----------
+
+func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
+ t.Parallel()
+
+ setting := operation_setting.GetGeneralSetting()
+ oldEnabled := setting.PingIntervalEnabled
+ oldSeconds := setting.PingIntervalSeconds
+ setting.PingIntervalEnabled = true
+ setting.PingIntervalSeconds = 1
+ t.Cleanup(func() {
+ setting.PingIntervalEnabled = oldEnabled
+ setting.PingIntervalSeconds = oldSeconds
+ })
+
+ // Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
+ // The ping interval is 1s, so we should see at least 2 pings.
+ pr, pw := io.Pipe()
+ go func() {
+ defer pw.Close()
+ for i := 0; i < 7; i++ {
+ fmt.Fprintf(pw, "data: chunk_%d\n", i)
+ time.Sleep(500 * time.Millisecond)
+ }
+ fmt.Fprint(pw, "data: [DONE]\n")
+ }()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ oldTimeout := constant.StreamingTimeout
+ constant.StreamingTimeout = 30
+ t.Cleanup(func() {
+ constant.StreamingTimeout = oldTimeout
+ })
+
+ resp := &http.Response{Body: pr}
+ info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
+
+ var count atomic.Int64
+ done := make(chan struct{})
+ go func() {
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ count.Add(1)
+ return true
+ })
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(15 * time.Second):
+ t.Fatal("timed out waiting for stream to finish")
+ }
+
+ assert.Equal(t, int64(7), count.Load())
+
+ body := recorder.Body.String()
+ pingCount := strings.Count(body, ": PING")
+ t.Logf("received %d pings in response body", pingCount)
+ assert.GreaterOrEqual(t, pingCount, 2,
+ "expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount)
+}
+
+func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
+ t.Parallel()
+
+ setting := operation_setting.GetGeneralSetting()
+ oldEnabled := setting.PingIntervalEnabled
+ oldSeconds := setting.PingIntervalSeconds
+ setting.PingIntervalEnabled = true
+ setting.PingIntervalSeconds = 1
+ t.Cleanup(func() {
+ setting.PingIntervalEnabled = oldEnabled
+ setting.PingIntervalSeconds = oldSeconds
+ })
+
+ pr, pw := io.Pipe()
+ go func() {
+ defer pw.Close()
+ for i := 0; i < 5; i++ {
+ fmt.Fprintf(pw, "data: chunk_%d\n", i)
+ time.Sleep(500 * time.Millisecond)
+ }
+ fmt.Fprint(pw, "data: [DONE]\n")
+ }()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ oldTimeout := constant.StreamingTimeout
+ constant.StreamingTimeout = 30
+ t.Cleanup(func() {
+ constant.StreamingTimeout = oldTimeout
+ })
+
+ resp := &http.Response{Body: pr}
+ info := &relaycommon.RelayInfo{
+ DisablePing: true,
+ ChannelMeta: &relaycommon.ChannelMeta{},
+ }
+
+ var count atomic.Int64
+ done := make(chan struct{})
+ go func() {
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ count.Add(1)
+ return true
+ })
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(15 * time.Second):
+ t.Fatal("timed out")
+ }
+
+ assert.Equal(t, int64(5), count.Load())
+
+ body := recorder.Body.String()
+ pingCount := strings.Count(body, ": PING")
+ assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
+}
+
+func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
+ t.Parallel()
+
+ setting := operation_setting.GetGeneralSetting()
+ oldEnabled := setting.PingIntervalEnabled
+ oldSeconds := setting.PingIntervalSeconds
+ setting.PingIntervalEnabled = true
+ setting.PingIntervalSeconds = 1
+ t.Cleanup(func() {
+ setting.PingIntervalEnabled = oldEnabled
+ setting.PingIntervalSeconds = oldSeconds
+ })
+
+ // Slow upstream + slow handler. Total stream takes ~5 seconds.
+ // The ping goroutine stays alive as long as the scanner is reading,
+ // so pings should fire between data writes.
+ pr, pw := io.Pipe()
+ go func() {
+ defer pw.Close()
+ for i := 0; i < 10; i++ {
+ fmt.Fprintf(pw, "data: chunk_%d\n", i)
+ time.Sleep(500 * time.Millisecond)
+ }
+ fmt.Fprint(pw, "data: [DONE]\n")
+ }()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ oldTimeout := constant.StreamingTimeout
+ constant.StreamingTimeout = 30
+ t.Cleanup(func() {
+ constant.StreamingTimeout = oldTimeout
+ })
+
+ resp := &http.Response{Body: pr}
+ info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
+
+ var count atomic.Int64
+ done := make(chan struct{})
+ go func() {
+ StreamScannerHandler(c, resp, info, func(data string) bool {
+ count.Add(1)
+ return true
+ })
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(15 * time.Second):
+ t.Fatal("timed out")
+ }
+
+ assert.Equal(t, int64(10), count.Load())
+
+ body := recorder.Body.String()
+ pingCount := strings.Count(body, ": PING")
+ t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount)
+ assert.GreaterOrEqual(t, pingCount, 3,
+ "expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount)
+}
diff --git a/relay/mjproxy_handler.go b/relay/mjproxy_handler.go
index 8916ab181..8e7c61e9c 100644
--- a/relay/mjproxy_handler.go
+++ b/relay/mjproxy_handler.go
@@ -184,7 +184,7 @@ func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyR
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
- modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
+ modelName := service.CovertMjpActionToModelName(constant.MjActionSwapFace)
priceData := helper.ModelPriceHelperPerCall(c, info)
@@ -485,7 +485,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dt
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- modelName := service.CoverActionToModelName(midjRequest.Action)
+ modelName := service.CovertMjpActionToModelName(midjRequest.Action)
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
diff --git a/relay/relay_task.go b/relay/relay_task.go
index ebbd1f65d..c740facdb 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -2,7 +2,6 @@ package relay
import (
"bytes"
- "encoding/json"
"errors"
"fmt"
"io"
@@ -15,29 +14,33 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
+ "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
+ "github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
- "github.com/QuantumNous/new-api/setting/ratio_setting"
-
"github.com/gin-gonic/gin"
)
-/*
-Task 任务通过平台、Action 区分任务
-*/
-func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- info.InitChannelMeta(c)
- // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
- if info.TaskRelayInfo == nil {
- info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
- }
+type TaskSubmitResult struct {
+ UpstreamTaskID string
+ TaskData []byte
+ Platform constant.TaskPlatform
+ Quota int
+ //PerCallPrice types.PriceData
+}
+
+// ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
+// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
+// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key),
+// 以及提取 OtherRatios(时长、分辨率)。
+// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
+func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
+ // 检测 remix action
path := c.Request.URL.Path
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
info.Action = constant.TaskActionRemix
}
-
- // 提取 remix 任务的 video_id
if info.Action == constant.TaskActionRemix {
videoID := c.Param("video_id")
if strings.TrimSpace(videoID) == "" {
@@ -46,64 +49,71 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
info.OriginTaskID = videoID
}
- platform := constant.TaskPlatform(c.GetString("platform"))
+ if info.OriginTaskID == "" {
+ return nil
+ }
- // 获取原始任务信息
- if info.OriginTaskID != "" {
- originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- if info.OriginModelName == "" {
- if originTask.Properties.OriginModelName != "" {
- info.OriginModelName = originTask.Properties.OriginModelName
- } else if originTask.Properties.UpstreamModelName != "" {
- info.OriginModelName = originTask.Properties.UpstreamModelName
- } else {
- var taskData map[string]interface{}
- _ = json.Unmarshal(originTask.Data, &taskData)
- if m, ok := taskData["model"].(string); ok && m != "" {
- info.OriginModelName = m
- platform = originTask.Platform
- }
- }
- }
- if originTask.ChannelId != info.ChannelId {
- channel, err := model.GetChannelById(originTask.ChannelId, true)
- if err != nil {
- taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
- return
- }
- if channel.Status != common.ChannelStatusEnabled {
- taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
- return
- }
- key, _, newAPIError := channel.GetNextEnabledKey()
- if newAPIError != nil {
- taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
- return
- }
- common.SetContextKey(c, constant.ContextKeyChannelKey, key)
- common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
- common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
- common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
+ // 查找原始任务
+ originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
+ if err != nil {
+ return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
+ }
+ if !exist {
+ return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
+ }
- info.ChannelBaseUrl = channel.GetBaseURL()
- info.ChannelId = originTask.ChannelId
- info.ChannelType = channel.Type
- info.ApiKey = key
- platform = originTask.Platform
- }
-
- // 使用原始任务的参数
- if info.Action == constant.TaskActionRemix {
+ // 从原始任务推导模型名称
+ if info.OriginModelName == "" {
+ if originTask.Properties.OriginModelName != "" {
+ info.OriginModelName = originTask.Properties.OriginModelName
+ } else if originTask.Properties.UpstreamModelName != "" {
+ info.OriginModelName = originTask.Properties.UpstreamModelName
+ } else {
var taskData map[string]interface{}
- _ = json.Unmarshal(originTask.Data, &taskData)
+ _ = common.Unmarshal(originTask.Data, &taskData)
+ if m, ok := taskData["model"].(string); ok && m != "" {
+ info.OriginModelName = m
+ }
+ }
+ }
+
+ // 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key)
+ ch, err := model.GetChannelById(originTask.ChannelId, true)
+ if err != nil {
+ return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
+ }
+ if ch.Status != common.ChannelStatusEnabled {
+ return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
+ }
+ info.LockedChannel = ch
+
+ if originTask.ChannelId != info.ChannelId {
+ key, _, newAPIError := ch.GetNextEnabledKey()
+ if newAPIError != nil {
+ return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
+ }
+ common.SetContextKey(c, constant.ContextKeyChannelKey, key)
+ common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
+ common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
+ common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
+
+ info.ChannelBaseUrl = ch.GetBaseURL()
+ info.ChannelId = originTask.ChannelId
+ info.ChannelType = ch.Type
+ info.ApiKey = key
+ }
+
+ // 提取 remix 参数(时长、分辨率 → OtherRatios)
+ if info.Action == constant.TaskActionRemix {
+ if originTask.PrivateData.BillingContext != nil {
+ // 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在)
+ for s, f := range originTask.PrivateData.BillingContext.OtherRatios {
+ info.PriceData.AddOtherRatio(s, f)
+ }
+ } else {
+ // 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在)
+ var taskData map[string]interface{}
+ _ = common.Unmarshal(originTask.Data, &taskData)
secondsStr, _ := taskData["seconds"].(string)
seconds, _ := strconv.Atoi(secondsStr)
if seconds <= 0 {
@@ -120,167 +130,146 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
}
}
}
+
+ return nil
+}
+
+// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
+// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
+// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
+// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
+// 控制器负责 defer Refund 和成功后 Settle。
+func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
+ info.InitChannelMeta(c)
+
+ // 1. 确定 platform → 创建适配器 → 验证请求
+ platform := constant.TaskPlatform(c.GetString("platform"))
if platform == "" {
platform = GetTaskPlatform(c)
}
-
- info.InitChannelMeta(c)
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
- return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
+ return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
}
adaptor.Init(info)
- // get & validate taskRequest 获取并验证文本请求
- taskErr = adaptor.ValidateRequestAndSetAction(c, info)
- if taskErr != nil {
- return
+ if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
+ return nil, taskErr
}
+ // 2. 确定模型名称
modelName := info.OriginModelName
if modelName == "" {
modelName = service.CoverTaskActionToModelName(platform, info.Action)
}
- modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
- if !success {
- defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
- if !ok {
- modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit
- } else {
- modelPrice = defaultPrice
+
+ // 2.5 应用渠道的模型映射(与同步任务对齐)
+ info.OriginModelName = modelName
+ info.UpstreamModelName = modelName
+ if err := helper.ModelMappedHelper(c, info, nil); err != nil {
+ return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
+ }
+
+ // 3. 预生成公开 task ID(仅首次)
+ if info.PublicTaskID == "" {
+ info.PublicTaskID = model.GenerateTaskID()
+ }
+
+ // 4. 价格计算:基础模型价格
+ info.OriginModelName = modelName
+ info.PriceData = helper.ModelPriceHelperPerCall(c, info)
+
+ // 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
+ // 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
+ // ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
+ if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
+ for k, v := range estimatedRatios {
+ info.PriceData.AddOtherRatio(k, v)
}
}
- // 处理 auto 分组:从 context 获取实际选中的分组
- // 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
- if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
- if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
- info.UsingGroup = groupStr
- }
- }
-
- // 预扣
- groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
- var ratio float64
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
- if hasUserGroupRatio {
- ratio = modelPrice * userGroupRatio
- } else {
- ratio = modelPrice * groupRatio
- }
- // FIXME: 临时修补,支持任务仅按次计费
+ // 6. 将 OtherRatios 应用到基础额度
if !common.StringsContains(constant.TaskPricePatches, modelName) {
- if len(info.PriceData.OtherRatios) > 0 {
- for _, ra := range info.PriceData.OtherRatios {
- if 1.0 != ra {
- ratio *= ra
- }
+ for _, ra := range info.PriceData.OtherRatios {
+ if ra != 1.0 {
+ info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
}
}
}
- println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
- userQuota, err := model.GetUserQuota(info.UserId, false)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
- return
- }
- quota := int(ratio * common.QuotaPerUnit)
- if userQuota-quota < 0 {
- taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
- return
+
+ // 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
+ if info.Billing == nil && !info.PriceData.FreeModel {
+ info.ForcePreConsume = true
+ if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
+ return nil, service.TaskErrorFromAPIError(apiErr)
+ }
}
- // build body
+ // 8. 构建请求体
requestBody, err := adaptor.BuildRequestBody(c, info)
if err != nil {
- taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
- return
+ return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
}
- // do request
+
+ // 9. 发送请求
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
- taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- return
+ return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
- // handle response
if resp != nil && resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
- taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
- return
+ return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
}
- defer func() {
- // release quota
- if info.ConsumeQuota && taskErr == nil {
+ // 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
+ otherRatios := info.PriceData.OtherRatios
+ if otherRatios == nil {
+ otherRatios = map[string]float64{}
+ }
+ ratiosJSON, _ := common.Marshal(otherRatios)
+ c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
- err := service.PostConsumeQuota(info, quota, 0, true)
- if err != nil {
- common.SysLog("error consuming token remain quota: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- //gRatio := groupRatio
- //if hasUserGroupRatio {
- // gRatio = userGroupRatio
- //}
- logContent := fmt.Sprintf("操作 %s", info.Action)
- // FIXME: 临时修补,支持任务仅按次计费
- if common.StringsContains(constant.TaskPricePatches, modelName) {
- logContent = fmt.Sprintf("%s,按次计费", logContent)
- } else {
- if len(info.PriceData.OtherRatios) > 0 {
- var contents []string
- for key, ra := range info.PriceData.OtherRatios {
- if 1.0 != ra {
- contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
- }
- }
- if len(contents) > 0 {
- logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
- }
- }
- }
- other := make(map[string]interface{})
- if c != nil && c.Request != nil && c.Request.URL != nil {
- other["request_path"] = c.Request.URL.Path
- }
- other["model_price"] = modelPrice
- other["group_ratio"] = groupRatio
- if hasUserGroupRatio {
- other["user_group_ratio"] = userGroupRatio
- }
- model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
- ChannelId: info.ChannelId,
- ModelName: modelName,
- TokenName: tokenName,
- Quota: quota,
- Content: logContent,
- TokenId: info.TokenId,
- Group: info.UsingGroup,
- Other: other,
- })
- model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
- model.UpdateChannelUsedQuota(info.ChannelId, quota)
- }
- }
- }()
-
- taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
+ // 11. 解析响应
+ upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
if taskErr != nil {
- return
+ return nil, taskErr
}
- info.ConsumeQuota = true
- // insert task
- task := model.InitTask(platform, info)
- task.TaskID = taskID
- task.Quota = quota
- task.Data = taskData
- task.Action = info.Action
- err = task.Insert()
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
- return
+
+ // 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
+ finalQuota := info.PriceData.Quota
+ if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
+ // 基于调整后的 ratios 重新计算 quota
+ finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
+ info.PriceData.OtherRatios = adjustedRatios
+ info.PriceData.Quota = finalQuota
}
- return nil
+
+ return &TaskSubmitResult{
+ UpstreamTaskID: upstreamTaskID,
+ TaskData: taskData,
+ Platform: platform,
+ Quota: finalQuota,
+ }, nil
+}
+
+// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
+// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
+func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
+ // 从 PriceData 获取不含 OtherRatios 的基础价格
+ baseQuota := info.PriceData.Quota
+ // 先除掉原有的 OtherRatios 恢复基础额度
+ for _, ra := range info.PriceData.OtherRatios {
+ if ra != 1.0 && ra > 0 {
+ baseQuota = int(float64(baseQuota) / ra)
+ }
+ }
+ // 应用新的 ratios
+ result := float64(baseQuota)
+ for _, ra := range ratios {
+ if ra != 1.0 {
+ result *= ra
+ }
+ }
+ return int(result)
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
@@ -336,7 +325,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta
} else {
tasks = make([]any, 0)
}
- respBody, err = json.Marshal(dto.TaskResponse[[]any]{
+ respBody, err = common.Marshal(dto.TaskResponse[[]any]{
Code: "success",
Data: tasks,
})
@@ -357,7 +346,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
return
}
- respBody, err = json.Marshal(dto.TaskResponse[any]{
+ respBody, err = common.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
@@ -381,97 +370,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
return
}
- func() {
- channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
- if err2 != nil {
- return
- }
- if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
- return
- }
- baseURL := constant.ChannelBaseURLs[channelModel.Type]
- if channelModel.GetBaseURL() != "" {
- baseURL = channelModel.GetBaseURL()
- }
- proxy := channelModel.GetSetting().Proxy
- adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
- if adaptor == nil {
- return
- }
- resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
- "task_id": originTask.TaskID,
- "action": originTask.Action,
- }, proxy)
- if err2 != nil || resp == nil {
- return
- }
- defer resp.Body.Close()
- body, err2 := io.ReadAll(resp.Body)
- if err2 != nil {
- return
- }
- ti, err2 := adaptor.ParseTaskResult(body)
- if err2 == nil && ti != nil {
- if ti.Status != "" {
- originTask.Status = model.TaskStatus(ti.Status)
- }
- if ti.Progress != "" {
- originTask.Progress = ti.Progress
- }
- if ti.Url != "" {
- if strings.HasPrefix(ti.Url, "data:") {
- } else {
- originTask.FailReason = ti.Url
- }
- }
- _ = originTask.Update()
- var raw map[string]any
- _ = json.Unmarshal(body, &raw)
- format := "mp4"
- if respObj, ok := raw["response"].(map[string]any); ok {
- if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
- if v0, ok := vids[0].(map[string]any); ok {
- if mt, ok := v0["mimeType"].(string); ok && mt != "" {
- if strings.Contains(mt, "mp4") {
- format = "mp4"
- } else {
- format = mt
- }
- }
- }
- }
- }
- status := "processing"
- switch originTask.Status {
- case model.TaskStatusSuccess:
- status = "succeeded"
- case model.TaskStatusFailure:
- status = "failed"
- case model.TaskStatusQueued, model.TaskStatusSubmitted:
- status = "queued"
- }
- if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
- out := map[string]any{
- "error": nil,
- "format": format,
- "metadata": nil,
- "status": status,
- "task_id": originTask.TaskID,
- "url": originTask.FailReason,
- }
- respBody, _ = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: out,
- })
- }
- }
- }()
+ isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
- if len(respBody) != 0 {
+ // Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
+ if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
+ respBody = realtimeResp
return
}
- if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
+ // OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
+ if isOpenAIVideoAPI {
adaptor := GetTaskAdaptor(originTask.Platform)
if adaptor == nil {
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
@@ -486,10 +394,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
respBody = openAIVideoData
return
}
- taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
+ taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
return
}
- respBody, err = json.Marshal(dto.TaskResponse[any]{
+
+ // 通用 TaskDto 格式
+ respBody, err = common.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
@@ -499,16 +409,150 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
return
}
+// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
+// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
+// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
+func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
+ channelModel, err := model.GetChannelById(task.ChannelId, true)
+ if err != nil {
+ return nil
+ }
+ if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
+ return nil
+ }
+
+ baseURL := constant.ChannelBaseURLs[channelModel.Type]
+ if channelModel.GetBaseURL() != "" {
+ baseURL = channelModel.GetBaseURL()
+ }
+ proxy := channelModel.GetSetting().Proxy
+ adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
+ if adaptor == nil {
+ return nil
+ }
+
+ resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
+ "task_id": task.GetUpstreamTaskID(),
+ "action": task.Action,
+ }, proxy)
+ if err != nil || resp == nil {
+ return nil
+ }
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil
+ }
+
+ ti, err := adaptor.ParseTaskResult(body)
+ if err != nil || ti == nil {
+ return nil
+ }
+
+ snap := task.Snapshot()
+
+ // 将上游最新状态更新到 task
+ if ti.Status != "" {
+ task.Status = model.TaskStatus(ti.Status)
+ }
+ if ti.Progress != "" {
+ task.Progress = ti.Progress
+ }
+ if strings.HasPrefix(ti.Url, "data:") {
+ // data: URI — kept in Data, not ResultURL
+ } else if ti.Url != "" {
+ task.PrivateData.ResultURL = ti.Url
+ } else if task.Status == model.TaskStatusSuccess {
+ // No URL from adaptor — construct proxy URL using public task ID
+ task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
+ }
+
+ if !snap.Equal(task.Snapshot()) {
+ _, _ = task.UpdateWithStatus(snap.Status)
+ }
+
+ // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
+ if isOpenAIVideoAPI {
+ return nil
+ }
+
+ // 非 OpenAI Video API: 构建自定义格式响应
+ format := detectVideoFormat(body)
+ out := map[string]any{
+ "error": nil,
+ "format": format,
+ "metadata": nil,
+ "status": mapTaskStatusToSimple(task.Status),
+ "task_id": task.TaskID,
+ "url": task.GetResultURL(),
+ }
+ respBody, _ := common.Marshal(dto.TaskResponse[any]{
+ Code: "success",
+ Data: out,
+ })
+ return respBody
+}
+
+// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
+func detectVideoFormat(rawBody []byte) string {
+ var raw map[string]any
+ if err := common.Unmarshal(rawBody, &raw); err != nil {
+ return "mp4"
+ }
+ respObj, ok := raw["response"].(map[string]any)
+ if !ok {
+ return "mp4"
+ }
+ vids, ok := respObj["videos"].([]any)
+ if !ok || len(vids) == 0 {
+ return "mp4"
+ }
+ v0, ok := vids[0].(map[string]any)
+ if !ok {
+ return "mp4"
+ }
+ mt, ok := v0["mimeType"].(string)
+ if !ok || mt == "" || strings.Contains(mt, "mp4") {
+ return "mp4"
+ }
+ return mt
+}
+
+// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
+func mapTaskStatusToSimple(status model.TaskStatus) string {
+ switch status {
+ case model.TaskStatusSuccess:
+ return "succeeded"
+ case model.TaskStatusFailure:
+ return "failed"
+ case model.TaskStatusQueued, model.TaskStatusSubmitted:
+ return "queued"
+ default:
+ return "processing"
+ }
+}
+
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
+ ID: task.ID,
+ CreatedAt: task.CreatedAt,
+ UpdatedAt: task.UpdatedAt,
TaskID: task.TaskID,
+ Platform: string(task.Platform),
+ UserId: task.UserId,
+ Group: task.Group,
+ ChannelId: task.ChannelId,
+ Quota: task.Quota,
Action: task.Action,
Status: string(task.Status),
FailReason: task.FailReason,
+ ResultURL: task.GetResultURL(),
SubmitTime: task.SubmitTime,
StartTime: task.StartTime,
FinishTime: task.FinishTime,
Progress: task.Progress,
+ Properties: task.Properties,
+ Username: task.Username,
Data: task.Data,
}
}
diff --git a/relay/responses_handler.go b/relay/responses_handler.go
index 3bcaa673f..18f1b7118 100644
--- a/relay/responses_handler.go
+++ b/relay/responses_handler.go
@@ -89,7 +89,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
// remove disabled fields for OpenAI Responses API
- jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+ jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
diff --git a/router/api-router.go b/router/api-router.go
index e2ef2f531..d48934000 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -13,6 +13,7 @@ import (
func SetApiRouter(router *gin.Engine) {
apiRouter := router.Group("/api")
+ apiRouter.Use(middleware.RouteTag("api"))
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
apiRouter.Use(middleware.BodyStorageCleanup()) // 清理请求体存储
apiRouter.Use(middleware.GlobalAPIRateLimit())
@@ -114,6 +115,9 @@ func SetApiRouter(router *gin.Engine) {
adminRoute.GET("/topup", controller.GetAllTopUps)
adminRoute.POST("/topup/complete", controller.AdminCompleteTopUp)
adminRoute.GET("/search", controller.SearchUsers)
+ adminRoute.GET("/:id/oauth/bindings", controller.GetUserOAuthBindingsByAdmin)
+ adminRoute.DELETE("/:id/oauth/bindings/:provider_id", controller.UnbindCustomOAuthByAdmin)
+ adminRoute.DELETE("/:id/bindings/:binding_type", controller.AdminClearUserBinding)
adminRoute.GET("/:id", controller.GetUser)
adminRoute.POST("/", controller.CreateUser)
adminRoute.POST("/manage", controller.ManageUser)
@@ -170,10 +174,11 @@ func SetApiRouter(router *gin.Engine) {
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
}
- // Custom OAuth provider management (admin only)
+ // Custom OAuth provider management (root only)
customOAuthRoute := apiRouter.Group("/custom-oauth-provider")
customOAuthRoute.Use(middleware.RootAuth())
{
+ customOAuthRoute.POST("/discovery", controller.FetchCustomOAuthDiscovery)
customOAuthRoute.GET("/", controller.GetCustomOAuthProviders)
customOAuthRoute.GET("/:id", controller.GetCustomOAuthProvider)
customOAuthRoute.POST("/", controller.CreateCustomOAuthProvider)
diff --git a/router/dashboard.go b/router/dashboard.go
index 17132dfb2..2e486156d 100644
--- a/router/dashboard.go
+++ b/router/dashboard.go
@@ -9,6 +9,7 @@ import (
func SetDashboardRouter(router *gin.Engine) {
apiRouter := router.Group("/")
+ apiRouter.Use(middleware.RouteTag("old_api"))
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
apiRouter.Use(middleware.GlobalAPIRateLimit())
apiRouter.Use(middleware.CORS())
diff --git a/router/main.go b/router/main.go
index 45b3080f2..ac9506fe4 100644
--- a/router/main.go
+++ b/router/main.go
@@ -8,6 +8,7 @@ import (
"strings"
"github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/middleware"
"github.com/gin-gonic/gin"
)
@@ -27,6 +28,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
} else {
frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
router.NoRoute(func(c *gin.Context) {
+ c.Set(middleware.RouteTagKey, "web")
c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
})
}
diff --git a/router/relay-router.go b/router/relay-router.go
index 04584945b..3d38be5ee 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -17,6 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.StatsMiddleware())
// https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models")
+ modelsRouter.Use(middleware.RouteTag("relay"))
modelsRouter.Use(middleware.TokenAuth())
{
modelsRouter.GET("", func(c *gin.Context) {
@@ -41,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
}
geminiRouter := router.Group("/v1beta/models")
+ geminiRouter.Use(middleware.RouteTag("relay"))
geminiRouter.Use(middleware.TokenAuth())
{
geminiRouter.GET("", func(c *gin.Context) {
@@ -49,6 +51,7 @@ func SetRelayRouter(router *gin.Engine) {
}
geminiCompatibleRouter := router.Group("/v1beta/openai/models")
+ geminiCompatibleRouter.Use(middleware.RouteTag("relay"))
geminiCompatibleRouter.Use(middleware.TokenAuth())
{
geminiCompatibleRouter.GET("", func(c *gin.Context) {
@@ -57,12 +60,14 @@ func SetRelayRouter(router *gin.Engine) {
}
playgroundRouter := router.Group("/pg")
+ playgroundRouter.Use(middleware.RouteTag("relay"))
playgroundRouter.Use(middleware.SystemPerformanceCheck())
playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
{
playgroundRouter.POST("/chat/completions", controller.Playground)
}
relayV1Router := router.Group("/v1")
+ relayV1Router.Use(middleware.RouteTag("relay"))
relayV1Router.Use(middleware.SystemPerformanceCheck())
relayV1Router.Use(middleware.TokenAuth())
relayV1Router.Use(middleware.ModelRequestRateLimit())
@@ -161,24 +166,28 @@ func SetRelayRouter(router *gin.Engine) {
}
relayMjRouter := router.Group("/mj")
+ relayMjRouter.Use(middleware.RouteTag("relay"))
relayMjRouter.Use(middleware.SystemPerformanceCheck())
registerMjRouterGroup(relayMjRouter)
relayMjModeRouter := router.Group("/:mode/mj")
+ relayMjModeRouter.Use(middleware.RouteTag("relay"))
relayMjModeRouter.Use(middleware.SystemPerformanceCheck())
registerMjRouterGroup(relayMjModeRouter)
//relayMjRouter.Use()
relaySunoRouter := router.Group("/suno")
+ relaySunoRouter.Use(middleware.RouteTag("relay"))
relaySunoRouter.Use(middleware.SystemPerformanceCheck())
relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
relaySunoRouter.POST("/submit/:action", controller.RelayTask)
- relaySunoRouter.POST("/fetch", controller.RelayTask)
- relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
+ relaySunoRouter.POST("/fetch", controller.RelayTaskFetch)
+ relaySunoRouter.GET("/fetch/:id", controller.RelayTaskFetch)
}
relayGeminiRouter := router.Group("/v1beta")
+ relayGeminiRouter.Use(middleware.RouteTag("relay"))
relayGeminiRouter.Use(middleware.SystemPerformanceCheck())
relayGeminiRouter.Use(middleware.TokenAuth())
relayGeminiRouter.Use(middleware.ModelRequestRateLimit())
diff --git a/router/video-router.go b/router/video-router.go
index d5fed1d78..461451104 100644
--- a/router/video-router.go
+++ b/router/video-router.go
@@ -8,32 +8,42 @@ import (
)
func SetVideoRouter(router *gin.Engine) {
+ // Video proxy: accepts either session auth (dashboard) or token auth (API clients)
+ videoProxyRouter := router.Group("/v1")
+ videoProxyRouter.Use(middleware.RouteTag("relay"))
+ videoProxyRouter.Use(middleware.TokenOrUserAuth())
+ {
+ videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy)
+ }
+
videoV1Router := router.Group("/v1")
+ videoV1Router.Use(middleware.RouteTag("relay"))
videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{
- videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
videoV1Router.POST("/video/generations", controller.RelayTask)
- videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
+ videoV1Router.GET("/video/generations/:task_id", controller.RelayTaskFetch)
videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
}
// openai compatible API video routes
// docs: https://platform.openai.com/docs/api-reference/videos/create
{
videoV1Router.POST("/videos", controller.RelayTask)
- videoV1Router.GET("/videos/:task_id", controller.RelayTask)
+ videoV1Router.GET("/videos/:task_id", controller.RelayTaskFetch)
}
klingV1Router := router.Group("/kling/v1")
+ klingV1Router.Use(middleware.RouteTag("relay"))
klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
{
klingV1Router.POST("/videos/text2video", controller.RelayTask)
klingV1Router.POST("/videos/image2video", controller.RelayTask)
- klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask)
- klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask)
+ klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTaskFetch)
+ klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTaskFetch)
}
// Jimeng official API routes - direct mapping to official API format
jimengOfficialGroup := router.Group("jimeng")
+ jimengOfficialGroup.Use(middleware.RouteTag("relay"))
jimengOfficialGroup.Use(middleware.JimengRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
{
// Maps to: /?Action=CVSync2AsyncSubmitTask&Version=2022-08-31 and /?Action=CVSync2AsyncGetResult&Version=2022-08-31
diff --git a/router/web-router.go b/router/web-router.go
index b053a3e63..17a8378dd 100644
--- a/router/web-router.go
+++ b/router/web-router.go
@@ -19,6 +19,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
router.Use(middleware.Cache())
router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/dist")))
router.NoRoute(func(c *gin.Context) {
+ c.Set(middleware.RouteTagKey, "web")
if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") || strings.HasPrefix(c.Request.RequestURI, "/assets") {
controller.RelayNotFound(c)
return
diff --git a/service/billing_session.go b/service/billing_session.go
index 1a31316b5..f24b68e55 100644
--- a/service/billing_session.go
+++ b/service/billing_session.go
@@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro
// shouldTrust 统一信任额度检查,适用于钱包和订阅。
func (s *BillingSession) shouldTrust(c *gin.Context) bool {
+ // 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路
+ if s.relayInfo.ForcePreConsume {
+ return false
+ }
+
trustQuota := common.GetTrustQuota()
if trustQuota <= 0 {
return false
diff --git a/service/channel_affinity.go b/service/channel_affinity.go
index fe1524c59..524c6574a 100644
--- a/service/channel_affinity.go
+++ b/service/channel_affinity.go
@@ -13,6 +13,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/pkg/cachex"
"github.com/QuantumNous/new-api/setting/operation_setting"
+ "github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/samber/hot"
"github.com/tidwall/gjson"
@@ -61,6 +62,12 @@ type ChannelAffinityStatsContext struct {
TTLSeconds int64
}
+const (
+ cacheTokenRateModeCachedOverPrompt = "cached_over_prompt"
+ cacheTokenRateModeCachedOverPromptPlusCached = "cached_over_prompt_plus_cached"
+ cacheTokenRateModeMixed = "mixed"
+)
+
type ChannelAffinityCacheStats struct {
Enabled bool `json:"enabled"`
Total int `json:"total"`
@@ -565,9 +572,10 @@ func RecordChannelAffinity(c *gin.Context, channelID int) {
}
type ChannelAffinityUsageCacheStats struct {
- RuleName string `json:"rule_name"`
- UsingGroup string `json:"using_group"`
- KeyFingerprint string `json:"key_fp"`
+ RuleName string `json:"rule_name"`
+ UsingGroup string `json:"using_group"`
+ KeyFingerprint string `json:"key_fp"`
+ CachedTokenRateMode string `json:"cached_token_rate_mode"`
Hit int64 `json:"hit"`
Total int64 `json:"total"`
@@ -582,6 +590,8 @@ type ChannelAffinityUsageCacheStats struct {
}
type ChannelAffinityUsageCacheCounters struct {
+ CachedTokenRateMode string `json:"cached_token_rate_mode"`
+
Hit int64 `json:"hit"`
Total int64 `json:"total"`
WindowSeconds int64 `json:"window_seconds"`
@@ -596,12 +606,17 @@ type ChannelAffinityUsageCacheCounters struct {
var channelAffinityUsageCacheStatsLocks [64]sync.Mutex
-func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage) {
+// ObserveChannelAffinityUsageCacheByRelayFormat records usage cache stats with a stable rate mode derived from relay format.
+func ObserveChannelAffinityUsageCacheByRelayFormat(c *gin.Context, usage *dto.Usage, relayFormat types.RelayFormat) {
+ ObserveChannelAffinityUsageCacheFromContext(c, usage, cachedTokenRateModeByRelayFormat(relayFormat))
+}
+
+func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage, cachedTokenRateMode string) {
statsCtx, ok := GetChannelAffinityStatsContext(c)
if !ok {
return
}
- observeChannelAffinityUsageCache(statsCtx, usage)
+ observeChannelAffinityUsageCache(statsCtx, usage, cachedTokenRateMode)
}
func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) ChannelAffinityUsageCacheStats {
@@ -628,6 +643,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann
}
}
return ChannelAffinityUsageCacheStats{
+ CachedTokenRateMode: v.CachedTokenRateMode,
RuleName: ruleName,
UsingGroup: usingGroup,
KeyFingerprint: keyFp,
@@ -643,7 +659,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann
}
}
-func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage) {
+func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage, cachedTokenRateMode string) {
entryKey := channelAffinityUsageCacheEntryKey(statsCtx.RuleName, statsCtx.UsingGroup, statsCtx.KeyFingerprint)
if entryKey == "" {
return
@@ -669,6 +685,14 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag
if !found {
next = ChannelAffinityUsageCacheCounters{}
}
+ currentMode := normalizeCachedTokenRateMode(cachedTokenRateMode)
+ if currentMode != "" {
+ if next.CachedTokenRateMode == "" {
+ next.CachedTokenRateMode = currentMode
+ } else if next.CachedTokenRateMode != currentMode && next.CachedTokenRateMode != cacheTokenRateModeMixed {
+ next.CachedTokenRateMode = cacheTokenRateModeMixed
+ }
+ }
next.Total++
hit, cachedTokens, promptCacheHitTokens := usageCacheSignals(usage)
if hit {
@@ -684,6 +708,30 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag
_ = cache.SetWithTTL(entryKey, next, ttl)
}
+func normalizeCachedTokenRateMode(mode string) string {
+ switch mode {
+ case cacheTokenRateModeCachedOverPrompt:
+ return cacheTokenRateModeCachedOverPrompt
+ case cacheTokenRateModeCachedOverPromptPlusCached:
+ return cacheTokenRateModeCachedOverPromptPlusCached
+ case cacheTokenRateModeMixed:
+ return cacheTokenRateModeMixed
+ default:
+ return ""
+ }
+}
+
+func cachedTokenRateModeByRelayFormat(relayFormat types.RelayFormat) string {
+ switch relayFormat {
+ case types.RelayFormatOpenAI, types.RelayFormatOpenAIResponses, types.RelayFormatOpenAIResponsesCompaction:
+ return cacheTokenRateModeCachedOverPrompt
+ case types.RelayFormatClaude:
+ return cacheTokenRateModeCachedOverPromptPlusCached
+ default:
+ return ""
+ }
+}
+
func channelAffinityUsageCacheEntryKey(ruleName, usingGroup, keyFp string) string {
ruleName = strings.TrimSpace(ruleName)
usingGroup = strings.TrimSpace(usingGroup)
diff --git a/service/channel_affinity_usage_cache_test.go b/service/channel_affinity_usage_cache_test.go
new file mode 100644
index 000000000..64d3d715b
--- /dev/null
+++ b/service/channel_affinity_usage_cache_test.go
@@ -0,0 +1,105 @@
+package service
+
+import (
+ "fmt"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/QuantumNous/new-api/dto"
+ "github.com/QuantumNous/new-api/types"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP string) *gin.Context {
+ rec := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(rec)
+ setChannelAffinityContext(ctx, channelAffinityMeta{
+ CacheKey: fmt.Sprintf("test:%s:%s:%s", ruleName, usingGroup, keyFP),
+ TTLSeconds: 600,
+ RuleName: ruleName,
+ UsingGroup: usingGroup,
+ KeyFingerprint: keyFP,
+ })
+ return ctx
+}
+
+func TestObserveChannelAffinityUsageCacheByRelayFormat_ClaudeMode(t *testing.T) {
+ ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
+ usingGroup := "default"
+ keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
+ ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
+
+ usage := &dto.Usage{
+ PromptTokens: 100,
+ CompletionTokens: 40,
+ TotalTokens: 140,
+ PromptTokensDetails: dto.InputTokenDetails{
+ CachedTokens: 30,
+ },
+ }
+
+ ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatClaude)
+ stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
+
+ require.EqualValues(t, 1, stats.Total)
+ require.EqualValues(t, 1, stats.Hit)
+ require.EqualValues(t, 100, stats.PromptTokens)
+ require.EqualValues(t, 40, stats.CompletionTokens)
+ require.EqualValues(t, 140, stats.TotalTokens)
+ require.EqualValues(t, 30, stats.CachedTokens)
+ require.Equal(t, cacheTokenRateModeCachedOverPromptPlusCached, stats.CachedTokenRateMode)
+}
+
+func TestObserveChannelAffinityUsageCacheByRelayFormat_MixedMode(t *testing.T) {
+ ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
+ usingGroup := "default"
+ keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
+ ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
+
+ openAIUsage := &dto.Usage{
+ PromptTokens: 100,
+ PromptTokensDetails: dto.InputTokenDetails{
+ CachedTokens: 10,
+ },
+ }
+ claudeUsage := &dto.Usage{
+ PromptTokens: 80,
+ PromptTokensDetails: dto.InputTokenDetails{
+ CachedTokens: 20,
+ },
+ }
+
+ ObserveChannelAffinityUsageCacheByRelayFormat(ctx, openAIUsage, types.RelayFormatOpenAI)
+ ObserveChannelAffinityUsageCacheByRelayFormat(ctx, claudeUsage, types.RelayFormatClaude)
+ stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
+
+ require.EqualValues(t, 2, stats.Total)
+ require.EqualValues(t, 2, stats.Hit)
+ require.EqualValues(t, 180, stats.PromptTokens)
+ require.EqualValues(t, 30, stats.CachedTokens)
+ require.Equal(t, cacheTokenRateModeMixed, stats.CachedTokenRateMode)
+}
+
+func TestObserveChannelAffinityUsageCacheByRelayFormat_UnsupportedModeKeepsEmpty(t *testing.T) {
+ ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
+ usingGroup := "default"
+ keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
+ ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
+
+ usage := &dto.Usage{
+ PromptTokens: 100,
+ PromptTokensDetails: dto.InputTokenDetails{
+ CachedTokens: 25,
+ },
+ }
+
+ ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatGemini)
+ stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
+
+ require.EqualValues(t, 1, stats.Total)
+ require.EqualValues(t, 1, stats.Hit)
+ require.EqualValues(t, 25, stats.CachedTokens)
+ require.Equal(t, "", stats.CachedTokenRateMode)
+}
diff --git a/service/codex_credential_refresh.go b/service/codex_credential_refresh.go
index 0290fe516..2e681ee61 100644
--- a/service/codex_credential_refresh.go
+++ b/service/codex_credential_refresh.go
@@ -62,7 +62,7 @@ func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts Code
refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
- res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
+ res, err := RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
if err != nil {
return nil, nil, err
}
diff --git a/service/codex_oauth.go b/service/codex_oauth.go
index 4c2dce1cc..33ef1d60a 100644
--- a/service/codex_oauth.go
+++ b/service/codex_oauth.go
@@ -12,6 +12,8 @@ import (
"net/url"
"strings"
"time"
+
+ "github.com/QuantumNous/new-api/common"
)
const (
@@ -38,12 +40,26 @@ type CodexOAuthAuthorizationFlow struct {
}
func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) {
- client := &http.Client{Timeout: defaultHTTPTimeout}
+ return RefreshCodexOAuthTokenWithProxy(ctx, refreshToken, "")
+}
+
+func RefreshCodexOAuthTokenWithProxy(ctx context.Context, refreshToken string, proxyURL string) (*CodexOAuthTokenResult, error) {
+ client, err := getCodexOAuthHTTPClient(proxyURL)
+ if err != nil {
+ return nil, err
+ }
return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken)
}
func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) {
- client := &http.Client{Timeout: defaultHTTPTimeout}
+ return ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, "")
+}
+
+func ExchangeCodexAuthorizationCodeWithProxy(ctx context.Context, code string, verifier string, proxyURL string) (*CodexOAuthTokenResult, error) {
+ client, err := getCodexOAuthHTTPClient(proxyURL)
+ if err != nil {
+ return nil, err
+ }
return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI)
}
@@ -104,7 +120,7 @@ func refreshCodexOAuthToken(
ExpiresIn int `json:"expires_in"`
}
- if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
+ if err := common.DecodeJson(resp.Body, &payload); err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
@@ -165,7 +181,7 @@ func exchangeCodexAuthorizationCode(
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
- if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
+ if err := common.DecodeJson(resp.Body, &payload); err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
@@ -181,6 +197,19 @@ func exchangeCodexAuthorizationCode(
}, nil
}
+func getCodexOAuthHTTPClient(proxyURL string) (*http.Client, error) {
+ baseClient, err := GetHttpClientWithProxy(strings.TrimSpace(proxyURL))
+ if err != nil {
+ return nil, err
+ }
+ if baseClient == nil {
+ return &http.Client{Timeout: defaultHTTPTimeout}, nil
+ }
+ clientCopy := *baseClient
+ clientCopy.Timeout = defaultHTTPTimeout
+ return &clientCopy, nil
+}
+
func buildCodexAuthorizeURL(state string, challenge string) (string, error) {
u, err := url.Parse(codexOAuthAuthorizeURL)
if err != nil {
diff --git a/service/error.go b/service/error.go
index 7a9d7a815..a2ff0aad7 100644
--- a/service/error.go
+++ b/service/error.go
@@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
return taskError
}
+
+// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。
+func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError {
+ if apiErr == nil {
+ return nil
+ }
+ return &dto.TaskError{
+ Code: string(apiErr.GetErrorCode()),
+ Message: apiErr.Err.Error(),
+ StatusCode: apiErr.StatusCode,
+ Error: apiErr.Err,
+ }
+}
diff --git a/service/log_info_generate.go b/service/log_info_generate.go
index 771da5b77..1c440911b 100644
--- a/service/log_info_generate.go
+++ b/service/log_info_generate.go
@@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
return info
}
-func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} {
+func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} {
other := make(map[string]interface{})
other["model_price"] = priceData.ModelPrice
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
diff --git a/service/midjourney.go b/service/midjourney.go
index 9b2eb5ca7..bdb0fe50a 100644
--- a/service/midjourney.go
+++ b/service/midjourney.go
@@ -19,7 +19,7 @@ import (
"github.com/gin-gonic/gin"
)
-func CoverActionToModelName(mjAction string) string {
+func CovertMjpActionToModelName(mjAction string) string {
modelName := "mj_" + strings.ToLower(mjAction)
if mjAction == constant.MjActionSwapFace {
modelName = "swap_face"
@@ -70,7 +70,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
}
}
- modelName := CoverActionToModelName(action)
+ modelName := CovertMjpActionToModelName(action)
return modelName, nil, true
}
diff --git a/service/quota.go b/service/quota.go
index 50421017e..7ee70edd5 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -236,6 +236,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
}
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
+ if usage != nil {
+ ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
+ }
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
diff --git a/service/task_billing.go b/service/task_billing.go
new file mode 100644
index 000000000..0da4cf431
--- /dev/null
+++ b/service/task_billing.go
@@ -0,0 +1,285 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/constant"
+ "github.com/QuantumNous/new-api/logger"
+ "github.com/QuantumNous/new-api/model"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
+ "github.com/QuantumNous/new-api/setting/ratio_setting"
+ "github.com/gin-gonic/gin"
+)
+
+// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
+// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
+func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("操作 %s", info.Action)
+ // 支持任务仅按次计费
+ if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
+ logContent = fmt.Sprintf("%s,按次计费", logContent)
+ } else {
+ if len(info.PriceData.OtherRatios) > 0 {
+ var contents []string
+ for key, ra := range info.PriceData.OtherRatios {
+ if 1.0 != ra {
+ contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
+ }
+ }
+ if len(contents) > 0 {
+ logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
+ }
+ }
+ }
+ other := make(map[string]interface{})
+ other["request_path"] = c.Request.URL.Path
+ other["model_price"] = info.PriceData.ModelPrice
+ other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
+ if info.PriceData.GroupRatioInfo.HasSpecialRatio {
+ other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
+ }
+ if info.IsModelMapped {
+ other["is_model_mapped"] = true
+ other["upstream_model_name"] = info.UpstreamModelName
+ }
+ model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
+ ChannelId: info.ChannelId,
+ ModelName: info.OriginModelName,
+ TokenName: tokenName,
+ Quota: info.PriceData.Quota,
+ Content: logContent,
+ TokenId: info.TokenId,
+ Group: info.UsingGroup,
+ Other: other,
+ })
+ model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
+ model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
+}
+
+// ---------------------------------------------------------------------------
+// 异步任务计费辅助函数
+// ---------------------------------------------------------------------------
+
+// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
+// 如果令牌已被删除或查询失败,返回空字符串。
+func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
+ token, err := model.GetTokenById(tokenId)
+ if err != nil {
+ logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
+ return ""
+ }
+ return token.Key
+}
+
+// taskIsSubscription 判断任务是否通过订阅计费。
+func taskIsSubscription(task *model.Task) bool {
+ return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
+}
+
+// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
+func taskAdjustFunding(task *model.Task, delta int) error {
+ if taskIsSubscription(task) {
+ return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
+ }
+ if delta > 0 {
+ return model.DecreaseUserQuota(task.UserId, delta)
+ }
+ return model.IncreaseUserQuota(task.UserId, -delta, false)
+}
+
+// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
+// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
+func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
+ if task.PrivateData.TokenId <= 0 || delta == 0 {
+ return
+ }
+ tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
+ if tokenKey == "" {
+ return
+ }
+ var err error
+ if delta > 0 {
+ err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
+ } else {
+ err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
+ }
+ if err != nil {
+ logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
+ }
+}
+
+// taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。
+func taskBillingOther(task *model.Task) map[string]interface{} {
+ other := make(map[string]interface{})
+ if bc := task.PrivateData.BillingContext; bc != nil {
+ other["model_price"] = bc.ModelPrice
+ other["group_ratio"] = bc.GroupRatio
+ if len(bc.OtherRatios) > 0 {
+ for k, v := range bc.OtherRatios {
+ other[k] = v
+ }
+ }
+ }
+ props := task.Properties
+ if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
+ other["is_model_mapped"] = true
+ other["upstream_model_name"] = props.UpstreamModelName
+ }
+ return other
+}
+
+// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
+func taskModelName(task *model.Task) string {
+ if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
+ return bc.OriginModelName
+ }
+ return task.Properties.OriginModelName
+}
+
+// RefundTaskQuota 统一的任务失败退款逻辑。
+// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
+func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
+ quota := task.Quota
+ if quota == 0 {
+ return
+ }
+
+ // 1. 退还资金来源(钱包或订阅)
+ if err := taskAdjustFunding(task, -quota); err != nil {
+ logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
+ return
+ }
+
+ // 2. 退还令牌额度
+ taskAdjustTokenQuota(ctx, task, -quota)
+
+ // 3. 记录日志
+ other := taskBillingOther(task)
+ other["task_id"] = task.TaskID
+ other["reason"] = reason
+ model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
+ UserId: task.UserId,
+ LogType: model.LogTypeRefund,
+ Content: "",
+ ChannelId: task.ChannelId,
+ ModelName: taskModelName(task),
+ Quota: quota,
+ TokenId: task.PrivateData.TokenId,
+ Group: task.Group,
+ Other: other,
+ })
+}
+
+// RecalculateTaskQuota 通用的异步差额结算。
+// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
+// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
+func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
+ if actualQuota <= 0 {
+ return
+ }
+ preConsumedQuota := task.Quota
+ quotaDelta := actualQuota - preConsumedQuota
+
+ if quotaDelta == 0 {
+ logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)",
+ task.TaskID, logger.LogQuota(actualQuota), reason))
+ return
+ }
+
+ logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)",
+ task.TaskID,
+ logger.LogQuota(quotaDelta),
+ logger.LogQuota(actualQuota),
+ logger.LogQuota(preConsumedQuota),
+ reason,
+ ))
+
+ // 调整资金来源
+ if err := taskAdjustFunding(task, quotaDelta); err != nil {
+ logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
+ return
+ }
+
+ // 调整令牌额度
+ taskAdjustTokenQuota(ctx, task, quotaDelta)
+
+ task.Quota = actualQuota
+
+ var logType int
+ var logQuota int
+ if quotaDelta > 0 {
+ logType = model.LogTypeConsume
+ logQuota = quotaDelta
+ model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
+ model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
+ } else {
+ logType = model.LogTypeRefund
+ logQuota = -quotaDelta
+ }
+ other := taskBillingOther(task)
+ other["task_id"] = task.TaskID
+ other["reason"] = reason
+ other["pre_consumed_quota"] = preConsumedQuota
+ other["actual_quota"] = actualQuota
+ model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
+ UserId: task.UserId,
+ LogType: logType,
+ Content: "",
+ ChannelId: task.ChannelId,
+ ModelName: taskModelName(task),
+ Quota: logQuota,
+ TokenId: task.PrivateData.TokenId,
+ Group: task.Group,
+ Other: other,
+ })
+}
+
+// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
+// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
+// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
+func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
+ if totalTokens <= 0 {
+ return
+ }
+
+ modelName := taskModelName(task)
+
+ // 获取模型价格和倍率
+ modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
+ // 只有配置了倍率(非固定价格)时才按 token 重新计费
+ if !hasRatioSetting || modelRatio <= 0 {
+ return
+ }
+
+ // 获取用户和组的倍率信息
+ group := task.Group
+ if group == "" {
+ user, err := model.GetUserById(task.UserId, false)
+ if err == nil {
+ group = user.Group
+ }
+ }
+ if group == "" {
+ return
+ }
+
+ groupRatio := ratio_setting.GetGroupRatio(group)
+ userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
+
+ var finalGroupRatio float64
+ if hasUserGroupRatio {
+ finalGroupRatio = userGroupRatio
+ } else {
+ finalGroupRatio = groupRatio
+ }
+
+ // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
+ actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
+
+ reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
+ RecalculateTaskQuota(ctx, task, actualQuota, reason)
+}
diff --git a/service/task_billing_test.go b/service/task_billing_test.go
new file mode 100644
index 000000000..1145bba54
--- /dev/null
+++ b/service/task_billing_test.go
@@ -0,0 +1,712 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/model"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
+ "github.com/glebarez/sqlite"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gorm.io/gorm"
+)
+
+func TestMain(m *testing.M) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ panic("failed to open test db: " + err.Error())
+ }
+ sqlDB, err := db.DB()
+ if err != nil {
+ panic("failed to get sql.DB: " + err.Error())
+ }
+ sqlDB.SetMaxOpenConns(1)
+
+ model.DB = db
+ model.LOG_DB = db
+
+ common.UsingSQLite = true
+ common.RedisEnabled = false
+ common.BatchUpdateEnabled = false
+ common.LogConsumeEnabled = true
+
+ if err := db.AutoMigrate(
+ &model.Task{},
+ &model.User{},
+ &model.Token{},
+ &model.Log{},
+ &model.Channel{},
+ &model.UserSubscription{},
+ ); err != nil {
+ panic("failed to migrate: " + err.Error())
+ }
+
+ os.Exit(m.Run())
+}
+
+// ---------------------------------------------------------------------------
+// Seed helpers
+// ---------------------------------------------------------------------------
+
+func truncate(t *testing.T) {
+ t.Helper()
+ t.Cleanup(func() {
+ model.DB.Exec("DELETE FROM tasks")
+ model.DB.Exec("DELETE FROM users")
+ model.DB.Exec("DELETE FROM tokens")
+ model.DB.Exec("DELETE FROM logs")
+ model.DB.Exec("DELETE FROM channels")
+ model.DB.Exec("DELETE FROM user_subscriptions")
+ })
+}
+
+func seedUser(t *testing.T, id int, quota int) {
+ t.Helper()
+ user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
+ require.NoError(t, model.DB.Create(user).Error)
+}
+
+func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
+ t.Helper()
+ token := &model.Token{
+ Id: id,
+ UserId: userId,
+ Key: key,
+ Name: "test_token",
+ Status: common.TokenStatusEnabled,
+ RemainQuota: remainQuota,
+ UsedQuota: 0,
+ }
+ require.NoError(t, model.DB.Create(token).Error)
+}
+
+func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
+ t.Helper()
+ sub := &model.UserSubscription{
+ Id: id,
+ UserId: userId,
+ AmountTotal: amountTotal,
+ AmountUsed: amountUsed,
+ Status: "active",
+ StartTime: time.Now().Unix(),
+ EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(),
+ }
+ require.NoError(t, model.DB.Create(sub).Error)
+}
+
+func seedChannel(t *testing.T, id int) {
+ t.Helper()
+ ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
+ require.NoError(t, model.DB.Create(ch).Error)
+}
+
+func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
+ return &model.Task{
+ TaskID: "task_" + time.Now().Format("150405.000"),
+ UserId: userId,
+ ChannelId: channelId,
+ Quota: quota,
+ Status: model.TaskStatus(model.TaskStatusInProgress),
+ Group: "default",
+ Data: json.RawMessage(`{}`),
+ CreatedAt: time.Now().Unix(),
+ UpdatedAt: time.Now().Unix(),
+ Properties: model.Properties{
+ OriginModelName: "test-model",
+ },
+ PrivateData: model.TaskPrivateData{
+ BillingSource: billingSource,
+ SubscriptionId: subscriptionId,
+ TokenId: tokenId,
+ BillingContext: &model.TaskBillingContext{
+ ModelPrice: 0.02,
+ GroupRatio: 1.0,
+ OriginModelName: "test-model",
+ },
+ },
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Read-back helpers
+// ---------------------------------------------------------------------------
+
+func getUserQuota(t *testing.T, id int) int {
+ t.Helper()
+ var user model.User
+ require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
+ return user.Quota
+}
+
+func getTokenRemainQuota(t *testing.T, id int) int {
+ t.Helper()
+ var token model.Token
+ require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
+ return token.RemainQuota
+}
+
+func getTokenUsedQuota(t *testing.T, id int) int {
+ t.Helper()
+ var token model.Token
+ require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
+ return token.UsedQuota
+}
+
+func getSubscriptionUsed(t *testing.T, id int) int64 {
+ t.Helper()
+ var sub model.UserSubscription
+ require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
+ return sub.AmountUsed
+}
+
+func getLastLog(t *testing.T) *model.Log {
+ t.Helper()
+ var log model.Log
+ err := model.LOG_DB.Order("id desc").First(&log).Error
+ if err != nil {
+ return nil
+ }
+ return &log
+}
+
+func countLogs(t *testing.T) int64 {
+ t.Helper()
+ var count int64
+ model.LOG_DB.Model(&model.Log{}).Count(&count)
+ return count
+}
+
+// ===========================================================================
+// RefundTaskQuota tests
+// ===========================================================================
+
+func TestRefundTaskQuota_Wallet(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 1, 1, 1
+ const initQuota, preConsumed = 10000, 3000
+ const tokenRemain = 5000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+ RefundTaskQuota(ctx, task, "task failed: upstream error")
+
+ // User quota should increase by preConsumed
+ assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+
+ // Token remain_quota should increase, used_quota should decrease
+ assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+ assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
+
+ // A refund log should be created
+ log := getLastLog(t)
+ require.NotNil(t, log)
+ assert.Equal(t, model.LogTypeRefund, log.Type)
+ assert.Equal(t, preConsumed, log.Quota)
+ assert.Equal(t, "test-model", log.ModelName)
+}
+
+func TestRefundTaskQuota_Subscription(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID, subID = 2, 2, 2, 1
+ const preConsumed = 2000
+ const subTotal, subUsed int64 = 100000, 50000
+ const tokenRemain = 8000
+
+ seedUser(t, userID, 0)
+ seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
+ seedChannel(t, channelID)
+ seedSubscription(t, subID, userID, subTotal, subUsed)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
+
+ RefundTaskQuota(ctx, task, "subscription task failed")
+
+ // Subscription used should decrease by preConsumed
+ assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
+
+ // Token should also be refunded
+ assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+
+ log := getLastLog(t)
+ require.NotNil(t, log)
+ assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID = 3
+ seedUser(t, userID, 5000)
+
+ task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
+
+ RefundTaskQuota(ctx, task, "zero quota task")
+
+ // No change to user quota
+ assert.Equal(t, 5000, getUserQuota(t, userID))
+
+ // No log created
+ assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRefundTaskQuota_NoToken(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, channelID = 4, 4
+ const initQuota, preConsumed = 10000, 1500
+
+ seedUser(t, userID, initQuota)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
+
+ RefundTaskQuota(ctx, task, "no token task failed")
+
+ // User quota refunded
+ assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+
+ // Log created
+ log := getLastLog(t)
+ require.NotNil(t, log)
+ assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+// ===========================================================================
+// RecalculateTaskQuota tests
+// ===========================================================================
+
+func TestRecalculate_PositiveDelta(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 10, 10, 10
+ const initQuota, preConsumed = 10000, 2000
+ const actualQuota = 3000 // under-charged by 1000
+ const tokenRemain = 5000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+ RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
+
+ // User quota should decrease by the delta (1000 additional charge)
+ assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
+
+ // Token should also be charged the delta
+ assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
+
+ // task.Quota should be updated to actualQuota
+ assert.Equal(t, actualQuota, task.Quota)
+
+ // Log type should be Consume (additional charge)
+ log := getLastLog(t)
+ require.NotNil(t, log)
+ assert.Equal(t, model.LogTypeConsume, log.Type)
+ assert.Equal(t, actualQuota-preConsumed, log.Quota)
+}
+
+func TestRecalculate_NegativeDelta(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 11, 11, 11
+ const initQuota, preConsumed = 10000, 5000
+ const actualQuota = 3000 // over-charged by 2000
+ const tokenRemain = 5000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+ RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
+
+ // User quota should increase by abs(delta) = 2000 (refund overpayment)
+ assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
+
+ // Token should be refunded the difference
+ assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+ // task.Quota updated
+ assert.Equal(t, actualQuota, task.Quota)
+
+ // Log type should be Refund
+ log := getLastLog(t)
+ require.NotNil(t, log)
+ assert.Equal(t, model.LogTypeRefund, log.Type)
+ assert.Equal(t, preConsumed-actualQuota, log.Quota)
+}
+
+func TestRecalculate_ZeroDelta(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID = 12
+ const initQuota, preConsumed = 10000, 3000
+
+ seedUser(t, userID, initQuota)
+
+ task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
+
+ RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
+
+ // No change to user quota
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
+
+ // No log created (delta is zero)
+ assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRecalculate_ActualQuotaZero(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID = 13
+ const initQuota = 10000
+
+ seedUser(t, userID, initQuota)
+
+ task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
+
+ RecalculateTaskQuota(ctx, task, 0, "zero actual")
+
+ // No change (early return)
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
+ assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID, subID = 14, 14, 14, 2
+ const preConsumed = 5000
+ const actualQuota = 2000 // over-charged by 3000
+ const subTotal, subUsed int64 = 100000, 50000
+ const tokenRemain = 8000
+
+ seedUser(t, userID, 0)
+ seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
+ seedChannel(t, channelID)
+ seedSubscription(t, subID, userID, subTotal, subUsed)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
+
+ RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
+
+ // Subscription used should decrease by delta (refund 3000)
+ assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
+
+ // Token refunded
+ assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+ assert.Equal(t, actualQuota, task.Quota)
+
+ log := getLastLog(t)
+ require.NotNil(t, log)
+ assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+// ===========================================================================
+// CAS + Billing integration tests
+// Simulates the flow in updateVideoSingleTask (service/task_polling.go)
+// ===========================================================================
+
+// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
+// It takes a persisted task (already in DB), applies the new status, and performs
+// the conditional update + billing exactly as the polling loop does.
+func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
+ snap := task.Snapshot()
+
+ shouldRefund := false
+ shouldSettle := false
+ quota := task.Quota
+
+ task.Status = newStatus
+ switch string(newStatus) {
+ case model.TaskStatusSuccess:
+ task.Progress = "100%"
+ task.FinishTime = 9999
+ shouldSettle = true
+ case model.TaskStatusFailure:
+ task.Progress = "100%"
+ task.FinishTime = 9999
+ task.FailReason = "upstream error"
+ if quota != 0 {
+ shouldRefund = true
+ }
+ default:
+ task.Progress = "50%"
+ }
+
+ isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
+ if isDone && snap.Status != task.Status {
+ won, err := task.UpdateWithStatus(snap.Status)
+ if err != nil {
+ shouldRefund = false
+ shouldSettle = false
+ } else if !won {
+ shouldRefund = false
+ shouldSettle = false
+ }
+ } else if !snap.Equal(task.Snapshot()) {
+ _, _ = task.UpdateWithStatus(snap.Status)
+ }
+
+ if shouldSettle && actualQuota > 0 {
+ RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
+ }
+ if shouldRefund {
+ RefundTaskQuota(ctx, task, task.FailReason)
+ }
+}
+
+func TestCASGuardedRefund_Win(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 20, 20, 20
+ const initQuota, preConsumed = 10000, 4000
+ const tokenRemain = 6000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+ task.Status = model.TaskStatus(model.TaskStatusInProgress)
+ require.NoError(t, model.DB.Create(task).Error)
+
+ simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
+
+ // CAS wins: task in DB should now be FAILURE
+ var reloaded model.Task
+ require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+ assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
+
+ // Refund should have happened
+ assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+ assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+
+ log := getLastLog(t)
+ require.NotNil(t, log)
+ assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+func TestCASGuardedRefund_Lose(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 21, 21, 21
+ const initQuota, preConsumed = 10000, 4000
+ const tokenRemain = 6000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
+ seedChannel(t, channelID)
+
+ // Create task with IN_PROGRESS in DB
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+ task.Status = model.TaskStatus(model.TaskStatusInProgress)
+ require.NoError(t, model.DB.Create(task).Error)
+
+ // Simulate another process already transitioning to FAILURE
+ model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
+
+ // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
+ // task.Status is still IN_PROGRESS in the snapshot
+ simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
+
+ // CAS lost: user quota should NOT change (no double refund)
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
+ assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+
+ // No billing log should be created
+ assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestCASGuardedSettle_Win(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 22, 22, 22
+ const initQuota, preConsumed = 10000, 5000
+ const actualQuota = 3000 // over-charged, should get partial refund
+ const tokenRemain = 8000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+ task.Status = model.TaskStatus(model.TaskStatusInProgress)
+ require.NoError(t, model.DB.Create(task).Error)
+
+ simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
+
+ // CAS wins: task should be SUCCESS
+ var reloaded model.Task
+ require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+ assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
+
+ // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
+ assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
+ assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+ // task.Quota should be updated to actualQuota
+ assert.Equal(t, actualQuota, task.Quota)
+}
+
+func TestNonTerminalUpdate_NoBilling(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, channelID = 23, 23
+ const initQuota, preConsumed = 10000, 3000
+
+ seedUser(t, userID, initQuota)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
+ task.Status = model.TaskStatus(model.TaskStatusInProgress)
+ task.Progress = "20%"
+ require.NoError(t, model.DB.Create(task).Error)
+
+ // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
+ simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
+
+ // User quota should NOT change
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
+
+ // No billing log
+ assert.Equal(t, int64(0), countLogs(t))
+
+ // Task progress should be updated in DB
+ var reloaded model.Task
+ require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+ assert.Equal(t, "50%", reloaded.Progress)
+}
+
+// ===========================================================================
+// Mock adaptor for settleTaskBillingOnComplete tests
+// ===========================================================================
+
+type mockAdaptor struct {
+ adjustReturn int
+}
+
+func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {}
+func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil }
+func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil }
+func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
+ return m.adjustReturn
+}
+
+// ===========================================================================
+// PerCallBilling tests — settleTaskBillingOnComplete
+// ===========================================================================
+
+func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 30, 30, 30
+ const initQuota, preConsumed = 10000, 5000
+ const tokenRemain = 8000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+ task.PrivateData.BillingContext.PerCallBilling = true
+
+ adaptor := &mockAdaptor{adjustReturn: 2000}
+ taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+ // Per-call: no adjustment despite adaptor returning 2000
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
+ assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+ assert.Equal(t, preConsumed, task.Quota)
+ assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 31, 31, 31
+ const initQuota, preConsumed = 10000, 4000
+ const tokenRemain = 7000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+ task.PrivateData.BillingContext.PerCallBilling = true
+
+ adaptor := &mockAdaptor{adjustReturn: 0}
+ taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
+
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+ // Per-call: no recalculation by tokens
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
+ assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+ assert.Equal(t, preConsumed, task.Quota)
+ assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 32, 32, 32
+ const initQuota, preConsumed = 10000, 5000
+ const adaptorQuota = 3000
+ const tokenRemain = 8000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+ // PerCallBilling defaults to false
+
+ adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
+ taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+ // Non-per-call: adaptor adjustment applies (refund 2000)
+ assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
+ assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
+ assert.Equal(t, adaptorQuota, task.Quota)
+
+ log := getLastLog(t)
+ require.NotNil(t, log)
+ assert.Equal(t, model.LogTypeRefund, log.Type)
+}
diff --git a/service/task_polling.go b/service/task_polling.go
new file mode 100644
index 000000000..3c5cab5b0
--- /dev/null
+++ b/service/task_polling.go
@@ -0,0 +1,539 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/constant"
+ "github.com/QuantumNous/new-api/dto"
+ "github.com/QuantumNous/new-api/logger"
+ "github.com/QuantumNous/new-api/model"
+ "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
+
+ "github.com/samber/lo"
+)
+
+// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
+type TaskPollingAdaptor interface {
+ Init(info *relaycommon.RelayInfo)
+ FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
+ ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
+ // AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。
+ // 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。
+ AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
+}
+
+// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
+// 打破 service -> relay -> relay/channel -> service 的循环依赖。
+var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
+
+// sweepTimedOutTasks 在主轮询之前独立清理超时任务。
+// 每次最多处理 100 条,剩余的下个周期继续处理。
+// 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。
+func sweepTimedOutTasks(ctx context.Context) {
+ if constant.TaskTimeoutMinutes <= 0 {
+ return
+ }
+ cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60
+ tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100)
+ if len(tasks) == 0 {
+ return
+ }
+
+ const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC
+ reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes)
+ legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)"
+ now := time.Now().Unix()
+ timedOutCount := 0
+
+ for _, task := range tasks {
+ isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff
+
+ oldStatus := task.Status
+ task.Status = model.TaskStatusFailure
+ task.Progress = "100%"
+ task.FinishTime = now
+ if isLegacy {
+ task.FailReason = legacyReason
+ } else {
+ task.FailReason = reason
+ }
+
+ won, err := task.UpdateWithStatus(oldStatus)
+ if err != nil {
+ logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err))
+ continue
+ }
+ if !won {
+ logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID))
+ continue
+ }
+ timedOutCount++
+ if !isLegacy && task.Quota != 0 {
+ RefundTaskQuota(ctx, task, reason)
+ }
+ }
+
+ if timedOutCount > 0 {
+ logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount))
+ }
+}
+
+// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
+func TaskPollingLoop() {
+ for {
+ time.Sleep(time.Duration(15) * time.Second)
+ common.SysLog("任务进度轮询开始")
+ ctx := context.TODO()
+ sweepTimedOutTasks(ctx)
+ allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
+ platformTask := make(map[constant.TaskPlatform][]*model.Task)
+ for _, t := range allTasks {
+ platformTask[t.Platform] = append(platformTask[t.Platform], t)
+ }
+ for platform, tasks := range platformTask {
+ if len(tasks) == 0 {
+ continue
+ }
+ taskChannelM := make(map[int][]string)
+ taskM := make(map[string]*model.Task)
+ nullTaskIds := make([]int64, 0)
+ for _, task := range tasks {
+ upstreamID := task.GetUpstreamTaskID()
+ if upstreamID == "" {
+ // 统计失败的未完成任务
+ nullTaskIds = append(nullTaskIds, task.ID)
+ continue
+ }
+ taskM[upstreamID] = task
+ taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
+ }
+ if len(nullTaskIds) > 0 {
+ err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if err != nil {
+ logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+ } else {
+ logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+ }
+ }
+ if len(taskChannelM) == 0 {
+ continue
+ }
+
+ DispatchPlatformUpdate(platform, taskChannelM, taskM)
+ }
+ common.SysLog("任务进度轮询完成")
+ }
+}
+
+// DispatchPlatformUpdate 按平台分发轮询更新
+func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
+ switch platform {
+ case constant.TaskPlatformMidjourney:
+ // MJ 轮询由其自身处理,这里预留入口
+ case constant.TaskPlatformSuno:
+ _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
+ default:
+ if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
+ common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
+ }
+ }
+}
+
+// UpdateSunoTasks 按渠道更新所有 Suno 任务
+func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+ for channelId, taskIds := range taskChannelM {
+ err := updateSunoTasks(ctx, channelId, taskIds, taskM)
+ if err != nil {
+ logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+ logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+ if len(taskIds) == 0 {
+ return nil
+ }
+ ch, err := model.CacheGetChannel(channelId)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+ // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
+ var failedIDs []int64
+ for _, upstreamID := range taskIds {
+ if t, ok := taskM[upstreamID]; ok {
+ failedIDs = append(failedIDs, t.ID)
+ }
+ }
+ err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
+ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if err != nil {
+ common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
+ }
+ return err
+ }
+ adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
+ if adaptor == nil {
+ return errors.New("adaptor not found")
+ }
+ proxy := ch.GetSetting().Proxy
+ resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
+ "ids": taskIds,
+ }, proxy)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
+ return err
+ }
+ if resp.StatusCode != http.StatusOK {
+ logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
+ }
+ defer resp.Body.Close()
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("Get Suno Task parse body error: %v", err))
+ return err
+ }
+ var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
+ err = common.Unmarshal(responseBody, &responseItems)
+ if err != nil {
+ logger.LogError(ctx, fmt.Sprintf("Get Suno Task parse body error2: %v, body: %s", err, string(responseBody)))
+ return err
+ }
+ if !responseItems.IsSuccess() {
+ common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
+ return err
+ }
+
+ for _, responseItem := range responseItems.Data {
+ task := taskM[responseItem.TaskID]
+ if !taskNeedsUpdate(task, responseItem) {
+ continue
+ }
+
+ task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
+ task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
+ task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
+ task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
+ task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
+ if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
+ logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+ task.Progress = "100%"
+ RefundTaskQuota(ctx, task, task.FailReason)
+ }
+ if responseItem.Status == model.TaskStatusSuccess {
+ task.Progress = "100%"
+ }
+ task.Data = responseItem.Data
+
+ err = task.Update()
+ if err != nil {
+ common.SysLog("UpdateSunoTask task error: " + err.Error())
+ }
+ }
+ return nil
+}
+
+// taskNeedsUpdate 检查 Suno 任务是否需要更新
+func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
+ if oldTask.SubmitTime != newTask.SubmitTime {
+ return true
+ }
+ if oldTask.StartTime != newTask.StartTime {
+ return true
+ }
+ if oldTask.FinishTime != newTask.FinishTime {
+ return true
+ }
+ if string(oldTask.Status) != newTask.Status {
+ return true
+ }
+ if oldTask.FailReason != newTask.FailReason {
+ return true
+ }
+
+ if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
+ return true
+ }
+
+ oldData, _ := common.Marshal(oldTask.Data)
+ newData, _ := common.Marshal(newTask.Data)
+
+ sort.Slice(oldData, func(i, j int) bool {
+ return oldData[i] < oldData[j]
+ })
+ sort.Slice(newData, func(i, j int) bool {
+ return newData[i] < newData[j]
+ })
+
+ if string(oldData) != string(newData) {
+ return true
+ }
+ return false
+}
+
+// UpdateVideoTasks 按渠道更新所有视频任务
+func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+ for channelId, taskIds := range taskChannelM {
+ if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
+ logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+ logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+ if len(taskIds) == 0 {
+ return nil
+ }
+ cacheGetChannel, err := model.CacheGetChannel(channelId)
+ if err != nil {
+ // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
+ var failedIDs []int64
+ for _, upstreamID := range taskIds {
+ if t, ok := taskM[upstreamID]; ok {
+ failedIDs = append(failedIDs, t.ID)
+ }
+ }
+ errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
+ "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if errUpdate != nil {
+ common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+ }
+ return fmt.Errorf("CacheGetChannel failed: %w", err)
+ }
+ adaptor := GetTaskAdaptorFunc(platform)
+ if adaptor == nil {
+ return fmt.Errorf("video adaptor not found")
+ }
+ info := &relaycommon.RelayInfo{}
+ info.ChannelMeta = &relaycommon.ChannelMeta{
+ ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
+ }
+ info.ApiKey = cacheGetChannel.Key
+ adaptor.Init(info)
+ for _, taskId := range taskIds {
+ if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+ logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
+ baseURL := constant.ChannelBaseURLs[ch.Type]
+ if ch.GetBaseURL() != "" {
+ baseURL = ch.GetBaseURL()
+ }
+ proxy := ch.GetSetting().Proxy
+
+ task := taskM[taskId]
+ if task == nil {
+ logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+ return fmt.Errorf("task %s not found", taskId)
+ }
+ key := ch.Key
+
+ privateData := task.PrivateData
+ if privateData.Key != "" {
+ key = privateData.Key
+ }
+ resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
+ "task_id": task.GetUpstreamTaskID(),
+ "action": task.Action,
+ }, proxy)
+ if err != nil {
+ return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
+ }
+ defer resp.Body.Close()
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
+ }
+
+ logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
+
+ snap := task.Snapshot()
+
+ taskResult := &relaycommon.TaskInfo{}
+ // try parse as New API response format
+ var responseItems dto.TaskResponse[model.Task]
+ if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
+ logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
+ t := responseItems.Data
+ taskResult.TaskID = t.TaskID
+ taskResult.Status = string(t.Status)
+ taskResult.Url = t.GetResultURL()
+ taskResult.Progress = t.Progress
+ taskResult.Reason = t.FailReason
+ task.Data = t.Data
+ } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
+ return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
+ } else {
+ task.Data = redactVideoResponseBody(responseBody)
+ }
+
+ logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
+
+ now := time.Now().Unix()
+ if taskResult.Status == "" {
+ taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
+ }
+
+ shouldRefund := false
+ shouldSettle := false
+ quota := task.Quota
+
+ task.Status = model.TaskStatus(taskResult.Status)
+ switch taskResult.Status {
+ case model.TaskStatusSubmitted:
+ task.Progress = taskcommon.ProgressSubmitted
+ case model.TaskStatusQueued:
+ task.Progress = taskcommon.ProgressQueued
+ case model.TaskStatusInProgress:
+ task.Progress = taskcommon.ProgressInProgress
+ if task.StartTime == 0 {
+ task.StartTime = now
+ }
+ case model.TaskStatusSuccess:
+ task.Progress = taskcommon.ProgressComplete
+ if task.FinishTime == 0 {
+ task.FinishTime = now
+ }
+ if strings.HasPrefix(taskResult.Url, "data:") {
+ // data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
+ } else if taskResult.Url != "" {
+ // Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
+ task.PrivateData.ResultURL = taskResult.Url
+ } else {
+ // No URL from adaptor — construct proxy URL using public task ID
+ task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
+ }
+ shouldSettle = true
+ case model.TaskStatusFailure:
+ logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
+ task.Status = model.TaskStatusFailure
+ task.Progress = taskcommon.ProgressComplete
+ if task.FinishTime == 0 {
+ task.FinishTime = now
+ }
+ task.FailReason = taskResult.Reason
+ logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+ taskResult.Progress = taskcommon.ProgressComplete
+ if quota != 0 {
+ shouldRefund = true
+ }
+ default:
+ return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID)
+ }
+ if taskResult.Progress != "" {
+ task.Progress = taskResult.Progress
+ }
+
+ isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure
+ if isDone && snap.Status != task.Status {
+ won, err := task.UpdateWithStatus(snap.Status)
+ if err != nil {
+ logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error()))
+ shouldRefund = false
+ shouldSettle = false
+ } else if !won {
+ logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID))
+ shouldRefund = false
+ shouldSettle = false
+ }
+ } else if !snap.Equal(task.Snapshot()) {
+ if _, err := task.UpdateWithStatus(snap.Status); err != nil {
+ logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error()))
+ }
+ } else {
+ // No changes, skip update
+ logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID))
+ }
+
+ if shouldSettle {
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+ }
+ if shouldRefund {
+ RefundTaskQuota(ctx, task, task.FailReason)
+ }
+
+ return nil
+}
+
+func redactVideoResponseBody(body []byte) []byte {
+ var m map[string]any
+ if err := common.Unmarshal(body, &m); err != nil {
+ return body
+ }
+ resp, _ := m["response"].(map[string]any)
+ if resp != nil {
+ delete(resp, "bytesBase64Encoded")
+ if v, ok := resp["video"].(string); ok {
+ resp["video"] = truncateBase64(v)
+ }
+ if vs, ok := resp["videos"].([]any); ok {
+ for i := range vs {
+ if vm, ok := vs[i].(map[string]any); ok {
+ delete(vm, "bytesBase64Encoded")
+ }
+ }
+ }
+ }
+ b, err := common.Marshal(m)
+ if err != nil {
+ return body
+ }
+ return b
+}
+
+func truncateBase64(s string) string {
+ const maxKeep = 256
+ if len(s) <= maxKeep {
+ return s
+ }
+ return s[:maxKeep] + "..."
+}
+
+// settleTaskBillingOnComplete 任务完成时的统一计费调整。
+// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度
+//
+// 2. taskResult.TotalTokens > 0 → 按 token 重算
+// 3. 都不满足 → 保持预扣额度不变
+func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
+ // 0. 按次计费的任务不做差额结算
+ if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
+ logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
+ return
+ }
+ // 1. 优先让 adaptor 决定最终额度
+ if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
+ RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
+ return
+ }
+ // 2. 回退到 token 重算
+ if taskResult.TotalTokens > 0 {
+ RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
+ return
+ }
+ // 3. 无调整,保持预扣额度
+}
diff --git a/service/violation_fee.go b/service/violation_fee.go
index 400c10dd5..455088561 100644
--- a/service/violation_fee.go
+++ b/service/violation_fee.go
@@ -18,8 +18,9 @@ import (
)
const (
- ViolationFeeCodePrefix = "violation_fee."
- CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE_CSAM"
+ ViolationFeeCodePrefix = "violation_fee."
+ CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE"
+ ContentViolatesUsageMarker = "Content violates usage guidelines"
)
func IsViolationFeeCode(code types.ErrorCode) bool {
@@ -30,11 +31,11 @@ func HasCSAMViolationMarker(err *types.NewAPIError) bool {
if err == nil {
return false
}
- if strings.Contains(err.Error(), CSAMViolationMarker) {
+ if strings.Contains(err.Error(), CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) {
return true
}
msg := err.ToOpenAIError().Message
- return strings.Contains(msg, CSAMViolationMarker)
+ return strings.Contains(msg, CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker)
}
func WrapAsViolationFeeGrokCSAM(err *types.NewAPIError) *types.NewAPIError {
diff --git a/setting/operation_setting/status_code_ranges.go b/setting/operation_setting/status_code_ranges.go
index 698c87c91..7e3bc847a 100644
--- a/setting/operation_setting/status_code_ranges.go
+++ b/setting/operation_setting/status_code_ranges.go
@@ -26,6 +26,11 @@ var AutomaticRetryStatusCodeRanges = []StatusCodeRange{
{Start: 525, End: 599},
}
+var alwaysSkipRetryStatusCodes = map[int]struct{}{
+ 504: {},
+ 524: {},
+}
+
func AutomaticDisableStatusCodesToString() string {
return statusCodeRangesToString(AutomaticDisableStatusCodeRanges)
}
@@ -56,7 +61,15 @@ func AutomaticRetryStatusCodesFromString(s string) error {
return nil
}
+func IsAlwaysSkipRetryStatusCode(code int) bool {
+ _, exists := alwaysSkipRetryStatusCodes[code]
+ return exists
+}
+
func ShouldRetryByStatusCode(code int) bool {
+ if IsAlwaysSkipRetryStatusCode(code) {
+ return false
+ }
return shouldMatchStatusCodeRanges(AutomaticRetryStatusCodeRanges, code)
}
diff --git a/setting/operation_setting/status_code_ranges_test.go b/setting/operation_setting/status_code_ranges_test.go
index 5801824ac..4e292a368 100644
--- a/setting/operation_setting/status_code_ranges_test.go
+++ b/setting/operation_setting/status_code_ranges_test.go
@@ -62,6 +62,8 @@ func TestShouldRetryByStatusCode(t *testing.T) {
require.True(t, ShouldRetryByStatusCode(429))
require.True(t, ShouldRetryByStatusCode(500))
+ require.False(t, ShouldRetryByStatusCode(504))
+ require.False(t, ShouldRetryByStatusCode(524))
require.False(t, ShouldRetryByStatusCode(400))
require.False(t, ShouldRetryByStatusCode(200))
}
@@ -77,3 +79,9 @@ func TestShouldRetryByStatusCode_DefaultMatchesLegacyBehavior(t *testing.T) {
require.False(t, ShouldRetryByStatusCode(524))
require.True(t, ShouldRetryByStatusCode(599))
}
+
+func TestIsAlwaysSkipRetryStatusCode(t *testing.T) {
+ require.True(t, IsAlwaysSkipRetryStatusCode(504))
+ require.True(t, IsAlwaysSkipRetryStatusCode(524))
+ require.False(t, IsAlwaysSkipRetryStatusCode(500))
+}
diff --git a/types/price_data.go b/types/price_data.go
index 3f7121b8c..93bc6ae8d 100644
--- a/types/price_data.go
+++ b/types/price_data.go
@@ -22,7 +22,8 @@ type PriceData struct {
AudioCompletionRatio float64
OtherRatios map[string]float64
UsePrice bool
- QuotaToPreConsume int // 预消耗额度
+ Quota int // 按次计费的最终额度(MJ / Task)
+ QuotaToPreConsume int // 按量计费的预消耗额度
GroupRatioInfo GroupRatioInfo
}
@@ -36,12 +37,6 @@ func (p *PriceData) AddOtherRatio(key string, ratio float64) {
p.OtherRatios[key] = ratio
}
-type PerCallPriceData struct {
- ModelPrice float64
- Quota int
- GroupRatioInfo GroupRatioInfo
-}
-
func (p *PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
}
diff --git a/web/src/components/auth/LoginForm.jsx b/web/src/components/auth/LoginForm.jsx
index 636317e44..7e8c0ce01 100644
--- a/web/src/components/auth/LoginForm.jsx
+++ b/web/src/components/auth/LoginForm.jsx
@@ -29,6 +29,7 @@ import {
showSuccess,
updateAPI,
getSystemName,
+ getOAuthProviderIcon,
setUserData,
onGitHubOAuthClicked,
onDiscordOAuthClicked,
@@ -130,6 +131,17 @@ const LoginForm = () => {
return {};
}
}, [statusState?.status]);
+ const hasCustomOAuthProviders =
+ (status.custom_oauth_providers || []).length > 0;
+ const hasOAuthLoginOptions = Boolean(
+ status.github_oauth ||
+ status.discord_oauth ||
+ status.oidc_enabled ||
+ status.wechat_login ||
+ status.linuxdo_oauth ||
+ status.telegram_oauth ||
+ hasCustomOAuthProviders,
+ );
useEffect(() => {
if (status?.turnstile_check) {
@@ -598,7 +610,7 @@ const LoginForm = () => {
theme='outline'
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
type='tertiary'
- icon={ {
+ e.target.style.display = 'none';
+ }}
+ />
+ )}
+
+ );
+ }
+
+ if (isSimpleEmoji(raw)) {
+ return (
+
+ {raw}
+
+ );
+ }
+
+ const key = normalizeOAuthIconKey(raw);
+ const IconComp = oauthProviderIconMap[key];
+ if (IconComp) {
+ return