feat(task): add model redirection, per-call billing, and multipart retry fix for async tasks

1. Async task model redirection (aligned with sync tasks):
   - Integrate ModelMappedHelper in RelayTaskSubmit after model name
     determination, populating OriginModelName / UpstreamModelName on RelayInfo.
   - All task adaptors now send UpstreamModelName to upstream providers:
     - Gemini & Vertex: BuildRequestURL uses UpstreamModelName.
     - Doubao & Ali: BuildRequestBody conditionally overwrites body.Model.
     - Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo
       and unconditionally uses info.UpstreamModelName.
     - Sora: BuildRequestBody parses JSON and multipart bodies to replace
       the "model" field with UpstreamModelName.
   - Frontend log visibility: LogTaskConsumption and taskBillingOther now
     emit is_model_mapped / upstream_model_name in the "other" JSON field.
   - Billing safety: RecalculateTaskQuotaByTokens reads model name from
     BillingContext.OriginModelName (via taskModelName) instead of
     task.Data["model"], preventing billing leaks from upstream model names.

2. Per-call billing (TaskPricePatches lifecycle):
   - Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling
     bool field, populated from TaskPricePatches at submission time.
   - settleTaskBillingOnComplete short-circuits when PerCallBilling is true,
     skipping both adaptor adjustments and token-based recalculation.
   - Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName
     consistently in controller/relay.go for billing context and logging.

3. Multipart retry boundary mismatch fix:
   - Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a
     new boundary and overwrites c.Request.Header["Content-Type"], subsequent
     calls to ParseMultipartFormReusable on retry would parse the cached
     original body with the wrong boundary, causing "NextPart: EOF".
   - Fix: ParseMultipartFormReusable now caches the original Content-Type in
     gin context key "_original_multipart_ct" on first call and reuses it for
     all subsequent parses, making multipart parsing retry-safe globally.
   - Sora adaptor reverted to the standard pattern (direct header set/get),
     which is now safe thanks to the root fix.

4. Tests:
   - task_billing_test.go: update makeTask to use OriginModelName; add
     PerCallBilling settlement tests (skip adaptor adjust, skip token recalc);
     add non-per-call adaptor adjustment test with refund verification.
This commit is contained in:
CaIon
2026-02-22 15:32:33 +08:00
parent 9976b311ef
commit ec5c6b28ea
19 changed files with 277 additions and 78 deletions

View File

