mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 00:46:42 +00:00
refactor(task): add CAS-guarded updates to prevent concurrent billing conflicts
Replace all bare task.Update() (DB.Save) calls with UpdateWithStatus(), which adds a WHERE status = ? guard to prevent concurrent processes from overwriting each other's state transitions. Key changes: model/task.go: - Add taskSnapshot struct with Equal() method for change detection - Add Snapshot() method to capture pre-update state - Add UpdateWithStatus(fromStatus) using DB.Where().Save() for CAS semantics with full-struct save (no explicit field listing needed) model/midjourney.go: - Add UpdateWithStatus(fromStatus string) with same CAS pattern service/task_polling.go (updateVideoSingleTask): - Snapshot before processing upstream response; skip DB write if unchanged - Terminal transitions (SUCCESS/FAILURE) use UpdateWithStatus CAS: billing/refund only executes if this process wins the transition - Non-terminal updates also use UpdateWithStatus to prevent overwriting a concurrent terminal transition back to IN_PROGRESS - Defer settleTaskBillingOnComplete to after CAS check (shouldSettle flag) relay/relay_task.go (tryRealtimeFetch): - Add snapshot + change detection; use UpdateWithStatus for CAS safety controller/midjourney.go (UpdateMidjourneyTaskBulk): - Capture preStatus before mutations; use UpdateWithStatus CAS - Gate refund (IncreaseUserQuota) on CAS success (won && shouldReturnQuota) This prevents the multi-instance race condition where: 1. Instance A reads task (IN_PROGRESS), fetches upstream (still IN_PROGRESS) 2. Instance B reads same task, fetches upstream (now SUCCESS), writes SUCCESS 3. Instance A's bare Save() overwrites SUCCESS back to IN_PROGRESS
This commit is contained in:
@@ -157,6 +157,17 @@ func (midjourney *Midjourney) Update() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
|
||||
// Returns (true, nil) if this caller won the update, (false, nil) if
|
||||
// another process already moved the task out of fromStatus.
|
||||
func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) {
|
||||
result := DB.Where("status = ?", fromStatus).Save(midjourney)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func MjBulkUpdate(mjIds []string, params map[string]any) error {
|
||||
return DB.Model(&Midjourney{}).
|
||||
Where("mj_id in (?)", mjIds).
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
@@ -340,38 +341,59 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func TaskUpdateProgress(id int64, progress string) error {
|
||||
return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
|
||||
}
|
||||
|
||||
func (Task *Task) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(Task).Error
|
||||
return err
|
||||
}
|
||||
|
||||
type taskSnapshot struct {
|
||||
Status TaskStatus
|
||||
Progress string
|
||||
StartTime int64
|
||||
FinishTime int64
|
||||
FailReason string
|
||||
ResultURL string
|
||||
Data json.RawMessage
|
||||
}
|
||||
|
||||
func (s taskSnapshot) Equal(other taskSnapshot) bool {
|
||||
return s.Status == other.Status &&
|
||||
s.Progress == other.Progress &&
|
||||
s.StartTime == other.StartTime &&
|
||||
s.FinishTime == other.FinishTime &&
|
||||
s.FailReason == other.FailReason &&
|
||||
s.ResultURL == other.ResultURL &&
|
||||
bytes.Equal(s.Data, other.Data)
|
||||
}
|
||||
|
||||
func (t *Task) Snapshot() taskSnapshot {
|
||||
return taskSnapshot{
|
||||
Status: t.Status,
|
||||
Progress: t.Progress,
|
||||
StartTime: t.StartTime,
|
||||
FinishTime: t.FinishTime,
|
||||
FailReason: t.FailReason,
|
||||
ResultURL: t.PrivateData.ResultURL,
|
||||
Data: t.Data,
|
||||
}
|
||||
}
|
||||
|
||||
func (Task *Task) Update() error {
|
||||
var err error
|
||||
err = DB.Save(Task).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
|
||||
if len(TaskIds) == 0 {
|
||||
return nil
|
||||
// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
|
||||
// Returns (true, nil) if this caller won the update, (false, nil) if
|
||||
// another process already moved the task out of fromStatus.
|
||||
func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
|
||||
result := DB.Where("status = ?", fromStatus).Save(t)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return DB.Model(&Task{}).
|
||||
Where("task_id in (?)", TaskIds).
|
||||
Updates(params).Error
|
||||
}
|
||||
|
||||
func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
|
||||
if len(taskIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return DB.Model(&Task{}).
|
||||
Where("id in (?)", taskIDs).
|
||||
Updates(params).Error
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
|
||||
@@ -388,37 +410,6 @@ type TaskQuotaUsage struct {
|
||||
Count float64 `json:"count"`
|
||||
}
|
||||
|
||||
func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
|
||||
query := DB.Model(Task{})
|
||||
// 添加过滤条件
|
||||
if queryParams.ChannelID != "" {
|
||||
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
||||
}
|
||||
if queryParams.UserID != "" {
|
||||
query = query.Where("user_id = ?", queryParams.UserID)
|
||||
}
|
||||
if len(queryParams.UserIDs) != 0 {
|
||||
query = query.Where("user_id in (?)", queryParams.UserIDs)
|
||||
}
|
||||
if queryParams.TaskID != "" {
|
||||
query = query.Where("task_id = ?", queryParams.TaskID)
|
||||
}
|
||||
if queryParams.Action != "" {
|
||||
query = query.Where("action = ?", queryParams.Action)
|
||||
}
|
||||
if queryParams.Status != "" {
|
||||
query = query.Where("status = ?", queryParams.Status)
|
||||
}
|
||||
if queryParams.StartTimestamp != 0 {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
|
||||
return stat, err
|
||||
}
|
||||
|
||||
// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
|
||||
func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
|
||||
var total int64
|
||||
|
||||
Reference in New Issue
Block a user