refactor(relay): improve channel locking and retry logic in RelayTask

- Enhanced the RelayTask function to utilize a locked channel when available, allowing for better reuse during retries.
- Updated error handling to ensure proper context setup for the selected channel.
- Clarified comments in ResolveOriginTask regarding channel locking and retry behavior.
- Introduced a new field in TaskRelayInfo to store the locked channel object, improving type safety and reducing import cycles.
This commit is contained in:
CaIon
2026-02-21 23:47:55 +08:00
parent 76892e8376
commit cda540180b
3 changed files with 36 additions and 18 deletions

View File

@@ -619,6 +619,11 @@ type TaskRelayInfo struct {
PublicTaskID string
ConsumeQuota bool
// LockedChannel holds the full channel object when the request is bound to
// a specific channel (e.g., remix on origin task's channel). Stored as any
// to avoid an import cycle with model; callers type-assert to *model.Channel.
LockedChannel any
}
type TaskSubmitReq struct {

View File

@@ -32,8 +32,9 @@ type TaskSubmitResult struct {
}
// ResolveOriginTask 处理基于已有任务的提交remix / continuation
// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过
// specific_channel_id 禁止重试),以及提取 OtherRatios时长、分辨率
// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
// (通过 info.LockedChannel重试时复用同一渠道并轮换 key
// 以及提取 OtherRatios时长、分辨率
// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
// 检测 remix action
@@ -77,15 +78,17 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr
}
}
// 锁定到原始任务的渠道(如果与当前选中的不同
// 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key
ch, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
}
if ch.Status != common.ChannelStatusEnabled {
return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
}
info.LockedChannel = ch
if originTask.ChannelId != info.ChannelId {
ch, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
}
if ch.Status != common.ChannelStatusEnabled {
return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
}
key, _, newAPIError := ch.GetNextEnabledKey()
if newAPIError != nil {
return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
@@ -101,9 +104,6 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr
info.ApiKey = key
}
// 渠道已锁定到原始任务 → 禁止重试切换到其他渠道
c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId))
// 提取 remix 参数(时长、分辨率 → OtherRatios
if info.Action == constant.TaskActionRemix {
if originTask.PrivateData.BillingContext != nil {