@@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
return nil, err
}
contentType := c.Request.Header.Get("Content-Type")
// Use the original Content-Type saved on first call to avoid boundary
// mismatch when callers overwrite c.Request.Header after multipart rebuild.
var contentType string
if saved, ok := c.Get("_original_multipart_ct"); ok {
contentType = saved.(string)
} else {
contentType = c.Request.Header.Get("Content-Type")
c.Set("_original_multipart_ct", contentType)
}
boundary, err := parseBoundary(contentType)
if err != nil {
return nil, err

View File

@@ -518,7 +518,7 @@ func RelayTask(c *gin.Context) {
}
addUsedChannel(c, channel.Id)
requestBody, bodyErr := common.GetRequestBody(c)
bodyStorage, bodyErr := common.GetBodyStorage(c)
if bodyErr != nil {
if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
@@ -527,7 +527,7 @@ func RelayTask(c *gin.Context) {
}
break
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
c.Request.Body = io.NopCloser(bodyStorage)
result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
if taskErr == nil {
@@ -557,7 +557,7 @@ func RelayTask(c *gin.Context) {
if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
common.SysError("settle task billing error: " + settleErr.Error())
}
service.LogTaskConsumption(c, relayInfo, result.ModelName)
service.LogTaskConsumption(c, relayInfo)
task := model.InitTask(result.Platform, relayInfo)
task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
@@ -565,11 +565,12 @@ func RelayTask(c *gin.Context) {
task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
task.PrivateData.TokenId = relayInfo.TokenId
task.PrivateData.BillingContext = &model.TaskBillingContext{
ModelPrice: relayInfo.PriceData.ModelPrice,
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
ModelRatio: relayInfo.PriceData.ModelRatio,
OtherRatios: relayInfo.PriceData.OtherRatios,
ModelName: result.ModelName,
ModelPrice: relayInfo.PriceData.ModelPrice,
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
ModelRatio: relayInfo.PriceData.ModelRatio,
OtherRatios: relayInfo.PriceData.OtherRatios,
OriginModelName: relayInfo.OriginModelName,
PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
}
task.Quota = result.Quota
task.Data = result.TaskData

View File

@@ -9,6 +9,7 @@ import (
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
@@ -37,7 +38,7 @@ func GetAllTask(c *gin.Context) {
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
total := model.TaskCountAllTasks(queryParams)
pageInfo.SetTotal(int(total))
pageInfo.SetItems(tasksToDto(items))
pageInfo.SetItems(tasksToDto(items, true))
common.ApiSuccess(c, pageInfo)
}
@@ -61,13 +62,32 @@ func GetUserTask(c *gin.Context) {
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
total := model.TaskCountAllUserTask(userId, queryParams)
pageInfo.SetTotal(int(total))
pageInfo.SetItems(tasksToDto(items))
pageInfo.SetItems(tasksToDto(items, false))
common.ApiSuccess(c, pageInfo)
}
func tasksToDto(tasks []*model.Task) []*dto.TaskDto {
func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto {
var userIdMap map[int]*model.UserBase
if fillUser {
userIdMap = make(map[int]*model.UserBase)
userIds := types.NewSet[int]()
for _, task := range tasks {
userIds.Add(task.UserId)
}
for _, userId := range userIds.Items() {
cacheUser, err := model.GetUserCache(userId)
if err == nil {
userIdMap[userId] = cacheUser
}
}
}
result := make([]*dto.TaskDto, len(tasks))
for i, task := range tasks {
if fillUser {
if user, ok := userIdMap[task.UserId]; ok {
task.Username = user.Username
}
}
result[i] = relay.TaskModel2Dto(task)
}
return result

View File

@@ -109,11 +109,12 @@ type TaskPrivateData struct {
// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
type TaskBillingContext struct {
ModelPrice float64 `json:"model_price,omitempty"` // 模型单价
GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率
ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率
OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
ModelName string `json:"model_name,omitempty"` // 模型名称
ModelPrice float64 `json:"model_price,omitempty"` // 模型单价
GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率
ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率
OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
OriginModelName string `json:"origin_model_name,omitempty"` // 模型名称必须为OriginModelName
PerCallBilling bool `json:"per_call_billing,omitempty"` // 按次计费:跳过轮询阶段的差额结算
}
// GetUpstreamTaskID 获取上游真实 task ID用于与 provider 通信)

View File

@@ -253,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
}
func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
upstreamModel := req.Model
if info.IsModelMapped {
upstreamModel = info.UpstreamModelName
}
aliReq := &AliVideoRequest{
Model: req.Model,
Model: upstreamModel,
Input: AliVideoInput{
Prompt: req.Prompt,
ImgURL: req.InputReference,
@@ -332,7 +336,7 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
}
}
if aliReq.Model != req.Model {
if aliReq.Model != upstreamModel {
return nil, errors.New("can't change model with metadata")
}

View File

@@ -131,7 +131,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
info.UpstreamModelName = body.Model
if info.IsModelMapped {
body.Model = info.UpstreamModelName
} else {
info.UpstreamModelName = body.Model
}
data, err := common.Marshal(body)
if err != nil {
return nil, err

View File

@@ -105,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
modelName := info.OriginModelName
modelName := info.UpstreamModelName
version := model_setting.GetGeminiVersionSetting(modelName)
return fmt.Sprintf(

View File

@@ -61,7 +61,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
return nil, fmt.Errorf("invalid request type in context")
}
body, err := a.convertToRequestPayload(&req)
body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
@@ -142,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
modelConfig := GetModelConfig(req.Model)
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
modelConfig := GetModelConfig(info.UpstreamModelName)
duration := DefaultDuration
if req.Duration > 0 {
duration = req.Duration
@@ -154,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
videoRequest := &VideoRequest{
Model: req.Model,
Model: info.UpstreamModelName,
Prompt: req.Prompt,
Duration: &duration,
Resolution: resolution,

View File

@@ -165,7 +165,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
}
body, err := a.convertToRequestPayload(&req)
body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
@@ -378,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
return h.Sum(nil)
}
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
ReqKey: req.Model,
ReqKey: info.UpstreamModelName,
Prompt: req.Prompt,
}

View File

@@ -150,7 +150,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, err
}
@@ -248,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: taskcommon.DefaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
ModelName: req.Model,
Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
ModelName: info.UpstreamModelName,
Model: info.UpstreamModelName,
CfgScale: 0.5,
StaticMask: "",
DynamicMasks: []DynamicMask{},
@@ -266,6 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
if r.ModelName == "" {
r.ModelName = "kling-v1"
r.Model = "kling-v1"
}
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")

View File

@@ -1,8 +1,10 @@
package sora
import (
"bytes"
"fmt"
"io"
"mime/multipart"
"net/http"
"strconv"
"strings"
@@ -145,6 +147,59 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, errors.Wrap(err, "get_request_body_failed")
}
cachedBody, err := storage.Bytes()
if err != nil {
return nil, errors.Wrap(err, "read_body_bytes_failed")
}
contentType := c.GetHeader("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
var bodyMap map[string]interface{}
if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
bodyMap["model"] = info.UpstreamModelName
if newBody, err := common.Marshal(bodyMap); err == nil {
return bytes.NewReader(newBody), nil
}
}
return bytes.NewReader(cachedBody), nil
}
if strings.Contains(contentType, "multipart/form-data") {
formData, err := common.ParseMultipartFormReusable(c)
if err != nil {
return bytes.NewReader(cachedBody), nil
}
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
writer.WriteField("model", info.UpstreamModelName)
for key, values := range formData.Value {
if key == "model" {
continue
}
for _, v := range values {
writer.WriteField(key, v)
}
}
for fieldName, fileHeaders := range formData.File {
for _, fh := range fileHeaders {
f, err := fh.Open()
if err != nil {
continue
}
part, err := writer.CreateFormFile(fieldName, fh.Filename)
if err != nil {
f.Close()
continue
}
io.Copy(part, f)
f.Close()
}
}
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return &buf, nil
}
return common.ReaderOnly(storage), nil
}

View File

@@ -86,7 +86,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials: %w", err)
}
modelName := info.OriginModelName
modelName := info.UpstreamModelName
if modelName == "" {
modelName = "veo-3.0-generate-001"
}

View File

@@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, err
}
@@ -224,9 +224,9 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
Model: taskcommon.DefaultString(req.Model, "viduq1"),
Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
Images: req.Images,
Prompt: req.Prompt,
Duration: taskcommon.DefaultInt(req.Duration, 5),

View File

@@ -26,7 +26,6 @@ type TaskSubmitResult struct {
UpstreamTaskID string
TaskData []byte
Platform constant.TaskPlatform
ModelName string
Quota int
//PerCallPrice types.PriceData
}
@@ -163,6 +162,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
modelName = service.CoverTaskActionToModelName(platform, info.Action)
}
// 2.5 应用渠道的模型映射(与同步任务对齐)
info.OriginModelName = modelName
info.UpstreamModelName = modelName
if err := helper.ModelMappedHelper(c, info, nil); err != nil {
return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
}
// 3. 预生成公开 task ID仅首次
if info.PublicTaskID == "" {
info.PublicTaskID = model.GenerateTaskID()
@@ -241,7 +247,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
UpstreamTaskID: upstreamTaskID,
TaskData: taskData,
Platform: platform,
ModelName: modelName,
Quota: finalQuota,
}, nil
}

View File

@@ -16,11 +16,11 @@ import (
// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
// 实际扣费已由 BillingSessionPreConsumeBilling + SettleBilling完成。
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) {
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("操作 %s", info.Action)
// 支持任务仅按次计费
if common.StringsContains(constant.TaskPricePatches, modelName) {
if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
logContent = fmt.Sprintf("%s按次计费", logContent)
} else {
if len(info.PriceData.OtherRatios) > 0 {
@@ -42,9 +42,13 @@ func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName s
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
}
if info.IsModelMapped {
other["is_model_mapped"] = true
other["upstream_model_name"] = info.UpstreamModelName
}
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
ModelName: modelName,
ModelName: info.OriginModelName,
TokenName: tokenName,
Quota: info.PriceData.Quota,
Content: logContent,
@@ -120,13 +124,18 @@ func taskBillingOther(task *model.Task) map[string]interface{} {
}
}
}
props := task.Properties
if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
other["is_model_mapped"] = true
other["upstream_model_name"] = props.UpstreamModelName
}
return other
}
// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
func taskModelName(task *model.Task) string {
if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" {
return bc.ModelName
if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
return bc.OriginModelName
}
return task.Properties.OriginModelName
}
@@ -237,15 +246,7 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo
return
}
// 获取模型名称
var taskData map[string]interface{}
if err := common.Unmarshal(task.Data, &taskData); err != nil {
return
}
modelName, ok := taskData["model"].(string)
if !ok || modelName == "" {
return
}
modelName := taskModelName(task)
// 获取模型价格和倍率
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)

