mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 06:20:56 +00:00
refactor(task): extract billing and polling logic from controller to service layer
Restructure the task relay system for better separation of concerns: - Extract task billing into service/task_billing.go with unified settlement flow - Move task polling loop from controller to service/task_polling.go (supports Suno + video platforms) - Split RelayTask into fetch/submit paths with dedicated retry logic (taskSubmitWithRetry) - Add TaskDto, TaskResponse generics, and FetchReq to dto/task.go - Add taskcommon/helpers.go for shared task adaptor utilities - Remove controller/task_video.go (logic consolidated into service layer) - Update all task adaptors (ali, doubao, gemini, hailuo, jimeng, kling, sora, suno, vertex, vidu) - Simplify frontend task logs to use new TaskDto response format
This commit is contained in:
@@ -451,17 +451,102 @@ func RelayNotFound(c *gin.Context) {
|
||||
}
|
||||
|
||||
func RelayTask(c *gin.Context) {
|
||||
retryTimes := common.RetryTimes
|
||||
channelId := c.GetInt("channel_id")
|
||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, &dto.TaskError{
|
||||
Code: "gen_relay_info_failed",
|
||||
Message: err.Error(),
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
})
|
||||
return
|
||||
}
|
||||
taskErr := taskRelayHandler(c, relayInfo)
|
||||
if taskErr == nil {
|
||||
retryTimes = 0
|
||||
|
||||
// Fetch 操作是纯 DB 查询(或 task 自带 channelId 的上游查询),不依赖上下文 channel,无需重试
|
||||
// TODO: 在video-route层面优化,避免无谓的 channel 选择和上下文设置,也没必要吧代码放到这里来写这么多屎山
|
||||
switch relayInfo.RelayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||
if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil {
|
||||
respondTaskError(c, taskErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ── Submit 路径 ─────────────────────────────────────────────────
|
||||
|
||||
// 1. 解析原始任务(remix / continuation),一次性,可能锁定渠道并禁止重试
|
||||
if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil {
|
||||
respondTaskError(c, taskErr)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. defer Refund(全部失败时回滚预扣费)
|
||||
var result *relay.TaskSubmitResult
|
||||
var taskErr *dto.TaskError
|
||||
defer func() {
|
||||
if taskErr != nil && relayInfo.Billing != nil {
|
||||
relayInfo.Billing.Refund(c)
|
||||
}
|
||||
}()
|
||||
|
||||
// 3. 执行 + 重试(RelayTaskSubmit 内部在首次调用时自动预扣费)
|
||||
taskErr = taskSubmitWithRetry(c, relayInfo, channelId, common.RetryTimes, func() *dto.TaskError {
|
||||
var te *dto.TaskError
|
||||
result, te = relay.RelayTaskSubmit(c, relayInfo)
|
||||
return te
|
||||
})
|
||||
|
||||
// 4. 成功:结算 + 日志 + 插入任务
|
||||
if taskErr == nil {
|
||||
if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
|
||||
common.SysError("settle task billing error: " + settleErr.Error())
|
||||
}
|
||||
service.LogTaskConsumption(c, relayInfo, result.ModelName)
|
||||
|
||||
task := model.InitTask(result.Platform, relayInfo)
|
||||
task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
|
||||
task.PrivateData.BillingSource = relayInfo.BillingSource
|
||||
task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
|
||||
task.PrivateData.TokenId = relayInfo.TokenId
|
||||
task.Quota = result.Quota
|
||||
task.Data = result.TaskData
|
||||
task.Action = relayInfo.Action
|
||||
if insertErr := task.Insert(); insertErr != nil {
|
||||
//taskErr = service.TaskErrorWrapper(insertErr, "insert_task_failed", http.StatusInternalServerError)
|
||||
common.SysError("insert task error: " + insertErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if taskErr != nil {
|
||||
respondTaskError(c, taskErr)
|
||||
}
|
||||
}
|
||||
|
||||
// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写)
|
||||
func respondTaskError(c *gin.Context, taskErr *dto.TaskError) {
|
||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
c.JSON(taskErr.StatusCode, taskErr)
|
||||
}
|
||||
|
||||
// taskSubmitWithRetry 执行首次尝试并在失败时切换渠道重试,返回最终的 taskErr。
|
||||
// attempt 闭包负责实际的上游请求,不涉及计费。
|
||||
func taskSubmitWithRetry(c *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
channelId int, retryTimes int, attempt func() *dto.TaskError) *dto.TaskError {
|
||||
|
||||
taskErr := attempt()
|
||||
if taskErr == nil {
|
||||
return nil
|
||||
}
|
||||
if !taskErr.LocalError {
|
||||
processChannelError(c,
|
||||
*types.NewChannelError(channelId, c.GetInt("channel_type"), c.GetString("channel_name"), common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
|
||||
common.GetContextKeyString(c, constant.ContextKeyChannelKey), common.GetContextKeyBool(c, constant.ContextKeyChannelAutoBan)),
|
||||
types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
|
||||
}
|
||||
|
||||
retryParam := &service.RetryParam{
|
||||
Ctx: c,
|
||||
TokenGroup: relayInfo.TokenGroup,
|
||||
@@ -480,7 +565,7 @@ func RelayTask(c *gin.Context) {
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
middleware.SetupContextForSelectedChannel(c, channel, c.GetString("original_model"))
|
||||
|
||||
bodyStorage, err := common.GetBodyStorage(c)
|
||||
if err != nil {
|
||||
@@ -492,30 +577,21 @@ func RelayTask(c *gin.Context) {
|
||||
break
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bodyStorage)
|
||||
taskErr = taskRelayHandler(c, relayInfo)
|
||||
taskErr = attempt()
|
||||
if taskErr != nil && !taskErr.LocalError {
|
||||
processChannelError(c,
|
||||
*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey,
|
||||
common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()),
|
||||
types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
|
||||
}
|
||||
}
|
||||
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
logger.LogInfo(c, retryLogStr)
|
||||
}
|
||||
if taskErr != nil {
|
||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
c.JSON(taskErr.StatusCode, taskErr)
|
||||
}
|
||||
}
|
||||
|
||||
func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayInfo.RelayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayInfo)
|
||||
}
|
||||
return err
|
||||
return taskErr
|
||||
}
|
||||
|
||||
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
|
||||
|
||||
@@ -1,231 +1,21 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层
|
||||
func UpdateTaskBulk() {
|
||||
//revocer
|
||||
//imageModel := "midjourney"
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
common.SysLog("任务进度轮询开始")
|
||||
ctx := context.TODO()
|
||||
allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
|
||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||
for _, t := range allTasks {
|
||||
platformTask[t.Platform] = append(platformTask[t.Platform], t)
|
||||
}
|
||||
for platform, tasks := range platformTask {
|
||||
if len(tasks) == 0 {
|
||||
continue
|
||||
}
|
||||
taskChannelM := make(map[int][]string)
|
||||
taskM := make(map[string]*model.Task)
|
||||
nullTaskIds := make([]int64, 0)
|
||||
for _, task := range tasks {
|
||||
if task.TaskID == "" {
|
||||
// 统计失败的未完成任务
|
||||
nullTaskIds = append(nullTaskIds, task.ID)
|
||||
continue
|
||||
}
|
||||
taskM[task.TaskID] = task
|
||||
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
|
||||
}
|
||||
if len(nullTaskIds) > 0 {
|
||||
err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||
} else {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||
}
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
UpdateTaskByPlatform(platform, taskChannelM, taskM)
|
||||
}
|
||||
common.SysLog("任务进度轮询完成")
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
|
||||
switch platform {
|
||||
case constant.TaskPlatformMidjourney:
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
default:
|
||||
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
err = model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
|
||||
if adaptor == nil {
|
||||
return errors.New("adaptor not found")
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
|
||||
"ids": taskIds,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
return err
|
||||
}
|
||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
return err
|
||||
}
|
||||
if !responseItems.IsSuccess() {
|
||||
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
|
||||
return err
|
||||
}
|
||||
|
||||
for _, responseItem := range responseItems.Data {
|
||||
task := taskM[responseItem.TaskID]
|
||||
if !checkTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
|
||||
task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
|
||||
task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
|
||||
task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
|
||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||
logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
if responseItem.Status == model.TaskStatusSuccess {
|
||||
task.Progress = "100%"
|
||||
}
|
||||
task.Data = responseItem.Data
|
||||
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
|
||||
|
||||
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.StartTime != newTask.StartTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if string(oldTask.Status) != newTask.Status {
|
||||
return true
|
||||
}
|
||||
if oldTask.FailReason != newTask.FailReason {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
|
||||
if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
|
||||
return true
|
||||
}
|
||||
|
||||
oldData, _ := json.Marshal(oldTask.Data)
|
||||
newData, _ := json.Marshal(newTask.Data)
|
||||
|
||||
sort.Slice(oldData, func(i, j int) bool {
|
||||
return oldData[i] < oldData[j]
|
||||
})
|
||||
sort.Slice(newData, func(i, j int) bool {
|
||||
return newData[i] < newData[j]
|
||||
})
|
||||
|
||||
if string(oldData) != string(newData) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
service.TaskPollingLoop()
|
||||
}
|
||||
|
||||
func GetAllTask(c *gin.Context) {
|
||||
@@ -247,7 +37,7 @@ func GetAllTask(c *gin.Context) {
|
||||
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllTasks(queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
pageInfo.SetItems(tasksToDto(items))
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
@@ -271,6 +61,14 @@ func GetUserTask(c *gin.Context) {
|
||||
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
pageInfo.SetItems(tasksToDto(items))
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func tasksToDto(tasks []*model.Task) []*dto.TaskDto {
|
||||
result := make([]*dto.TaskDto, len(tasks))
|
||||
for i, task := range tasks {
|
||||
result[i] = relay.TaskModel2Dto(task)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,313 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
info := &relaycommon.RelayInfo{}
|
||||
info.ChannelMeta = &relaycommon.ChannelMeta{
|
||||
ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
|
||||
}
|
||||
info.ApiKey = cacheGetChannel.Key
|
||||
adaptor.Init(info)
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
key := channel.Key
|
||||
|
||||
privateData := task.PrivateData
|
||||
if privateData.Key != "" {
|
||||
key = privateData.Key
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if resp.StatusCode != http.StatusOK {
|
||||
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||
//}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
|
||||
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
taskResult.Url = t.FailReason
|
||||
taskResult.Progress = t.Progress
|
||||
taskResult.Reason = t.FailReason
|
||||
task.Data = t.Data
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
//return fmt.Errorf("task %s status is empty", taskId)
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
|
||||
// 记录原本的状态,防止重复退款
|
||||
shouldRefund := false
|
||||
quota := task.Quota
|
||||
preStatus := task.Status
|
||||
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = "10%"
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = "20%"
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = "30%"
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||
task.FailReason = taskResult.Url
|
||||
}
|
||||
|
||||
// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
|
||||
if taskResult.TotalTokens > 0 {
|
||||
// 获取模型名称
|
||||
var taskData map[string]interface{}
|
||||
if err := json.Unmarshal(task.Data, &taskData); err == nil {
|
||||
if modelName, ok := taskData["model"].(string); ok && modelName != "" {
|
||||
// 获取模型价格和倍率
|
||||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
||||
if hasRatioSetting && modelRatio > 0 {
|
||||
// 获取用户和组的倍率信息
|
||||
group := task.Group
|
||||
if group == "" {
|
||||
user, err := model.GetUserById(task.UserId, false)
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
}
|
||||
}
|
||||
if group != "" {
|
||||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
|
||||
|
||||
var finalGroupRatio float64
|
||||
if hasUserGroupRatio {
|
||||
finalGroupRatio = userGroupRatio
|
||||
} else {
|
||||
finalGroupRatio = groupRatio
|
||||
}
|
||||
|
||||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||||
actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
|
||||
|
||||
// 计算差额
|
||||
preConsumedQuota := task.Quota
|
||||
quotaDelta := actualQuota - preConsumedQuota
|
||||
|
||||
if quotaDelta > 0 {
|
||||
// 需要补扣费
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(quotaDelta),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
taskResult.TotalTokens,
|
||||
))
|
||||
if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
||||
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
||||
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||
|
||||
// 记录消费日志
|
||||
logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
|
||||
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
} else if quotaDelta < 0 {
|
||||
// 需要退还多扣的费用
|
||||
refundQuota := -quotaDelta
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(refundQuota),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
taskResult.TotalTokens,
|
||||
))
|
||||
if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
|
||||
} else {
|
||||
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||
|
||||
// 记录退款日志
|
||||
logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
|
||||
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
} else {
|
||||
// quotaDelta == 0, 预扣费刚好准确
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
|
||||
task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
taskResult.Progress = "100%"
|
||||
if quota != 0 {
|
||||
if preStatus != model.TaskStatusFailure {
|
||||
shouldRefund = true
|
||||
} else {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||
shouldRefund = false
|
||||
}
|
||||
|
||||
if shouldRefund {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactVideoResponseBody(body []byte) []byte {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return body
|
||||
}
|
||||
resp, _ := m["response"].(map[string]any)
|
||||
if resp != nil {
|
||||
delete(resp, "bytesBase64Encoded")
|
||||
if v, ok := resp["video"].(string); ok {
|
||||
resp["video"] = truncateBase64(v)
|
||||
}
|
||||
if vs, ok := resp["videos"].([]any); ok {
|
||||
for i := range vs {
|
||||
if vm, ok := vs[i].(map[string]any); ok {
|
||||
delete(vm, "bytesBase64Encoded")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func truncateBase64(s string) string {
|
||||
const maxKeep = 256
|
||||
if len(s) <= maxKeep {
|
||||
return s
|
||||
}
|
||||
return s[:maxKeep] + "..."
|
||||
}
|
||||
@@ -16,59 +16,44 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// videoProxyError returns a standardized OpenAI-style error response.
|
||||
func videoProxyError(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"message": message,
|
||||
"type": errType,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func VideoProxy(c *gin.Context) {
|
||||
taskID := c.Param("task_id")
|
||||
if taskID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "task_id is required",
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
task, exists, err := model.GetByOnlyTaskId(taskID)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to query task",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task")
|
||||
return
|
||||
}
|
||||
if !exists || task == nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Task not found",
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found")
|
||||
return
|
||||
}
|
||||
|
||||
if task.Status != model.TaskStatusSuccess {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadRequest, "invalid_request_error",
|
||||
fmt.Sprintf("Task is not completed yet, current status: %s", task.Status))
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.CacheGetChannel(task.ChannelId)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to retrieve channel information",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error()))
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information")
|
||||
return
|
||||
}
|
||||
baseURL := channel.GetBaseURL()
|
||||
@@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) {
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create proxy client",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create proxy request",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) {
|
||||
apiKey := task.PrivateData.Key
|
||||
if apiKey == "" {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "API key not stored for task",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task")
|
||||
return
|
||||
}
|
||||
|
||||
videoURL, err = getGeminiVideoURL(channel, task, apiKey)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to resolve Gemini video URL",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL")
|
||||
return
|
||||
}
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
|
||||
videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
|
||||
videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
|
||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
default:
|
||||
// Video URL is directly in task.FailReason
|
||||
videoURL = task.FailReason
|
||||
// Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data)
|
||||
videoURL = task.GetResultURL()
|
||||
}
|
||||
|
||||
req.URL, err = url.Parse(videoURL)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create proxy request",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to fetch video content",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadGateway, "server_error",
|
||||
fmt.Sprintf("Upstream service returned status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
|
||||
c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
if _, err = io.Copy(c.Writer, resp.Body); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
@@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
|
||||
|
||||
proxy := channel.GetSetting().Proxy
|
||||
resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
|
||||
"task_id": task.TaskID,
|
||||
"task_id": task.GetUpstreamTaskID(),
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
@@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
|
||||
return ""
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(task.Data, &payload); err != nil {
|
||||
if err := common.Unmarshal(task.Data, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
return extractGeminiVideoURLFromMap(payload)
|
||||
@@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
|
||||
|
||||
func extractGeminiVideoURLFromPayload(body []byte) string {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
return extractGeminiVideoURLFromMap(payload)
|
||||
|
||||
Reference in New Issue
Block a user