refactor(channel_select): enhance retry logic and context key usage for channel selection

This commit is contained in:
CaIon
2025-12-13 16:43:38 +08:00
parent b58fa3debc
commit c51936e068
7 changed files with 155 additions and 52 deletions

View File

@@ -21,7 +21,6 @@ const (
ContextKeyTokenCrossGroupRetry ContextKey = "token_cross_group_retry"
/* channel related keys */
ContextKeyAutoGroupIndex ContextKey = "auto_group_index"
ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelName ContextKey = "channel_name"
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
@@ -39,6 +38,10 @@ const (
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
ContextKeyChannelKey ContextKey = "channel_key"
ContextKeyAutoGroup ContextKey = "auto_group"
ContextKeyAutoGroupIndex ContextKey = "auto_group_index"
ContextKeyAutoGroupRetryIndex ContextKey = "auto_group_retry_index"
/* user related keys */
ContextKeyUserId ContextKey = "id"
ContextKeyUserSetting ContextKey = "user_setting"

View File

@@ -3,10 +3,7 @@ package controller
import (
"errors"
"fmt"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
@@ -54,12 +51,6 @@ func Playground(c *gin.Context) {
Group: relayInfo.UsingGroup,
}
_ = middleware.SetupContextForToken(c, tempToken)
_, newAPIError = getChannel(c, relayInfo, 0)
if newAPIError != nil {
return
}
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
Relay(c, types.RelayFormatOpenAI)
}

View File

@@ -157,8 +157,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
}
}()
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, relayInfo, i)
retryParam := &service.RetryParam{
Ctx: c,
TokenGroup: relayInfo.TokenGroup,
ModelName: relayInfo.OriginModelName,
Retry: common.GetPointer(0),
}
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
channel, err := getChannel(c, relayInfo, retryParam)
if err != nil {
logger.LogError(c, err.Error())
newAPIError = err
@@ -186,7 +193,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
if !shouldRetry(c, newAPIError, common.RetryTimes-retryParam.GetRetry()) {
break
}
}
@@ -211,8 +218,8 @@ func addUsedChannel(c *gin.Context, channelId int) {
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) {
if retryCount == 0 {
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.NewAPIError) {
if info.ChannelMeta == nil {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
@@ -225,7 +232,7 @@ func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount)
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(retryParam)
info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
@@ -370,7 +377,7 @@ func RelayMidjourney(c *gin.Context) {
}
func RelayNotImplemented(c *gin.Context) {
err := dto.OpenAIError{
err := types.OpenAIError{
Message: "API not implemented",
Type: "new_api_error",
Param: "",
@@ -382,7 +389,7 @@ func RelayNotImplemented(c *gin.Context) {
}
func RelayNotFound(c *gin.Context) {
err := dto.OpenAIError{
err := types.OpenAIError{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
@@ -405,8 +412,14 @@ func RelayTask(c *gin.Context) {
if taskErr == nil {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, newAPIError := getChannel(c, relayInfo, i)
retryParam := &service.RetryParam{
Ctx: c,
TokenGroup: relayInfo.TokenGroup,
ModelName: relayInfo.OriginModelName,
Retry: common.GetPointer(0),
}
for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() {
channel, newAPIError := getChannel(c, relayInfo, retryParam)
if newAPIError != nil {
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
@@ -416,7 +429,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, _ := common.GetRequestBody(c)

View File

@@ -308,7 +308,7 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
c.Set("token_model_limit_enabled", false)
}
common.SetContextKey(c, constant.ContextKeyTokenGroup, token.Group)
c.Set("token_cross_group_retry", token.CrossGroupRetry)
common.SetContextKey(c, constant.ContextKeyTokenCrossGroupRetry, token.CrossGroupRetry)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1])

View File

@@ -97,7 +97,12 @@ func Distribute() func(c *gin.Context) {
common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup)
}
}
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(c, usingGroup, modelRequest.Model, 0)
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
Ctx: c,
ModelName: modelRequest.Model,
TokenGroup: usingGroup,
Retry: common.GetPointer(0),
})
if err != nil {
showGroup := usingGroup
if usingGroup == "auto" {

View File

@@ -11,50 +11,141 @@ import (
"github.com/gin-gonic/gin"
)
type RetryParam struct {
Ctx *gin.Context
TokenGroup string
ModelName string
Retry *int
}
func (p *RetryParam) GetRetry() int {
if p.Retry == nil {
return 0
}
return *p.Retry
}
func (p *RetryParam) SetRetry(retry int) {
p.Retry = &retry
}
func (p *RetryParam) IncreaseRetry() {
if p.Retry == nil {
p.Retry = new(int)
}
*p.Retry++
}
// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
func CacheGetRandomSatisfiedChannel(c *gin.Context, tokenGroup string, modelName string, retry int) (*model.Channel, string, error) {
// 尝试获取一个满足要求的随机渠道。
//
// For "auto" tokenGroup with cross-group Retry enabled:
// 对于启用了跨分组重试的 "auto" tokenGroup
//
// - Each group will exhaust all its priorities before moving to the next group.
// 每个分组会用完所有优先级后才会切换到下一个分组。
//
// - Uses ContextKeyAutoGroupIndex to track current group index.
// 使用 ContextKeyAutoGroupIndex 跟踪当前分组索引。
//
// - Uses ContextKeyAutoGroupRetryIndex to track the global Retry count when current group started.
// 使用 ContextKeyAutoGroupRetryIndex 跟踪当前分组开始时的全局重试次数。
//
// - priorityRetry = Retry - startRetryIndex, represents the priority level within current group.
// priorityRetry = Retry - startRetryIndex表示当前分组内的优先级级别。
//
// - When GetRandomSatisfiedChannel returns nil (priorities exhausted), moves to next group.
// 当 GetRandomSatisfiedChannel 返回 nil优先级用完切换到下一个分组。
//
// Example flow (2 groups, each with 2 priorities, RetryTimes=3):
// 示例流程2个分组每个有2个优先级RetryTimes=3
//
// Retry=0: GroupA, priority0 (startRetryIndex=0, priorityRetry=0)
// 分组A, 优先级0
//
// Retry=1: GroupA, priority1 (startRetryIndex=0, priorityRetry=1)
// 分组A, 优先级1
//
// Retry=2: GroupA exhausted → GroupB, priority0 (startRetryIndex=2, priorityRetry=0)
// 分组A用完 → 分组B, 优先级0
//
// Retry=3: GroupB, priority1 (startRetryIndex=2, priorityRetry=1)
// 分组B, 优先级1
func CacheGetRandomSatisfiedChannel(param *RetryParam) (*model.Channel, string, error) {
var channel *model.Channel
var err error
selectGroup := tokenGroup
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
if tokenGroup == "auto" {
selectGroup := param.TokenGroup
userGroup := common.GetContextKeyString(param.Ctx, constant.ContextKeyUserGroup)
if param.TokenGroup == "auto" {
if len(setting.GetAutoGroups()) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled")
}
autoGroups := GetUserAutoGroup(userGroup)
startIndex := 0
priorityRetry := retry
crossGroupRetry := common.GetContextKeyBool(c, constant.ContextKeyTokenCrossGroupRetry)
if crossGroupRetry && retry > 0 {
logger.LogDebug(c, "Auto group retry cross group, retry: %d", retry)
if lastIndex, exists := common.GetContextKey(c, constant.ContextKeyAutoGroupIndex); exists {
if idx, ok := lastIndex.(int); ok {
startIndex = idx + 1
priorityRetry = 0
}
// startGroupIndex: the group index to start searching from
// startGroupIndex: 开始搜索的分组索引
startGroupIndex := 0
crossGroupRetry := common.GetContextKeyBool(param.Ctx, constant.ContextKeyTokenCrossGroupRetry)
if lastGroupIndex, exists := common.GetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex); exists {
if idx, ok := lastGroupIndex.(int); ok {
startGroupIndex = idx
}
logger.LogDebug(c, "Auto group retry cross group, start index: %d", startIndex)
}
for i := startIndex; i < len(autoGroups); i++ {
for i := startGroupIndex; i < len(autoGroups); i++ {
autoGroup := autoGroups[i]
logger.LogDebug(c, "Auto selecting group: %s", autoGroup)
channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, priorityRetry)
if channel == nil {
// Calculate priorityRetry for current group
// 计算当前分组的 priorityRetry
priorityRetry := param.GetRetry()
// If moved to a new group, reset priorityRetry and update startRetryIndex
// 如果切换到新分组,重置 priorityRetry 并更新 startRetryIndex
if i > startGroupIndex {
priorityRetry = 0
continue
} else {
c.Set("auto_group", autoGroup)
common.SetContextKey(c, constant.ContextKeyAutoGroupIndex, i)
selectGroup = autoGroup
logger.LogDebug(c, "Auto selected group: %s", autoGroup)
break
}
logger.LogDebug(param.Ctx, "Auto selecting group: %s, priorityRetry: %d", autoGroup, priorityRetry)
channel, _ = model.GetRandomSatisfiedChannel(autoGroup, param.ModelName, priorityRetry)
if channel == nil {
// Current group has no available channel for this model, try next group
// 当前分组没有该模型的可用渠道,尝试下一个分组
logger.LogDebug(param.Ctx, "No available channel in group %s for model %s at priorityRetry %d, trying next group", autoGroup, param.ModelName, priorityRetry)
// 重置状态以尝试下一个分组
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i+1)
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupRetryIndex, 0)
// Reset retry counter so outer loop can continue for next group
// 重置重试计数器,以便外层循环可以为下一个分组继续
param.SetRetry(0)
continue
}
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroup, autoGroup)
selectGroup = autoGroup
logger.LogDebug(param.Ctx, "Auto selected group: %s", autoGroup)
// Prepare state for next retry
// 为下一次重试准备状态
if crossGroupRetry && priorityRetry >= common.RetryTimes {
// Current group has exhausted all retries, prepare to switch to next group
// This request still uses current group, but next retry will use next group
// 当前分组已用完所有重试次数,准备切换到下一个分组
// 本次请求仍使用当前分组,但下次重试将使用下一个分组
logger.LogDebug(param.Ctx, "Current group %s retries exhausted (priorityRetry=%d >= RetryTimes=%d), preparing switch to next group for next retry", autoGroup, priorityRetry, common.RetryTimes)
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i+1)
// Reset retry counter so outer loop can continue for next group
// 重置重试计数器,以便外层循环可以为下一个分组继续
param.SetRetry(-1)
} else {
// Stay in current group, save current state
// 保持在当前分组,保存当前状态
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i)
}
break
}
} else {
channel, err = model.GetRandomSatisfiedChannel(tokenGroup, modelName, retry)
channel, err = model.GetRandomSatisfiedChannel(param.TokenGroup, param.ModelName, param.GetRetry())
if err != nil {
return nil, tokenGroup, err
return nil, param.TokenGroup, err
}
}
return channel, selectGroup, nil

View File

@@ -108,7 +108,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
modelRatio, _, _ := ratio_setting.GetModelRatio(modelName)
autoGroup, exists := ctx.Get("auto_group")
autoGroup, exists := common.GetContextKey(ctx, constant.ContextKeyAutoGroup)
if exists {
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
log.Printf("final group ratio: %f", groupRatio)