mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:05:21 +00:00
refactor(task): add CAS-guarded updates to prevent concurrent billing conflicts
Replace all bare task.Update() (DB.Save) calls with UpdateWithStatus(), which adds a WHERE status = ? guard to prevent concurrent processes from overwriting each other's state transitions. Key changes: model/task.go: - Add taskSnapshot struct with Equal() method for change detection - Add Snapshot() method to capture pre-update state - Add UpdateWithStatus(fromStatus) using DB.Where().Save() for CAS semantics with full-struct save (no explicit field listing needed) model/midjourney.go: - Add UpdateWithStatus(fromStatus string) with same CAS pattern service/task_polling.go (updateVideoSingleTask): - Snapshot before processing upstream response; skip DB write if unchanged - Terminal transitions (SUCCESS/FAILURE) use UpdateWithStatus CAS: billing/refund only executes if this process wins the transition - Non-terminal updates also use UpdateWithStatus to prevent overwriting a concurrent terminal transition back to IN_PROGRESS - Defer settleTaskBillingOnComplete to after CAS check (shouldSettle flag) relay/relay_task.go (tryRealtimeFetch): - Add snapshot + change detection; use UpdateWithStatus for CAS safety controller/midjourney.go (UpdateMidjourneyTaskBulk): - Capture preStatus before mutations; use UpdateWithStatus CAS - Gate refund (IncreaseUserQuota) on CAS success (won && shouldReturnQuota) This prevents the multi-instance race condition where: 1. Instance A reads task (IN_PROGRESS), fetches upstream (still IN_PROGRESS) 2. Instance B reads same task, fetches upstream (now SUCCESS), writes SUCCESS 3. Instance A's bare Save() overwrites SUCCESS back to IN_PROGRESS
This commit is contained in:
@@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
preStatus := task.Status
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
@@ -172,18 +173,16 @@ func UpdateMidjourneyTaskBulk() {
|
||||
shouldReturnQuota = true
|
||||
}
|
||||
}
|
||||
err = task.Update()
|
||||
won, err := task.UpdateWithStatus(preStatus)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
} else {
|
||||
if shouldReturnQuota {
|
||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
} else if won && shouldReturnQuota {
|
||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user