View File

@@ -3,12 +3,14 @@ package service
import (
"context"
"encoding/json"
"net/http"
"os"
"testing"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/glebarez/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -125,7 +127,7 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc
BillingContext: &model.TaskBillingContext{
ModelPrice: 0.02,
GroupRatio: 1.0,
ModelName: "test-model",
OriginModelName: "test-model",
},
},
}
@@ -604,3 +606,107 @@ func TestNonTerminalUpdate_NoBilling(t *testing.T) {
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
assert.Equal(t, "50%", reloaded.Progress)
}
// ===========================================================================
// Mock adaptor for settleTaskBillingOnComplete tests
// ===========================================================================
type mockAdaptor struct {
adjustReturn int
}
func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {}
func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil }
func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil }
func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
return m.adjustReturn
}
// ===========================================================================
// PerCallBilling tests — settleTaskBillingOnComplete
// ===========================================================================
func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
truncate(t)
ctx := context.Background()
const userID, tokenID, channelID = 30, 30, 30
const initQuota, preConsumed = 10000, 5000
const tokenRemain = 8000
seedUser(t, userID, initQuota)
seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
seedChannel(t, channelID)
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
task.PrivateData.BillingContext.PerCallBilling = true
adaptor := &mockAdaptor{adjustReturn: 2000}
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
// Per-call: no adjustment despite adaptor returning 2000
assert.Equal(t, initQuota, getUserQuota(t, userID))
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
assert.Equal(t, preConsumed, task.Quota)
assert.Equal(t, int64(0), countLogs(t))
}
func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
truncate(t)
ctx := context.Background()
const userID, tokenID, channelID = 31, 31, 31
const initQuota, preConsumed = 10000, 4000
const tokenRemain = 7000
seedUser(t, userID, initQuota)
seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
seedChannel(t, channelID)
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
task.PrivateData.BillingContext.PerCallBilling = true
adaptor := &mockAdaptor{adjustReturn: 0}
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
// Per-call: no recalculation by tokens
assert.Equal(t, initQuota, getUserQuota(t, userID))
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
assert.Equal(t, preConsumed, task.Quota)
assert.Equal(t, int64(0), countLogs(t))
}
func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
truncate(t)
ctx := context.Background()
const userID, tokenID, channelID = 32, 32, 32
const initQuota, preConsumed = 10000, 5000
const adaptorQuota = 3000
const tokenRemain = 8000
seedUser(t, userID, initQuota)
seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
seedChannel(t, channelID)
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
// PerCallBilling defaults to false
adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
// Non-per-call: adaptor adjustment applies (refund 2000)
assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
assert.Equal(t, adaptorQuota, task.Quota)
log := getLastLog(t)
require.NotNil(t, log)
assert.Equal(t, model.LogTypeRefund, log.Type)
}

