refactor(task): extract billing and polling logic from controller to service layer

Restructure the task relay system for better separation of concerns:
- Extract task billing into service/task_billing.go with unified settlement flow
- Move task polling loop from controller to service/task_polling.go (supports Suno + video platforms)
- Split RelayTask into fetch/submit paths with dedicated retry logic (taskSubmitWithRetry)
- Add TaskDto, TaskResponse generics, and FetchReq to dto/task.go
- Add taskcommon/helpers.go for shared task adaptor utilities
- Remove controller/task_video.go (logic consolidated into service layer)
- Update all task adaptors (ali, doubao, gemini, hailuo, jimeng, kling, sora, suno, vertex, vidu)
- Simplify frontend task logs to use new TaskDto response format
This commit is contained in:
CaIon
2026-02-10 20:40:33 +08:00
parent e0a6ee1cb8
commit 9e3954428d
34 changed files with 1465 additions and 1191 deletions

View File

@@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro
// shouldTrust 统一信任额度检查,适用于钱包和订阅。
func (s *BillingSession) shouldTrust(c *gin.Context) bool {
// 异步任务ForcePreConsume=true必须预扣全额不允许信任旁路
if s.relayInfo.ForcePreConsume {
return false
}
trustQuota := common.GetTrustQuota()
if trustQuota <= 0 {
return false

View File

@@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
return taskError
}
// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。
func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError {
if apiErr == nil {
return nil
}
return &dto.TaskError{
Code: string(apiErr.GetErrorCode()),
Message: apiErr.Err.Error(),
StatusCode: apiErr.StatusCode,
Error: apiErr.Err,
}
}

View File

@@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
return info
}
func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} {
func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} {
other := make(map[string]interface{})
other["model_price"] = priceData.ModelPrice
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio

227
service/task_billing.go Normal file
View File

@@ -0,0 +1,227 @@
package service
import (
"context"
"fmt"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
// 实际扣费已由 BillingSessionPreConsumeBilling + SettleBilling完成。
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("操作 %s", info.Action)
// 支持任务仅按次计费
if common.StringsContains(constant.TaskPricePatches, modelName) {
logContent = fmt.Sprintf("%s按次计费", logContent)
} else {
if len(info.PriceData.OtherRatios) > 0 {
var contents []string
for key, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
}
}
if len(contents) > 0 {
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
}
}
}
other := make(map[string]interface{})
other["request_path"] = c.Request.URL.Path
other["model_price"] = info.PriceData.ModelPrice
other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
}
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
ModelName: modelName,
TokenName: tokenName,
Quota: info.PriceData.Quota,
Content: logContent,
TokenId: info.TokenId,
Group: info.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
}
// ---------------------------------------------------------------------------
// 异步任务计费辅助函数
// ---------------------------------------------------------------------------
// resolveTokenKey 通过 TokenId 运行时获取令牌 Key用于 Redis 缓存操作)。
// 如果令牌已被删除或查询失败,返回空字符串。
func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
token, err := model.GetTokenById(tokenId)
if err != nil {
logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
return ""
}
return token.Key
}
// taskIsSubscription 判断任务是否通过订阅计费。
func taskIsSubscription(task *model.Task) bool {
return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
}
// taskAdjustFunding 调整任务的资金来源钱包或订阅delta > 0 表示扣费delta < 0 表示退还。
func taskAdjustFunding(task *model.Task, delta int) error {
if taskIsSubscription(task) {
return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
}
if delta > 0 {
return model.DecreaseUserQuota(task.UserId, delta)
}
return model.IncreaseUserQuota(task.UserId, -delta, false)
}
// taskAdjustTokenQuota 调整任务的令牌额度delta > 0 表示扣费delta < 0 表示退还。
// 需要通过 resolveTokenKey 运行时获取 key不从 PrivateData 中读取)。
func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
if task.PrivateData.TokenId <= 0 || delta == 0 {
return
}
tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
if tokenKey == "" {
return
}
var err error
if delta > 0 {
err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
} else {
err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
}
if err != nil {
logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
}
}
// RefundTaskQuota 统一的任务失败退款逻辑。
// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
quota := task.Quota
if quota == 0 {
return
}
// 1. 退还资金来源(钱包或订阅)
if err := taskAdjustFunding(task, -quota); err != nil {
logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
return
}
// 2. 退还令牌额度
taskAdjustTokenQuota(ctx, task, -quota)
// 3. 记录日志
logContent := fmt.Sprintf("异步任务执行失败 %s补偿 %s原因%s", task.TaskID, logger.LogQuota(quota), reason)
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
if totalTokens <= 0 {
return
}
// 获取模型名称
var taskData map[string]interface{}
if err := common.Unmarshal(task.Data, &taskData); err != nil {
return
}
modelName, ok := taskData["model"].(string)
if !ok || modelName == "" {
return
}
// 获取模型价格和倍率
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
// 只有配置了倍率(非固定价格)时才按 token 重新计费
if !hasRatioSetting || modelRatio <= 0 {
return
}
// 获取用户和组的倍率信息
group := task.Group
if group == "" {
user, err := model.GetUserById(task.UserId, false)
if err == nil {
group = user.Group
}
}
if group == "" {
return
}
groupRatio := ratio_setting.GetGroupRatio(group)
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
var finalGroupRatio float64
if hasUserGroupRatio {
finalGroupRatio = userGroupRatio
} else {
finalGroupRatio = groupRatio
}
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
// 计算差额(正数=需要补扣,负数=需要退还)
preConsumedQuota := task.Quota
quotaDelta := actualQuota - preConsumedQuota
if quotaDelta == 0 {
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%stokens%d",
task.TaskID, logger.LogQuota(actualQuota), totalTokens))
return
}
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算delta=%s实际%s预扣%stokens%d",
task.TaskID,
logger.LogQuota(quotaDelta),
logger.LogQuota(actualQuota),
logger.LogQuota(preConsumedQuota),
totalTokens,
))
// 调整资金来源
if err := taskAdjustFunding(task, quotaDelta); err != nil {
logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
return
}
// 调整令牌额度
taskAdjustTokenQuota(ctx, task, quotaDelta)
// 更新统计(仅补扣时更新,退还不影响已用统计)
if quotaDelta > 0 {
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
}
task.Quota = actualQuota
var action string
if quotaDelta > 0 {
action = "补扣费"
} else {
action = "退还"
}
logContent := fmt.Sprintf("视频任务成功%s模型倍率 %.2f,分组倍率 %.2ftokens %d预扣费 %s实际扣费 %s",
action, modelRatio, finalGroupRatio, totalTokens,
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}

