mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:05:21 +00:00
refactor(channel_select): enhance retry logic and context key usage for channel selection
This commit is contained in:
@@ -21,7 +21,6 @@ const (
|
||||
ContextKeyTokenCrossGroupRetry ContextKey = "token_cross_group_retry"
|
||||
|
||||
/* channel related keys */
|
||||
ContextKeyAutoGroupIndex ContextKey = "auto_group_index"
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelName ContextKey = "channel_name"
|
||||
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
|
||||
@@ -39,6 +38,10 @@ const (
|
||||
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
|
||||
ContextKeyChannelKey ContextKey = "channel_key"
|
||||
|
||||
ContextKeyAutoGroup ContextKey = "auto_group"
|
||||
ContextKeyAutoGroupIndex ContextKey = "auto_group_index"
|
||||
ContextKeyAutoGroupRetryIndex ContextKey = "auto_group_retry_index"
|
||||
|
||||
/* user related keys */
|
||||
ContextKeyUserId ContextKey = "id"
|
||||
ContextKeyUserSetting ContextKey = "user_setting"
|
||||
|
||||
@@ -3,10 +3,7 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -54,12 +51,6 @@ func Playground(c *gin.Context) {
|
||||
Group: relayInfo.UsingGroup,
|
||||
}
|
||||
_ = middleware.SetupContextForToken(c, tempToken)
|
||||
_, newAPIError = getChannel(c, relayInfo, 0)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
Relay(c, types.RelayFormatOpenAI)
|
||||
}
|
||||
|
||||
@@ -157,8 +157,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, relayInfo, i)
|
||||
retryParam := &service.RetryParam{
|
||||
Ctx: c,
|
||||
TokenGroup: relayInfo.TokenGroup,
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
Retry: common.GetPointer(0),
|
||||
}
|
||||
|
||||
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
|
||||
channel, err := getChannel(c, relayInfo, retryParam)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
newAPIError = err
|
||||
@@ -186,7 +193,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-retryParam.GetRetry()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -211,8 +218,8 @@ func addUsedChannel(c *gin.Context, channelId int) {
|
||||
c.Set("use_channel", useChannel)
|
||||
}
|
||||
|
||||
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||||
if retryCount == 0 {
|
||||
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.NewAPIError) {
|
||||
if info.ChannelMeta == nil {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
autoBanInt := 1
|
||||
if !autoBan {
|
||||
@@ -225,7 +232,7 @@ func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*m
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount)
|
||||
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(retryParam)
|
||||
|
||||
info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
|
||||
|
||||
@@ -370,7 +377,7 @@ func RelayMidjourney(c *gin.Context) {
|
||||
}
|
||||
|
||||
func RelayNotImplemented(c *gin.Context) {
|
||||
err := dto.OpenAIError{
|
||||
err := types.OpenAIError{
|
||||
Message: "API not implemented",
|
||||
Type: "new_api_error",
|
||||
Param: "",
|
||||
@@ -382,7 +389,7 @@ func RelayNotImplemented(c *gin.Context) {
|
||||
}
|
||||
|
||||
func RelayNotFound(c *gin.Context) {
|
||||
err := dto.OpenAIError{
|
||||
err := types.OpenAIError{
|
||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||
Type: "invalid_request_error",
|
||||
Param: "",
|
||||
@@ -405,8 +412,14 @@ func RelayTask(c *gin.Context) {
|
||||
if taskErr == nil {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, newAPIError := getChannel(c, relayInfo, i)
|
||||
retryParam := &service.RetryParam{
|
||||
Ctx: c,
|
||||
TokenGroup: relayInfo.TokenGroup,
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
Retry: common.GetPointer(0),
|
||||
}
|
||||
for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() {
|
||||
channel, newAPIError := getChannel(c, relayInfo, retryParam)
|
||||
if newAPIError != nil {
|
||||
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||
@@ -416,7 +429,7 @@ func RelayTask(c *gin.Context) {
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
|
||||
@@ -308,7 +308,7 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyTokenGroup, token.Group)
|
||||
c.Set("token_cross_group_retry", token.CrossGroupRetry)
|
||||
common.SetContextKey(c, constant.ContextKeyTokenCrossGroupRetry, token.CrossGroupRetry)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
|
||||
@@ -97,7 +97,12 @@ func Distribute() func(c *gin.Context) {
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup)
|
||||
}
|
||||
}
|
||||
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(c, usingGroup, modelRequest.Model, 0)
|
||||
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
|
||||
Ctx: c,
|
||||
ModelName: modelRequest.Model,
|
||||
TokenGroup: usingGroup,
|
||||
Retry: common.GetPointer(0),
|
||||
})
|
||||
if err != nil {
|
||||
showGroup := usingGroup
|
||||
if usingGroup == "auto" {
|
||||
|
||||
@@ -11,50 +11,141 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type RetryParam struct {
|
||||
Ctx *gin.Context
|
||||
TokenGroup string
|
||||
ModelName string
|
||||
Retry *int
|
||||
}
|
||||
|
||||
func (p *RetryParam) GetRetry() int {
|
||||
if p.Retry == nil {
|
||||
return 0
|
||||
}
|
||||
return *p.Retry
|
||||
}
|
||||
|
||||
func (p *RetryParam) SetRetry(retry int) {
|
||||
p.Retry = &retry
|
||||
}
|
||||
|
||||
func (p *RetryParam) IncreaseRetry() {
|
||||
if p.Retry == nil {
|
||||
p.Retry = new(int)
|
||||
}
|
||||
*p.Retry++
|
||||
}
|
||||
|
||||
// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
|
||||
func CacheGetRandomSatisfiedChannel(c *gin.Context, tokenGroup string, modelName string, retry int) (*model.Channel, string, error) {
|
||||
// 尝试获取一个满足要求的随机渠道。
|
||||
//
|
||||
// For "auto" tokenGroup with cross-group Retry enabled:
|
||||
// 对于启用了跨分组重试的 "auto" tokenGroup:
|
||||
//
|
||||
// - Each group will exhaust all its priorities before moving to the next group.
|
||||
// 每个分组会用完所有优先级后才会切换到下一个分组。
|
||||
//
|
||||
// - Uses ContextKeyAutoGroupIndex to track current group index.
|
||||
// 使用 ContextKeyAutoGroupIndex 跟踪当前分组索引。
|
||||
//
|
||||
// - Uses ContextKeyAutoGroupRetryIndex to track the global Retry count when current group started.
|
||||
// 使用 ContextKeyAutoGroupRetryIndex 跟踪当前分组开始时的全局重试次数。
|
||||
//
|
||||
// - priorityRetry = Retry - startRetryIndex, represents the priority level within current group.
|
||||
// priorityRetry = Retry - startRetryIndex,表示当前分组内的优先级级别。
|
||||
//
|
||||
// - When GetRandomSatisfiedChannel returns nil (priorities exhausted), moves to next group.
|
||||
// 当 GetRandomSatisfiedChannel 返回 nil(优先级用完)时,切换到下一个分组。
|
||||
//
|
||||
// Example flow (2 groups, each with 2 priorities, RetryTimes=3):
|
||||
// 示例流程(2个分组,每个有2个优先级,RetryTimes=3):
|
||||
//
|
||||
// Retry=0: GroupA, priority0 (startRetryIndex=0, priorityRetry=0)
|
||||
// 分组A, 优先级0
|
||||
//
|
||||
// Retry=1: GroupA, priority1 (startRetryIndex=0, priorityRetry=1)
|
||||
// 分组A, 优先级1
|
||||
//
|
||||
// Retry=2: GroupA exhausted → GroupB, priority0 (startRetryIndex=2, priorityRetry=0)
|
||||
// 分组A用完 → 分组B, 优先级0
|
||||
//
|
||||
// Retry=3: GroupB, priority1 (startRetryIndex=2, priorityRetry=1)
|
||||
// 分组B, 优先级1
|
||||
func CacheGetRandomSatisfiedChannel(param *RetryParam) (*model.Channel, string, error) {
|
||||
var channel *model.Channel
|
||||
var err error
|
||||
selectGroup := tokenGroup
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
if tokenGroup == "auto" {
|
||||
selectGroup := param.TokenGroup
|
||||
userGroup := common.GetContextKeyString(param.Ctx, constant.ContextKeyUserGroup)
|
||||
|
||||
if param.TokenGroup == "auto" {
|
||||
if len(setting.GetAutoGroups()) == 0 {
|
||||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||||
}
|
||||
autoGroups := GetUserAutoGroup(userGroup)
|
||||
startIndex := 0
|
||||
priorityRetry := retry
|
||||
crossGroupRetry := common.GetContextKeyBool(c, constant.ContextKeyTokenCrossGroupRetry)
|
||||
if crossGroupRetry && retry > 0 {
|
||||
logger.LogDebug(c, "Auto group retry cross group, retry: %d", retry)
|
||||
if lastIndex, exists := common.GetContextKey(c, constant.ContextKeyAutoGroupIndex); exists {
|
||||
if idx, ok := lastIndex.(int); ok {
|
||||
startIndex = idx + 1
|
||||
priorityRetry = 0
|
||||
}
|
||||
|
||||
// startGroupIndex: the group index to start searching from
|
||||
// startGroupIndex: 开始搜索的分组索引
|
||||
startGroupIndex := 0
|
||||
crossGroupRetry := common.GetContextKeyBool(param.Ctx, constant.ContextKeyTokenCrossGroupRetry)
|
||||
|
||||
if lastGroupIndex, exists := common.GetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex); exists {
|
||||
if idx, ok := lastGroupIndex.(int); ok {
|
||||
startGroupIndex = idx
|
||||
}
|
||||
logger.LogDebug(c, "Auto group retry cross group, start index: %d", startIndex)
|
||||
}
|
||||
|
||||
for i := startIndex; i < len(autoGroups); i++ {
|
||||
for i := startGroupIndex; i < len(autoGroups); i++ {
|
||||
autoGroup := autoGroups[i]
|
||||
logger.LogDebug(c, "Auto selecting group: %s", autoGroup)
|
||||
channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, priorityRetry)
|
||||
if channel == nil {
|
||||
// Calculate priorityRetry for current group
|
||||
// 计算当前分组的 priorityRetry
|
||||
priorityRetry := param.GetRetry()
|
||||
// If moved to a new group, reset priorityRetry and update startRetryIndex
|
||||
// 如果切换到新分组,重置 priorityRetry 并更新 startRetryIndex
|
||||
if i > startGroupIndex {
|
||||
priorityRetry = 0
|
||||
continue
|
||||
} else {
|
||||
c.Set("auto_group", autoGroup)
|
||||
common.SetContextKey(c, constant.ContextKeyAutoGroupIndex, i)
|
||||
selectGroup = autoGroup
|
||||
logger.LogDebug(c, "Auto selected group: %s", autoGroup)
|
||||
break
|
||||
}
|
||||
logger.LogDebug(param.Ctx, "Auto selecting group: %s, priorityRetry: %d", autoGroup, priorityRetry)
|
||||
|
||||
channel, _ = model.GetRandomSatisfiedChannel(autoGroup, param.ModelName, priorityRetry)
|
||||
if channel == nil {
|
||||
// Current group has no available channel for this model, try next group
|
||||
// 当前分组没有该模型的可用渠道,尝试下一个分组
|
||||
logger.LogDebug(param.Ctx, "No available channel in group %s for model %s at priorityRetry %d, trying next group", autoGroup, param.ModelName, priorityRetry)
|
||||
// 重置状态以尝试下一个分组
|
||||
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i+1)
|
||||
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupRetryIndex, 0)
|
||||
// Reset retry counter so outer loop can continue for next group
|
||||
// 重置重试计数器,以便外层循环可以为下一个分组继续
|
||||
param.SetRetry(0)
|
||||
continue
|
||||
}
|
||||
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroup, autoGroup)
|
||||
selectGroup = autoGroup
|
||||
logger.LogDebug(param.Ctx, "Auto selected group: %s", autoGroup)
|
||||
|
||||
// Prepare state for next retry
|
||||
// 为下一次重试准备状态
|
||||
if crossGroupRetry && priorityRetry >= common.RetryTimes {
|
||||
// Current group has exhausted all retries, prepare to switch to next group
|
||||
// This request still uses current group, but next retry will use next group
|
||||
// 当前分组已用完所有重试次数,准备切换到下一个分组
|
||||
// 本次请求仍使用当前分组,但下次重试将使用下一个分组
|
||||
logger.LogDebug(param.Ctx, "Current group %s retries exhausted (priorityRetry=%d >= RetryTimes=%d), preparing switch to next group for next retry", autoGroup, priorityRetry, common.RetryTimes)
|
||||
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i+1)
|
||||
// Reset retry counter so outer loop can continue for next group
|
||||
// 重置重试计数器,以便外层循环可以为下一个分组继续
|
||||
param.SetRetry(-1)
|
||||
} else {
|
||||
// Stay in current group, save current state
|
||||
// 保持在当前分组,保存当前状态
|
||||
common.SetContextKey(param.Ctx, constant.ContextKeyAutoGroupIndex, i)
|
||||
}
|
||||
break
|
||||
}
|
||||
} else {
|
||||
channel, err = model.GetRandomSatisfiedChannel(tokenGroup, modelName, retry)
|
||||
channel, err = model.GetRandomSatisfiedChannel(param.TokenGroup, param.ModelName, param.GetRetry())
|
||||
if err != nil {
|
||||
return nil, tokenGroup, err
|
||||
return nil, param.TokenGroup, err
|
||||
}
|
||||
}
|
||||
return channel, selectGroup, nil
|
||||
|
||||
@@ -108,7 +108,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
||||
modelRatio, _, _ := ratio_setting.GetModelRatio(modelName)
|
||||
|
||||
autoGroup, exists := ctx.Get("auto_group")
|
||||
autoGroup, exists := common.GetContextKey(ctx, constant.ContextKeyAutoGroup)
|
||||
if exists {
|
||||
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
|
||||
log.Printf("final group ratio: %f", groupRatio)
|
||||
|
||||
Reference in New Issue
Block a user