diff --git a/README.fr.md b/README.fr.md index 77fd0cd1c..6b4d0ceba 100644 --- a/README.fr.md +++ b/README.fr.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
diff --git a/README.ja.md b/README.ja.md index 2cb00affb..2b35bdfe9 100644 --- a/README.ja.md +++ b/README.ja.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
diff --git a/README.md b/README.md index 5f64a0d0b..8f23d5dcd 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
diff --git a/README.zh_CN.md b/README.zh_CN.md index 55265d9a8..fd3204950 100644 --- a/README.zh_CN.md +++ b/README.zh_CN.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
diff --git a/README.zh_TW.md b/README.zh_TW.md index 2fa93157e..9264bc722 100644 --- a/README.zh_TW.md +++ b/README.zh_TW.md @@ -30,8 +30,8 @@

- - Calcium-Ion%2Fnew-api | Trendshift + + QuantumNous%2Fnew-api | Trendshift
diff --git a/common/gin.go b/common/gin.go index 48971c130..5cad6e5c9 100644 --- a/common/gin.go +++ b/common/gin.go @@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { return nil, err } - contentType := c.Request.Header.Get("Content-Type") + // Use the original Content-Type saved on first call to avoid boundary + // mismatch when callers overwrite c.Request.Header after multipart rebuild. + var contentType string + if saved, ok := c.Get("_original_multipart_ct"); ok { + contentType = saved.(string) + } else { + contentType = c.Request.Header.Get("Content-Type") + c.Set("_original_multipart_ct", contentType) + } boundary, err := parseBoundary(contentType) if err != nil { return nil, err @@ -295,7 +303,13 @@ func parseFormData(data []byte, v any) error { } func parseMultipartFormData(c *gin.Context, data []byte, v any) error { - contentType := c.Request.Header.Get("Content-Type") + var contentType string + if saved, ok := c.Get("_original_multipart_ct"); ok { + contentType = saved.(string) + } else { + contentType = c.Request.Header.Get("Content-Type") + c.Set("_original_multipart_ct", contentType) + } boundary, err := parseBoundary(contentType) if err != nil { if errors.Is(err, errBoundaryNotFound) { diff --git a/common/init.go b/common/init.go index 6d2c3572b..e4ddbb453 100644 --- a/common/init.go +++ b/common/init.go @@ -145,6 +145,8 @@ func initConstantEnv() { constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false) // 任务轮询时查询的最大数量 constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000) + // 异步任务超时时间(分钟),超过此时间未完成的任务将被标记为失败并退款。0 表示禁用。 + constant.TaskTimeoutMinutes = GetEnvOrDefault("TASK_TIMEOUT_MINUTES", 1440) soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "") if soraPatchStr != "" { diff --git a/constant/env.go b/constant/env.go index 957f68669..d5aff1b0b 100644 --- a/constant/env.go +++ b/constant/env.go @@ -16,6 +16,7 @@ var NotificationLimitDurationMinute int var GenerateDefaultToken bool var ErrorLogEnabled bool var TaskQueryLimit int +var TaskTimeoutMinutes int // temporary variable for sora patch, will be removed in future var TaskPricePatches []string diff --git a/controller/codex_oauth.go b/controller/codex_oauth.go index 3071413c6..de9743ab7 100644 --- a/controller/codex_oauth.go +++ b/controller/codex_oauth.go @@ -145,6 +145,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { return } + channelProxy := "" if channelID > 0 { ch, err := model.GetChannelById(channelID, false) if err != nil { @@ -159,6 +160,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"}) return } + channelProxy = ch.GetSetting().Proxy } session := sessions.Default(c) @@ -176,7 +178,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second) defer cancel() - tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier) + tokenRes, err := service.ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, channelProxy) if err != nil { common.SysError("failed to exchange codex authorization code: " + err.Error()) c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"}) diff --git a/controller/codex_usage.go b/controller/codex_usage.go index 62b7a754f..52fdbdf6f 100644 --- a/controller/codex_usage.go +++ b/controller/codex_usage.go @@ -2,7 +2,6 @@ package controller import ( "context" - "encoding/json" "fmt" "net/http" "strconv" @@ -80,7 +79,7 @@ func GetCodexChannelUsage(c *gin.Context) { refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second) defer refreshCancel() - res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken) + res, refreshErr := service.RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy) if refreshErr == nil { oauthKey.AccessToken = res.AccessToken oauthKey.RefreshToken = res.RefreshToken @@ -109,7 +108,7 @@ func GetCodexChannelUsage(c *gin.Context) { } var payload any - if json.Unmarshal(body, &payload) != nil { + if common.Unmarshal(body, &payload) != nil { payload = string(body) } diff --git a/controller/custom_oauth.go b/controller/custom_oauth.go index e2245f880..c21ec7910 100644 --- a/controller/custom_oauth.go +++ b/controller/custom_oauth.go @@ -1,8 +1,13 @@ package controller import ( + "context" + "io" "net/http" + "net/url" "strconv" + "strings" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" @@ -16,6 +21,7 @@ type CustomOAuthProviderResponse struct { Id int `json:"id"` Name string `json:"name"` Slug string `json:"slug"` + Icon string `json:"icon"` Enabled bool `json:"enabled"` ClientId string `json:"client_id"` AuthorizationEndpoint string `json:"authorization_endpoint"` @@ -28,6 +34,16 @@ type CustomOAuthProviderResponse struct { EmailField string `json:"email_field"` WellKnown string `json:"well_known"` AuthStyle int `json:"auth_style"` + AccessPolicy string `json:"access_policy"` + AccessDeniedMessage string `json:"access_denied_message"` +} + +type UserOAuthBindingResponse struct { + ProviderId int `json:"provider_id"` + ProviderName string `json:"provider_name"` + ProviderSlug string `json:"provider_slug"` + ProviderIcon string `json:"provider_icon"` + ProviderUserId string `json:"provider_user_id"` } func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse { @@ -35,6 +51,7 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro Id: p.Id, Name: p.Name, Slug: p.Slug, + Icon: p.Icon, Enabled: p.Enabled, ClientId: p.ClientId, AuthorizationEndpoint: p.AuthorizationEndpoint, @@ -47,6 +64,8 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro EmailField: p.EmailField, WellKnown: p.WellKnown, AuthStyle: p.AuthStyle, + AccessPolicy: p.AccessPolicy, + AccessDeniedMessage: p.AccessDeniedMessage, } } @@ -96,6 +115,7 @@ func GetCustomOAuthProvider(c *gin.Context) { type CreateCustomOAuthProviderRequest struct { Name string `json:"name" binding:"required"` Slug string `json:"slug" binding:"required"` + Icon string `json:"icon"` Enabled bool `json:"enabled"` ClientId string `json:"client_id" binding:"required"` ClientSecret string `json:"client_secret" binding:"required"` @@ -109,6 +129,85 @@ type CreateCustomOAuthProviderRequest struct { EmailField string `json:"email_field"` WellKnown string `json:"well_known"` AuthStyle int `json:"auth_style"` + AccessPolicy string `json:"access_policy"` + AccessDeniedMessage string `json:"access_denied_message"` +} + +type FetchCustomOAuthDiscoveryRequest struct { + WellKnownURL string `json:"well_known_url"` + IssuerURL string `json:"issuer_url"` +} + +// FetchCustomOAuthDiscovery fetches OIDC discovery document via backend (root-only route) +func FetchCustomOAuthDiscovery(c *gin.Context) { + var req FetchCustomOAuthDiscoveryRequest + if err := c.ShouldBindJSON(&req); err != nil { + common.ApiErrorMsg(c, "无效的请求参数: "+err.Error()) + return + } + + wellKnownURL := strings.TrimSpace(req.WellKnownURL) + issuerURL := strings.TrimSpace(req.IssuerURL) + + if wellKnownURL == "" && issuerURL == "" { + common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL") + return + } + + targetURL := wellKnownURL + if targetURL == "" { + targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration" + } + targetURL = strings.TrimSpace(targetURL) + + parsedURL, err := url.Parse(targetURL) + if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") { + common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https") + return + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second) + defer cancel() + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error()) + return + } + httpReq.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 20 * time.Second} + resp, err := client.Do(httpReq) + if err != nil { + common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error()) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + message := strings.TrimSpace(string(body)) + if message == "" { + message = resp.Status + } + common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message) + return + } + + var discovery map[string]any + if err = common.DecodeJson(resp.Body, &discovery); err != nil { + common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "well_known_url": targetURL, + "discovery": discovery, + }, + }) } // CreateCustomOAuthProvider creates a new custom OAuth provider @@ -134,6 +233,7 @@ func CreateCustomOAuthProvider(c *gin.Context) { provider := &model.CustomOAuthProvider{ Name: req.Name, Slug: req.Slug, + Icon: req.Icon, Enabled: req.Enabled, ClientId: req.ClientId, ClientSecret: req.ClientSecret, @@ -147,6 +247,8 @@ func CreateCustomOAuthProvider(c *gin.Context) { EmailField: req.EmailField, WellKnown: req.WellKnown, AuthStyle: req.AuthStyle, + AccessPolicy: req.AccessPolicy, + AccessDeniedMessage: req.AccessDeniedMessage, } if err := model.CreateCustomOAuthProvider(provider); err != nil { @@ -168,9 +270,10 @@ func CreateCustomOAuthProvider(c *gin.Context) { type UpdateCustomOAuthProviderRequest struct { Name string `json:"name"` Slug string `json:"slug"` - Enabled *bool `json:"enabled"` // Optional: if nil, keep existing + Icon *string `json:"icon"` // Optional: if nil, keep existing + Enabled *bool `json:"enabled"` // Optional: if nil, keep existing ClientId string `json:"client_id"` - ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing + ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` UserInfoEndpoint string `json:"user_info_endpoint"` @@ -181,6 +284,8 @@ type UpdateCustomOAuthProviderRequest struct { EmailField string `json:"email_field"` WellKnown *string `json:"well_known"` // Optional: if nil, keep existing AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing + AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing + AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing } // UpdateCustomOAuthProvider updates an existing custom OAuth provider @@ -227,6 +332,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) { if req.Slug != "" { provider.Slug = req.Slug } + if req.Icon != nil { + provider.Icon = *req.Icon + } if req.Enabled != nil { provider.Enabled = *req.Enabled } @@ -266,6 +374,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) { if req.AuthStyle != nil { provider.AuthStyle = *req.AuthStyle } + if req.AccessPolicy != nil { + provider.AccessPolicy = *req.AccessPolicy + } + if req.AccessDeniedMessage != nil { + provider.AccessDeniedMessage = *req.AccessDeniedMessage + } if err := model.UpdateCustomOAuthProvider(provider); err != nil { common.ApiError(c, err) @@ -327,6 +441,30 @@ func DeleteCustomOAuthProvider(c *gin.Context) { }) } +func buildUserOAuthBindingsResponse(userId int) ([]UserOAuthBindingResponse, error) { + bindings, err := model.GetUserOAuthBindingsByUserId(userId) + if err != nil { + return nil, err + } + + response := make([]UserOAuthBindingResponse, 0, len(bindings)) + for _, binding := range bindings { + provider, err := model.GetCustomOAuthProviderById(binding.ProviderId) + if err != nil { + continue + } + response = append(response, UserOAuthBindingResponse{ + ProviderId: binding.ProviderId, + ProviderName: provider.Name, + ProviderSlug: provider.Slug, + ProviderIcon: provider.Icon, + ProviderUserId: binding.ProviderUserId, + }) + } + + return response, nil +} + // GetUserOAuthBindings returns all OAuth bindings for the current user func GetUserOAuthBindings(c *gin.Context) { userId := c.GetInt("id") @@ -335,32 +473,43 @@ func GetUserOAuthBindings(c *gin.Context) { return } - bindings, err := model.GetUserOAuthBindingsByUserId(userId) + response, err := buildUserOAuthBindingsResponse(userId) if err != nil { common.ApiError(c, err) return } - // Build response with provider info - type BindingResponse struct { - ProviderId int `json:"provider_id"` - ProviderName string `json:"provider_name"` - ProviderSlug string `json:"provider_slug"` - ProviderUserId string `json:"provider_user_id"` + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": response, + }) +} + +func GetUserOAuthBindingsByAdmin(c *gin.Context) { + userIdStr := c.Param("id") + userId, err := strconv.Atoi(userIdStr) + if err != nil { + common.ApiErrorMsg(c, "invalid user id") + return } - response := make([]BindingResponse, 0) - for _, binding := range bindings { - provider, err := model.GetCustomOAuthProviderById(binding.ProviderId) - if err != nil { - continue // Skip if provider not found - } - response = append(response, BindingResponse{ - ProviderId: binding.ProviderId, - ProviderName: provider.Name, - ProviderSlug: provider.Slug, - ProviderUserId: binding.ProviderUserId, - }) + targetUser, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + myRole := c.GetInt("role") + if myRole <= targetUser.Role && myRole != common.RoleRootUser { + common.ApiErrorMsg(c, "no permission") + return + } + + response, err := buildUserOAuthBindingsResponse(userId) + if err != nil { + common.ApiError(c, err) + return } c.JSON(http.StatusOK, gin.H{ @@ -395,3 +544,41 @@ func UnbindCustomOAuth(c *gin.Context) { "message": "解绑成功", }) } + +func UnbindCustomOAuthByAdmin(c *gin.Context) { + userIdStr := c.Param("id") + userId, err := strconv.Atoi(userIdStr) + if err != nil { + common.ApiErrorMsg(c, "invalid user id") + return + } + + targetUser, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + myRole := c.GetInt("role") + if myRole <= targetUser.Role && myRole != common.RoleRootUser { + common.ApiErrorMsg(c, "no permission") + return + } + + providerIdStr := c.Param("provider_id") + providerId, err := strconv.Atoi(providerIdStr) + if err != nil { + common.ApiErrorMsg(c, "invalid provider id") + return + } + + if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil { + common.ApiError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "success", + }) +} diff --git a/controller/midjourney.go b/controller/midjourney.go index c480c12bb..69aa5ccd4 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -105,13 +105,13 @@ func UpdateMidjourneyTaskBulk() { } responseBody, err := io.ReadAll(resp.Body) if err != nil { - logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Mjp Task parse body error: %v", err)) continue } var responseItems []dto.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { - logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Mjp Task parse body error2: %v, body: %s", err, string(responseBody))) continue } resp.Body.Close() @@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() { if !checkMjTaskNeedUpdate(task, responseItem) { continue } + preStatus := task.Status task.Code = 1 task.Progress = responseItem.Progress task.PromptEn = responseItem.PromptEn @@ -172,18 +173,26 @@ func UpdateMidjourneyTaskBulk() { shouldReturnQuota = true } } - err = task.Update() + won, err := task.UpdateWithStatus(preStatus) if err != nil { logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) - } else { - if shouldReturnQuota { - err = model.IncreaseUserQuota(task.UserId, task.Quota, false) - if err != nil { - logger.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } else if won && shouldReturnQuota { + err = model.IncreaseUserQuota(task.UserId, task.Quota, false) + if err != nil { + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } + model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ + UserId: task.UserId, + LogType: model.LogTypeRefund, + Content: "", + ChannelId: task.ChannelId, + ModelName: service.CovertMjpActionToModelName(task.Action), + Quota: task.Quota, + Other: map[string]interface{}{ + "task_id": task.MjId, + "reason": "构图失败", + }, + }) } } } diff --git a/controller/misc.go b/controller/misc.go index a16e2d554..b24a74adf 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -134,8 +134,10 @@ func GetStatus(c *gin.Context) { customProviders := oauth.GetEnabledCustomProviders() if len(customProviders) > 0 { type CustomOAuthInfo struct { + Id int `json:"id"` Name string `json:"name"` Slug string `json:"slug"` + Icon string `json:"icon"` ClientId string `json:"client_id"` AuthorizationEndpoint string `json:"authorization_endpoint"` Scopes string `json:"scopes"` @@ -144,8 +146,10 @@ func GetStatus(c *gin.Context) { for _, p := range customProviders { config := p.GetConfig() providersInfo = append(providersInfo, CustomOAuthInfo{ + Id: config.Id, Name: config.Name, Slug: config.Slug, + Icon: config.Icon, ClientId: config.ClientId, AuthorizationEndpoint: config.AuthorizationEndpoint, Scopes: config.Scopes, diff --git a/controller/oauth.go b/controller/oauth.go index 65e18f9da..818a28f84 100644 --- a/controller/oauth.go +++ b/controller/oauth.go @@ -237,6 +237,16 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o // Set up new user user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1) + + if oauthUser.Username != "" { + if exists, err := model.CheckUserExistOrDeleted(oauthUser.Username, ""); err == nil && !exists { + // 防止索引退化 + if len(oauthUser.Username) <= model.UserNameMaxLength { + user.Username = oauthUser.Username + } + } + } + if oauthUser.DisplayName != "" { user.DisplayName = oauthUser.DisplayName } else if oauthUser.Username != "" { @@ -295,12 +305,12 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o // Set the provider user ID on the user model and update provider.SetProviderUserID(user, oauthUser.ProviderUserID) if err := tx.Model(user).Updates(map[string]interface{}{ - "github_id": user.GitHubId, - "discord_id": user.DiscordId, - "oidc_id": user.OidcId, - "linux_do_id": user.LinuxDOId, - "wechat_id": user.WeChatId, - "telegram_id": user.TelegramId, + "github_id": user.GitHubId, + "discord_id": user.DiscordId, + "oidc_id": user.OidcId, + "linux_do_id": user.LinuxDOId, + "wechat_id": user.WeChatId, + "telegram_id": user.TelegramId, }).Error; err != nil { return err } @@ -340,6 +350,8 @@ func handleOAuthError(c *gin.Context, err error) { } else { common.ApiErrorI18n(c, e.MsgKey) } + case *oauth.AccessDeniedError: + common.ApiErrorMsg(c, e.Message) case *oauth.TrustLevelError: common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow) default: diff --git a/controller/relay.go b/controller/relay.go index e3e92bc51..1788b25b7 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -455,72 +455,147 @@ func RelayNotFound(c *gin.Context) { }) } -func RelayTask(c *gin.Context) { - retryTimes := common.RetryTimes - channelId := c.GetInt("channel_id") - c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) +func RelayTaskFetch(c *gin.Context) { relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) if err != nil { + c.JSON(http.StatusInternalServerError, &dto.TaskError{ + Code: "gen_relay_info_failed", + Message: err.Error(), + StatusCode: http.StatusInternalServerError, + }) return } - taskErr := taskRelayHandler(c, relayInfo) - if taskErr == nil { - retryTimes = 0 + if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { + respondTaskError(c, taskErr) } +} + +func RelayTask(c *gin.Context) { + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) + if err != nil { + c.JSON(http.StatusInternalServerError, &dto.TaskError{ + Code: "gen_relay_info_failed", + Message: err.Error(), + StatusCode: http.StatusInternalServerError, + }) + return + } + + if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil { + respondTaskError(c, taskErr) + return + } + + var result *relay.TaskSubmitResult + var taskErr *dto.TaskError + defer func() { + if taskErr != nil && relayInfo.Billing != nil { + relayInfo.Billing.Refund(c) + } + }() + retryParam := &service.RetryParam{ Ctx: c, TokenGroup: relayInfo.TokenGroup, ModelName: relayInfo.OriginModelName, Retry: common.GetPointer(0), } - for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() { - channel, newAPIError := getChannel(c, relayInfo, retryParam) - if newAPIError != nil { - logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) - taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) - break - } - channelId = channel.Id - useChannel := c.GetStringSlice("use_channel") - useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) - c.Set("use_channel", useChannel) - logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry())) - //middleware.SetupContextForSelectedChannel(c, channel, originalModel) - bodyStorage, err := common.GetBodyStorage(c) - if err != nil { - if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) { - taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge) + for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { + var channel *model.Channel + + if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil { + channel = lockedCh + if retryParam.GetRetry() > 0 { + if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil { + taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError) + break + } + } + } else { + var channelErr *types.NewAPIError + channel, channelErr = getChannel(c, relayInfo, retryParam) + if channelErr != nil { + logger.LogError(c, channelErr.Error()) + taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) + break + } + } + + addUsedChannel(c, channel.Id) + bodyStorage, bodyErr := common.GetBodyStorage(c) + if bodyErr != nil { + if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) { + taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge) } else { - taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest) + taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusBadRequest) } break } c.Request.Body = io.NopCloser(bodyStorage) - taskErr = taskRelayHandler(c, relayInfo) + + result, taskErr = relay.RelayTaskSubmit(c, relayInfo) + if taskErr == nil { + break + } + + if !taskErr.LocalError { + processChannelError(c, + *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, + common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), + types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) + } + + if !shouldRetryTaskRelay(c, channel.Id, taskErr, common.RetryTimes-retryParam.GetRetry()) { + break + } } + useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) logger.LogInfo(c, retryLogStr) } - if taskErr != nil { - if taskErr.StatusCode == http.StatusTooManyRequests { - taskErr.Message = "当前分组上游负载已饱和,请稍后再试" + + // ── 成功:结算 + 日志 + 插入任务 ── + if taskErr == nil { + if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil { + common.SysError("settle task billing error: " + settleErr.Error()) } - c.JSON(taskErr.StatusCode, taskErr) + service.LogTaskConsumption(c, relayInfo) + + task := model.InitTask(result.Platform, relayInfo) + task.PrivateData.UpstreamTaskID = result.UpstreamTaskID + task.PrivateData.BillingSource = relayInfo.BillingSource + task.PrivateData.SubscriptionId = relayInfo.SubscriptionId + task.PrivateData.TokenId = relayInfo.TokenId + task.PrivateData.BillingContext = &model.TaskBillingContext{ + ModelPrice: relayInfo.PriceData.ModelPrice, + GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, + ModelRatio: relayInfo.PriceData.ModelRatio, + OtherRatios: relayInfo.PriceData.OtherRatios, + OriginModelName: relayInfo.OriginModelName, + PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName), + } + task.Quota = result.Quota + task.Data = result.TaskData + task.Action = relayInfo.Action + if insertErr := task.Insert(); insertErr != nil { + common.SysError("insert task error: " + insertErr.Error()) + } + } + + if taskErr != nil { + respondTaskError(c, taskErr) } } -func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError { - var err *dto.TaskError - switch relayInfo.RelayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: - err = relay.RelayTaskFetch(c, relayInfo.RelayMode) - default: - err = relay.RelayTaskSubmit(c, relayInfo) +// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写) +func respondTaskError(c *gin.Context, taskErr *dto.TaskError) { + if taskErr.StatusCode == http.StatusTooManyRequests { + taskErr.Message = "当前分组上游负载已饱和,请稍后再试" } - return err + c.JSON(taskErr.StatusCode, taskErr) } func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool { @@ -544,7 +619,7 @@ func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, } if taskErr.StatusCode/100 == 5 { // 超时不重试 - if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 { + if operation_setting.IsAlwaysSkipRetryStatusCode(taskErr.StatusCode) { return false } return true diff --git a/controller/task.go b/controller/task.go index 244f9161c..eac7db153 100644 --- a/controller/task.go +++ b/controller/task.go @@ -1,231 +1,22 @@ package controller import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "sort" "strconv" - "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" - "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" + "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" - "github.com/samber/lo" ) +// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层 func UpdateTaskBulk() { - //revocer - //imageModel := "midjourney" - for { - time.Sleep(time.Duration(15) * time.Second) - common.SysLog("任务进度轮询开始") - ctx := context.TODO() - allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) - platformTask := make(map[constant.TaskPlatform][]*model.Task) - for _, t := range allTasks { - platformTask[t.Platform] = append(platformTask[t.Platform], t) - } - for platform, tasks := range platformTask { - if len(tasks) == 0 { - continue - } - taskChannelM := make(map[int][]string) - taskM := make(map[string]*model.Task) - nullTaskIds := make([]int64, 0) - for _, task := range tasks { - if task.TaskID == "" { - // 统计失败的未完成任务 - nullTaskIds = append(nullTaskIds, task.ID) - continue - } - taskM[task.TaskID] = task - taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID) - } - if len(nullTaskIds) > 0 { - err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ - "status": "FAILURE", - "progress": "100%", - }) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) - } else { - logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) - } - } - if len(taskChannelM) == 0 { - continue - } - - UpdateTaskByPlatform(platform, taskChannelM, taskM) - } - common.SysLog("任务进度轮询完成") - } -} - -func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { - switch platform { - case constant.TaskPlatformMidjourney: - //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) - case constant.TaskPlatformSuno: - _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) - default: - if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) - } - } -} - -func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { - for channelId, taskIds := range taskChannelM { - err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) - } - } - return nil -} - -func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { - logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) - if len(taskIds) == 0 { - return nil - } - channel, err := model.CacheGetChannel(channelId) - if err != nil { - common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) - err = model.TaskBulkUpdate(taskIds, map[string]any{ - "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), - "status": "FAILURE", - "progress": "100%", - }) - if err != nil { - common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) - } - return err - } - adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno) - if adaptor == nil { - return errors.New("adaptor not found") - } - proxy := channel.GetSetting().Proxy - resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{ - "ids": taskIds, - }, proxy) - if err != nil { - common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) - return err - } - if resp.StatusCode != http.StatusOK { - logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) - return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) - } - defer resp.Body.Close() - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) - return err - } - var responseItems dto.TaskResponse[[]dto.SunoDataResponse] - err = json.Unmarshal(responseBody, &responseItems) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) - return err - } - if !responseItems.IsSuccess() { - common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) - return err - } - - for _, responseItem := range responseItems.Data { - task := taskM[responseItem.TaskID] - if !checkTaskNeedUpdate(task, responseItem) { - continue - } - - task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) - task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) - task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) - task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) - task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) - if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { - logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) - task.Progress = "100%" - //err = model.CacheUpdateUserQuota(task.UserId) ? - if err != nil { - logger.LogError(ctx, "error update user quota cache: "+err.Error()) - } else { - quota := task.Quota - if quota != 0 { - err = model.IncreaseUserQuota(task.UserId, quota, false) - if err != nil { - logger.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } - } - if responseItem.Status == model.TaskStatusSuccess { - task.Progress = "100%" - } - task.Data = responseItem.Data - - err = task.Update() - if err != nil { - common.SysLog("UpdateMidjourneyTask task error: " + err.Error()) - } - } - return nil -} - -func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { - - if oldTask.SubmitTime != newTask.SubmitTime { - return true - } - if oldTask.StartTime != newTask.StartTime { - return true - } - if oldTask.FinishTime != newTask.FinishTime { - return true - } - if string(oldTask.Status) != newTask.Status { - return true - } - if oldTask.FailReason != newTask.FailReason { - return true - } - if oldTask.FinishTime != newTask.FinishTime { - return true - } - - if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { - return true - } - - oldData, _ := json.Marshal(oldTask.Data) - newData, _ := json.Marshal(newTask.Data) - - sort.Slice(oldData, func(i, j int) bool { - return oldData[i] < oldData[j] - }) - sort.Slice(newData, func(i, j int) bool { - return newData[i] < newData[j] - }) - - if string(oldData) != string(newData) { - return true - } - return false + service.TaskPollingLoop() } func GetAllTask(c *gin.Context) { @@ -247,7 +38,7 @@ func GetAllTask(c *gin.Context) { items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllTasks(queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(items) + pageInfo.SetItems(tasksToDto(items, true)) common.ApiSuccess(c, pageInfo) } @@ -271,6 +62,33 @@ func GetUserTask(c *gin.Context) { items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllUserTask(userId, queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(items) + pageInfo.SetItems(tasksToDto(items, false)) common.ApiSuccess(c, pageInfo) } + +func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto { + var userIdMap map[int]*model.UserBase + if fillUser { + userIdMap = make(map[int]*model.UserBase) + userIds := types.NewSet[int]() + for _, task := range tasks { + userIds.Add(task.UserId) + } + for _, userId := range userIds.Items() { + cacheUser, err := model.GetUserCache(userId) + if err == nil { + userIdMap[userId] = cacheUser + } + } + } + result := make([]*dto.TaskDto, len(tasks)) + for i, task := range tasks { + if fillUser { + if user, ok := userIdMap[task.UserId]; ok { + task.Username = user.Username + } + } + result[i] = relay.TaskModel2Dto(task) + } + return result +} diff --git a/controller/task_video.go b/controller/task_video.go deleted file mode 100644 index d7c19e620..000000000 --- a/controller/task_video.go +++ /dev/null @@ -1,313 +0,0 @@ -package controller - -import ( - "context" - "encoding/json" - "fmt" - "io" - "time" - - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/constant" - "github.com/QuantumNous/new-api/dto" - "github.com/QuantumNous/new-api/logger" - "github.com/QuantumNous/new-api/model" - "github.com/QuantumNous/new-api/relay" - "github.com/QuantumNous/new-api/relay/channel" - relaycommon "github.com/QuantumNous/new-api/relay/common" - "github.com/QuantumNous/new-api/setting/ratio_setting" -) - -func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { - for channelId, taskIds := range taskChannelM { - if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { - logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) - } - } - return nil -} - -func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { - logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) - if len(taskIds) == 0 { - return nil - } - cacheGetChannel, err := model.CacheGetChannel(channelId) - if err != nil { - errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ - "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), - "status": "FAILURE", - "progress": "100%", - }) - if errUpdate != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) - } - return fmt.Errorf("CacheGetChannel failed: %w", err) - } - adaptor := relay.GetTaskAdaptor(platform) - if adaptor == nil { - return fmt.Errorf("video adaptor not found") - } - info := &relaycommon.RelayInfo{} - info.ChannelMeta = &relaycommon.ChannelMeta{ - ChannelBaseUrl: cacheGetChannel.GetBaseURL(), - } - info.ApiKey = cacheGetChannel.Key - adaptor.Init(info) - for _, taskId := range taskIds { - if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { - logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) - } - } - return nil -} - -func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { - baseURL := constant.ChannelBaseURLs[channel.Type] - if channel.GetBaseURL() != "" { - baseURL = channel.GetBaseURL() - } - proxy := channel.GetSetting().Proxy - - task := taskM[taskId] - if task == nil { - logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) - return fmt.Errorf("task %s not found", taskId) - } - key := channel.Key - - privateData := task.PrivateData - if privateData.Key != "" { - key = privateData.Key - } - resp, err := adaptor.FetchTask(baseURL, key, map[string]any{ - "task_id": taskId, - "action": task.Action, - }, proxy) - if err != nil { - return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) - } - //if resp.StatusCode != http.StatusOK { - //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode) - //} - defer resp.Body.Close() - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("readAll failed for task %s: %w", taskId, err) - } - - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody))) - - taskResult := &relaycommon.TaskInfo{} - // try parse as New API response format - var responseItems dto.TaskResponse[model.Task] - if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems)) - t := responseItems.Data - taskResult.TaskID = t.TaskID - taskResult.Status = string(t.Status) - taskResult.Url = t.FailReason - taskResult.Progress = t.Progress - taskResult.Reason = t.FailReason - task.Data = t.Data - } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { - return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) - } else { - task.Data = redactVideoResponseBody(responseBody) - } - - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult)) - - now := time.Now().Unix() - if taskResult.Status == "" { - //return fmt.Errorf("task %s status is empty", taskId) - taskResult = relaycommon.FailTaskInfo("upstream returned empty status") - } - - // 记录原本的状态,防止重复退款 - shouldRefund := false - quota := task.Quota - preStatus := task.Status - - task.Status = model.TaskStatus(taskResult.Status) - switch taskResult.Status { - case model.TaskStatusSubmitted: - task.Progress = "10%" - case model.TaskStatusQueued: - task.Progress = "20%" - case model.TaskStatusInProgress: - task.Progress = "30%" - if task.StartTime == 0 { - task.StartTime = now - } - case model.TaskStatusSuccess: - task.Progress = "100%" - if task.FinishTime == 0 { - task.FinishTime = now - } - if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") { - task.FailReason = taskResult.Url - } - - // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费 - if taskResult.TotalTokens > 0 { - // 获取模型名称 - var taskData map[string]interface{} - if err := json.Unmarshal(task.Data, &taskData); err == nil { - if modelName, ok := taskData["model"].(string); ok && modelName != "" { - // 获取模型价格和倍率 - modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) - // 只有配置了倍率(非固定价格)时才按 token 重新计费 - if hasRatioSetting && modelRatio > 0 { - // 获取用户和组的倍率信息 - group := task.Group - if group == "" { - user, err := model.GetUserById(task.UserId, false) - if err == nil { - group = user.Group - } - } - if group != "" { - groupRatio := ratio_setting.GetGroupRatio(group) - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) - - var finalGroupRatio float64 - if hasUserGroupRatio { - finalGroupRatio = userGroupRatio - } else { - finalGroupRatio = groupRatio - } - - // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio - actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio) - - // 计算差额 - preConsumedQuota := task.Quota - quotaDelta := actualQuota - preConsumedQuota - - if quotaDelta > 0 { - // 需要补扣费 - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(quotaDelta), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - taskResult.TotalTokens, - )) - if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil { - logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error())) - } else { - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) - task.Quota = actualQuota // 更新任务记录的实际扣费额度 - - // 记录消费日志 - logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s", - modelRatio, finalGroupRatio, taskResult.TotalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } else if quotaDelta < 0 { - // 需要退还多扣的费用 - refundQuota := -quotaDelta - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(refundQuota), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - taskResult.TotalTokens, - )) - if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil { - logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error())) - } else { - task.Quota = actualQuota // 更新任务记录的实际扣费额度 - - // 记录退款日志 - logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s", - modelRatio, finalGroupRatio, taskResult.TotalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } else { - // quotaDelta == 0, 预扣费刚好准确 - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", - task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens)) - } - } - } - } - } - } - case model.TaskStatusFailure: - logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) - task.Status = model.TaskStatusFailure - task.Progress = "100%" - if task.FinishTime == 0 { - task.FinishTime = now - } - task.FailReason = taskResult.Reason - logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) - taskResult.Progress = "100%" - if quota != 0 { - if preStatus != model.TaskStatusFailure { - shouldRefund = true - } else { - logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) - } - } - default: - return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) - } - if taskResult.Progress != "" { - task.Progress = taskResult.Progress - } - if err := task.Update(); err != nil { - common.SysLog("UpdateVideoTask task error: " + err.Error()) - shouldRefund = false - } - - if shouldRefund { - // 任务失败且之前状态不是失败才退还额度,防止重复退还 - if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { - logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - - return nil -} - -func redactVideoResponseBody(body []byte) []byte { - var m map[string]any - if err := json.Unmarshal(body, &m); err != nil { - return body - } - resp, _ := m["response"].(map[string]any) - if resp != nil { - delete(resp, "bytesBase64Encoded") - if v, ok := resp["video"].(string); ok { - resp["video"] = truncateBase64(v) - } - if vs, ok := resp["videos"].([]any); ok { - for i := range vs { - if vm, ok := vs[i].(map[string]any); ok { - delete(vm, "bytesBase64Encoded") - } - } - } - } - b, err := json.Marshal(m) - if err != nil { - return body - } - return b -} - -func truncateBase64(s string) string { - const maxKeep = 256 - if len(s) <= maxKeep { - return s - } - return s[:maxKeep] + "..." -} diff --git a/controller/user.go b/controller/user.go index db078071e..b58eab88f 100644 --- a/controller/user.go +++ b/controller/user.go @@ -582,6 +582,44 @@ func UpdateUser(c *gin.Context) { return } +func AdminClearUserBinding(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + common.ApiErrorI18n(c, i18n.MsgInvalidParams) + return + } + + bindingType := strings.ToLower(strings.TrimSpace(c.Param("binding_type"))) + if bindingType == "" { + common.ApiErrorI18n(c, i18n.MsgInvalidParams) + return + } + + user, err := model.GetUserById(id, false) + if err != nil { + common.ApiError(c, err) + return + } + + myRole := c.GetInt("role") + if myRole <= user.Role && myRole != common.RoleRootUser { + common.ApiErrorI18n(c, i18n.MsgUserNoPermissionSameLevel) + return + } + + if err := user.ClearBinding(bindingType); err != nil { + common.ApiError(c, err) + return + } + + model.RecordLog(user.Id, model.LogTypeManage, fmt.Sprintf("admin cleared %s binding for user %s", bindingType, user.Username)) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "success", + }) +} + func UpdateSelf(c *gin.Context) { var requestData map[string]interface{} err := json.NewDecoder(c.Request.Body).Decode(&requestData) diff --git a/controller/video_proxy.go b/controller/video_proxy.go index f102baae4..f1dd2bc92 100644 --- a/controller/video_proxy.go +++ b/controller/video_proxy.go @@ -16,59 +16,44 @@ import ( "github.com/gin-gonic/gin" ) +// videoProxyError returns a standardized OpenAI-style error response. +func videoProxyError(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "message": message, + "type": errType, + }, + }) +} + func VideoProxy(c *gin.Context) { taskID := c.Param("task_id") if taskID == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "task_id is required", - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required") return } task, exists, err := model.GetByOnlyTaskId(taskID) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to query task", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task") return } if !exists || task == nil { - logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err)) - c.JSON(http.StatusNotFound, gin.H{ - "error": gin.H{ - "message": "Task not found", - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found") return } if task.Status != model.TaskStatusSuccess { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status), - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusBadRequest, "invalid_request_error", + fmt.Sprintf("Task is not completed yet, current status: %s", task.Status)) return } channel, err := model.CacheGetChannel(task.ChannelId) if err != nil { - logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID)) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to retrieve channel information", - "type": "server_error", - }, - }) + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error())) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information") return } baseURL := channel.GetBaseURL() @@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) { client, err := service.GetHttpClientWithProxy(proxy) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy client", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client") return } @@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy request", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request") return } @@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) { apiKey := task.PrivateData.Key if apiKey == "" { logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID)) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "API key not stored for task", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task") return } - videoURL, err = getGeminiVideoURL(channel, task, apiKey) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error())) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to resolve Gemini video URL", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL") return } req.Header.Set("x-goog-api-key", apiKey) case constant.ChannelTypeOpenAI, constant.ChannelTypeSora: - videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID) + videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID()) req.Header.Set("Authorization", "Bearer "+channel.Key) default: - // Video URL is directly in task.FailReason - videoURL = task.FailReason + // Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data) + videoURL = task.GetResultURL() } req.URL, err = url.Parse(videoURL) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy request", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request") return } resp, err := client.Do(req) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error())) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to fetch video content", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content") return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL)) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode), - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", + fmt.Sprintf("Upstream service returned status %d", resp.StatusCode)) return } @@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) { } } - c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours + c.Writer.Header().Set("Cache-Control", "public, max-age=86400") c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { + if _, err = io.Copy(c.Writer, resp.Body); err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error())) } } diff --git a/controller/video_proxy_gemini.go b/controller/video_proxy_gemini.go index 053ac6515..a63a2a5c4 100644 --- a/controller/video_proxy_gemini.go +++ b/controller/video_proxy_gemini.go @@ -1,12 +1,12 @@ package controller import ( - "encoding/json" "fmt" "io" "strconv" "strings" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" @@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string) proxy := channel.GetSetting().Proxy resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{ - "task_id": task.TaskID, + "task_id": task.GetUpstreamTaskID(), "action": task.Action, }, proxy) if err != nil { @@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string { return "" } var payload map[string]any - if err := json.Unmarshal(task.Data, &payload); err != nil { + if err := common.Unmarshal(task.Data, &payload); err != nil { return "" } return extractGeminiVideoURLFromMap(payload) @@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string { func extractGeminiVideoURLFromPayload(body []byte) string { var payload map[string]any - if err := json.Unmarshal(body, &payload); err != nil { + if err := common.Unmarshal(body, &payload); err != nil { return "" } return extractGeminiVideoURLFromMap(payload) diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 74bceb281..72fdf460c 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -24,14 +24,16 @@ const ( ) type ChannelOtherSettings struct { - AzureResponsesVersion string `json:"azure_responses_version,omitempty"` - VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" - OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"` - ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true - AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费) - DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用) - AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私) - AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"` + AzureResponsesVersion string `json:"azure_responses_version,omitempty"` + VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" + OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"` + ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true + AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费) + AllowInferenceGeo bool `json:"allow_inference_geo,omitempty"` // 是否允许 inference_geo 透传(仅 Claude,默认过滤以满足数据驻留合规) + DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用) + AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私) + AllowIncludeObfuscation bool `json:"allow_include_obfuscation,omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护) + AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"` } func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool { diff --git a/dto/claude.go b/dto/claude.go index 8b6b495f6..32e31710b 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -190,10 +190,13 @@ type ClaudeToolChoice struct { } type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt,omitempty"` - System any `json:"system,omitempty"` - Messages []ClaudeMessage `json:"messages,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + System any `json:"system,omitempty"` + Messages []ClaudeMessage `json:"messages,omitempty"` + // InferenceGeo controls Claude data residency region. + // This field is filtered by default and can be enabled via channel setting allow_inference_geo. + InferenceGeo string `json:"inference_geo,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"` MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` @@ -210,7 +213,8 @@ type ClaudeRequest struct { Thinking *Thinking `json:"thinking,omitempty"` McpServers json.RawMessage `json:"mcp_servers,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"` - // 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤 + // ServiceTier specifies upstream service level and may affect billing. + // This field is filtered by default and can be enabled via channel setting allow_service_tier. ServiceTier string `json:"service_tier,omitempty"` } diff --git a/dto/gemini.go b/dto/gemini.go index 0fd74c639..b97f19ec6 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -324,25 +324,26 @@ type GeminiChatTool struct { } type GeminiChatGenerationConfig struct { - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` - ResponseMimeType string `json:"responseMimeType,omitempty"` - ResponseSchema any `json:"responseSchema,omitempty"` - ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"` - PresencePenalty *float32 `json:"presencePenalty,omitempty"` - FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"` - ResponseLogprobs bool `json:"responseLogprobs,omitempty"` - Logprobs *int32 `json:"logprobs,omitempty"` - MediaResolution MediaResolution `json:"mediaResolution,omitempty"` - Seed int64 `json:"seed,omitempty"` - ResponseModalities []string `json:"responseModalities,omitempty"` - ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` - SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config - ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` + ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"` + PresencePenalty *float32 `json:"presencePenalty,omitempty"` + FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"` + ResponseLogprobs bool `json:"responseLogprobs,omitempty"` + Logprobs *int32 `json:"logprobs,omitempty"` + EnableEnhancedCivicAnswers *bool `json:"enableEnhancedCivicAnswers,omitempty"` + MediaResolution MediaResolution `json:"mediaResolution,omitempty"` + Seed int64 `json:"seed,omitempty"` + ResponseModalities []string `json:"responseModalities,omitempty"` + ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` + SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config + ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config } // UnmarshalJSON allows GeminiChatGenerationConfig to accept both snake_case and camelCase fields. @@ -350,22 +351,23 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error { type Alias GeminiChatGenerationConfig var aux struct { Alias - TopPSnake float64 `json:"top_p,omitempty"` - TopKSnake float64 `json:"top_k,omitempty"` - MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"` - CandidateCountSnake int `json:"candidate_count,omitempty"` - StopSequencesSnake []string `json:"stop_sequences,omitempty"` - ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"` - ResponseSchemaSnake any `json:"response_schema,omitempty"` - ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"` - PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"` - FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"` - ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"` - MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"` - ResponseModalitiesSnake []string `json:"response_modalities,omitempty"` - ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"` - SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"` - ImageConfigSnake json.RawMessage `json:"image_config,omitempty"` + TopPSnake float64 `json:"top_p,omitempty"` + TopKSnake float64 `json:"top_k,omitempty"` + MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"` + CandidateCountSnake int `json:"candidate_count,omitempty"` + StopSequencesSnake []string `json:"stop_sequences,omitempty"` + ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"` + ResponseSchemaSnake any `json:"response_schema,omitempty"` + ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"` + PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"` + FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"` + ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"` + EnableEnhancedCivicAnswersSnake *bool `json:"enable_enhanced_civic_answers,omitempty"` + MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"` + ResponseModalitiesSnake []string `json:"response_modalities,omitempty"` + ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"` + SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"` + ImageConfigSnake json.RawMessage `json:"image_config,omitempty"` } if err := common.Unmarshal(data, &aux); err != nil { @@ -408,6 +410,9 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error { if aux.ResponseLogprobsSnake { c.ResponseLogprobs = aux.ResponseLogprobsSnake } + if aux.EnableEnhancedCivicAnswersSnake != nil { + c.EnableEnhancedCivicAnswers = aux.EnableEnhancedCivicAnswersSnake + } if aux.MediaResolutionSnake != "" { c.MediaResolution = aux.MediaResolutionSnake } @@ -453,12 +458,14 @@ type GeminiChatResponse struct { } type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - ThoughtsTokenCount int `json:"thoughtsTokenCount"` - CachedContentTokenCount int `json:"cachedContentTokenCount"` - PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` + PromptTokenCount int `json:"promptTokenCount"` + ToolUsePromptTokenCount int `json:"toolUsePromptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + ThoughtsTokenCount int `json:"thoughtsTokenCount"` + CachedContentTokenCount int `json:"cachedContentTokenCount"` + PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` + ToolUsePromptTokensDetails []GeminiPromptTokensDetails `json:"toolUsePromptTokensDetails"` } type GeminiPromptTokensDetails struct { diff --git a/dto/openai_request.go b/dto/openai_request.go index 9113a086e..c0a69a376 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -54,18 +54,22 @@ type GeneralOpenAIRequest struct { ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + FunctionCall json.RawMessage `json:"function_call,omitempty"` User string `json:"user,omitempty"` - LogProbs bool `json:"logprobs,omitempty"` - TopLogProbs int `json:"top_logprobs,omitempty"` - Dimensions int `json:"dimensions,omitempty"` - Modalities json.RawMessage `json:"modalities,omitempty"` - Audio json.RawMessage `json:"audio,omitempty"` + // ServiceTier specifies upstream service level and may affect billing. + // This field is filtered by default and can be enabled via channel setting allow_service_tier. + ServiceTier string `json:"service_tier,omitempty"` + LogProbs bool `json:"logprobs,omitempty"` + TopLogProbs int `json:"top_logprobs,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + Modalities json.RawMessage `json:"modalities,omitempty"` + Audio json.RawMessage `json:"audio,omitempty"` // 安全标识符,用于帮助 OpenAI 检测可能违反使用政策的应用程序用户 - // 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤以保护用户隐私 + // 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤,可通过 allow_safety_identifier 开启 SafetyIdentifier string `json:"safety_identifier,omitempty"` // Whether or not to store the output of this chat completion request for use in our model distillation or evals products. // 是否存储此次请求数据供 OpenAI 用于评估和优化产品 - // 注意:默认过滤此字段以保护用户隐私,但过滤后可能导致 Codex 无法正常使用 + // 注意:默认允许透传,可通过 disable_store 禁用;禁用后可能导致 Codex 无法正常使用 Store json.RawMessage `json:"store,omitempty"` // Used by OpenAI to cache responses for similar requests to optimize your cache hit rates. Replaces the user field PromptCacheKey string `json:"prompt_cache_key,omitempty"` @@ -261,6 +265,9 @@ type FunctionRequest struct { type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` + // IncludeObfuscation is only for /v1/responses stream payload. + // This field is filtered by default and can be enabled via channel setting allow_include_obfuscation. + IncludeObfuscation bool `json:"include_obfuscation,omitempty"` } func (r *GeneralOpenAIRequest) GetMaxTokens() uint { @@ -799,30 +806,42 @@ type WebSearchOptions struct { // https://platform.openai.com/docs/api-reference/responses/create type OpenAIResponsesRequest struct { - Model string `json:"model"` - Input json.RawMessage `json:"input,omitempty"` - Include json.RawMessage `json:"include,omitempty"` + Model string `json:"model"` + Input json.RawMessage `json:"input,omitempty"` + Include json.RawMessage `json:"include,omitempty"` + // 在后台运行推理,暂时还不支持依赖的接口 + // Background json.RawMessage `json:"background,omitempty"` + Conversation json.RawMessage `json:"conversation,omitempty"` + ContextManagement json.RawMessage `json:"context_management,omitempty"` Instructions json.RawMessage `json:"instructions,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"` + TopLogProbs *int `json:"top_logprobs,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"` ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"` PreviousResponseID string `json:"previous_response_id,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"` - // 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤 - ServiceTier string `json:"service_tier,omitempty"` + // ServiceTier specifies upstream service level and may affect billing. + // This field is filtered by default and can be enabled via channel setting allow_service_tier. + ServiceTier string `json:"service_tier,omitempty"` + // Store controls whether upstream may store request/response data. + // This field is allowed by default and can be disabled via channel setting disable_store. Store json.RawMessage `json:"store,omitempty"` PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"` PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - Text json.RawMessage `json:"text,omitempty"` - ToolChoice json.RawMessage `json:"tool_choice,omitempty"` - Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map - TopP *float64 `json:"top_p,omitempty"` - Truncation string `json:"truncation,omitempty"` - User string `json:"user,omitempty"` - MaxToolCalls uint `json:"max_tool_calls,omitempty"` - Prompt json.RawMessage `json:"prompt,omitempty"` + // SafetyIdentifier carries client identity for policy abuse detection. + // This field is filtered by default and can be enabled via channel setting allow_safety_identifier. + SafetyIdentifier string `json:"safety_identifier,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + Text json.RawMessage `json:"text,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map + TopP *float64 `json:"top_p,omitempty"` + Truncation string `json:"truncation,omitempty"` + User string `json:"user,omitempty"` + MaxToolCalls uint `json:"max_tool_calls,omitempty"` + Prompt json.RawMessage `json:"prompt,omitempty"` // qwen EnableThinking json.RawMessage `json:"enable_thinking,omitempty"` // perplexity diff --git a/dto/suno.go b/dto/suno.go index a6bb3ebae..90e11b810 100644 --- a/dto/suno.go +++ b/dto/suno.go @@ -4,10 +4,6 @@ import ( "encoding/json" ) -type TaskData interface { - SunoDataResponse | []SunoDataResponse | string | any -} - type SunoSubmitReq struct { GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"` Prompt string `json:"prompt,omitempty"` @@ -20,10 +16,6 @@ type SunoSubmitReq struct { MakeInstrumental bool `json:"make_instrumental"` } -type FetchReq struct { - IDs []string `json:"ids"` -} - type SunoDataResponse struct { TaskID string `json:"task_id" gorm:"type:varchar(50);index"` Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode @@ -66,30 +58,6 @@ type SunoLyrics struct { Text string `json:"text"` } -const TaskSuccessCode = "success" - -type TaskResponse[T TaskData] struct { - Code string `json:"code"` - Message string `json:"message"` - Data T `json:"data"` -} - -func (t *TaskResponse[T]) IsSuccess() bool { - return t.Code == TaskSuccessCode -} - -type TaskDto struct { - TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id - Action string `json:"action"` // 任务类型, song, lyrics, description-mode - Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed - FailReason string `json:"fail_reason"` - SubmitTime int64 `json:"submit_time"` - StartTime int64 `json:"start_time"` - FinishTime int64 `json:"finish_time"` - Progress string `json:"progress"` - Data json.RawMessage `json:"data"` -} - type SunoGoAPISubmitReq struct { CustomMode bool `json:"custom_mode"` diff --git a/dto/task.go b/dto/task.go index afc186b41..4a9a8e2e6 100644 --- a/dto/task.go +++ b/dto/task.go @@ -1,5 +1,9 @@ package dto +import ( + "encoding/json" +) + type TaskError struct { Code string `json:"code"` Message string `json:"message"` @@ -8,3 +12,46 @@ type TaskError struct { LocalError bool `json:"-"` Error error `json:"-"` } + +type TaskData interface { + SunoDataResponse | []SunoDataResponse | string | any +} + +const TaskSuccessCode = "success" + +type TaskResponse[T TaskData] struct { + Code string `json:"code"` + Message string `json:"message"` + Data T `json:"data"` +} + +func (t *TaskResponse[T]) IsSuccess() bool { + return t.Code == TaskSuccessCode +} + +type TaskDto struct { + ID int64 `json:"id"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + TaskID string `json:"task_id"` + Platform string `json:"platform"` + UserId int `json:"user_id"` + Group string `json:"group"` + ChannelId int `json:"channel_id"` + Quota int `json:"quota"` + Action string `json:"action"` + Status string `json:"status"` + FailReason string `json:"fail_reason"` + ResultURL string `json:"result_url,omitempty"` // 任务结果 URL(视频地址等) + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + Progress string `json:"progress"` + Properties any `json:"properties"` + Username string `json:"username,omitempty"` + Data json.RawMessage `json:"data"` +} + +type FetchReq struct { + IDs []string `json:"ids"` +} diff --git a/logger/logger.go b/logger/logger.go index 61b1d49d8..90cf5006e 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,7 +2,6 @@ package logger import ( "context" - "encoding/json" "fmt" "io" "log" @@ -151,7 +150,7 @@ func FormatQuota(quota int) string { // LogJson 仅供测试使用 only for test func LogJson(ctx context.Context, msg string, obj any) { - jsonStr, err := json.Marshal(obj) + jsonStr, err := common.Marshal(obj) if err != nil { LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) return diff --git a/main.go b/main.go index 852e1a0a8..476a2ed24 100644 --- a/main.go +++ b/main.go @@ -19,6 +19,7 @@ import ( "github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/oauth" + "github.com/QuantumNous/new-api/relay" "github.com/QuantumNous/new-api/router" "github.com/QuantumNous/new-api/service" _ "github.com/QuantumNous/new-api/setting/performance_setting" @@ -111,6 +112,15 @@ func main() { // Subscription quota reset task (daily/weekly/monthly/custom) service.StartSubscriptionQuotaResetTask() + // Wire task polling adaptor factory (breaks service -> relay import cycle) + service.GetTaskAdaptorFunc = func(platform constant.TaskPlatform) service.TaskPollingAdaptor { + a := relay.GetTaskAdaptor(platform) + if a == nil { + return nil + } + return a + } + if common.IsMasterNode && constant.UpdateTask { gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() diff --git a/middleware/auth.go b/middleware/auth.go index cf1843510..342e7f498 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -170,6 +170,24 @@ func WssAuth(c *gin.Context) { } +// TokenOrUserAuth allows either session-based user auth or API token auth. +// Used for endpoints that need to be accessible from both the dashboard and API clients. +func TokenOrUserAuth() func(c *gin.Context) { + return func(c *gin.Context) { + // Try session auth first (dashboard users) + session := sessions.Default(c) + if id := session.Get("id"); id != nil { + if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled { + c.Set("id", id) + c.Next() + return + } + } + // Fall back to token auth (API clients) + TokenAuth()(c) + } +} + // TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。 // 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。 // 即使令牌已过期、已耗尽或已禁用,也允许访问。 diff --git a/middleware/logger.go b/middleware/logger.go index b4ed8c89d..151008d9f 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -7,14 +7,28 @@ import ( "github.com/gin-gonic/gin" ) +const RouteTagKey = "route_tag" + +func RouteTag(tag string) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set(RouteTagKey, tag) + c.Next() + } +} + func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[common.RequestIdKey].(string) + requestID, _ = param.Keys[common.RequestIdKey].(string) } - return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", + tag, _ := param.Keys[RouteTagKey].(string) + if tag == "" { + tag = "web" + } + return fmt.Sprintf("[GIN] %s | %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), + tag, requestID, param.StatusCode, param.Latency, diff --git a/model/custom_oauth_provider.go b/model/custom_oauth_provider.go index 43c69833a..12b4d1111 100644 --- a/model/custom_oauth_provider.go +++ b/model/custom_oauth_provider.go @@ -2,32 +2,65 @@ package model import ( "errors" + "fmt" "strings" "time" + + "github.com/QuantumNous/new-api/common" ) +type accessPolicyPayload struct { + Logic string `json:"logic"` + Conditions []accessConditionItem `json:"conditions"` + Groups []accessPolicyPayload `json:"groups"` +} + +type accessConditionItem struct { + Field string `json:"field"` + Op string `json:"op"` + Value any `json:"value"` +} + +var supportedAccessPolicyOps = map[string]struct{}{ + "eq": {}, + "ne": {}, + "gt": {}, + "gte": {}, + "lt": {}, + "lte": {}, + "in": {}, + "not_in": {}, + "contains": {}, + "not_contains": {}, + "exists": {}, + "not_exists": {}, +} + // CustomOAuthProvider stores configuration for custom OAuth providers type CustomOAuthProvider struct { - Id int `json:"id" gorm:"primaryKey"` - Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise" - Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise" - Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled - ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID - ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend) - AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL - TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL - UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL - Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes + Id int `json:"id" gorm:"primaryKey"` + Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise" + Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise" + Icon string `json:"icon" gorm:"type:varchar(128);default:''"` // Icon name from @lobehub/icons + Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled + ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID + ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend) + AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL + TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL + UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL + Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes // Field mapping configuration (supports JSONPath via gjson) - UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id" - UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path - DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path - EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path + UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id" + UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path + DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path + EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path // Advanced options - WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional) - AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth) + WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional) + AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth) + AccessPolicy string `json:"access_policy" gorm:"type:text"` // JSON policy for access control based on user info + AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` // Custom error message template when access is denied CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` @@ -158,6 +191,57 @@ func validateCustomOAuthProvider(provider *CustomOAuthProvider) error { if provider.Scopes == "" { provider.Scopes = "openid profile email" } + if strings.TrimSpace(provider.AccessPolicy) != "" { + var policy accessPolicyPayload + if err := common.UnmarshalJsonStr(provider.AccessPolicy, &policy); err != nil { + return errors.New("access_policy must be valid JSON") + } + if err := validateAccessPolicyPayload(&policy); err != nil { + return fmt.Errorf("access_policy is invalid: %w", err) + } + } + + return nil +} + +func validateAccessPolicyPayload(policy *accessPolicyPayload) error { + if policy == nil { + return errors.New("policy is nil") + } + + logic := strings.ToLower(strings.TrimSpace(policy.Logic)) + if logic == "" { + logic = "and" + } + if logic != "and" && logic != "or" { + return fmt.Errorf("unsupported logic: %s", logic) + } + + if len(policy.Conditions) == 0 && len(policy.Groups) == 0 { + return errors.New("policy requires at least one condition or group") + } + + for index, condition := range policy.Conditions { + field := strings.TrimSpace(condition.Field) + if field == "" { + return fmt.Errorf("condition[%d].field is required", index) + } + op := strings.ToLower(strings.TrimSpace(condition.Op)) + if _, ok := supportedAccessPolicyOps[op]; !ok { + return fmt.Errorf("condition[%d].op is unsupported: %s", index, op) + } + if op == "in" || op == "not_in" { + if _, ok := condition.Value.([]any); !ok { + return fmt.Errorf("condition[%d].value must be an array for op %s", index, op) + } + } + } + + for index := range policy.Groups { + if err := validateAccessPolicyPayload(&policy.Groups[index]); err != nil { + return fmt.Errorf("group[%d]: %w", index, err) + } + } return nil } diff --git a/model/log.go b/model/log.go index d7cd97a42..2d4782fa5 100644 --- a/model/log.go +++ b/model/log.go @@ -199,6 +199,49 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) } } +type RecordTaskBillingLogParams struct { + UserId int + LogType int + Content string + ChannelId int + ModelName string + Quota int + TokenId int + Group string + Other map[string]interface{} +} + +func RecordTaskBillingLog(params RecordTaskBillingLogParams) { + if params.LogType == LogTypeConsume && !common.LogConsumeEnabled { + return + } + username, _ := GetUsernameById(params.UserId, false) + tokenName := "" + if params.TokenId > 0 { + if token, err := GetTokenById(params.TokenId); err == nil { + tokenName = token.Name + } + } + log := &Log{ + UserId: params.UserId, + Username: username, + CreatedAt: common.GetTimestamp(), + Type: params.LogType, + Content: params.Content, + TokenName: tokenName, + ModelName: params.ModelName, + Quota: params.Quota, + ChannelId: params.ChannelId, + TokenId: params.TokenId, + Group: params.Group, + Other: common.MapToJsonStr(params.Other), + } + err := LOG_DB.Create(log).Error + if err != nil { + common.SysLog("failed to record task billing log: " + err.Error()) + } +} + func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) { var tx *gorm.DB if logType == LogTypeUnknown { @@ -252,8 +295,24 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName Id int `gorm:"column:id"` Name string `gorm:"column:name"` } - if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil { - return logs, total, err + if common.MemoryCacheEnabled { + // Cache get channel + for _, channelId := range channelIds.Items() { + if cacheChannel, err := CacheGetChannel(channelId); err == nil { + channels = append(channels, struct { + Id int `gorm:"column:id"` + Name string `gorm:"column:name"` + }{ + Id: channelId, + Name: cacheChannel.Name, + }) + } + } + } else { + // Bulk query channels from DB + if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil { + return logs, total, err + } } channelMap := make(map[int]string, len(channels)) for _, channel := range channels { diff --git a/model/midjourney.go b/model/midjourney.go index c6ef5de5b..e1a8d772b 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -157,6 +157,19 @@ func (midjourney *Midjourney) Update() error { return err } +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Returns (true, nil) if this caller won the update, (false, nil) if +// another process already moved the task out of fromStatus. +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback. +func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) { + result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney) + if result.Error != nil { + return false, result.Error + } + return result.RowsAffected > 0, nil +} + func MjBulkUpdate(mjIds []string, params map[string]any) error { return DB.Model(&Midjourney{}). Where("mj_id in (?)", mjIds). diff --git a/model/task.go b/model/task.go index 82c2e978a..984445083 100644 --- a/model/task.go +++ b/model/task.go @@ -1,10 +1,12 @@ package model import ( + "bytes" "database/sql/driver" "encoding/json" "time" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" commonRelay "github.com/QuantumNous/new-api/relay/common" @@ -64,13 +66,12 @@ type Task struct { } func (t *Task) SetData(data any) { - b, _ := json.Marshal(data) + b, _ := common.Marshal(data) t.Data = json.RawMessage(b) } func (t *Task) GetData(v any) error { - err := json.Unmarshal(t.Data, &v) - return err + return common.Unmarshal(t.Data, &v) } type Properties struct { @@ -85,18 +86,59 @@ func (m *Properties) Scan(val interface{}) error { *m = Properties{} return nil } - return json.Unmarshal(bytesValue, m) + return common.Unmarshal(bytesValue, m) } func (m Properties) Value() (driver.Value, error) { if m == (Properties{}) { return nil, nil } - return json.Marshal(m) + return common.Marshal(m) } type TaskPrivateData struct { - Key string `json:"key,omitempty"` + Key string `json:"key,omitempty"` + UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID + ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等) + // 计费上下文:用于异步退款/差额结算(轮询阶段读取) + BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" + SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 + TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 + BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算) +} + +// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。 +type TaskBillingContext struct { + ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 + GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 + ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 + OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) + OriginModelName string `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName + PerCallBilling bool `json:"per_call_billing,omitempty"` // 按次计费:跳过轮询阶段的差额结算 +} + +// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) +// 旧数据没有 UpstreamTaskID 时,TaskID 本身就是上游 ID +func (t *Task) GetUpstreamTaskID() string { + if t.PrivateData.UpstreamTaskID != "" { + return t.PrivateData.UpstreamTaskID + } + return t.TaskID +} + +// GetResultURL 获取任务结果 URL(视频地址等) +// 新数据存在 PrivateData.ResultURL 中;旧数据回退到 FailReason(历史兼容) +func (t *Task) GetResultURL() string { + if t.PrivateData.ResultURL != "" { + return t.PrivateData.ResultURL + } + return t.FailReason +} + +// GenerateTaskID 生成对外暴露的 task_xxxx 格式 ID +func GenerateTaskID() string { + key, _ := common.GenerateRandomCharsKey(32) + return "task_" + key } func (p *TaskPrivateData) Scan(val interface{}) error { @@ -104,14 +146,14 @@ func (p *TaskPrivateData) Scan(val interface{}) error { if len(bytesValue) == 0 { return nil } - return json.Unmarshal(bytesValue, p) + return common.Unmarshal(bytesValue, p) } func (p TaskPrivateData) Value() (driver.Value, error) { if (p == TaskPrivateData{}) { return nil, nil } - return json.Marshal(p) + return common.Marshal(p) } // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 @@ -142,7 +184,16 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) } } + // 使用预生成的公开 ID(如果有),否则新生成 + taskID := "" + if relayInfo.TaskRelayInfo != nil && relayInfo.TaskRelayInfo.PublicTaskID != "" { + taskID = relayInfo.TaskRelayInfo.PublicTaskID + } else { + taskID = GenerateTaskID() + } + t := &Task{ + TaskID: taskID, UserId: relayInfo.UserId, Group: relayInfo.UsingGroup, SubmitTime: time.Now().Unix(), @@ -237,6 +288,20 @@ func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []* return tasks } +func GetTimedOutUnfinishedTasks(cutoffUnix int64, limit int) []*Task { + var tasks []*Task + err := DB.Where("progress != ?", "100%"). + Where("status NOT IN ?", []string{TaskStatusFailure, TaskStatusSuccess}). + Where("submit_time < ?", cutoffUnix). + Order("submit_time"). + Limit(limit). + Find(&tasks).Error + if err != nil { + return nil + } + return tasks +} + func GetAllUnFinishSyncTasks(limit int) []*Task { var tasks []*Task var err error @@ -291,40 +356,70 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) { return task, nil } -func TaskUpdateProgress(id int64, progress string) error { - return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error -} - func (Task *Task) Insert() error { var err error err = DB.Create(Task).Error return err } +type taskSnapshot struct { + Status TaskStatus + Progress string + StartTime int64 + FinishTime int64 + FailReason string + ResultURL string + Data json.RawMessage +} + +func (s taskSnapshot) Equal(other taskSnapshot) bool { + return s.Status == other.Status && + s.Progress == other.Progress && + s.StartTime == other.StartTime && + s.FinishTime == other.FinishTime && + s.FailReason == other.FailReason && + s.ResultURL == other.ResultURL && + bytes.Equal(s.Data, other.Data) +} + +func (t *Task) Snapshot() taskSnapshot { + return taskSnapshot{ + Status: t.Status, + Progress: t.Progress, + StartTime: t.StartTime, + FinishTime: t.FinishTime, + FailReason: t.FailReason, + ResultURL: t.PrivateData.ResultURL, + Data: t.Data, + } +} + func (Task *Task) Update() error { var err error err = DB.Save(Task).Error return err } -func TaskBulkUpdate(TaskIds []string, params map[string]any) error { - if len(TaskIds) == 0 { - return nil +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Returns (true, nil) if this caller won the update, (false, nil) if +// another process already moved the task out of fromStatus. +// +// Uses Model().Select("*").Updates() instead of Save() because GORM's Save +// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches +// zero rows, which silently bypasses the CAS guard. +func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) { + result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t) + if result.Error != nil { + return false, result.Error } - return DB.Model(&Task{}). - Where("task_id in (?)", TaskIds). - Updates(params).Error -} - -func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error { - if len(taskIDs) == 0 { - return nil - } - return DB.Model(&Task{}). - Where("id in (?)", taskIDs). - Updates(params).Error + return result.RowsAffected > 0, nil } +// TaskBulkUpdateByID performs an unconditional bulk UPDATE by primary key IDs. +// WARNING: This function has NO CAS (Compare-And-Swap) guard — it will overwrite +// any concurrent status changes. DO NOT use in billing/quota lifecycle flows +// (e.g., timeout, success, failure transitions that trigger refunds or settlements). +// For status transitions that involve billing, use Task.UpdateWithStatus() instead. func TaskBulkUpdateByID(ids []int64, params map[string]any) error { if len(ids) == 0 { return nil @@ -339,37 +434,6 @@ type TaskQuotaUsage struct { Count float64 `json:"count"` } -func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) { - query := DB.Model(Task{}) - // 添加过滤条件 - if queryParams.ChannelID != "" { - query = query.Where("channel_id = ?", queryParams.ChannelID) - } - if queryParams.UserID != "" { - query = query.Where("user_id = ?", queryParams.UserID) - } - if len(queryParams.UserIDs) != 0 { - query = query.Where("user_id in (?)", queryParams.UserIDs) - } - if queryParams.TaskID != "" { - query = query.Where("task_id = ?", queryParams.TaskID) - } - if queryParams.Action != "" { - query = query.Where("action = ?", queryParams.Action) - } - if queryParams.Status != "" { - query = query.Where("status = ?", queryParams.Status) - } - if queryParams.StartTimestamp != 0 { - query = query.Where("submit_time >= ?", queryParams.StartTimestamp) - } - if queryParams.EndTimestamp != 0 { - query = query.Where("submit_time <= ?", queryParams.EndTimestamp) - } - err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error - return stat, err -} - // TaskCountAllTasks returns total tasks that match the given query params (admin usage) func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 { var total int64 @@ -438,6 +502,6 @@ func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo { openAIVideo.SetProgressStr(t.Progress) openAIVideo.CreatedAt = t.CreatedAt openAIVideo.CompletedAt = t.UpdatedAt - openAIVideo.SetMetadata("url", t.FailReason) + openAIVideo.SetMetadata("url", t.GetResultURL()) return openAIVideo } diff --git a/model/task_cas_test.go b/model/task_cas_test.go new file mode 100644 index 000000000..3449c6d26 --- /dev/null +++ b/model/task_cas_test.go @@ -0,0 +1,217 @@ +package model + +import ( + "encoding/json" + "os" + "sync" + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestMain(m *testing.M) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + panic("failed to open test db: " + err.Error()) + } + DB = db + LOG_DB = db + + common.UsingSQLite = true + common.RedisEnabled = false + common.BatchUpdateEnabled = false + common.LogConsumeEnabled = true + + sqlDB, err := db.DB() + if err != nil { + panic("failed to get sql.DB: " + err.Error()) + } + sqlDB.SetMaxOpenConns(1) + + if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil { + panic("failed to migrate: " + err.Error()) + } + + os.Exit(m.Run()) +} + +func truncateTables(t *testing.T) { + t.Helper() + t.Cleanup(func() { + DB.Exec("DELETE FROM tasks") + DB.Exec("DELETE FROM users") + DB.Exec("DELETE FROM tokens") + DB.Exec("DELETE FROM logs") + DB.Exec("DELETE FROM channels") + }) +} + +func insertTask(t *testing.T, task *Task) { + t.Helper() + task.CreatedAt = time.Now().Unix() + task.UpdatedAt = time.Now().Unix() + require.NoError(t, DB.Create(task).Error) +} + +// --------------------------------------------------------------------------- +// Snapshot / Equal — pure logic tests (no DB) +// --------------------------------------------------------------------------- + +func TestSnapshotEqual_Same(t *testing.T) { + s := taskSnapshot{ + Status: TaskStatusInProgress, + Progress: "50%", + StartTime: 1000, + FinishTime: 0, + FailReason: "", + ResultURL: "", + Data: json.RawMessage(`{"key":"value"}`), + } + assert.True(t, s.Equal(s)) +} + +func TestSnapshotEqual_DifferentStatus(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)} + b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_DifferentProgress(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)} + b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_DifferentData(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)} + b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_NilVsEmpty(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: nil} + b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}} + // bytes.Equal(nil, []byte{}) == true + assert.True(t, a.Equal(b)) +} + +func TestSnapshot_Roundtrip(t *testing.T) { + task := &Task{ + Status: TaskStatusInProgress, + Progress: "42%", + StartTime: 1234, + FinishTime: 5678, + FailReason: "timeout", + PrivateData: TaskPrivateData{ + ResultURL: "https://example.com/result.mp4", + }, + Data: json.RawMessage(`{"model":"test-model"}`), + } + snap := task.Snapshot() + assert.Equal(t, task.Status, snap.Status) + assert.Equal(t, task.Progress, snap.Progress) + assert.Equal(t, task.StartTime, snap.StartTime) + assert.Equal(t, task.FinishTime, snap.FinishTime) + assert.Equal(t, task.FailReason, snap.FailReason) + assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL) + assert.JSONEq(t, string(task.Data), string(snap.Data)) +} + +// --------------------------------------------------------------------------- +// UpdateWithStatus CAS — DB integration tests +// --------------------------------------------------------------------------- + +func TestUpdateWithStatus_Win(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_win", + Status: TaskStatusInProgress, + Progress: "50%", + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + task.Status = TaskStatusSuccess + task.Progress = "100%" + won, err := task.UpdateWithStatus(TaskStatusInProgress) + require.NoError(t, err) + assert.True(t, won) + + var reloaded Task + require.NoError(t, DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, TaskStatusSuccess, reloaded.Status) + assert.Equal(t, "100%", reloaded.Progress) +} + +func TestUpdateWithStatus_Lose(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_lose", + Status: TaskStatusFailure, + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + task.Status = TaskStatusSuccess + won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus + require.NoError(t, err) + assert.False(t, won) + + var reloaded Task + require.NoError(t, DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged +} + +func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_race", + Status: TaskStatusInProgress, + Quota: 1000, + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + const goroutines = 5 + wins := make([]bool, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + t := &Task{} + *t = Task{ + ID: task.ID, + TaskID: task.TaskID, + Status: TaskStatusSuccess, + Progress: "100%", + Quota: task.Quota, + Data: json.RawMessage(`{}`), + } + t.CreatedAt = task.CreatedAt + t.UpdatedAt = time.Now().Unix() + won, err := t.UpdateWithStatus(TaskStatusInProgress) + if err == nil { + wins[idx] = won + } + }(i) + } + wg.Wait() + + winCount := 0 + for _, w := range wins { + if w { + winCount++ + } + } + assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS") +} diff --git a/model/token.go b/model/token.go index 9e05b63ca..773b2d792 100644 --- a/model/token.go +++ b/model/token.go @@ -360,7 +360,7 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func IncreaseTokenQuota(id int, key string, quota int) (err error) { +func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -373,10 +373,10 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { }) } if common.BatchUpdateEnabled { - addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota) return nil } - return increaseTokenQuota(id, quota) + return increaseTokenQuota(tokenId, quota) } func increaseTokenQuota(id int, quota int) (err error) { diff --git a/model/user.go b/model/user.go index e0c9c686f..1210b5435 100644 --- a/model/user.go +++ b/model/user.go @@ -1,6 +1,7 @@ package model import ( + "database/sql" "encoding/json" "errors" "fmt" @@ -15,6 +16,8 @@ import ( "gorm.io/gorm" ) +const UserNameMaxLength = 20 + // User if you add sensitive fields, don't forget to clean them in setupLogin function. // Otherwise, the sensitive information will be saved on local storage in plain text! type User struct { @@ -536,6 +539,37 @@ func (user *User) Edit(updatePassword bool) error { return updateUserCache(*user) } +func (user *User) ClearBinding(bindingType string) error { + if user.Id == 0 { + return errors.New("user id is empty") + } + + bindingColumnMap := map[string]string{ + "email": "email", + "github": "github_id", + "discord": "discord_id", + "oidc": "oidc_id", + "wechat": "wechat_id", + "telegram": "telegram_id", + "linuxdo": "linux_do_id", + } + + column, ok := bindingColumnMap[bindingType] + if !ok { + return errors.New("invalid binding type") + } + + if err := DB.Model(&User{}).Where("id = ?", user.Id).Update(column, "").Error; err != nil { + return err + } + + if err := DB.Where("id = ?", user.Id).First(user).Error; err != nil { + return err + } + + return updateUserCache(*user) +} + func (user *User) Delete() error { if user.Id == 0 { return errors.New("id 为空!") @@ -820,10 +854,17 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) // Don't return error - fall through to DB } fromDB = true - err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error + // can be nil setting + var safeSetting sql.NullString + err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&safeSetting).Error if err != nil { return settingMap, err } + if safeSetting.Valid { + setting = safeSetting.String + } else { + setting = "" + } userBase := &UserBase{ Setting: setting, } diff --git a/oauth/generic.go b/oauth/generic.go index c7aa87931..bc18054d5 100644 --- a/oauth/generic.go +++ b/oauth/generic.go @@ -3,19 +3,24 @@ package oauth import ( "context" "encoding/base64" - "encoding/json" + stdjson "encoding/json" + "errors" "fmt" "io" "net/http" "net/url" + "regexp" + "strconv" "strings" "time" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/i18n" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" + "github.com/samber/lo" "github.com/tidwall/gjson" ) @@ -31,6 +36,40 @@ type GenericOAuthProvider struct { config *model.CustomOAuthProvider } +type accessPolicy struct { + Logic string `json:"logic"` + Conditions []accessCondition `json:"conditions"` + Groups []accessPolicy `json:"groups"` +} + +type accessCondition struct { + Field string `json:"field"` + Op string `json:"op"` + Value any `json:"value"` +} + +type accessPolicyFailure struct { + Field string + Op string + Expected any + Current any +} + +var supportedAccessPolicyOps = []string{ + "eq", + "ne", + "gt", + "gte", + "lt", + "lte", + "in", + "not_in", + "contains", + "not_contains", + "exists", + "not_exists", +} + // NewGenericOAuthProvider creates a new generic OAuth provider from config func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider { return &GenericOAuthProvider{config: config} @@ -125,7 +164,7 @@ func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c ErrorDesc string `json:"error_description"` } - if err := json.Unmarshal(body, &tokenResponse); err != nil { + if err := common.Unmarshal(body, &tokenResponse); err != nil { // Try to parse as URL-encoded (some OAuth servers like GitHub return this format) parsedValues, parseErr := url.ParseQuery(bodyStr) if parseErr != nil { @@ -227,11 +266,30 @@ func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToke logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s", p.config.Slug, userId, username, displayName, email) + policyRaw := strings.TrimSpace(p.config.AccessPolicy) + if policyRaw != "" { + policy, err := parseAccessPolicy(policyRaw) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] invalid access policy: %s", p.config.Slug, err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, nil, "invalid access policy configuration") + } + allowed, failure := evaluateAccessPolicy(bodyStr, policy) + if !allowed { + message := renderAccessDeniedMessage(p.config.AccessDeniedMessage, p.config.Name, bodyStr, failure) + logger.LogWarn(ctx, fmt.Sprintf("[OAuth-Generic-%s] access denied by policy: field=%s op=%s expected=%v current=%v", + p.config.Slug, failure.Field, failure.Op, failure.Expected, failure.Current)) + return nil, &AccessDeniedError{Message: message} + } + } + return &OAuthUser{ ProviderUserID: userId, Username: username, DisplayName: displayName, Email: email, + Extra: map[string]any{ + "provider": p.config.Slug, + }, }, nil } @@ -266,3 +324,345 @@ func (p *GenericOAuthProvider) GetProviderId() int { func (p *GenericOAuthProvider) IsGenericProvider() bool { return true } + +func parseAccessPolicy(raw string) (*accessPolicy, error) { + var policy accessPolicy + if err := common.UnmarshalJsonStr(raw, &policy); err != nil { + return nil, err + } + if err := validateAccessPolicy(&policy); err != nil { + return nil, err + } + return &policy, nil +} + +func validateAccessPolicy(policy *accessPolicy) error { + if policy == nil { + return errors.New("policy is nil") + } + + logic := strings.ToLower(strings.TrimSpace(policy.Logic)) + if logic == "" { + logic = "and" + } + if !lo.Contains([]string{"and", "or"}, logic) { + return fmt.Errorf("unsupported policy logic: %s", logic) + } + policy.Logic = logic + + if len(policy.Conditions) == 0 && len(policy.Groups) == 0 { + return errors.New("policy requires at least one condition or group") + } + + for index := range policy.Conditions { + if err := validateAccessCondition(&policy.Conditions[index], index); err != nil { + return err + } + } + + for index := range policy.Groups { + if err := validateAccessPolicy(&policy.Groups[index]); err != nil { + return fmt.Errorf("invalid policy group[%d]: %w", index, err) + } + } + + return nil +} + +func validateAccessCondition(condition *accessCondition, index int) error { + if condition == nil { + return fmt.Errorf("condition[%d] is nil", index) + } + + condition.Field = strings.TrimSpace(condition.Field) + if condition.Field == "" { + return fmt.Errorf("condition[%d].field is required", index) + } + + condition.Op = normalizePolicyOp(condition.Op) + if !lo.Contains(supportedAccessPolicyOps, condition.Op) { + return fmt.Errorf("condition[%d].op is unsupported: %s", index, condition.Op) + } + + if lo.Contains([]string{"in", "not_in"}, condition.Op) { + if _, ok := condition.Value.([]any); !ok { + return fmt.Errorf("condition[%d].value must be an array for op %s", index, condition.Op) + } + } + + return nil +} + +func evaluateAccessPolicy(body string, policy *accessPolicy) (bool, *accessPolicyFailure) { + if policy == nil { + return true, nil + } + + logic := strings.ToLower(strings.TrimSpace(policy.Logic)) + if logic == "" { + logic = "and" + } + + hasAny := len(policy.Conditions) > 0 || len(policy.Groups) > 0 + if !hasAny { + return true, nil + } + + if logic == "or" { + var firstFailure *accessPolicyFailure + for _, cond := range policy.Conditions { + ok, failure := evaluateAccessCondition(body, cond) + if ok { + return true, nil + } + if firstFailure == nil { + firstFailure = failure + } + } + for _, group := range policy.Groups { + ok, failure := evaluateAccessPolicy(body, &group) + if ok { + return true, nil + } + if firstFailure == nil { + firstFailure = failure + } + } + return false, firstFailure + } + + for _, cond := range policy.Conditions { + ok, failure := evaluateAccessCondition(body, cond) + if !ok { + return false, failure + } + } + for _, group := range policy.Groups { + ok, failure := evaluateAccessPolicy(body, &group) + if !ok { + return false, failure + } + } + return true, nil +} + +func evaluateAccessCondition(body string, cond accessCondition) (bool, *accessPolicyFailure) { + path := cond.Field + op := cond.Op + result := gjson.Get(body, path) + current := gjsonResultToValue(result) + failure := &accessPolicyFailure{ + Field: path, + Op: op, + Expected: cond.Value, + Current: current, + } + + switch op { + case "exists": + return result.Exists(), failure + case "not_exists": + return !result.Exists(), failure + case "eq": + return compareAny(current, cond.Value) == 0, failure + case "ne": + return compareAny(current, cond.Value) != 0, failure + case "gt": + return compareAny(current, cond.Value) > 0, failure + case "gte": + return compareAny(current, cond.Value) >= 0, failure + case "lt": + return compareAny(current, cond.Value) < 0, failure + case "lte": + return compareAny(current, cond.Value) <= 0, failure + case "in": + return valueInSlice(current, cond.Value), failure + case "not_in": + return !valueInSlice(current, cond.Value), failure + case "contains": + return containsValue(current, cond.Value), failure + case "not_contains": + return !containsValue(current, cond.Value), failure + default: + return false, failure + } +} + +func normalizePolicyOp(op string) string { + return strings.ToLower(strings.TrimSpace(op)) +} + +func gjsonResultToValue(result gjson.Result) any { + if !result.Exists() { + return nil + } + if result.IsArray() { + arr := result.Array() + values := make([]any, 0, len(arr)) + for _, item := range arr { + values = append(values, gjsonResultToValue(item)) + } + return values + } + switch result.Type { + case gjson.Null: + return nil + case gjson.True: + return true + case gjson.False: + return false + case gjson.Number: + return result.Num + case gjson.String: + return result.String() + case gjson.JSON: + var data any + if err := common.UnmarshalJsonStr(result.Raw, &data); err == nil { + return data + } + return result.Raw + default: + return result.Value() + } +} + +func compareAny(left any, right any) int { + if lf, ok := toFloat(left); ok { + if rf, ok2 := toFloat(right); ok2 { + switch { + case lf < rf: + return -1 + case lf > rf: + return 1 + default: + return 0 + } + } + } + + ls := strings.TrimSpace(fmt.Sprint(left)) + rs := strings.TrimSpace(fmt.Sprint(right)) + switch { + case ls < rs: + return -1 + case ls > rs: + return 1 + default: + return 0 + } +} + +func toFloat(v any) (float64, bool) { + switch value := v.(type) { + case float64: + return value, true + case float32: + return float64(value), true + case int: + return float64(value), true + case int8: + return float64(value), true + case int16: + return float64(value), true + case int32: + return float64(value), true + case int64: + return float64(value), true + case uint: + return float64(value), true + case uint8: + return float64(value), true + case uint16: + return float64(value), true + case uint32: + return float64(value), true + case uint64: + return float64(value), true + case stdjson.Number: + n, err := value.Float64() + if err == nil { + return n, true + } + case string: + n, err := strconv.ParseFloat(strings.TrimSpace(value), 64) + if err == nil { + return n, true + } + } + return 0, false +} + +func valueInSlice(current any, expected any) bool { + list, ok := expected.([]any) + if !ok { + return false + } + return lo.ContainsBy(list, func(item any) bool { + return compareAny(current, item) == 0 + }) +} + +func containsValue(current any, expected any) bool { + switch value := current.(type) { + case string: + target := strings.TrimSpace(fmt.Sprint(expected)) + return strings.Contains(value, target) + case []any: + return lo.ContainsBy(value, func(item any) bool { + return compareAny(item, expected) == 0 + }) + } + return false +} + +func renderAccessDeniedMessage(template string, providerName string, body string, failure *accessPolicyFailure) string { + defaultMessage := "Access denied: your account does not meet this provider's access requirements." + message := strings.TrimSpace(template) + if message == "" { + return defaultMessage + } + + if failure == nil { + failure = &accessPolicyFailure{} + } + + replacements := map[string]string{ + "{{provider}}": providerName, + "{{field}}": failure.Field, + "{{op}}": failure.Op, + "{{required}}": fmt.Sprint(failure.Expected), + "{{current}}": fmt.Sprint(failure.Current), + } + + for key, value := range replacements { + message = strings.ReplaceAll(message, key, value) + } + + currentPattern := regexp.MustCompile(`\{\{current\.([^}]+)\}\}`) + message = currentPattern.ReplaceAllStringFunc(message, func(token string) string { + match := currentPattern.FindStringSubmatch(token) + if len(match) != 2 { + return "" + } + path := strings.TrimSpace(match[1]) + if path == "" { + return "" + } + return strings.TrimSpace(gjson.Get(body, path).String()) + }) + + requiredPattern := regexp.MustCompile(`\{\{required\.([^}]+)\}\}`) + message = requiredPattern.ReplaceAllStringFunc(message, func(token string) string { + match := requiredPattern.FindStringSubmatch(token) + if len(match) != 2 { + return "" + } + path := strings.TrimSpace(match[1]) + if failure.Field == path { + return fmt.Sprint(failure.Expected) + } + return "" + }) + + return strings.TrimSpace(message) +} diff --git a/oauth/types.go b/oauth/types.go index 1b0e3646a..383e6f351 100644 --- a/oauth/types.go +++ b/oauth/types.go @@ -57,3 +57,12 @@ func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string) RawError: rawError, } } + +// AccessDeniedError is a direct user-facing access denial message. +type AccessDeniedError struct { + Message string +} + +func (e *AccessDeniedError) Error() string { + return e.Message +} diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index ff7606e2e..d2f7c6bb6 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -36,6 +36,32 @@ type TaskAdaptor interface { ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError + // ── Billing ────────────────────────────────────────────────────── + + // EstimateBilling returns OtherRatios for pre-charge based on user request. + // Called after ValidateRequestAndSetAction, before price calculation. + // Adaptors should extract duration, resolution, etc. from the parsed request + // and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}). + // Return nil to use the base model price without extra ratios. + EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 + + // AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream + // submit response. Called after a successful DoResponse. + // If the upstream returned actual parameters that differ from the estimate + // (e.g. actual seconds), return updated ratios so the caller can recalculate + // the quota and settle the delta with the pre-charge. + // Return nil if no adjustment is needed. + AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64 + + // AdjustBillingOnComplete returns the actual quota when a task reaches a + // terminal state (success/failure) during polling. + // Called by the polling loop after ParseTaskResult. + // Return a positive value to trigger delta settlement (supplement / refund). + // Return 0 to keep the pre-charged amount unchanged. + AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int + + // ── Request / Response ─────────────────────────────────────────── + BuildRequestURL(info *relaycommon.RelayInfo) (string, error) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) @@ -46,9 +72,9 @@ type TaskAdaptor interface { GetModelList() []string GetChannelName() string - // FetchTask - FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) + // ── Polling ────────────────────────────────────────────────────── + FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index dcdff584b..49773e1e6 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -61,8 +61,9 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{ "cookie": {}, // Additional headers that should not be forwarded by name-matching passthrough rules. - "host": {}, - "content-length": {}, + "host": {}, + "content-length": {}, + "accept-encoding": {}, // Do not passthrough credentials by wildcard/regex. "authorization": {}, diff --git a/relay/channel/api_request_test.go b/relay/channel/api_request_test.go index 31e15340a..791379b90 100644 --- a/relay/channel/api_request_test.go +++ b/relay/channel/api_request_test.go @@ -110,3 +110,30 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) { _, ok := headers["X-Legacy"] require.False(t, ok) } + +func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + ctx.Request.Header.Set("X-Trace-Id", "trace-123") + ctx.Request.Header.Set("Accept-Encoding", "gzip") + + info := &relaycommon.RelayInfo{ + IsChannelTest: false, + ChannelMeta: &relaycommon.ChannelMeta{ + HeadersOverride: map[string]any{ + "*": "", + }, + }, + } + + headers, err := processHeaderOverride(info, ctx) + require.NoError(t, err) + require.Equal(t, "trace-123", headers["X-Trace-Id"]) + + _, hasAcceptEncoding := headers["Accept-Encoding"] + require.False(t, hasAcceptEncoding) +} diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 39485b16f..1a434a432 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } // 计算使用量(基于 UsageMetadata) - usage := dto.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount, - TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, - } - - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } + usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) service.IOCopyBytesGracefully(c, resp, responseBody) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index b10ec06c7..b81a148a3 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse { } } +func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage { + promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount + if promptTokens <= 0 && fallbackPromptTokens > 0 { + promptTokens = fallbackPromptTokens + } + + usage := dto.Usage{ + PromptTokens: promptTokens, + CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount, + TotalTokens: metadata.TotalTokenCount, + } + usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount + usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount + + for _, detail := range metadata.PromptTokensDetails { + if detail.Modality == "AUDIO" { + usage.PromptTokensDetails.AudioTokens += detail.TokenCount + } else if detail.Modality == "TEXT" { + usage.PromptTokensDetails.TextTokens += detail.TokenCount + } + } + for _, detail := range metadata.ToolUsePromptTokensDetails { + if detail.Modality == "AUDIO" { + usage.PromptTokensDetails.AudioTokens += detail.TokenCount + } else if detail.Modality == "TEXT" { + usage.PromptTokensDetails.TextTokens += detail.TokenCount + } + } + + if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 { + usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + } + + if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 { + usage.PromptTokensDetails.TextTokens = usage.PromptTokens + } + + return usage +} + func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Id: helper.GetResponseID(c), @@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http // 更新使用量统计 if geminiResponse.UsageMetadata.TotalTokenCount != 0 { - usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount - usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } + mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) + *usage = mappedUsage } return callback(data, &geminiResponse) @@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } } - usage.PromptTokensDetails.TextTokens = usage.PromptTokens - if usage.TotalTokens > 0 { - usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - } - if usage.CompletionTokens <= 0 { if info.ReceivedResponseCount > 0 { usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens()) @@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if len(geminiResponse.Candidates) == 0 { - usage := dto.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - } - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } - if usage.PromptTokens <= 0 { - usage.PromptTokens = info.GetEstimatePromptTokens() - } + usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) var newAPIError *types.NewAPIError if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil { @@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName - usage := dto.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, - TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, - } - - usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount - usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { - if detail.Modality == "AUDIO" { - usage.PromptTokensDetails.AudioTokens = detail.TokenCount - } else if detail.Modality == "TEXT" { - usage.PromptTokensDetails.TextTokens = detail.TokenCount - } - } + usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens()) fullTextResponse.Usage = usage diff --git a/relay/channel/gemini/relay_gemini_usage_test.go b/relay/channel/gemini/relay_gemini_usage_test.go new file mode 100644 index 000000000..c8f9f8343 --- /dev/null +++ b/relay/channel/gemini/relay_gemini_usage_test.go @@ -0,0 +1,333 @@ +package gemini + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + info := &relaycommon.RelayInfo{ + RelayFormat: types.RelayFormatGemini, + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 151, + ToolUsePromptTokenCount: 18329, + CandidatesTokenCount: 1089, + ThoughtsTokenCount: 1120, + TotalTokenCount: 20689, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiChatHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 18480, usage.PromptTokens) + require.Equal(t, 2209, usage.CompletionTokens) + require.Equal(t, 20689, usage.TotalTokens) + require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) +} + +func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldStreamingTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 300 + t.Cleanup(func() { + constant.StreamingTimeout = oldStreamingTimeout + }) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + + chunk := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "partial"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 151, + ToolUsePromptTokenCount: 18329, + CandidatesTokenCount: 1089, + ThoughtsTokenCount: 1120, + TotalTokenCount: 20689, + }, + } + + chunkData, err := common.Marshal(chunk) + require.NoError(t, err) + + streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n") + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(streamBody)), + } + + usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool { + return true + }) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 18480, usage.PromptTokens) + require.Equal(t, 2209, usage.CompletionTokens) + require.Equal(t, 20689, usage.TotalTokens) + require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) +} + +func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 151, + ToolUsePromptTokenCount: 18329, + CandidatesTokenCount: 1089, + ThoughtsTokenCount: 1120, + TotalTokenCount: 20689, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiTextGenerationHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 18480, usage.PromptTokens) + require.Equal(t, 2209, usage.CompletionTokens) + require.Equal(t, 20689, usage.TotalTokens) + require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens) +} + +func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + info := &relaycommon.RelayInfo{ + RelayFormat: types.RelayFormatGemini, + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + info.SetEstimatePromptTokens(20) + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 0, + ToolUsePromptTokenCount: 0, + CandidatesTokenCount: 90, + ThoughtsTokenCount: 10, + TotalTokenCount: 110, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiChatHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 20, usage.PromptTokens) + require.Equal(t, 100, usage.CompletionTokens) + require.Equal(t, 110, usage.TotalTokens) +} + +func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldStreamingTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 300 + t.Cleanup(func() { + constant.StreamingTimeout = oldStreamingTimeout + }) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + info.SetEstimatePromptTokens(20) + + chunk := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "partial"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 0, + ToolUsePromptTokenCount: 0, + CandidatesTokenCount: 90, + ThoughtsTokenCount: 10, + TotalTokenCount: 110, + }, + } + + chunkData, err := common.Marshal(chunk) + require.NoError(t, err) + + streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n") + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(streamBody)), + } + + usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool { + return true + }) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 20, usage.PromptTokens) + require.Equal(t, 100, usage.CompletionTokens) + require.Equal(t, 110, usage.TotalTokens) +} + +func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil) + + info := &relaycommon.RelayInfo{ + OriginModelName: "gemini-3-flash-preview", + ChannelMeta: &relaycommon.ChannelMeta{ + UpstreamModelName: "gemini-3-flash-preview", + }, + } + info.SetEstimatePromptTokens(20) + + payload := dto.GeminiChatResponse{ + Candidates: []dto.GeminiChatCandidate{ + { + Content: dto.GeminiChatContent{ + Role: "model", + Parts: []dto.GeminiPart{ + {Text: "ok"}, + }, + }, + }, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: 0, + ToolUsePromptTokenCount: 0, + CandidatesTokenCount: 90, + ThoughtsTokenCount: 10, + TotalTokenCount: 110, + }, + } + + body, err := common.Marshal(payload) + require.NoError(t, err) + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader(body)), + } + + usage, newAPIError := GeminiTextGenerationHandler(c, info, resp) + require.Nil(t, newAPIError) + require.NotNil(t, usage) + require.Equal(t, 20, usage.PromptTokens) + require.Equal(t, 100, usage.CompletionTokens) + require.Equal(t, 110, usage.TotalTokens) +} diff --git a/relay/channel/minimax/adaptor.go b/relay/channel/minimax/adaptor.go index 8235abc05..d244e695a 100644 --- a/relay/channel/minimax/adaptor.go +++ b/relay/channel/minimax/adaptor.go @@ -10,6 +10,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/claude" "github.com/QuantumNous/new-api/relay/channel/openai" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/constant" @@ -26,7 +27,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { - return nil, errors.New("not implemented") + adaptor := claude.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -119,8 +121,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return handleTTSResponse(c, resp, info) } - adaptor := openai.Adaptor{} - return adaptor.DoResponse(c, resp, info) + switch info.RelayFormat { + case types.RelayFormatClaude: + adaptor := claude.Adaptor{} + return adaptor.DoResponse(c, resp, info) + default: + adaptor := openai.Adaptor{} + return adaptor.DoResponse(c, resp, info) + } } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/minimax/relay-minimax.go b/relay/channel/minimax/relay-minimax.go index b314e69d7..c249de6a4 100644 --- a/relay/channel/minimax/relay-minimax.go +++ b/relay/channel/minimax/relay-minimax.go @@ -6,6 +6,7 @@ import ( channelconstant "github.com/QuantumNous/new-api/constant" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/constant" + "github.com/QuantumNous/new-api/types" ) func GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -13,13 +14,17 @@ func GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if baseUrl == "" { baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax] } - - switch info.RelayMode { - case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil - case constant.RelayModeAudioSpeech: - return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil + switch info.RelayFormat { + case types.RelayFormatClaude: + return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil default: - return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil + case constant.RelayModeAudioSpeech: + return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil + default: + return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) + } } } diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index d55452c08..f698fc9f6 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/samber/lo" @@ -108,10 +109,10 @@ type AliMetadata struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string - aliReq *AliVideoRequest } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { @@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - // 阿里通义万相支持 JSON 格式,不使用 multipart - var taskReq relaycommon.TaskSubmitReq - if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil { - return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest) - } - aliReq, err := a.convertToAliRequest(info, taskReq) - if err != nil { - return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError) - } - a.aliReq = aliReq - logger.LogJson(c, "ali video request body", aliReq) + // ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context return relaycommon.ValidateMultipartDirect(c, info) } @@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { - bodyBytes, err := common.Marshal(a.aliReq) + taskReq, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil, errors.Wrap(err, "get_task_request_failed") + } + + aliReq, err := a.convertToAliRequest(info, taskReq) + if err != nil { + return nil, errors.Wrap(err, "convert_to_ali_request_failed") + } + logger.LogJson(c, "ali video request body", aliReq) + + bodyBytes, err := common.Marshal(aliReq) if err != nil { return nil, errors.Wrap(err, "marshal_ali_request_failed") } - return bytes.NewReader(bodyBytes), nil } @@ -252,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) } func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) { + upstreamModel := req.Model + if info.IsModelMapped { + upstreamModel = info.UpstreamModelName + } aliReq := &AliVideoRequest{ - Model: req.Model, + Model: upstreamModel, Input: AliVideoInput{ Prompt: req.Prompt, ImgURL: req.InputReference, @@ -331,23 +336,37 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay } } - if aliReq.Model != req.Model { + if aliReq.Model != upstreamModel { return nil, errors.New("can't change model with metadata") } - info.PriceData.OtherRatios = map[string]float64{ + return aliReq, nil +} + +// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。 +// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 { + taskReq, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil + } + + aliReq, err := a.convertToAliRequest(info, taskReq) + if err != nil { + return nil + } + + otherRatios := map[string]float64{ "seconds": float64(aliReq.Parameters.Duration), } - ratios, err := ProcessAliOtherRatios(aliReq) if err != nil { - return nil, err + return otherRatios } - for s, f := range ratios { - info.PriceData.OtherRatios[s] = f + for k, v := range ratios { + otherRatios[k] = v } - - return aliReq, nil + return otherRatios } // DoRequest delegates to common helper @@ -384,7 +403,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // 转换为 OpenAI 格式响应 openAIResp := dto.NewOpenAIVideo() - openAIResp.ID = aliResp.Output.TaskID + openAIResp.ID = info.PublicTaskID + openAIResp.TaskID = info.PublicTaskID openAIResp.Model = c.GetString("model") if openAIResp.Model == "" && info != nil { openAIResp.Model = info.OriginModelName diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 6ebecb3c0..8f1d748ce 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -2,7 +2,6 @@ package doubao import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -14,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -89,6 +89,7 @@ type responseTask struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -130,8 +131,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } - info.UpstreamModelName = body.Model - data, err := json.Marshal(body) + if info.IsModelMapped { + body.Model = info.UpstreamModelName + } else { + info.UpstreamModelName = body.Model + } + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -154,7 +159,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // Parse Doubao response var dResp responsePayload - if err := json.Unmarshal(responseBody, &dResp); err != nil { + if err := common.Unmarshal(responseBody, &dResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -165,8 +170,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = dResp.ID - ov.TaskID = dResp.ID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName @@ -234,12 +239,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } @@ -248,7 +248,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := responseTask{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } @@ -286,7 +286,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var dResp responseTask - if err := json.Unmarshal(originTask.Data, &dResp); err != nil { + if err := common.Unmarshal(originTask.Data, &dResp); err != nil { return nil, errors.Wrap(err, "unmarshal doubao task data failed") } @@ -307,6 +307,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index 16c6919b7..5644cd5dc 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -2,8 +2,6 @@ package gemini import ( "bytes" - "encoding/base64" - "encoding/json" "fmt" "io" "net/http" @@ -16,10 +14,10 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/model_setting" - "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" ) @@ -87,6 +85,7 @@ type operationResponse struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -106,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { - modelName := info.OriginModelName + modelName := info.UpstreamModelName version := model_setting.GetGeminiVersionSetting(modelName) return fmt.Sprintf( @@ -145,16 +144,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &body.Parameters) - if err != nil { + if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -175,16 +169,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var s submitResponse - if err := json.Unmarshal(responseBody, &s); err != nil { + if err := common.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } - taskID = encodeLocalTaskID(s.Name) + taskID = taskcommon.EncodeLocalTaskID(s.Name) ov := dto.NewOpenAIVideo() - ov.ID = taskID - ov.TaskID = taskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -206,7 +200,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy return nil, fmt.Errorf("invalid task_id") } - upstreamName, err := decodeLocalTaskID(taskID) + upstreamName, err := taskcommon.DecodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } @@ -232,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse - if err := json.Unmarshal(respBody, &op); err != nil { + if err := common.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } @@ -254,9 +248,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e ti.Status = model.TaskStatusSuccess ti.Progress = "100%" - taskID := encodeLocalTaskID(op.Name) - ti.TaskID = taskID - ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID) + ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name) + // Url intentionally left empty — the caller constructs the proxy URL using the public task ID // Extract URL from generateVideoResponse if available if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 { @@ -269,7 +262,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - upstreamName, err := decodeLocalTaskID(task.TaskID) + // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction. + // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name. + upstreamTaskID := task.GetUpstreamTaskID() + upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID) if err != nil { upstreamName = "" } @@ -297,18 +293,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { // helpers // ============================ -func encodeLocalTaskID(name string) string { - return base64.RawURLEncoding.EncodeToString([]byte(name)) -} - -func decodeLocalTaskID(local string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(local) - if err != nil { - return "", err - } - return string(b), nil -} - var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) func extractModelFromOperationName(name string) string { diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index c77905bfb..28b3a97f1 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -2,7 +2,6 @@ package hailuo import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -18,12 +17,14 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) // https://platform.minimaxi.com/docs/api-reference/video-generation-intro type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -60,12 +61,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, fmt.Errorf("invalid request type in context") } - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -86,7 +87,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var hResp VideoResponse - if err := json.Unmarshal(responseBody, &hResp); err != nil { + if err := common.Unmarshal(responseBody, &hResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -101,8 +102,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = hResp.TaskID - ov.TaskID = hResp.TaskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName @@ -141,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string { return ChannelName } -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) { - modelConfig := GetModelConfig(req.Model) +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) { + modelConfig := GetModelConfig(info.UpstreamModelName) duration := DefaultDuration if req.Duration > 0 { duration = req.Duration @@ -153,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } videoRequest := &VideoRequest{ - Model: req.Model, + Model: info.UpstreamModelName, Prompt: req.Prompt, Duration: &duration, Resolution: resolution, @@ -182,7 +183,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := QueryTaskResponse{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } @@ -224,7 +225,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var hailuoResp QueryTaskResponse - if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil { + if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil { return nil, errors.Wrap(err, "unmarshal hailuo task data failed") } @@ -271,7 +272,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string { } var retrieveResp RetrieveFileResponse - if err := json.Unmarshal(responseBody, &retrieveResp); err != nil { + if err := common.Unmarshal(responseBody, &retrieveResp); err != nil { return "" } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 1522a967f..e6211b1e4 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -6,7 +6,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/hex" - "encoding/json" "fmt" "io" "net/http" @@ -25,6 +24,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) @@ -77,6 +77,7 @@ const ( // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int accessKey string secretKey string @@ -164,11 +165,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } } - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -191,7 +192,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // Parse Jimeng response var jResp responsePayload - if err := json.Unmarshal(responseBody, &jResp); err != nil { + if err := common.Unmarshal(responseBody, &jResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -202,8 +203,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = jResp.Data.TaskID - ov.TaskID = jResp.Data.TaskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -225,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774 "task_id": taskID, } - payloadBytes, err := json.Marshal(payload) + payloadBytes, err := common.Marshal(payload) if err != nil { return nil, errors.Wrap(err, "marshal fetch task payload failed") } @@ -377,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte { return h.Sum(nil) } -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ - ReqKey: req.Model, + ReqKey: info.UpstreamModelName, Prompt: req.Prompt, } @@ -398,13 +399,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* r.BinaryDataBase64 = req.Images } } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } @@ -432,7 +427,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := responseTask{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } taskResult := relaycommon.TaskInfo{} @@ -458,7 +453,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var jimengResp responseTask - if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil { + if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil { return nil, errors.Wrap(err, "unmarshal jimeng task data failed") } @@ -477,8 +472,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } func isNewAPIRelay(apiKey string) bool { diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 5fb853481..cdbb56878 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -2,7 +2,6 @@ package kling import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -21,6 +20,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) @@ -97,6 +97,7 @@ type responsePayload struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -149,14 +150,14 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } req := v.(relaycommon.TaskSubmitReq) - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, err } if body.Image == "" && body.ImageTail == "" { c.Set("action", constant.TaskActionTextGenerate) } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -180,7 +181,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } var kResp responsePayload - err = json.Unmarshal(responseBody, &kResp) + err = common.Unmarshal(responseBody, &kResp) if err != nil { taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) return @@ -190,8 +191,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } ov := dto.NewOpenAIVideo() - ov.ID = kResp.Data.TaskId - ov.TaskID = kResp.Data.TaskId + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -247,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ Prompt: req.Prompt, Image: req.Image, - Mode: defaultString(req.Mode, "std"), - Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + Mode: taskcommon.DefaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)), AspectRatio: a.getAspectRatio(req.Size), - ModelName: req.Model, - Model: req.Model, // Keep consistent with model_name, double writing improves compatibility + ModelName: info.UpstreamModelName, + Model: info.UpstreamModelName, CfgScale: 0.5, StaticMask: "", DynamicMasks: []DynamicMask{}, @@ -265,14 +266,9 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } if r.ModelName == "" { r.ModelName = "kling-v1" + r.Model = "kling-v1" } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } return &r, nil @@ -291,20 +287,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string { } } -func defaultString(s, def string) string { - if strings.TrimSpace(s) == "" { - return def - } - return s -} - -func defaultInt(v int, def int) int { - if v == 0 { - return def - } - return v -} - // ============================ // JWT helpers // ============================ @@ -340,7 +322,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { taskInfo := &relaycommon.TaskInfo{} resPayload := responsePayload{} - err := json.Unmarshal(respBody, &resPayload) + err := common.Unmarshal(respBody, &resPayload) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } @@ -374,7 +356,7 @@ func isNewAPIRelay(apiKey string) bool { func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var klingResp responsePayload - if err := json.Unmarshal(originTask.Data, &klingResp); err != nil { + if err := common.Unmarshal(originTask.Data, &klingResp); err != nil { return nil, errors.Wrap(err, "unmarshal kling task data failed") } @@ -401,6 +383,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro Code: fmt.Sprintf("%d", klingResp.Code), } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index c149f9663..e9029aa20 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -1,9 +1,13 @@ package sora import ( + "bytes" "fmt" "io" + "mime/multipart" "net/http" + "net/textproto" + "strconv" "strings" "github.com/QuantumNous/new-api/common" @@ -11,12 +15,13 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" - "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" + "github.com/tidwall/sjson" ) // ============================ @@ -57,6 +62,7 @@ type responseTask struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -69,15 +75,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func validateRemixRequest(c *gin.Context) *dto.TaskError { - var req struct { - Prompt string `json:"prompt"` - } + var req relaycommon.TaskSubmitReq if err := common.UnmarshalBodyReusable(c, &req); err != nil { return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) } if strings.TrimSpace(req.Prompt) == "" { return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest) } + // 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致 + c.Set("task_request", req) return nil } @@ -88,6 +94,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return relaycommon.ValidateMultipartDirect(c, info) } +// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 { + // remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置 + if info.Action == constant.TaskActionRemix { + return nil + } + + req, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil + } + + seconds, _ := strconv.Atoi(req.Seconds) + if seconds == 0 { + seconds = req.Duration + } + if seconds <= 0 { + seconds = 4 + } + + size := req.Size + if size == "" { + size = "720x1280" + } + + ratios := map[string]float64{ + "seconds": float64(seconds), + "size": 1, + } + if size == "1792x1024" || size == "1024x1792" { + ratios["size"] = 1.666667 + } + return ratios +} + func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.Action == constant.TaskActionRemix { return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil @@ -107,6 +148,74 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "get_request_body_failed") } + cachedBody, err := storage.Bytes() + if err != nil { + return nil, errors.Wrap(err, "read_body_bytes_failed") + } + contentType := c.GetHeader("Content-Type") + + if strings.HasPrefix(contentType, "application/json") { + var bodyMap map[string]interface{} + if err := common.Unmarshal(cachedBody, &bodyMap); err == nil { + bodyMap["model"] = info.UpstreamModelName + if newBody, err := common.Marshal(bodyMap); err == nil { + return bytes.NewReader(newBody), nil + } + } + return bytes.NewReader(cachedBody), nil + } + + if strings.Contains(contentType, "multipart/form-data") { + formData, err := common.ParseMultipartFormReusable(c) + if err != nil { + return bytes.NewReader(cachedBody), nil + } + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + writer.WriteField("model", info.UpstreamModelName) + for key, values := range formData.Value { + if key == "model" { + continue + } + for _, v := range values { + writer.WriteField(key, v) + } + } + for fieldName, fileHeaders := range formData.File { + for _, fh := range fileHeaders { + f, err := fh.Open() + if err != nil { + continue + } + ct := fh.Header.Get("Content-Type") + if ct == "" || ct == "application/octet-stream" { + buf512 := make([]byte, 512) + n, _ := io.ReadFull(f, buf512) + ct = http.DetectContentType(buf512[:n]) + // Re-open after sniffing so the full content is copied below + f.Close() + f, err = fh.Open() + if err != nil { + continue + } + } + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fh.Filename)) + h.Set("Content-Type", ct) + part, err := writer.CreatePart(h) + if err != nil { + f.Close() + continue + } + io.Copy(part, f) + f.Close() + } + } + writer.Close() + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return &buf, nil + } + return common.ReaderOnly(storage), nil } @@ -116,7 +225,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -131,17 +240,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco return } - if dResp.ID == "" { - if dResp.TaskID == "" { - taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) - return - } - dResp.ID = dResp.TaskID - dResp.TaskID = "" + upstreamID := dResp.ID + if upstreamID == "" { + upstreamID = dResp.TaskID + } + if upstreamID == "" { + taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) + return } + // 使用公开 task_xxxx ID 返回给客户端 + dResp.ID = info.PublicTaskID + dResp.TaskID = info.PublicTaskID c.JSON(http.StatusOK, dResp) - return dResp.ID, responseBody, nil + return upstreamID, responseBody, nil } // FetchTask fetch task status @@ -192,7 +304,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e taskResult.Status = model.TaskStatusInProgress case "completed": taskResult.Status = model.TaskStatusSuccess - taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID) + // Url intentionally left empty — the caller constructs the proxy URL using the public task ID case "failed", "cancelled": taskResult.Status = model.TaskStatusFailure if resTask.Error != nil { @@ -210,5 +322,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - return task.Data, nil + data := task.Data + var err error + if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil { + return nil, errors.Wrap(err, "set id failed") + } + return data, nil } diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 8ea9a1c7f..35b5e423b 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -2,18 +2,16 @@ package suno import ( "bytes" - "context" - "encoding/json" "fmt" "io" "net/http" "strings" - "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -21,11 +19,16 @@ import ( ) type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int } +// ParseTaskResult is not used for Suno tasks. +// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that +// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API. +// This differs from the per-task polling used by video adaptors. func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { - return nil, fmt.Errorf("not implement") // todo implement this method if needed + return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable") } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { @@ -47,13 +50,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return } - if sunoRequest.ContinueClipId != "" { - if sunoRequest.TaskID == "" { - taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest) - return - } - info.OriginTaskID = sunoRequest.TaskID - } + //if sunoRequest.ContinueClipId != "" { + // if sunoRequest.TaskID == "" { + // taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest) + // return + // } + // info.OriginTaskID = sunoRequest.TaskID + //} info.Action = action c.Set("task_request", sunoRequest) @@ -76,12 +79,9 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { sunoRequest, ok := c.Get("task_request") if !ok { - err := common.UnmarshalBodyReusable(c, &sunoRequest) - if err != nil { - return nil, err - } + return nil, fmt.Errorf("task_request not found in context") } - data, err := json.Marshal(sunoRequest) + data, err := common.Marshal(sunoRequest) if err != nil { return nil, err } @@ -99,7 +99,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } var sunoResponse dto.TaskResponse[string] - err = json.Unmarshal(responseBody, &sunoResponse) + err = common.Unmarshal(responseBody, &sunoResponse) if err != nil { taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return @@ -109,17 +109,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody)) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - return + // 使用公开 task_xxxx ID 替换上游 ID 返回给客户端 + publicResponse := dto.TaskResponse[string]{ + Code: sunoResponse.Code, + Message: sunoResponse.Message, + Data: info.PublicTaskID, } + c.JSON(http.StatusOK, publicResponse) return sunoResponse.Data, nil, nil } @@ -134,7 +130,7 @@ func (a *TaskAdaptor) GetChannelName() string { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl) - byteBody, err := json.Marshal(body) + byteBody, err := common.Marshal(body) if err != nil { return nil, err } @@ -144,13 +140,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy common.SysLog(fmt.Sprintf("Get Task error: %v", err)) return nil, err } - defer req.Body.Close() - // 设置超时时间 - timeout := time.Second * 15 - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - // 使用带有超时的 context 创建新的请求 - req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+key) client, err := service.GetHttpClientWithProxy(proxy) diff --git a/relay/channel/task/taskcommon/helpers.go b/relay/channel/task/taskcommon/helpers.go new file mode 100644 index 000000000..27d6612d4 --- /dev/null +++ b/relay/channel/task/taskcommon/helpers.go @@ -0,0 +1,95 @@ +package taskcommon + +import ( + "encoding/base64" + "fmt" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/system_setting" + "github.com/gin-gonic/gin" +) + +// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip. +// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target). +func UnmarshalMetadata(metadata map[string]any, target any) error { + if metadata == nil { + return nil + } + metaBytes, err := common.Marshal(metadata) + if err != nil { + return fmt.Errorf("marshal metadata failed: %w", err) + } + if err := common.Unmarshal(metaBytes, target); err != nil { + return fmt.Errorf("unmarshal metadata failed: %w", err) + } + return nil +} + +// DefaultString returns val if non-empty, otherwise fallback. +func DefaultString(val, fallback string) string { + if val == "" { + return fallback + } + return val +} + +// DefaultInt returns val if non-zero, otherwise fallback. +func DefaultInt(val, fallback int) int { + if val == 0 { + return fallback + } + return val +} + +// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string. +// Used by Gemini/Vertex to store upstream names as task IDs. +func EncodeLocalTaskID(name string) string { + return base64.RawURLEncoding.EncodeToString([]byte(name)) +} + +// DecodeLocalTaskID decodes a base64-encoded upstream operation name. +func DecodeLocalTaskID(id string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(id) + if err != nil { + return "", err + } + return string(b), nil +} + +// BuildProxyURL constructs the video proxy URL using the public task ID. +// e.g., "https://your-server.com/v1/videos/task_xxxx/content" +func BuildProxyURL(taskID string) string { + return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID) +} + +// Status-to-progress mapping constants for polling updates. +const ( + ProgressSubmitted = "10%" + ProgressQueued = "20%" + ProgressInProgress = "30%" + ProgressComplete = "100%" +) + +// --------------------------------------------------------------------------- +// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods. +// Adaptors that do not need custom billing can embed this struct directly. +// --------------------------------------------------------------------------- + +type BaseBilling struct{} + +// EstimateBilling returns nil (no extra ratios; use base model price). +func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 { + return nil +} + +// AdjustBillingOnSubmit returns nil (no submit-time adjustment). +func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 { + return nil +} + +// AdjustBillingOnComplete returns 0 (keep pre-charged amount). +func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { + return 0 +} diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index 8ec77266e..700e60976 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -2,13 +2,12 @@ package vertex import ( "bytes" - "encoding/base64" - "encoding/json" "fmt" "io" "net/http" "regexp" "strings" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" @@ -17,6 +16,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -62,6 +62,7 @@ type operationResponse struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -82,10 +83,10 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { + if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } - modelName := info.OriginModelName + modelName := info.UpstreamModelName if modelName == "" { modelName = "veo-3.0-generate-001" } @@ -116,7 +117,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info req.Header.Set("Accept", "application/json") adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { + if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return fmt.Errorf("failed to decode credentials: %w", err) } @@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info return nil } +// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 { + sampleCount := 1 + v, ok := c.Get("task_request") + if ok { + req := v.(relaycommon.TaskSubmitReq) + if req.Metadata != nil { + if sc, exists := req.Metadata["sampleCount"]; exists { + if i, ok := sc.(int); ok && i > 0 { + sampleCount = i + } + if f, ok := sc.(float64); ok && int(f) > 0 { + sampleCount = int(f) + } + } + } + } + return map[string]float64{ + "sampleCount": float64(sampleCount), + } +} + // BuildRequestBody converts request into Vertex specific format. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, ok := c.Get("task_request") @@ -166,25 +189,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, fmt.Errorf("sampleCount must be greater than 0") } - // if req.Duration > 0 { - // body.Parameters["durationSeconds"] = req.Duration - // } else if req.Seconds != "" { - // seconds, err := strconv.Atoi(req.Seconds) - // if err != nil { - // return nil, errors.Wrap(err, "convert seconds to int failed") - // } - // body.Parameters["durationSeconds"] = seconds - // } - - info.PriceData.OtherRatios = map[string]float64{ - "sampleCount": float64(body.Parameters["sampleCount"].(int)), - } - - // if v, ok := body.Parameters["durationSeconds"]; ok { - // info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int)) - // } - - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -205,14 +210,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var s submitResponse - if err := json.Unmarshal(responseBody, &s); err != nil { + if err := common.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } - localID := encodeLocalTaskID(s.Name) - c.JSON(http.StatusOK, gin.H{"task_id": localID}) + localID := taskcommon.EncodeLocalTaskID(s.Name) + ov := dto.NewOpenAIVideo() + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID + ov.CreatedAt = time.Now().Unix() + ov.Model = info.OriginModelName + c.JSON(http.StatusOK, ov) return localID, responseBody, nil } @@ -225,7 +235,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy if !ok { return nil, fmt.Errorf("invalid task_id") } - upstreamName, err := decodeLocalTaskID(taskID) + upstreamName, err := taskcommon.DecodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } @@ -245,12 +255,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) } payload := map[string]string{"operationName": upstreamName} - data, err := json.Marshal(payload) + data, err := common.Marshal(payload) if err != nil { return nil, err } adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(key), adc); err != nil { + if err := common.Unmarshal([]byte(key), adc); err != nil { return nil, fmt.Errorf("failed to decode credentials: %w", err) } token, err := vertexcore.AcquireAccessToken(*adc, proxy) @@ -274,7 +284,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse - if err := json.Unmarshal(respBody, &op); err != nil { + if err := common.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } ti := &relaycommon.TaskInfo{} @@ -338,7 +348,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - upstreamName, err := decodeLocalTaskID(task.TaskID) + // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction. + // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name. + upstreamTaskID := task.GetUpstreamTaskID() + upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID) if err != nil { upstreamName = "" } @@ -353,8 +366,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { v.SetProgressStr(task.Progress) v.CreatedAt = task.CreatedAt v.CompletedAt = task.UpdatedAt - if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 { - v.SetMetadata("url", task.FailReason) + if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 { + v.SetMetadata("url", resultURL) } return common.Marshal(v) @@ -364,18 +377,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { // helpers // ============================ -func encodeLocalTaskID(name string) string { - return base64.RawURLEncoding.EncodeToString([]byte(name)) -} - -func decodeLocalTaskID(local string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(local) - if err != nil { - return "", err - } - return string(b), nil -} - var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`) func extractRegionFromOperationName(name string) string { diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 3657161c0..6ae1c181b 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -2,7 +2,6 @@ package vidu import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -16,6 +15,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -73,6 +73,7 @@ type creation struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int baseURL string } @@ -115,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } req := v.(relaycommon.TaskSubmitReq) - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, err } @@ -127,7 +128,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -168,7 +169,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } var vResp responsePayload - err = json.Unmarshal(responseBody, &vResp) + err = common.Unmarshal(responseBody, &vResp) if err != nil { taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError) return @@ -180,8 +181,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = vResp.TaskId - ov.TaskID = vResp.TaskId + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -223,47 +224,27 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ - Model: defaultString(req.Model, "viduq1"), + Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"), Images: req.Images, Prompt: req.Prompt, - Duration: defaultInt(req.Duration, 5), - Resolution: defaultString(req.Size, "1080p"), + Duration: taskcommon.DefaultInt(req.Duration, 5), + Resolution: taskcommon.DefaultString(req.Size, "1080p"), MovementAmplitude: "auto", Bgm: false, } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } return &r, nil } -func defaultString(value, defaultValue string) string { - if value == "" { - return defaultValue - } - return value -} - -func defaultInt(value, defaultValue int) int { - if value == 0 { - return defaultValue - } - return value -} - func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { taskInfo := &relaycommon.TaskInfo{} var taskResp taskResultResponse - err := json.Unmarshal(respBody, &taskResp) + err := common.Unmarshal(respBody, &taskResp) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } @@ -293,7 +274,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var viduResp taskResultResponse - if err := json.Unmarshal(originTask.Data, &viduResp); err != nil { + if err := common.Unmarshal(originTask.Data, &viduResp); err != nil { return nil, errors.Wrap(err, "unmarshal vidu task data failed") } @@ -315,6 +296,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go index 580cba5f4..8f69b9375 100644 --- a/relay/chat_completions_via_responses.go +++ b/relay/chat_completions_via_responses.go @@ -75,7 +75,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings) + chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -119,7 +119,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings) + jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 2dfa09df5..1722cd9b2 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -146,7 +146,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ } // remove disabled fields for Claude API - jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings) + jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } diff --git a/relay/common/override_test.go b/relay/common/override_test.go index a37eb78f9..0fc24467d 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -6,6 +6,9 @@ import ( "testing" "github.com/QuantumNous/new-api/types" + + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/setting/model_setting" ) func TestApplyParamOverrideTrimPrefix(t *testing.T) { @@ -1311,6 +1314,76 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) { } } +func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) { + input := `{ + "service_tier":"flex", + "safety_identifier":"user-123", + "store":true, + "stream_options":{"include_obfuscation":false} + }` + settings := dto.ChannelOtherSettings{} + + out, err := RemoveDisabledFields([]byte(input), settings, true) + if err != nil { + t.Fatalf("RemoveDisabledFields returned error: %v", err) + } + assertJSONEqual(t, input, string(out)) +} + +func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) { + original := model_setting.GetGlobalSettings().PassThroughRequestEnabled + model_setting.GetGlobalSettings().PassThroughRequestEnabled = true + t.Cleanup(func() { + model_setting.GetGlobalSettings().PassThroughRequestEnabled = original + }) + + input := `{ + "service_tier":"flex", + "safety_identifier":"user-123", + "stream_options":{"include_obfuscation":false} + }` + settings := dto.ChannelOtherSettings{} + + out, err := RemoveDisabledFields([]byte(input), settings, false) + if err != nil { + t.Fatalf("RemoveDisabledFields returned error: %v", err) + } + assertJSONEqual(t, input, string(out)) +} + +func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) { + input := `{ + "service_tier":"flex", + "inference_geo":"eu", + "safety_identifier":"user-123", + "store":true, + "stream_options":{"include_obfuscation":false} + }` + settings := dto.ChannelOtherSettings{} + + out, err := RemoveDisabledFields([]byte(input), settings, false) + if err != nil { + t.Fatalf("RemoveDisabledFields returned error: %v", err) + } + assertJSONEqual(t, `{"store":true}`, string(out)) +} + +func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) { + input := `{ + "inference_geo":"eu", + "store":true + }` + settings := dto.ChannelOtherSettings{ + AllowInferenceGeo: true, + } + + out, err := RemoveDisabledFields([]byte(input), settings, false) + if err != nil { + t.Fatalf("RemoveDisabledFields returned error: %v", err) + } + assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out)) +} + func assertJSONEqual(t *testing.T, want, got string) { t.Helper() diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index e5a0a06f5..8b0789c0d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -119,8 +119,12 @@ type RelayInfo struct { SendResponseCount int ReceivedResponseCount int FinalPreConsumedQuota int // 最终预消耗的配额 + // ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路, + // 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行, + // 必须在提交前锁定全额。 + ForcePreConsume bool // Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。 - // 免费模型和按次计费(MJ/Task)时为 nil。 + // 免费模型时为 nil。 Billing BillingSettler // BillingSource indicates whether this request is billed from wallet quota or subscription. // "" or "wallet" => wallet; "subscription" => subscription @@ -153,7 +157,8 @@ type RelayInfo struct { // RequestConversionChain records request format conversions in order, e.g. // ["openai", "openai_responses"] or ["openai", "claude"]. RequestConversionChain []types.RelayFormat - // 最终请求到上游的格式 TODO: 当前仅设置了Claude + // 最终请求到上游的格式。可由 adaptor 显式设置; + // 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。 FinalRequestRelayFormat types.RelayFormat ThinkingContentInfo @@ -552,8 +557,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req return nil, errors.New("request is not a OpenAIResponsesCompactionRequest") case types.RelayFormatTask: info = genBaseRelayInfo(c, nil) + info.TaskRelayInfo = &TaskRelayInfo{} case types.RelayFormatMjProxy: info = genBaseRelayInfo(c, nil) + info.TaskRelayInfo = &TaskRelayInfo{} default: err = errors.New("invalid relay format") } @@ -600,6 +607,19 @@ func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) { info.RequestConversionChain = append(info.RequestConversionChain, format) } +func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat { + if info == nil { + return "" + } + if info.FinalRequestRelayFormat != "" { + return info.FinalRequestRelayFormat + } + if n := len(info.RequestConversionChain); n > 0 { + return info.RequestConversionChain[n-1] + } + return info.RelayFormat +} + func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo { info := genBaseRelayInfo(c, request) if info.RelayMode == relayconstant.RelayModeUnknown { @@ -635,8 +655,16 @@ func (info *RelayInfo) HasSendResponse() bool { type TaskRelayInfo struct { Action string OriginTaskID string + // PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID, + // 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。 + PublicTaskID string ConsumeQuota bool + + // LockedChannel holds the full channel object when the request is bound to + // a specific channel (e.g., remix on origin task's channel). Stored as any + // to avoid an import cycle with model; callers type-assert to *model.Channel. + LockedChannel any } type TaskSubmitReq struct { @@ -694,11 +722,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error { func (t *TaskSubmitReq) UnmarshalMetadata(v any) error { metadata := t.Metadata if metadata != nil { - metadataBytes, err := json.Marshal(metadata) + metadataBytes, err := common.Marshal(metadata) if err != nil { return fmt.Errorf("marshal metadata failed: %w", err) } - err = json.Unmarshal(metadataBytes, v) + err = common.Unmarshal(metadataBytes, v) if err != nil { return fmt.Errorf("unmarshal metadata to target failed: %w", err) } @@ -727,9 +755,15 @@ func FailTaskInfo(reason string) *TaskInfo { // RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段 // service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持) +// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤) // store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用) // safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私) -func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) { +// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持) +func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled { + return jsonData, nil + } + var data map[string]interface{} if err := common.Unmarshal(jsonData, &data); err != nil { common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error()) @@ -743,6 +777,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther } } + // 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域) + if !channelOtherSettings.AllowInferenceGeo { + if _, exists := data["inference_geo"]; exists { + delete(data, "inference_geo") + } + } + // 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用) if channelOtherSettings.DisableStore { if _, exists := data["store"]; exists { @@ -757,6 +798,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther } } + // 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护) + if !channelOtherSettings.AllowIncludeObfuscation { + if streamOptionsAny, exists := data["stream_options"]; exists { + if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok { + if _, includeExists := streamOptions["include_obfuscation"]; includeExists { + delete(streamOptions, "include_obfuscation") + } + if len(streamOptions) == 0 { + delete(data, "stream_options") + } else { + data["stream_options"] = streamOptions + } + } + } + } + jsonDataAfter, err := common.Marshal(data) if err != nil { common.SysError("RemoveDisabledFields Marshal error :" + err.Error()) diff --git a/relay/common/relay_info_test.go b/relay/common/relay_info_test.go new file mode 100644 index 000000000..e53ec804c --- /dev/null +++ b/relay/common/relay_info_test.go @@ -0,0 +1,40 @@ +package common + +import ( + "testing" + + "github.com/QuantumNous/new-api/types" + "github.com/stretchr/testify/require" +) + +func TestRelayInfoGetFinalRequestRelayFormatPrefersExplicitFinal(t *testing.T) { + info := &RelayInfo{ + RelayFormat: types.RelayFormatOpenAI, + RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude}, + FinalRequestRelayFormat: types.RelayFormatOpenAIResponses, + } + + require.Equal(t, types.RelayFormat(types.RelayFormatOpenAIResponses), info.GetFinalRequestRelayFormat()) +} + +func TestRelayInfoGetFinalRequestRelayFormatFallsBackToConversionChain(t *testing.T) { + info := &RelayInfo{ + RelayFormat: types.RelayFormatOpenAI, + RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude}, + } + + require.Equal(t, types.RelayFormat(types.RelayFormatClaude), info.GetFinalRequestRelayFormat()) +} + +func TestRelayInfoGetFinalRequestRelayFormatFallsBackToRelayFormat(t *testing.T) { + info := &RelayInfo{ + RelayFormat: types.RelayFormatGemini, + } + + require.Equal(t, types.RelayFormat(types.RelayFormatGemini), info.GetFinalRequestRelayFormat()) +} + +func TestRelayInfoGetFinalRequestRelayFormatNilReceiver(t *testing.T) { + var info *RelayInfo + require.Equal(t, types.RelayFormat(""), info.GetFinalRequestRelayFormat()) +} diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index b662f9053..3cbb18c22 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) { return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) } - info.PriceData.OtherRatios = map[string]float64{ - "seconds": float64(seconds), - "size": 1, - } - if lo.Contains([]string{"1792x1024", "1024x1792"}, size) { - info.PriceData.OtherRatios["size"] = 1.666667 - } + // OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置 } - info.Action = action + storeTaskRequest(c, info, action, req) return nil } diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 7f4b99488..9a25237c7 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -165,7 +165,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types } // remove disabled fields for OpenAI API - jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings) + jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } @@ -232,7 +232,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage } if originUsage != nil { - service.ObserveChannelAffinityUsageCacheFromContext(ctx, usage) + service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat()) } adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason) @@ -336,7 +336,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage var audioInputQuota decimal.Decimal var audioInputPrice float64 - isClaudeUsageSemantic := relayInfo.FinalRequestRelayFormat == types.RelayFormatClaude + isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude if !relayInfo.PriceData.UsePrice { baseTokens := dPromptTokens // 减去 cached tokens diff --git a/relay/helper/price.go b/relay/helper/price.go index c310220fe..1cb04166f 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens } // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData { groupRatioInfo := HandleGroupRatio(c, info) modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) @@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types. } } quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) - priceData := types.PerCallPriceData{ + + // 免费模型检测(与 ModelPriceHelper 对齐) + freeModel := false + if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume { + if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 { + quota = 0 + freeModel = true + } + } + + priceData := types.PriceData{ + FreeModel: freeModel, ModelPrice: modelPrice, Quota: quota, GroupRatioInfo: groupRatioInfo, diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index 4f3ab2363..ae70f53c0 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -176,10 +176,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon }) } + dataChan := make(chan string, 10) + + wg.Add(1) + gopool.Go(func() { + defer func() { + wg.Done() + if r := recover(); r != nil { + logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r)) + } + common.SafeSendBool(stopChan, true) + }() + for data := range dataChan { + writeMutex.Lock() + success := dataHandler(data) + writeMutex.Unlock() + if !success { + return + } + } + }) + // Scanner goroutine with improved error handling wg.Add(1) common.RelayCtxGo(ctx, func() { defer func() { + close(dataChan) wg.Done() if r := recover(); r != nil { logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) @@ -215,27 +237,16 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon continue } data = data[5:] - data = strings.TrimLeft(data, " ") - data = strings.TrimSuffix(data, "\r") + data = strings.TrimSpace(data) + if data == "" { + continue + } if !strings.HasPrefix(data, "[DONE]") { info.SetFirstResponseTime() info.ReceivedResponseCount++ - // 使用超时机制防止写操作阻塞 - done := make(chan bool, 1) - gopool.Go(func() { - writeMutex.Lock() - defer writeMutex.Unlock() - done <- dataHandler(data) - }) select { - case success := <-done: - if !success { - return - } - case <-time.After(10 * time.Second): - logger.LogError(c, "data handler timeout") - return + case dataChan <- data: case <-ctx.Done(): return case <-stopChan: diff --git a/relay/helper/stream_scanner_test.go b/relay/helper/stream_scanner_test.go new file mode 100644 index 000000000..6890d82a5 --- /dev/null +++ b/relay/helper/stream_scanner_test.go @@ -0,0 +1,521 @@ +package helper + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/QuantumNous/new-api/constant" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/operation_setting" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) { + t.Helper() + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { + constant.StreamingTimeout = oldTimeout + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + resp := &http.Response{ + Body: io.NopCloser(body), + } + + info := &relaycommon.RelayInfo{ + ChannelMeta: &relaycommon.ChannelMeta{}, + } + + return c, resp, info +} + +func buildSSEBody(n int) string { + var b strings.Builder + for i := 0; i < n; i++ { + fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i) + } + b.WriteString("data: [DONE]\n") + return b.String() +} + +// slowReader wraps a reader and injects a delay before each Read call, +// simulating a slow upstream that trickles data. +type slowReader struct { + r io.Reader + delay time.Duration +} + +func (s *slowReader) Read(p []byte) (int, error) { + time.Sleep(s.delay) + return s.r.Read(p) +} + +// ---------- Basic correctness ---------- + +func TestStreamScannerHandler_NilInputs(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} + + StreamScannerHandler(c, nil, info, func(data string) bool { return true }) + StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil) +} + +func TestStreamScannerHandler_EmptyBody(t *testing.T) { + t.Parallel() + + c, resp, info := setupStreamTest(t, strings.NewReader("")) + + var called atomic.Bool + StreamScannerHandler(c, resp, info, func(data string) bool { + called.Store(true) + return true + }) + + assert.False(t, called.Load(), "handler should not be called for empty body") +} + +func TestStreamScannerHandler_1000Chunks(t *testing.T) { + t.Parallel() + + const numChunks = 1000 + body := buildSSEBody(numChunks) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + + assert.Equal(t, int64(numChunks), count.Load()) + assert.Equal(t, numChunks, info.ReceivedResponseCount) +} + +func TestStreamScannerHandler_10000Chunks(t *testing.T) { + t.Parallel() + + const numChunks = 10000 + body := buildSSEBody(numChunks) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var count atomic.Int64 + start := time.Now() + + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + + elapsed := time.Since(start) + assert.Equal(t, int64(numChunks), count.Load()) + assert.Equal(t, numChunks, info.ReceivedResponseCount) + t.Logf("10000 chunks processed in %v", elapsed) +} + +func TestStreamScannerHandler_OrderPreserved(t *testing.T) { + t.Parallel() + + const numChunks = 500 + body := buildSSEBody(numChunks) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var mu sync.Mutex + received := make([]string, 0, numChunks) + + StreamScannerHandler(c, resp, info, func(data string) bool { + mu.Lock() + received = append(received, data) + mu.Unlock() + return true + }) + + require.Equal(t, numChunks, len(received)) + for i := 0; i < numChunks; i++ { + expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i) + assert.Equal(t, expected, received[i], "chunk %d out of order", i) + } +} + +func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) { + t.Parallel() + + body := buildSSEBody(50) + "data: should_not_appear\n" + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + + assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed") +} + +func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) { + t.Parallel() + + const numChunks = 200 + body := buildSSEBody(numChunks) + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + const failAt = 50 + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string) bool { + n := count.Add(1) + return n < failAt + }) + + // The worker stops at failAt; the scanner may have read ahead, + // but the handler should not be called beyond failAt. + assert.Equal(t, int64(failAt), count.Load()) +} + +func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) { + t.Parallel() + + var b strings.Builder + b.WriteString(": comment line\n") + b.WriteString("event: message\n") + b.WriteString("id: 12345\n") + b.WriteString("retry: 5000\n") + for i := 0; i < 100; i++ { + fmt.Fprintf(&b, "data: payload_%d\n", i) + b.WriteString(": interleaved comment\n") + } + b.WriteString("data: [DONE]\n") + + c, resp, info := setupStreamTest(t, strings.NewReader(b.String())) + + var count atomic.Int64 + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + + assert.Equal(t, int64(100), count.Load()) +} + +func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) { + t.Parallel() + + body := "data: {\"trimmed\":true} \ndata: [DONE]\n" + c, resp, info := setupStreamTest(t, strings.NewReader(body)) + + var got string + StreamScannerHandler(c, resp, info, func(data string) bool { + got = data + return true + }) + + assert.Equal(t, "{\"trimmed\":true}", got) +} + +// ---------- Decoupling: scanner not blocked by slow handler ---------- + +func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) { + t.Parallel() + + // Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk). + // If the scanner were synchronously coupled to the handler, total time would be + // ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms. + // With decoupling, total time should be closer to + // ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms + // because the scanner reads ahead into the buffer while the handler processes. + const numChunks = 50 + const upstreamDelay = 10 * time.Millisecond + const handlerDelay = 20 * time.Millisecond + + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for i := 0; i < numChunks; i++ { + fmt.Fprintf(pw, "data: {\"id\":%d}\n", i) + time.Sleep(upstreamDelay) + } + fmt.Fprint(pw, "data: [DONE]\n") + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) + + resp := &http.Response{Body: pr} + info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} + + var count atomic.Int64 + start := time.Now() + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + time.Sleep(handlerDelay) + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("StreamScannerHandler did not complete in time") + } + + elapsed := time.Since(start) + assert.Equal(t, int64(numChunks), count.Load()) + + coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay) + t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime) + + // If decoupled, elapsed should be well under the coupled estimate. + assert.Less(t, elapsed, coupledTime*85/100, + "decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime) +} + +func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) { + t.Parallel() + + const numChunks = 50 + body := buildSSEBody(numChunks) + reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond} + c, resp, info := setupStreamTest(t, reader) + + var count atomic.Int64 + start := time.Now() + + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("timed out with slow upstream") + } + + elapsed := time.Since(start) + assert.Equal(t, int64(numChunks), count.Load()) + t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed) +} + +// ---------- Ping tests ---------- + +func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) { + t.Parallel() + + setting := operation_setting.GetGeneralSetting() + oldEnabled := setting.PingIntervalEnabled + oldSeconds := setting.PingIntervalSeconds + setting.PingIntervalEnabled = true + setting.PingIntervalSeconds = 1 + t.Cleanup(func() { + setting.PingIntervalEnabled = oldEnabled + setting.PingIntervalSeconds = oldSeconds + }) + + // Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds. + // The ping interval is 1s, so we should see at least 2 pings. + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for i := 0; i < 7; i++ { + fmt.Fprintf(pw, "data: chunk_%d\n", i) + time.Sleep(500 * time.Millisecond) + } + fmt.Fprint(pw, "data: [DONE]\n") + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { + constant.StreamingTimeout = oldTimeout + }) + + resp := &http.Response{Body: pr} + info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} + + var count atomic.Int64 + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for stream to finish") + } + + assert.Equal(t, int64(7), count.Load()) + + body := recorder.Body.String() + pingCount := strings.Count(body, ": PING") + t.Logf("received %d pings in response body", pingCount) + assert.GreaterOrEqual(t, pingCount, 2, + "expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount) +} + +func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) { + t.Parallel() + + setting := operation_setting.GetGeneralSetting() + oldEnabled := setting.PingIntervalEnabled + oldSeconds := setting.PingIntervalSeconds + setting.PingIntervalEnabled = true + setting.PingIntervalSeconds = 1 + t.Cleanup(func() { + setting.PingIntervalEnabled = oldEnabled + setting.PingIntervalSeconds = oldSeconds + }) + + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for i := 0; i < 5; i++ { + fmt.Fprintf(pw, "data: chunk_%d\n", i) + time.Sleep(500 * time.Millisecond) + } + fmt.Fprint(pw, "data: [DONE]\n") + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { + constant.StreamingTimeout = oldTimeout + }) + + resp := &http.Response{Body: pr} + info := &relaycommon.RelayInfo{ + DisablePing: true, + ChannelMeta: &relaycommon.ChannelMeta{}, + } + + var count atomic.Int64 + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("timed out") + } + + assert.Equal(t, int64(5), count.Load()) + + body := recorder.Body.String() + pingCount := strings.Count(body, ": PING") + assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true") +} + +func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) { + t.Parallel() + + setting := operation_setting.GetGeneralSetting() + oldEnabled := setting.PingIntervalEnabled + oldSeconds := setting.PingIntervalSeconds + setting.PingIntervalEnabled = true + setting.PingIntervalSeconds = 1 + t.Cleanup(func() { + setting.PingIntervalEnabled = oldEnabled + setting.PingIntervalSeconds = oldSeconds + }) + + // Slow upstream + slow handler. Total stream takes ~5 seconds. + // The ping goroutine stays alive as long as the scanner is reading, + // so pings should fire between data writes. + pr, pw := io.Pipe() + go func() { + defer pw.Close() + for i := 0; i < 10; i++ { + fmt.Fprintf(pw, "data: chunk_%d\n", i) + time.Sleep(500 * time.Millisecond) + } + fmt.Fprint(pw, "data: [DONE]\n") + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + oldTimeout := constant.StreamingTimeout + constant.StreamingTimeout = 30 + t.Cleanup(func() { + constant.StreamingTimeout = oldTimeout + }) + + resp := &http.Response{Body: pr} + info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} + + var count atomic.Int64 + done := make(chan struct{}) + go func() { + StreamScannerHandler(c, resp, info, func(data string) bool { + count.Add(1) + return true + }) + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("timed out") + } + + assert.Equal(t, int64(10), count.Load()) + + body := recorder.Body.String() + pingCount := strings.Count(body, ": PING") + t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount) + assert.GreaterOrEqual(t, pingCount, 3, + "expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount) +} diff --git a/relay/mjproxy_handler.go b/relay/mjproxy_handler.go index 8916ab181..8e7c61e9c 100644 --- a/relay/mjproxy_handler.go +++ b/relay/mjproxy_handler.go @@ -184,7 +184,7 @@ func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyR if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") } - modelName := service.CoverActionToModelName(constant.MjActionSwapFace) + modelName := service.CovertMjpActionToModelName(constant.MjActionSwapFace) priceData := helper.ModelPriceHelperPerCall(c, info) @@ -485,7 +485,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dt fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - modelName := service.CoverActionToModelName(midjRequest.Action) + modelName := service.CovertMjpActionToModelName(midjRequest.Action) priceData := helper.ModelPriceHelperPerCall(c, relayInfo) diff --git a/relay/relay_task.go b/relay/relay_task.go index ebbd1f65d..c740facdb 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "errors" "fmt" "io" @@ -15,29 +14,33 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" + "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" - "github.com/QuantumNous/new-api/setting/ratio_setting" - "github.com/gin-gonic/gin" ) -/* -Task 任务通过平台、Action 区分任务 -*/ -func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - info.InitChannelMeta(c) - // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields - if info.TaskRelayInfo == nil { - info.TaskRelayInfo = &relaycommon.TaskRelayInfo{} - } +type TaskSubmitResult struct { + UpstreamTaskID string + TaskData []byte + Platform constant.TaskPlatform + Quota int + //PerCallPrice types.PriceData +} + +// ResolveOriginTask 处理基于已有任务的提交(remix / continuation): +// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道 +// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key), +// 以及提取 OtherRatios(时长、分辨率)。 +// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。 +func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { + // 检测 remix action path := c.Request.URL.Path if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") { info.Action = constant.TaskActionRemix } - - // 提取 remix 任务的 video_id if info.Action == constant.TaskActionRemix { videoID := c.Param("video_id") if strings.TrimSpace(videoID) == "" { @@ -46,64 +49,71 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. info.OriginTaskID = videoID } - platform := constant.TaskPlatform(c.GetString("platform")) + if info.OriginTaskID == "" { + return nil + } - // 获取原始任务信息 - if info.OriginTaskID != "" { - originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) - return - } - if !exist { - taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) - return - } - if info.OriginModelName == "" { - if originTask.Properties.OriginModelName != "" { - info.OriginModelName = originTask.Properties.OriginModelName - } else if originTask.Properties.UpstreamModelName != "" { - info.OriginModelName = originTask.Properties.UpstreamModelName - } else { - var taskData map[string]interface{} - _ = json.Unmarshal(originTask.Data, &taskData) - if m, ok := taskData["model"].(string); ok && m != "" { - info.OriginModelName = m - platform = originTask.Platform - } - } - } - if originTask.ChannelId != info.ChannelId { - channel, err := model.GetChannelById(originTask.ChannelId, true) - if err != nil { - taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) - return - } - if channel.Status != common.ChannelStatusEnabled { - taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) - return - } - key, _, newAPIError := channel.GetNextEnabledKey() - if newAPIError != nil { - taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) - return - } - common.SetContextKey(c, constant.ContextKeyChannelKey, key) - common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) - common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) - common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) + // 查找原始任务 + originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) + if err != nil { + return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) + } + if !exist { + return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) + } - info.ChannelBaseUrl = channel.GetBaseURL() - info.ChannelId = originTask.ChannelId - info.ChannelType = channel.Type - info.ApiKey = key - platform = originTask.Platform - } - - // 使用原始任务的参数 - if info.Action == constant.TaskActionRemix { + // 从原始任务推导模型名称 + if info.OriginModelName == "" { + if originTask.Properties.OriginModelName != "" { + info.OriginModelName = originTask.Properties.OriginModelName + } else if originTask.Properties.UpstreamModelName != "" { + info.OriginModelName = originTask.Properties.UpstreamModelName + } else { var taskData map[string]interface{} - _ = json.Unmarshal(originTask.Data, &taskData) + _ = common.Unmarshal(originTask.Data, &taskData) + if m, ok := taskData["model"].(string); ok && m != "" { + info.OriginModelName = m + } + } + } + + // 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key) + ch, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + } + if ch.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) + } + info.LockedChannel = ch + + if originTask.ChannelId != info.ChannelId { + key, _, newAPIError := ch.GetNextEnabledKey() + if newAPIError != nil { + return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) + } + common.SetContextKey(c, constant.ContextKeyChannelKey, key) + common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type) + common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL()) + common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) + + info.ChannelBaseUrl = ch.GetBaseURL() + info.ChannelId = originTask.ChannelId + info.ChannelType = ch.Type + info.ApiKey = key + } + + // 提取 remix 参数(时长、分辨率 → OtherRatios) + if info.Action == constant.TaskActionRemix { + if originTask.PrivateData.BillingContext != nil { + // 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在) + for s, f := range originTask.PrivateData.BillingContext.OtherRatios { + info.PriceData.AddOtherRatio(s, f) + } + } else { + // 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在) + var taskData map[string]interface{} + _ = common.Unmarshal(originTask.Data, &taskData) secondsStr, _ := taskData["seconds"].(string) seconds, _ := strconv.Atoi(secondsStr) if seconds <= 0 { @@ -120,167 +130,146 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. } } } + + return nil +} + +// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次): +// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → +// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→ +// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。 +// 控制器负责 defer Refund 和成功后 Settle。 +func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) { + info.InitChannelMeta(c) + + // 1. 确定 platform → 创建适配器 → 验证请求 + platform := constant.TaskPlatform(c.GetString("platform")) if platform == "" { platform = GetTaskPlatform(c) } - - info.InitChannelMeta(c) adaptor := GetTaskAdaptor(platform) if adaptor == nil { - return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) + return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) } adaptor.Init(info) - // get & validate taskRequest 获取并验证文本请求 - taskErr = adaptor.ValidateRequestAndSetAction(c, info) - if taskErr != nil { - return + if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil { + return nil, taskErr } + // 2. 确定模型名称 modelName := info.OriginModelName if modelName == "" { modelName = service.CoverTaskActionToModelName(platform, info.Action) } - modelPrice, success := ratio_setting.GetModelPrice(modelName, true) - if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName] - if !ok { - modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit - } else { - modelPrice = defaultPrice + + // 2.5 应用渠道的模型映射(与同步任务对齐) + info.OriginModelName = modelName + info.UpstreamModelName = modelName + if err := helper.ModelMappedHelper(c, info, nil); err != nil { + return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest) + } + + // 3. 预生成公开 task ID(仅首次) + if info.PublicTaskID == "" { + info.PublicTaskID = model.GenerateTaskID() + } + + // 4. 价格计算:基础模型价格 + info.OriginModelName = modelName + info.PriceData = helper.ModelPriceHelperPerCall(c, info) + + // 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等) + // 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。 + // ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。 + if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 { + for k, v := range estimatedRatios { + info.PriceData.AddOtherRatio(k, v) } } - // 处理 auto 分组:从 context 获取实际选中的分组 - // 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中 - if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists { - if groupStr, ok := autoGroup.(string); ok && groupStr != "" { - info.UsingGroup = groupStr - } - } - - // 预扣 - groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup) - var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup) - if hasUserGroupRatio { - ratio = modelPrice * userGroupRatio - } else { - ratio = modelPrice * groupRatio - } - // FIXME: 临时修补,支持任务仅按次计费 + // 6. 将 OtherRatios 应用到基础额度 if !common.StringsContains(constant.TaskPricePatches, modelName) { - if len(info.PriceData.OtherRatios) > 0 { - for _, ra := range info.PriceData.OtherRatios { - if 1.0 != ra { - ratio *= ra - } + for _, ra := range info.PriceData.OtherRatios { + if ra != 1.0 { + info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra) } } } - println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio)) - userQuota, err := model.GetUserQuota(info.UserId, false) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - return - } - quota := int(ratio * common.QuotaPerUnit) - if userQuota-quota < 0 { - taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden) - return + + // 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) + if info.Billing == nil && !info.PriceData.FreeModel { + info.ForcePreConsume = true + if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil { + return nil, service.TaskErrorFromAPIError(apiErr) + } } - // build body + // 8. 构建请求体 requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { - taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) - return + return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) } - // do request + + // 9. 发送请求 resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) - return + return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - // handle response if resp != nil && resp.StatusCode != http.StatusOK { responseBody, _ := io.ReadAll(resp.Body) - taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) - return + return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) } - defer func() { - // release quota - if info.ConsumeQuota && taskErr == nil { + // 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置) + otherRatios := info.PriceData.OtherRatios + if otherRatios == nil { + otherRatios = map[string]float64{} + } + ratiosJSON, _ := common.Marshal(otherRatios) + c.Header("X-New-Api-Other-Ratios", string(ratiosJSON)) - err := service.PostConsumeQuota(info, quota, 0, true) - if err != nil { - common.SysLog("error consuming token remain quota: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - //gRatio := groupRatio - //if hasUserGroupRatio { - // gRatio = userGroupRatio - //} - logContent := fmt.Sprintf("操作 %s", info.Action) - // FIXME: 临时修补,支持任务仅按次计费 - if common.StringsContains(constant.TaskPricePatches, modelName) { - logContent = fmt.Sprintf("%s,按次计费", logContent) - } else { - if len(info.PriceData.OtherRatios) > 0 { - var contents []string - for key, ra := range info.PriceData.OtherRatios { - if 1.0 != ra { - contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) - } - } - if len(contents) > 0 { - logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) - } - } - } - other := make(map[string]interface{}) - if c != nil && c.Request != nil && c.Request.URL != nil { - other["request_path"] = c.Request.URL.Path - } - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - if hasUserGroupRatio { - other["user_group_ratio"] = userGroupRatio - } - model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ - ChannelId: info.ChannelId, - ModelName: modelName, - TokenName: tokenName, - Quota: quota, - Content: logContent, - TokenId: info.TokenId, - Group: info.UsingGroup, - Other: other, - }) - model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota) - model.UpdateChannelUsedQuota(info.ChannelId, quota) - } - } - }() - - taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) + // 11. 解析响应 + upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { - return + return nil, taskErr } - info.ConsumeQuota = true - // insert task - task := model.InitTask(platform, info) - task.TaskID = taskID - task.Quota = quota - task.Data = taskData - task.Action = info.Action - err = task.Insert() - if err != nil { - taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) - return + + // 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios + finalQuota := info.PriceData.Quota + if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 { + // 基于调整后的 ratios 重新计算 quota + finalQuota = recalcQuotaFromRatios(info, adjustedRatios) + info.PriceData.OtherRatios = adjustedRatios + info.PriceData.Quota = finalQuota } - return nil + + return &TaskSubmitResult{ + UpstreamTaskID: upstreamTaskID, + TaskData: taskData, + Platform: platform, + Quota: finalQuota, + }, nil +} + +// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。 +// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。 +func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int { + // 从 PriceData 获取不含 OtherRatios 的基础价格 + baseQuota := info.PriceData.Quota + // 先除掉原有的 OtherRatios 恢复基础额度 + for _, ra := range info.PriceData.OtherRatios { + if ra != 1.0 && ra > 0 { + baseQuota = int(float64(baseQuota) / ra) + } + } + // 应用新的 ratios + result := float64(baseQuota) + for _, ra := range ratios { + if ra != 1.0 { + result *= ra + } + } + return int(result) } var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ @@ -336,7 +325,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta } else { tasks = make([]any, 0) } - respBody, err = json.Marshal(dto.TaskResponse[[]any]{ + respBody, err = common.Marshal(dto.TaskResponse[[]any]{ Code: "success", Data: tasks, }) @@ -357,7 +346,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ + respBody, err = common.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) @@ -381,97 +370,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } - func() { - channelModel, err2 := model.GetChannelById(originTask.ChannelId, true) - if err2 != nil { - return - } - if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { - return - } - baseURL := constant.ChannelBaseURLs[channelModel.Type] - if channelModel.GetBaseURL() != "" { - baseURL = channelModel.GetBaseURL() - } - proxy := channelModel.GetSetting().Proxy - adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) - if adaptor == nil { - return - } - resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ - "task_id": originTask.TaskID, - "action": originTask.Action, - }, proxy) - if err2 != nil || resp == nil { - return - } - defer resp.Body.Close() - body, err2 := io.ReadAll(resp.Body) - if err2 != nil { - return - } - ti, err2 := adaptor.ParseTaskResult(body) - if err2 == nil && ti != nil { - if ti.Status != "" { - originTask.Status = model.TaskStatus(ti.Status) - } - if ti.Progress != "" { - originTask.Progress = ti.Progress - } - if ti.Url != "" { - if strings.HasPrefix(ti.Url, "data:") { - } else { - originTask.FailReason = ti.Url - } - } - _ = originTask.Update() - var raw map[string]any - _ = json.Unmarshal(body, &raw) - format := "mp4" - if respObj, ok := raw["response"].(map[string]any); ok { - if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 { - if v0, ok := vids[0].(map[string]any); ok { - if mt, ok := v0["mimeType"].(string); ok && mt != "" { - if strings.Contains(mt, "mp4") { - format = "mp4" - } else { - format = mt - } - } - } - } - } - status := "processing" - switch originTask.Status { - case model.TaskStatusSuccess: - status = "succeeded" - case model.TaskStatusFailure: - status = "failed" - case model.TaskStatusQueued, model.TaskStatusSubmitted: - status = "queued" - } - if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { - out := map[string]any{ - "error": nil, - "format": format, - "metadata": nil, - "status": status, - "task_id": originTask.TaskID, - "url": originTask.FailReason, - } - respBody, _ = json.Marshal(dto.TaskResponse[any]{ - Code: "success", - Data: out, - }) - } - } - }() + isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") - if len(respBody) != 0 { + // Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态 + if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 { + respBody = realtimeResp return } - if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { + // OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo + if isOpenAIVideoAPI { adaptor := GetTaskAdaptor(originTask.Platform) if adaptor == nil { taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest) @@ -486,10 +394,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d respBody = openAIVideoData return } - taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented) + taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented) return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ + + // 通用 TaskDto 格式 + respBody, err = common.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) @@ -499,16 +409,150 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } +// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。 +// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。 +// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。 +func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { + channelModel, err := model.GetChannelById(task.ChannelId, true) + if err != nil { + return nil + } + if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { + return nil + } + + baseURL := constant.ChannelBaseURLs[channelModel.Type] + if channelModel.GetBaseURL() != "" { + baseURL = channelModel.GetBaseURL() + } + proxy := channelModel.GetSetting().Proxy + adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) + if adaptor == nil { + return nil + } + + resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ + "task_id": task.GetUpstreamTaskID(), + "action": task.Action, + }, proxy) + if err != nil || resp == nil { + return nil + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil + } + + ti, err := adaptor.ParseTaskResult(body) + if err != nil || ti == nil { + return nil + } + + snap := task.Snapshot() + + // 将上游最新状态更新到 task + if ti.Status != "" { + task.Status = model.TaskStatus(ti.Status) + } + if ti.Progress != "" { + task.Progress = ti.Progress + } + if strings.HasPrefix(ti.Url, "data:") { + // data: URI — kept in Data, not ResultURL + } else if ti.Url != "" { + task.PrivateData.ResultURL = ti.Url + } else if task.Status == model.TaskStatusSuccess { + // No URL from adaptor — construct proxy URL using public task ID + task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) + } + + if !snap.Equal(task.Snapshot()) { + _, _ = task.UpdateWithStatus(snap.Status) + } + + // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理 + if isOpenAIVideoAPI { + return nil + } + + // 非 OpenAI Video API: 构建自定义格式响应 + format := detectVideoFormat(body) + out := map[string]any{ + "error": nil, + "format": format, + "metadata": nil, + "status": mapTaskStatusToSimple(task.Status), + "task_id": task.TaskID, + "url": task.GetResultURL(), + } + respBody, _ := common.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: out, + }) + return respBody +} + +// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式 +func detectVideoFormat(rawBody []byte) string { + var raw map[string]any + if err := common.Unmarshal(rawBody, &raw); err != nil { + return "mp4" + } + respObj, ok := raw["response"].(map[string]any) + if !ok { + return "mp4" + } + vids, ok := respObj["videos"].([]any) + if !ok || len(vids) == 0 { + return "mp4" + } + v0, ok := vids[0].(map[string]any) + if !ok { + return "mp4" + } + mt, ok := v0["mimeType"].(string) + if !ok || mt == "" || strings.Contains(mt, "mp4") { + return "mp4" + } + return mt +} + +// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串 +func mapTaskStatusToSimple(status model.TaskStatus) string { + switch status { + case model.TaskStatusSuccess: + return "succeeded" + case model.TaskStatusFailure: + return "failed" + case model.TaskStatusQueued, model.TaskStatusSubmitted: + return "queued" + default: + return "processing" + } +} + func TaskModel2Dto(task *model.Task) *dto.TaskDto { return &dto.TaskDto{ + ID: task.ID, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, TaskID: task.TaskID, + Platform: string(task.Platform), + UserId: task.UserId, + Group: task.Group, + ChannelId: task.ChannelId, + Quota: task.Quota, Action: task.Action, Status: string(task.Status), FailReason: task.FailReason, + ResultURL: task.GetResultURL(), SubmitTime: task.SubmitTime, StartTime: task.StartTime, FinishTime: task.FinishTime, Progress: task.Progress, + Properties: task.Properties, + Username: task.Username, Data: task.Data, } } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 3bcaa673f..18f1b7118 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -89,7 +89,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * } // remove disabled fields for OpenAI Responses API - jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings) + jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } diff --git a/router/api-router.go b/router/api-router.go index e2ef2f531..d48934000 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -13,6 +13,7 @@ import ( func SetApiRouter(router *gin.Engine) { apiRouter := router.Group("/api") + apiRouter.Use(middleware.RouteTag("api")) apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.BodyStorageCleanup()) // 清理请求体存储 apiRouter.Use(middleware.GlobalAPIRateLimit()) @@ -114,6 +115,9 @@ func SetApiRouter(router *gin.Engine) { adminRoute.GET("/topup", controller.GetAllTopUps) adminRoute.POST("/topup/complete", controller.AdminCompleteTopUp) adminRoute.GET("/search", controller.SearchUsers) + adminRoute.GET("/:id/oauth/bindings", controller.GetUserOAuthBindingsByAdmin) + adminRoute.DELETE("/:id/oauth/bindings/:provider_id", controller.UnbindCustomOAuthByAdmin) + adminRoute.DELETE("/:id/bindings/:binding_type", controller.AdminClearUserBinding) adminRoute.GET("/:id", controller.GetUser) adminRoute.POST("/", controller.CreateUser) adminRoute.POST("/manage", controller.ManageUser) @@ -170,10 +174,11 @@ func SetApiRouter(router *gin.Engine) { optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } - // Custom OAuth provider management (admin only) + // Custom OAuth provider management (root only) customOAuthRoute := apiRouter.Group("/custom-oauth-provider") customOAuthRoute.Use(middleware.RootAuth()) { + customOAuthRoute.POST("/discovery", controller.FetchCustomOAuthDiscovery) customOAuthRoute.GET("/", controller.GetCustomOAuthProviders) customOAuthRoute.GET("/:id", controller.GetCustomOAuthProvider) customOAuthRoute.POST("/", controller.CreateCustomOAuthProvider) diff --git a/router/dashboard.go b/router/dashboard.go index 17132dfb2..2e486156d 100644 --- a/router/dashboard.go +++ b/router/dashboard.go @@ -9,6 +9,7 @@ import ( func SetDashboardRouter(router *gin.Engine) { apiRouter := router.Group("/") + apiRouter.Use(middleware.RouteTag("old_api")) apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.GlobalAPIRateLimit()) apiRouter.Use(middleware.CORS()) diff --git a/router/main.go b/router/main.go index 45b3080f2..ac9506fe4 100644 --- a/router/main.go +++ b/router/main.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/middleware" "github.com/gin-gonic/gin" ) @@ -27,6 +28,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { } else { frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/") router.NoRoute(func(c *gin.Context) { + c.Set(middleware.RouteTagKey, "web") c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI)) }) } diff --git a/router/relay-router.go b/router/relay-router.go index 04584945b..3d38be5ee 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -17,6 +17,7 @@ func SetRelayRouter(router *gin.Engine) { router.Use(middleware.StatsMiddleware()) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") + modelsRouter.Use(middleware.RouteTag("relay")) modelsRouter.Use(middleware.TokenAuth()) { modelsRouter.GET("", func(c *gin.Context) { @@ -41,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) { } geminiRouter := router.Group("/v1beta/models") + geminiRouter.Use(middleware.RouteTag("relay")) geminiRouter.Use(middleware.TokenAuth()) { geminiRouter.GET("", func(c *gin.Context) { @@ -49,6 +51,7 @@ func SetRelayRouter(router *gin.Engine) { } geminiCompatibleRouter := router.Group("/v1beta/openai/models") + geminiCompatibleRouter.Use(middleware.RouteTag("relay")) geminiCompatibleRouter.Use(middleware.TokenAuth()) { geminiCompatibleRouter.GET("", func(c *gin.Context) { @@ -57,12 +60,14 @@ func SetRelayRouter(router *gin.Engine) { } playgroundRouter := router.Group("/pg") + playgroundRouter.Use(middleware.RouteTag("relay")) playgroundRouter.Use(middleware.SystemPerformanceCheck()) playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute()) { playgroundRouter.POST("/chat/completions", controller.Playground) } relayV1Router := router.Group("/v1") + relayV1Router.Use(middleware.RouteTag("relay")) relayV1Router.Use(middleware.SystemPerformanceCheck()) relayV1Router.Use(middleware.TokenAuth()) relayV1Router.Use(middleware.ModelRequestRateLimit()) @@ -161,24 +166,28 @@ func SetRelayRouter(router *gin.Engine) { } relayMjRouter := router.Group("/mj") + relayMjRouter.Use(middleware.RouteTag("relay")) relayMjRouter.Use(middleware.SystemPerformanceCheck()) registerMjRouterGroup(relayMjRouter) relayMjModeRouter := router.Group("/:mode/mj") + relayMjModeRouter.Use(middleware.RouteTag("relay")) relayMjModeRouter.Use(middleware.SystemPerformanceCheck()) registerMjRouterGroup(relayMjModeRouter) //relayMjRouter.Use() relaySunoRouter := router.Group("/suno") + relaySunoRouter.Use(middleware.RouteTag("relay")) relaySunoRouter.Use(middleware.SystemPerformanceCheck()) relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { relaySunoRouter.POST("/submit/:action", controller.RelayTask) - relaySunoRouter.POST("/fetch", controller.RelayTask) - relaySunoRouter.GET("/fetch/:id", controller.RelayTask) + relaySunoRouter.POST("/fetch", controller.RelayTaskFetch) + relaySunoRouter.GET("/fetch/:id", controller.RelayTaskFetch) } relayGeminiRouter := router.Group("/v1beta") + relayGeminiRouter.Use(middleware.RouteTag("relay")) relayGeminiRouter.Use(middleware.SystemPerformanceCheck()) relayGeminiRouter.Use(middleware.TokenAuth()) relayGeminiRouter.Use(middleware.ModelRequestRateLimit()) diff --git a/router/video-router.go b/router/video-router.go index d5fed1d78..461451104 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -8,32 +8,42 @@ import ( ) func SetVideoRouter(router *gin.Engine) { + // Video proxy: accepts either session auth (dashboard) or token auth (API clients) + videoProxyRouter := router.Group("/v1") + videoProxyRouter.Use(middleware.RouteTag("relay")) + videoProxyRouter.Use(middleware.TokenOrUserAuth()) + { + videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy) + } + videoV1Router := router.Group("/v1") + videoV1Router.Use(middleware.RouteTag("relay")) videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { - videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy) videoV1Router.POST("/video/generations", controller.RelayTask) - videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) + videoV1Router.GET("/video/generations/:task_id", controller.RelayTaskFetch) videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask) } // openai compatible API video routes // docs: https://platform.openai.com/docs/api-reference/videos/create { videoV1Router.POST("/videos", controller.RelayTask) - videoV1Router.GET("/videos/:task_id", controller.RelayTask) + videoV1Router.GET("/videos/:task_id", controller.RelayTaskFetch) } klingV1Router := router.Group("/kling/v1") + klingV1Router.Use(middleware.RouteTag("relay")) klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute()) { klingV1Router.POST("/videos/text2video", controller.RelayTask) klingV1Router.POST("/videos/image2video", controller.RelayTask) - klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask) - klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask) + klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTaskFetch) + klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTaskFetch) } // Jimeng official API routes - direct mapping to official API format jimengOfficialGroup := router.Group("jimeng") + jimengOfficialGroup.Use(middleware.RouteTag("relay")) jimengOfficialGroup.Use(middleware.JimengRequestConvert(), middleware.TokenAuth(), middleware.Distribute()) { // Maps to: /?Action=CVSync2AsyncSubmitTask&Version=2022-08-31 and /?Action=CVSync2AsyncGetResult&Version=2022-08-31 diff --git a/router/web-router.go b/router/web-router.go index b053a3e63..17a8378dd 100644 --- a/router/web-router.go +++ b/router/web-router.go @@ -19,6 +19,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { router.Use(middleware.Cache()) router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/dist"))) router.NoRoute(func(c *gin.Context) { + c.Set(middleware.RouteTagKey, "web") if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") || strings.HasPrefix(c.Request.RequestURI, "/assets") { controller.RelayNotFound(c) return diff --git a/service/billing_session.go b/service/billing_session.go index 1a31316b5..f24b68e55 100644 --- a/service/billing_session.go +++ b/service/billing_session.go @@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro // shouldTrust 统一信任额度检查,适用于钱包和订阅。 func (s *BillingSession) shouldTrust(c *gin.Context) bool { + // 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路 + if s.relayInfo.ForcePreConsume { + return false + } + trustQuota := common.GetTrustQuota() if trustQuota <= 0 { return false diff --git a/service/channel_affinity.go b/service/channel_affinity.go index fe1524c59..524c6574a 100644 --- a/service/channel_affinity.go +++ b/service/channel_affinity.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/pkg/cachex" "github.com/QuantumNous/new-api/setting/operation_setting" + "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/samber/hot" "github.com/tidwall/gjson" @@ -61,6 +62,12 @@ type ChannelAffinityStatsContext struct { TTLSeconds int64 } +const ( + cacheTokenRateModeCachedOverPrompt = "cached_over_prompt" + cacheTokenRateModeCachedOverPromptPlusCached = "cached_over_prompt_plus_cached" + cacheTokenRateModeMixed = "mixed" +) + type ChannelAffinityCacheStats struct { Enabled bool `json:"enabled"` Total int `json:"total"` @@ -565,9 +572,10 @@ func RecordChannelAffinity(c *gin.Context, channelID int) { } type ChannelAffinityUsageCacheStats struct { - RuleName string `json:"rule_name"` - UsingGroup string `json:"using_group"` - KeyFingerprint string `json:"key_fp"` + RuleName string `json:"rule_name"` + UsingGroup string `json:"using_group"` + KeyFingerprint string `json:"key_fp"` + CachedTokenRateMode string `json:"cached_token_rate_mode"` Hit int64 `json:"hit"` Total int64 `json:"total"` @@ -582,6 +590,8 @@ type ChannelAffinityUsageCacheStats struct { } type ChannelAffinityUsageCacheCounters struct { + CachedTokenRateMode string `json:"cached_token_rate_mode"` + Hit int64 `json:"hit"` Total int64 `json:"total"` WindowSeconds int64 `json:"window_seconds"` @@ -596,12 +606,17 @@ type ChannelAffinityUsageCacheCounters struct { var channelAffinityUsageCacheStatsLocks [64]sync.Mutex -func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage) { +// ObserveChannelAffinityUsageCacheByRelayFormat records usage cache stats with a stable rate mode derived from relay format. +func ObserveChannelAffinityUsageCacheByRelayFormat(c *gin.Context, usage *dto.Usage, relayFormat types.RelayFormat) { + ObserveChannelAffinityUsageCacheFromContext(c, usage, cachedTokenRateModeByRelayFormat(relayFormat)) +} + +func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage, cachedTokenRateMode string) { statsCtx, ok := GetChannelAffinityStatsContext(c) if !ok { return } - observeChannelAffinityUsageCache(statsCtx, usage) + observeChannelAffinityUsageCache(statsCtx, usage, cachedTokenRateMode) } func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) ChannelAffinityUsageCacheStats { @@ -628,6 +643,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann } } return ChannelAffinityUsageCacheStats{ + CachedTokenRateMode: v.CachedTokenRateMode, RuleName: ruleName, UsingGroup: usingGroup, KeyFingerprint: keyFp, @@ -643,7 +659,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann } } -func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage) { +func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage, cachedTokenRateMode string) { entryKey := channelAffinityUsageCacheEntryKey(statsCtx.RuleName, statsCtx.UsingGroup, statsCtx.KeyFingerprint) if entryKey == "" { return @@ -669,6 +685,14 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag if !found { next = ChannelAffinityUsageCacheCounters{} } + currentMode := normalizeCachedTokenRateMode(cachedTokenRateMode) + if currentMode != "" { + if next.CachedTokenRateMode == "" { + next.CachedTokenRateMode = currentMode + } else if next.CachedTokenRateMode != currentMode && next.CachedTokenRateMode != cacheTokenRateModeMixed { + next.CachedTokenRateMode = cacheTokenRateModeMixed + } + } next.Total++ hit, cachedTokens, promptCacheHitTokens := usageCacheSignals(usage) if hit { @@ -684,6 +708,30 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag _ = cache.SetWithTTL(entryKey, next, ttl) } +func normalizeCachedTokenRateMode(mode string) string { + switch mode { + case cacheTokenRateModeCachedOverPrompt: + return cacheTokenRateModeCachedOverPrompt + case cacheTokenRateModeCachedOverPromptPlusCached: + return cacheTokenRateModeCachedOverPromptPlusCached + case cacheTokenRateModeMixed: + return cacheTokenRateModeMixed + default: + return "" + } +} + +func cachedTokenRateModeByRelayFormat(relayFormat types.RelayFormat) string { + switch relayFormat { + case types.RelayFormatOpenAI, types.RelayFormatOpenAIResponses, types.RelayFormatOpenAIResponsesCompaction: + return cacheTokenRateModeCachedOverPrompt + case types.RelayFormatClaude: + return cacheTokenRateModeCachedOverPromptPlusCached + default: + return "" + } +} + func channelAffinityUsageCacheEntryKey(ruleName, usingGroup, keyFp string) string { ruleName = strings.TrimSpace(ruleName) usingGroup = strings.TrimSpace(usingGroup) diff --git a/service/channel_affinity_usage_cache_test.go b/service/channel_affinity_usage_cache_test.go new file mode 100644 index 000000000..64d3d715b --- /dev/null +++ b/service/channel_affinity_usage_cache_test.go @@ -0,0 +1,105 @@ +package service + +import ( + "fmt" + "net/http/httptest" + "testing" + "time" + + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP string) *gin.Context { + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + setChannelAffinityContext(ctx, channelAffinityMeta{ + CacheKey: fmt.Sprintf("test:%s:%s:%s", ruleName, usingGroup, keyFP), + TTLSeconds: 600, + RuleName: ruleName, + UsingGroup: usingGroup, + KeyFingerprint: keyFP, + }) + return ctx +} + +func TestObserveChannelAffinityUsageCacheByRelayFormat_ClaudeMode(t *testing.T) { + ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano()) + usingGroup := "default" + keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano()) + ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP) + + usage := &dto.Usage{ + PromptTokens: 100, + CompletionTokens: 40, + TotalTokens: 140, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 30, + }, + } + + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatClaude) + stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP) + + require.EqualValues(t, 1, stats.Total) + require.EqualValues(t, 1, stats.Hit) + require.EqualValues(t, 100, stats.PromptTokens) + require.EqualValues(t, 40, stats.CompletionTokens) + require.EqualValues(t, 140, stats.TotalTokens) + require.EqualValues(t, 30, stats.CachedTokens) + require.Equal(t, cacheTokenRateModeCachedOverPromptPlusCached, stats.CachedTokenRateMode) +} + +func TestObserveChannelAffinityUsageCacheByRelayFormat_MixedMode(t *testing.T) { + ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano()) + usingGroup := "default" + keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano()) + ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP) + + openAIUsage := &dto.Usage{ + PromptTokens: 100, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 10, + }, + } + claudeUsage := &dto.Usage{ + PromptTokens: 80, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 20, + }, + } + + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, openAIUsage, types.RelayFormatOpenAI) + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, claudeUsage, types.RelayFormatClaude) + stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP) + + require.EqualValues(t, 2, stats.Total) + require.EqualValues(t, 2, stats.Hit) + require.EqualValues(t, 180, stats.PromptTokens) + require.EqualValues(t, 30, stats.CachedTokens) + require.Equal(t, cacheTokenRateModeMixed, stats.CachedTokenRateMode) +} + +func TestObserveChannelAffinityUsageCacheByRelayFormat_UnsupportedModeKeepsEmpty(t *testing.T) { + ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano()) + usingGroup := "default" + keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano()) + ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP) + + usage := &dto.Usage{ + PromptTokens: 100, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 25, + }, + } + + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatGemini) + stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP) + + require.EqualValues(t, 1, stats.Total) + require.EqualValues(t, 1, stats.Hit) + require.EqualValues(t, 25, stats.CachedTokens) + require.Equal(t, "", stats.CachedTokenRateMode) +} diff --git a/service/codex_credential_refresh.go b/service/codex_credential_refresh.go index 0290fe516..2e681ee61 100644 --- a/service/codex_credential_refresh.go +++ b/service/codex_credential_refresh.go @@ -62,7 +62,7 @@ func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts Code refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken) + res, err := RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy) if err != nil { return nil, nil, err } diff --git a/service/codex_oauth.go b/service/codex_oauth.go index 4c2dce1cc..33ef1d60a 100644 --- a/service/codex_oauth.go +++ b/service/codex_oauth.go @@ -12,6 +12,8 @@ import ( "net/url" "strings" "time" + + "github.com/QuantumNous/new-api/common" ) const ( @@ -38,12 +40,26 @@ type CodexOAuthAuthorizationFlow struct { } func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) { - client := &http.Client{Timeout: defaultHTTPTimeout} + return RefreshCodexOAuthTokenWithProxy(ctx, refreshToken, "") +} + +func RefreshCodexOAuthTokenWithProxy(ctx context.Context, refreshToken string, proxyURL string) (*CodexOAuthTokenResult, error) { + client, err := getCodexOAuthHTTPClient(proxyURL) + if err != nil { + return nil, err + } return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken) } func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) { - client := &http.Client{Timeout: defaultHTTPTimeout} + return ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, "") +} + +func ExchangeCodexAuthorizationCodeWithProxy(ctx context.Context, code string, verifier string, proxyURL string) (*CodexOAuthTokenResult, error) { + client, err := getCodexOAuthHTTPClient(proxyURL) + if err != nil { + return nil, err + } return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI) } @@ -104,7 +120,7 @@ func refreshCodexOAuthToken( ExpiresIn int `json:"expires_in"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := common.DecodeJson(resp.Body, &payload); err != nil { return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { @@ -165,7 +181,7 @@ func exchangeCodexAuthorizationCode( RefreshToken string `json:"refresh_token"` ExpiresIn int `json:"expires_in"` } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + if err := common.DecodeJson(resp.Body, &payload); err != nil { return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { @@ -181,6 +197,19 @@ func exchangeCodexAuthorizationCode( }, nil } +func getCodexOAuthHTTPClient(proxyURL string) (*http.Client, error) { + baseClient, err := GetHttpClientWithProxy(strings.TrimSpace(proxyURL)) + if err != nil { + return nil, err + } + if baseClient == nil { + return &http.Client{Timeout: defaultHTTPTimeout}, nil + } + clientCopy := *baseClient + clientCopy.Timeout = defaultHTTPTimeout + return &clientCopy, nil +} + func buildCodexAuthorizeURL(state string, challenge string) (string, error) { u, err := url.Parse(codexOAuthAuthorizeURL) if err != nil { diff --git a/service/error.go b/service/error.go index 7a9d7a815..a2ff0aad7 100644 --- a/service/error.go +++ b/service/error.go @@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { return taskError } + +// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。 +func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError { + if apiErr == nil { + return nil + } + return &dto.TaskError{ + Code: string(apiErr.GetErrorCode()), + Message: apiErr.Err.Error(), + StatusCode: apiErr.StatusCode, + Error: apiErr.Err, + } +} diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 771da5b77..1c440911b 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, return info } -func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} { +func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} { other := make(map[string]interface{}) other["model_price"] = priceData.ModelPrice other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio diff --git a/service/midjourney.go b/service/midjourney.go index 9b2eb5ca7..bdb0fe50a 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -19,7 +19,7 @@ import ( "github.com/gin-gonic/gin" ) -func CoverActionToModelName(mjAction string) string { +func CovertMjpActionToModelName(mjAction string) string { modelName := "mj_" + strings.ToLower(mjAction) if mjAction == constant.MjActionSwapFace { modelName = "swap_face" @@ -70,7 +70,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false } } - modelName := CoverActionToModelName(action) + modelName := CovertMjpActionToModelName(action) return modelName, nil, true } diff --git a/service/quota.go b/service/quota.go index 50421017e..7ee70edd5 100644 --- a/service/quota.go +++ b/service/quota.go @@ -236,6 +236,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod } func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) { + if usage != nil { + ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat()) + } useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens diff --git a/service/task_billing.go b/service/task_billing.go new file mode 100644 index 000000000..0da4cf431 --- /dev/null +++ b/service/task_billing.go @@ -0,0 +1,285 @@ +package service + +import ( + "context" + "fmt" + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/ratio_setting" + "github.com/gin-gonic/gin" +) + +// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。 +// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。 +func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("操作 %s", info.Action) + // 支持任务仅按次计费 + if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) { + logContent = fmt.Sprintf("%s,按次计费", logContent) + } else { + if len(info.PriceData.OtherRatios) > 0 { + var contents []string + for key, ra := range info.PriceData.OtherRatios { + if 1.0 != ra { + contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) + } + } + if len(contents) > 0 { + logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) + } + } + } + other := make(map[string]interface{}) + other["request_path"] = c.Request.URL.Path + other["model_price"] = info.PriceData.ModelPrice + other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio + if info.PriceData.GroupRatioInfo.HasSpecialRatio { + other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio + } + if info.IsModelMapped { + other["is_model_mapped"] = true + other["upstream_model_name"] = info.UpstreamModelName + } + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, + ModelName: info.OriginModelName, + TokenName: tokenName, + Quota: info.PriceData.Quota, + Content: logContent, + TokenId: info.TokenId, + Group: info.UsingGroup, + Other: other, + }) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota) + model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota) +} + +// --------------------------------------------------------------------------- +// 异步任务计费辅助函数 +// --------------------------------------------------------------------------- + +// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。 +// 如果令牌已被删除或查询失败,返回空字符串。 +func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string { + token, err := model.GetTokenById(tokenId) + if err != nil { + logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error())) + return "" + } + return token.Key +} + +// taskIsSubscription 判断任务是否通过订阅计费。 +func taskIsSubscription(task *model.Task) bool { + return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0 +} + +// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。 +func taskAdjustFunding(task *model.Task, delta int) error { + if taskIsSubscription(task) { + return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta)) + } + if delta > 0 { + return model.DecreaseUserQuota(task.UserId, delta) + } + return model.IncreaseUserQuota(task.UserId, -delta, false) +} + +// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。 +// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。 +func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) { + if task.PrivateData.TokenId <= 0 || delta == 0 { + return + } + tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID) + if tokenKey == "" { + return + } + var err error + if delta > 0 { + err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta) + } else { + err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta) + } + if err != nil { + logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error())) + } +} + +// taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。 +func taskBillingOther(task *model.Task) map[string]interface{} { + other := make(map[string]interface{}) + if bc := task.PrivateData.BillingContext; bc != nil { + other["model_price"] = bc.ModelPrice + other["group_ratio"] = bc.GroupRatio + if len(bc.OtherRatios) > 0 { + for k, v := range bc.OtherRatios { + other[k] = v + } + } + } + props := task.Properties + if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName { + other["is_model_mapped"] = true + other["upstream_model_name"] = props.UpstreamModelName + } + return other +} + +// taskModelName 从 BillingContext 或 Properties 中获取模型名称。 +func taskModelName(task *model.Task) string { + if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" { + return bc.OriginModelName + } + return task.Properties.OriginModelName +} + +// RefundTaskQuota 统一的任务失败退款逻辑。 +// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。 +func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { + quota := task.Quota + if quota == 0 { + return + } + + // 1. 退还资金来源(钱包或订阅) + if err := taskAdjustFunding(task, -quota); err != nil { + logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 2. 退还令牌额度 + taskAdjustTokenQuota(ctx, task, -quota) + + // 3. 记录日志 + other := taskBillingOther(task) + other["task_id"] = task.TaskID + other["reason"] = reason + model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ + UserId: task.UserId, + LogType: model.LogTypeRefund, + Content: "", + ChannelId: task.ChannelId, + ModelName: taskModelName(task), + Quota: quota, + TokenId: task.PrivateData.TokenId, + Group: task.Group, + Other: other, + }) +} + +// RecalculateTaskQuota 通用的异步差额结算。 +// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。 +// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。 +func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) { + if actualQuota <= 0 { + return + } + preConsumedQuota := task.Quota + quotaDelta := actualQuota - preConsumedQuota + + if quotaDelta == 0 { + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)", + task.TaskID, logger.LogQuota(actualQuota), reason)) + return + } + + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)", + task.TaskID, + logger.LogQuota(quotaDelta), + logger.LogQuota(actualQuota), + logger.LogQuota(preConsumedQuota), + reason, + )) + + // 调整资金来源 + if err := taskAdjustFunding(task, quotaDelta); err != nil { + logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 调整令牌额度 + taskAdjustTokenQuota(ctx, task, quotaDelta) + + task.Quota = actualQuota + + var logType int + var logQuota int + if quotaDelta > 0 { + logType = model.LogTypeConsume + logQuota = quotaDelta + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) + } else { + logType = model.LogTypeRefund + logQuota = -quotaDelta + } + other := taskBillingOther(task) + other["task_id"] = task.TaskID + other["reason"] = reason + other["pre_consumed_quota"] = preConsumedQuota + other["actual_quota"] = actualQuota + model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ + UserId: task.UserId, + LogType: logType, + Content: "", + ChannelId: task.ChannelId, + ModelName: taskModelName(task), + Quota: logQuota, + TokenId: task.PrivateData.TokenId, + Group: task.Group, + Other: other, + }) +} + +// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。 +// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度, +// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。 +func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) { + if totalTokens <= 0 { + return + } + + modelName := taskModelName(task) + + // 获取模型价格和倍率 + modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) + // 只有配置了倍率(非固定价格)时才按 token 重新计费 + if !hasRatioSetting || modelRatio <= 0 { + return + } + + // 获取用户和组的倍率信息 + group := task.Group + if group == "" { + user, err := model.GetUserById(task.UserId, false) + if err == nil { + group = user.Group + } + } + if group == "" { + return + } + + groupRatio := ratio_setting.GetGroupRatio(group) + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) + + var finalGroupRatio float64 + if hasUserGroupRatio { + finalGroupRatio = userGroupRatio + } else { + finalGroupRatio = groupRatio + } + + // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio + actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio) + + reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio) + RecalculateTaskQuota(ctx, task, actualQuota, reason) +} diff --git a/service/task_billing_test.go b/service/task_billing_test.go new file mode 100644 index 000000000..1145bba54 --- /dev/null +++ b/service/task_billing_test.go @@ -0,0 +1,712 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "os" + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestMain(m *testing.M) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + panic("failed to open test db: " + err.Error()) + } + sqlDB, err := db.DB() + if err != nil { + panic("failed to get sql.DB: " + err.Error()) + } + sqlDB.SetMaxOpenConns(1) + + model.DB = db + model.LOG_DB = db + + common.UsingSQLite = true + common.RedisEnabled = false + common.BatchUpdateEnabled = false + common.LogConsumeEnabled = true + + if err := db.AutoMigrate( + &model.Task{}, + &model.User{}, + &model.Token{}, + &model.Log{}, + &model.Channel{}, + &model.UserSubscription{}, + ); err != nil { + panic("failed to migrate: " + err.Error()) + } + + os.Exit(m.Run()) +} + +// --------------------------------------------------------------------------- +// Seed helpers +// --------------------------------------------------------------------------- + +func truncate(t *testing.T) { + t.Helper() + t.Cleanup(func() { + model.DB.Exec("DELETE FROM tasks") + model.DB.Exec("DELETE FROM users") + model.DB.Exec("DELETE FROM tokens") + model.DB.Exec("DELETE FROM logs") + model.DB.Exec("DELETE FROM channels") + model.DB.Exec("DELETE FROM user_subscriptions") + }) +} + +func seedUser(t *testing.T, id int, quota int) { + t.Helper() + user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled} + require.NoError(t, model.DB.Create(user).Error) +} + +func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) { + t.Helper() + token := &model.Token{ + Id: id, + UserId: userId, + Key: key, + Name: "test_token", + Status: common.TokenStatusEnabled, + RemainQuota: remainQuota, + UsedQuota: 0, + } + require.NoError(t, model.DB.Create(token).Error) +} + +func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) { + t.Helper() + sub := &model.UserSubscription{ + Id: id, + UserId: userId, + AmountTotal: amountTotal, + AmountUsed: amountUsed, + Status: "active", + StartTime: time.Now().Unix(), + EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(), + } + require.NoError(t, model.DB.Create(sub).Error) +} + +func seedChannel(t *testing.T, id int) { + t.Helper() + ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled} + require.NoError(t, model.DB.Create(ch).Error) +} + +func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task { + return &model.Task{ + TaskID: "task_" + time.Now().Format("150405.000"), + UserId: userId, + ChannelId: channelId, + Quota: quota, + Status: model.TaskStatus(model.TaskStatusInProgress), + Group: "default", + Data: json.RawMessage(`{}`), + CreatedAt: time.Now().Unix(), + UpdatedAt: time.Now().Unix(), + Properties: model.Properties{ + OriginModelName: "test-model", + }, + PrivateData: model.TaskPrivateData{ + BillingSource: billingSource, + SubscriptionId: subscriptionId, + TokenId: tokenId, + BillingContext: &model.TaskBillingContext{ + ModelPrice: 0.02, + GroupRatio: 1.0, + OriginModelName: "test-model", + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Read-back helpers +// --------------------------------------------------------------------------- + +func getUserQuota(t *testing.T, id int) int { + t.Helper() + var user model.User + require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error) + return user.Quota +} + +func getTokenRemainQuota(t *testing.T, id int) int { + t.Helper() + var token model.Token + require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error) + return token.RemainQuota +} + +func getTokenUsedQuota(t *testing.T, id int) int { + t.Helper() + var token model.Token + require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error) + return token.UsedQuota +} + +func getSubscriptionUsed(t *testing.T, id int) int64 { + t.Helper() + var sub model.UserSubscription + require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error) + return sub.AmountUsed +} + +func getLastLog(t *testing.T) *model.Log { + t.Helper() + var log model.Log + err := model.LOG_DB.Order("id desc").First(&log).Error + if err != nil { + return nil + } + return &log +} + +func countLogs(t *testing.T) int64 { + t.Helper() + var count int64 + model.LOG_DB.Model(&model.Log{}).Count(&count) + return count +} + +// =========================================================================== +// RefundTaskQuota tests +// =========================================================================== + +func TestRefundTaskQuota_Wallet(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 1, 1, 1 + const initQuota, preConsumed = 10000, 3000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-test-key", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RefundTaskQuota(ctx, task, "task failed: upstream error") + + // User quota should increase by preConsumed + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + + // Token remain_quota should increase, used_quota should decrease + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID)) + + // A refund log should be created + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) + assert.Equal(t, preConsumed, log.Quota) + assert.Equal(t, "test-model", log.ModelName) +} + +func TestRefundTaskQuota_Subscription(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID, subID = 2, 2, 2, 1 + const preConsumed = 2000 + const subTotal, subUsed int64 = 100000, 50000 + const tokenRemain = 8000 + + seedUser(t, userID, 0) + seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain) + seedChannel(t, channelID) + seedSubscription(t, subID, userID, subTotal, subUsed) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) + + RefundTaskQuota(ctx, task, "subscription task failed") + + // Subscription used should decrease by preConsumed + assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID)) + + // Token should also be refunded + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +func TestRefundTaskQuota_ZeroQuota(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 3 + seedUser(t, userID, 5000) + + task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0) + + RefundTaskQuota(ctx, task, "zero quota task") + + // No change to user quota + assert.Equal(t, 5000, getUserQuota(t, userID)) + + // No log created + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRefundTaskQuota_NoToken(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, channelID = 4, 4 + const initQuota, preConsumed = 10000, 1500 + + seedUser(t, userID, initQuota) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0 + + RefundTaskQuota(ctx, task, "no token task failed") + + // User quota refunded + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + + // Log created + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +// =========================================================================== +// RecalculateTaskQuota tests +// =========================================================================== + +func TestRecalculate_PositiveDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 10, 10, 10 + const initQuota, preConsumed = 10000, 2000 + const actualQuota = 3000 // under-charged by 1000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + + // User quota should decrease by the delta (1000 additional charge) + assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID)) + + // Token should also be charged the delta + assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID)) + + // task.Quota should be updated to actualQuota + assert.Equal(t, actualQuota, task.Quota) + + // Log type should be Consume (additional charge) + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeConsume, log.Type) + assert.Equal(t, actualQuota-preConsumed, log.Quota) +} + +func TestRecalculate_NegativeDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 11, 11, 11 + const initQuota, preConsumed = 10000, 5000 + const actualQuota = 3000 // over-charged by 2000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + + // User quota should increase by abs(delta) = 2000 (refund overpayment) + assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) + + // Token should be refunded the difference + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + // task.Quota updated + assert.Equal(t, actualQuota, task.Quota) + + // Log type should be Refund + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) + assert.Equal(t, preConsumed-actualQuota, log.Quota) +} + +func TestRecalculate_ZeroDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 12 + const initQuota, preConsumed = 10000, 3000 + + seedUser(t, userID, initQuota) + + task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, preConsumed, "exact match") + + // No change to user quota + assert.Equal(t, initQuota, getUserQuota(t, userID)) + + // No log created (delta is zero) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRecalculate_ActualQuotaZero(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 13 + const initQuota = 10000 + + seedUser(t, userID, initQuota) + + task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, 0, "zero actual") + + // No change (early return) + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRecalculate_Subscription_NegativeDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID, subID = 14, 14, 14, 2 + const preConsumed = 5000 + const actualQuota = 2000 // over-charged by 3000 + const subTotal, subUsed int64 = 100000, 50000 + const tokenRemain = 8000 + + seedUser(t, userID, 0) + seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain) + seedChannel(t, channelID) + seedSubscription(t, subID, userID, subTotal, subUsed) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) + + RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge") + + // Subscription used should decrease by delta (refund 3000) + assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID)) + + // Token refunded + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + assert.Equal(t, actualQuota, task.Quota) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +// =========================================================================== +// CAS + Billing integration tests +// Simulates the flow in updateVideoSingleTask (service/task_polling.go) +// =========================================================================== + +// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask. +// It takes a persisted task (already in DB), applies the new status, and performs +// the conditional update + billing exactly as the polling loop does. +func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) { + snap := task.Snapshot() + + shouldRefund := false + shouldSettle := false + quota := task.Quota + + task.Status = newStatus + switch string(newStatus) { + case model.TaskStatusSuccess: + task.Progress = "100%" + task.FinishTime = 9999 + shouldSettle = true + case model.TaskStatusFailure: + task.Progress = "100%" + task.FinishTime = 9999 + task.FailReason = "upstream error" + if quota != 0 { + shouldRefund = true + } + default: + task.Progress = "50%" + } + + isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure) + if isDone && snap.Status != task.Status { + won, err := task.UpdateWithStatus(snap.Status) + if err != nil { + shouldRefund = false + shouldSettle = false + } else if !won { + shouldRefund = false + shouldSettle = false + } + } else if !snap.Equal(task.Snapshot()) { + _, _ = task.UpdateWithStatus(snap.Status) + } + + if shouldSettle && actualQuota > 0 { + RecalculateTaskQuota(ctx, task, actualQuota, "test settle") + } + if shouldRefund { + RefundTaskQuota(ctx, task, task.FailReason) + } +} + +func TestCASGuardedRefund_Win(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 20, 20, 20 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 6000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) + + // CAS wins: task in DB should now be FAILURE + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status) + + // Refund should have happened + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +func TestCASGuardedRefund_Lose(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 21, 21, 21 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 6000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain) + seedChannel(t, channelID) + + // Create task with IN_PROGRESS in DB + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + // Simulate another process already transitioning to FAILURE + model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure) + + // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition + // task.Status is still IN_PROGRESS in the snapshot + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) + + // CAS lost: user quota should NOT change (no double refund) + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + + // No billing log should be created + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestCASGuardedSettle_Win(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 22, 22, 22 + const initQuota, preConsumed = 10000, 5000 + const actualQuota = 3000 // over-charged, should get partial refund + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota) + + // CAS wins: task should be SUCCESS + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status) + + // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user) + assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + // task.Quota should be updated to actualQuota + assert.Equal(t, actualQuota, task.Quota) +} + +func TestNonTerminalUpdate_NoBilling(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, channelID = 23, 23 + const initQuota, preConsumed = 10000, 3000 + + seedUser(t, userID, initQuota) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + task.Progress = "20%" + require.NoError(t, model.DB.Create(task).Error) + + // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed) + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0) + + // User quota should NOT change + assert.Equal(t, initQuota, getUserQuota(t, userID)) + + // No billing log + assert.Equal(t, int64(0), countLogs(t)) + + // Task progress should be updated in DB + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.Equal(t, "50%", reloaded.Progress) +} + +// =========================================================================== +// Mock adaptor for settleTaskBillingOnComplete tests +// =========================================================================== + +type mockAdaptor struct { + adjustReturn int +} + +func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {} +func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil } +func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil } +func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { + return m.adjustReturn +} + +// =========================================================================== +// PerCallBilling tests — settleTaskBillingOnComplete +// =========================================================================== + +func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 30, 30, 30 + const initQuota, preConsumed = 10000, 5000 + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.PrivateData.BillingContext.PerCallBilling = true + + adaptor := &mockAdaptor{adjustReturn: 2000} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Per-call: no adjustment despite adaptor returning 2000 + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, preConsumed, task.Quota) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 31, 31, 31 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 7000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.PrivateData.BillingContext.PerCallBilling = true + + adaptor := &mockAdaptor{adjustReturn: 0} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Per-call: no recalculation by tokens + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, preConsumed, task.Quota) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 32, 32, 32 + const initQuota, preConsumed = 10000, 5000 + const adaptorQuota = 3000 + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + // PerCallBilling defaults to false + + adaptor := &mockAdaptor{adjustReturn: adaptorQuota} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Non-per-call: adaptor adjustment applies (refund 2000) + assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID)) + assert.Equal(t, adaptorQuota, task.Quota) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} diff --git a/service/task_polling.go b/service/task_polling.go new file mode 100644 index 000000000..3c5cab5b0 --- /dev/null +++ b/service/task_polling.go @@ -0,0 +1,539 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "sort" + "strings" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" + relaycommon "github.com/QuantumNous/new-api/relay/common" + + "github.com/samber/lo" +) + +// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖 +type TaskPollingAdaptor interface { + Init(info *relaycommon.RelayInfo) + FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error) + ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error) + // AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。 + // 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。 + AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int +} + +// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。 +// 打破 service -> relay -> relay/channel -> service 的循环依赖。 +var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor + +// sweepTimedOutTasks 在主轮询之前独立清理超时任务。 +// 每次最多处理 100 条,剩余的下个周期继续处理。 +// 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。 +func sweepTimedOutTasks(ctx context.Context) { + if constant.TaskTimeoutMinutes <= 0 { + return + } + cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60 + tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100) + if len(tasks) == 0 { + return + } + + const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC + reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes) + legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)" + now := time.Now().Unix() + timedOutCount := 0 + + for _, task := range tasks { + isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff + + oldStatus := task.Status + task.Status = model.TaskStatusFailure + task.Progress = "100%" + task.FinishTime = now + if isLegacy { + task.FailReason = legacyReason + } else { + task.FailReason = reason + } + + won, err := task.UpdateWithStatus(oldStatus) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err)) + continue + } + if !won { + logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID)) + continue + } + timedOutCount++ + if !isLegacy && task.Quota != 0 { + RefundTaskQuota(ctx, task, reason) + } + } + + if timedOutCount > 0 { + logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount)) + } +} + +// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务 +func TaskPollingLoop() { + for { + time.Sleep(time.Duration(15) * time.Second) + common.SysLog("任务进度轮询开始") + ctx := context.TODO() + sweepTimedOutTasks(ctx) + allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) + platformTask := make(map[constant.TaskPlatform][]*model.Task) + for _, t := range allTasks { + platformTask[t.Platform] = append(platformTask[t.Platform], t) + } + for platform, tasks := range platformTask { + if len(tasks) == 0 { + continue + } + taskChannelM := make(map[int][]string) + taskM := make(map[string]*model.Task) + nullTaskIds := make([]int64, 0) + for _, task := range tasks { + upstreamID := task.GetUpstreamTaskID() + if upstreamID == "" { + // 统计失败的未完成任务 + nullTaskIds = append(nullTaskIds, task.ID) + continue + } + taskM[upstreamID] = task + taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID) + } + if len(nullTaskIds) > 0 { + err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + } else { + logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + } + } + if len(taskChannelM) == 0 { + continue + } + + DispatchPlatformUpdate(platform, taskChannelM, taskM) + } + common.SysLog("任务进度轮询完成") + } +} + +// DispatchPlatformUpdate 按平台分发轮询更新 +func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { + switch platform { + case constant.TaskPlatformMidjourney: + // MJ 轮询由其自身处理,这里预留入口 + case constant.TaskPlatformSuno: + _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM) + default: + if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil { + common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err)) + } + } +} + +// UpdateSunoTasks 按渠道更新所有 Suno 任务 +func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + err := updateSunoTasks(ctx, channelId, taskIds, taskM) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) + } + } + return nil +} + +func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + ch, err := model.CacheGetChannel(channelId) + if err != nil { + common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values) + var failedIDs []int64 + for _, upstreamID := range taskIds { + if t, ok := taskM[upstreamID]; ok { + failedIDs = append(failedIDs, t.ID) + } + } + err = model.TaskBulkUpdateByID(failedIDs, map[string]any{ + "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err)) + } + return err + } + adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno) + if adaptor == nil { + return errors.New("adaptor not found") + } + proxy := ch.GetSetting().Proxy + resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{ + "ids": taskIds, + }, proxy) + if err != nil { + common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) + return err + } + if resp.StatusCode != http.StatusOK { + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + return fmt.Errorf("Get Task status code: %d", resp.StatusCode) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysLog(fmt.Sprintf("Get Suno Task parse body error: %v", err)) + return err + } + var responseItems dto.TaskResponse[[]dto.SunoDataResponse] + err = common.Unmarshal(responseBody, &responseItems) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("Get Suno Task parse body error2: %v, body: %s", err, string(responseBody))) + return err + } + if !responseItems.IsSuccess() { + common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) + return err + } + + for _, responseItem := range responseItems.Data { + task := taskM[responseItem.TaskID] + if !taskNeedsUpdate(task, responseItem) { + continue + } + + task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) + task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) + task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) + task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) + task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) + if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { + logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + task.Progress = "100%" + RefundTaskQuota(ctx, task, task.FailReason) + } + if responseItem.Status == model.TaskStatusSuccess { + task.Progress = "100%" + } + task.Data = responseItem.Data + + err = task.Update() + if err != nil { + common.SysLog("UpdateSunoTask task error: " + err.Error()) + } + } + return nil +} + +// taskNeedsUpdate 检查 Suno 任务是否需要更新 +func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { + if oldTask.SubmitTime != newTask.SubmitTime { + return true + } + if oldTask.StartTime != newTask.StartTime { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if string(oldTask.Status) != newTask.Status { + return true + } + if oldTask.FailReason != newTask.FailReason { + return true + } + + if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { + return true + } + + oldData, _ := common.Marshal(oldTask.Data) + newData, _ := common.Marshal(newTask.Data) + + sort.Slice(oldData, func(i, j int) bool { + return oldData[i] < oldData[j] + }) + sort.Slice(newData, func(i, j int) bool { + return newData[i] < newData[j] + }) + + if string(oldData) != string(newData) { + return true + } + return false +} + +// UpdateVideoTasks 按渠道更新所有视频任务 +func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil { + logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + } + } + return nil +} + +func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { + logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + cacheGetChannel, err := model.CacheGetChannel(channelId) + if err != nil { + // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values) + var failedIDs []int64 + for _, upstreamID := range taskIds { + if t, ok := taskM[upstreamID]; ok { + failedIDs = append(failedIDs, t.ID) + } + } + errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{ + "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if errUpdate != nil { + common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + } + return fmt.Errorf("CacheGetChannel failed: %w", err) + } + adaptor := GetTaskAdaptorFunc(platform) + if adaptor == nil { + return fmt.Errorf("video adaptor not found") + } + info := &relaycommon.RelayInfo{} + info.ChannelMeta = &relaycommon.ChannelMeta{ + ChannelBaseUrl: cacheGetChannel.GetBaseURL(), + } + info.ApiKey = cacheGetChannel.Key + adaptor.Init(info) + for _, taskId := range taskIds { + if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { + logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + } + } + return nil +} + +func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error { + baseURL := constant.ChannelBaseURLs[ch.Type] + if ch.GetBaseURL() != "" { + baseURL = ch.GetBaseURL() + } + proxy := ch.GetSetting().Proxy + + task := taskM[taskId] + if task == nil { + logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + return fmt.Errorf("task %s not found", taskId) + } + key := ch.Key + + privateData := task.PrivateData + if privateData.Key != "" { + key = privateData.Key + } + resp, err := adaptor.FetchTask(baseURL, key, map[string]any{ + "task_id": task.GetUpstreamTaskID(), + "action": task.Action, + }, proxy) + if err != nil { + return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("readAll failed for task %s: %w", taskId, err) + } + + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody))) + + snap := task.Snapshot() + + taskResult := &relaycommon.TaskInfo{} + // try parse as New API response format + var responseItems dto.TaskResponse[model.Task] + if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems)) + t := responseItems.Data + taskResult.TaskID = t.TaskID + taskResult.Status = string(t.Status) + taskResult.Url = t.GetResultURL() + taskResult.Progress = t.Progress + taskResult.Reason = t.FailReason + task.Data = t.Data + } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { + return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) + } else { + task.Data = redactVideoResponseBody(responseBody) + } + + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult)) + + now := time.Now().Unix() + if taskResult.Status == "" { + taskResult = relaycommon.FailTaskInfo("upstream returned empty status") + } + + shouldRefund := false + shouldSettle := false + quota := task.Quota + + task.Status = model.TaskStatus(taskResult.Status) + switch taskResult.Status { + case model.TaskStatusSubmitted: + task.Progress = taskcommon.ProgressSubmitted + case model.TaskStatusQueued: + task.Progress = taskcommon.ProgressQueued + case model.TaskStatusInProgress: + task.Progress = taskcommon.ProgressInProgress + if task.StartTime == 0 { + task.StartTime = now + } + case model.TaskStatusSuccess: + task.Progress = taskcommon.ProgressComplete + if task.FinishTime == 0 { + task.FinishTime = now + } + if strings.HasPrefix(taskResult.Url, "data:") { + // data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL + } else if taskResult.Url != "" { + // Direct upstream URL (e.g. Kling, Ali, Doubao, etc.) + task.PrivateData.ResultURL = taskResult.Url + } else { + // No URL from adaptor — construct proxy URL using public task ID + task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) + } + shouldSettle = true + case model.TaskStatusFailure: + logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) + task.Status = model.TaskStatusFailure + task.Progress = taskcommon.ProgressComplete + if task.FinishTime == 0 { + task.FinishTime = now + } + task.FailReason = taskResult.Reason + logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + taskResult.Progress = taskcommon.ProgressComplete + if quota != 0 { + shouldRefund = true + } + default: + return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID) + } + if taskResult.Progress != "" { + task.Progress = taskResult.Progress + } + + isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure + if isDone && snap.Status != task.Status { + won, err := task.UpdateWithStatus(snap.Status) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error())) + shouldRefund = false + shouldSettle = false + } else if !won { + logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID)) + shouldRefund = false + shouldSettle = false + } + } else if !snap.Equal(task.Snapshot()) { + if _, err := task.UpdateWithStatus(snap.Status); err != nil { + logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error())) + } + } else { + // No changes, skip update + logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID)) + } + + if shouldSettle { + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + } + if shouldRefund { + RefundTaskQuota(ctx, task, task.FailReason) + } + + return nil +} + +func redactVideoResponseBody(body []byte) []byte { + var m map[string]any + if err := common.Unmarshal(body, &m); err != nil { + return body + } + resp, _ := m["response"].(map[string]any) + if resp != nil { + delete(resp, "bytesBase64Encoded") + if v, ok := resp["video"].(string); ok { + resp["video"] = truncateBase64(v) + } + if vs, ok := resp["videos"].([]any); ok { + for i := range vs { + if vm, ok := vs[i].(map[string]any); ok { + delete(vm, "bytesBase64Encoded") + } + } + } + } + b, err := common.Marshal(m) + if err != nil { + return body + } + return b +} + +func truncateBase64(s string) string { + const maxKeep = 256 + if len(s) <= maxKeep { + return s + } + return s[:maxKeep] + "..." +} + +// settleTaskBillingOnComplete 任务完成时的统一计费调整。 +// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度 +// +// 2. taskResult.TotalTokens > 0 → 按 token 重算 +// 3. 都不满足 → 保持预扣额度不变 +func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) { + // 0. 按次计费的任务不做差额结算 + if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling { + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID)) + return + } + // 1. 优先让 adaptor 决定最终额度 + if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 { + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整") + return + } + // 2. 回退到 token 重算 + if taskResult.TotalTokens > 0 { + RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) + return + } + // 3. 无调整,保持预扣额度 +} diff --git a/service/violation_fee.go b/service/violation_fee.go index 400c10dd5..455088561 100644 --- a/service/violation_fee.go +++ b/service/violation_fee.go @@ -18,8 +18,9 @@ import ( ) const ( - ViolationFeeCodePrefix = "violation_fee." - CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE_CSAM" + ViolationFeeCodePrefix = "violation_fee." + CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE" + ContentViolatesUsageMarker = "Content violates usage guidelines" ) func IsViolationFeeCode(code types.ErrorCode) bool { @@ -30,11 +31,11 @@ func HasCSAMViolationMarker(err *types.NewAPIError) bool { if err == nil { return false } - if strings.Contains(err.Error(), CSAMViolationMarker) { + if strings.Contains(err.Error(), CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) { return true } msg := err.ToOpenAIError().Message - return strings.Contains(msg, CSAMViolationMarker) + return strings.Contains(msg, CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) } func WrapAsViolationFeeGrokCSAM(err *types.NewAPIError) *types.NewAPIError { diff --git a/setting/operation_setting/status_code_ranges.go b/setting/operation_setting/status_code_ranges.go index 698c87c91..7e3bc847a 100644 --- a/setting/operation_setting/status_code_ranges.go +++ b/setting/operation_setting/status_code_ranges.go @@ -26,6 +26,11 @@ var AutomaticRetryStatusCodeRanges = []StatusCodeRange{ {Start: 525, End: 599}, } +var alwaysSkipRetryStatusCodes = map[int]struct{}{ + 504: {}, + 524: {}, +} + func AutomaticDisableStatusCodesToString() string { return statusCodeRangesToString(AutomaticDisableStatusCodeRanges) } @@ -56,7 +61,15 @@ func AutomaticRetryStatusCodesFromString(s string) error { return nil } +func IsAlwaysSkipRetryStatusCode(code int) bool { + _, exists := alwaysSkipRetryStatusCodes[code] + return exists +} + func ShouldRetryByStatusCode(code int) bool { + if IsAlwaysSkipRetryStatusCode(code) { + return false + } return shouldMatchStatusCodeRanges(AutomaticRetryStatusCodeRanges, code) } diff --git a/setting/operation_setting/status_code_ranges_test.go b/setting/operation_setting/status_code_ranges_test.go index 5801824ac..4e292a368 100644 --- a/setting/operation_setting/status_code_ranges_test.go +++ b/setting/operation_setting/status_code_ranges_test.go @@ -62,6 +62,8 @@ func TestShouldRetryByStatusCode(t *testing.T) { require.True(t, ShouldRetryByStatusCode(429)) require.True(t, ShouldRetryByStatusCode(500)) + require.False(t, ShouldRetryByStatusCode(504)) + require.False(t, ShouldRetryByStatusCode(524)) require.False(t, ShouldRetryByStatusCode(400)) require.False(t, ShouldRetryByStatusCode(200)) } @@ -77,3 +79,9 @@ func TestShouldRetryByStatusCode_DefaultMatchesLegacyBehavior(t *testing.T) { require.False(t, ShouldRetryByStatusCode(524)) require.True(t, ShouldRetryByStatusCode(599)) } + +func TestIsAlwaysSkipRetryStatusCode(t *testing.T) { + require.True(t, IsAlwaysSkipRetryStatusCode(504)) + require.True(t, IsAlwaysSkipRetryStatusCode(524)) + require.False(t, IsAlwaysSkipRetryStatusCode(500)) +} diff --git a/types/price_data.go b/types/price_data.go index 3f7121b8c..93bc6ae8d 100644 --- a/types/price_data.go +++ b/types/price_data.go @@ -22,7 +22,8 @@ type PriceData struct { AudioCompletionRatio float64 OtherRatios map[string]float64 UsePrice bool - QuotaToPreConsume int // 预消耗额度 + Quota int // 按次计费的最终额度(MJ / Task) + QuotaToPreConsume int // 按量计费的预消耗额度 GroupRatioInfo GroupRatioInfo } @@ -36,12 +37,6 @@ func (p *PriceData) AddOtherRatio(key string, ratio float64) { p.OtherRatios[key] = ratio } -type PerCallPriceData struct { - ModelPrice float64 - Quota int - GroupRatioInfo GroupRatioInfo -} - func (p *PriceData) ToSetting() string { return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio) } diff --git a/web/src/components/auth/LoginForm.jsx b/web/src/components/auth/LoginForm.jsx index 636317e44..7e8c0ce01 100644 --- a/web/src/components/auth/LoginForm.jsx +++ b/web/src/components/auth/LoginForm.jsx @@ -29,6 +29,7 @@ import { showSuccess, updateAPI, getSystemName, + getOAuthProviderIcon, setUserData, onGitHubOAuthClicked, onDiscordOAuthClicked, @@ -130,6 +131,17 @@ const LoginForm = () => { return {}; } }, [statusState?.status]); + const hasCustomOAuthProviders = + (status.custom_oauth_providers || []).length > 0; + const hasOAuthLoginOptions = Boolean( + status.github_oauth || + status.discord_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth || + hasCustomOAuthProviders, + ); useEffect(() => { if (status?.turnstile_check) { @@ -598,7 +610,7 @@ const LoginForm = () => { theme='outline' className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors' type='tertiary' - icon={} + icon={getOAuthProviderIcon(provider.icon || '', 20)} onClick={() => handleCustomOAuthClick(provider)} loading={customOAuthLoading[provider.slug]} > @@ -817,12 +829,7 @@ const LoginForm = () => { - {(status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth) && ( + {hasOAuthLoginOptions && ( <> {t('或')} @@ -952,14 +959,7 @@ const LoginForm = () => { />

{showEmailLogin || - !( - status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth - ) + !hasOAuthLoginOptions ? renderEmailLoginForm() : renderOAuthOptions()} {renderWeChatLoginModal()} diff --git a/web/src/components/auth/RegisterForm.jsx b/web/src/components/auth/RegisterForm.jsx index 2edc499b1..0a755b194 100644 --- a/web/src/components/auth/RegisterForm.jsx +++ b/web/src/components/auth/RegisterForm.jsx @@ -27,8 +27,10 @@ import { showSuccess, updateAPI, getSystemName, + getOAuthProviderIcon, setUserData, onDiscordOAuthClicked, + onCustomOAuthClicked, } from '../../helpers'; import Turnstile from 'react-turnstile'; import { @@ -98,6 +100,7 @@ const RegisterForm = () => { const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] = useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); + const [customOAuthLoading, setCustomOAuthLoading] = useState({}); const [disableButton, setDisableButton] = useState(false); const [countdown, setCountdown] = useState(30); const [agreedToTerms, setAgreedToTerms] = useState(false); @@ -126,6 +129,17 @@ const RegisterForm = () => { return {}; } }, [statusState?.status]); + const hasCustomOAuthProviders = + (status.custom_oauth_providers || []).length > 0; + const hasOAuthRegisterOptions = Boolean( + status.github_oauth || + status.discord_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth || + hasCustomOAuthProviders, + ); const [showEmailVerification, setShowEmailVerification] = useState(false); @@ -319,6 +333,17 @@ const RegisterForm = () => { } }; + const handleCustomOAuthClick = (provider) => { + setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: true })); + try { + onCustomOAuthClicked(provider, { shouldLogout: true }); + } finally { + setTimeout(() => { + setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: false })); + }, 3000); + } + }; + const handleEmailRegisterClick = () => { setEmailRegisterLoading(true); setShowEmailRegister(true); @@ -469,6 +494,23 @@ const RegisterForm = () => { )} + {status.custom_oauth_providers && + status.custom_oauth_providers.map((provider) => ( + + ))} + {status.telegram_oauth && (
{
- {(status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth) && ( + {hasOAuthRegisterOptions && ( <> {t('或')} @@ -745,14 +782,7 @@ const RegisterForm = () => { />
{showEmailRegister || - !( - status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth - ) + !hasOAuthRegisterOptions ? renderEmailRegisterForm() : renderOAuthOptions()} {renderWeChatLoginModal()} diff --git a/web/src/components/common/modals/RiskAcknowledgementModal.jsx b/web/src/components/common/modals/RiskAcknowledgementModal.jsx new file mode 100644 index 000000000..54aa62eaa --- /dev/null +++ b/web/src/components/common/modals/RiskAcknowledgementModal.jsx @@ -0,0 +1,214 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useCallback, useEffect, useMemo, useState } from 'react'; +import { + Modal, + Button, + Typography, + Checkbox, + Input, + Space, +} from '@douyinfe/semi-ui'; +import { IconAlertTriangle } from '@douyinfe/semi-icons'; +import { useIsMobile } from '../../../hooks/common/useIsMobile'; +import MarkdownRenderer from '../markdown/MarkdownRenderer'; + +const { Text } = Typography; + +const RiskMarkdownBlock = React.memo(function RiskMarkdownBlock({ + markdownContent, +}) { + if (!markdownContent) { + return null; + } + + return ( +
+ +
+ ); +}); + +const RiskAcknowledgementModal = React.memo(function RiskAcknowledgementModal({ + visible, + title, + markdownContent = '', + detailTitle = '', + detailItems = [], + checklist = [], + inputPrompt = '', + requiredText = '', + inputPlaceholder = '', + mismatchText = '', + cancelText = '', + confirmText = '', + onCancel, + onConfirm, +}) { + const isMobile = useIsMobile(); + const [checkedItems, setCheckedItems] = useState([]); + const [typedText, setTypedText] = useState(''); + + useEffect(() => { + if (!visible) return; + setCheckedItems(Array(checklist.length).fill(false)); + setTypedText(''); + }, [visible, checklist.length]); + + const allChecked = useMemo(() => { + if (checklist.length === 0) return true; + return checkedItems.length === checklist.length && checkedItems.every(Boolean); + }, [checkedItems, checklist.length]); + + const typedMatched = useMemo(() => { + if (!requiredText) return true; + return typedText.trim() === requiredText.trim(); + }, [typedText, requiredText]); + + const detailText = useMemo(() => detailItems.join(', '), [detailItems]); + const canConfirm = allChecked && typedMatched; + + const handleChecklistChange = useCallback((index, checked) => { + setCheckedItems((previous) => { + const next = [...previous]; + next[index] = checked; + return next; + }); + }, []); + + return ( + + + {title} + + } + width={isMobile ? '100%' : 860} + centered + maskClosable={false} + closeOnEsc={false} + onCancel={onCancel} + bodyStyle={{ + maxHeight: isMobile ? '70vh' : '72vh', + overflowY: 'auto', + padding: isMobile ? '12px 16px' : '18px 22px', + }} + footer={ + + + + + } + > +
+ + + + {detailItems.length > 0 ? ( +
+ {detailTitle ? {detailTitle} : null} +
+ {detailText} +
+
+ ) : null} + + {checklist.length > 0 ? ( +
+ {checklist.map((item, index) => ( + { + handleChecklistChange(index, event.target.checked); + }} + > + {item} + + ))} +
+ ) : null} + + {requiredText ? ( +
+ {inputPrompt ? {inputPrompt} : null} +
+ {requiredText} +
+ event.preventDefault()} + onCut={(event) => event.preventDefault()} + onPaste={(event) => event.preventDefault()} + onDrop={(event) => event.preventDefault()} + /> + {!typedMatched && typedText ? ( + + {mismatchText} + + ) : null} +
+ ) : null} +
+
+ ); +}); + +export default RiskAcknowledgementModal; diff --git a/web/src/components/settings/CustomOAuthSetting.jsx b/web/src/components/settings/CustomOAuthSetting.jsx index 4b6df4c81..0912160be 100644 --- a/web/src/components/settings/CustomOAuthSetting.jsx +++ b/web/src/components/settings/CustomOAuthSetting.jsx @@ -27,14 +27,20 @@ import { Modal, Banner, Card, + Collapse, + Switch, Table, Tag, Popconfirm, Space, - Select, } from '@douyinfe/semi-ui'; -import { IconPlus, IconEdit, IconDelete } from '@douyinfe/semi-icons'; -import { API, showError, showSuccess } from '../../helpers'; +import { + IconPlus, + IconEdit, + IconDelete, + IconRefresh, +} from '@douyinfe/semi-icons'; +import { API, showError, showSuccess, getOAuthProviderIcon } from '../../helpers'; import { useTranslation } from 'react-i18next'; const { Text } = Typography; @@ -120,6 +126,69 @@ const OAUTH_PRESETS = { }, }; +const OAUTH_PRESET_ICONS = { + 'github-enterprise': 'github', + gitlab: 'gitlab', + gitea: 'gitea', + nextcloud: 'nextcloud', + keycloak: 'keycloak', + authentik: 'authentik', + ory: 'openid', +}; + +const getPresetIcon = (preset) => OAUTH_PRESET_ICONS[preset] || ''; + +const PRESET_RESET_VALUES = { + name: '', + slug: '', + icon: '', + authorization_endpoint: '', + token_endpoint: '', + user_info_endpoint: '', + scopes: '', + user_id_field: '', + username_field: '', + display_name_field: '', + email_field: '', + well_known: '', + auth_style: 0, + access_policy: '', + access_denied_message: '', +}; + +const DISCOVERY_FIELD_LABELS = { + authorization_endpoint: 'Authorization Endpoint', + token_endpoint: 'Token Endpoint', + user_info_endpoint: 'User Info Endpoint', + scopes: 'Scopes', + user_id_field: 'User ID Field', + username_field: 'Username Field', + display_name_field: 'Display Name Field', + email_field: 'Email Field', +}; + +const ACCESS_POLICY_TEMPLATES = { + level_active: `{ + "logic": "and", + "conditions": [ + {"field": "trust_level", "op": "gte", "value": 2}, + {"field": "active", "op": "eq", "value": true} + ] +}`, + org_or_role: `{ + "logic": "or", + "conditions": [ + {"field": "org", "op": "eq", "value": "core"}, + {"field": "roles", "op": "contains", "value": "admin"} + ] +}`, +}; + +const ACCESS_DENIED_TEMPLATES = { + level_hint: '需要等级 {{required}},你当前等级 {{current}}(字段:{{field}})', + org_hint: '仅限指定组织或角色访问。组织={{current.org}},角色={{current.roles}}', +}; + const CustomOAuthSetting = ({ serverAddress }) => { const { t } = useTranslation(); const [providers, setProviders] = useState([]); @@ -129,8 +198,47 @@ const CustomOAuthSetting = ({ serverAddress }) => { const [formValues, setFormValues] = useState({}); const [selectedPreset, setSelectedPreset] = useState(''); const [baseUrl, setBaseUrl] = useState(''); + const [discoveryLoading, setDiscoveryLoading] = useState(false); + const [discoveryInfo, setDiscoveryInfo] = useState(null); + const [advancedActiveKeys, setAdvancedActiveKeys] = useState([]); const formApiRef = React.useRef(null); + const mergeFormValues = (newValues) => { + setFormValues((prev) => ({ ...prev, ...newValues })); + if (!formApiRef.current) return; + Object.entries(newValues).forEach(([key, value]) => { + formApiRef.current.setValue(key, value); + }); + }; + + const getLatestFormValues = () => { + const values = formApiRef.current?.getValues?.(); + return values && typeof values === 'object' ? values : formValues; + }; + + const normalizeBaseUrl = (url) => (url || '').trim().replace(/\/+$/, ''); + + const inferBaseUrlFromProvider = (provider) => { + const endpoint = provider?.authorization_endpoint || provider?.token_endpoint; + if (!endpoint) return ''; + try { + const url = new URL(endpoint); + return `${url.protocol}//${url.host}`; + } catch (error) { + return ''; + } + }; + + const resetDiscoveryState = () => { + setDiscoveryInfo(null); + }; + + const closeModal = () => { + setModalVisible(false); + resetDiscoveryState(); + setAdvancedActiveKeys([]); + }; + const fetchProviders = async () => { setLoading(true); try { @@ -154,23 +262,30 @@ const CustomOAuthSetting = ({ serverAddress }) => { setEditingProvider(null); setFormValues({ enabled: false, + icon: '', scopes: 'openid profile email', user_id_field: 'sub', username_field: 'preferred_username', display_name_field: 'name', email_field: 'email', auth_style: 0, + access_policy: '', + access_denied_message: '', }); setSelectedPreset(''); setBaseUrl(''); + resetDiscoveryState(); + setAdvancedActiveKeys([]); setModalVisible(true); }; const handleEdit = (provider) => { setEditingProvider(provider); setFormValues({ ...provider }); - setSelectedPreset(''); - setBaseUrl(''); + setSelectedPreset(OAUTH_PRESETS[provider.slug] ? provider.slug : ''); + setBaseUrl(inferBaseUrlFromProvider(provider)); + resetDiscoveryState(); + setAdvancedActiveKeys([]); setModalVisible(true); }; @@ -189,6 +304,8 @@ const CustomOAuthSetting = ({ serverAddress }) => { }; const handleSubmit = async () => { + const currentValues = getLatestFormValues(); + // Validate required fields const requiredFields = [ 'name', @@ -204,7 +321,7 @@ const CustomOAuthSetting = ({ serverAddress }) => { } for (const field of requiredFields) { - if (!formValues[field]) { + if (!currentValues[field]) { showError(t(`请填写 ${field}`)); return; } @@ -213,11 +330,11 @@ const CustomOAuthSetting = ({ serverAddress }) => { // Validate endpoint URLs must be full URLs const endpointFields = ['authorization_endpoint', 'token_endpoint', 'user_info_endpoint']; for (const field of endpointFields) { - const value = formValues[field]; + const value = currentValues[field]; if (value && !value.startsWith('http://') && !value.startsWith('https://')) { - // Check if user selected a preset but forgot to fill server address + // Check if user selected a preset but forgot to fill issuer URL if (selectedPreset && !baseUrl) { - showError(t('请先填写服务器地址,以自动生成完整的端点 URL')); + showError(t('请先填写 Issuer URL,以自动生成完整的端点 URL')); } else { showError(t('端点 URL 必须是完整地址(以 http:// 或 https:// 开头)')); } @@ -226,80 +343,199 @@ const CustomOAuthSetting = ({ serverAddress }) => { } try { + const payload = { ...currentValues, enabled: !!currentValues.enabled }; + delete payload.preset; + delete payload.base_url; + let res; if (editingProvider) { res = await API.put( `/api/custom-oauth-provider/${editingProvider.id}`, - formValues + payload ); } else { - res = await API.post('/api/custom-oauth-provider/', formValues); + res = await API.post('/api/custom-oauth-provider/', payload); } if (res.data.success) { showSuccess(editingProvider ? t('更新成功') : t('创建成功')); - setModalVisible(false); + closeModal(); fetchProviders(); } else { showError(res.data.message); } } catch (error) { - showError(editingProvider ? t('更新失败') : t('创建失败')); + showError( + error?.response?.data?.message || + (editingProvider ? t('更新失败') : t('创建失败')), + ); + } + }; + + const handleFetchFromDiscovery = async () => { + const cleanBaseUrl = normalizeBaseUrl(baseUrl); + const configuredWellKnown = (formValues.well_known || '').trim(); + const wellKnownUrl = + configuredWellKnown || + (cleanBaseUrl ? `${cleanBaseUrl}/.well-known/openid-configuration` : ''); + + if (!wellKnownUrl) { + showError(t('请先填写 Discovery URL 或 Issuer URL')); + return; + } + + setDiscoveryLoading(true); + try { + const res = await API.post('/api/custom-oauth-provider/discovery', { + well_known_url: configuredWellKnown || '', + issuer_url: cleanBaseUrl || '', + }); + if (!res.data.success) { + throw new Error(res.data.message || t('未知错误')); + } + const data = res.data.data?.discovery || {}; + const resolvedWellKnown = res.data.data?.well_known_url || wellKnownUrl; + + const discoveredValues = { + well_known: resolvedWellKnown, + }; + const autoFilledFields = []; + if (data.authorization_endpoint) { + discoveredValues.authorization_endpoint = data.authorization_endpoint; + autoFilledFields.push('authorization_endpoint'); + } + if (data.token_endpoint) { + discoveredValues.token_endpoint = data.token_endpoint; + autoFilledFields.push('token_endpoint'); + } + if (data.userinfo_endpoint) { + discoveredValues.user_info_endpoint = data.userinfo_endpoint; + autoFilledFields.push('user_info_endpoint'); + } + + const scopesSupported = Array.isArray(data.scopes_supported) + ? data.scopes_supported + : []; + if (scopesSupported.length > 0 && !formValues.scopes) { + const preferredScopes = ['openid', 'profile', 'email'].filter((scope) => + scopesSupported.includes(scope), + ); + discoveredValues.scopes = + preferredScopes.length > 0 + ? preferredScopes.join(' ') + : scopesSupported.slice(0, 5).join(' '); + autoFilledFields.push('scopes'); + } + + const claimsSupported = Array.isArray(data.claims_supported) + ? data.claims_supported + : []; + const claimMap = { + user_id_field: 'sub', + username_field: 'preferred_username', + display_name_field: 'name', + email_field: 'email', + }; + Object.entries(claimMap).forEach(([field, claim]) => { + if (!formValues[field] && claimsSupported.includes(claim)) { + discoveredValues[field] = claim; + autoFilledFields.push(field); + } + }); + + const hasCoreEndpoint = + discoveredValues.authorization_endpoint || + discoveredValues.token_endpoint || + discoveredValues.user_info_endpoint; + if (!hasCoreEndpoint) { + showError(t('未在 Discovery 响应中找到可用的 OAuth 端点')); + return; + } + + mergeFormValues(discoveredValues); + setDiscoveryInfo({ + wellKnown: wellKnownUrl, + autoFilledFields, + scopesSupported: scopesSupported.slice(0, 12), + claimsSupported: claimsSupported.slice(0, 12), + }); + showSuccess(t('已从 Discovery 自动填充配置')); + } catch (error) { + showError( + t('获取 Discovery 配置失败:') + (error?.message || t('未知错误')), + ); + } finally { + setDiscoveryLoading(false); } }; const handlePresetChange = (preset) => { setSelectedPreset(preset); - if (preset && OAUTH_PRESETS[preset]) { - const presetConfig = OAUTH_PRESETS[preset]; - const cleanUrl = baseUrl ? baseUrl.replace(/\/+$/, '') : ''; - const newValues = { - name: presetConfig.name, - slug: preset, - scopes: presetConfig.scopes, - user_id_field: presetConfig.user_id_field, - username_field: presetConfig.username_field, - display_name_field: presetConfig.display_name_field, - email_field: presetConfig.email_field, - auth_style: presetConfig.auth_style ?? 0, - }; - // Only fill endpoints if server address is provided - if (cleanUrl) { - newValues.authorization_endpoint = cleanUrl + presetConfig.authorization_endpoint; - newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint; - newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint; - } - setFormValues((prev) => ({ ...prev, ...newValues })); - // Update form fields directly via formApi - if (formApiRef.current) { - Object.entries(newValues).forEach(([key, value]) => { - formApiRef.current.setValue(key, value); - }); - } + resetDiscoveryState(); + const cleanUrl = normalizeBaseUrl(baseUrl); + if (!preset || !OAUTH_PRESETS[preset]) { + mergeFormValues(PRESET_RESET_VALUES); + return; } + + const presetConfig = OAUTH_PRESETS[preset]; + const newValues = { + ...PRESET_RESET_VALUES, + name: presetConfig.name, + slug: preset, + icon: getPresetIcon(preset), + scopes: presetConfig.scopes, + user_id_field: presetConfig.user_id_field, + username_field: presetConfig.username_field, + display_name_field: presetConfig.display_name_field, + email_field: presetConfig.email_field, + auth_style: presetConfig.auth_style ?? 0, + }; + if (cleanUrl) { + newValues.authorization_endpoint = + cleanUrl + presetConfig.authorization_endpoint; + newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint; + newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint; + } + mergeFormValues(newValues); }; const handleBaseUrlChange = (url) => { setBaseUrl(url); if (url && selectedPreset && OAUTH_PRESETS[selectedPreset]) { const presetConfig = OAUTH_PRESETS[selectedPreset]; - const cleanUrl = url.replace(/\/+$/, ''); // Remove trailing slashes + const cleanUrl = normalizeBaseUrl(url); const newValues = { authorization_endpoint: cleanUrl + presetConfig.authorization_endpoint, token_endpoint: cleanUrl + presetConfig.token_endpoint, user_info_endpoint: cleanUrl + presetConfig.user_info_endpoint, }; - setFormValues((prev) => ({ ...prev, ...newValues })); - // Update form fields directly via formApi (use merge mode to preserve other fields) - if (formApiRef.current) { - Object.entries(newValues).forEach(([key, value]) => { - formApiRef.current.setValue(key, value); - }); - } + mergeFormValues(newValues); } }; + const applyAccessPolicyTemplate = (templateKey) => { + const template = ACCESS_POLICY_TEMPLATES[templateKey]; + if (!template) return; + mergeFormValues({ access_policy: template }); + showSuccess(t('已填充策略模板')); + }; + + const applyDeniedTemplate = (templateKey) => { + const template = ACCESS_DENIED_TEMPLATES[templateKey]; + if (!template) return; + mergeFormValues({ access_denied_message: template }); + showSuccess(t('已填充提示模板')); + }; + const columns = [ + { + title: t('图标'), + dataIndex: 'icon', + key: 'icon', + width: 80, + render: (icon) => getOAuthProviderIcon(icon || '', 18), + }, { title: t('名称'), dataIndex: 'name', @@ -325,7 +561,10 @@ const CustomOAuthSetting = ({ serverAddress }) => { title: t('Client ID'), dataIndex: 'client_id', key: 'client_id', - render: (id) => (id ? id.substring(0, 20) + '...' : '-'), + render: (id) => { + if (!id) return '-'; + return id.length > 20 ? `${id.substring(0, 20)}...` : id; + }, }, { title: t('操作'), @@ -352,6 +591,10 @@ const CustomOAuthSetting = ({ serverAddress }) => { }, ]; + const discoveryAutoFilledLabels = (discoveryInfo?.autoFilledFields || []) + .map((field) => DISCOVERY_FIELD_LABELS[field] || field) + .join(', '); + return ( @@ -391,56 +634,142 @@ const CustomOAuthSetting = ({ serverAddress }) => { setModalVisible(false)} - okText={t('保存')} - cancelText={t('取消')} - width={800} + onCancel={closeModal} + width={860} + centered + bodyStyle={{ maxHeight: '72vh', overflowY: 'auto', paddingRight: 6 }} + footer={ +
+ + {t('启用供应商')} + mergeFormValues({ enabled: !!checked })} + /> + + {formValues.enabled ? t('已启用') : t('已禁用')} + + + + +
+ } >
setFormValues(values)} + onValueChange={() => { + setFormValues((prev) => ({ ...prev, ...getLatestFormValues() })); + }} getFormApi={(api) => (formApiRef.current = api)} > - {!editingProvider && ( - - - ({ - value: key, - label: config.name, - })), - ]} - /> - - - - - + + {t('Configuration')} + + + {t('先填写配置,再自动填充 OAuth 端点,能显著减少手工输入')} + + {discoveryInfo && ( + +
+ {t('已从 Discovery 获取配置,可继续手动修改所有字段。')} +
+ {discoveryAutoFilledLabels ? ( +
+ {t('自动填充字段')}: + {' '} + {discoveryAutoFilledLabels} +
+ ) : null} + {discoveryInfo.scopesSupported?.length ? ( +
+ {t('Discovery scopes')}: + {' '} + {discoveryInfo.scopesSupported.join(', ')} +
+ ) : null} + {discoveryInfo.claimsSupported?.length ? ( +
+ {t('Discovery claims')}: + {' '} + {discoveryInfo.claimsSupported.join(', ')} +
+ ) : null} +
+ } + /> )} + + + ({ + value: key, + label: config.name, + })), + ]} + /> + + + + + +
+ +
+ +
+ + + + + + { + + + + {t( + '图标使用 react-icons(Simple Icons)或 URL/emoji,例如:github、gitlab、si:google', + )} + + } + showClear + /> + + +
+ {getOAuthProviderIcon(formValues.icon || '', 24)} +
+ +
+ { label={t('Authorization Endpoint')} placeholder={ selectedPreset && OAUTH_PRESETS[selectedPreset] - ? t('填写服务器地址后自动生成:') + + ? t('填写 Issuer URL 后自动生成:') + OAUTH_PRESETS[selectedPreset].authorization_endpoint : 'https://example.com/oauth/authorize' } @@ -544,15 +908,14 @@ const CustomOAuthSetting = ({ serverAddress }) => { - - - @@ -568,7 +931,7 @@ const CustomOAuthSetting = ({ serverAddress }) => { @@ -576,7 +939,7 @@ const CustomOAuthSetting = ({ serverAddress }) => { @@ -586,41 +949,100 @@ const CustomOAuthSetting = ({ serverAddress }) => { - - {t('高级选项')} - + { + const keys = Array.isArray(activeKey) ? activeKey : [activeKey]; + setAdvancedActiveKeys(keys.filter(Boolean)); + }} + > + + + + + + - - - - - - - {t('启用此 OAuth 提供商')} - - - + + {t('准入策略')} + + + {t('可选:基于用户信息 JSON 做组合条件准入,条件不满足时返回自定义提示')} + + + + mergeFormValues({ access_policy: value })} + label={t('准入策略 JSON(可选)')} + rows={6} + placeholder={`{ + "logic": "and", + "conditions": [ + {"field": "trust_level", "op": "gte", "value": 2}, + {"field": "active", "op": "eq", "value": true} + ] +}`} + extraText={t('支持逻辑 and/or 与嵌套 groups;操作符支持 eq/ne/gt/gte/lt/lte/in/not_in/contains/exists')} + showClear + /> + + + + + + + + + mergeFormValues({ access_denied_message: value })} + label={t('拒绝提示模板(可选)')} + placeholder={t('例如:需要等级 {{required}},你当前等级 {{current}}')} + extraText={t('可用变量:{{provider}} {{field}} {{op}} {{required}} {{current}} 以及 {{current.path}}')} + showClear + /> + + + + + + + + diff --git a/web/src/components/settings/personal/cards/AccountManagement.jsx b/web/src/components/settings/personal/cards/AccountManagement.jsx index bc27630ba..29249caa1 100644 --- a/web/src/components/settings/personal/cards/AccountManagement.jsx +++ b/web/src/components/settings/personal/cards/AccountManagement.jsx @@ -50,6 +50,7 @@ import { onLinuxDOOAuthClicked, onDiscordOAuthClicked, onCustomOAuthClicked, + getOAuthProviderIcon, } from '../../../../helpers'; import TwoFASetting from '../components/TwoFASetting'; @@ -148,12 +149,14 @@ const AccountManagement = ({ // Check if custom OAuth provider is bound const isCustomOAuthBound = (providerId) => { - return customOAuthBindings.some((b) => b.provider_id === providerId); + const normalizedId = Number(providerId); + return customOAuthBindings.some((b) => Number(b.provider_id) === normalizedId); }; // Get binding info for a provider const getCustomOAuthBinding = (providerId) => { - return customOAuthBindings.find((b) => b.provider_id === providerId); + const normalizedId = Number(providerId); + return customOAuthBindings.find((b) => Number(b.provider_id) === normalizedId); }; React.useEffect(() => { @@ -524,10 +527,10 @@ const AccountManagement = ({
- + {getOAuthProviderIcon( + provider.icon || binding?.provider_icon || '', + 20, + )}
diff --git a/web/src/components/settings/personal/cards/NotificationSettings.jsx b/web/src/components/settings/personal/cards/NotificationSettings.jsx index 964a730e4..e57e39d63 100644 --- a/web/src/components/settings/personal/cards/NotificationSettings.jsx +++ b/web/src/components/settings/personal/cards/NotificationSettings.jsx @@ -86,6 +86,7 @@ const NotificationSettings = ({ channel: true, models: true, deployment: true, + subscription: true, redemption: true, user: true, setting: true, @@ -169,6 +170,7 @@ const NotificationSettings = ({ channel: true, models: true, deployment: true, + subscription: true, redemption: true, user: true, setting: true, @@ -296,6 +298,11 @@ const NotificationSettings = ({ title: t('模型部署'), description: t('模型部署管理'), }, + { + key: 'subscription', + title: t('订阅管理'), + description: t('订阅套餐管理'), + }, { key: 'redemption', title: t('兑换码管理'), diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index f54b6c41a..3a91207dc 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -62,9 +62,14 @@ import CodexOAuthModal from './CodexOAuthModal'; import ParamOverrideEditorModal from './ParamOverrideEditorModal'; import JSONEditor from '../../../common/ui/JSONEditor'; import SecureVerificationModal from '../../../common/modals/SecureVerificationModal'; +import StatusCodeRiskGuardModal from './StatusCodeRiskGuardModal'; import ChannelKeyDisplay from '../../../common/ui/ChannelKeyDisplay'; import { useSecureVerification } from '../../../../hooks/common/useSecureVerification'; import { createApiCalls } from '../../../../services/secureVerification'; +import { + collectInvalidStatusCodeEntries, + collectNewDisallowedStatusCodeRedirects, +} from './statusCodeRiskGuard'; import { IconSave, IconClose, @@ -195,6 +200,8 @@ const EditChannelModal = (props) => { allow_service_tier: false, disable_store: false, // false = 允许透传(默认开启) allow_safety_identifier: false, + allow_include_obfuscation: false, + allow_inference_geo: false, claude_beta_query: false, }; const [batch, setBatch] = useState(false); @@ -209,6 +216,7 @@ const EditChannelModal = (props) => { const [fullModels, setFullModels] = useState([]); const [modelGroups, setModelGroups] = useState([]); const [customModel, setCustomModel] = useState(''); + const [modelSearchValue, setModelSearchValue] = useState(''); const [modalImageUrl, setModalImageUrl] = useState(''); const [isModalOpenurl, setIsModalOpenurl] = useState(false); const [modelModalVisible, setModelModalVisible] = useState(false); @@ -249,6 +257,25 @@ const EditChannelModal = (props) => { return []; } }, [inputs.model_mapping]); + const modelSearchMatchedCount = useMemo(() => { + const keyword = modelSearchValue.trim(); + if (!keyword) { + return modelOptions.length; + } + return modelOptions.reduce( + (count, option) => count + (selectFilter(keyword, option) ? 1 : 0), + 0, + ); + }, [modelOptions, modelSearchValue]); + const modelSearchHintText = useMemo(() => { + const keyword = modelSearchValue.trim(); + if (!keyword || modelSearchMatchedCount !== 0) { + return ''; + } + return t('未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加', { + name: keyword, + }); + }, [modelSearchMatchedCount, modelSearchValue, t]); const paramOverrideMeta = useMemo(() => { const raw = typeof inputs.param_override === 'string' @@ -338,6 +365,12 @@ const EditChannelModal = (props) => { window.open(targetUrl, '_blank', 'noopener'); }; const [verifyLoading, setVerifyLoading] = useState(false); + const statusCodeRiskConfirmResolverRef = useRef(null); + const [statusCodeRiskConfirmVisible, setStatusCodeRiskConfirmVisible] = + useState(false); + const [statusCodeRiskDetailItems, setStatusCodeRiskDetailItems] = useState( + [], + ); // 表单块导航相关状态 const formSectionRefs = useRef({ @@ -359,6 +392,7 @@ const EditChannelModal = (props) => { const doubaoApiClickCountRef = useRef(0); const initialModelsRef = useRef([]); const initialModelMappingRef = useRef(''); + const initialStatusCodeMappingRef = useRef(''); // 2FA状态更新辅助函数 const updateTwoFAState = (updates) => { @@ -811,6 +845,10 @@ const EditChannelModal = (props) => { data.disable_store = parsedSettings.disable_store || false; data.allow_safety_identifier = parsedSettings.allow_safety_identifier || false; + data.allow_include_obfuscation = + parsedSettings.allow_include_obfuscation || false; + data.allow_inference_geo = + parsedSettings.allow_inference_geo || false; data.claude_beta_query = parsedSettings.claude_beta_query || false; } catch (error) { console.error('解析其他设置失败:', error); @@ -822,6 +860,8 @@ const EditChannelModal = (props) => { data.allow_service_tier = false; data.disable_store = false; data.allow_safety_identifier = false; + data.allow_include_obfuscation = false; + data.allow_inference_geo = false; data.claude_beta_query = false; } } else { @@ -832,6 +872,8 @@ const EditChannelModal = (props) => { data.allow_service_tier = false; data.disable_store = false; data.allow_safety_identifier = false; + data.allow_include_obfuscation = false; + data.allow_inference_geo = false; data.claude_beta_query = false; } @@ -868,6 +910,7 @@ const EditChannelModal = (props) => { .map((model) => (model || '').trim()) .filter(Boolean); initialModelMappingRef.current = data.model_mapping || ''; + initialStatusCodeMappingRef.current = data.status_code_mapping || ''; let parsedIonet = null; if (data.other_info) { @@ -1173,6 +1216,7 @@ const EditChannelModal = (props) => { }, [inputs]); useEffect(() => { + setModelSearchValue(''); if (props.visible) { if (isEdit) { loadChannel(); @@ -1194,11 +1238,22 @@ const EditChannelModal = (props) => { if (!isEdit) { initialModelsRef.current = []; initialModelMappingRef.current = ''; + initialStatusCodeMappingRef.current = ''; } }, [isEdit, props.visible]); + useEffect(() => { + return () => { + if (statusCodeRiskConfirmResolverRef.current) { + statusCodeRiskConfirmResolverRef.current(false); + statusCodeRiskConfirmResolverRef.current = null; + } + }; + }, []); + // 统一的模态框重置函数 const resetModalState = () => { + resolveStatusCodeRiskConfirm(false); formApiRef.current?.reset(); // 重置渠道设置状态 setChannelSettings({ @@ -1216,6 +1271,7 @@ const EditChannelModal = (props) => { // 重置豆包隐藏入口状态 setDoubaoApiEditUnlocked(false); doubaoApiClickCountRef.current = 0; + setModelSearchValue(''); // 清空表单中的key_mode字段 if (formApiRef.current) { formApiRef.current.setValue('key_mode', undefined); @@ -1328,6 +1384,22 @@ const EditChannelModal = (props) => { }); }); + const resolveStatusCodeRiskConfirm = (confirmed) => { + setStatusCodeRiskConfirmVisible(false); + setStatusCodeRiskDetailItems([]); + if (statusCodeRiskConfirmResolverRef.current) { + statusCodeRiskConfirmResolverRef.current(confirmed); + statusCodeRiskConfirmResolverRef.current = null; + } + }; + + const confirmStatusCodeRisk = (detailItems) => + new Promise((resolve) => { + statusCodeRiskConfirmResolverRef.current = resolve; + setStatusCodeRiskDetailItems(detailItems); + setStatusCodeRiskConfirmVisible(true); + }); + const hasModelConfigChanged = (normalizedModels, modelMappingStr) => { if (!isEdit) return true; const initialModels = initialModelsRef.current; @@ -1518,6 +1590,27 @@ const EditChannelModal = (props) => { } } + const invalidStatusCodeEntries = collectInvalidStatusCodeEntries( + localInputs.status_code_mapping, + ); + if (invalidStatusCodeEntries.length > 0) { + showError( + `${t('状态码复写包含无效的状态码')}: ${invalidStatusCodeEntries.join(', ')}`, + ); + return; + } + + const riskyStatusCodeRedirects = collectNewDisallowedStatusCodeRedirects( + initialStatusCodeMappingRef.current, + localInputs.status_code_mapping, + ); + if (riskyStatusCodeRedirects.length > 0) { + const confirmed = await confirmStatusCodeRisk(riskyStatusCodeRedirects); + if (!confirmed) { + return; + } + } + if (localInputs.base_url && localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice( 0, @@ -1570,13 +1663,16 @@ const EditChannelModal = (props) => { // type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值) if (localInputs.type === 1 || localInputs.type === 14) { settings.allow_service_tier = localInputs.allow_service_tier === true; - // 仅 OpenAI 渠道需要 store 和 safety_identifier + // 仅 OpenAI 渠道需要 store / safety_identifier / include_obfuscation if (localInputs.type === 1) { settings.disable_store = localInputs.disable_store === true; settings.allow_safety_identifier = localInputs.allow_safety_identifier === true; + settings.allow_include_obfuscation = + localInputs.allow_include_obfuscation === true; } if (localInputs.type === 14) { + settings.allow_inference_geo = localInputs.allow_inference_geo === true; settings.claude_beta_query = localInputs.claude_beta_query === true; } } @@ -1599,6 +1695,8 @@ const EditChannelModal = (props) => { delete localInputs.allow_service_tier; delete localInputs.disable_store; delete localInputs.allow_safety_identifier; + delete localInputs.allow_include_obfuscation; + delete localInputs.allow_inference_geo; delete localInputs.claude_beta_query; let res; @@ -2917,9 +3015,18 @@ const EditChannelModal = (props) => { rules={[{ required: true, message: t('请选择模型') }]} multiple filter={selectFilter} + allowCreate autoClearSearchValue={false} searchPosition='dropdown' optionList={modelOptions} + onSearch={(value) => setModelSearchValue(value)} + innerBottomSlot={ + modelSearchHintText ? ( + + {modelSearchHintText} + + ) : null + } style={{ width: '100%' }} onChange={(value) => handleInputChange('models', value)} renderSelectedItem={(optionNode) => { @@ -3444,6 +3551,24 @@ const EditChannelModal = (props) => { 'safety_identifier 字段用于帮助 OpenAI 识别可能违反使用政策的应用程序用户。默认关闭以保护用户隐私', )} /> + + + handleChannelOtherSettingsChange( + 'allow_include_obfuscation', + value, + ) + } + extraText={t( + 'include_obfuscation 用于控制 Responses 流混淆字段。默认关闭以避免客户端关闭该安全保护', + )} + /> )} @@ -3469,6 +3594,22 @@ const EditChannelModal = (props) => { 'service_tier 字段用于指定服务层级,允许透传可能导致实际计费高于预期。默认关闭以避免额外费用', )} /> + + + handleChannelOtherSettingsChange( + 'allow_inference_geo', + value, + ) + } + extraText={t( + 'inference_geo 字段用于控制 Claude 数据驻留推理区域。默认关闭以避免未经授权透传地域信息', + )} + /> )} @@ -3613,6 +3754,12 @@ const EditChannelModal = (props) => { onVisibleChange={(visible) => setIsModalOpenurl(visible)} /> + resolveStatusCodeRiskConfirm(false)} + onConfirm={() => resolveStatusCodeRiskConfirm(true)} + /> {/* 使用通用安全验证模态框 */} . For commercial licensing, please contact support@quantumnous.com */ -import React, { useState, useEffect, useRef } from 'react'; +import React, { useState, useEffect, useRef, useMemo } from 'react'; import { API, showError, @@ -64,6 +64,7 @@ const EditTagModal = (props) => { const [modelOptions, setModelOptions] = useState([]); const [groupOptions, setGroupOptions] = useState([]); const [customModel, setCustomModel] = useState(''); + const [modelSearchValue, setModelSearchValue] = useState(''); const originInputs = { tag: '', new_tag: null, @@ -74,6 +75,25 @@ const EditTagModal = (props) => { header_override: null, }; const [inputs, setInputs] = useState(originInputs); + const modelSearchMatchedCount = useMemo(() => { + const keyword = modelSearchValue.trim(); + if (!keyword) { + return modelOptions.length; + } + return modelOptions.reduce( + (count, option) => count + (selectFilter(keyword, option) ? 1 : 0), + 0, + ); + }, [modelOptions, modelSearchValue]); + const modelSearchHintText = useMemo(() => { + const keyword = modelSearchValue.trim(); + if (!keyword || modelSearchMatchedCount !== 0) { + return ''; + } + return t('未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加', { + name: keyword, + }); + }, [modelSearchMatchedCount, modelSearchValue, t]); const formApiRef = useRef(null); const getInitValues = () => ({ ...originInputs }); @@ -292,6 +312,7 @@ const EditTagModal = (props) => { fetchModels().then(); fetchGroups().then(); fetchTagModels().then(); + setModelSearchValue(''); if (formApiRef.current) { formApiRef.current.setValues({ ...getInitValues(), @@ -461,9 +482,18 @@ const EditTagModal = (props) => { placeholder={t('请选择该渠道所支持的模型,留空则不更改')} multiple filter={selectFilter} + allowCreate autoClearSearchValue={false} searchPosition='dropdown' optionList={modelOptions} + onSearch={(value) => setModelSearchValue(value)} + innerBottomSlot={ + modelSearchHintText ? ( + + {modelSearchHintText} + + ) : null + } style={{ width: '100%' }} onChange={(value) => handleInputChange('models', value)} /> diff --git a/web/src/components/table/channels/modals/StatusCodeRiskGuardModal.jsx b/web/src/components/table/channels/modals/StatusCodeRiskGuardModal.jsx new file mode 100644 index 000000000..ba3f46f59 --- /dev/null +++ b/web/src/components/table/channels/modals/StatusCodeRiskGuardModal.jsx @@ -0,0 +1,41 @@ +import React, { useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import RiskAcknowledgementModal from '../../../common/modals/RiskAcknowledgementModal'; +import { + STATUS_CODE_RISK_I18N_KEYS, + STATUS_CODE_RISK_CHECKLIST_KEYS, +} from './statusCodeRiskGuard'; + +const StatusCodeRiskGuardModal = React.memo(function StatusCodeRiskGuardModal({ + visible, + detailItems, + onCancel, + onConfirm, +}) { + const { t, i18n } = useTranslation(); + const checklist = useMemo( + () => STATUS_CODE_RISK_CHECKLIST_KEYS.map((item) => t(item)), + [t, i18n.language], + ); + + return ( + + ); +}); + +export default StatusCodeRiskGuardModal; diff --git a/web/src/components/table/channels/modals/statusCodeRiskGuard.js b/web/src/components/table/channels/modals/statusCodeRiskGuard.js new file mode 100644 index 000000000..169736baa --- /dev/null +++ b/web/src/components/table/channels/modals/statusCodeRiskGuard.js @@ -0,0 +1,132 @@ +const NON_REDIRECTABLE_STATUS_CODES = new Set([504, 524]); + +export const STATUS_CODE_RISK_I18N_KEYS = { + title: '高危操作确认', + detailTitle: '检测到以下高危状态码重定向规则', + inputPrompt: '操作确认', + confirmButton: '我确认开启高危重试', + markdown: '高危状态码重试风险告知与免责声明Markdown', + confirmText: '高危状态码重试风险确认输入文本', + inputPlaceholder: '高危状态码重试风险输入框占位文案', + mismatchText: '高危状态码重试风险输入不匹配提示', +}; + +export const STATUS_CODE_RISK_CHECKLIST_KEYS = [ + '高危状态码重试风险确认项1', + '高危状态码重试风险确认项2', + '高危状态码重试风险确认项3', + '高危状态码重试风险确认项4', +]; + +function parseStatusCodeKey(rawKey) { + if (typeof rawKey !== 'string') { + return null; + } + const normalized = rawKey.trim(); + if (!/^[1-5]\d{2}$/.test(normalized)) { + return null; + } + return Number.parseInt(normalized, 10); +} + +function parseStatusCodeMappingTarget(rawValue) { + if (typeof rawValue === 'number' && Number.isInteger(rawValue)) { + return rawValue >= 100 && rawValue <= 599 ? rawValue : null; + } + if (typeof rawValue === 'string') { + const normalized = rawValue.trim(); + if (!/^[1-5]\d{2}$/.test(normalized)) { + return null; + } + const code = Number.parseInt(normalized, 10); + return code >= 100 && code <= 599 ? code : null; + } + return null; +} + +export function collectInvalidStatusCodeEntries(statusCodeMappingStr) { + if ( + typeof statusCodeMappingStr !== 'string' || + statusCodeMappingStr.trim() === '' + ) { + return []; + } + + let parsed; + try { + parsed = JSON.parse(statusCodeMappingStr); + } catch { + return []; + } + + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + return []; + } + + const invalid = []; + for (const [rawKey, rawValue] of Object.entries(parsed)) { + const fromCode = parseStatusCodeKey(rawKey); + const toCode = parseStatusCodeMappingTarget(rawValue); + if (fromCode === null || toCode === null) { + invalid.push(`${rawKey} → ${rawValue}`); + } + } + + return invalid; +} + +export function collectDisallowedStatusCodeRedirects(statusCodeMappingStr) { + if ( + typeof statusCodeMappingStr !== 'string' || + statusCodeMappingStr.trim() === '' + ) { + return []; + } + + let parsed; + try { + parsed = JSON.parse(statusCodeMappingStr); + } catch (error) { + return []; + } + + if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) { + return []; + } + + const riskyMappings = []; + Object.entries(parsed).forEach(([rawFrom, rawTo]) => { + const fromCode = parseStatusCodeKey(rawFrom); + const toCode = parseStatusCodeMappingTarget(rawTo); + if (fromCode === null || toCode === null) { + return; + } + if (!NON_REDIRECTABLE_STATUS_CODES.has(fromCode)) { + return; + } + if (fromCode === toCode) { + return; + } + riskyMappings.push(`${fromCode} -> ${toCode}`); + }); + + return Array.from(new Set(riskyMappings)).sort(); +} + +export function collectNewDisallowedStatusCodeRedirects( + originalStatusCodeMappingStr, + currentStatusCodeMappingStr, +) { + const currentRisky = collectDisallowedStatusCodeRedirects( + currentStatusCodeMappingStr, + ); + if (currentRisky.length === 0) { + return []; + } + + const originalRiskySet = new Set( + collectDisallowedStatusCodeRedirects(originalStatusCodeMappingStr), + ); + + return currentRisky.filter((mapping) => !originalRiskySet.has(mapping)); +} diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index c78d5773e..4097545e5 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -84,8 +84,8 @@ function renderDuration(submit_time, finishTime) { // 返回带有样式的颜色标签 return ( - }> - {durationSec} 秒 + + {durationSec} s ); } @@ -149,7 +149,7 @@ const renderPlatform = (platform, t) => { ); if (option) { return ( - }> + {option.label} ); @@ -157,13 +157,13 @@ const renderPlatform = (platform, t) => { switch (platform) { case 'suno': return ( - }> + Suno ); default: return ( - }> + {t('未知')} ); @@ -240,7 +240,7 @@ export const getTaskLogsColumns = ({ openContentModal, isAdminUser, openVideoModal, - showUserInfoFunc, + openAudioModal, }) => { return [ { @@ -278,7 +278,6 @@ export const getTaskLogsColumns = ({ color={colors[parseInt(text) % colors.length]} size='large' shape='circle' - prefixIcon={} onClick={() => { copyText(text); }} @@ -294,7 +293,7 @@ export const getTaskLogsColumns = ({ { key: COLUMN_KEYS.USERNAME, title: t('用户'), - dataIndex: 'user_id', + dataIndex: 'username', render: (userId, record, index) => { if (!isAdminUser) { return <>; @@ -302,22 +301,14 @@ export const getTaskLogsColumns = ({ const displayText = String(record.username || userId || '?'); return ( - - showUserInfoFunc && showUserInfoFunc(userId)} - > - {displayText.slice(0, 1)} - - - showUserInfoFunc && showUserInfoFunc(userId)} + - {userId} + {displayText.slice(0, 1)} + + + {displayText} ); @@ -396,7 +387,27 @@ export const getTaskLogsColumns = ({ dataIndex: 'fail_reason', fixed: 'right', render: (text, record, index) => { - // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 + // Suno audio preview + const isSunoSuccess = + record.platform === 'suno' && + record.status === 'SUCCESS' && + Array.isArray(record.data) && + record.data.some((c) => c.audio_url); + if (isSunoSuccess) { + return ( + { + e.preventDefault(); + openAudioModal(record.data); + }} + > + {t('点击预览音乐')} + + ); + } + + // 视频预览:优先使用 result_url,兼容旧数据 fail_reason 中的 URL const isVideoTask = record.action === TASK_ACTION_GENERATE || record.action === TASK_ACTION_TEXT_GENERATE || @@ -404,14 +415,15 @@ export const getTaskLogsColumns = ({ record.action === TASK_ACTION_REFERENCE_GENERATE || record.action === TASK_ACTION_REMIX_GENERATE; const isSuccess = record.status === 'SUCCESS'; - const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); - if (isSuccess && isVideoTask && isUrl) { + const resultUrl = record.result_url; + const hasResultUrl = typeof resultUrl === 'string' && /^https?:\/\//.test(resultUrl); + if (isSuccess && isVideoTask && hasResultUrl) { return ( { e.preventDefault(); - openVideoModal(text); + openVideoModal(resultUrl); }} > {t('点击预览视频')} diff --git a/web/src/components/table/task-logs/TaskLogsTable.jsx b/web/src/components/table/task-logs/TaskLogsTable.jsx index b62e15bd2..b3cec8ccc 100644 --- a/web/src/components/table/task-logs/TaskLogsTable.jsx +++ b/web/src/components/table/task-logs/TaskLogsTable.jsx @@ -40,6 +40,7 @@ const TaskLogsTable = (taskLogsData) => { copyText, openContentModal, openVideoModal, + openAudioModal, showUserInfoFunc, isAdminUser, t, @@ -54,10 +55,11 @@ const TaskLogsTable = (taskLogsData) => { copyText, openContentModal, openVideoModal, + openAudioModal, showUserInfoFunc, isAdminUser, }); - }, [t, COLUMN_KEYS, copyText, openContentModal, openVideoModal, showUserInfoFunc, isAdminUser]); + }, [t, COLUMN_KEYS, copyText, openContentModal, openVideoModal, openAudioModal, showUserInfoFunc, isAdminUser]); // Filter columns based on visibility settings const getVisibleColumns = () => { diff --git a/web/src/components/table/task-logs/index.jsx b/web/src/components/table/task-logs/index.jsx index 140725a89..07c387123 100644 --- a/web/src/components/table/task-logs/index.jsx +++ b/web/src/components/table/task-logs/index.jsx @@ -25,7 +25,7 @@ import TaskLogsActions from './TaskLogsActions'; import TaskLogsFilters from './TaskLogsFilters'; import ColumnSelectorModal from './modals/ColumnSelectorModal'; import ContentModal from './modals/ContentModal'; -import UserInfoModal from '../usage-logs/modals/UserInfoModal'; +import AudioPreviewModal from './modals/AudioPreviewModal'; import { useTaskLogsData } from '../../../hooks/task-logs/useTaskLogsData'; import { useIsMobile } from '../../../hooks/common/useIsMobile'; import { createCardProPagination } from '../../../helpers/utils'; @@ -46,7 +46,11 @@ const TaskLogsPage = () => { modalContent={taskLogsData.videoUrl} isVideo={true} /> - + . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useState, useRef, useEffect } from 'react'; +import { Modal, Typography, Tag, Button } from '@douyinfe/semi-ui'; +import { IconExternalOpen, IconCopy } from '@douyinfe/semi-icons'; +import { useTranslation } from 'react-i18next'; + +const { Text, Title } = Typography; + +const formatDuration = (seconds) => { + if (!seconds || seconds <= 0) return '--:--'; + const m = Math.floor(seconds / 60); + const s = Math.floor(seconds % 60); + return `${m}:${s.toString().padStart(2, '0')}`; +}; + +const AudioClipCard = ({ clip }) => { + const { t } = useTranslation(); + const [hasError, setHasError] = useState(false); + const audioRef = useRef(null); + + useEffect(() => { + setHasError(false); + }, [clip.audio_url]); + + const title = clip.title || t('未命名'); + const tags = clip.tags || clip.metadata?.tags || ''; + const duration = clip.duration || clip.metadata?.duration; + const imageUrl = clip.image_url || clip.image_large_url; + const audioUrl = clip.audio_url; + + return ( +
+ {imageUrl && ( + {title} { + e.target.style.display = 'none'; + }} + /> + )} +
+
+ + {title} + + {duration > 0 && ( + + {formatDuration(duration)} + + )} +
+ + {tags && ( +
+ + {tags} + +
+ )} + + {hasError ? ( +
+ + {t('音频无法播放')} + + + +
+ ) : ( +
+
+ ); +}; + +const AudioPreviewModal = ({ isModalOpen, setIsModalOpen, audioClips }) => { + const { t } = useTranslation(); + const clips = Array.isArray(audioClips) ? audioClips : []; + + return ( + setIsModalOpen(false)} + onCancel={() => setIsModalOpen(false)} + closable={null} + footer={null} + bodyStyle={{ + maxHeight: '70vh', + overflow: 'auto', + padding: '16px', + }} + width={560} + > + {clips.length === 0 ? ( + {t('无')} + ) : ( +
+ {clips.map((clip, idx) => ( + + ))} +
+ )} +
+ ); +}; + +export default AudioPreviewModal; diff --git a/web/src/components/table/task-logs/modals/ContentModal.jsx b/web/src/components/table/task-logs/modals/ContentModal.jsx index 88df4d8ce..3527fd96d 100644 --- a/web/src/components/table/task-logs/modals/ContentModal.jsx +++ b/web/src/components/table/task-logs/modals/ContentModal.jsx @@ -144,8 +144,6 @@ const ContentModal = ({ maxHeight: '100%', objectFit: 'contain', }} - autoPlay - crossOrigin='anonymous' onError={handleVideoError} onLoadedData={handleVideoLoaded} onLoadStart={() => setIsLoading(true)} diff --git a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx index f0dcd379e..b1538877a 100644 --- a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx +++ b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx @@ -133,6 +133,12 @@ function renderType(type, t) { {t('错误')} ); + case 6: + return ( + + {t('退款')} + + ); default: return ( @@ -368,7 +374,7 @@ export const getLogsColumns = ({ } return isAdminUser && - (record.type === 0 || record.type === 2 || record.type === 5) ? ( + (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) ? ( @@ -459,7 +465,7 @@ export const getLogsColumns = ({ title: t('令牌'), dataIndex: 'token_name', render: (text, record, index) => { - return record.type === 0 || record.type === 2 || record.type === 5 ? ( + return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? (
{ - if (record.type === 0 || record.type === 2 || record.type === 5) { + if (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) { if (record.group) { return <>{renderGroup(record.group)}; } else { @@ -522,7 +528,7 @@ export const getLogsColumns = ({ title: t('模型'), dataIndex: 'model_name', render: (text, record, index) => { - return record.type === 0 || record.type === 2 || record.type === 5 ? ( + return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? ( <>{renderModelName(record, copyText, t)} ) : ( <> @@ -589,7 +595,7 @@ export const getLogsColumns = ({ cacheText = `${t('缓存写')} ${formatTokenCount(cacheSummary.cacheWriteTokens)}`; } - return record.type === 0 || record.type === 2 || record.type === 5 ? ( + return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? (
{ return parseInt(text) > 0 && - (record.type === 0 || record.type === 2 || record.type === 5) ? ( + (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) ? ( <>{ {text} } ) : ( <> @@ -635,7 +641,7 @@ export const getLogsColumns = ({ title: t('花费'), dataIndex: 'quota', render: (text, record, index) => { - if (!(record.type === 0 || record.type === 2 || record.type === 5)) { + if (!(record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6)) { return <>; } const other = getLogOther(record.other); @@ -722,6 +728,16 @@ export const getLogsColumns = ({ fixed: 'right', render: (text, record, index) => { let other = getLogOther(record.other); + if (record.type === 6) { + return ( + + {t('异步任务退款')} + + ); + } if (other == null || record.type !== 2) { return ( {t('管理')} {t('系统')} {t('错误')} + {t('退款')}
diff --git a/web/src/components/table/usage-logs/modals/ChannelAffinityUsageCacheModal.jsx b/web/src/components/table/usage-logs/modals/ChannelAffinityUsageCacheModal.jsx index ea1a5c7fb..383ebabc1 100644 --- a/web/src/components/table/usage-logs/modals/ChannelAffinityUsageCacheModal.jsx +++ b/web/src/components/table/usage-logs/modals/ChannelAffinityUsageCacheModal.jsx @@ -39,6 +39,21 @@ function formatTokenRate(n, d) { return `${r.toFixed(2)}%`; } +function formatCachedTokenRate(cachedTokens, promptTokens, mode) { + if (mode === 'cached_over_prompt_plus_cached') { + const denominator = Number(promptTokens || 0) + Number(cachedTokens || 0); + return formatTokenRate(cachedTokens, denominator); + } + if (mode === 'cached_over_prompt') { + return formatTokenRate(cachedTokens, promptTokens); + } + return '-'; +} + +function hasTextValue(value) { + return typeof value === 'string' && value.trim() !== ''; +} + const ChannelAffinityUsageCacheModal = ({ t, showChannelAffinityUsageCacheModal, @@ -107,7 +122,7 @@ const ChannelAffinityUsageCacheModal = ({ t, ]); - const rows = useMemo(() => { + const { rows, supportsTokenStats } = useMemo(() => { const s = stats || {}; const hit = Number(s.hit || 0); const total = Number(s.total || 0); @@ -118,48 +133,62 @@ const ChannelAffinityUsageCacheModal = ({ const totalTokens = Number(s.total_tokens || 0); const cachedTokens = Number(s.cached_tokens || 0); const promptCacheHitTokens = Number(s.prompt_cache_hit_tokens || 0); + const cachedTokenRateMode = String(s.cached_token_rate_mode || '').trim(); + const supportsTokenStats = + cachedTokenRateMode === 'cached_over_prompt' || + cachedTokenRateMode === 'cached_over_prompt_plus_cached' || + cachedTokenRateMode === 'mixed'; - return [ - { key: t('规则'), value: s.rule_name || params.rule_name || '-' }, - { key: t('分组'), value: s.using_group || params.using_group || '-' }, - { - key: t('Key 摘要'), - value: params.key_hint || '-', - }, - { - key: t('Key 指纹'), - value: s.key_fp || params.key_fp || '-', - }, - { key: t('TTL(秒)'), value: windowSeconds > 0 ? windowSeconds : '-' }, - { - key: t('命中率'), - value: `${hit}/${total} (${formatRate(hit, total)})`, - }, - { - key: t('Prompt tokens'), - value: promptTokens, - }, - { - key: t('Cached tokens'), - value: `${cachedTokens} (${formatTokenRate(cachedTokens, promptTokens)})`, - }, - { - key: t('Prompt cache hit tokens'), - value: promptCacheHitTokens, - }, - { - key: t('Completion tokens'), - value: completionTokens, - }, - { - key: t('Total tokens'), - value: totalTokens, - }, - { - key: t('最近一次'), - value: lastSeenAt > 0 ? timestamp2string(lastSeenAt) : '-', - }, - ]; + const data = []; + const ruleName = String(s.rule_name || params.rule_name || '').trim(); + const usingGroup = String(s.using_group || params.using_group || '').trim(); + const keyHint = String(params.key_hint || '').trim(); + const keyFp = String(s.key_fp || params.key_fp || '').trim(); + + if (hasTextValue(ruleName)) { + data.push({ key: t('规则'), value: ruleName }); + } + if (hasTextValue(usingGroup)) { + data.push({ key: t('分组'), value: usingGroup }); + } + if (hasTextValue(keyHint)) { + data.push({ key: t('Key 摘要'), value: keyHint }); + } + if (hasTextValue(keyFp)) { + data.push({ key: t('Key 指纹'), value: keyFp }); + } + if (windowSeconds > 0) { + data.push({ key: t('TTL(秒)'), value: windowSeconds }); + } + if (total > 0) { + data.push({ key: t('命中率'), value: `${hit}/${total} (${formatRate(hit, total)})` }); + } + if (lastSeenAt > 0) { + data.push({ key: t('最近一次'), value: timestamp2string(lastSeenAt) }); + } + + if (supportsTokenStats) { + if (promptTokens > 0) { + data.push({ key: t('Prompt tokens'), value: promptTokens }); + } + if (promptTokens > 0 || cachedTokens > 0) { + data.push({ + key: t('Cached tokens'), + value: `${cachedTokens} (${formatCachedTokenRate(cachedTokens, promptTokens, cachedTokenRateMode)})`, + }); + } + if (promptCacheHitTokens > 0) { + data.push({ key: t('Prompt cache hit tokens'), value: promptCacheHitTokens }); + } + if (completionTokens > 0) { + data.push({ key: t('Completion tokens'), value: completionTokens }); + } + if (totalTokens > 0) { + data.push({ key: t('Total tokens'), value: totalTokens }); + } + } + + return { rows: data, supportsTokenStats }; }, [stats, params, t]); return ( @@ -179,15 +208,27 @@ const ChannelAffinityUsageCacheModal = ({ {t( '命中判定:usage 中存在 cached tokens(例如 cached_tokens/prompt_cache_hit_tokens)即视为命中。', )} + {' '} + {t( + 'Cached tokens 占比口径由后端返回:Claude 语义按 cached/(prompt+cached),其余按 cached/prompt。', + )} + {' '} + {t('当前仅 OpenAI / Claude 语义支持缓存 token 统计,其他通道将隐藏 token 相关字段。')} + {stats && !supportsTokenStats ? ( + <> + {' '} + {t('该记录不包含可用的 token 统计口径。')} + + ) : null}
- {stats ? ( + {stats && rows.length > 0 ? ( ) : (
- {loading ? t('加载中...') : t('暂无数据')} + {loading ? t('加载中...') : t('暂无可展示数据')}
)} diff --git a/web/src/components/table/users/modals/EditUserModal.jsx b/web/src/components/table/users/modals/EditUserModal.jsx index 32601daa8..90676d840 100644 --- a/web/src/components/table/users/modals/EditUserModal.jsx +++ b/web/src/components/table/users/modals/EditUserModal.jsx @@ -45,7 +45,6 @@ import { Avatar, Row, Col, - Input, InputNumber, } from '@douyinfe/semi-ui'; import { @@ -56,6 +55,7 @@ import { IconUserGroup, IconPlus, } from '@douyinfe/semi-icons'; +import UserBindingManagementModal from './UserBindingManagementModal'; const { Text, Title } = Typography; @@ -68,6 +68,7 @@ const EditUserModal = (props) => { const [addAmountLocal, setAddAmountLocal] = useState(''); const isMobile = useIsMobile(); const [groupOptions, setGroupOptions] = useState([]); + const [bindingModalVisible, setBindingModalVisible] = useState(false); const formApiRef = useRef(null); const isEdit = Boolean(userId); @@ -81,6 +82,7 @@ const EditUserModal = (props) => { discord_id: '', wechat_id: '', telegram_id: '', + linux_do_id: '', email: '', quota: 0, group: 'default', @@ -115,8 +117,17 @@ const EditUserModal = (props) => { useEffect(() => { loadUser(); if (userId) fetchGroups(); + setBindingModalVisible(false); }, [props.editingUser.id]); + const openBindingModal = () => { + setBindingModalVisible(true); + }; + + const closeBindingModal = () => { + setBindingModalVisible(false); + }; + /* ----------------------- submit ----------------------- */ const submit = async (values) => { setLoading(true); @@ -196,7 +207,7 @@ const EditUserModal = (props) => { onSubmit={submit} > {({ values }) => ( -
+
{/* 基本信息 */}
@@ -316,56 +327,51 @@ const EditUserModal = (props) => { )} - {/* 绑定信息 */} - -
- - - -
- - {t('绑定信息')} - -
- {t('第三方账户绑定状态(只读)')} + {/* 绑定信息入口 */} + {userId && ( + +
+
+ + + +
+ + {t('绑定信息')} + +
+ {t('管理用户已绑定的第三方账户,支持筛选与解绑')} +
+
+
-
- - - {[ - 'github_id', - 'discord_id', - 'oidc_id', - 'wechat_id', - 'email', - 'telegram_id', - ].map((field) => ( - - - - ))} - - + + )}
)} + + {/* 添加额度模态框 */} {
{t('金额')} - ({t('仅用于换算,实际保存的是额度')}) + + {' '} + ({t('仅用于换算,实际保存的是额度')}) +
{ onChange={(val) => { setAddAmountLocal(val); setAddQuotaLocal( - val != null && val !== '' ? displayAmountToQuota(Math.abs(val)) * Math.sign(val) : '', + val != null && val !== '' + ? displayAmountToQuota(Math.abs(val)) * Math.sign(val) + : '', ); }} style={{ width: '100%' }} @@ -430,7 +441,11 @@ const EditUserModal = (props) => { setAddQuotaLocal(val); setAddAmountLocal( val != null && val !== '' - ? Number((quotaToDisplayAmount(Math.abs(val)) * Math.sign(val)).toFixed(2)) + ? Number( + ( + quotaToDisplayAmount(Math.abs(val)) * Math.sign(val) + ).toFixed(2), + ) : '', ); }} diff --git a/web/src/components/table/users/modals/UserBindingManagementModal.jsx b/web/src/components/table/users/modals/UserBindingManagementModal.jsx new file mode 100644 index 000000000..c5b2a3a15 --- /dev/null +++ b/web/src/components/table/users/modals/UserBindingManagementModal.jsx @@ -0,0 +1,410 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React from 'react'; +import { useTranslation } from 'react-i18next'; +import { + API, + showError, + showSuccess, + getOAuthProviderIcon, +} from '../../../../helpers'; +import { + Modal, + Spin, + Typography, + Card, + Checkbox, + Tag, + Button, +} from '@douyinfe/semi-ui'; +import { + IconLink, + IconMail, + IconDelete, + IconGithubLogo, +} from '@douyinfe/semi-icons'; +import { SiDiscord, SiTelegram, SiWechat, SiLinux } from 'react-icons/si'; + +const { Text } = Typography; + +const UserBindingManagementModal = ({ + visible, + onCancel, + userId, + isMobile, + formApiRef, +}) => { + const { t } = useTranslation(); + const [bindingLoading, setBindingLoading] = React.useState(false); + const [showBoundOnly, setShowBoundOnly] = React.useState(true); + const [statusInfo, setStatusInfo] = React.useState({}); + const [customOAuthBindings, setCustomOAuthBindings] = React.useState([]); + const [bindingActionLoading, setBindingActionLoading] = React.useState({}); + + const loadBindingData = React.useCallback(async () => { + if (!userId) return; + + setBindingLoading(true); + try { + const [statusRes, customBindingRes] = await Promise.all([ + API.get('/api/status'), + API.get(`/api/user/${userId}/oauth/bindings`), + ]); + + if (statusRes.data?.success) { + setStatusInfo(statusRes.data.data || {}); + } else { + showError(statusRes.data?.message || t('操作失败')); + } + + if (customBindingRes.data?.success) { + setCustomOAuthBindings(customBindingRes.data.data || []); + } else { + showError(customBindingRes.data?.message || t('操作失败')); + } + } catch (error) { + showError( + error.response?.data?.message || error.message || t('操作失败'), + ); + } finally { + setBindingLoading(false); + } + }, [t, userId]); + + React.useEffect(() => { + if (!visible) return; + setShowBoundOnly(true); + setBindingActionLoading({}); + loadBindingData(); + }, [visible, loadBindingData]); + + const setBindingLoadingState = (key, value) => { + setBindingActionLoading((prev) => ({ ...prev, [key]: value })); + }; + + const handleUnbindBuiltInAccount = (bindingItem) => { + if (!userId) return; + + Modal.confirm({ + title: t('确认解绑'), + content: t('确定要解绑 {{name}} 吗?', { name: bindingItem.name }), + okText: t('确认'), + cancelText: t('取消'), + onOk: async () => { + const loadingKey = `builtin-${bindingItem.key}`; + setBindingLoadingState(loadingKey, true); + try { + const res = await API.delete( + `/api/user/${userId}/bindings/${bindingItem.key}`, + ); + if (!res.data?.success) { + showError(res.data?.message || t('操作失败')); + return; + } + formApiRef.current?.setValue(bindingItem.field, ''); + showSuccess(t('解绑成功')); + } catch (error) { + showError( + error.response?.data?.message || error.message || t('操作失败'), + ); + } finally { + setBindingLoadingState(loadingKey, false); + } + }, + }); + }; + + const handleUnbindCustomOAuthAccount = (provider) => { + if (!userId) return; + + Modal.confirm({ + title: t('确认解绑'), + content: t('确定要解绑 {{name}} 吗?', { name: provider.name }), + okText: t('确认'), + cancelText: t('取消'), + onOk: async () => { + const loadingKey = `custom-${provider.id}`; + setBindingLoadingState(loadingKey, true); + try { + const res = await API.delete( + `/api/user/${userId}/oauth/bindings/${provider.id}`, + ); + if (!res.data?.success) { + showError(res.data?.message || t('操作失败')); + return; + } + setCustomOAuthBindings((prev) => + prev.filter( + (item) => Number(item.provider_id) !== Number(provider.id), + ), + ); + showSuccess(t('解绑成功')); + } catch (error) { + showError( + error.response?.data?.message || error.message || t('操作失败'), + ); + } finally { + setBindingLoadingState(loadingKey, false); + } + }, + }); + }; + + const currentValues = formApiRef.current?.getValues?.() || {}; + + const builtInBindingItems = [ + { + key: 'email', + field: 'email', + name: t('邮箱'), + enabled: true, + value: currentValues.email, + icon: ( + + ), + }, + { + key: 'github', + field: 'github_id', + name: 'GitHub', + enabled: Boolean(statusInfo.github_oauth), + value: currentValues.github_id, + icon: ( + + ), + }, + { + key: 'discord', + field: 'discord_id', + name: 'Discord', + enabled: Boolean(statusInfo.discord_oauth), + value: currentValues.discord_id, + icon: ( + + ), + }, + { + key: 'oidc', + field: 'oidc_id', + name: 'OIDC', + enabled: Boolean(statusInfo.oidc_enabled), + value: currentValues.oidc_id, + icon: ( + + ), + }, + { + key: 'wechat', + field: 'wechat_id', + name: t('微信'), + enabled: Boolean(statusInfo.wechat_login), + value: currentValues.wechat_id, + icon: ( + + ), + }, + { + key: 'telegram', + field: 'telegram_id', + name: 'Telegram', + enabled: Boolean(statusInfo.telegram_oauth), + value: currentValues.telegram_id, + icon: ( + + ), + }, + { + key: 'linuxdo', + field: 'linux_do_id', + name: 'LinuxDO', + enabled: Boolean(statusInfo.linuxdo_oauth), + value: currentValues.linux_do_id, + icon: ( + + ), + }, + ]; + + const customBindingMap = new Map( + customOAuthBindings.map((item) => [Number(item.provider_id), item]), + ); + + const customProviderMap = new Map( + (statusInfo.custom_oauth_providers || []).map((provider) => [ + Number(provider.id), + provider, + ]), + ); + + customOAuthBindings.forEach((binding) => { + if (!customProviderMap.has(Number(binding.provider_id))) { + customProviderMap.set(Number(binding.provider_id), { + id: binding.provider_id, + name: binding.provider_name, + icon: binding.provider_icon, + }); + } + }); + + const customBindingItems = Array.from(customProviderMap.values()).map( + (provider) => { + const binding = customBindingMap.get(Number(provider.id)); + return { + key: `custom-${provider.id}`, + providerId: provider.id, + name: provider.name, + enabled: true, + value: binding?.provider_user_id || '', + icon: getOAuthProviderIcon( + provider.icon || binding?.provider_icon || '', + 20, + ), + }; + }, + ); + + const allBindingItems = [ + ...builtInBindingItems.map((item) => ({ ...item, type: 'builtin' })), + ...customBindingItems.map((item) => ({ ...item, type: 'custom' })), + ]; + + const boundCount = allBindingItems.filter((item) => + Boolean(item.value), + ).length; + + const visibleBindingItems = showBoundOnly + ? allBindingItems.filter((item) => Boolean(item.value)) + : allBindingItems; + + return ( + + + {t('账户绑定管理')} +
+ } + > + +
+
+ setShowBoundOnly(Boolean(e.target.checked))} + > + {t('仅显示已绑定')} + + + {t('已绑定')} {boundCount} / {allBindingItems.length} + +
+ + {visibleBindingItems.length === 0 ? ( + + {t('暂无已绑定项')} + + ) : ( +
+ {visibleBindingItems.map((item, index) => { + const isBound = Boolean(item.value); + const loadingKey = + item.type === 'builtin' + ? `builtin-${item.key}` + : `custom-${item.providerId}`; + const statusText = isBound + ? item.value + : item.enabled + ? t('未绑定') + : t('未启用'); + const shouldSpanTwoColsOnDesktop = + visibleBindingItems.length % 2 === 1 && + index === visibleBindingItems.length - 1; + + return ( + +
+
+
+ {item.icon} +
+
+
+ {item.name} + + {item.type === 'builtin' + ? t('内置') + : t('自定义')} + +
+
+ {statusText} +
+
+
+ +
+
+ ); + })} +
+ )} +
+
+
+ ); +}; + +export default UserBindingManagementModal; diff --git a/web/src/helpers/render.jsx b/web/src/helpers/render.jsx index ecc252cfd..3ba198cb3 100644 --- a/web/src/helpers/render.jsx +++ b/web/src/helpers/render.jsx @@ -76,6 +76,31 @@ import { Server, CalendarClock, } from 'lucide-react'; +import { + SiAtlassian, + SiAuth0, + SiAuthentik, + SiBitbucket, + SiDiscord, + SiDropbox, + SiFacebook, + SiGitea, + SiGithub, + SiGitlab, + SiGoogle, + SiKeycloak, + SiLinkedin, + SiNextcloud, + SiNotion, + SiOkta, + SiOpenid, + SiReddit, + SiSlack, + SiTelegram, + SiTwitch, + SiWechat, + SiX, +} from 'react-icons/si'; // 获取侧边栏Lucide图标组件 export function getLucideIcon(key, selected = false) { @@ -472,6 +497,106 @@ export function getLobeHubIcon(iconName, size = 14) { return ; } +const oauthProviderIconMap = { + github: SiGithub, + gitlab: SiGitlab, + gitea: SiGitea, + google: SiGoogle, + discord: SiDiscord, + facebook: SiFacebook, + linkedin: SiLinkedin, + x: SiX, + twitter: SiX, + slack: SiSlack, + telegram: SiTelegram, + wechat: SiWechat, + keycloak: SiKeycloak, + nextcloud: SiNextcloud, + authentik: SiAuthentik, + openid: SiOpenid, + okta: SiOkta, + auth0: SiAuth0, + atlassian: SiAtlassian, + bitbucket: SiBitbucket, + notion: SiNotion, + twitch: SiTwitch, + reddit: SiReddit, + dropbox: SiDropbox, +}; + +function isHttpUrl(value) { + return /^https?:\/\//i.test(value || ''); +} + +function isSimpleEmoji(value) { + if (!value) return false; + const trimmed = String(value).trim(); + return trimmed.length > 0 && trimmed.length <= 4 && !isHttpUrl(trimmed); +} + +function normalizeOAuthIconKey(raw) { + return raw + .trim() + .toLowerCase() + .replace(/^ri:/, '') + .replace(/^react-icons:/, '') + .replace(/^si:/, ''); +} + +/** + * Render custom OAuth provider icon with react-icons or URL/emoji fallback. + * Supported formats: + * - react-icons simple key: github / gitlab / google / keycloak + * - prefixed key: ri:github / si:github + * - full URL image: https://example.com/logo.png + * - emoji: 🐱 + */ +export function getOAuthProviderIcon(iconName, size = 20) { + const raw = String(iconName || '').trim(); + const iconSize = Number(size) > 0 ? Number(size) : 20; + + if (!raw) { + return ; + } + + if (isHttpUrl(raw)) { + return ( + provider icon + ); + } + + if (isSimpleEmoji(raw)) { + return ( + + {raw} + + ); + } + + const key = normalizeOAuthIconKey(raw); + const IconComp = oauthProviderIconMap[key]; + if (IconComp) { + return ; + } + + return {raw.charAt(0).toUpperCase()}; +} + // 颜色列表 const colors = [ 'amber', diff --git a/web/src/hooks/task-logs/useTaskLogsData.js b/web/src/hooks/task-logs/useTaskLogsData.js index a461e3522..6ba3de388 100644 --- a/web/src/hooks/task-logs/useTaskLogsData.js +++ b/web/src/hooks/task-logs/useTaskLogsData.js @@ -72,6 +72,10 @@ export const useTaskLogsData = () => { const [isVideoModalOpen, setIsVideoModalOpen] = useState(false); const [videoUrl, setVideoUrl] = useState(''); + // Audio preview modal state + const [isAudioModalOpen, setIsAudioModalOpen] = useState(false); + const [audioClips, setAudioClips] = useState([]); + // User info modal state const [showUserInfo, setShowUserInfoModal] = useState(false); const [userInfoData, setUserInfoData] = useState(null); @@ -277,6 +281,11 @@ export const useTaskLogsData = () => { setIsVideoModalOpen(true); }; + const openAudioModal = (clips) => { + setAudioClips(clips); + setIsAudioModalOpen(true); + }; + // User info function const showUserInfoFunc = async (userId) => { if (!isAdminUser) { @@ -319,6 +328,11 @@ export const useTaskLogsData = () => { setIsVideoModalOpen, videoUrl, + // Audio preview modal + isAudioModalOpen, + setIsAudioModalOpen, + audioClips, + // Form state formApi, setFormApi, @@ -351,7 +365,8 @@ export const useTaskLogsData = () => { refresh, copyText, openContentModal, - openVideoModal, // 新增 + openVideoModal, + openAudioModal, enrichLogs, syncPageData, diff --git a/web/src/hooks/usage-logs/useUsageLogsData.jsx b/web/src/hooks/usage-logs/useUsageLogsData.jsx index 14c021e41..b69a7cf18 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.jsx +++ b/web/src/hooks/usage-logs/useUsageLogsData.jsx @@ -344,7 +344,7 @@ export const useLogsData = () => { let other = getLogOther(logs[i].other); let expandDataLocal = []; - if (isAdminUser && (logs[i].type === 0 || logs[i].type === 2)) { + if (isAdminUser && (logs[i].type === 0 || logs[i].type === 2 || logs[i].type === 6)) { expandDataLocal.push({ key: t('渠道信息'), value: `${logs[i].channel} - ${logs[i].channel_name || '[未知]'}`, @@ -535,6 +535,24 @@ export const useLogsData = () => { }); } } + if (logs[i].type === 6) { + if (other?.task_id) { + expandDataLocal.push({ + key: t('任务ID'), + value: other.task_id, + }); + } + if (other?.reason) { + expandDataLocal.push({ + key: t('失败原因'), + value: ( +
+ {other.reason} +
+ ), + }); + } + } if (other?.request_path) { expandDataLocal.push({ key: t('请求路径'), @@ -590,13 +608,13 @@ export const useLogsData = () => { ), }); } - if (isAdminUser) { + if (isAdminUser && logs[i].type !== 6) { expandDataLocal.push({ key: t('请求转换'), value: requestConversionDisplayValue(other?.request_conversion), }); } - if (isAdminUser) { + if (isAdminUser && logs[i].type !== 6) { let localCountMode = ''; if (other?.admin_info?.local_count_tokens) { localCountMode = t('本地计费'); diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 8b2b08529..f6c13e7d8 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -302,7 +302,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "Task ID", - "任务ID": "Task ID", "任务日志": "Task Logs", "任务状态": "Status", "任务记录": "Task Records", @@ -544,7 +543,6 @@ "创建": "Create", "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)": "Create token with auto group by default, initial token will also be set to auto (otherwise leave blank for user default group)", "创建失败": "Creation failed", - "创建成功": "Creation successful", "创建或选择密钥时,将 Project 设置为 io.cloud": "When creating or selecting a key, set Project to io.cloud", "创建新用户账户": "Create new user account", "创建新的令牌": "Create New Token", @@ -787,7 +785,6 @@ "天": "day", "天前": "days ago", "失败": "Failed", - "失败原因": "Failure reason", "失败时自动禁用通道": "Automatically disable channel on failure", "失败重试次数": "Failed retry times", "奖励说明": "Reward description", @@ -1336,7 +1333,6 @@ "更新失败,请检查输入信息": "Update failed, please check the input information", "更新容器配置": "Update Container Configuration", "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。": "Updating container configuration may cause the container to restart, please ensure you perform this operation at an appropriate time.", - "更新成功": "Update successful", "更新所有已启用通道余额": "Update balance for all enabled channels", "更新支付设置": "Update payment settings", "更新时间": "Update time", @@ -1638,10 +1634,14 @@ "点击查看差异": "Click to view differences", "点击此处": "click here", "点击预览视频": "Click to preview video", + "点击预览音乐": "Click to preview music", + "音乐预览": "Music Preview", + "音频无法播放": "Audio cannot be played", "点击验证按钮,使用您的生物特征或安全密钥": "Click the verification button and use your biometrics or security key", "版权所有": "All rights reserved", "状态": "Status", "状态码复写": "Status Code Override", + "状态码复写包含无效的状态码": "Status code override contains invalid status codes", "状态筛选": "Status filter", "状态页面Slug": "Status Page Slug", "环境变量": "Environment Variables", @@ -1767,7 +1767,6 @@ "确认清理不活跃的磁盘缓存?": "Confirm cleanup of inactive disk cache?", "确认禁用": "Confirm disable", "确认补单": "Confirm Order Completion", - "确认解绑": "Confirm Unbind", "确认解绑 Passkey": "Confirm Unbind Passkey", "确认设置并完成初始化": "Confirm settings and complete initialization", "确认重置 Passkey": "Confirm Passkey Reset", @@ -1945,7 +1944,6 @@ "自动分组auto,从第一个开始选择": "Auto grouping auto, select from the first one", "自动刷新": "Auto Refresh", "自动刷新中": "Auto refreshing", - "自动检测": "Auto Detect", "自动模式": "Auto Mode", "自动测试所有通道间隔时间": "Auto test interval for all channels", "自动禁用": "Auto disabled", @@ -1955,6 +1953,19 @@ "自动重试状态码": "Auto-retry status codes", "自动重试状态码格式不正确": "Invalid auto-retry status code format", "支持填写单个状态码或范围(含首尾),使用逗号分隔": "Supports single status codes or inclusive ranges; separate with commas", + "支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响": "Supports single status codes or inclusive ranges; separate with commas. 504 and 524 are never retried and are not affected by this setting", + "高危操作确认": "High-risk operation confirmation", + "检测到以下高危状态码重定向规则": "Detected high-risk status-code redirect rules", + "操作确认": "Operation confirmation", + "我确认开启高危重试": "I confirm enabling high-risk retry", + "高危状态码重试风险告知与免责声明Markdown": "### ⚠️ High-Risk Operation: Risk Notice and Disclaimer for 504/524 Retry\nBy default, this project does not retry for status codes `400` (bad request), `504` (gateway timeout), and `524` (timeout occurred).\n In many cases, 504 and 524 mean the request has reached the upstream AI service and processing has started, but the connection was closed due to long processing time.\n\nEnabling redirection/retry for these timeout status codes is a **high-risk operation**. Before enabling it, you must read and understand the consequences below:\n\n#### 1. Core Risks (Read Carefully)\n1. 💸 Duplicate/multiple billing risk: Most upstream AI providers **still charge** for requests that started processing but got interrupted by network timeout (504/524). If retry is triggered, a new upstream request will be sent, which can lead to **duplicate or multiple charges**.\n2. ⏳ Severe client timeout: If a single request already timed out, adding retries can multiply total latency and cause severe or unacceptable timeout behavior for your final client/caller.\n3. 💥 Request backlog and system crash risk: Forcing retries on timeout requests keeps threads and connections occupied for longer. Under high concurrency, this can cause serious backlog, exhaust system resources, trigger a cascading failure, and crash your proxy service.\n\n#### 2. Risk Acknowledgement\nIf you still choose to enable this feature, you acknowledge all of the following:", + "高危状态码重试风险确认输入文本": "I understand the duplicate billing and crash risks, and confirm enabling it.", + "高危状态码重试风险确认项1": "I have fully read and understood the risks and fully understand the destructive consequences of forcing retries for status codes 504 and 524.", + "高危状态码重试风险确认项2": "I have communicated with the upstream provider and confirmed that the timeout issue is an upstream bottleneck and cannot be resolved upstream at this time.", + "高危状态码重试风险确认项3": "I voluntarily accept all duplicate/multiple billing risks and will not file issues or complaints in this project repository regarding billing anomalies caused by this retry behavior.", + "高危状态码重试风险确认项4": "I voluntarily accept system stability risks, including severe client timeout and possible service crash. Any consequences caused by enabling this feature are my own responsibility.", + "高危状态码重试风险输入框占位文案": "Please type the exact text above", + "高危状态码重试风险输入不匹配提示": "The input does not match the required text", "例如:401, 403, 429, 500-599": "e.g. 401,403,429,500-599", "自动选择": "Auto Select", "自定义充值数量选项": "Custom Recharge Amount Options", @@ -2343,46 +2354,9 @@ "输入验证码完成设置": "Enter verification code to complete setup", "输出": "Output", "输出 {{completion}} tokens / 1M tokens * {{symbol}}{{compPrice}}) * {{ratioType}} {{ratio}}": "Output {{completion}} tokens / 1M tokens * {{symbol}}{{compPrice}} * {{ratioType}} {{ratio}}", - "磁盘缓存设置(磁盘换内存)": "Disk Cache Settings (Disk Swap Memory)", - "启用磁盘缓存后,大请求体将临时存储到磁盘而非内存,可显著降低内存占用,适用于处理包含大量图片/文件的请求。建议在 SSD 环境下使用。": "When enabled, large request bodies are temporarily stored on disk instead of memory, significantly reducing memory usage. Suitable for requests with large images/files. SSD recommended.", - "启用磁盘缓存": "Enable Disk Cache", - "将大请求体临时存储到磁盘": "Store large request bodies temporarily on disk", - "磁盘缓存阈值 (MB)": "Disk Cache Threshold (MB)", - "请求体超过此大小时使用磁盘缓存": "Use disk cache when request body exceeds this size", - "磁盘缓存最大总量 (MB)": "Max Disk Cache Size (MB)", - "可用空间: {{free}} / 总空间: {{total}}": "Free: {{free}} / Total: {{total}}", - "磁盘缓存占用的最大空间": "Maximum space occupied by disk cache", - "留空使用系统临时目录": "Leave empty to use system temp directory", - "例如 /var/cache/new-api": "e.g. /var/cache/new-api", - "性能监控": "Performance Monitor", - "刷新统计": "Refresh Stats", - "重置统计": "Reset Stats", - "执行 GC": "Run GC", - "请求体磁盘缓存": "Request Body Disk Cache", - "活跃文件": "Active Files", - "磁盘命中": "Disk Hits", - "请求体内存缓存": "Request Body Memory Cache", - "当前缓存大小": "Current Cache Size", - "活跃缓存数": "Active Cache Count", - "内存命中": "Memory Hits", - "缓存目录磁盘空间": "Cache Directory Disk Space", - "磁盘可用空间小于缓存最大总量设置": "Disk free space is less than max cache size setting", - "已分配内存": "Allocated Memory", - "总分配内存": "Total Allocated Memory", - "系统内存": "System Memory", - "GC 次数": "GC Count", - "Goroutine 数": "Goroutine Count", - "目录文件数": "Directory File Count", - "目录总大小": "Directory Total Size", - "磁盘缓存已清理": "Disk cache cleared", - "清理失败": "Cleanup failed", - "统计已重置": "Statistics reset", - "重置失败": "Reset failed", - "GC 已执行": "GC executed", "GC 执行失败": "GC execution failed", "缓存目录": "Cache Directory", "可用": "Available", - "输出价格": "Output Price", "输出价格:{{symbol}}{{price}} * {{completionRatio}} = {{symbol}}{{total}} / 1M tokens (补全倍率: {{completionRatio}})": "Output price: {{symbol}}{{price}} * {{completionRatio}} = {{symbol}}{{total}} / 1M tokens (Completion ratio: {{completionRatio}})", "输出倍率 {{completionRatio}}": "Output ratio {{completionRatio}}", "边栏设置": "Sidebar Settings", @@ -2545,6 +2519,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "errors", + "退款": "Refund", + "错误详情": "Error Details", + "异步任务退款": "Async Task Refund", + "任务ID": "Task ID", + "失败原因": "Failure Reason", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "The key is the group name, and the value is another JSON object. The key is the group name, and the value is the special group ratio for users in that group. For example: {\"vip\": {\"default\": 0.5, \"test\": 1}} means that users in the vip group have a ratio of 0.5 when using tokens from the default group, and a ratio of 1 when using tokens from the test group", "键为原状态码,值为要复写的状态码,仅影响本地判断": "The key is the original status code, and the value is the status code to override, only affects local judgment", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.", @@ -2856,6 +2835,7 @@ "缓存写": "Cache Write", "写": "Write", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Per Anthropic conventions, /v1/messages input tokens count only non-cached input and exclude cache read/write tokens.", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "No matching models. Press Enter to add \"{{name}}\" as a custom model name." } } diff --git a/web/src/i18n/locales/fr.json b/web/src/i18n/locales/fr.json index d4c76db69..c36b969dd 100644 --- a/web/src/i18n/locales/fr.json +++ b/web/src/i18n/locales/fr.json @@ -304,7 +304,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "ID de la tâche", - "任务ID": "ID de la tâche", "任务日志": "Tâches", "任务状态": "Statut de la tâche", "任务记录": "Tâches", @@ -792,7 +791,6 @@ "天": "Jour", "天前": "il y a des jours", "失败": "Échec", - "失败原因": "Raison de l'échec", "失败时自动禁用通道": "Désactiver automatiquement le canal en cas d'échec", "失败重试次数": "Nombre de tentatives en cas d'échec", "奖励说明": "Description de la récompense", @@ -1648,10 +1646,14 @@ "点击查看差异": "Cliquez pour voir les différences", "点击此处": "cliquez ici", "点击预览视频": "Cliquez pour prévisualiser la vidéo", + "点击预览音乐": "Cliquez pour écouter la musique", + "音乐预览": "Aperçu musical", + "音频无法播放": "Impossible de lire l'audio", "点击验证按钮,使用您的生物特征或安全密钥": "Cliquez sur le bouton de vérification pour utiliser vos caractéristiques biométriques ou votre clé de sécurité", "版权所有": "Tous droits réservés", "状态": "Statut", "状态码复写": "Remplacement du code d'état", + "状态码复写包含无效的状态码": "Le remplacement du code d'état contient des codes d'état invalides", "状态筛选": "Filtre d'état", "状态页面Slug": "Slug de la page d'état", "环境变量": "Environment Variables", @@ -2508,6 +2510,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "Erreur", + "退款": "Remboursement", + "错误详情": "Détails de l'erreur", + "异步任务退款": "Remboursement de tâche asynchrone", + "任务ID": "ID de tâche", + "失败原因": "Raison de l'échec", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "La clé est le nom du groupe, la valeur est un autre objet JSON, la clé est le nom du groupe, la valeur est le ratio de groupe spécial des utilisateurs de ce groupe, par exemple : {\"vip\": {\"default\": 0.5, \"test\": 1}}, ce qui signifie que les utilisateurs du groupe vip ont un ratio de 0.5 lors de l'utilisation de jetons du groupe default et un ratio de 1 lors de l'utilisation du groupe test", "键为原状态码,值为要复写的状态码,仅影响本地判断": "La clé est le code d'état d'origine, la valeur est le code d'état à réécrire, n'affecte que le jugement local", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "La clé correspond au nom du groupe d'utilisateurs et la valeur à un objet de mappage des opérations. Les clés internes commençant par \"+:\" ajoutent le groupe indiqué (clé = nom du groupe, valeur = description), celles commençant par \"-:\" retirent le groupe indiqué, et les clés sans préfixe ajoutent directement ce groupe. Exemple : {\"vip\": {\"+:premium\": \"Groupe avancé\", \"special\": \"Groupe spécial\", \"-:default\": \"Groupe par défaut\"}} signifie que les utilisateurs du groupe vip peuvent accéder aux groupes premium et special tout en perdant l'accès au groupe default.", @@ -2730,6 +2737,7 @@ "缓存写": "Écriture cache", "写": "Écriture", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Selon la convention Anthropic, les tokens d'entrée de /v1/messages ne comptent que les entrées non mises en cache et excluent les tokens de lecture/écriture du cache.", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "Aucun modèle correspondant. Appuyez sur Entrée pour ajouter «{{name}}» comme nom de modèle personnalisé." } } diff --git a/web/src/i18n/locales/ja.json b/web/src/i18n/locales/ja.json index 9ab727ec4..2951e9ea3 100644 --- a/web/src/i18n/locales/ja.json +++ b/web/src/i18n/locales/ja.json @@ -300,7 +300,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "タスクID", - "任务ID": "タスクID", "任务日志": "タスク履歴", "任务状态": "タスクステータス", "任务记录": "タスク履歴", @@ -783,7 +782,6 @@ "天": "日", "天前": "日前", "失败": "失敗", - "失败原因": "失敗理由", "失败时自动禁用通道": "失敗時にチャネルを自動的に無効にする", "失败重试次数": "再試行回数", "奖励说明": "特典説明", @@ -1633,10 +1631,14 @@ "点击查看差异": "差分を表示", "点击此处": "こちらをクリック", "点击预览视频": "動画をプレビュー", + "点击预览音乐": "音楽をプレビュー", + "音乐预览": "音楽プレビュー", + "音频无法播放": "音声を再生できません", "点击验证按钮,使用您的生物特征或安全密钥": "認証ボタンをクリックし、生体情報またはセキュリティキーを使用してください", "版权所有": "All rights reserved", "状态": "ステータス", "状态码复写": "ステータスコードの上書き", + "状态码复写包含无效的状态码": "ステータスコードの上書きに無効なステータスコードが含まれています", "状态筛选": "ステータスフィルター", "状态页面Slug": "ステータスページスラッグ", "环境变量": "Environment Variables", @@ -2491,6 +2493,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "エラー", + "退款": "返金", + "错误详情": "エラー詳細", + "异步任务退款": "非同期タスク返金", + "任务ID": "タスクID", + "失败原因": "失敗の原因", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "キーはグループ名、値は別のJSONオブジェクトです。このオブジェクトのキーには、利用するトークンが属するグループ名を指定し、値にはそのユーザーグループに適用される特別な倍率を指定します。例:{\"vip\": {\"default\": 0.5, \"test\": 1}} は、vipグループのユーザーがdefaultグループのトークンを利用する際の倍率が0.5、testグループのトークンを利用する際の倍率が1になることを示します", "键为原状态码,值为要复写的状态码,仅影响本地判断": "キーは元のステータスコード、値は上書きするステータスコードで、ローカルでの判断にのみ影響します", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.", @@ -2713,6 +2720,7 @@ "缓存写": "キャッシュ書込", "写": "書込", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Anthropic の仕様により、/v1/messages の入力 tokens は非キャッシュ入力のみを集計し、キャッシュ読み取り/書き込み tokens は含みません。", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "一致するモデルが見つかりません。Enterキーで「{{name}}」をカスタムモデル名として追加できます。" } } diff --git a/web/src/i18n/locales/ru.json b/web/src/i18n/locales/ru.json index 97e243d37..82ccb0edf 100644 --- a/web/src/i18n/locales/ru.json +++ b/web/src/i18n/locales/ru.json @@ -307,7 +307,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "ID задачи", - "任务ID": "ID задачи", "任务日志": "Журнал задач", "任务状态": "Статус задачи", "任务记录": "Записи задач", @@ -798,7 +797,6 @@ "天": "день", "天前": "дней назад", "失败": "Неудача", - "失败原因": "Причина неудачи", "失败时自动禁用通道": "Автоматически отключать канал при неудаче", "失败重试次数": "Количество повторных попыток при неудаче", "奖励说明": "Описание награды", @@ -1659,10 +1657,14 @@ "点击查看差异": "Нажмите для просмотра различий", "点击此处": "Нажмите здесь", "点击预览视频": "Нажмите для предварительного просмотра видео", + "点击预览音乐": "Нажмите для прослушивания музыки", + "音乐预览": "Предварительное прослушивание", + "音频无法播放": "Не удалось воспроизвести аудио", "点击验证按钮,使用您的生物特征或安全密钥": "Нажмите кнопку проверки, используйте ваши биометрические данные или ключ безопасности", "版权所有": "Все права защищены", "状态": "Статус", "状态码复写": "Перезапись кода состояния", + "状态码复写包含无效的状态码": "Перезапись кода состояния содержит недопустимые коды состояния", "状态筛选": "Фильтр по статусу", "状态页面Slug": "Slug страницы статуса", "环境变量": "Environment Variables", @@ -2521,6 +2523,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "Ошибка", + "退款": "Возврат", + "错误详情": "Детали ошибки", + "异步任务退款": "Возврат асинхронной задачи", + "任务ID": "ID задачи", + "失败原因": "Причина ошибки", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "Ключ - это имя группы, значение - другой JSON объект, ключ - имя группы, значение - специальный групповой коэффициент для пользователей этой группы, например: {\"vip\": {\"default\": 0.5, \"test\": 1}}, означает, что пользователи группы vip при использовании токенов группы default имеют коэффициент 0.5, при использовании группы test - коэффициент 1", "键为原状态码,值为要复写的状态码,仅影响本地判断": "Ключ - исходный код состояния, значение - код состояния для перезаписи, влияет только на локальную проверку", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Ключ — это название группы пользователей, значение — объект сопоставления операций. Внутренние ключи с префиксом \"+:\" добавляют указанные группы (ключ — название группы, значение — описание), с префиксом \"-:\" удаляют указанные группы, без префикса — сразу добавляют эту группу. Пример: {\"vip\": {\"+:premium\": \"Продвинутая группа\", \"special\": \"Особая группа\", \"-:default\": \"Группа по умолчанию\"}} означает, что пользователи группы vip могут использовать группы premium и special, одновременно теряя доступ к группе default.", @@ -2743,6 +2750,7 @@ "缓存写": "Запись в кэш", "写": "Запись", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Согласно соглашению Anthropic, входные токены /v1/messages учитывают только некэшированный ввод и не включают токены чтения/записи кэша.", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "Совпадающих моделей не найдено. Нажмите Enter, чтобы добавить «{{name}}» как пользовательское имя модели." } } diff --git a/web/src/i18n/locales/vi.json b/web/src/i18n/locales/vi.json index 8875b1b5f..f78620cff 100644 --- a/web/src/i18n/locales/vi.json +++ b/web/src/i18n/locales/vi.json @@ -301,7 +301,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "ID tác vụ", - "任务ID": "ID tác vụ", "任务日志": "Nhật ký tác vụ", "任务状态": "Trạng thái", "任务记录": "Hồ sơ tác vụ", @@ -784,7 +783,6 @@ "天": "ngày", "天前": "ngày trước", "失败": "Thất bại", - "失败原因": "Lý do thất bại", "失败时自动禁用通道": "Tự động vô hiệu hóa kênh khi thất bại", "失败重试次数": "Số lần thử lại thất bại", "奖励说明": "Mô tả phần thưởng", @@ -1775,6 +1773,9 @@ "点击链接重置密码": "Nhấp vào liên kết để đặt lại mật khẩu", "点击阅读": "Nhấp để đọc", "点击预览视频": "Nhấp để xem trước video", + "点击预览音乐": "Nhấp để nghe nhạc", + "音乐预览": "Xem trước nhạc", + "音频无法播放": "Không thể phát âm thanh", "点击验证按钮,使用您的生物特征或安全密钥": "Nhấp vào nút xác minh và sử dụng sinh trắc học hoặc khóa bảo mật của bạn", "版": "Phiên bản", "版本": "Phiên bản", @@ -1784,6 +1785,7 @@ "状态": "Trạng thái", "状态更新时间": "Thời gian cập nhật trạng thái", "状态码复写": "Ghi đè mã trạng thái", + "状态码复写包含无效的状态码": "Ghi đè mã trạng thái chứa mã trạng thái không hợp lệ", "状态筛选": "Lọc trạng thái", "状态页面Slug": "Slug trang trạng thái", "环境变量": "Environment Variables", @@ -3060,10 +3062,14 @@ "销毁容器失败": "Failed to destroy container", "锁定": "Khóa", "错误": "Lỗi", + "退款": "Hoàn tiền", "错误信息": "Thông tin lỗi", "错误日志": "Nhật ký lỗi", "错误码": "Mã lỗi", "错误详情": "Chi tiết lỗi", + "异步任务退款": "Hoàn tiền tác vụ bất đồng bộ", + "任务ID": "ID tác vụ", + "失败原因": "Nguyên nhân thất bại", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "Khóa là tên nhóm và giá trị là một đối tượng JSON khác. Khóa là tên nhóm và giá trị là tỷ lệ nhóm đặc biệt cho người dùng trong nhóm đó. Ví dụ: {\"vip\": {\"default\": 0.5, \"test\": 1}} có nghĩa là người dùng trong nhóm vip có tỷ lệ 0.5 khi sử dụng mã thông báo từ nhóm default và tỷ lệ 1 khi sử dụng mã thông báo từ nhóm test.", "键为原状态码,值为要复写的状态码,仅影响本地判断": "Khóa là mã trạng thái gốc và giá trị là mã trạng thái cần ghi đè, chỉ ảnh hưởng đến phán đoán cục bộ", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.", @@ -3290,6 +3296,7 @@ "缓存写": "Ghi bộ nhớ đệm", "写": "Ghi", "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "Theo quy ước của Anthropic, input tokens của /v1/messages chỉ tính phần đầu vào không dùng cache và không bao gồm tokens đọc/ghi cache.", - "设计版本": "b80c3466cb6feafeb3990c7820e10e50" + "设计版本": "b80c3466cb6feafeb3990c7820e10e50", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "Không tìm thấy mô hình khớp. Nhấn Enter để thêm \"{{name}}\" làm tên mô hình tùy chỉnh." } } diff --git a/web/src/i18n/locales/zh-CN.json b/web/src/i18n/locales/zh-CN.json index 43ce65b7a..fb135f6fb 100644 --- a/web/src/i18n/locales/zh-CN.json +++ b/web/src/i18n/locales/zh-CN.json @@ -298,7 +298,6 @@ "价格重新计算中...": "价格重新计算中...", "价格预估": "价格预估", "任务 ID": "任务 ID", - "任务ID": "任务ID", "任务日志": "任务日志", "任务状态": "任务状态", "任务记录": "任务记录", @@ -539,7 +538,6 @@ "创建": "创建", "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)": "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)", "创建失败": "创建失败", - "创建成功": "创建成功", "创建或选择密钥时,将 Project 设置为 io.cloud": "创建或选择密钥时,将 Project 设置为 io.cloud", "创建新用户账户": "创建新用户账户", "创建新的令牌": "创建新的令牌", @@ -782,7 +780,6 @@ "天": "天", "天前": "天前", "失败": "失败", - "失败原因": "失败原因", "失败时自动禁用通道": "失败时自动禁用通道", "失败重试次数": "失败重试次数", "奖励说明": "奖励说明", @@ -1326,7 +1323,6 @@ "更新失败,请检查输入信息": "更新失败,请检查输入信息", "更新容器配置": "更新容器配置", "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。": "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。", - "更新成功": "更新成功", "更新所有已启用通道余额": "更新所有已启用通道余额", "更新支付设置": "更新支付设置", "更新时间": "更新时间", @@ -1628,10 +1624,14 @@ "点击查看差异": "点击查看差异", "点击此处": "点击此处", "点击预览视频": "点击预览视频", + "点击预览音乐": "点击预览音乐", + "音乐预览": "音乐预览", + "音频无法播放": "音频无法播放", "点击验证按钮,使用您的生物特征或安全密钥": "点击验证按钮,使用您的生物特征或安全密钥", "版权所有": "版权所有", "状态": "状态", "状态码复写": "状态码复写", + "状态码复写包含无效的状态码": "状态码复写包含无效的状态码", "状态筛选": "状态筛选", "状态页面Slug": "状态页面Slug", "环境变量": "环境变量", @@ -1754,7 +1754,6 @@ "确认清除历史日志": "确认清除历史日志", "确认禁用": "确认禁用", "确认补单": "确认补单", - "确认解绑": "确认解绑", "确认解绑 Passkey": "确认解绑 Passkey", "确认设置并完成初始化": "确认设置并完成初始化", "确认重置 Passkey": "确认重置 Passkey", @@ -1932,7 +1931,6 @@ "自动分组auto,从第一个开始选择": "自动分组auto,从第一个开始选择", "自动刷新": "自动刷新", "自动刷新中": "自动刷新中", - "自动检测": "自动检测", "自动模式": "自动模式", "自动测试所有通道间隔时间": "自动测试所有通道间隔时间", "自动禁用": "自动禁用", @@ -1942,6 +1940,19 @@ "自动重试状态码": "自动重试状态码", "自动重试状态码格式不正确": "自动重试状态码格式不正确", "支持填写单个状态码或范围(含首尾),使用逗号分隔": "支持填写单个状态码或范围(含首尾),使用逗号分隔", + "支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响": "支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响", + "高危操作确认": "高危操作确认", + "检测到以下高危状态码重定向规则": "检测到以下高危状态码重定向规则", + "操作确认": "操作确认", + "我确认开启高危重试": "我确认开启高危重试", + "高危状态码重试风险告知与免责声明Markdown": "### ⚠️ 高危操作:504/524 状态码重试风险告知与免责声明\n本项目默认对 `400 (请求错误)`、`504 (网关超时)`和 `524 (cdn发生超时)`状态码不进行重试。\n504 和 524 错误通常意味着**请求已成功送达上游 AI 服务,且上游正在处理,但因处理时间过长导致连接断开**。\n\n开启对此类超时状态码的重定向/重试属于**极高风险操作**。作为本开源项目的使用者,在开启该功能前,您必须仔细阅读并知悉以下严重后果:\n\n#### 一、 核心风险告知(请仔细阅读)\n1. 💸 双重/多重计费风险: 绝大多数 AI 上游厂商对于已经开始处理但因网络原因中断(504/524)的请求**依然会进行扣费**。此时若触发重试,将会向上游发起全新请求,导致您被**双重甚至多重计费**。\n2. ⏳ 客户端严重超时: 单次请求已经触发超时,叠加重试机制将会使总请求耗时成倍增加,导致您的最终客户端(或调用方)出现严重甚至完全无法接受的超时现象。\n3. 💥 请求积压与系统崩溃风险: 强制重试超时请求会长时间占用系统线程和连接数。在高并发场景下,这会导致严重的**请求积压**,进而耗尽系统资源,引发雪崩效应,导致您的整个代理服务崩溃。\n\n#### 二、 风险确认声明\n如果您坚持开启该功能,即代表您作出以下确认:", + "高危状态码重试风险确认输入文本": "我已了解多重计费与崩溃风险,确认开启", + "高危状态码重试风险确认项1": "我已充分阅读并理解:本人已完整阅读上述全部风险提示,完全理解强制重试 504 和 524 状态码可能带来的破坏性后果。", + "高危状态码重试风险确认项2": "我已与上游沟通并确认:本人确认,当前出现的超时问题属于上游服务的瓶颈。本人已与上游提供商进行过沟通,确认上游无法解决该超时问题,因此才采取强制重试方案作为妥协手段。", + "高危状态码重试风险确认项3": "我自愿承担计费损失:本人知晓并接受由此产生的全部双重/多重计费风险,承诺不会因重试导致的账单异常在本项目仓库中提交 Issue 或抱怨。", + "高危状态码重试风险确认项4": "我自愿承担系统稳定性风险:本人知晓该操作可能导致客户端严重超时及服务崩溃。若因本人开启此功能导致请求积压或服务不可用,后果由本人自行承担。", + "高危状态码重试风险输入框占位文案": "请完整输入上方文字", + "高危状态码重试风险输入不匹配提示": "输入内容与要求不一致", "例如:401, 403, 429, 500-599": "例如:401,403,429,500-599", "自动选择": "自动选择", "自定义充值数量选项": "自定义充值数量选项", @@ -2531,6 +2542,11 @@ "销毁容器": "销毁容器", "销毁容器失败": "销毁容器失败", "错误": "错误", + "退款": "退款", + "错误详情": "错误详情", + "异步任务退款": "异步任务退款", + "任务ID": "任务ID", + "失败原因": "失败原因", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1", "键为原状态码,值为要复写的状态码,仅影响本地判断": "键为原状态码,值为要复写的状态码,仅影响本地判断", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限", @@ -2796,6 +2812,7 @@ "缓存读": "缓存读", "缓存写": "缓存写", "写": "写", - "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。" + "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。": "根据 Anthropic 协定,/v1/messages 的输入 tokens 仅统计非缓存输入,不包含缓存读取与缓存写入 tokens。", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加" } } diff --git a/web/src/i18n/locales/zh-TW.json b/web/src/i18n/locales/zh-TW.json index 562a7d543..85be3f9f7 100644 --- a/web/src/i18n/locales/zh-TW.json +++ b/web/src/i18n/locales/zh-TW.json @@ -1628,10 +1628,14 @@ "点击查看差异": "點擊查看差異", "点击此处": "點擊此處", "点击预览视频": "點擊預覽影片", + "点击预览音乐": "點擊預覽音樂", + "音乐预览": "音樂預覽", + "音频无法播放": "音訊無法播放", "点击验证按钮,使用您的生物特征或安全密钥": "點擊驗證按鈕,使用您的生物特徵或安全密鑰", "版权所有": "版權所有", "状态": "狀態", "状态码复写": "狀態碼複寫", + "状态码复写包含无效的状态码": "狀態碼複寫包含無效的狀態碼", "状态筛选": "狀態篩選", "状态页面Slug": "狀態頁面Slug", "环境变量": "環境變數", @@ -1942,6 +1946,19 @@ "自动重试状态码": "自動重試狀態碼", "自动重试状态码格式不正确": "自動重試狀態碼格式不正確", "支持填写单个状态码或范围(含首尾),使用逗号分隔": "支援填寫單個狀態碼或範圍(含首尾),使用逗號分隔", + "支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响": "支援填寫單個狀態碼或範圍(含首尾),使用逗號分隔;504 和 524 一律不重試,不受此處設定影響", + "高危操作确认": "高風險操作確認", + "检测到以下高危状态码重定向规则": "檢測到以下高風險狀態碼重定向規則", + "操作确认": "操作確認", + "我确认开启高危重试": "我確認開啟高風險重試", + "高危状态码重试风险告知与免责声明Markdown": "### ⚠️ 高風險操作:504/524 狀態碼重試風險告知與免責聲明\n\n【背景提示】\n本專案預設對 `400`(請求錯誤)、`504`(閘道逾時)與 `524`(發生逾時)狀態碼不進行重試。504 與 524 錯誤通常代表**請求已成功送達上游 AI 服務,且上游正在處理,但因處理時間過長導致連線中斷**。\n\n開啟此類逾時狀態碼的重定向/重試屬於**極高風險操作**。作為本開源專案使用者,在開啟該功能前,您必須仔細閱讀並知悉以下嚴重後果:\n\n#### 一、 核心風險告知(請仔細閱讀)\n1. 💸 雙重/多重計費風險:多數 AI 上游廠商對於已開始處理但因網路原因中斷(504/524)的請求**仍然會扣費**。此時若觸發重試,將會向上游發起全新請求,導致您被**雙重甚至多重計費**。\n2. ⏳ 用戶端嚴重逾時:單次請求已觸發逾時,疊加重試機制會使總請求耗時成倍增加,導致最終用戶端(或呼叫方)出現嚴重甚至無法接受的逾時現象。\n3. 💥 請求積壓與系統崩潰風險:強制重試逾時請求會長時間占用系統執行緒與連線數。在高併發場景下,這將導致嚴重**請求積壓**,進而耗盡系統資源,引發雪崩效應,造成整個代理服務崩潰。\n\n#### 二、 風險確認聲明\n若您堅持開啟該功能,即代表您作出以下確認:", + "高危状态码重试风险确认输入文本": "我已了解多重計費與崩潰風險,確認開啟", + "高危状态码重试风险确认项1": "我已充分閱讀並理解:本人已完整閱讀上述全部風險提示,完全理解強制重試 504 與 524 狀態碼可能帶來的破壞性後果。", + "高危状态码重试风险确认项2": "我已與上游溝通並確認:本人確認,當前逾時問題屬於上游服務瓶頸。本人已與上游供應商溝通,確認上游無法解決該逾時問題,因此才採取強制重試方案作為妥協手段。", + "高危状态码重试风险确认项3": "我自願承擔計費損失:本人知悉並接受由此產生的全部雙重/多重計費風險,承諾不會因重試導致的帳單異常在本專案倉庫提交 Issue 或抱怨。", + "高危状态码重试风险确认项4": "我自願承擔系統穩定性風險:本人知悉該操作可能導致用戶端嚴重逾時及服務崩潰。若因本人開啟此功能導致請求積壓或服務不可用,後果由本人自行承擔。", + "高危状态码重试风险输入框占位文案": "請完整輸入上方文字", + "高危状态码重试风险输入不匹配提示": "輸入內容與要求不一致", "例如:401, 403, 429, 500-599": "例如:401,403,429,500-599", "自动选择": "自動選擇", "自定义充值数量选项": "自訂儲值數量選項", @@ -2788,6 +2805,7 @@ "填写服务器地址后自动生成:": "填寫伺服器位址後自動生成:", "自动生成:": "自動生成:", "请先填写服务器地址,以自动生成完整的端点 URL": "請先填寫伺服器位址,以自動生成完整的端點 URL", - "端点 URL 必须是完整地址(以 http:// 或 https:// 开头)": "端點 URL 必須是完整位址(以 http:// 或 https:// 開頭)" + "端点 URL 必须是完整地址(以 http:// 或 https:// 开头)": "端點 URL 必須是完整位址(以 http:// 或 https:// 開頭)", + "未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加": "未匹配到模型,按下 Enter 鍵可將「{{name}}」作為自訂模型名稱新增" } } diff --git a/web/src/pages/Setting/Operation/SettingsMonitoring.jsx b/web/src/pages/Setting/Operation/SettingsMonitoring.jsx index 29b55e56c..e4ee116f2 100644 --- a/web/src/pages/Setting/Operation/SettingsMonitoring.jsx +++ b/web/src/pages/Setting/Operation/SettingsMonitoring.jsx @@ -254,7 +254,7 @@ export default function SettingsMonitoring(props) { label={t('自动重试状态码')} placeholder={t('例如:401, 403, 429, 500-599')} extraText={t( - '支持填写单个状态码或范围(含首尾),使用逗号分隔', + '支持填写单个状态码或范围(含首尾),使用逗号分隔;504 和 524 始终不重试,不受此处配置影响', )} field={'AutomaticRetryStatusCodes'} onChange={(value) =>