Files
new-api/relay/relay_task.go
CaIon ec5c6b28ea feat(task): add model redirection, per-call billing, and multipart retry fix for async tasks
1. Async task model redirection (aligned with sync tasks):
   - Integrate ModelMappedHelper in RelayTaskSubmit after model name
     determination, populating OriginModelName / UpstreamModelName on RelayInfo.
   - All task adaptors now send UpstreamModelName to upstream providers:
     - Gemini & Vertex: BuildRequestURL uses UpstreamModelName.
     - Doubao & Ali: BuildRequestBody conditionally overwrites body.Model.
     - Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo
       and unconditionally uses info.UpstreamModelName.
     - Sora: BuildRequestBody parses JSON and multipart bodies to replace
       the "model" field with UpstreamModelName.
   - Frontend log visibility: LogTaskConsumption and taskBillingOther now
     emit is_model_mapped / upstream_model_name in the "other" JSON field.
   - Billing safety: RecalculateTaskQuotaByTokens reads model name from
     BillingContext.OriginModelName (via taskModelName) instead of
     task.Data["model"], preventing billing leaks from upstream model names.

2. Per-call billing (TaskPricePatches lifecycle):
   - Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling
     bool field, populated from TaskPricePatches at submission time.
   - settleTaskBillingOnComplete short-circuits when PerCallBilling is true,
     skipping both adaptor adjustments and token-based recalculation.
   - Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName
     consistently in controller/relay.go for billing context and logging.

3. Multipart retry boundary mismatch fix:
   - Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a
     new boundary and overwrites c.Request.Header["Content-Type"], subsequent
     calls to ParseMultipartFormReusable on retry would parse the cached
     original body with the wrong boundary, causing "NextPart: EOF".
   - Fix: ParseMultipartFormReusable now caches the original Content-Type in
     gin context key "_original_multipart_ct" on first call and reuses it for
     all subsequent parses, making multipart parsing retry-safe globally.
   - Sora adaptor reverted to the standard pattern (direct header set/get),
     which is now safe thanks to the root fix.

4. Tests:
   - task_billing_test.go: update makeTask to use OriginModelName; add
     PerCallBilling settlement tests (skip adaptor adjust, skip token recalc);
     add non-per-call adaptor adjustment test with refund verification.
2026-02-22 16:33:00 +08:00

