mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 09:13:31 +00:00
feat(task): add model redirection, per-call billing, and multipart retry fix for async tasks
1. Async task model redirection (aligned with sync tasks):
- Integrate ModelMappedHelper in RelayTaskSubmit after model name
determination, populating OriginModelName / UpstreamModelName on RelayInfo.
- All task adaptors now send UpstreamModelName to upstream providers:
- Gemini & Vertex: BuildRequestURL uses UpstreamModelName.
- Doubao & Ali: BuildRequestBody conditionally overwrites body.Model.
- Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo
and unconditionally uses info.UpstreamModelName.
- Sora: BuildRequestBody parses JSON and multipart bodies to replace
the "model" field with UpstreamModelName.
- Frontend log visibility: LogTaskConsumption and taskBillingOther now
emit is_model_mapped / upstream_model_name in the "other" JSON field.
- Billing safety: RecalculateTaskQuotaByTokens reads model name from
BillingContext.OriginModelName (via taskModelName) instead of
task.Data["model"], preventing billing leaks from upstream model names.
2. Per-call billing (TaskPricePatches lifecycle):
- Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling
bool field, populated from TaskPricePatches at submission time.
- settleTaskBillingOnComplete short-circuits when PerCallBilling is true,
skipping both adaptor adjustments and token-based recalculation.
- Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName
consistently in controller/relay.go for billing context and logging.
3. Multipart retry boundary mismatch fix:
- Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a
new boundary and overwrites c.Request.Header["Content-Type"], subsequent
calls to ParseMultipartFormReusable on retry would parse the cached
original body with the wrong boundary, causing "NextPart: EOF".
- Fix: ParseMultipartFormReusable now caches the original Content-Type in
gin context key "_original_multipart_ct" on first call and reuses it for
all subsequent parses, making multipart parsing retry-safe globally.
- Sora adaptor reverted to the standard pattern (direct header set/get),
which is now safe thanks to the root fix.
4. Tests:
- task_billing_test.go: update makeTask to use OriginModelName; add
PerCallBilling settlement tests (skip adaptor adjust, skip token recalc);
add non-per-call adaptor adjustment test with refund verification.
This commit is contained in:
@@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
// Use the original Content-Type saved on first call to avoid boundary
|
||||
// mismatch when callers overwrite c.Request.Header after multipart rebuild.
|
||||
var contentType string
|
||||
if saved, ok := c.Get("_original_multipart_ct"); ok {
|
||||
contentType = saved.(string)
|
||||
} else {
|
||||
contentType = c.Request.Header.Get("Content-Type")
|
||||
c.Set("_original_multipart_ct", contentType)
|
||||
}
|
||||
boundary, err := parseBoundary(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -518,7 +518,7 @@ func RelayTask(c *gin.Context) {
|
||||
}
|
||||
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, bodyErr := common.GetRequestBody(c)
|
||||
bodyStorage, bodyErr := common.GetBodyStorage(c)
|
||||
if bodyErr != nil {
|
||||
if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
|
||||
taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
|
||||
@@ -527,7 +527,7 @@ func RelayTask(c *gin.Context) {
|
||||
}
|
||||
break
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
c.Request.Body = io.NopCloser(bodyStorage)
|
||||
|
||||
result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
|
||||
if taskErr == nil {
|
||||
@@ -557,7 +557,7 @@ func RelayTask(c *gin.Context) {
|
||||
if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
|
||||
common.SysError("settle task billing error: " + settleErr.Error())
|
||||
}
|
||||
service.LogTaskConsumption(c, relayInfo, result.ModelName)
|
||||
service.LogTaskConsumption(c, relayInfo)
|
||||
|
||||
task := model.InitTask(result.Platform, relayInfo)
|
||||
task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
|
||||
@@ -565,11 +565,12 @@ func RelayTask(c *gin.Context) {
|
||||
task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
|
||||
task.PrivateData.TokenId = relayInfo.TokenId
|
||||
task.PrivateData.BillingContext = &model.TaskBillingContext{
|
||||
ModelPrice: relayInfo.PriceData.ModelPrice,
|
||||
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
|
||||
ModelRatio: relayInfo.PriceData.ModelRatio,
|
||||
OtherRatios: relayInfo.PriceData.OtherRatios,
|
||||
ModelName: result.ModelName,
|
||||
ModelPrice: relayInfo.PriceData.ModelPrice,
|
||||
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
|
||||
ModelRatio: relayInfo.PriceData.ModelRatio,
|
||||
OtherRatios: relayInfo.PriceData.OtherRatios,
|
||||
OriginModelName: relayInfo.OriginModelName,
|
||||
PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
|
||||
}
|
||||
task.Quota = result.Quota
|
||||
task.Data = result.TaskData
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -37,7 +38,7 @@ func GetAllTask(c *gin.Context) {
|
||||
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllTasks(queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tasksToDto(items))
|
||||
pageInfo.SetItems(tasksToDto(items, true))
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
@@ -61,13 +62,32 @@ func GetUserTask(c *gin.Context) {
|
||||
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tasksToDto(items))
|
||||
pageInfo.SetItems(tasksToDto(items, false))
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func tasksToDto(tasks []*model.Task) []*dto.TaskDto {
|
||||
func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto {
|
||||
var userIdMap map[int]*model.UserBase
|
||||
if fillUser {
|
||||
userIdMap = make(map[int]*model.UserBase)
|
||||
userIds := types.NewSet[int]()
|
||||
for _, task := range tasks {
|
||||
userIds.Add(task.UserId)
|
||||
}
|
||||
for _, userId := range userIds.Items() {
|
||||
cacheUser, err := model.GetUserCache(userId)
|
||||
if err == nil {
|
||||
userIdMap[userId] = cacheUser
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]*dto.TaskDto, len(tasks))
|
||||
for i, task := range tasks {
|
||||
if fillUser {
|
||||
if user, ok := userIdMap[task.UserId]; ok {
|
||||
task.Username = user.Username
|
||||
}
|
||||
}
|
||||
result[i] = relay.TaskModel2Dto(task)
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -109,11 +109,12 @@ type TaskPrivateData struct {
|
||||
|
||||
// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
|
||||
type TaskBillingContext struct {
|
||||
ModelPrice float64 `json:"model_price,omitempty"` // 模型单价
|
||||
GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率
|
||||
ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率
|
||||
OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
|
||||
ModelName string `json:"model_name,omitempty"` // 模型名称
|
||||
ModelPrice float64 `json:"model_price,omitempty"` // 模型单价
|
||||
GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率
|
||||
ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率
|
||||
OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
|
||||
OriginModelName string `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName
|
||||
PerCallBilling bool `json:"per_call_billing,omitempty"` // 按次计费:跳过轮询阶段的差额结算
|
||||
}
|
||||
|
||||
// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)
|
||||
|
||||
@@ -253,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
|
||||
upstreamModel := req.Model
|
||||
if info.IsModelMapped {
|
||||
upstreamModel = info.UpstreamModelName
|
||||
}
|
||||
aliReq := &AliVideoRequest{
|
||||
Model: req.Model,
|
||||
Model: upstreamModel,
|
||||
Input: AliVideoInput{
|
||||
Prompt: req.Prompt,
|
||||
ImgURL: req.InputReference,
|
||||
@@ -332,7 +336,7 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
|
||||
}
|
||||
}
|
||||
|
||||
if aliReq.Model != req.Model {
|
||||
if aliReq.Model != upstreamModel {
|
||||
return nil, errors.New("can't change model with metadata")
|
||||
}
|
||||
|
||||
|
||||
@@ -131,7 +131,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
info.UpstreamModelName = body.Model
|
||||
if info.IsModelMapped {
|
||||
body.Model = info.UpstreamModelName
|
||||
} else {
|
||||
info.UpstreamModelName = body.Model
|
||||
}
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -105,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
modelName := info.OriginModelName
|
||||
modelName := info.UpstreamModelName
|
||||
version := model_setting.GetGeminiVersionSetting(modelName)
|
||||
|
||||
return fmt.Sprintf(
|
||||
|
||||
@@ -61,7 +61,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
return nil, fmt.Errorf("invalid request type in context")
|
||||
}
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
@@ -142,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
|
||||
modelConfig := GetModelConfig(req.Model)
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
|
||||
modelConfig := GetModelConfig(info.UpstreamModelName)
|
||||
duration := DefaultDuration
|
||||
if req.Duration > 0 {
|
||||
duration = req.Duration
|
||||
@@ -154,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
|
||||
videoRequest := &VideoRequest{
|
||||
Model: req.Model,
|
||||
Model: info.UpstreamModelName,
|
||||
Prompt: req.Prompt,
|
||||
Duration: &duration,
|
||||
Resolution: resolution,
|
||||
|
||||
@@ -165,7 +165,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
}
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
@@ -378,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
ReqKey: req.Model,
|
||||
ReqKey: info.UpstreamModelName,
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -248,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: taskcommon.DefaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
ModelName: req.Model,
|
||||
Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
|
||||
ModelName: info.UpstreamModelName,
|
||||
Model: info.UpstreamModelName,
|
||||
CfgScale: 0.5,
|
||||
StaticMask: "",
|
||||
DynamicMasks: []DynamicMask{},
|
||||
@@ -266,6 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
if r.ModelName == "" {
|
||||
r.ModelName = "kling-v1"
|
||||
r.Model = "kling-v1"
|
||||
}
|
||||
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package sora
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -145,6 +147,59 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get_request_body_failed")
|
||||
}
|
||||
cachedBody, err := storage.Bytes()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read_body_bytes_failed")
|
||||
}
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
var bodyMap map[string]interface{}
|
||||
if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
|
||||
bodyMap["model"] = info.UpstreamModelName
|
||||
if newBody, err := common.Marshal(bodyMap); err == nil {
|
||||
return bytes.NewReader(newBody), nil
|
||||
}
|
||||
}
|
||||
return bytes.NewReader(cachedBody), nil
|
||||
}
|
||||
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
formData, err := common.ParseMultipartFormReusable(c)
|
||||
if err != nil {
|
||||
return bytes.NewReader(cachedBody), nil
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
writer.WriteField("model", info.UpstreamModelName)
|
||||
for key, values := range formData.Value {
|
||||
if key == "model" {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
writer.WriteField(key, v)
|
||||
}
|
||||
}
|
||||
for fieldName, fileHeaders := range formData.File {
|
||||
for _, fh := range fileHeaders {
|
||||
f, err := fh.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
part, err := writer.CreateFormFile(fieldName, fh.Filename)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
continue
|
||||
}
|
||||
io.Copy(part, f)
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return &buf, nil
|
||||
}
|
||||
|
||||
return common.ReaderOnly(storage), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
||||
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
modelName := info.OriginModelName
|
||||
modelName := info.UpstreamModelName
|
||||
if modelName == "" {
|
||||
modelName = "veo-3.0-generate-001"
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -224,9 +224,9 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Model: taskcommon.DefaultString(req.Model, "viduq1"),
|
||||
Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
|
||||
Images: req.Images,
|
||||
Prompt: req.Prompt,
|
||||
Duration: taskcommon.DefaultInt(req.Duration, 5),
|
||||
|
||||
@@ -26,7 +26,6 @@ type TaskSubmitResult struct {
|
||||
UpstreamTaskID string
|
||||
TaskData []byte
|
||||
Platform constant.TaskPlatform
|
||||
ModelName string
|
||||
Quota int
|
||||
//PerCallPrice types.PriceData
|
||||
}
|
||||
@@ -163,6 +162,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
||||
modelName = service.CoverTaskActionToModelName(platform, info.Action)
|
||||
}
|
||||
|
||||
// 2.5 应用渠道的模型映射(与同步任务对齐)
|
||||
info.OriginModelName = modelName
|
||||
info.UpstreamModelName = modelName
|
||||
if err := helper.ModelMappedHelper(c, info, nil); err != nil {
|
||||
return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// 3. 预生成公开 task ID(仅首次)
|
||||
if info.PublicTaskID == "" {
|
||||
info.PublicTaskID = model.GenerateTaskID()
|
||||
@@ -241,7 +247,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
||||
UpstreamTaskID: upstreamTaskID,
|
||||
TaskData: taskData,
|
||||
Platform: platform,
|
||||
ModelName: modelName,
|
||||
Quota: finalQuota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -16,11 +16,11 @@ import (
|
||||
|
||||
// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
|
||||
// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
|
||||
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) {
|
||||
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||||
// 支持任务仅按次计费
|
||||
if common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
|
||||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||||
} else {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
@@ -42,9 +42,13 @@ func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName s
|
||||
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
|
||||
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
|
||||
}
|
||||
if info.IsModelMapped {
|
||||
other["is_model_mapped"] = true
|
||||
other["upstream_model_name"] = info.UpstreamModelName
|
||||
}
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: modelName,
|
||||
ModelName: info.OriginModelName,
|
||||
TokenName: tokenName,
|
||||
Quota: info.PriceData.Quota,
|
||||
Content: logContent,
|
||||
@@ -120,13 +124,18 @@ func taskBillingOther(task *model.Task) map[string]interface{} {
|
||||
}
|
||||
}
|
||||
}
|
||||
props := task.Properties
|
||||
if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
|
||||
other["is_model_mapped"] = true
|
||||
other["upstream_model_name"] = props.UpstreamModelName
|
||||
}
|
||||
return other
|
||||
}
|
||||
|
||||
// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
|
||||
func taskModelName(task *model.Task) string {
|
||||
if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" {
|
||||
return bc.ModelName
|
||||
if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
|
||||
return bc.OriginModelName
|
||||
}
|
||||
return task.Properties.OriginModelName
|
||||
}
|
||||
@@ -237,15 +246,7 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo
|
||||
return
|
||||
}
|
||||
|
||||
// 获取模型名称
|
||||
var taskData map[string]interface{}
|
||||
if err := common.Unmarshal(task.Data, &taskData); err != nil {
|
||||
return
|
||||
}
|
||||
modelName, ok := taskData["model"].(string)
|
||||
if !ok || modelName == "" {
|
||||
return
|
||||
}
|
||||
modelName := taskModelName(task)
|
||||
|
||||
// 获取模型价格和倍率
|
||||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||
|
||||
@@ -3,12 +3,14 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -125,7 +127,7 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc
|
||||
BillingContext: &model.TaskBillingContext{
|
||||
ModelPrice: 0.02,
|
||||
GroupRatio: 1.0,
|
||||
ModelName: "test-model",
|
||||
OriginModelName: "test-model",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -604,3 +606,107 @@ func TestNonTerminalUpdate_NoBilling(t *testing.T) {
|
||||
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
|
||||
assert.Equal(t, "50%", reloaded.Progress)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Mock adaptor for settleTaskBillingOnComplete tests
|
||||
// ===========================================================================
|
||||
|
||||
type mockAdaptor struct {
|
||||
adjustReturn int
|
||||
}
|
||||
|
||||
func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {}
|
||||
func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil }
|
||||
func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil }
|
||||
func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
|
||||
return m.adjustReturn
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// PerCallBilling tests — settleTaskBillingOnComplete
|
||||
// ===========================================================================
|
||||
|
||||
func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 30, 30, 30
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.PrivateData.BillingContext.PerCallBilling = true
|
||||
|
||||
adaptor := &mockAdaptor{adjustReturn: 2000}
|
||||
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
|
||||
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
|
||||
// Per-call: no adjustment despite adaptor returning 2000
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, preConsumed, task.Quota)
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 31, 31, 31
|
||||
const initQuota, preConsumed = 10000, 4000
|
||||
const tokenRemain = 7000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.PrivateData.BillingContext.PerCallBilling = true
|
||||
|
||||
adaptor := &mockAdaptor{adjustReturn: 0}
|
||||
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
|
||||
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
|
||||
// Per-call: no recalculation by tokens
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, preConsumed, task.Quota)
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 32, 32, 32
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const adaptorQuota = 3000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
// PerCallBilling defaults to false
|
||||
|
||||
adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
|
||||
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
|
||||
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
|
||||
// Non-per-call: adaptor adjustment applies (refund 2000)
|
||||
assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, adaptorQuota, task.Quota)
|
||||
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
@@ -467,6 +467,11 @@ func truncateBase64(s string) string {
|
||||
// 2. taskResult.TotalTokens > 0 → 按 token 重算
|
||||
// 3. 都不满足 → 保持预扣额度不变
|
||||
func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
|
||||
// 0. 按次计费的任务不做差额结算
|
||||
if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
|
||||
return
|
||||
}
|
||||
// 1. 优先让 adaptor 决定最终额度
|
||||
if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
|
||||
|
||||
@@ -84,8 +84,8 @@ function renderDuration(submit_time, finishTime) {
|
||||
|
||||
// 返回带有样式的颜色标签
|
||||
return (
|
||||
<Tag color={color} shape='circle' prefixIcon={<Clock size={14} />}>
|
||||
{durationSec} 秒
|
||||
<Tag color={color} shape='circle'>
|
||||
{durationSec} s
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
@@ -149,7 +149,7 @@ const renderPlatform = (platform, t) => {
|
||||
);
|
||||
if (option) {
|
||||
return (
|
||||
<Tag color={option.color} shape='circle' prefixIcon={<Video size={14} />}>
|
||||
<Tag color={option.color} shape='circle'>
|
||||
{option.label}
|
||||
</Tag>
|
||||
);
|
||||
@@ -157,13 +157,13 @@ const renderPlatform = (platform, t) => {
|
||||
switch (platform) {
|
||||
case 'suno':
|
||||
return (
|
||||
<Tag color='green' shape='circle' prefixIcon={<Music size={14} />}>
|
||||
<Tag color='green' shape='circle'>
|
||||
Suno
|
||||
</Tag>
|
||||
);
|
||||
default:
|
||||
return (
|
||||
<Tag color='white' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||
<Tag color='white' shape='circle'>
|
||||
{t('未知')}
|
||||
</Tag>
|
||||
);
|
||||
@@ -240,7 +240,6 @@ export const getTaskLogsColumns = ({
|
||||
openContentModal,
|
||||
isAdminUser,
|
||||
openVideoModal,
|
||||
showUserInfoFunc,
|
||||
}) => {
|
||||
return [
|
||||
{
|
||||
@@ -278,7 +277,6 @@ export const getTaskLogsColumns = ({
|
||||
color={colors[parseInt(text) % colors.length]}
|
||||
size='large'
|
||||
shape='circle'
|
||||
prefixIcon={<Hash size={14} />}
|
||||
onClick={() => {
|
||||
copyText(text);
|
||||
}}
|
||||
@@ -294,7 +292,7 @@ export const getTaskLogsColumns = ({
|
||||
{
|
||||
key: COLUMN_KEYS.USERNAME,
|
||||
title: t('用户'),
|
||||
dataIndex: 'user_id',
|
||||
dataIndex: 'username',
|
||||
render: (userId, record, index) => {
|
||||
if (!isAdminUser) {
|
||||
return <></>;
|
||||
@@ -302,22 +300,14 @@ export const getTaskLogsColumns = ({
|
||||
const displayText = String(record.username || userId || '?');
|
||||
return (
|
||||
<Space>
|
||||
<Tooltip content={displayText}>
|
||||
<Avatar
|
||||
size='extra-small'
|
||||
color={stringToColor(displayText)}
|
||||
style={{ cursor: 'pointer' }}
|
||||
onClick={() => showUserInfoFunc && showUserInfoFunc(userId)}
|
||||
>
|
||||
{displayText.slice(0, 1)}
|
||||
</Avatar>
|
||||
</Tooltip>
|
||||
<Typography.Text
|
||||
ellipsis={{ showTooltip: true }}
|
||||
style={{ cursor: 'pointer', color: 'var(--semi-color-primary)' }}
|
||||
onClick={() => showUserInfoFunc && showUserInfoFunc(userId)}
|
||||
<Avatar
|
||||
size='extra-small'
|
||||
color={stringToColor(displayText)}
|
||||
>
|
||||
{userId}
|
||||
{displayText.slice(0, 1)}
|
||||
</Avatar>
|
||||
<Typography.Text>
|
||||
{displayText}
|
||||
</Typography.Text>
|
||||
</Space>
|
||||
);
|
||||
|
||||
@@ -25,7 +25,6 @@ import TaskLogsActions from './TaskLogsActions';
|
||||
import TaskLogsFilters from './TaskLogsFilters';
|
||||
import ColumnSelectorModal from './modals/ColumnSelectorModal';
|
||||
import ContentModal from './modals/ContentModal';
|
||||
import UserInfoModal from '../usage-logs/modals/UserInfoModal';
|
||||
import { useTaskLogsData } from '../../../hooks/task-logs/useTaskLogsData';
|
||||
import { useIsMobile } from '../../../hooks/common/useIsMobile';
|
||||
import { createCardProPagination } from '../../../helpers/utils';
|
||||
@@ -46,7 +45,6 @@ const TaskLogsPage = () => {
|
||||
modalContent={taskLogsData.videoUrl}
|
||||
isVideo={true}
|
||||
/>
|
||||
<UserInfoModal {...taskLogsData} />
|
||||
|
||||
<Layout>
|
||||
<CardPro
|
||||
|
||||
Reference in New Issue
Block a user