Files
new-api/controller/midjourney.go
CaIon 5ec4633cb8 refactor(task): add CAS-guarded updates to prevent concurrent billing conflicts
Replace all bare task.Update() (DB.Save) calls with UpdateWithStatus(),
which adds a WHERE status = ? guard to prevent concurrent processes from
overwriting each other's state transitions.

Key changes:

model/task.go:
- Add taskSnapshot struct with Equal() method for change detection
- Add Snapshot() method to capture pre-update state
- Add UpdateWithStatus(fromStatus) using DB.Where().Save() for CAS
  semantics with full-struct save (no explicit field listing needed)

model/midjourney.go:
- Add UpdateWithStatus(fromStatus string) with same CAS pattern

service/task_polling.go (updateVideoSingleTask):
- Snapshot before processing upstream response; skip DB write if unchanged
- Terminal transitions (SUCCESS/FAILURE) use UpdateWithStatus CAS:
  billing/refund only executes if this process wins the transition
- Non-terminal updates also use UpdateWithStatus to prevent overwriting
  a concurrent terminal transition back to IN_PROGRESS
- Defer settleTaskBillingOnComplete to after CAS check (shouldSettle flag)

relay/relay_task.go (tryRealtimeFetch):
- Add snapshot + change detection; use UpdateWithStatus for CAS safety

controller/midjourney.go (UpdateMidjourneyTaskBulk):
- Capture preStatus before mutations; use UpdateWithStatus CAS
- Gate refund (IncreaseUserQuota) on CAS success (won && shouldReturnQuota)

This prevents the multi-instance race condition where:
1. Instance A reads task (IN_PROGRESS), fetches upstream (still IN_PROGRESS)
2. Instance B reads same task, fetches upstream (now SUCCESS), writes SUCCESS
3. Instance A's bare Save() overwrites SUCCESS back to IN_PROGRESS
2026-02-22 16:01:19 +08:00

296 lines
8.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package controller
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
)
func UpdateMidjourneyTaskBulk() {
//imageModel := "midjourney"
ctx := context.TODO()
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) == 0 {
continue
}
logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0)
for _, task := range tasks {
if task.MjId == "" {
// 统计失败的未完成任务
nullTaskIds = append(nullTaskIds, task.Id)
continue
}
taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
}
if len(nullTaskIds) > 0 {
err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else {
logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
continue
}
for channelId, taskIds := range taskChannelM {
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
body, _ := json.Marshal(map[string]any{
"ids": taskIds,
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 15
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := service.GetHttpClient().Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
if resp.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue
}
resp.Body.Close()
req.Body.Close()
cancel()
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
// 如果时间超过一小时且进度不是100%,则认为任务失败
if useTime > 3600000 && task.Progress != "100%" {
responseItem.FailReason = "上游任务超时超过1小时"
responseItem.Status = "FAILURE"
}
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
preStatus := task.Status
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if responseItem.Properties != nil {
propertiesStr, _ := json.Marshal(responseItem.Properties)
task.Properties = string(propertiesStr)
}
if responseItem.Buttons != nil {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
// 映射 VideoUrl
task.VideoUrl = responseItem.VideoUrl
// 映射 VideoUrls - 将数组序列化为 JSON 字符串
if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
task.VideoUrls = "[]" // 失败时设置为空数组
} else {
task.VideoUrls = string(videoUrlsStr)
}
} else {
task.VideoUrls = "" // 空值时清空字段
}
shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
if task.Quota != 0 {
shouldReturnQuota = true
}
}
won, err := task.UpdateWithStatus(preStatus)
if err != nil {
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else if won && shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil {
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, logger.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
}
}
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
if oldTask.Code != 1 {
return true
}
if oldTask.Progress != newTask.Progress {
return true
}
if oldTask.PromptEn != newTask.PromptEn {
return true
}
if oldTask.State != newTask.State {
return true
}
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.ImageUrl != newTask.ImageUrl {
return true
}
if oldTask.Status != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.Progress != "100%" && newTask.FailReason != "" {
return true
}
// 检查 VideoUrl 是否需要更新
if oldTask.VideoUrl != newTask.VideoUrl {
return true
}
// 检查 VideoUrls 是否需要更新
if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 {
newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls)
if oldTask.VideoUrls != string(newVideoUrlsStr) {
return true
}
} else if oldTask.VideoUrls != "" {
// 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空)
return true
}
return false
}
func GetAllMidjourney(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
// 解析其他查询参数
queryParams := model.TaskQueryParams{
ChannelID: c.Query("channel_id"),
MjID: c.Query("mj_id"),
StartTimestamp: c.Query("start_timestamp"),
EndTimestamp: c.Query("end_timestamp"),
}
items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
total := model.CountAllTasks(queryParams)
if setting.MjForwardUrlEnabled {
for i, midjourney := range items {
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
items[i] = midjourney
}
}
pageInfo.SetTotal(int(total))
pageInfo.SetItems(items)
common.ApiSuccess(c, pageInfo)
}
func GetUserMidjourney(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
userId := c.GetInt("id")
queryParams := model.TaskQueryParams{
MjID: c.Query("mj_id"),
StartTimestamp: c.Query("start_timestamp"),
EndTimestamp: c.Query("end_timestamp"),
}
items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
total := model.CountAllUserTask(userId, queryParams)
if setting.MjForwardUrlEnabled {
for i, midjourney := range items {
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
items[i] = midjourney
}
}
pageInfo.SetTotal(int(total))
pageInfo.SetItems(items)
common.ApiSuccess(c, pageInfo)
}