mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
feat(task): introduce task timeout configuration and cleanup unfinished tasks
- Added TaskTimeoutMinutes constant to configure the timeout duration for asynchronous tasks. - Implemented sweepTimedOutTasks function to identify and handle unfinished tasks that exceed the timeout limit, marking them as failed and processing refunds if applicable. - Enhanced task polling loop to include the new timeout handling logic, ensuring timely cleanup of stale tasks.
This commit is contained in:
@@ -145,6 +145,8 @@ func initConstantEnv() {
|
|||||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||||
// 任务轮询时查询的最大数量
|
// 任务轮询时查询的最大数量
|
||||||
constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
|
constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
|
||||||
|
// 异步任务超时时间(分钟),超过此时间未完成的任务将被标记为失败并退款。0 表示禁用。
|
||||||
|
constant.TaskTimeoutMinutes = GetEnvOrDefault("TASK_TIMEOUT_MINUTES", 1440)
|
||||||
|
|
||||||
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
|
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
|
||||||
if soraPatchStr != "" {
|
if soraPatchStr != "" {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ var NotificationLimitDurationMinute int
|
|||||||
var GenerateDefaultToken bool
|
var GenerateDefaultToken bool
|
||||||
var ErrorLogEnabled bool
|
var ErrorLogEnabled bool
|
||||||
var TaskQueryLimit int
|
var TaskQueryLimit int
|
||||||
|
var TaskTimeoutMinutes int
|
||||||
|
|
||||||
// temporary variable for sora patch, will be removed in future
|
// temporary variable for sora patch, will be removed in future
|
||||||
var TaskPricePatches []string
|
var TaskPricePatches []string
|
||||||
|
|||||||
@@ -288,6 +288,20 @@ func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*
|
|||||||
return tasks
|
return tasks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetTimedOutUnfinishedTasks(cutoffUnix int64, limit int) []*Task {
|
||||||
|
var tasks []*Task
|
||||||
|
err := DB.Where("progress != ?", "100%").
|
||||||
|
Where("status NOT IN ?", []string{TaskStatusFailure, TaskStatusSuccess}).
|
||||||
|
Where("submit_time < ?", cutoffUnix).
|
||||||
|
Order("submit_time").
|
||||||
|
Limit(limit).
|
||||||
|
Find(&tasks).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return tasks
|
||||||
|
}
|
||||||
|
|
||||||
func GetAllUnFinishSyncTasks(limit int) []*Task {
|
func GetAllUnFinishSyncTasks(limit int) []*Task {
|
||||||
var tasks []*Task
|
var tasks []*Task
|
||||||
var err error
|
var err error
|
||||||
@@ -401,6 +415,11 @@ func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
|
|||||||
return result.RowsAffected > 0, nil
|
return result.RowsAffected > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TaskBulkUpdateByID performs an unconditional bulk UPDATE by primary key IDs.
|
||||||
|
// WARNING: This function has NO CAS (Compare-And-Swap) guard — it will overwrite
|
||||||
|
// any concurrent status changes. DO NOT use in billing/quota lifecycle flows
|
||||||
|
// (e.g., timeout, success, failure transitions that trigger refunds or settlements).
|
||||||
|
// For status transitions that involve billing, use Task.UpdateWithStatus() instead.
|
||||||
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
|
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
|
||||||
if len(ids) == 0 {
|
if len(ids) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -35,12 +35,65 @@ type TaskPollingAdaptor interface {
|
|||||||
// 打破 service -> relay -> relay/channel -> service 的循环依赖。
|
// 打破 service -> relay -> relay/channel -> service 的循环依赖。
|
||||||
var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
|
var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
|
||||||
|
|
||||||
|
// sweepTimedOutTasks 在主轮询之前独立清理超时任务。
|
||||||
|
// 每次最多处理 100 条,剩余的下个周期继续处理。
|
||||||
|
// 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。
|
||||||
|
func sweepTimedOutTasks(ctx context.Context) {
|
||||||
|
if constant.TaskTimeoutMinutes <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60
|
||||||
|
tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100)
|
||||||
|
if len(tasks) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC
|
||||||
|
reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes)
|
||||||
|
legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)"
|
||||||
|
now := time.Now().Unix()
|
||||||
|
timedOutCount := 0
|
||||||
|
|
||||||
|
for _, task := range tasks {
|
||||||
|
isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff
|
||||||
|
|
||||||
|
oldStatus := task.Status
|
||||||
|
task.Status = model.TaskStatusFailure
|
||||||
|
task.Progress = "100%"
|
||||||
|
task.FinishTime = now
|
||||||
|
if isLegacy {
|
||||||
|
task.FailReason = legacyReason
|
||||||
|
} else {
|
||||||
|
task.FailReason = reason
|
||||||
|
}
|
||||||
|
|
||||||
|
won, err := task.UpdateWithStatus(oldStatus)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !won {
|
||||||
|
logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
timedOutCount++
|
||||||
|
if !isLegacy && task.Quota != 0 {
|
||||||
|
RefundTaskQuota(ctx, task, reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if timedOutCount > 0 {
|
||||||
|
logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
|
// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
|
||||||
func TaskPollingLoop() {
|
func TaskPollingLoop() {
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(15) * time.Second)
|
time.Sleep(time.Duration(15) * time.Second)
|
||||||
common.SysLog("任务进度轮询开始")
|
common.SysLog("任务进度轮询开始")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
|
sweepTimedOutTasks(ctx)
|
||||||
allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
|
allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
|
||||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||||
for _, t := range allTasks {
|
for _, t := range allTasks {
|
||||||
|
|||||||
Reference in New Issue
Block a user