446
service/task_polling.go Normal file
View File

@@ -0,0 +1,446 @@
package service
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/samber/lo"
)
// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
type TaskPollingAdaptor interface {
Init(info *relaycommon.RelayInfo)
FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
}
// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
// 打破 service -> relay -> relay/channel -> service 的循环依赖。
var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
func TaskPollingLoop() {
for {
time.Sleep(time.Duration(15) * time.Second)
common.SysLog("任务进度轮询开始")
ctx := context.TODO()
allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
platformTask := make(map[constant.TaskPlatform][]*model.Task)
for _, t := range allTasks {
platformTask[t.Platform] = append(platformTask[t.Platform], t)
}
for platform, tasks := range platformTask {
if len(tasks) == 0 {
continue
}
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Task)
nullTaskIds := make([]int64, 0)
for _, task := range tasks {
upstreamID := task.GetUpstreamTaskID()
if upstreamID == "" {
// 统计失败的未完成任务
nullTaskIds = append(nullTaskIds, task.ID)
continue
}
taskM[upstreamID] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
}
if len(nullTaskIds) > 0 {
err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
} else {
logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
continue
}
DispatchPlatformUpdate(platform, taskChannelM, taskM)
}
common.SysLog("任务进度轮询完成")
}
}
// DispatchPlatformUpdate 按平台分发轮询更新
func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
switch platform {
case constant.TaskPlatformMidjourney:
// MJ 轮询由其自身处理,这里预留入口
case constant.TaskPlatformSuno:
_ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
default:
if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
}
}
}
// UpdateSunoTasks 按渠道更新所有 Suno 任务
func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
err := updateSunoTasks(ctx, channelId, taskIds, taskM)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
}
}
return nil
}
func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
ch, err := model.CacheGetChannel(channelId)
if err != nil {
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
var failedIDs []int64
for _, upstreamID := range taskIds {
if t, ok := taskM[upstreamID]; ok {
failedIDs = append(failedIDs, t.ID)
}
}
err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
}
return err
}
adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
if adaptor == nil {
return errors.New("adaptor not found")
}
proxy := ch.GetSetting().Proxy
resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
"ids": taskIds,
}, proxy)
if err != nil {
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
return err
}
if resp.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
return err
}
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
err = common.Unmarshal(responseBody, &responseItems)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
return err
}
if !responseItems.IsSuccess() {
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
return err
}
for _, responseItem := range responseItems.Data {
task := taskM[responseItem.TaskID]
if !taskNeedsUpdate(task, responseItem) {
continue
}
task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
task.Progress = "100%"
RefundTaskQuota(ctx, task, task.FailReason)
}
if responseItem.Status == model.TaskStatusSuccess {
task.Progress = "100%"
}
task.Data = responseItem.Data
err = task.Update()
if err != nil {
common.SysLog("UpdateSunoTask task error: " + err.Error())
}
}
return nil
}
// taskNeedsUpdate 检查 Suno 任务是否需要更新
func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if string(oldTask.Status) != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
return true
}
oldData, _ := common.Marshal(oldTask.Data)
newData, _ := common.Marshal(newTask.Data)
sort.Slice(oldData, func(i, j int) bool {
return oldData[i] < oldData[j]
})
sort.Slice(newData, func(i, j int) bool {
return newData[i] < newData[j]
})
if string(oldData) != string(newData) {
return true
}
return false
}
// UpdateVideoTasks 按渠道更新所有视频任务
func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
cacheGetChannel, err := model.CacheGetChannel(channelId)
if err != nil {
// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
var failedIDs []int64
for _, upstreamID := range taskIds {
if t, ok := taskM[upstreamID]; ok {
failedIDs = append(failedIDs, t.ID)
}
}
errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if errUpdate != nil {
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
adaptor := GetTaskAdaptorFunc(platform)
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
info := &relaycommon.RelayInfo{}
info.ChannelMeta = &relaycommon.ChannelMeta{
ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
}
info.ApiKey = cacheGetChannel.Key
adaptor.Init(info)
for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
}
}
return nil
}
func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
baseURL := constant.ChannelBaseURLs[ch.Type]
if ch.GetBaseURL() != "" {
baseURL = ch.GetBaseURL()
}
proxy := ch.GetSetting().Proxy
task := taskM[taskId]
if task == nil {
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
key := ch.Key
privateData := task.PrivateData
if privateData.Key != "" {
key = privateData.Key
}
resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
"task_id": task.GetUpstreamTaskID(),
"action": task.Action,
}, proxy)
if err != nil {
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
}
logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
taskResult := &relaycommon.TaskInfo{}
// try parse as New API response format
var responseItems dto.TaskResponse[model.Task]
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
t := responseItems.Data
taskResult.TaskID = t.TaskID
taskResult.Status = string(t.Status)
taskResult.Url = t.GetResultURL()
taskResult.Progress = t.Progress
taskResult.Reason = t.FailReason
task.Data = t.Data
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
} else {
task.Data = redactVideoResponseBody(responseBody)
}
logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
now := time.Now().Unix()
if taskResult.Status == "" {
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
}
// 记录原本的状态,防止重复退款
shouldRefund := false
quota := task.Quota
preStatus := task.Status
task.Status = model.TaskStatus(taskResult.Status)
switch taskResult.Status {
case model.TaskStatusSubmitted:
task.Progress = taskcommon.ProgressSubmitted
case model.TaskStatusQueued:
task.Progress = taskcommon.ProgressQueued
case model.TaskStatusInProgress:
task.Progress = taskcommon.ProgressInProgress
if task.StartTime == 0 {
task.StartTime = now
}
case model.TaskStatusSuccess:
task.Progress = taskcommon.ProgressComplete
if task.FinishTime == 0 {
task.FinishTime = now
}
if strings.HasPrefix(taskResult.Url, "data:") {
// data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
} else if taskResult.Url != "" {
// Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
task.PrivateData.ResultURL = taskResult.Url
} else {
// No URL from adaptor — construct proxy URL using public task ID
task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
}
// 如果返回了 total_tokens根据模型倍率重新计费
if taskResult.TotalTokens > 0 {
RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
}
case model.TaskStatusFailure:
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
task.Status = model.TaskStatusFailure
task.Progress = taskcommon.ProgressComplete
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Reason
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
taskResult.Progress = taskcommon.ProgressComplete
if quota != 0 {
if preStatus != model.TaskStatusFailure {
shouldRefund = true
} else {
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
}
}
default:
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
}
if taskResult.Progress != "" {
task.Progress = taskResult.Progress
}
if err := task.Update(); err != nil {
common.SysLog("UpdateVideoTask task error: " + err.Error())
shouldRefund = false
}
if shouldRefund {
RefundTaskQuota(ctx, task, task.FailReason)
}
return nil
}
func redactVideoResponseBody(body []byte) []byte {
var m map[string]any
if err := common.Unmarshal(body, &m); err != nil {
return body
}
resp, _ := m["response"].(map[string]any)
if resp != nil {
delete(resp, "bytesBase64Encoded")
if v, ok := resp["video"].(string); ok {
resp["video"] = truncateBase64(v)
}
if vs, ok := resp["videos"].([]any); ok {
for i := range vs {
if vm, ok := vs[i].(map[string]any); ok {
delete(vm, "bytesBase64Encoded")
}
}
}
}
b, err := common.Marshal(m)
if err != nil {
return body
}
return b
}
func truncateBase64(s string) string {
const maxKeep = 256
if len(s) <= maxKeep {
return s
}
return s[:maxKeep] + "..."
}