559 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package relay
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
)
type TaskSubmitResult struct {
UpstreamTaskID string
TaskData []byte
Platform constant.TaskPlatform
Quota int
//PerCallPrice types.PriceData
}
// ResolveOriginTask 处理基于已有任务的提交remix / continuation
// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
// (通过 info.LockedChannel重试时复用同一渠道并轮换 key
// 以及提取 OtherRatios时长、分辨率
// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
// 检测 remix action
path := c.Request.URL.Path
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
info.Action = constant.TaskActionRemix
}
if info.Action == constant.TaskActionRemix {
videoID := c.Param("video_id")
if strings.TrimSpace(videoID) == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
}
info.OriginTaskID = videoID
}
if info.OriginTaskID == "" {
return nil
}
// 查找原始任务
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
if err != nil {
return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
}
if !exist {
return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
}
// 从原始任务推导模型名称
if info.OriginModelName == "" {
if originTask.Properties.OriginModelName != "" {
info.OriginModelName = originTask.Properties.OriginModelName
} else if originTask.Properties.UpstreamModelName != "" {
info.OriginModelName = originTask.Properties.UpstreamModelName
} else {
var taskData map[string]interface{}
_ = common.Unmarshal(originTask.Data, &taskData)
if m, ok := taskData["model"].(string); ok && m != "" {
info.OriginModelName = m
}
}
}
// 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key
ch, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
}
if ch.Status != common.ChannelStatusEnabled {
return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
}
info.LockedChannel = ch
if originTask.ChannelId != info.ChannelId {
key, _, newAPIError := ch.GetNextEnabledKey()
if newAPIError != nil {
return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
}
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
info.ChannelBaseUrl = ch.GetBaseURL()
info.ChannelId = originTask.ChannelId
info.ChannelType = ch.Type
info.ApiKey = key
}
// 提取 remix 参数(时长、分辨率 → OtherRatios
if info.Action == constant.TaskActionRemix {
if originTask.PrivateData.BillingContext != nil {
// 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios如果存在
for s, f := range originTask.PrivateData.BillingContext.OtherRatios {
info.PriceData.AddOtherRatio(s, f)
}
} else {
// 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size如果存在
var taskData map[string]interface{}
_ = common.Unmarshal(originTask.Data, &taskData)
secondsStr, _ := taskData["seconds"].(string)
seconds, _ := strconv.Atoi(secondsStr)
if seconds <= 0 {
seconds = 4
}
sizeStr, _ := taskData["size"].(string)
if info.PriceData.OtherRatios == nil {
info.PriceData.OtherRatios = map[string]float64{}
}
info.PriceData.OtherRatios["seconds"] = float64(seconds)
info.PriceData.OtherRatios["size"] = 1
if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
info.PriceData.OtherRatios["size"] = 1.666667
}
}
}
return nil
}
// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
// 控制器负责 defer Refund 和成功后 Settle。
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
info.InitChannelMeta(c)
// 1. 确定 platform → 创建适配器 → 验证请求
platform := constant.TaskPlatform(c.GetString("platform"))
if platform == "" {
platform = GetTaskPlatform(c)
}
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
}
adaptor.Init(info)
if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
return nil, taskErr
}
// 2. 确定模型名称
modelName := info.OriginModelName
if modelName == "" {
modelName = service.CoverTaskActionToModelName(platform, info.Action)
}
// 2.5 应用渠道的模型映射(与同步任务对齐)
info.OriginModelName = modelName
info.UpstreamModelName = modelName
if err := helper.ModelMappedHelper(c, info, nil); err != nil {
return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
}
// 3. 预生成公开 task ID仅首次
if info.PublicTaskID == "" {
info.PublicTaskID = model.GenerateTaskID()
}
// 4. 价格计算:基础模型价格
info.OriginModelName = modelName
info.PriceData = helper.ModelPriceHelperPerCall(c, info)
// 5. 计费估算:让适配器根据用户请求提供 OtherRatios时长、分辨率等
// 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData
// ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios此处合并。
if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
for k, v := range estimatedRatios {
info.PriceData.AddOtherRatio(k, v)
}
}
// 6. 将 OtherRatios 应用到基础额度
if !common.StringsContains(constant.TaskPricePatches, modelName) {
for _, ra := range info.PriceData.OtherRatios {
if ra != 1.0 {
info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
}
}
}
// 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
if info.Billing == nil && !info.PriceData.FreeModel {
info.ForcePreConsume = true
if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
return nil, service.TaskErrorFromAPIError(apiErr)
}
}
// 8. 构建请求体
requestBody, err := adaptor.BuildRequestBody(c, info)
if err != nil {
return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
}
// 9. 发送请求
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil && resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
}
// 10. 返回 OtherRatios 给下游header 必须在 DoResponse 写 body 之前设置)
otherRatios := info.PriceData.OtherRatios
if otherRatios == nil {
otherRatios = map[string]float64{}
}
ratiosJSON, _ := common.Marshal(otherRatios)
c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
// 11. 解析响应
upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
if taskErr != nil {
return nil, taskErr
}
// 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
finalQuota := info.PriceData.Quota
if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
// 基于调整后的 ratios 重新计算 quota
finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
info.PriceData.OtherRatios = adjustedRatios
info.PriceData.Quota = finalQuota
}
return &TaskSubmitResult{
UpstreamTaskID: upstreamTaskID,
TaskData: taskData,
Platform: platform,
Quota: finalQuota,
}, nil
}
// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
// 从 PriceData 获取不含 OtherRatios 的基础价格
baseQuota := info.PriceData.Quota
// 先除掉原有的 OtherRatios 恢复基础额度
for _, ra := range info.PriceData.OtherRatios {
if ra != 1.0 && ra > 0 {
baseQuota = int(float64(baseQuota) / ra)
}
}
// 应用新的 ratios
result := float64(baseQuota)
for _, ra := range ratios {
if ra != 1.0 {
result *= ra
}
}
return int(result)
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
}
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
respBuilder, ok := fetchRespBuilders[relayMode]
if !ok {
taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
}
respBody, taskErr := respBuilder(c)
if taskErr != nil {
return taskErr
}
if len(respBody) == 0 {
respBody = []byte("{\"code\":\"success\",\"data\":null}")
}
c.Writer.Header().Set("Content-Type", "application/json")
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
return
}
return
}
func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
userId := c.GetInt("id")
var condition = struct {
IDs []any `json:"ids"`
Action string `json:"action"`
}{}
err := c.BindJSON(&condition)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
return
}
var tasks []any
if len(condition.IDs) > 0 {
taskModels, err := model.GetByTaskIds(userId, condition.IDs)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
return
}
for _, task := range taskModels {
tasks = append(tasks, TaskModel2Dto(task))
}
} else {
tasks = make([]any, 0)
}
respBody, err = common.Marshal(dto.TaskResponse[[]any]{
Code: "success",
Data: tasks,
})
return
}
func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
respBody, err = common.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
return
}
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("task_id")
if taskId == "" {
taskId = c.GetString("task_id")
}
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
// Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
respBody = realtimeResp
return
}
// OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
if isOpenAIVideoAPI {
adaptor := GetTaskAdaptor(originTask.Platform)
if adaptor == nil {
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
return
}
if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
return
}
respBody = openAIVideoData
return
}
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
return
}
// 通用 TaskDto 格式
respBody, err = common.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
if err != nil {
taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
}
return
}
// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
channelModel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
return nil
}
if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
return nil
}
baseURL := constant.ChannelBaseURLs[channelModel.Type]
if channelModel.GetBaseURL() != "" {
baseURL = channelModel.GetBaseURL()
}
proxy := channelModel.GetSetting().Proxy
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
if adaptor == nil {
return nil
}
resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
"task_id": task.GetUpstreamTaskID(),
"action": task.Action,
}, proxy)
if err != nil || resp == nil {
return nil
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil
}
ti, err := adaptor.ParseTaskResult(body)
if err != nil || ti == nil {
return nil
}
snap := task.Snapshot()
// 将上游最新状态更新到 task
if ti.Status != "" {
task.Status = model.TaskStatus(ti.Status)
}
if ti.Progress != "" {
task.Progress = ti.Progress
}
if strings.HasPrefix(ti.Url, "data:") {
// data: URI — kept in Data, not ResultURL
} else if ti.Url != "" {
task.PrivateData.ResultURL = ti.Url
} else if task.Status == model.TaskStatusSuccess {
// No URL from adaptor — construct proxy URL using public task ID
task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
}
if !snap.Equal(task.Snapshot()) {
_, _ = task.UpdateWithStatus(snap.Status)
}
// OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
if isOpenAIVideoAPI {
return nil
}
// 非 OpenAI Video API: 构建自定义格式响应
format := detectVideoFormat(body)
out := map[string]any{
"error": nil,
"format": format,
"metadata": nil,
"status": mapTaskStatusToSimple(task.Status),
"task_id": task.TaskID,
"url": task.GetResultURL(),
}
respBody, _ := common.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: out,
})
return respBody
}
// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
func detectVideoFormat(rawBody []byte) string {
var raw map[string]any
if err := common.Unmarshal(rawBody, &raw); err != nil {
return "mp4"
}
respObj, ok := raw["response"].(map[string]any)
if !ok {
return "mp4"
}
vids, ok := respObj["videos"].([]any)
if !ok || len(vids) == 0 {
return "mp4"
}
v0, ok := vids[0].(map[string]any)
if !ok {
return "mp4"
}
mt, ok := v0["mimeType"].(string)
if !ok || mt == "" || strings.Contains(mt, "mp4") {
return "mp4"
}
return mt
}
// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
func mapTaskStatusToSimple(status model.TaskStatus) string {
switch status {
case model.TaskStatusSuccess:
return "succeeded"
case model.TaskStatusFailure:
return "failed"
case model.TaskStatusQueued, model.TaskStatusSubmitted:
return "queued"
default:
return "processing"
}
}
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
ID: task.ID,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
TaskID: task.TaskID,
Platform: string(task.Platform),
UserId: task.UserId,
Group: task.Group,
ChannelId: task.ChannelId,
Quota: task.Quota,
Action: task.Action,
Status: string(task.Status),
FailReason: task.FailReason,
ResultURL: task.GetResultURL(),
SubmitTime: task.SubmitTime,
StartTime: task.StartTime,
FinishTime: task.FinishTime,
Progress: task.Progress,
Properties: task.Properties,
Username: task.Username,
Data: task.Data,
}
}