View File

@@ -467,6 +467,11 @@ func truncateBase64(s string) string {
// 2. taskResult.TotalTokens > 0 → 按 token 重算
// 3. 都不满足 → 保持预扣额度不变
func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
// 0. 按次计费的任务不做差额结算
if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
return
}
// 1. 优先让 adaptor 决定最终额度
if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")

View File

@@ -84,8 +84,8 @@ function renderDuration(submit_time, finishTime) {
// 返回带有样式的颜色标签
return (
<Tag color={color} shape='circle' prefixIcon={<Clock size={14} />}>
{durationSec}
<Tag color={color} shape='circle'>
{durationSec} s
</Tag>
);
}
@@ -149,7 +149,7 @@ const renderPlatform = (platform, t) => {
);
if (option) {
return (
<Tag color={option.color} shape='circle' prefixIcon={<Video size={14} />}>
<Tag color={option.color} shape='circle'>
{option.label}
</Tag>
);
@@ -157,13 +157,13 @@ const renderPlatform = (platform, t) => {
switch (platform) {
case 'suno':
return (
<Tag color='green' shape='circle' prefixIcon={<Music size={14} />}>
<Tag color='green' shape='circle'>
Suno
</Tag>
);
default:
return (
<Tag color='white' shape='circle' prefixIcon={<HelpCircle size={14} />}>
<Tag color='white' shape='circle'>
{t('未知')}
</Tag>
);
@@ -240,7 +240,6 @@ export const getTaskLogsColumns = ({
openContentModal,
isAdminUser,
openVideoModal,
showUserInfoFunc,
}) => {
return [
{
@@ -278,7 +277,6 @@ export const getTaskLogsColumns = ({
color={colors[parseInt(text) % colors.length]}
size='large'
shape='circle'
prefixIcon={<Hash size={14} />}
onClick={() => {
copyText(text);
}}
@@ -294,7 +292,7 @@ export const getTaskLogsColumns = ({
{
key: COLUMN_KEYS.USERNAME,
title: t('用户'),
dataIndex: 'user_id',
dataIndex: 'username',
render: (userId, record, index) => {
if (!isAdminUser) {
return <></>;
@@ -302,22 +300,14 @@ export const getTaskLogsColumns = ({
const displayText = String(record.username || userId || '?');
return (
<Space>
<Tooltip content={displayText}>
<Avatar
size='extra-small'
color={stringToColor(displayText)}
style={{ cursor: 'pointer' }}
onClick={() => showUserInfoFunc && showUserInfoFunc(userId)}
>
{displayText.slice(0, 1)}
</Avatar>
</Tooltip>
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ cursor: 'pointer', color: 'var(--semi-color-primary)' }}
onClick={() => showUserInfoFunc && showUserInfoFunc(userId)}
<Avatar
size='extra-small'
color={stringToColor(displayText)}
>
{userId}
{displayText.slice(0, 1)}
</Avatar>
<Typography.Text>
{displayText}
</Typography.Text>
</Space>
);

View File

@@ -25,7 +25,6 @@ import TaskLogsActions from './TaskLogsActions';
import TaskLogsFilters from './TaskLogsFilters';
import ColumnSelectorModal from './modals/ColumnSelectorModal';
import ContentModal from './modals/ContentModal';
import UserInfoModal from '../usage-logs/modals/UserInfoModal';
import { useTaskLogsData } from '../../../hooks/task-logs/useTaskLogsData';
import { useIsMobile } from '../../../hooks/common/useIsMobile';
import { createCardProPagination } from '../../../helpers/utils';
@@ -46,7 +45,6 @@ const TaskLogsPage = () => {
modalContent={taskLogsData.videoUrl}
isVideo={true}
/>
<UserInfoModal {...taskLogsData} />
<Layout>
<CardPro