From c51936e068c112ad77526fcd4c0b4fd517eb5435 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 13 Dec 2025 16:43:38 +0800 Subject: [PATCH] refactor(channel_select): enhance retry logic and context key usage for channel selection --- constant/context_key.go | 5 +- controller/playground.go | 9 --- controller/relay.go | 35 ++++++--- middleware/auth.go | 2 +- middleware/distributor.go | 7 +- service/channel_select.go | 147 ++++++++++++++++++++++++++++++-------- service/quota.go | 2 +- 7 files changed, 155 insertions(+), 52 deletions(-) diff --git a/constant/context_key.go b/constant/context_key.go index ecc5178ee..833aabae1 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -21,7 +21,6 @@ const ( ContextKeyTokenCrossGroupRetry ContextKey = "token_cross_group_retry" /* channel related keys */ - ContextKeyAutoGroupIndex ContextKey = "auto_group_index" ContextKeyChannelId ContextKey = "channel_id" ContextKeyChannelName ContextKey = "channel_name" ContextKeyChannelCreateTime ContextKey = "channel_create_time" @@ -39,6 +38,10 @@ const ( ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index" ContextKeyChannelKey ContextKey = "channel_key" + ContextKeyAutoGroup ContextKey = "auto_group" + ContextKeyAutoGroupIndex ContextKey = "auto_group_index" + ContextKeyAutoGroupRetryIndex ContextKey = "auto_group_retry_index" + /* user related keys */ ContextKeyUserId ContextKey = "id" ContextKeyUserSetting ContextKey = "user_setting" diff --git a/controller/playground.go b/controller/playground.go index d9e2ba9a1..501c4e156 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -3,10 +3,7 @@ package controller import ( "errors" "fmt" - "time" - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/model" relaycommon "github.com/QuantumNous/new-api/relay/common" @@ -54,12 +51,6 @@ func Playground(c *gin.Context) { Group: relayInfo.UsingGroup, } _ = middleware.SetupContextForToken(c, tempToken) - _, newAPIError = getChannel(c, relayInfo, 0) - if newAPIError != nil { - return - } - //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) - common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) Relay(c, types.RelayFormatOpenAI) } diff --git a/controller/relay.go b/controller/relay.go index 2013b9c0f..a0618452c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -157,8 +157,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { } }() - for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, relayInfo, i) + retryParam := &service.RetryParam{ + Ctx: c, + TokenGroup: relayInfo.TokenGroup, + ModelName: relayInfo.OriginModelName, + Retry: common.GetPointer(0), + } + + for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { + channel, err := getChannel(c, relayInfo, retryParam) if err != nil { logger.LogError(c, err.Error()) newAPIError = err @@ -186,7 +193,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - if !shouldRetry(c, newAPIError, common.RetryTimes-i) { + if !shouldRetry(c, newAPIError, common.RetryTimes-retryParam.GetRetry()) { break } } @@ -211,8 +218,8 @@ func addUsedChannel(c *gin.Context, channelId int) { c.Set("use_channel", useChannel) } -func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) { - if retryCount == 0 { +func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.NewAPIError) { + if info.ChannelMeta == nil { autoBan := c.GetBool("auto_ban") autoBanInt := 1 if !autoBan { @@ -225,7 +232,7 @@ func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*m AutoBan: &autoBanInt, }, nil } - channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount) + channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(retryParam) info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info) @@ -370,7 +377,7 @@ func RelayMidjourney(c *gin.Context) { } func RelayNotImplemented(c *gin.Context) { - err := dto.OpenAIError{ + err := types.OpenAIError{ Message: "API not implemented", Type: "new_api_error", Param: "", @@ -382,7 +389,7 @@ func RelayNotImplemented(c *gin.Context) { } func RelayNotFound(c *gin.Context) { - err := dto.OpenAIError{ + err := types.OpenAIError{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", @@ -405,8 +412,14 @@ func RelayTask(c *gin.Context) { if taskErr == nil { retryTimes = 0 } - for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { - channel, newAPIError := getChannel(c, relayInfo, i) + retryParam := &service.RetryParam{ + Ctx: c, + TokenGroup: relayInfo.TokenGroup, + ModelName: relayInfo.OriginModelName, + Retry: common.GetPointer(0), + } + for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() { + channel, newAPIError := getChannel(c, relayInfo, retryParam) if newAPIError != nil { logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) @@ -416,7 +429,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) - logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry())) //middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, _ := common.GetRequestBody(c) diff --git a/middleware/auth.go b/middleware/auth.go index cefc4e068..d24120042 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -308,7 +308,7 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e c.Set("token_model_limit_enabled", false) } common.SetContextKey(c, constant.ContextKeyTokenGroup, token.Group) - c.Set("token_cross_group_retry", token.CrossGroupRetry) + common.SetContextKey(c, constant.ContextKeyTokenCrossGroupRetry, token.CrossGroupRetry) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("specific_channel_id", parts[1]) diff --git a/middleware/distributor.go b/middleware/distributor.go index 3c8529d96..390dc059f 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -97,7 +97,12 @@ func Distribute() func(c *gin.Context) { common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup) } } - channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(c, usingGroup, modelRequest.Model, 0) + channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{ + Ctx: c, + ModelName: modelRequest.Model, + TokenGroup: usingGroup, + Retry: common.GetPointer(0), + }) if err != nil { showGroup := usingGroup if usingGroup == "auto" { diff --git a/service/channel_select.go b/service/channel_select.go index ab33bcd19..afaf4f04e 100644 --- a/service/channel_select.go +++ b/service/channel_select.go @@ -11,50 +11,141 @@ import ( "github.com/gin-gonic/gin" ) +type RetryParam struct { + Ctx *gin.Context + TokenGroup string + ModelName string + Retry *int +} + +func (p *RetryParam) GetRetry() int { + if p.Retry == nil { + return 0 + } + return *p.Retry +} + +func (p *RetryParam) SetRetry(retry int) { + p.Retry = &retry +} + +func (p *RetryParam) IncreaseRetry() { + if p.Retry == nil { + p.Retry = new(int) + } + *p.Retry++ +} + // CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements. -func CacheGetRandomSatisfiedChannel(c *gin.Context, tokenGroup string, modelName string, retry int) (*model.Channel, string, error) { +// 尝试获取一个满足要求的随机渠道。 +// +// For "auto" tokenGroup with cross-group Retry enabled: +// 对于启用了跨分组重试的 "auto" tokenGroup: +// +// - Each group will exhaust all its priorities before moving to the next group. +// 每个分组会用完所有优先级后才会切换到下一个分组。 +// +// - Uses ContextKeyAutoGroupIndex to track current group index. +// 使用 ContextKeyAutoGroupIndex 跟踪当前分组索引。 +// +// - Uses ContextKeyAutoGroupRetryIndex to track the global Retry count when current group started. +// 使用 ContextKeyAutoGroupRetryIndex 跟踪当前分组开始时的全局重试次数。 +// +// - priorityRetry = Retry - startRetryIndex, represents the priority level within current group. +// priorityRetry = Retry - startRetryIndex,表示当前分组内的优先级级别。 +// +// - When GetRandomSatisfiedChannel returns nil (priorities exhausted), moves to next group. +// 当 GetRandomSatisfiedChannel 返回 nil(优先级用完)时,切换到下一个分组。 +// +// Example flow (2 groups, each with 2 priorities, RetryTimes=3): +// 示例流程(2个分组,每个有2个优先级,RetryTimes=3): +// +// Retry=0: GroupA, priority0 (startRetryIndex=0, priorityRetry=0) +// 分组A, 优先级0 +// +// Retry=1: GroupA, priority1 (startRetryIndex=0, priorityRetry=1) +// 分组A, 优先级1 +// +// Retry=2: GroupA exhausted → GroupB, priority0 (startRetryIndex=2, priorityRetry=0) +// 分组A用完 → 分组B, 优先级0 +// +// Retry=3: GroupB, priority1 (startRetryIndex=2, priorityRetry=1) +// 分组B, 优先级1 +func CacheGetRandomSatisfiedChannel(param *RetryParam) (*model.Channel, string, error) { var channel *model.Channel var err error - selectGroup := tokenGroup - userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) - if tokenGroup == "auto" { + selectGroup := param.TokenGroup + userGroup := common.GetContextKeyString(param.Ctx, constant.ContextKeyUserGroup) + + if param.TokenGroup == "auto" { if len(setting.GetAutoGroups()) == 0 { return nil, selectGroup, errors.New("auto groups is not enabled") } autoGroups := GetUserAutoGroup(userGroup) - startIndex := 0 - priorityRetry := retry - crossGroupRetry := common.GetContextKeyBool(c, constant.ContextKeyTokenCrossGroupRetry) - if crossGroupRetry && retry > 0 { - logger.LogDebug(c, "Auto group retry cross group, retry: %d", retry) - if lastIndex, exists := common.GetContextKey(c, constant.ContextKeyAutoGroupIndex); exists { - if idx, ok := lastIndex.(int); ok { - startIndex = idx + 1 - priorityRetry = 0 - } + + // startGroupIndex: the group index to start searching from + // startGroupIndex: 开始搜索的分组索引 + startGroupIndex := 0 + crossGroupRetry := common.GetContextKeyBool(param.Ctx, constant.ContextKeyTokenCrossGroupRetry) + + if lastGroupIndex, exists := common.GetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex); exists { + if idx, ok := lastGroupIndex.(int); ok { + startGroupIndex = idx } - logger.LogDebug(c, "Auto group retry cross group, start index: %d", startIndex) } - for i := startIndex; i < len(autoGroups); i++ { + for i := startGroupIndex; i < len(autoGroups); i++ { autoGroup := autoGroups[i] - logger.LogDebug(c, "Auto selecting group: %s", autoGroup) - channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, priorityRetry) - if channel == nil { + // Calculate priorityRetry for current group + // 计算当前分组的 priorityRetry + priorityRetry := param.GetRetry() + // If moved to a new group, reset priorityRetry and update startRetryIndex + // 如果切换到新分组,重置 priorityRetry 并更新 startRetryIndex + if i > startGroupIndex { priorityRetry = 0 - continue - } else { - c.Set("auto_group", autoGroup) - common.SetContextKey(c, constant.ContextKeyAutoGroupIndex, i) - selectGroup = autoGroup - logger.LogDebug(c, "Auto selected group: %s", autoGroup) - break } + logger.LogDebug(param.Ctx, "Auto selecting group: %s, priorityRetry: %d", autoGroup, priorityRetry) + + channel, _ = model.GetRandomSatisfiedChannel(autoGroup, param.ModelName, priorityRetry) + if channel == nil { + // Current group has no available channel for this model, try next group + // 当前分组没有该模型的可用渠道,尝试下一个分组 + logger.LogDebug(param.Ctx, "No available channel in group %s for model %s at priorityRetry %d, trying next group", autoGroup, param.ModelName, priorityRetry) + // 重置状态以尝试下一个分组 + common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i+1) + common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupRetryIndex, 0) + // Reset retry counter so outer loop can continue for next group + // 重置重试计数器,以便外层循环可以为下一个分组继续 + param.SetRetry(0) + continue + } + common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroup, autoGroup) + selectGroup = autoGroup + logger.LogDebug(param.Ctx, "Auto selected group: %s", autoGroup) + + // Prepare state for next retry + // 为下一次重试准备状态 + if crossGroupRetry && priorityRetry >= common.RetryTimes { + // Current group has exhausted all retries, prepare to switch to next group + // This request still uses current group, but next retry will use next group + // 当前分组已用完所有重试次数,准备切换到下一个分组 + // 本次请求仍使用当前分组,但下次重试将使用下一个分组 + logger.LogDebug(param.Ctx, "Current group %s retries exhausted (priorityRetry=%d >= RetryTimes=%d), preparing switch to next group for next retry", autoGroup, priorityRetry, common.RetryTimes) + common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i+1) + // Reset retry counter so outer loop can continue for next group + // 重置重试计数器,以便外层循环可以为下一个分组继续 + param.SetRetry(-1) + } else { + // Stay in current group, save current state + // 保持在当前分组,保存当前状态 + common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i) + } + break } } else { - channel, err = model.GetRandomSatisfiedChannel(tokenGroup, modelName, retry) + channel, err = model.GetRandomSatisfiedChannel(param.TokenGroup, param.ModelName, param.GetRetry()) if err != nil { - return nil, tokenGroup, err + return nil, param.TokenGroup, err } } return channel, selectGroup, nil diff --git a/service/quota.go b/service/quota.go index 0f41b851b..0da8dafd3 100644 --- a/service/quota.go +++ b/service/quota.go @@ -108,7 +108,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup) modelRatio, _, _ := ratio_setting.GetModelRatio(modelName) - autoGroup, exists := ctx.Get("auto_group") + autoGroup, exists := common.GetContextKey(ctx, constant.ContextKeyAutoGroup) if exists { groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string)) log.Printf("final group ratio: %f", groupRatio)