From ec5c6b28eafb165d402b897f3fd252e0ffe98028 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 15:32:33 +0800 Subject: [PATCH] feat(task): add model redirection, per-call billing, and multipart retry fix for async tasks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Async task model redirection (aligned with sync tasks): - Integrate ModelMappedHelper in RelayTaskSubmit after model name determination, populating OriginModelName / UpstreamModelName on RelayInfo. - All task adaptors now send UpstreamModelName to upstream providers: - Gemini & Vertex: BuildRequestURL uses UpstreamModelName. - Doubao & Ali: BuildRequestBody conditionally overwrites body.Model. - Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo and unconditionally uses info.UpstreamModelName. - Sora: BuildRequestBody parses JSON and multipart bodies to replace the "model" field with UpstreamModelName. - Frontend log visibility: LogTaskConsumption and taskBillingOther now emit is_model_mapped / upstream_model_name in the "other" JSON field. - Billing safety: RecalculateTaskQuotaByTokens reads model name from BillingContext.OriginModelName (via taskModelName) instead of task.Data["model"], preventing billing leaks from upstream model names. 2. Per-call billing (TaskPricePatches lifecycle): - Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling bool field, populated from TaskPricePatches at submission time. - settleTaskBillingOnComplete short-circuits when PerCallBilling is true, skipping both adaptor adjustments and token-based recalculation. - Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName consistently in controller/relay.go for billing context and logging. 3. Multipart retry boundary mismatch fix: - Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a new boundary and overwrites c.Request.Header["Content-Type"], subsequent calls to ParseMultipartFormReusable on retry would parse the cached original body with the wrong boundary, causing "NextPart: EOF". - Fix: ParseMultipartFormReusable now caches the original Content-Type in gin context key "_original_multipart_ct" on first call and reuses it for all subsequent parses, making multipart parsing retry-safe globally. - Sora adaptor reverted to the standard pattern (direct header set/get), which is now safe thanks to the root fix. 4. Tests: - task_billing_test.go: update makeTask to use OriginModelName; add PerCallBilling settlement tests (skip adaptor adjust, skip token recalc); add non-per-call adaptor adjustment test with refund verification. --- common/gin.go | 10 +- controller/relay.go | 17 +-- controller/task.go | 26 ++++- model/task.go | 11 +- relay/channel/task/ali/adaptor.go | 8 +- relay/channel/task/doubao/adaptor.go | 6 +- relay/channel/task/gemini/adaptor.go | 2 +- relay/channel/task/hailuo/adaptor.go | 8 +- relay/channel/task/jimeng/adaptor.go | 6 +- relay/channel/task/kling/adaptor.go | 9 +- relay/channel/task/sora/adaptor.go | 55 +++++++++ relay/channel/task/vertex/adaptor.go | 2 +- relay/channel/task/vidu/adaptor.go | 6 +- relay/relay_task.go | 9 +- service/task_billing.go | 29 ++--- service/task_billing_test.go | 108 +++++++++++++++++- service/task_polling.go | 5 + .../table/task-logs/TaskLogsColumnDefs.jsx | 36 +++--- web/src/components/table/task-logs/index.jsx | 2 - 19 files changed, 277 insertions(+), 78 deletions(-) diff --git a/common/gin.go b/common/gin.go index 48971c130..009e39080 100644 --- a/common/gin.go +++ b/common/gin.go @@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { return nil, err } - contentType := c.Request.Header.Get("Content-Type") + // Use the original Content-Type saved on first call to avoid boundary + // mismatch when callers overwrite c.Request.Header after multipart rebuild. + var contentType string + if saved, ok := c.Get("_original_multipart_ct"); ok { + contentType = saved.(string) + } else { + contentType = c.Request.Header.Get("Content-Type") + c.Set("_original_multipart_ct", contentType) + } boundary, err := parseBoundary(contentType) if err != nil { return nil, err diff --git a/controller/relay.go b/controller/relay.go index 6951974c5..7e7922e75 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -518,7 +518,7 @@ func RelayTask(c *gin.Context) { } addUsedChannel(c, channel.Id) - requestBody, bodyErr := common.GetRequestBody(c) + bodyStorage, bodyErr := common.GetBodyStorage(c) if bodyErr != nil { if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) { taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge) @@ -527,7 +527,7 @@ func RelayTask(c *gin.Context) { } break } - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + c.Request.Body = io.NopCloser(bodyStorage) result, taskErr = relay.RelayTaskSubmit(c, relayInfo) if taskErr == nil { @@ -557,7 +557,7 @@ func RelayTask(c *gin.Context) { if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil { common.SysError("settle task billing error: " + settleErr.Error()) } - service.LogTaskConsumption(c, relayInfo, result.ModelName) + service.LogTaskConsumption(c, relayInfo) task := model.InitTask(result.Platform, relayInfo) task.PrivateData.UpstreamTaskID = result.UpstreamTaskID @@ -565,11 +565,12 @@ func RelayTask(c *gin.Context) { task.PrivateData.SubscriptionId = relayInfo.SubscriptionId task.PrivateData.TokenId = relayInfo.TokenId task.PrivateData.BillingContext = &model.TaskBillingContext{ - ModelPrice: relayInfo.PriceData.ModelPrice, - GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, - ModelRatio: relayInfo.PriceData.ModelRatio, - OtherRatios: relayInfo.PriceData.OtherRatios, - ModelName: result.ModelName, + ModelPrice: relayInfo.PriceData.ModelPrice, + GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, + ModelRatio: relayInfo.PriceData.ModelRatio, + OtherRatios: relayInfo.PriceData.OtherRatios, + OriginModelName: relayInfo.OriginModelName, + PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName), } task.Quota = result.Quota task.Data = result.TaskData diff --git a/controller/task.go b/controller/task.go index ec713c5d2..eac7db153 100644 --- a/controller/task.go +++ b/controller/task.go @@ -9,6 +9,7 @@ import ( "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" ) @@ -37,7 +38,7 @@ func GetAllTask(c *gin.Context) { items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllTasks(queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(tasksToDto(items)) + pageInfo.SetItems(tasksToDto(items, true)) common.ApiSuccess(c, pageInfo) } @@ -61,13 +62,32 @@ func GetUserTask(c *gin.Context) { items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllUserTask(userId, queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(tasksToDto(items)) + pageInfo.SetItems(tasksToDto(items, false)) common.ApiSuccess(c, pageInfo) } -func tasksToDto(tasks []*model.Task) []*dto.TaskDto { +func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto { + var userIdMap map[int]*model.UserBase + if fillUser { + userIdMap = make(map[int]*model.UserBase) + userIds := types.NewSet[int]() + for _, task := range tasks { + userIds.Add(task.UserId) + } + for _, userId := range userIds.Items() { + cacheUser, err := model.GetUserCache(userId) + if err == nil { + userIdMap[userId] = cacheUser + } + } + } result := make([]*dto.TaskDto, len(tasks)) for i, task := range tasks { + if fillUser { + if user, ok := userIdMap[task.UserId]; ok { + task.Username = user.Username + } + } result[i] = relay.TaskModel2Dto(task) } return result diff --git a/model/task.go b/model/task.go index 0cf6bd47e..da3be34ed 100644 --- a/model/task.go +++ b/model/task.go @@ -109,11 +109,12 @@ type TaskPrivateData struct { // TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。 type TaskBillingContext struct { - ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 - GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 - ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 - OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) - ModelName string `json:"model_name,omitempty"` // 模型名称 + ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 + GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 + ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 + OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) + OriginModelName string `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName + PerCallBilling bool `json:"per_call_billing,omitempty"` // 按次计费:跳过轮询阶段的差额结算 } // GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index f55178b3b..f698fc9f6 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -253,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) } func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) { + upstreamModel := req.Model + if info.IsModelMapped { + upstreamModel = info.UpstreamModelName + } aliReq := &AliVideoRequest{ - Model: req.Model, + Model: upstreamModel, Input: AliVideoInput{ Prompt: req.Prompt, ImgURL: req.InputReference, @@ -332,7 +336,7 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay } } - if aliReq.Model != req.Model { + if aliReq.Model != upstreamModel { return nil, errors.New("can't change model with metadata") } diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index eca421bd3..8f1d748ce 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -131,7 +131,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } - info.UpstreamModelName = body.Model + if info.IsModelMapped { + body.Model = info.UpstreamModelName + } else { + info.UpstreamModelName = body.Model + } data, err := common.Marshal(body) if err != nil { return nil, err diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index 06c00a469..5644cd5dc 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -105,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { - modelName := info.OriginModelName + modelName := info.UpstreamModelName version := model_setting.GetGeminiVersionSetting(modelName) return fmt.Sprintf( diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index ab83d659b..28b3a97f1 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -61,7 +61,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, fmt.Errorf("invalid request type in context") } - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } @@ -142,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string { return ChannelName } -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) { - modelConfig := GetModelConfig(req.Model) +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) { + modelConfig := GetModelConfig(info.UpstreamModelName) duration := DefaultDuration if req.Duration > 0 { duration = req.Duration @@ -154,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } videoRequest := &VideoRequest{ - Model: req.Model, + Model: info.UpstreamModelName, Prompt: req.Prompt, Duration: &duration, Resolution: resolution, diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index b61cca418..e6211b1e4 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -165,7 +165,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } } - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } @@ -378,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte { return h.Sum(nil) } -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ - ReqKey: req.Model, + ReqKey: info.UpstreamModelName, Prompt: req.Prompt, } diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 46e210f19..cdbb56878 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -150,7 +150,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } req := v.(relaycommon.TaskSubmitReq) - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, err } @@ -248,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ Prompt: req.Prompt, Image: req.Image, Mode: taskcommon.DefaultString(req.Mode, "std"), Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)), AspectRatio: a.getAspectRatio(req.Size), - ModelName: req.Model, - Model: req.Model, // Keep consistent with model_name, double writing improves compatibility + ModelName: info.UpstreamModelName, + Model: info.UpstreamModelName, CfgScale: 0.5, StaticMask: "", DynamicMasks: []DynamicMask{}, @@ -266,6 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } if r.ModelName == "" { r.ModelName = "kling-v1" + r.Model = "kling-v1" } if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index bf2f70053..33db8fe55 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -1,8 +1,10 @@ package sora import ( + "bytes" "fmt" "io" + "mime/multipart" "net/http" "strconv" "strings" @@ -145,6 +147,59 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "get_request_body_failed") } + cachedBody, err := storage.Bytes() + if err != nil { + return nil, errors.Wrap(err, "read_body_bytes_failed") + } + contentType := c.GetHeader("Content-Type") + + if strings.HasPrefix(contentType, "application/json") { + var bodyMap map[string]interface{} + if err := common.Unmarshal(cachedBody, &bodyMap); err == nil { + bodyMap["model"] = info.UpstreamModelName + if newBody, err := common.Marshal(bodyMap); err == nil { + return bytes.NewReader(newBody), nil + } + } + return bytes.NewReader(cachedBody), nil + } + + if strings.Contains(contentType, "multipart/form-data") { + formData, err := common.ParseMultipartFormReusable(c) + if err != nil { + return bytes.NewReader(cachedBody), nil + } + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + writer.WriteField("model", info.UpstreamModelName) + for key, values := range formData.Value { + if key == "model" { + continue + } + for _, v := range values { + writer.WriteField(key, v) + } + } + for fieldName, fileHeaders := range formData.File { + for _, fh := range fileHeaders { + f, err := fh.Open() + if err != nil { + continue + } + part, err := writer.CreateFormFile(fieldName, fh.Filename) + if err != nil { + f.Close() + continue + } + io.Copy(part, f) + f.Close() + } + } + writer.Close() + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return &buf, nil + } + return common.ReaderOnly(storage), nil } diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index 4931002dd..700e60976 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -86,7 +86,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } - modelName := info.OriginModelName + modelName := info.UpstreamModelName if modelName == "" { modelName = "veo-3.0-generate-001" } diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index e689bf888..6ae1c181b 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } req := v.(relaycommon.TaskSubmitReq) - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, err } @@ -224,9 +224,9 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ - Model: taskcommon.DefaultString(req.Model, "viduq1"), + Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"), Images: req.Images, Prompt: req.Prompt, Duration: taskcommon.DefaultInt(req.Duration, 5), diff --git a/relay/relay_task.go b/relay/relay_task.go index cd43e6ebb..c740facdb 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -26,7 +26,6 @@ type TaskSubmitResult struct { UpstreamTaskID string TaskData []byte Platform constant.TaskPlatform - ModelName string Quota int //PerCallPrice types.PriceData } @@ -163,6 +162,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe modelName = service.CoverTaskActionToModelName(platform, info.Action) } + // 2.5 应用渠道的模型映射(与同步任务对齐) + info.OriginModelName = modelName + info.UpstreamModelName = modelName + if err := helper.ModelMappedHelper(c, info, nil); err != nil { + return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest) + } + // 3. 预生成公开 task ID(仅首次) if info.PublicTaskID == "" { info.PublicTaskID = model.GenerateTaskID() @@ -241,7 +247,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe UpstreamTaskID: upstreamTaskID, TaskData: taskData, Platform: platform, - ModelName: modelName, Quota: finalQuota, }, nil } diff --git a/service/task_billing.go b/service/task_billing.go index 78ad0fc09..0da4cf431 100644 --- a/service/task_billing.go +++ b/service/task_billing.go @@ -16,11 +16,11 @@ import ( // LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。 // 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。 -func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) { +func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("操作 %s", info.Action) // 支持任务仅按次计费 - if common.StringsContains(constant.TaskPricePatches, modelName) { + if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) { logContent = fmt.Sprintf("%s,按次计费", logContent) } else { if len(info.PriceData.OtherRatios) > 0 { @@ -42,9 +42,13 @@ func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName s if info.PriceData.GroupRatioInfo.HasSpecialRatio { other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio } + if info.IsModelMapped { + other["is_model_mapped"] = true + other["upstream_model_name"] = info.UpstreamModelName + } model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ ChannelId: info.ChannelId, - ModelName: modelName, + ModelName: info.OriginModelName, TokenName: tokenName, Quota: info.PriceData.Quota, Content: logContent, @@ -120,13 +124,18 @@ func taskBillingOther(task *model.Task) map[string]interface{} { } } } + props := task.Properties + if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName { + other["is_model_mapped"] = true + other["upstream_model_name"] = props.UpstreamModelName + } return other } // taskModelName 从 BillingContext 或 Properties 中获取模型名称。 func taskModelName(task *model.Task) string { - if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" { - return bc.ModelName + if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" { + return bc.OriginModelName } return task.Properties.OriginModelName } @@ -237,15 +246,7 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo return } - // 获取模型名称 - var taskData map[string]interface{} - if err := common.Unmarshal(task.Data, &taskData); err != nil { - return - } - modelName, ok := taskData["model"].(string) - if !ok || modelName == "" { - return - } + modelName := taskModelName(task) // 获取模型价格和倍率 modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) diff --git a/service/task_billing_test.go b/service/task_billing_test.go index 6c2d231d5..1145bba54 100644 --- a/service/task_billing_test.go +++ b/service/task_billing_test.go @@ -3,12 +3,14 @@ package service import ( "context" "encoding/json" + "net/http" "os" "testing" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/glebarez/sqlite" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -125,7 +127,7 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc BillingContext: &model.TaskBillingContext{ ModelPrice: 0.02, GroupRatio: 1.0, - ModelName: "test-model", + OriginModelName: "test-model", }, }, } @@ -604,3 +606,107 @@ func TestNonTerminalUpdate_NoBilling(t *testing.T) { require.NoError(t, model.DB.First(&reloaded, task.ID).Error) assert.Equal(t, "50%", reloaded.Progress) } + +// =========================================================================== +// Mock adaptor for settleTaskBillingOnComplete tests +// =========================================================================== + +type mockAdaptor struct { + adjustReturn int +} + +func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {} +func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil } +func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil } +func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { + return m.adjustReturn +} + +// =========================================================================== +// PerCallBilling tests — settleTaskBillingOnComplete +// =========================================================================== + +func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 30, 30, 30 + const initQuota, preConsumed = 10000, 5000 + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.PrivateData.BillingContext.PerCallBilling = true + + adaptor := &mockAdaptor{adjustReturn: 2000} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Per-call: no adjustment despite adaptor returning 2000 + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, preConsumed, task.Quota) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 31, 31, 31 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 7000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.PrivateData.BillingContext.PerCallBilling = true + + adaptor := &mockAdaptor{adjustReturn: 0} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Per-call: no recalculation by tokens + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, preConsumed, task.Quota) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 32, 32, 32 + const initQuota, preConsumed = 10000, 5000 + const adaptorQuota = 3000 + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + // PerCallBilling defaults to false + + adaptor := &mockAdaptor{adjustReturn: adaptorQuota} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Non-per-call: adaptor adjustment applies (refund 2000) + assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID)) + assert.Equal(t, adaptorQuota, task.Quota) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} diff --git a/service/task_polling.go b/service/task_polling.go index 7e92d14ba..a03fc9b88 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -467,6 +467,11 @@ func truncateBase64(s string) string { // 2. taskResult.TotalTokens > 0 → 按 token 重算 // 3. 都不满足 → 保持预扣额度不变 func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) { + // 0. 按次计费的任务不做差额结算 + if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling { + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID)) + return + } // 1. 优先让 adaptor 决定最终额度 if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 { RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整") diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index 4bce45256..7fddb0a50 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -84,8 +84,8 @@ function renderDuration(submit_time, finishTime) { // 返回带有样式的颜色标签 return ( - }> - {durationSec} 秒 + + {durationSec} s ); } @@ -149,7 +149,7 @@ const renderPlatform = (platform, t) => { ); if (option) { return ( - }> + {option.label} ); @@ -157,13 +157,13 @@ const renderPlatform = (platform, t) => { switch (platform) { case 'suno': return ( - }> + Suno ); default: return ( - }> + {t('未知')} ); @@ -240,7 +240,6 @@ export const getTaskLogsColumns = ({ openContentModal, isAdminUser, openVideoModal, - showUserInfoFunc, }) => { return [ { @@ -278,7 +277,6 @@ export const getTaskLogsColumns = ({ color={colors[parseInt(text) % colors.length]} size='large' shape='circle' - prefixIcon={} onClick={() => { copyText(text); }} @@ -294,7 +292,7 @@ export const getTaskLogsColumns = ({ { key: COLUMN_KEYS.USERNAME, title: t('用户'), - dataIndex: 'user_id', + dataIndex: 'username', render: (userId, record, index) => { if (!isAdminUser) { return <>; @@ -302,22 +300,14 @@ export const getTaskLogsColumns = ({ const displayText = String(record.username || userId || '?'); return ( - - showUserInfoFunc && showUserInfoFunc(userId)} - > - {displayText.slice(0, 1)} - - - showUserInfoFunc && showUserInfoFunc(userId)} + - {userId} + {displayText.slice(0, 1)} + + + {displayText} ); diff --git a/web/src/components/table/task-logs/index.jsx b/web/src/components/table/task-logs/index.jsx index 140725a89..bc5b91787 100644 --- a/web/src/components/table/task-logs/index.jsx +++ b/web/src/components/table/task-logs/index.jsx @@ -25,7 +25,6 @@ import TaskLogsActions from './TaskLogsActions'; import TaskLogsFilters from './TaskLogsFilters'; import ColumnSelectorModal from './modals/ColumnSelectorModal'; import ContentModal from './modals/ContentModal'; -import UserInfoModal from '../usage-logs/modals/UserInfoModal'; import { useTaskLogsData } from '../../../hooks/task-logs/useTaskLogsData'; import { useIsMobile } from '../../../hooks/common/useIsMobile'; import { createCardProPagination } from '../../../helpers/utils'; @@ -46,7 +45,6 @@ const TaskLogsPage = () => { modalContent={taskLogsData.videoUrl} isVideo={true} /> -