diff --git a/common/gin.go b/common/gin.go
index 48971c130..009e39080 100644
--- a/common/gin.go
+++ b/common/gin.go
@@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
return nil, err
}
- contentType := c.Request.Header.Get("Content-Type")
+ // Use the original Content-Type saved on first call to avoid boundary
+ // mismatch when callers overwrite c.Request.Header after multipart rebuild.
+ var contentType string
+ if saved, ok := c.Get("_original_multipart_ct"); ok {
+ contentType = saved.(string)
+ } else {
+ contentType = c.Request.Header.Get("Content-Type")
+ c.Set("_original_multipart_ct", contentType)
+ }
boundary, err := parseBoundary(contentType)
if err != nil {
return nil, err
diff --git a/controller/relay.go b/controller/relay.go
index 6951974c5..7e7922e75 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -518,7 +518,7 @@ func RelayTask(c *gin.Context) {
}
addUsedChannel(c, channel.Id)
- requestBody, bodyErr := common.GetRequestBody(c)
+ bodyStorage, bodyErr := common.GetBodyStorage(c)
if bodyErr != nil {
if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
@@ -527,7 +527,7 @@ func RelayTask(c *gin.Context) {
}
break
}
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+ c.Request.Body = io.NopCloser(bodyStorage)
result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
if taskErr == nil {
@@ -557,7 +557,7 @@ func RelayTask(c *gin.Context) {
if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
common.SysError("settle task billing error: " + settleErr.Error())
}
- service.LogTaskConsumption(c, relayInfo, result.ModelName)
+ service.LogTaskConsumption(c, relayInfo)
task := model.InitTask(result.Platform, relayInfo)
task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
@@ -565,11 +565,12 @@ func RelayTask(c *gin.Context) {
task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
task.PrivateData.TokenId = relayInfo.TokenId
task.PrivateData.BillingContext = &model.TaskBillingContext{
- ModelPrice: relayInfo.PriceData.ModelPrice,
- GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
- ModelRatio: relayInfo.PriceData.ModelRatio,
- OtherRatios: relayInfo.PriceData.OtherRatios,
- ModelName: result.ModelName,
+ ModelPrice: relayInfo.PriceData.ModelPrice,
+ GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
+ ModelRatio: relayInfo.PriceData.ModelRatio,
+ OtherRatios: relayInfo.PriceData.OtherRatios,
+ OriginModelName: relayInfo.OriginModelName,
+ PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
}
task.Quota = result.Quota
task.Data = result.TaskData
diff --git a/controller/task.go b/controller/task.go
index ec713c5d2..eac7db153 100644
--- a/controller/task.go
+++ b/controller/task.go
@@ -9,6 +9,7 @@ import (
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay"
"github.com/QuantumNous/new-api/service"
+ "github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
@@ -37,7 +38,7 @@ func GetAllTask(c *gin.Context) {
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
total := model.TaskCountAllTasks(queryParams)
pageInfo.SetTotal(int(total))
- pageInfo.SetItems(tasksToDto(items))
+ pageInfo.SetItems(tasksToDto(items, true))
common.ApiSuccess(c, pageInfo)
}
@@ -61,13 +62,32 @@ func GetUserTask(c *gin.Context) {
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
total := model.TaskCountAllUserTask(userId, queryParams)
pageInfo.SetTotal(int(total))
- pageInfo.SetItems(tasksToDto(items))
+ pageInfo.SetItems(tasksToDto(items, false))
common.ApiSuccess(c, pageInfo)
}
-func tasksToDto(tasks []*model.Task) []*dto.TaskDto {
+func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto {
+ var userIdMap map[int]*model.UserBase
+ if fillUser {
+ userIdMap = make(map[int]*model.UserBase)
+ userIds := types.NewSet[int]()
+ for _, task := range tasks {
+ userIds.Add(task.UserId)
+ }
+ for _, userId := range userIds.Items() {
+ cacheUser, err := model.GetUserCache(userId)
+ if err == nil {
+ userIdMap[userId] = cacheUser
+ }
+ }
+ }
result := make([]*dto.TaskDto, len(tasks))
for i, task := range tasks {
+ if fillUser {
+ if user, ok := userIdMap[task.UserId]; ok {
+ task.Username = user.Username
+ }
+ }
result[i] = relay.TaskModel2Dto(task)
}
return result
diff --git a/model/task.go b/model/task.go
index 0cf6bd47e..da3be34ed 100644
--- a/model/task.go
+++ b/model/task.go
@@ -109,11 +109,12 @@ type TaskPrivateData struct {
// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
type TaskBillingContext struct {
- ModelPrice float64 `json:"model_price,omitempty"` // 模型单价
- GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率
- ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率
- OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
- ModelName string `json:"model_name,omitempty"` // 模型名称
+ ModelPrice float64 `json:"model_price,omitempty"` // 模型单价
+ GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率
+ ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率
+ OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
+ OriginModelName string `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName
+ PerCallBilling bool `json:"per_call_billing,omitempty"` // 按次计费:跳过轮询阶段的差额结算
}
// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)
diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go
index f55178b3b..f698fc9f6 100644
--- a/relay/channel/task/ali/adaptor.go
+++ b/relay/channel/task/ali/adaptor.go
@@ -253,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
}
func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
+ upstreamModel := req.Model
+ if info.IsModelMapped {
+ upstreamModel = info.UpstreamModelName
+ }
aliReq := &AliVideoRequest{
- Model: req.Model,
+ Model: upstreamModel,
Input: AliVideoInput{
Prompt: req.Prompt,
ImgURL: req.InputReference,
@@ -332,7 +336,7 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
}
}
- if aliReq.Model != req.Model {
+ if aliReq.Model != upstreamModel {
return nil, errors.New("can't change model with metadata")
}
diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go
index eca421bd3..8f1d748ce 100644
--- a/relay/channel/task/doubao/adaptor.go
+++ b/relay/channel/task/doubao/adaptor.go
@@ -131,7 +131,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
- info.UpstreamModelName = body.Model
+ if info.IsModelMapped {
+ body.Model = info.UpstreamModelName
+ } else {
+ info.UpstreamModelName = body.Model
+ }
data, err := common.Marshal(body)
if err != nil {
return nil, err
diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go
index 06c00a469..5644cd5dc 100644
--- a/relay/channel/task/gemini/adaptor.go
+++ b/relay/channel/task/gemini/adaptor.go
@@ -105,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
- modelName := info.OriginModelName
+ modelName := info.UpstreamModelName
version := model_setting.GetGeminiVersionSetting(modelName)
return fmt.Sprintf(
diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go
index ab83d659b..28b3a97f1 100644
--- a/relay/channel/task/hailuo/adaptor.go
+++ b/relay/channel/task/hailuo/adaptor.go
@@ -61,7 +61,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
return nil, fmt.Errorf("invalid request type in context")
}
- body, err := a.convertToRequestPayload(&req)
+ body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
@@ -142,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
- modelConfig := GetModelConfig(req.Model)
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
+ modelConfig := GetModelConfig(info.UpstreamModelName)
duration := DefaultDuration
if req.Duration > 0 {
duration = req.Duration
@@ -154,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
videoRequest := &VideoRequest{
- Model: req.Model,
+ Model: info.UpstreamModelName,
Prompt: req.Prompt,
Duration: &duration,
Resolution: resolution,
diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go
index b61cca418..e6211b1e4 100644
--- a/relay/channel/task/jimeng/adaptor.go
+++ b/relay/channel/task/jimeng/adaptor.go
@@ -165,7 +165,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
}
- body, err := a.convertToRequestPayload(&req)
+ body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
@@ -378,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
return h.Sum(nil)
}
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
- ReqKey: req.Model,
+ ReqKey: info.UpstreamModelName,
Prompt: req.Prompt,
}
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
index 46e210f19..cdbb56878 100644
--- a/relay/channel/task/kling/adaptor.go
+++ b/relay/channel/task/kling/adaptor.go
@@ -150,7 +150,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
req := v.(relaycommon.TaskSubmitReq)
- body, err := a.convertToRequestPayload(&req)
+ body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, err
}
@@ -248,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: taskcommon.DefaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
- ModelName: req.Model,
- Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
+ ModelName: info.UpstreamModelName,
+ Model: info.UpstreamModelName,
CfgScale: 0.5,
StaticMask: "",
DynamicMasks: []DynamicMask{},
@@ -266,6 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
if r.ModelName == "" {
r.ModelName = "kling-v1"
+ r.Model = "kling-v1"
}
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go
index bf2f70053..33db8fe55 100644
--- a/relay/channel/task/sora/adaptor.go
+++ b/relay/channel/task/sora/adaptor.go
@@ -1,8 +1,10 @@
package sora
import (
+ "bytes"
"fmt"
"io"
+ "mime/multipart"
"net/http"
"strconv"
"strings"
@@ -145,6 +147,59 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, errors.Wrap(err, "get_request_body_failed")
}
+ cachedBody, err := storage.Bytes()
+ if err != nil {
+ return nil, errors.Wrap(err, "read_body_bytes_failed")
+ }
+ contentType := c.GetHeader("Content-Type")
+
+ if strings.HasPrefix(contentType, "application/json") {
+ var bodyMap map[string]interface{}
+ if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
+ bodyMap["model"] = info.UpstreamModelName
+ if newBody, err := common.Marshal(bodyMap); err == nil {
+ return bytes.NewReader(newBody), nil
+ }
+ }
+ return bytes.NewReader(cachedBody), nil
+ }
+
+ if strings.Contains(contentType, "multipart/form-data") {
+ formData, err := common.ParseMultipartFormReusable(c)
+ if err != nil {
+ return bytes.NewReader(cachedBody), nil
+ }
+ var buf bytes.Buffer
+ writer := multipart.NewWriter(&buf)
+ writer.WriteField("model", info.UpstreamModelName)
+ for key, values := range formData.Value {
+ if key == "model" {
+ continue
+ }
+ for _, v := range values {
+ writer.WriteField(key, v)
+ }
+ }
+ for fieldName, fileHeaders := range formData.File {
+ for _, fh := range fileHeaders {
+ f, err := fh.Open()
+ if err != nil {
+ continue
+ }
+ part, err := writer.CreateFormFile(fieldName, fh.Filename)
+ if err != nil {
+ f.Close()
+ continue
+ }
+ io.Copy(part, f)
+ f.Close()
+ }
+ }
+ writer.Close()
+ c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+ return &buf, nil
+ }
+
return common.ReaderOnly(storage), nil
}
diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go
index 4931002dd..700e60976 100644
--- a/relay/channel/task/vertex/adaptor.go
+++ b/relay/channel/task/vertex/adaptor.go
@@ -86,7 +86,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials: %w", err)
}
- modelName := info.OriginModelName
+ modelName := info.UpstreamModelName
if modelName == "" {
modelName = "veo-3.0-generate-001"
}
diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go
index e689bf888..6ae1c181b 100644
--- a/relay/channel/task/vidu/adaptor.go
+++ b/relay/channel/task/vidu/adaptor.go
@@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
}
req := v.(relaycommon.TaskSubmitReq)
- body, err := a.convertToRequestPayload(&req)
+ body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, err
}
@@ -224,9 +224,9 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
- Model: taskcommon.DefaultString(req.Model, "viduq1"),
+ Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
Images: req.Images,
Prompt: req.Prompt,
Duration: taskcommon.DefaultInt(req.Duration, 5),
diff --git a/relay/relay_task.go b/relay/relay_task.go
index cd43e6ebb..c740facdb 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -26,7 +26,6 @@ type TaskSubmitResult struct {
UpstreamTaskID string
TaskData []byte
Platform constant.TaskPlatform
- ModelName string
Quota int
//PerCallPrice types.PriceData
}
@@ -163,6 +162,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
modelName = service.CoverTaskActionToModelName(platform, info.Action)
}
+ // 2.5 应用渠道的模型映射(与同步任务对齐)
+ info.OriginModelName = modelName
+ info.UpstreamModelName = modelName
+ if err := helper.ModelMappedHelper(c, info, nil); err != nil {
+ return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
+ }
+
// 3. 预生成公开 task ID(仅首次)
if info.PublicTaskID == "" {
info.PublicTaskID = model.GenerateTaskID()
@@ -241,7 +247,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
UpstreamTaskID: upstreamTaskID,
TaskData: taskData,
Platform: platform,
- ModelName: modelName,
Quota: finalQuota,
}, nil
}
diff --git a/service/task_billing.go b/service/task_billing.go
index 78ad0fc09..0da4cf431 100644
--- a/service/task_billing.go
+++ b/service/task_billing.go
@@ -16,11 +16,11 @@ import (
// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
-func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) {
+func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("操作 %s", info.Action)
// 支持任务仅按次计费
- if common.StringsContains(constant.TaskPricePatches, modelName) {
+ if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
logContent = fmt.Sprintf("%s,按次计费", logContent)
} else {
if len(info.PriceData.OtherRatios) > 0 {
@@ -42,9 +42,13 @@ func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName s
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
}
+ if info.IsModelMapped {
+ other["is_model_mapped"] = true
+ other["upstream_model_name"] = info.UpstreamModelName
+ }
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
- ModelName: modelName,
+ ModelName: info.OriginModelName,
TokenName: tokenName,
Quota: info.PriceData.Quota,
Content: logContent,
@@ -120,13 +124,18 @@ func taskBillingOther(task *model.Task) map[string]interface{} {
}
}
}
+ props := task.Properties
+ if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
+ other["is_model_mapped"] = true
+ other["upstream_model_name"] = props.UpstreamModelName
+ }
return other
}
// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
func taskModelName(task *model.Task) string {
- if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" {
- return bc.ModelName
+ if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
+ return bc.OriginModelName
}
return task.Properties.OriginModelName
}
@@ -237,15 +246,7 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo
return
}
- // 获取模型名称
- var taskData map[string]interface{}
- if err := common.Unmarshal(task.Data, &taskData); err != nil {
- return
- }
- modelName, ok := taskData["model"].(string)
- if !ok || modelName == "" {
- return
- }
+ modelName := taskModelName(task)
// 获取模型价格和倍率
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
diff --git a/service/task_billing_test.go b/service/task_billing_test.go
index 6c2d231d5..1145bba54 100644
--- a/service/task_billing_test.go
+++ b/service/task_billing_test.go
@@ -3,12 +3,14 @@ package service
import (
"context"
"encoding/json"
+ "net/http"
"os"
"testing"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/glebarez/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -125,7 +127,7 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc
BillingContext: &model.TaskBillingContext{
ModelPrice: 0.02,
GroupRatio: 1.0,
- ModelName: "test-model",
+ OriginModelName: "test-model",
},
},
}
@@ -604,3 +606,107 @@ func TestNonTerminalUpdate_NoBilling(t *testing.T) {
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
assert.Equal(t, "50%", reloaded.Progress)
}
+
+// ===========================================================================
+// Mock adaptor for settleTaskBillingOnComplete tests
+// ===========================================================================
+
+type mockAdaptor struct {
+ adjustReturn int
+}
+
+func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {}
+func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil }
+func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil }
+func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
+ return m.adjustReturn
+}
+
+// ===========================================================================
+// PerCallBilling tests — settleTaskBillingOnComplete
+// ===========================================================================
+
+func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 30, 30, 30
+ const initQuota, preConsumed = 10000, 5000
+ const tokenRemain = 8000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+ task.PrivateData.BillingContext.PerCallBilling = true
+
+ adaptor := &mockAdaptor{adjustReturn: 2000}
+ taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+ // Per-call: no adjustment despite adaptor returning 2000
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
+ assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+ assert.Equal(t, preConsumed, task.Quota)
+ assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 31, 31, 31
+ const initQuota, preConsumed = 10000, 4000
+ const tokenRemain = 7000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+ task.PrivateData.BillingContext.PerCallBilling = true
+
+ adaptor := &mockAdaptor{adjustReturn: 0}
+ taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
+
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+ // Per-call: no recalculation by tokens
+ assert.Equal(t, initQuota, getUserQuota(t, userID))
+ assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+ assert.Equal(t, preConsumed, task.Quota)
+ assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
+ truncate(t)
+ ctx := context.Background()
+
+ const userID, tokenID, channelID = 32, 32, 32
+ const initQuota, preConsumed = 10000, 5000
+ const adaptorQuota = 3000
+ const tokenRemain = 8000
+
+ seedUser(t, userID, initQuota)
+ seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
+ seedChannel(t, channelID)
+
+ task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+ // PerCallBilling defaults to false
+
+ adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
+ taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+ settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+ // Non-per-call: adaptor adjustment applies (refund 2000)
+ assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
+ assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
+ assert.Equal(t, adaptorQuota, task.Quota)
+
+ log := getLastLog(t)
+ require.NotNil(t, log)
+ assert.Equal(t, model.LogTypeRefund, log.Type)
+}
diff --git a/service/task_polling.go b/service/task_polling.go
index 7e92d14ba..a03fc9b88 100644
--- a/service/task_polling.go
+++ b/service/task_polling.go
@@ -467,6 +467,11 @@ func truncateBase64(s string) string {
// 2. taskResult.TotalTokens > 0 → 按 token 重算
// 3. 都不满足 → 保持预扣额度不变
func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
+ // 0. 按次计费的任务不做差额结算
+ if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
+ logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
+ return
+ }
// 1. 优先让 adaptor 决定最终额度
if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx
index 4bce45256..7fddb0a50 100644
--- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx
+++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx
@@ -84,8 +84,8 @@ function renderDuration(submit_time, finishTime) {
// 返回带有样式的颜色标签
return (
- }>
- {durationSec} 秒
+
+ {durationSec} s
);
}
@@ -149,7 +149,7 @@ const renderPlatform = (platform, t) => {
);
if (option) {
return (
- }>
+
{option.label}
);
@@ -157,13 +157,13 @@ const renderPlatform = (platform, t) => {
switch (platform) {
case 'suno':
return (
- }>
+
Suno
);
default:
return (
- }>
+
{t('未知')}
);
@@ -240,7 +240,6 @@ export const getTaskLogsColumns = ({
openContentModal,
isAdminUser,
openVideoModal,
- showUserInfoFunc,
}) => {
return [
{
@@ -278,7 +277,6 @@ export const getTaskLogsColumns = ({
color={colors[parseInt(text) % colors.length]}
size='large'
shape='circle'
- prefixIcon={}
onClick={() => {
copyText(text);
}}
@@ -294,7 +292,7 @@ export const getTaskLogsColumns = ({
{
key: COLUMN_KEYS.USERNAME,
title: t('用户'),
- dataIndex: 'user_id',
+ dataIndex: 'username',
render: (userId, record, index) => {
if (!isAdminUser) {
return <>>;
@@ -302,22 +300,14 @@ export const getTaskLogsColumns = ({
const displayText = String(record.username || userId || '?');
return (
-
- showUserInfoFunc && showUserInfoFunc(userId)}
- >
- {displayText.slice(0, 1)}
-
-
- showUserInfoFunc && showUserInfoFunc(userId)}
+
- {userId}
+ {displayText.slice(0, 1)}
+
+
+ {displayText}
);
diff --git a/web/src/components/table/task-logs/index.jsx b/web/src/components/table/task-logs/index.jsx
index 140725a89..bc5b91787 100644
--- a/web/src/components/table/task-logs/index.jsx
+++ b/web/src/components/table/task-logs/index.jsx
@@ -25,7 +25,6 @@ import TaskLogsActions from './TaskLogsActions';
import TaskLogsFilters from './TaskLogsFilters';
import ColumnSelectorModal from './modals/ColumnSelectorModal';
import ContentModal from './modals/ContentModal';
-import UserInfoModal from '../usage-logs/modals/UserInfoModal';
import { useTaskLogsData } from '../../../hooks/task-logs/useTaskLogsData';
import { useIsMobile } from '../../../hooks/common/useIsMobile';
import { createCardProPagination } from '../../../helpers/utils';
@@ -46,7 +45,6 @@ const TaskLogsPage = () => {
modalContent={taskLogsData.videoUrl}
isVideo={true}
/>
-