diff --git a/controller/relay.go b/controller/relay.go index 1477df8f7..6951974c5 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -497,11 +497,24 @@ func RelayTask(c *gin.Context) { } for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { - channel, channelErr := getChannel(c, relayInfo, retryParam) - if channelErr != nil { - logger.LogError(c, channelErr.Error()) - taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) - break + var channel *model.Channel + + if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil { + channel = lockedCh + if retryParam.GetRetry() > 0 { + if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil { + taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError) + break + } + } + } else { + var channelErr *types.NewAPIError + channel, channelErr = getChannel(c, relayInfo, retryParam) + if channelErr != nil { + logger.LogError(c, channelErr.Error()) + taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) + break + } } addUsedChannel(c, channel.Id) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index b68826812..541f1b9f8 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -619,6 +619,11 @@ type TaskRelayInfo struct { PublicTaskID string ConsumeQuota bool + + // LockedChannel holds the full channel object when the request is bound to + // a specific channel (e.g., remix on origin task's channel). Stored as any + // to avoid an import cycle with model; callers type-assert to *model.Channel. + LockedChannel any } type TaskSubmitReq struct { diff --git a/relay/relay_task.go b/relay/relay_task.go index cc4d0e450..8d0e61d72 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -32,8 +32,9 @@ type TaskSubmitResult struct { } // ResolveOriginTask 处理基于已有任务的提交(remix / continuation): -// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过 -// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。 +// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道 +// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key), +// 以及提取 OtherRatios(时长、分辨率)。 // 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。 func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { // 检测 remix action @@ -77,15 +78,17 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr } } - // 锁定到原始任务的渠道(如果与当前选中的不同) + // 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key) + ch, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + } + if ch.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) + } + info.LockedChannel = ch + if originTask.ChannelId != info.ChannelId { - ch, err := model.GetChannelById(originTask.ChannelId, true) - if err != nil { - return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) - } - if ch.Status != common.ChannelStatusEnabled { - return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) - } key, _, newAPIError := ch.GetNextEnabledKey() if newAPIError != nil { return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) @@ -101,9 +104,6 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr info.ApiKey = key } - // 渠道已锁定到原始任务 → 禁止重试切换到其他渠道 - c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId)) - // 提取 remix 参数(时长、分辨率 → OtherRatios) if info.Action == constant.TaskActionRemix { if originTask.PrivateData.BillingContext != nil {