diff --git a/common/init.go b/common/init.go index 6d2c3572b..e4ddbb453 100644 --- a/common/init.go +++ b/common/init.go @@ -145,6 +145,8 @@ func initConstantEnv() { constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false) // 任务轮询时查询的最大数量 constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000) + // 异步任务超时时间(分钟),超过此时间未完成的任务将被标记为失败并退款。0 表示禁用。 + constant.TaskTimeoutMinutes = GetEnvOrDefault("TASK_TIMEOUT_MINUTES", 1440) soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "") if soraPatchStr != "" { diff --git a/constant/env.go b/constant/env.go index 957f68669..d5aff1b0b 100644 --- a/constant/env.go +++ b/constant/env.go @@ -16,6 +16,7 @@ var NotificationLimitDurationMinute int var GenerateDefaultToken bool var ErrorLogEnabled bool var TaskQueryLimit int +var TaskTimeoutMinutes int // temporary variable for sora patch, will be removed in future var TaskPricePatches []string diff --git a/model/task.go b/model/task.go index da3be34ed..984445083 100644 --- a/model/task.go +++ b/model/task.go @@ -288,6 +288,20 @@ func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []* return tasks } +func GetTimedOutUnfinishedTasks(cutoffUnix int64, limit int) []*Task { + var tasks []*Task + err := DB.Where("progress != ?", "100%"). + Where("status NOT IN ?", []string{TaskStatusFailure, TaskStatusSuccess}). + Where("submit_time < ?", cutoffUnix). + Order("submit_time"). + Limit(limit). + Find(&tasks).Error + if err != nil { + return nil + } + return tasks +} + func GetAllUnFinishSyncTasks(limit int) []*Task { var tasks []*Task var err error @@ -401,6 +415,11 @@ func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) { return result.RowsAffected > 0, nil } +// TaskBulkUpdateByID performs an unconditional bulk UPDATE by primary key IDs. +// WARNING: This function has NO CAS (Compare-And-Swap) guard — it will overwrite +// any concurrent status changes. DO NOT use in billing/quota lifecycle flows +// (e.g., timeout, success, failure transitions that trigger refunds or settlements). +// For status transitions that involve billing, use Task.UpdateWithStatus() instead. func TaskBulkUpdateByID(ids []int64, params map[string]any) error { if len(ids) == 0 { return nil diff --git a/service/task_polling.go b/service/task_polling.go index a03fc9b88..9ac4deddc 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -35,12 +35,65 @@ type TaskPollingAdaptor interface { // 打破 service -> relay -> relay/channel -> service 的循环依赖。 var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor +// sweepTimedOutTasks 在主轮询之前独立清理超时任务。 +// 每次最多处理 100 条,剩余的下个周期继续处理。 +// 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。 +func sweepTimedOutTasks(ctx context.Context) { + if constant.TaskTimeoutMinutes <= 0 { + return + } + cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60 + tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100) + if len(tasks) == 0 { + return + } + + const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC + reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes) + legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)" + now := time.Now().Unix() + timedOutCount := 0 + + for _, task := range tasks { + isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff + + oldStatus := task.Status + task.Status = model.TaskStatusFailure + task.Progress = "100%" + task.FinishTime = now + if isLegacy { + task.FailReason = legacyReason + } else { + task.FailReason = reason + } + + won, err := task.UpdateWithStatus(oldStatus) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err)) + continue + } + if !won { + logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID)) + continue + } + timedOutCount++ + if !isLegacy && task.Quota != 0 { + RefundTaskQuota(ctx, task, reason) + } + } + + if timedOutCount > 0 { + logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount)) + } +} + // TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务 func TaskPollingLoop() { for { time.Sleep(time.Duration(15) * time.Second) common.SysLog("任务进度轮询开始") ctx := context.TODO() + sweepTimedOutTasks(ctx) allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) platformTask := make(map[constant.TaskPlatform][]*model.Task) for _, t := range allTasks {