From 9e3954428dc0bf6bf5c29eed415d8b213affad22 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 10 Feb 2026 20:40:33 +0800 Subject: [PATCH] refactor(task): extract billing and polling logic from controller to service layer Restructure the task relay system for better separation of concerns: - Extract task billing into service/task_billing.go with unified settlement flow - Move task polling loop from controller to service/task_polling.go (supports Suno + video platforms) - Split RelayTask into fetch/submit paths with dedicated retry logic (taskSubmitWithRetry) - Add TaskDto, TaskResponse generics, and FetchReq to dto/task.go - Add taskcommon/helpers.go for shared task adaptor utilities - Remove controller/task_video.go (logic consolidated into service layer) - Update all task adaptors (ali, doubao, gemini, hailuo, jimeng, kling, sora, suno, vertex, vidu) - Simplify frontend task logs to use new TaskDto response format --- controller/relay.go | 122 +++- controller/task.go | 228 +------ controller/task_video.go | 313 ---------- controller/video_proxy.go | 111 +--- controller/video_proxy_gemini.go | 8 +- dto/suno.go | 32 - dto/task.go | 47 ++ main.go | 10 + middleware/auth.go | 18 + model/task.go | 57 +- model/token.go | 6 +- relay/channel/task/ali/adaptor.go | 3 +- relay/channel/task/doubao/adaptor.go | 24 +- relay/channel/task/gemini/adaptor.go | 47 +- relay/channel/task/hailuo/adaptor.go | 15 +- relay/channel/task/jimeng/adaptor.go | 27 +- relay/channel/task/kling/adaptor.go | 43 +- relay/channel/task/sora/adaptor.go | 24 +- relay/channel/task/suno/adaptor.go | 29 +- relay/channel/task/taskcommon/helpers.go | 70 +++ relay/channel/task/vertex/adaptor.go | 50 +- relay/channel/task/vidu/adaptor.go | 45 +- relay/common/relay_info.go | 15 +- relay/helper/price.go | 15 +- relay/relay_task.go | 576 +++++++++--------- router/video-router.go | 8 +- service/billing_session.go | 5 + service/error.go | 13 + service/log_info_generate.go | 2 +- service/task_billing.go | 227 +++++++ service/task_polling.go | 446 ++++++++++++++ types/price_data.go | 9 +- .../table/task-logs/TaskLogsColumnDefs.jsx | 9 +- .../table/task-logs/modals/ContentModal.jsx | 2 - 34 files changed, 1465 insertions(+), 1191 deletions(-) delete mode 100644 controller/task_video.go create mode 100644 relay/channel/task/taskcommon/helpers.go create mode 100644 service/task_billing.go create mode 100644 service/task_polling.go diff --git a/controller/relay.go b/controller/relay.go index 0b30e6e9e..132fee9ba 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -451,17 +451,102 @@ func RelayNotFound(c *gin.Context) { } func RelayTask(c *gin.Context) { - retryTimes := common.RetryTimes channelId := c.GetInt("channel_id") c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) if err != nil { + c.JSON(http.StatusInternalServerError, &dto.TaskError{ + Code: "gen_relay_info_failed", + Message: err.Error(), + StatusCode: http.StatusInternalServerError, + }) return } - taskErr := taskRelayHandler(c, relayInfo) - if taskErr == nil { - retryTimes = 0 + + // Fetch 操作是纯 DB 查询(或 task 自带 channelId 的上游查询),不依赖上下文 channel,无需重试 + // TODO: 在video-route层面优化,避免无谓的 channel 选择和上下文设置,也没必要吧代码放到这里来写这么多屎山 + switch relayInfo.RelayMode { + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: + if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { + respondTaskError(c, taskErr) + } + return } + + // ── Submit 路径 ───────────────────────────────────────────────── + + // 1. 解析原始任务(remix / continuation),一次性,可能锁定渠道并禁止重试 + if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil { + respondTaskError(c, taskErr) + return + } + + // 2. defer Refund(全部失败时回滚预扣费) + var result *relay.TaskSubmitResult + var taskErr *dto.TaskError + defer func() { + if taskErr != nil && relayInfo.Billing != nil { + relayInfo.Billing.Refund(c) + } + }() + + // 3. 执行 + 重试(RelayTaskSubmit 内部在首次调用时自动预扣费) + taskErr = taskSubmitWithRetry(c, relayInfo, channelId, common.RetryTimes, func() *dto.TaskError { + var te *dto.TaskError + result, te = relay.RelayTaskSubmit(c, relayInfo) + return te + }) + + // 4. 成功:结算 + 日志 + 插入任务 + if taskErr == nil { + if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil { + common.SysError("settle task billing error: " + settleErr.Error()) + } + service.LogTaskConsumption(c, relayInfo, result.ModelName) + + task := model.InitTask(result.Platform, relayInfo) + task.PrivateData.UpstreamTaskID = result.UpstreamTaskID + task.PrivateData.BillingSource = relayInfo.BillingSource + task.PrivateData.SubscriptionId = relayInfo.SubscriptionId + task.PrivateData.TokenId = relayInfo.TokenId + task.Quota = result.Quota + task.Data = result.TaskData + task.Action = relayInfo.Action + if insertErr := task.Insert(); insertErr != nil { + //taskErr = service.TaskErrorWrapper(insertErr, "insert_task_failed", http.StatusInternalServerError) + common.SysError("insert task error: " + insertErr.Error()) + } + } + + if taskErr != nil { + respondTaskError(c, taskErr) + } +} + +// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写) +func respondTaskError(c *gin.Context, taskErr *dto.TaskError) { + if taskErr.StatusCode == http.StatusTooManyRequests { + taskErr.Message = "当前分组上游负载已饱和,请稍后再试" + } + c.JSON(taskErr.StatusCode, taskErr) +} + +// taskSubmitWithRetry 执行首次尝试并在失败时切换渠道重试,返回最终的 taskErr。 +// attempt 闭包负责实际的上游请求,不涉及计费。 +func taskSubmitWithRetry(c *gin.Context, relayInfo *relaycommon.RelayInfo, + channelId int, retryTimes int, attempt func() *dto.TaskError) *dto.TaskError { + + taskErr := attempt() + if taskErr == nil { + return nil + } + if !taskErr.LocalError { + processChannelError(c, + *types.NewChannelError(channelId, c.GetInt("channel_type"), c.GetString("channel_name"), common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), + common.GetContextKeyString(c, constant.ContextKeyChannelKey), common.GetContextKeyBool(c, constant.ContextKeyChannelAutoBan)), + types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) + } + retryParam := &service.RetryParam{ Ctx: c, TokenGroup: relayInfo.TokenGroup, @@ -480,7 +565,7 @@ func RelayTask(c *gin.Context) { useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry())) - //middleware.SetupContextForSelectedChannel(c, channel, originalModel) + middleware.SetupContextForSelectedChannel(c, channel, c.GetString("original_model")) bodyStorage, err := common.GetBodyStorage(c) if err != nil { @@ -492,30 +577,21 @@ func RelayTask(c *gin.Context) { break } c.Request.Body = io.NopCloser(bodyStorage) - taskErr = taskRelayHandler(c, relayInfo) + taskErr = attempt() + if taskErr != nil && !taskErr.LocalError { + processChannelError(c, + *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, + common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), + types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) + } } + useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) logger.LogInfo(c, retryLogStr) } - if taskErr != nil { - if taskErr.StatusCode == http.StatusTooManyRequests { - taskErr.Message = "当前分组上游负载已饱和,请稍后再试" - } - c.JSON(taskErr.StatusCode, taskErr) - } -} - -func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError { - var err *dto.TaskError - switch relayInfo.RelayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: - err = relay.RelayTaskFetch(c, relayInfo.RelayMode) - default: - err = relay.RelayTaskSubmit(c, relayInfo) - } - return err + return taskErr } func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool { diff --git a/controller/task.go b/controller/task.go index 244f9161c..ec713c5d2 100644 --- a/controller/task.go +++ b/controller/task.go @@ -1,231 +1,21 @@ package controller import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "sort" "strconv" - "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" - "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" + "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" - "github.com/samber/lo" ) +// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层 func UpdateTaskBulk() { - //revocer - //imageModel := "midjourney" - for { - time.Sleep(time.Duration(15) * time.Second) - common.SysLog("任务进度轮询开始") - ctx := context.TODO() - allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) - platformTask := make(map[constant.TaskPlatform][]*model.Task) - for _, t := range allTasks { - platformTask[t.Platform] = append(platformTask[t.Platform], t) - } - for platform, tasks := range platformTask { - if len(tasks) == 0 { - continue - } - taskChannelM := make(map[int][]string) - taskM := make(map[string]*model.Task) - nullTaskIds := make([]int64, 0) - for _, task := range tasks { - if task.TaskID == "" { - // 统计失败的未完成任务 - nullTaskIds = append(nullTaskIds, task.ID) - continue - } - taskM[task.TaskID] = task - taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID) - } - if len(nullTaskIds) > 0 { - err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ - "status": "FAILURE", - "progress": "100%", - }) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) - } else { - logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) - } - } - if len(taskChannelM) == 0 { - continue - } - - UpdateTaskByPlatform(platform, taskChannelM, taskM) - } - common.SysLog("任务进度轮询完成") - } -} - -func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { - switch platform { - case constant.TaskPlatformMidjourney: - //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) - case constant.TaskPlatformSuno: - _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) - default: - if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) - } - } -} - -func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { - for channelId, taskIds := range taskChannelM { - err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) - } - } - return nil -} - -func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { - logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) - if len(taskIds) == 0 { - return nil - } - channel, err := model.CacheGetChannel(channelId) - if err != nil { - common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) - err = model.TaskBulkUpdate(taskIds, map[string]any{ - "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), - "status": "FAILURE", - "progress": "100%", - }) - if err != nil { - common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) - } - return err - } - adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno) - if adaptor == nil { - return errors.New("adaptor not found") - } - proxy := channel.GetSetting().Proxy - resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{ - "ids": taskIds, - }, proxy) - if err != nil { - common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) - return err - } - if resp.StatusCode != http.StatusOK { - logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) - return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) - } - defer resp.Body.Close() - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) - return err - } - var responseItems dto.TaskResponse[[]dto.SunoDataResponse] - err = json.Unmarshal(responseBody, &responseItems) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) - return err - } - if !responseItems.IsSuccess() { - common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) - return err - } - - for _, responseItem := range responseItems.Data { - task := taskM[responseItem.TaskID] - if !checkTaskNeedUpdate(task, responseItem) { - continue - } - - task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) - task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) - task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) - task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) - task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) - if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { - logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) - task.Progress = "100%" - //err = model.CacheUpdateUserQuota(task.UserId) ? - if err != nil { - logger.LogError(ctx, "error update user quota cache: "+err.Error()) - } else { - quota := task.Quota - if quota != 0 { - err = model.IncreaseUserQuota(task.UserId, quota, false) - if err != nil { - logger.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } - } - if responseItem.Status == model.TaskStatusSuccess { - task.Progress = "100%" - } - task.Data = responseItem.Data - - err = task.Update() - if err != nil { - common.SysLog("UpdateMidjourneyTask task error: " + err.Error()) - } - } - return nil -} - -func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { - - if oldTask.SubmitTime != newTask.SubmitTime { - return true - } - if oldTask.StartTime != newTask.StartTime { - return true - } - if oldTask.FinishTime != newTask.FinishTime { - return true - } - if string(oldTask.Status) != newTask.Status { - return true - } - if oldTask.FailReason != newTask.FailReason { - return true - } - if oldTask.FinishTime != newTask.FinishTime { - return true - } - - if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { - return true - } - - oldData, _ := json.Marshal(oldTask.Data) - newData, _ := json.Marshal(newTask.Data) - - sort.Slice(oldData, func(i, j int) bool { - return oldData[i] < oldData[j] - }) - sort.Slice(newData, func(i, j int) bool { - return newData[i] < newData[j] - }) - - if string(oldData) != string(newData) { - return true - } - return false + service.TaskPollingLoop() } func GetAllTask(c *gin.Context) { @@ -247,7 +37,7 @@ func GetAllTask(c *gin.Context) { items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllTasks(queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(items) + pageInfo.SetItems(tasksToDto(items)) common.ApiSuccess(c, pageInfo) } @@ -271,6 +61,14 @@ func GetUserTask(c *gin.Context) { items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllUserTask(userId, queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(items) + pageInfo.SetItems(tasksToDto(items)) common.ApiSuccess(c, pageInfo) } + +func tasksToDto(tasks []*model.Task) []*dto.TaskDto { + result := make([]*dto.TaskDto, len(tasks)) + for i, task := range tasks { + result[i] = relay.TaskModel2Dto(task) + } + return result +} diff --git a/controller/task_video.go b/controller/task_video.go deleted file mode 100644 index d7c19e620..000000000 --- a/controller/task_video.go +++ /dev/null @@ -1,313 +0,0 @@ -package controller - -import ( - "context" - "encoding/json" - "fmt" - "io" - "time" - - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/constant" - "github.com/QuantumNous/new-api/dto" - "github.com/QuantumNous/new-api/logger" - "github.com/QuantumNous/new-api/model" - "github.com/QuantumNous/new-api/relay" - "github.com/QuantumNous/new-api/relay/channel" - relaycommon "github.com/QuantumNous/new-api/relay/common" - "github.com/QuantumNous/new-api/setting/ratio_setting" -) - -func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { - for channelId, taskIds := range taskChannelM { - if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { - logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) - } - } - return nil -} - -func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { - logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) - if len(taskIds) == 0 { - return nil - } - cacheGetChannel, err := model.CacheGetChannel(channelId) - if err != nil { - errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ - "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), - "status": "FAILURE", - "progress": "100%", - }) - if errUpdate != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) - } - return fmt.Errorf("CacheGetChannel failed: %w", err) - } - adaptor := relay.GetTaskAdaptor(platform) - if adaptor == nil { - return fmt.Errorf("video adaptor not found") - } - info := &relaycommon.RelayInfo{} - info.ChannelMeta = &relaycommon.ChannelMeta{ - ChannelBaseUrl: cacheGetChannel.GetBaseURL(), - } - info.ApiKey = cacheGetChannel.Key - adaptor.Init(info) - for _, taskId := range taskIds { - if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { - logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) - } - } - return nil -} - -func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { - baseURL := constant.ChannelBaseURLs[channel.Type] - if channel.GetBaseURL() != "" { - baseURL = channel.GetBaseURL() - } - proxy := channel.GetSetting().Proxy - - task := taskM[taskId] - if task == nil { - logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) - return fmt.Errorf("task %s not found", taskId) - } - key := channel.Key - - privateData := task.PrivateData - if privateData.Key != "" { - key = privateData.Key - } - resp, err := adaptor.FetchTask(baseURL, key, map[string]any{ - "task_id": taskId, - "action": task.Action, - }, proxy) - if err != nil { - return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) - } - //if resp.StatusCode != http.StatusOK { - //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode) - //} - defer resp.Body.Close() - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("readAll failed for task %s: %w", taskId, err) - } - - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody))) - - taskResult := &relaycommon.TaskInfo{} - // try parse as New API response format - var responseItems dto.TaskResponse[model.Task] - if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems)) - t := responseItems.Data - taskResult.TaskID = t.TaskID - taskResult.Status = string(t.Status) - taskResult.Url = t.FailReason - taskResult.Progress = t.Progress - taskResult.Reason = t.FailReason - task.Data = t.Data - } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { - return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) - } else { - task.Data = redactVideoResponseBody(responseBody) - } - - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult)) - - now := time.Now().Unix() - if taskResult.Status == "" { - //return fmt.Errorf("task %s status is empty", taskId) - taskResult = relaycommon.FailTaskInfo("upstream returned empty status") - } - - // 记录原本的状态,防止重复退款 - shouldRefund := false - quota := task.Quota - preStatus := task.Status - - task.Status = model.TaskStatus(taskResult.Status) - switch taskResult.Status { - case model.TaskStatusSubmitted: - task.Progress = "10%" - case model.TaskStatusQueued: - task.Progress = "20%" - case model.TaskStatusInProgress: - task.Progress = "30%" - if task.StartTime == 0 { - task.StartTime = now - } - case model.TaskStatusSuccess: - task.Progress = "100%" - if task.FinishTime == 0 { - task.FinishTime = now - } - if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") { - task.FailReason = taskResult.Url - } - - // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费 - if taskResult.TotalTokens > 0 { - // 获取模型名称 - var taskData map[string]interface{} - if err := json.Unmarshal(task.Data, &taskData); err == nil { - if modelName, ok := taskData["model"].(string); ok && modelName != "" { - // 获取模型价格和倍率 - modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) - // 只有配置了倍率(非固定价格)时才按 token 重新计费 - if hasRatioSetting && modelRatio > 0 { - // 获取用户和组的倍率信息 - group := task.Group - if group == "" { - user, err := model.GetUserById(task.UserId, false) - if err == nil { - group = user.Group - } - } - if group != "" { - groupRatio := ratio_setting.GetGroupRatio(group) - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) - - var finalGroupRatio float64 - if hasUserGroupRatio { - finalGroupRatio = userGroupRatio - } else { - finalGroupRatio = groupRatio - } - - // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio - actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio) - - // 计算差额 - preConsumedQuota := task.Quota - quotaDelta := actualQuota - preConsumedQuota - - if quotaDelta > 0 { - // 需要补扣费 - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(quotaDelta), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - taskResult.TotalTokens, - )) - if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil { - logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error())) - } else { - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) - task.Quota = actualQuota // 更新任务记录的实际扣费额度 - - // 记录消费日志 - logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s", - modelRatio, finalGroupRatio, taskResult.TotalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } else if quotaDelta < 0 { - // 需要退还多扣的费用 - refundQuota := -quotaDelta - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(refundQuota), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - taskResult.TotalTokens, - )) - if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil { - logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error())) - } else { - task.Quota = actualQuota // 更新任务记录的实际扣费额度 - - // 记录退款日志 - logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s", - modelRatio, finalGroupRatio, taskResult.TotalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } else { - // quotaDelta == 0, 预扣费刚好准确 - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", - task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens)) - } - } - } - } - } - } - case model.TaskStatusFailure: - logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) - task.Status = model.TaskStatusFailure - task.Progress = "100%" - if task.FinishTime == 0 { - task.FinishTime = now - } - task.FailReason = taskResult.Reason - logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) - taskResult.Progress = "100%" - if quota != 0 { - if preStatus != model.TaskStatusFailure { - shouldRefund = true - } else { - logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) - } - } - default: - return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) - } - if taskResult.Progress != "" { - task.Progress = taskResult.Progress - } - if err := task.Update(); err != nil { - common.SysLog("UpdateVideoTask task error: " + err.Error()) - shouldRefund = false - } - - if shouldRefund { - // 任务失败且之前状态不是失败才退还额度,防止重复退还 - if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { - logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - - return nil -} - -func redactVideoResponseBody(body []byte) []byte { - var m map[string]any - if err := json.Unmarshal(body, &m); err != nil { - return body - } - resp, _ := m["response"].(map[string]any) - if resp != nil { - delete(resp, "bytesBase64Encoded") - if v, ok := resp["video"].(string); ok { - resp["video"] = truncateBase64(v) - } - if vs, ok := resp["videos"].([]any); ok { - for i := range vs { - if vm, ok := vs[i].(map[string]any); ok { - delete(vm, "bytesBase64Encoded") - } - } - } - } - b, err := json.Marshal(m) - if err != nil { - return body - } - return b -} - -func truncateBase64(s string) string { - const maxKeep = 256 - if len(s) <= maxKeep { - return s - } - return s[:maxKeep] + "..." -} diff --git a/controller/video_proxy.go b/controller/video_proxy.go index f102baae4..f1dd2bc92 100644 --- a/controller/video_proxy.go +++ b/controller/video_proxy.go @@ -16,59 +16,44 @@ import ( "github.com/gin-gonic/gin" ) +// videoProxyError returns a standardized OpenAI-style error response. +func videoProxyError(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "message": message, + "type": errType, + }, + }) +} + func VideoProxy(c *gin.Context) { taskID := c.Param("task_id") if taskID == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "task_id is required", - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required") return } task, exists, err := model.GetByOnlyTaskId(taskID) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to query task", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task") return } if !exists || task == nil { - logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err)) - c.JSON(http.StatusNotFound, gin.H{ - "error": gin.H{ - "message": "Task not found", - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found") return } if task.Status != model.TaskStatusSuccess { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status), - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusBadRequest, "invalid_request_error", + fmt.Sprintf("Task is not completed yet, current status: %s", task.Status)) return } channel, err := model.CacheGetChannel(task.ChannelId) if err != nil { - logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID)) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to retrieve channel information", - "type": "server_error", - }, - }) + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error())) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information") return } baseURL := channel.GetBaseURL() @@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) { client, err := service.GetHttpClientWithProxy(proxy) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy client", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client") return } @@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy request", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request") return } @@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) { apiKey := task.PrivateData.Key if apiKey == "" { logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID)) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "API key not stored for task", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task") return } - videoURL, err = getGeminiVideoURL(channel, task, apiKey) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error())) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to resolve Gemini video URL", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL") return } req.Header.Set("x-goog-api-key", apiKey) case constant.ChannelTypeOpenAI, constant.ChannelTypeSora: - videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID) + videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID()) req.Header.Set("Authorization", "Bearer "+channel.Key) default: - // Video URL is directly in task.FailReason - videoURL = task.FailReason + // Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data) + videoURL = task.GetResultURL() } req.URL, err = url.Parse(videoURL) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy request", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request") return } resp, err := client.Do(req) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error())) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to fetch video content", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content") return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL)) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode), - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", + fmt.Sprintf("Upstream service returned status %d", resp.StatusCode)) return } @@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) { } } - c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours + c.Writer.Header().Set("Cache-Control", "public, max-age=86400") c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { + if _, err = io.Copy(c.Writer, resp.Body); err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error())) } } diff --git a/controller/video_proxy_gemini.go b/controller/video_proxy_gemini.go index 053ac6515..a63a2a5c4 100644 --- a/controller/video_proxy_gemini.go +++ b/controller/video_proxy_gemini.go @@ -1,12 +1,12 @@ package controller import ( - "encoding/json" "fmt" "io" "strconv" "strings" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" @@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string) proxy := channel.GetSetting().Proxy resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{ - "task_id": task.TaskID, + "task_id": task.GetUpstreamTaskID(), "action": task.Action, }, proxy) if err != nil { @@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string { return "" } var payload map[string]any - if err := json.Unmarshal(task.Data, &payload); err != nil { + if err := common.Unmarshal(task.Data, &payload); err != nil { return "" } return extractGeminiVideoURLFromMap(payload) @@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string { func extractGeminiVideoURLFromPayload(body []byte) string { var payload map[string]any - if err := json.Unmarshal(body, &payload); err != nil { + if err := common.Unmarshal(body, &payload); err != nil { return "" } return extractGeminiVideoURLFromMap(payload) diff --git a/dto/suno.go b/dto/suno.go index a6bb3ebae..90e11b810 100644 --- a/dto/suno.go +++ b/dto/suno.go @@ -4,10 +4,6 @@ import ( "encoding/json" ) -type TaskData interface { - SunoDataResponse | []SunoDataResponse | string | any -} - type SunoSubmitReq struct { GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"` Prompt string `json:"prompt,omitempty"` @@ -20,10 +16,6 @@ type SunoSubmitReq struct { MakeInstrumental bool `json:"make_instrumental"` } -type FetchReq struct { - IDs []string `json:"ids"` -} - type SunoDataResponse struct { TaskID string `json:"task_id" gorm:"type:varchar(50);index"` Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode @@ -66,30 +58,6 @@ type SunoLyrics struct { Text string `json:"text"` } -const TaskSuccessCode = "success" - -type TaskResponse[T TaskData] struct { - Code string `json:"code"` - Message string `json:"message"` - Data T `json:"data"` -} - -func (t *TaskResponse[T]) IsSuccess() bool { - return t.Code == TaskSuccessCode -} - -type TaskDto struct { - TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id - Action string `json:"action"` // 任务类型, song, lyrics, description-mode - Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed - FailReason string `json:"fail_reason"` - SubmitTime int64 `json:"submit_time"` - StartTime int64 `json:"start_time"` - FinishTime int64 `json:"finish_time"` - Progress string `json:"progress"` - Data json.RawMessage `json:"data"` -} - type SunoGoAPISubmitReq struct { CustomMode bool `json:"custom_mode"` diff --git a/dto/task.go b/dto/task.go index afc186b41..4a9a8e2e6 100644 --- a/dto/task.go +++ b/dto/task.go @@ -1,5 +1,9 @@ package dto +import ( + "encoding/json" +) + type TaskError struct { Code string `json:"code"` Message string `json:"message"` @@ -8,3 +12,46 @@ type TaskError struct { LocalError bool `json:"-"` Error error `json:"-"` } + +type TaskData interface { + SunoDataResponse | []SunoDataResponse | string | any +} + +const TaskSuccessCode = "success" + +type TaskResponse[T TaskData] struct { + Code string `json:"code"` + Message string `json:"message"` + Data T `json:"data"` +} + +func (t *TaskResponse[T]) IsSuccess() bool { + return t.Code == TaskSuccessCode +} + +type TaskDto struct { + ID int64 `json:"id"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + TaskID string `json:"task_id"` + Platform string `json:"platform"` + UserId int `json:"user_id"` + Group string `json:"group"` + ChannelId int `json:"channel_id"` + Quota int `json:"quota"` + Action string `json:"action"` + Status string `json:"status"` + FailReason string `json:"fail_reason"` + ResultURL string `json:"result_url,omitempty"` // 任务结果 URL(视频地址等) + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + Progress string `json:"progress"` + Properties any `json:"properties"` + Username string `json:"username,omitempty"` + Data json.RawMessage `json:"data"` +} + +type FetchReq struct { + IDs []string `json:"ids"` +} diff --git a/main.go b/main.go index 852e1a0a8..476a2ed24 100644 --- a/main.go +++ b/main.go @@ -19,6 +19,7 @@ import ( "github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/oauth" + "github.com/QuantumNous/new-api/relay" "github.com/QuantumNous/new-api/router" "github.com/QuantumNous/new-api/service" _ "github.com/QuantumNous/new-api/setting/performance_setting" @@ -111,6 +112,15 @@ func main() { // Subscription quota reset task (daily/weekly/monthly/custom) service.StartSubscriptionQuotaResetTask() + // Wire task polling adaptor factory (breaks service -> relay import cycle) + service.GetTaskAdaptorFunc = func(platform constant.TaskPlatform) service.TaskPollingAdaptor { + a := relay.GetTaskAdaptor(platform) + if a == nil { + return nil + } + return a + } + if common.IsMasterNode && constant.UpdateTask { gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() diff --git a/middleware/auth.go b/middleware/auth.go index cf1843510..342e7f498 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -170,6 +170,24 @@ func WssAuth(c *gin.Context) { } +// TokenOrUserAuth allows either session-based user auth or API token auth. +// Used for endpoints that need to be accessible from both the dashboard and API clients. +func TokenOrUserAuth() func(c *gin.Context) { + return func(c *gin.Context) { + // Try session auth first (dashboard users) + session := sessions.Default(c) + if id := session.Get("id"); id != nil { + if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled { + c.Set("id", id) + c.Next() + return + } + } + // Fall back to token auth (API clients) + TokenAuth()(c) + } +} + // TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。 // 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。 // 即使令牌已过期、已耗尽或已禁用,也允许访问。 diff --git a/model/task.go b/model/task.go index 82c2e978a..38bb4d05a 100644 --- a/model/task.go +++ b/model/task.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" commonRelay "github.com/QuantumNous/new-api/relay/common" @@ -64,13 +65,12 @@ type Task struct { } func (t *Task) SetData(data any) { - b, _ := json.Marshal(data) + b, _ := common.Marshal(data) t.Data = json.RawMessage(b) } func (t *Task) GetData(v any) error { - err := json.Unmarshal(t.Data, &v) - return err + return common.Unmarshal(t.Data, &v) } type Properties struct { @@ -85,18 +85,48 @@ func (m *Properties) Scan(val interface{}) error { *m = Properties{} return nil } - return json.Unmarshal(bytesValue, m) + return common.Unmarshal(bytesValue, m) } func (m Properties) Value() (driver.Value, error) { if m == (Properties{}) { return nil, nil } - return json.Marshal(m) + return common.Marshal(m) } type TaskPrivateData struct { - Key string `json:"key,omitempty"` + Key string `json:"key,omitempty"` + UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID + ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等) + // 计费上下文:用于异步退款/差额结算(轮询阶段读取) + BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" + SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 + TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 +} + +// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) +// 旧数据没有 UpstreamTaskID 时,TaskID 本身就是上游 ID +func (t *Task) GetUpstreamTaskID() string { + if t.PrivateData.UpstreamTaskID != "" { + return t.PrivateData.UpstreamTaskID + } + return t.TaskID +} + +// GetResultURL 获取任务结果 URL(视频地址等) +// 新数据存在 PrivateData.ResultURL 中;旧数据回退到 FailReason(历史兼容) +func (t *Task) GetResultURL() string { + if t.PrivateData.ResultURL != "" { + return t.PrivateData.ResultURL + } + return t.FailReason +} + +// GenerateTaskID 生成对外暴露的 task_xxxx 格式 ID +func GenerateTaskID() string { + key, _ := common.GenerateRandomCharsKey(32) + return "task_" + key } func (p *TaskPrivateData) Scan(val interface{}) error { @@ -104,14 +134,14 @@ func (p *TaskPrivateData) Scan(val interface{}) error { if len(bytesValue) == 0 { return nil } - return json.Unmarshal(bytesValue, p) + return common.Unmarshal(bytesValue, p) } func (p TaskPrivateData) Value() (driver.Value, error) { if (p == TaskPrivateData{}) { return nil, nil } - return json.Marshal(p) + return common.Marshal(p) } // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 @@ -142,7 +172,16 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) } } + // 使用预生成的公开 ID(如果有),否则新生成 + taskID := "" + if relayInfo.TaskRelayInfo != nil && relayInfo.TaskRelayInfo.PublicTaskID != "" { + taskID = relayInfo.TaskRelayInfo.PublicTaskID + } else { + taskID = GenerateTaskID() + } + t := &Task{ + TaskID: taskID, UserId: relayInfo.UserId, Group: relayInfo.UsingGroup, SubmitTime: time.Now().Unix(), @@ -438,6 +477,6 @@ func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo { openAIVideo.SetProgressStr(t.Progress) openAIVideo.CreatedAt = t.CreatedAt openAIVideo.CompletedAt = t.UpdatedAt - openAIVideo.SetMetadata("url", t.FailReason) + openAIVideo.SetMetadata("url", t.GetResultURL()) return openAIVideo } diff --git a/model/token.go b/model/token.go index 9e05b63ca..773b2d792 100644 --- a/model/token.go +++ b/model/token.go @@ -360,7 +360,7 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func IncreaseTokenQuota(id int, key string, quota int) (err error) { +func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -373,10 +373,10 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { }) } if common.BatchUpdateEnabled { - addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota) return nil } - return increaseTokenQuota(id, quota) + return increaseTokenQuota(tokenId, quota) } func increaseTokenQuota(id int, quota int) (err error) { diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index d55452c08..5d14ff655 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -384,7 +384,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // 转换为 OpenAI 格式响应 openAIResp := dto.NewOpenAIVideo() - openAIResp.ID = aliResp.Output.TaskID + openAIResp.ID = info.PublicTaskID + openAIResp.TaskID = info.PublicTaskID openAIResp.Model = c.GetString("model") if openAIResp.Model == "" && info != nil { openAIResp.Model = info.OriginModelName diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 6ebecb3c0..3da125afc 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -2,7 +2,6 @@ package doubao import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -14,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -131,7 +131,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, errors.Wrap(err, "convert request payload failed") } info.UpstreamModelName = body.Model - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -154,7 +154,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // Parse Doubao response var dResp responsePayload - if err := json.Unmarshal(responseBody, &dResp); err != nil { + if err := common.Unmarshal(responseBody, &dResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -165,8 +165,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = dResp.ID - ov.TaskID = dResp.ID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName @@ -234,12 +234,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } @@ -248,7 +243,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := responseTask{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } @@ -286,7 +281,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var dResp responseTask - if err := json.Unmarshal(originTask.Data, &dResp); err != nil { + if err := common.Unmarshal(originTask.Data, &dResp); err != nil { return nil, errors.Wrap(err, "unmarshal doubao task data failed") } @@ -307,6 +302,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index 16c6919b7..a863ea852 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -2,8 +2,6 @@ package gemini import ( "bytes" - "encoding/base64" - "encoding/json" "fmt" "io" "net/http" @@ -16,10 +14,10 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/model_setting" - "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" ) @@ -145,16 +143,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &body.Parameters) - if err != nil { + if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -175,16 +168,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var s submitResponse - if err := json.Unmarshal(responseBody, &s); err != nil { + if err := common.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } - taskID = encodeLocalTaskID(s.Name) + taskID = taskcommon.EncodeLocalTaskID(s.Name) ov := dto.NewOpenAIVideo() - ov.ID = taskID - ov.TaskID = taskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -206,7 +199,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy return nil, fmt.Errorf("invalid task_id") } - upstreamName, err := decodeLocalTaskID(taskID) + upstreamName, err := taskcommon.DecodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } @@ -232,7 +225,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse - if err := json.Unmarshal(respBody, &op); err != nil { + if err := common.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } @@ -254,9 +247,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e ti.Status = model.TaskStatusSuccess ti.Progress = "100%" - taskID := encodeLocalTaskID(op.Name) - ti.TaskID = taskID - ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID) + ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name) + // Url intentionally left empty — the caller constructs the proxy URL using the public task ID // Extract URL from generateVideoResponse if available if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 { @@ -269,7 +261,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - upstreamName, err := decodeLocalTaskID(task.TaskID) + // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction. + // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name. + upstreamTaskID := task.GetUpstreamTaskID() + upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID) if err != nil { upstreamName = "" } @@ -297,18 +292,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { // helpers // ============================ -func encodeLocalTaskID(name string) string { - return base64.RawURLEncoding.EncodeToString([]byte(name)) -} - -func decodeLocalTaskID(local string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(local) - if err != nil { - return "", err - } - return string(b), nil -} - var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) func extractModelFromOperationName(name string) string { diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index c77905bfb..67a68a10e 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -2,7 +2,6 @@ package hailuo import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -65,7 +64,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, errors.Wrap(err, "convert request payload failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -86,7 +85,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var hResp VideoResponse - if err := json.Unmarshal(responseBody, &hResp); err != nil { + if err := common.Unmarshal(responseBody, &hResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -101,8 +100,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = hResp.TaskID - ov.TaskID = hResp.TaskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName @@ -182,7 +181,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := QueryTaskResponse{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } @@ -224,7 +223,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var hailuoResp QueryTaskResponse - if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil { + if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil { return nil, errors.Wrap(err, "unmarshal hailuo task data failed") } @@ -271,7 +270,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string { } var retrieveResp RetrieveFileResponse - if err := json.Unmarshal(responseBody, &retrieveResp); err != nil { + if err := common.Unmarshal(responseBody, &retrieveResp); err != nil { return "" } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 1522a967f..7f88be248 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -6,7 +6,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/hex" - "encoding/json" "fmt" "io" "net/http" @@ -25,6 +24,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) @@ -168,7 +168,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -191,7 +191,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // Parse Jimeng response var jResp responsePayload - if err := json.Unmarshal(responseBody, &jResp); err != nil { + if err := common.Unmarshal(responseBody, &jResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -202,8 +202,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = jResp.Data.TaskID - ov.TaskID = jResp.Data.TaskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -225,7 +225,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774 "task_id": taskID, } - payloadBytes, err := json.Marshal(payload) + payloadBytes, err := common.Marshal(payload) if err != nil { return nil, errors.Wrap(err, "marshal fetch task payload failed") } @@ -398,13 +398,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* r.BinaryDataBase64 = req.Images } } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } @@ -432,7 +426,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := responseTask{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } taskResult := relaycommon.TaskInfo{} @@ -458,7 +452,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var jimengResp responseTask - if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil { + if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil { return nil, errors.Wrap(err, "unmarshal jimeng task data failed") } @@ -477,8 +471,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } func isNewAPIRelay(apiKey string) bool { diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 5fb853481..4458626b2 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -2,7 +2,6 @@ package kling import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -21,6 +20,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) @@ -156,7 +156,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if body.Image == "" && body.ImageTail == "" { c.Set("action", constant.TaskActionTextGenerate) } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -180,7 +180,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } var kResp responsePayload - err = json.Unmarshal(responseBody, &kResp) + err = common.Unmarshal(responseBody, &kResp) if err != nil { taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) return @@ -190,8 +190,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } ov := dto.NewOpenAIVideo() - ov.ID = kResp.Data.TaskId - ov.TaskID = kResp.Data.TaskId + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -251,8 +251,8 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* r := requestPayload{ Prompt: req.Prompt, Image: req.Image, - Mode: defaultString(req.Mode, "std"), - Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + Mode: taskcommon.DefaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)), AspectRatio: a.getAspectRatio(req.Size), ModelName: req.Model, Model: req.Model, // Keep consistent with model_name, double writing improves compatibility @@ -266,13 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* if r.ModelName == "" { r.ModelName = "kling-v1" } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } return &r, nil @@ -291,20 +285,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string { } } -func defaultString(s, def string) string { - if strings.TrimSpace(s) == "" { - return def - } - return s -} - -func defaultInt(v int, def int) int { - if v == 0 { - return def - } - return v -} - // ============================ // JWT helpers // ============================ @@ -340,7 +320,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { taskInfo := &relaycommon.TaskInfo{} resPayload := responsePayload{} - err := json.Unmarshal(respBody, &resPayload) + err := common.Unmarshal(respBody, &resPayload) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } @@ -374,7 +354,7 @@ func isNewAPIRelay(apiKey string) bool { func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var klingResp responsePayload - if err := json.Unmarshal(originTask.Data, &klingResp); err != nil { + if err := common.Unmarshal(originTask.Data, &klingResp); err != nil { return nil, errors.Wrap(err, "unmarshal kling task data failed") } @@ -401,6 +381,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro Code: fmt.Sprintf("%d", klingResp.Code), } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index c149f9663..ee69a3e48 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -13,7 +13,6 @@ import ( "github.com/QuantumNous/new-api/relay/channel" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" - "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" @@ -116,7 +115,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -131,17 +130,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco return } - if dResp.ID == "" { - if dResp.TaskID == "" { - taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) - return - } - dResp.ID = dResp.TaskID - dResp.TaskID = "" + upstreamID := dResp.ID + if upstreamID == "" { + upstreamID = dResp.TaskID + } + if upstreamID == "" { + taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) + return } + // 使用公开 task_xxxx ID 返回给客户端 + dResp.ID = info.PublicTaskID + dResp.TaskID = info.PublicTaskID c.JSON(http.StatusOK, dResp) - return dResp.ID, responseBody, nil + return upstreamID, responseBody, nil } // FetchTask fetch task status @@ -192,7 +194,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e taskResult.Status = model.TaskStatusInProgress case "completed": taskResult.Status = model.TaskStatusSuccess - taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID) + // Url intentionally left empty — the caller constructs the proxy URL using the public task ID case "failed", "cancelled": taskResult.Status = model.TaskStatusFailure if resTask.Error != nil { diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 8ea9a1c7f..5dd62a70f 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -3,7 +3,6 @@ package suno import ( "bytes" "context" - "encoding/json" "fmt" "io" "net/http" @@ -24,8 +23,12 @@ type TaskAdaptor struct { ChannelType int } +// ParseTaskResult is not used for Suno tasks. +// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that +// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API. +// This differs from the per-task polling used by video adaptors. func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { - return nil, fmt.Errorf("not implement") // todo implement this method if needed + return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable") } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { @@ -81,7 +84,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, err } } - data, err := json.Marshal(sunoRequest) + data, err := common.Marshal(sunoRequest) if err != nil { return nil, err } @@ -99,7 +102,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } var sunoResponse dto.TaskResponse[string] - err = json.Unmarshal(responseBody, &sunoResponse) + err = common.Unmarshal(responseBody, &sunoResponse) if err != nil { taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return @@ -109,17 +112,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody)) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - return + // 使用公开 task_xxxx ID 替换上游 ID 返回给客户端 + publicResponse := dto.TaskResponse[string]{ + Code: sunoResponse.Code, + Message: sunoResponse.Message, + Data: info.PublicTaskID, } + c.JSON(http.StatusOK, publicResponse) return sunoResponse.Data, nil, nil } @@ -134,7 +133,7 @@ func (a *TaskAdaptor) GetChannelName() string { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl) - byteBody, err := json.Marshal(body) + byteBody, err := common.Marshal(body) if err != nil { return nil, err } diff --git a/relay/channel/task/taskcommon/helpers.go b/relay/channel/task/taskcommon/helpers.go new file mode 100644 index 000000000..b1dde998b --- /dev/null +++ b/relay/channel/task/taskcommon/helpers.go @@ -0,0 +1,70 @@ +package taskcommon + +import ( + "encoding/base64" + "fmt" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/setting/system_setting" +) + +// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip. +// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target). +func UnmarshalMetadata(metadata map[string]any, target any) error { + if metadata == nil { + return nil + } + metaBytes, err := common.Marshal(metadata) + if err != nil { + return fmt.Errorf("marshal metadata failed: %w", err) + } + if err := common.Unmarshal(metaBytes, target); err != nil { + return fmt.Errorf("unmarshal metadata failed: %w", err) + } + return nil +} + +// DefaultString returns val if non-empty, otherwise fallback. +func DefaultString(val, fallback string) string { + if val == "" { + return fallback + } + return val +} + +// DefaultInt returns val if non-zero, otherwise fallback. +func DefaultInt(val, fallback int) int { + if val == 0 { + return fallback + } + return val +} + +// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string. +// Used by Gemini/Vertex to store upstream names as task IDs. +func EncodeLocalTaskID(name string) string { + return base64.RawURLEncoding.EncodeToString([]byte(name)) +} + +// DecodeLocalTaskID decodes a base64-encoded upstream operation name. +func DecodeLocalTaskID(id string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(id) + if err != nil { + return "", err + } + return string(b), nil +} + +// BuildProxyURL constructs the video proxy URL using the public task ID. +// e.g., "https://your-server.com/v1/videos/task_xxxx/content" +func BuildProxyURL(taskID string) string { + return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID) +} + +// Status-to-progress mapping constants for polling updates. +const ( + ProgressSubmitted = "10%" + ProgressQueued = "20%" + ProgressInProgress = "30%" + ProgressComplete = "100%" +) diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index 8ec77266e..fb3a313ff 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -2,13 +2,12 @@ package vertex import ( "bytes" - "encoding/base64" - "encoding/json" "fmt" "io" "net/http" "regexp" "strings" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" @@ -17,6 +16,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -82,7 +82,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { + if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } modelName := info.OriginModelName @@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info req.Header.Set("Accept", "application/json") adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { + if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return fmt.Errorf("failed to decode credentials: %w", err) } @@ -184,7 +184,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn // info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int)) // } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -205,14 +205,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var s submitResponse - if err := json.Unmarshal(responseBody, &s); err != nil { + if err := common.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } - localID := encodeLocalTaskID(s.Name) - c.JSON(http.StatusOK, gin.H{"task_id": localID}) + localID := taskcommon.EncodeLocalTaskID(s.Name) + ov := dto.NewOpenAIVideo() + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID + ov.CreatedAt = time.Now().Unix() + ov.Model = info.OriginModelName + c.JSON(http.StatusOK, ov) return localID, responseBody, nil } @@ -225,7 +230,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy if !ok { return nil, fmt.Errorf("invalid task_id") } - upstreamName, err := decodeLocalTaskID(taskID) + upstreamName, err := taskcommon.DecodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } @@ -245,12 +250,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) } payload := map[string]string{"operationName": upstreamName} - data, err := json.Marshal(payload) + data, err := common.Marshal(payload) if err != nil { return nil, err } adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(key), adc); err != nil { + if err := common.Unmarshal([]byte(key), adc); err != nil { return nil, fmt.Errorf("failed to decode credentials: %w", err) } token, err := vertexcore.AcquireAccessToken(*adc, proxy) @@ -274,7 +279,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse - if err := json.Unmarshal(respBody, &op); err != nil { + if err := common.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } ti := &relaycommon.TaskInfo{} @@ -338,7 +343,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - upstreamName, err := decodeLocalTaskID(task.TaskID) + // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction. + // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name. + upstreamTaskID := task.GetUpstreamTaskID() + upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID) if err != nil { upstreamName = "" } @@ -353,8 +361,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { v.SetProgressStr(task.Progress) v.CreatedAt = task.CreatedAt v.CompletedAt = task.UpdatedAt - if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 { - v.SetMetadata("url", task.FailReason) + if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 { + v.SetMetadata("url", resultURL) } return common.Marshal(v) @@ -364,18 +372,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { // helpers // ============================ -func encodeLocalTaskID(name string) string { - return base64.RawURLEncoding.EncodeToString([]byte(name)) -} - -func decodeLocalTaskID(local string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(local) - if err != nil { - return "", err - } - return string(b), nil -} - var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`) func extractRegionFromOperationName(name string) string { diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 3657161c0..1bab12f03 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -2,7 +2,6 @@ package vidu import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -16,6 +15,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -127,7 +127,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -168,7 +168,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } var vResp responsePayload - err = json.Unmarshal(responseBody, &vResp) + err = common.Unmarshal(responseBody, &vResp) if err != nil { taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError) return @@ -180,8 +180,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = vResp.TaskId - ov.TaskID = vResp.TaskId + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -225,45 +225,25 @@ func (a *TaskAdaptor) GetChannelName() string { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { r := requestPayload{ - Model: defaultString(req.Model, "viduq1"), + Model: taskcommon.DefaultString(req.Model, "viduq1"), Images: req.Images, Prompt: req.Prompt, - Duration: defaultInt(req.Duration, 5), - Resolution: defaultString(req.Size, "1080p"), + Duration: taskcommon.DefaultInt(req.Duration, 5), + Resolution: taskcommon.DefaultString(req.Size, "1080p"), MovementAmplitude: "auto", Bgm: false, } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } return &r, nil } -func defaultString(value, defaultValue string) string { - if value == "" { - return defaultValue - } - return value -} - -func defaultInt(value, defaultValue int) int { - if value == 0 { - return defaultValue - } - return value -} - func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { taskInfo := &relaycommon.TaskInfo{} var taskResp taskResultResponse - err := json.Unmarshal(respBody, &taskResp) + err := common.Unmarshal(respBody, &taskResp) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } @@ -293,7 +273,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var viduResp taskResultResponse - if err := json.Unmarshal(originTask.Data, &viduResp); err != nil { + if err := common.Unmarshal(originTask.Data, &viduResp); err != nil { return nil, errors.Wrap(err, "unmarshal vidu task data failed") } @@ -315,6 +295,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 81b7d21d6..b68826812 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -118,8 +118,12 @@ type RelayInfo struct { SendResponseCount int ReceivedResponseCount int FinalPreConsumedQuota int // 最终预消耗的配额 + // ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路, + // 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行, + // 必须在提交前锁定全额。 + ForcePreConsume bool // Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。 - // 免费模型和按次计费(MJ/Task)时为 nil。 + // 免费模型时为 nil。 Billing BillingSettler // BillingSource indicates whether this request is billed from wallet quota or subscription. // "" or "wallet" => wallet; "subscription" => subscription @@ -525,8 +529,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req return nil, errors.New("request is not a OpenAIResponsesCompactionRequest") case types.RelayFormatTask: info = genBaseRelayInfo(c, nil) + info.TaskRelayInfo = &TaskRelayInfo{} case types.RelayFormatMjProxy: info = genBaseRelayInfo(c, nil) + info.TaskRelayInfo = &TaskRelayInfo{} default: err = errors.New("invalid relay format") } @@ -608,6 +614,9 @@ func (info *RelayInfo) HasSendResponse() bool { type TaskRelayInfo struct { Action string OriginTaskID string + // PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID, + // 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。 + PublicTaskID string ConsumeQuota bool } @@ -667,11 +676,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error { func (t *TaskSubmitReq) UnmarshalMetadata(v any) error { metadata := t.Metadata if metadata != nil { - metadataBytes, err := json.Marshal(metadata) + metadataBytes, err := common.Marshal(metadata) if err != nil { return fmt.Errorf("marshal metadata failed: %w", err) } - err = json.Unmarshal(metadataBytes, v) + err = common.Unmarshal(metadataBytes, v) if err != nil { return fmt.Errorf("unmarshal metadata to target failed: %w", err) } diff --git a/relay/helper/price.go b/relay/helper/price.go index c310220fe..1cb04166f 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens } // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData { groupRatioInfo := HandleGroupRatio(c, info) modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) @@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types. } } quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) - priceData := types.PerCallPriceData{ + + // 免费模型检测(与 ModelPriceHelper 对齐) + freeModel := false + if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume { + if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 { + quota = 0 + freeModel = true + } + } + + priceData := types.PriceData{ + FreeModel: freeModel, ModelPrice: modelPrice, Quota: quota, GroupRatioInfo: groupRatioInfo, diff --git a/relay/relay_task.go b/relay/relay_task.go index ebbd1f65d..d372ca2e8 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "errors" "fmt" "io" @@ -15,29 +14,33 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" + "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" - "github.com/QuantumNous/new-api/setting/ratio_setting" - "github.com/gin-gonic/gin" ) -/* -Task 任务通过平台、Action 区分任务 -*/ -func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - info.InitChannelMeta(c) - // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields - if info.TaskRelayInfo == nil { - info.TaskRelayInfo = &relaycommon.TaskRelayInfo{} - } +type TaskSubmitResult struct { + UpstreamTaskID string + TaskData []byte + Platform constant.TaskPlatform + ModelName string + Quota int + //PerCallPrice types.PriceData +} + +// ResolveOriginTask 处理基于已有任务的提交(remix / continuation): +// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过 +// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。 +// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。 +func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { + // 检测 remix action path := c.Request.URL.Path if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") { info.Action = constant.TaskActionRemix } - - // 提取 remix 任务的 video_id if info.Action == constant.TaskActionRemix { videoID := c.Param("video_id") if strings.TrimSpace(videoID) == "" { @@ -46,241 +49,164 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. info.OriginTaskID = videoID } - platform := constant.TaskPlatform(c.GetString("platform")) + if info.OriginTaskID == "" { + return nil + } - // 获取原始任务信息 - if info.OriginTaskID != "" { - originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) - return - } - if !exist { - taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) - return - } - if info.OriginModelName == "" { - if originTask.Properties.OriginModelName != "" { - info.OriginModelName = originTask.Properties.OriginModelName - } else if originTask.Properties.UpstreamModelName != "" { - info.OriginModelName = originTask.Properties.UpstreamModelName - } else { - var taskData map[string]interface{} - _ = json.Unmarshal(originTask.Data, &taskData) - if m, ok := taskData["model"].(string); ok && m != "" { - info.OriginModelName = m - platform = originTask.Platform - } - } - } - if originTask.ChannelId != info.ChannelId { - channel, err := model.GetChannelById(originTask.ChannelId, true) - if err != nil { - taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) - return - } - if channel.Status != common.ChannelStatusEnabled { - taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) - return - } - key, _, newAPIError := channel.GetNextEnabledKey() - if newAPIError != nil { - taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) - return - } - common.SetContextKey(c, constant.ContextKeyChannelKey, key) - common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) - common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) - common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) + // 查找原始任务 + originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) + if err != nil { + return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) + } + if !exist { + return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) + } - info.ChannelBaseUrl = channel.GetBaseURL() - info.ChannelId = originTask.ChannelId - info.ChannelType = channel.Type - info.ApiKey = key - platform = originTask.Platform - } - - // 使用原始任务的参数 - if info.Action == constant.TaskActionRemix { + // 从原始任务推导模型名称 + if info.OriginModelName == "" { + if originTask.Properties.OriginModelName != "" { + info.OriginModelName = originTask.Properties.OriginModelName + } else if originTask.Properties.UpstreamModelName != "" { + info.OriginModelName = originTask.Properties.UpstreamModelName + } else { var taskData map[string]interface{} - _ = json.Unmarshal(originTask.Data, &taskData) - secondsStr, _ := taskData["seconds"].(string) - seconds, _ := strconv.Atoi(secondsStr) - if seconds <= 0 { - seconds = 4 - } - sizeStr, _ := taskData["size"].(string) - if info.PriceData.OtherRatios == nil { - info.PriceData.OtherRatios = map[string]float64{} - } - info.PriceData.OtherRatios["seconds"] = float64(seconds) - info.PriceData.OtherRatios["size"] = 1 - if sizeStr == "1792x1024" || sizeStr == "1024x1792" { - info.PriceData.OtherRatios["size"] = 1.666667 + _ = common.Unmarshal(originTask.Data, &taskData) + if m, ok := taskData["model"].(string); ok && m != "" { + info.OriginModelName = m } } } + + // 锁定到原始任务的渠道(如果与当前选中的不同) + if originTask.ChannelId != info.ChannelId { + ch, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + } + if ch.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) + } + key, _, newAPIError := ch.GetNextEnabledKey() + if newAPIError != nil { + return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) + } + common.SetContextKey(c, constant.ContextKeyChannelKey, key) + common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type) + common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL()) + common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) + + info.ChannelBaseUrl = ch.GetBaseURL() + info.ChannelId = originTask.ChannelId + info.ChannelType = ch.Type + info.ApiKey = key + } + + // 渠道已锁定到原始任务 → 禁止重试切换到其他渠道 + c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId)) + + // 提取 remix 参数(时长、分辨率 → OtherRatios) + if info.Action == constant.TaskActionRemix { + var taskData map[string]interface{} + _ = common.Unmarshal(originTask.Data, &taskData) + secondsStr, _ := taskData["seconds"].(string) + seconds, _ := strconv.Atoi(secondsStr) + if seconds <= 0 { + seconds = 4 + } + sizeStr, _ := taskData["size"].(string) + if info.PriceData.OtherRatios == nil { + info.PriceData.OtherRatios = map[string]float64{} + } + info.PriceData.OtherRatios["seconds"] = float64(seconds) + info.PriceData.OtherRatios["size"] = 1 + if sizeStr == "1792x1024" || sizeStr == "1024x1792" { + info.PriceData.OtherRatios["size"] = 1.666667 + } + } + + return nil +} + +// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次): +// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 → +// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。 +// 控制器负责 defer Refund 和成功后 Settle。 +func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) { + info.InitChannelMeta(c) + + // 1. 确定 platform → 创建适配器 → 验证请求 + platform := constant.TaskPlatform(c.GetString("platform")) if platform == "" { platform = GetTaskPlatform(c) } - - info.InitChannelMeta(c) adaptor := GetTaskAdaptor(platform) if adaptor == nil { - return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) + return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) } adaptor.Init(info) - // get & validate taskRequest 获取并验证文本请求 - taskErr = adaptor.ValidateRequestAndSetAction(c, info) - if taskErr != nil { - return + if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil { + return nil, taskErr } + // 2. 确定模型名称 modelName := info.OriginModelName if modelName == "" { modelName = service.CoverTaskActionToModelName(platform, info.Action) } - modelPrice, success := ratio_setting.GetModelPrice(modelName, true) - if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName] - if !ok { - modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit - } else { - modelPrice = defaultPrice - } + + // 3. 预生成公开 task ID(仅首次) + if info.PublicTaskID == "" { + info.PublicTaskID = model.GenerateTaskID() } - // 处理 auto 分组:从 context 获取实际选中的分组 - // 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中 - if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists { - if groupStr, ok := autoGroup.(string); ok && groupStr != "" { - info.UsingGroup = groupStr - } - } + // 4. 价格计算 + info.OriginModelName = modelName + info.PriceData = helper.ModelPriceHelperPerCall(c, info) - // 预扣 - groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup) - var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup) - if hasUserGroupRatio { - ratio = modelPrice * userGroupRatio - } else { - ratio = modelPrice * groupRatio - } - // FIXME: 临时修补,支持任务仅按次计费 if !common.StringsContains(constant.TaskPricePatches, modelName) { - if len(info.PriceData.OtherRatios) > 0 { - for _, ra := range info.PriceData.OtherRatios { - if 1.0 != ra { - ratio *= ra - } + for _, ra := range info.PriceData.OtherRatios { + if ra != 1.0 { + info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra) } } } - println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio)) - userQuota, err := model.GetUserQuota(info.UserId, false) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - return - } - quota := int(ratio * common.QuotaPerUnit) - if userQuota-quota < 0 { - taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden) - return + + // 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) + if info.Billing == nil && !info.PriceData.FreeModel { + info.ForcePreConsume = true + if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil { + return nil, service.TaskErrorFromAPIError(apiErr) + } } - // build body + // 6. 构建请求体 requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { - taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) - return + return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) } - // do request + + // 7. 发送请求 resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) - return + return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - // handle response if resp != nil && resp.StatusCode != http.StatusOK { responseBody, _ := io.ReadAll(resp.Body) - taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) - return + return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) } - defer func() { - // release quota - if info.ConsumeQuota && taskErr == nil { - - err := service.PostConsumeQuota(info, quota, 0, true) - if err != nil { - common.SysLog("error consuming token remain quota: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - //gRatio := groupRatio - //if hasUserGroupRatio { - // gRatio = userGroupRatio - //} - logContent := fmt.Sprintf("操作 %s", info.Action) - // FIXME: 临时修补,支持任务仅按次计费 - if common.StringsContains(constant.TaskPricePatches, modelName) { - logContent = fmt.Sprintf("%s,按次计费", logContent) - } else { - if len(info.PriceData.OtherRatios) > 0 { - var contents []string - for key, ra := range info.PriceData.OtherRatios { - if 1.0 != ra { - contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) - } - } - if len(contents) > 0 { - logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) - } - } - } - other := make(map[string]interface{}) - if c != nil && c.Request != nil && c.Request.URL != nil { - other["request_path"] = c.Request.URL.Path - } - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - if hasUserGroupRatio { - other["user_group_ratio"] = userGroupRatio - } - model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ - ChannelId: info.ChannelId, - ModelName: modelName, - TokenName: tokenName, - Quota: quota, - Content: logContent, - TokenId: info.TokenId, - Group: info.UsingGroup, - Other: other, - }) - model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota) - model.UpdateChannelUsedQuota(info.ChannelId, quota) - } - } - }() - - taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) + // 8. 解析响应 + upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { - return + return nil, taskErr } - info.ConsumeQuota = true - // insert task - task := model.InitTask(platform, info) - task.TaskID = taskID - task.Quota = quota - task.Data = taskData - task.Action = info.Action - err = task.Insert() - if err != nil { - taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) - return - } - return nil + + return &TaskSubmitResult{ + UpstreamTaskID: upstreamTaskID, + TaskData: taskData, + Platform: platform, + ModelName: modelName, + }, nil } var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ @@ -336,7 +262,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta } else { tasks = make([]any, 0) } - respBody, err = json.Marshal(dto.TaskResponse[[]any]{ + respBody, err = common.Marshal(dto.TaskResponse[[]any]{ Code: "success", Data: tasks, }) @@ -357,7 +283,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ + respBody, err = common.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) @@ -381,97 +307,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } - func() { - channelModel, err2 := model.GetChannelById(originTask.ChannelId, true) - if err2 != nil { - return - } - if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { - return - } - baseURL := constant.ChannelBaseURLs[channelModel.Type] - if channelModel.GetBaseURL() != "" { - baseURL = channelModel.GetBaseURL() - } - proxy := channelModel.GetSetting().Proxy - adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) - if adaptor == nil { - return - } - resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ - "task_id": originTask.TaskID, - "action": originTask.Action, - }, proxy) - if err2 != nil || resp == nil { - return - } - defer resp.Body.Close() - body, err2 := io.ReadAll(resp.Body) - if err2 != nil { - return - } - ti, err2 := adaptor.ParseTaskResult(body) - if err2 == nil && ti != nil { - if ti.Status != "" { - originTask.Status = model.TaskStatus(ti.Status) - } - if ti.Progress != "" { - originTask.Progress = ti.Progress - } - if ti.Url != "" { - if strings.HasPrefix(ti.Url, "data:") { - } else { - originTask.FailReason = ti.Url - } - } - _ = originTask.Update() - var raw map[string]any - _ = json.Unmarshal(body, &raw) - format := "mp4" - if respObj, ok := raw["response"].(map[string]any); ok { - if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 { - if v0, ok := vids[0].(map[string]any); ok { - if mt, ok := v0["mimeType"].(string); ok && mt != "" { - if strings.Contains(mt, "mp4") { - format = "mp4" - } else { - format = mt - } - } - } - } - } - status := "processing" - switch originTask.Status { - case model.TaskStatusSuccess: - status = "succeeded" - case model.TaskStatusFailure: - status = "failed" - case model.TaskStatusQueued, model.TaskStatusSubmitted: - status = "queued" - } - if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { - out := map[string]any{ - "error": nil, - "format": format, - "metadata": nil, - "status": status, - "task_id": originTask.TaskID, - "url": originTask.FailReason, - } - respBody, _ = json.Marshal(dto.TaskResponse[any]{ - Code: "success", - Data: out, - }) - } - } - }() + isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") - if len(respBody) != 0 { + // Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态 + if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 { + respBody = realtimeResp return } - if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { + // OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo + if isOpenAIVideoAPI { adaptor := GetTaskAdaptor(originTask.Platform) if adaptor == nil { taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest) @@ -486,10 +331,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d respBody = openAIVideoData return } - taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented) + taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented) return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ + + // 通用 TaskDto 格式 + respBody, err = common.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) @@ -499,16 +346,145 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } +// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。 +// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。 +// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。 +func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { + channelModel, err := model.GetChannelById(task.ChannelId, true) + if err != nil { + return nil + } + if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { + return nil + } + + baseURL := constant.ChannelBaseURLs[channelModel.Type] + if channelModel.GetBaseURL() != "" { + baseURL = channelModel.GetBaseURL() + } + proxy := channelModel.GetSetting().Proxy + adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) + if adaptor == nil { + return nil + } + + resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ + "task_id": task.GetUpstreamTaskID(), + "action": task.Action, + }, proxy) + if err != nil || resp == nil { + return nil + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil + } + + ti, err := adaptor.ParseTaskResult(body) + if err != nil || ti == nil { + return nil + } + + // 将上游最新状态更新到 task + if ti.Status != "" { + task.Status = model.TaskStatus(ti.Status) + } + if ti.Progress != "" { + task.Progress = ti.Progress + } + if strings.HasPrefix(ti.Url, "data:") { + // data: URI — kept in Data, not ResultURL + } else if ti.Url != "" { + task.PrivateData.ResultURL = ti.Url + } else if task.Status == model.TaskStatusSuccess { + // No URL from adaptor — construct proxy URL using public task ID + task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) + } + _ = task.Update() + + // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理 + if isOpenAIVideoAPI { + return nil + } + + // 非 OpenAI Video API: 构建自定义格式响应 + format := detectVideoFormat(body) + out := map[string]any{ + "error": nil, + "format": format, + "metadata": nil, + "status": mapTaskStatusToSimple(task.Status), + "task_id": task.TaskID, + "url": task.GetResultURL(), + } + respBody, _ := common.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: out, + }) + return respBody +} + +// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式 +func detectVideoFormat(rawBody []byte) string { + var raw map[string]any + if err := common.Unmarshal(rawBody, &raw); err != nil { + return "mp4" + } + respObj, ok := raw["response"].(map[string]any) + if !ok { + return "mp4" + } + vids, ok := respObj["videos"].([]any) + if !ok || len(vids) == 0 { + return "mp4" + } + v0, ok := vids[0].(map[string]any) + if !ok { + return "mp4" + } + mt, ok := v0["mimeType"].(string) + if !ok || mt == "" || strings.Contains(mt, "mp4") { + return "mp4" + } + return mt +} + +// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串 +func mapTaskStatusToSimple(status model.TaskStatus) string { + switch status { + case model.TaskStatusSuccess: + return "succeeded" + case model.TaskStatusFailure: + return "failed" + case model.TaskStatusQueued, model.TaskStatusSubmitted: + return "queued" + default: + return "processing" + } +} + func TaskModel2Dto(task *model.Task) *dto.TaskDto { return &dto.TaskDto{ + ID: task.ID, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, TaskID: task.TaskID, + Platform: string(task.Platform), + UserId: task.UserId, + Group: task.Group, + ChannelId: task.ChannelId, + Quota: task.Quota, Action: task.Action, Status: string(task.Status), FailReason: task.FailReason, + ResultURL: task.GetResultURL(), SubmitTime: task.SubmitTime, StartTime: task.StartTime, FinishTime: task.FinishTime, Progress: task.Progress, + Properties: task.Properties, + Username: task.Username, Data: task.Data, } } diff --git a/router/video-router.go b/router/video-router.go index d5fed1d78..d2bce42b2 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -8,10 +8,16 @@ import ( ) func SetVideoRouter(router *gin.Engine) { + // Video proxy: accepts either session auth (dashboard) or token auth (API clients) + videoProxyRouter := router.Group("/v1") + videoProxyRouter.Use(middleware.TokenOrUserAuth()) + { + videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy) + } + videoV1Router := router.Group("/v1") videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { - videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy) videoV1Router.POST("/video/generations", controller.RelayTask) videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask) diff --git a/service/billing_session.go b/service/billing_session.go index 1a31316b5..f24b68e55 100644 --- a/service/billing_session.go +++ b/service/billing_session.go @@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro // shouldTrust 统一信任额度检查,适用于钱包和订阅。 func (s *BillingSession) shouldTrust(c *gin.Context) bool { + // 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路 + if s.relayInfo.ForcePreConsume { + return false + } + trustQuota := common.GetTrustQuota() if trustQuota <= 0 { return false diff --git a/service/error.go b/service/error.go index 7a9d7a815..a2ff0aad7 100644 --- a/service/error.go +++ b/service/error.go @@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { return taskError } + +// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。 +func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError { + if apiErr == nil { + return nil + } + return &dto.TaskError{ + Code: string(apiErr.GetErrorCode()), + Message: apiErr.Err.Error(), + StatusCode: apiErr.StatusCode, + Error: apiErr.Err, + } +} diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 771da5b77..1c440911b 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, return info } -func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} { +func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} { other := make(map[string]interface{}) other["model_price"] = priceData.ModelPrice other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio diff --git a/service/task_billing.go b/service/task_billing.go new file mode 100644 index 000000000..ec0094bd9 --- /dev/null +++ b/service/task_billing.go @@ -0,0 +1,227 @@ +package service + +import ( + "context" + "fmt" + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/ratio_setting" + "github.com/gin-gonic/gin" +) + +// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。 +// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。 +func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("操作 %s", info.Action) + // 支持任务仅按次计费 + if common.StringsContains(constant.TaskPricePatches, modelName) { + logContent = fmt.Sprintf("%s,按次计费", logContent) + } else { + if len(info.PriceData.OtherRatios) > 0 { + var contents []string + for key, ra := range info.PriceData.OtherRatios { + if 1.0 != ra { + contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) + } + } + if len(contents) > 0 { + logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) + } + } + } + other := make(map[string]interface{}) + other["request_path"] = c.Request.URL.Path + other["model_price"] = info.PriceData.ModelPrice + other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio + if info.PriceData.GroupRatioInfo.HasSpecialRatio { + other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio + } + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, + ModelName: modelName, + TokenName: tokenName, + Quota: info.PriceData.Quota, + Content: logContent, + TokenId: info.TokenId, + Group: info.UsingGroup, + Other: other, + }) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota) + model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota) +} + +// --------------------------------------------------------------------------- +// 异步任务计费辅助函数 +// --------------------------------------------------------------------------- + +// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。 +// 如果令牌已被删除或查询失败,返回空字符串。 +func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string { + token, err := model.GetTokenById(tokenId) + if err != nil { + logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error())) + return "" + } + return token.Key +} + +// taskIsSubscription 判断任务是否通过订阅计费。 +func taskIsSubscription(task *model.Task) bool { + return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0 +} + +// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。 +func taskAdjustFunding(task *model.Task, delta int) error { + if taskIsSubscription(task) { + return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta)) + } + if delta > 0 { + return model.DecreaseUserQuota(task.UserId, delta) + } + return model.IncreaseUserQuota(task.UserId, -delta, false) +} + +// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。 +// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。 +func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) { + if task.PrivateData.TokenId <= 0 || delta == 0 { + return + } + tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID) + if tokenKey == "" { + return + } + var err error + if delta > 0 { + err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta) + } else { + err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta) + } + if err != nil { + logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error())) + } +} + +// RefundTaskQuota 统一的任务失败退款逻辑。 +// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。 +func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { + quota := task.Quota + if quota == 0 { + return + } + + // 1. 退还资金来源(钱包或订阅) + if err := taskAdjustFunding(task, -quota); err != nil { + logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 2. 退还令牌额度 + taskAdjustTokenQuota(ctx, task, -quota) + + // 3. 记录日志 + logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s,原因:%s", task.TaskID, logger.LogQuota(quota), reason) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) +} + +// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。 +// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度, +// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。 +func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) { + if totalTokens <= 0 { + return + } + + // 获取模型名称 + var taskData map[string]interface{} + if err := common.Unmarshal(task.Data, &taskData); err != nil { + return + } + modelName, ok := taskData["model"].(string) + if !ok || modelName == "" { + return + } + + // 获取模型价格和倍率 + modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) + // 只有配置了倍率(非固定价格)时才按 token 重新计费 + if !hasRatioSetting || modelRatio <= 0 { + return + } + + // 获取用户和组的倍率信息 + group := task.Group + if group == "" { + user, err := model.GetUserById(task.UserId, false) + if err == nil { + group = user.Group + } + } + if group == "" { + return + } + + groupRatio := ratio_setting.GetGroupRatio(group) + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) + + var finalGroupRatio float64 + if hasUserGroupRatio { + finalGroupRatio = userGroupRatio + } else { + finalGroupRatio = groupRatio + } + + // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio + actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio) + + // 计算差额(正数=需要补扣,负数=需要退还) + preConsumedQuota := task.Quota + quotaDelta := actualQuota - preConsumedQuota + + if quotaDelta == 0 { + logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", + task.TaskID, logger.LogQuota(actualQuota), totalTokens)) + return + } + + logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)", + task.TaskID, + logger.LogQuota(quotaDelta), + logger.LogQuota(actualQuota), + logger.LogQuota(preConsumedQuota), + totalTokens, + )) + + // 调整资金来源 + if err := taskAdjustFunding(task, quotaDelta); err != nil { + logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 调整令牌额度 + taskAdjustTokenQuota(ctx, task, quotaDelta) + + // 更新统计(仅补扣时更新,退还不影响已用统计) + if quotaDelta > 0 { + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) + } + task.Quota = actualQuota + + var action string + if quotaDelta > 0 { + action = "补扣费" + } else { + action = "退还" + } + logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s", + action, modelRatio, finalGroupRatio, totalTokens, + logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) +} diff --git a/service/task_polling.go b/service/task_polling.go new file mode 100644 index 000000000..847e1659b --- /dev/null +++ b/service/task_polling.go @@ -0,0 +1,446 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "sort" + "strings" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" + relaycommon "github.com/QuantumNous/new-api/relay/common" + + "github.com/samber/lo" +) + +// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖 +type TaskPollingAdaptor interface { + Init(info *relaycommon.RelayInfo) + FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error) + ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error) +} + +// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。 +// 打破 service -> relay -> relay/channel -> service 的循环依赖。 +var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor + +// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务 +func TaskPollingLoop() { + for { + time.Sleep(time.Duration(15) * time.Second) + common.SysLog("任务进度轮询开始") + ctx := context.TODO() + allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) + platformTask := make(map[constant.TaskPlatform][]*model.Task) + for _, t := range allTasks { + platformTask[t.Platform] = append(platformTask[t.Platform], t) + } + for platform, tasks := range platformTask { + if len(tasks) == 0 { + continue + } + taskChannelM := make(map[int][]string) + taskM := make(map[string]*model.Task) + nullTaskIds := make([]int64, 0) + for _, task := range tasks { + upstreamID := task.GetUpstreamTaskID() + if upstreamID == "" { + // 统计失败的未完成任务 + nullTaskIds = append(nullTaskIds, task.ID) + continue + } + taskM[upstreamID] = task + taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID) + } + if len(nullTaskIds) > 0 { + err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + } else { + logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + } + } + if len(taskChannelM) == 0 { + continue + } + + DispatchPlatformUpdate(platform, taskChannelM, taskM) + } + common.SysLog("任务进度轮询完成") + } +} + +// DispatchPlatformUpdate 按平台分发轮询更新 +func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { + switch platform { + case constant.TaskPlatformMidjourney: + // MJ 轮询由其自身处理,这里预留入口 + case constant.TaskPlatformSuno: + _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM) + default: + if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil { + common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err)) + } + } +} + +// UpdateSunoTasks 按渠道更新所有 Suno 任务 +func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + err := updateSunoTasks(ctx, channelId, taskIds, taskM) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) + } + } + return nil +} + +func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + ch, err := model.CacheGetChannel(channelId) + if err != nil { + common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values) + var failedIDs []int64 + for _, upstreamID := range taskIds { + if t, ok := taskM[upstreamID]; ok { + failedIDs = append(failedIDs, t.ID) + } + } + err = model.TaskBulkUpdateByID(failedIDs, map[string]any{ + "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err)) + } + return err + } + adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno) + if adaptor == nil { + return errors.New("adaptor not found") + } + proxy := ch.GetSetting().Proxy + resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{ + "ids": taskIds, + }, proxy) + if err != nil { + common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) + return err + } + if resp.StatusCode != http.StatusOK { + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + return fmt.Errorf("Get Task status code: %d", resp.StatusCode) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) + return err + } + var responseItems dto.TaskResponse[[]dto.SunoDataResponse] + err = common.Unmarshal(responseBody, &responseItems) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + return err + } + if !responseItems.IsSuccess() { + common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) + return err + } + + for _, responseItem := range responseItems.Data { + task := taskM[responseItem.TaskID] + if !taskNeedsUpdate(task, responseItem) { + continue + } + + task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) + task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) + task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) + task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) + task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) + if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { + logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + task.Progress = "100%" + RefundTaskQuota(ctx, task, task.FailReason) + } + if responseItem.Status == model.TaskStatusSuccess { + task.Progress = "100%" + } + task.Data = responseItem.Data + + err = task.Update() + if err != nil { + common.SysLog("UpdateSunoTask task error: " + err.Error()) + } + } + return nil +} + +// taskNeedsUpdate 检查 Suno 任务是否需要更新 +func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { + if oldTask.SubmitTime != newTask.SubmitTime { + return true + } + if oldTask.StartTime != newTask.StartTime { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if string(oldTask.Status) != newTask.Status { + return true + } + if oldTask.FailReason != newTask.FailReason { + return true + } + + if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { + return true + } + + oldData, _ := common.Marshal(oldTask.Data) + newData, _ := common.Marshal(newTask.Data) + + sort.Slice(oldData, func(i, j int) bool { + return oldData[i] < oldData[j] + }) + sort.Slice(newData, func(i, j int) bool { + return newData[i] < newData[j] + }) + + if string(oldData) != string(newData) { + return true + } + return false +} + +// UpdateVideoTasks 按渠道更新所有视频任务 +func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil { + logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + } + } + return nil +} + +func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { + logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + cacheGetChannel, err := model.CacheGetChannel(channelId) + if err != nil { + // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values) + var failedIDs []int64 + for _, upstreamID := range taskIds { + if t, ok := taskM[upstreamID]; ok { + failedIDs = append(failedIDs, t.ID) + } + } + errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{ + "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if errUpdate != nil { + common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + } + return fmt.Errorf("CacheGetChannel failed: %w", err) + } + adaptor := GetTaskAdaptorFunc(platform) + if adaptor == nil { + return fmt.Errorf("video adaptor not found") + } + info := &relaycommon.RelayInfo{} + info.ChannelMeta = &relaycommon.ChannelMeta{ + ChannelBaseUrl: cacheGetChannel.GetBaseURL(), + } + info.ApiKey = cacheGetChannel.Key + adaptor.Init(info) + for _, taskId := range taskIds { + if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { + logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + } + } + return nil +} + +func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error { + baseURL := constant.ChannelBaseURLs[ch.Type] + if ch.GetBaseURL() != "" { + baseURL = ch.GetBaseURL() + } + proxy := ch.GetSetting().Proxy + + task := taskM[taskId] + if task == nil { + logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + return fmt.Errorf("task %s not found", taskId) + } + key := ch.Key + + privateData := task.PrivateData + if privateData.Key != "" { + key = privateData.Key + } + resp, err := adaptor.FetchTask(baseURL, key, map[string]any{ + "task_id": task.GetUpstreamTaskID(), + "action": task.Action, + }, proxy) + if err != nil { + return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("readAll failed for task %s: %w", taskId, err) + } + + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody))) + + taskResult := &relaycommon.TaskInfo{} + // try parse as New API response format + var responseItems dto.TaskResponse[model.Task] + if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems)) + t := responseItems.Data + taskResult.TaskID = t.TaskID + taskResult.Status = string(t.Status) + taskResult.Url = t.GetResultURL() + taskResult.Progress = t.Progress + taskResult.Reason = t.FailReason + task.Data = t.Data + } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { + return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) + } else { + task.Data = redactVideoResponseBody(responseBody) + } + + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult)) + + now := time.Now().Unix() + if taskResult.Status == "" { + taskResult = relaycommon.FailTaskInfo("upstream returned empty status") + } + + // 记录原本的状态,防止重复退款 + shouldRefund := false + quota := task.Quota + preStatus := task.Status + + task.Status = model.TaskStatus(taskResult.Status) + switch taskResult.Status { + case model.TaskStatusSubmitted: + task.Progress = taskcommon.ProgressSubmitted + case model.TaskStatusQueued: + task.Progress = taskcommon.ProgressQueued + case model.TaskStatusInProgress: + task.Progress = taskcommon.ProgressInProgress + if task.StartTime == 0 { + task.StartTime = now + } + case model.TaskStatusSuccess: + task.Progress = taskcommon.ProgressComplete + if task.FinishTime == 0 { + task.FinishTime = now + } + if strings.HasPrefix(taskResult.Url, "data:") { + // data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL + } else if taskResult.Url != "" { + // Direct upstream URL (e.g. Kling, Ali, Doubao, etc.) + task.PrivateData.ResultURL = taskResult.Url + } else { + // No URL from adaptor — construct proxy URL using public task ID + task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) + } + + // 如果返回了 total_tokens,根据模型倍率重新计费 + if taskResult.TotalTokens > 0 { + RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) + } + case model.TaskStatusFailure: + logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) + task.Status = model.TaskStatusFailure + task.Progress = taskcommon.ProgressComplete + if task.FinishTime == 0 { + task.FinishTime = now + } + task.FailReason = taskResult.Reason + logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + taskResult.Progress = taskcommon.ProgressComplete + if quota != 0 { + if preStatus != model.TaskStatusFailure { + shouldRefund = true + } else { + logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) + } + } + default: + return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) + } + if taskResult.Progress != "" { + task.Progress = taskResult.Progress + } + if err := task.Update(); err != nil { + common.SysLog("UpdateVideoTask task error: " + err.Error()) + shouldRefund = false + } + + if shouldRefund { + RefundTaskQuota(ctx, task, task.FailReason) + } + + return nil +} + +func redactVideoResponseBody(body []byte) []byte { + var m map[string]any + if err := common.Unmarshal(body, &m); err != nil { + return body + } + resp, _ := m["response"].(map[string]any) + if resp != nil { + delete(resp, "bytesBase64Encoded") + if v, ok := resp["video"].(string); ok { + resp["video"] = truncateBase64(v) + } + if vs, ok := resp["videos"].([]any); ok { + for i := range vs { + if vm, ok := vs[i].(map[string]any); ok { + delete(vm, "bytesBase64Encoded") + } + } + } + } + b, err := common.Marshal(m) + if err != nil { + return body + } + return b +} + +func truncateBase64(s string) string { + const maxKeep = 256 + if len(s) <= maxKeep { + return s + } + return s[:maxKeep] + "..." +} diff --git a/types/price_data.go b/types/price_data.go index 3f7121b8c..93bc6ae8d 100644 --- a/types/price_data.go +++ b/types/price_data.go @@ -22,7 +22,8 @@ type PriceData struct { AudioCompletionRatio float64 OtherRatios map[string]float64 UsePrice bool - QuotaToPreConsume int // 预消耗额度 + Quota int // 按次计费的最终额度(MJ / Task) + QuotaToPreConsume int // 按量计费的预消耗额度 GroupRatioInfo GroupRatioInfo } @@ -36,12 +37,6 @@ func (p *PriceData) AddOtherRatio(key string, ratio float64) { p.OtherRatios[key] = ratio } -type PerCallPriceData struct { - ModelPrice float64 - Quota int - GroupRatioInfo GroupRatioInfo -} - func (p *PriceData) ToSetting() string { return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio) } diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index c78d5773e..4bce45256 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -396,7 +396,7 @@ export const getTaskLogsColumns = ({ dataIndex: 'fail_reason', fixed: 'right', render: (text, record, index) => { - // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 + // 视频预览:优先使用 result_url,兼容旧数据 fail_reason 中的 URL const isVideoTask = record.action === TASK_ACTION_GENERATE || record.action === TASK_ACTION_TEXT_GENERATE || @@ -404,14 +404,15 @@ export const getTaskLogsColumns = ({ record.action === TASK_ACTION_REFERENCE_GENERATE || record.action === TASK_ACTION_REMIX_GENERATE; const isSuccess = record.status === 'SUCCESS'; - const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); - if (isSuccess && isVideoTask && isUrl) { + const resultUrl = record.result_url; + const hasResultUrl = typeof resultUrl === 'string' && /^https?:\/\//.test(resultUrl); + if (isSuccess && isVideoTask && hasResultUrl) { return ( { e.preventDefault(); - openVideoModal(text); + openVideoModal(resultUrl); }} > {t('点击预览视频')} diff --git a/web/src/components/table/task-logs/modals/ContentModal.jsx b/web/src/components/table/task-logs/modals/ContentModal.jsx index 88df4d8ce..3527fd96d 100644 --- a/web/src/components/table/task-logs/modals/ContentModal.jsx +++ b/web/src/components/table/task-logs/modals/ContentModal.jsx @@ -144,8 +144,6 @@ const ContentModal = ({ maxHeight: '100%', objectFit: 'contain', }} - autoPlay - crossOrigin='anonymous' onError={handleVideoError} onLoadedData={handleVideoLoaded} onLoadStart={() => setIsLoading(true)}