From 5ec4633cb8ed92c7c863a106fdf9b5cfa389c66c Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 00:52:35 +0800 Subject: [PATCH] refactor(task): add CAS-guarded updates to prevent concurrent billing conflicts Replace all bare task.Update() (DB.Save) calls with UpdateWithStatus(), which adds a WHERE status = ? guard to prevent concurrent processes from overwriting each other's state transitions. Key changes: model/task.go: - Add taskSnapshot struct with Equal() method for change detection - Add Snapshot() method to capture pre-update state - Add UpdateWithStatus(fromStatus) using DB.Where().Save() for CAS semantics with full-struct save (no explicit field listing needed) model/midjourney.go: - Add UpdateWithStatus(fromStatus string) with same CAS pattern service/task_polling.go (updateVideoSingleTask): - Snapshot before processing upstream response; skip DB write if unchanged - Terminal transitions (SUCCESS/FAILURE) use UpdateWithStatus CAS: billing/refund only executes if this process wins the transition - Non-terminal updates also use UpdateWithStatus to prevent overwriting a concurrent terminal transition back to IN_PROGRESS - Defer settleTaskBillingOnComplete to after CAS check (shouldSettle flag) relay/relay_task.go (tryRealtimeFetch): - Add snapshot + change detection; use UpdateWithStatus for CAS safety controller/midjourney.go (UpdateMidjourneyTaskBulk): - Capture preStatus before mutations; use UpdateWithStatus CAS - Gate refund (IncreaseUserQuota) on CAS success (won && shouldReturnQuota) This prevents the multi-instance race condition where: 1. Instance A reads task (IN_PROGRESS), fetches upstream (still IN_PROGRESS) 2. Instance B reads same task, fetches upstream (now SUCCESS), writes SUCCESS 3. Instance A's bare Save() overwrites SUCCESS back to IN_PROGRESS --- controller/midjourney.go | 17 ++++---- model/midjourney.go | 11 +++++ model/task.go | 91 ++++++++++++++++++---------------------- relay/relay_task.go | 7 +++- service/task_polling.go | 43 ++++++++++++------- 5 files changed, 95 insertions(+), 74 deletions(-) diff --git a/controller/midjourney.go b/controller/midjourney.go index c480c12bb..4045a5509 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() { if !checkMjTaskNeedUpdate(task, responseItem) { continue } + preStatus := task.Status task.Code = 1 task.Progress = responseItem.Progress task.PromptEn = responseItem.PromptEn @@ -172,18 +173,16 @@ func UpdateMidjourneyTaskBulk() { shouldReturnQuota = true } } - err = task.Update() + won, err := task.UpdateWithStatus(preStatus) if err != nil { logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) - } else { - if shouldReturnQuota { - err = model.IncreaseUserQuota(task.UserId, task.Quota, false) - if err != nil { - logger.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } else if won && shouldReturnQuota { + err = model.IncreaseUserQuota(task.UserId, task.Quota, false) + if err != nil { + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } } diff --git a/model/midjourney.go b/model/midjourney.go index c6ef5de5b..9867e8a96 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -157,6 +157,17 @@ func (midjourney *Midjourney) Update() error { return err } +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Returns (true, nil) if this caller won the update, (false, nil) if +// another process already moved the task out of fromStatus. +func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) { + result := DB.Where("status = ?", fromStatus).Save(midjourney) + if result.Error != nil { + return false, result.Error + } + return result.RowsAffected > 0, nil +} + func MjBulkUpdate(mjIds []string, params map[string]any) error { return DB.Model(&Midjourney{}). Where("mj_id in (?)", mjIds). diff --git a/model/task.go b/model/task.go index 592643ebb..4d1482f8b 100644 --- a/model/task.go +++ b/model/task.go @@ -1,6 +1,7 @@ package model import ( + "bytes" "database/sql/driver" "encoding/json" "time" @@ -340,38 +341,59 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) { return task, nil } -func TaskUpdateProgress(id int64, progress string) error { - return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error -} - func (Task *Task) Insert() error { var err error err = DB.Create(Task).Error return err } +type taskSnapshot struct { + Status TaskStatus + Progress string + StartTime int64 + FinishTime int64 + FailReason string + ResultURL string + Data json.RawMessage +} + +func (s taskSnapshot) Equal(other taskSnapshot) bool { + return s.Status == other.Status && + s.Progress == other.Progress && + s.StartTime == other.StartTime && + s.FinishTime == other.FinishTime && + s.FailReason == other.FailReason && + s.ResultURL == other.ResultURL && + bytes.Equal(s.Data, other.Data) +} + +func (t *Task) Snapshot() taskSnapshot { + return taskSnapshot{ + Status: t.Status, + Progress: t.Progress, + StartTime: t.StartTime, + FinishTime: t.FinishTime, + FailReason: t.FailReason, + ResultURL: t.PrivateData.ResultURL, + Data: t.Data, + } +} + func (Task *Task) Update() error { var err error err = DB.Save(Task).Error return err } -func TaskBulkUpdate(TaskIds []string, params map[string]any) error { - if len(TaskIds) == 0 { - return nil +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Returns (true, nil) if this caller won the update, (false, nil) if +// another process already moved the task out of fromStatus. +func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) { + result := DB.Where("status = ?", fromStatus).Save(t) + if result.Error != nil { + return false, result.Error } - return DB.Model(&Task{}). - Where("task_id in (?)", TaskIds). - Updates(params).Error -} - -func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error { - if len(taskIDs) == 0 { - return nil - } - return DB.Model(&Task{}). - Where("id in (?)", taskIDs). - Updates(params).Error + return result.RowsAffected > 0, nil } func TaskBulkUpdateByID(ids []int64, params map[string]any) error { @@ -388,37 +410,6 @@ type TaskQuotaUsage struct { Count float64 `json:"count"` } -func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) { - query := DB.Model(Task{}) - // 添加过滤条件 - if queryParams.ChannelID != "" { - query = query.Where("channel_id = ?", queryParams.ChannelID) - } - if queryParams.UserID != "" { - query = query.Where("user_id = ?", queryParams.UserID) - } - if len(queryParams.UserIDs) != 0 { - query = query.Where("user_id in (?)", queryParams.UserIDs) - } - if queryParams.TaskID != "" { - query = query.Where("task_id = ?", queryParams.TaskID) - } - if queryParams.Action != "" { - query = query.Where("action = ?", queryParams.Action) - } - if queryParams.Status != "" { - query = query.Where("status = ?", queryParams.Status) - } - if queryParams.StartTimestamp != 0 { - query = query.Where("submit_time >= ?", queryParams.StartTimestamp) - } - if queryParams.EndTimestamp != 0 { - query = query.Where("submit_time <= ?", queryParams.EndTimestamp) - } - err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error - return stat, err -} - // TaskCountAllTasks returns total tasks that match the given query params (admin usage) func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 { var total int64 diff --git a/relay/relay_task.go b/relay/relay_task.go index 8d0e61d72..cd43e6ebb 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -444,6 +444,8 @@ func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { return nil } + snap := task.Snapshot() + // 将上游最新状态更新到 task if ti.Status != "" { task.Status = model.TaskStatus(ti.Status) @@ -459,7 +461,10 @@ func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { // No URL from adaptor — construct proxy URL using public task ID task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) } - _ = task.Update() + + if !snap.Equal(task.Snapshot()) { + _, _ = task.UpdateWithStatus(snap.Status) + } // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理 if isOpenAIVideoAPI { diff --git a/service/task_polling.go b/service/task_polling.go index efbad8afa..7e92d14ba 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -319,6 +319,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody))) + snap := task.Snapshot() + taskResult := &relaycommon.TaskInfo{} // try parse as New API response format var responseItems dto.TaskResponse[model.Task] @@ -344,10 +346,9 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * taskResult = relaycommon.FailTaskInfo("upstream returned empty status") } - // 记录原本的状态,防止重复退款 shouldRefund := false + shouldSettle := false quota := task.Quota - preStatus := task.Status task.Status = model.TaskStatus(taskResult.Status) switch taskResult.Status { @@ -374,9 +375,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * // No URL from adaptor — construct proxy URL using public task ID task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) } - - // 完成时计费调整:优先由 adaptor 计算,回退到 token 重算 - settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + shouldSettle = true case model.TaskStatusFailure: logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) task.Status = model.TaskStatusFailure @@ -388,23 +387,39 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) taskResult.Progress = taskcommon.ProgressComplete if quota != 0 { - if preStatus != model.TaskStatusFailure { - shouldRefund = true - } else { - logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) - } + shouldRefund = true } default: - return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) + return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID) } if taskResult.Progress != "" { task.Progress = taskResult.Progress } - if err := task.Update(); err != nil { - common.SysLog("UpdateVideoTask task error: " + err.Error()) - shouldRefund = false + + isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure + if isDone && snap.Status != task.Status { + won, err := task.UpdateWithStatus(snap.Status) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error())) + shouldRefund = false + shouldSettle = false + } else if !won { + logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID)) + shouldRefund = false + shouldSettle = false + } + } else if !snap.Equal(task.Snapshot()) { + if _, err := task.UpdateWithStatus(snap.Status); err != nil { + logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error())) + } + } else { + // No changes, skip update + logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID)) } + if shouldSettle { + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + } if shouldRefund { RefundTaskQuota(ctx, task, task.FailReason) }