From ce6fb95f96c581ba320657c2244a91acee636d21 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 12 Dec 2025 22:04:38 +0800 Subject: [PATCH] refactor(relay): update channel retrieval to use RelayInfo structure --- controller/playground.go | 14 +++++++++----- controller/relay.go | 24 +++++++++++++----------- relay/common/relay_info.go | 2 ++ service/channel_select.go | 10 +++++----- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/controller/playground.go b/controller/playground.go index 342f47cf0..d9e2ba9a1 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -9,6 +9,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" @@ -31,8 +32,11 @@ func Playground(c *gin.Context) { return } - group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) - modelName := c.GetString("original_model") + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, nil, nil) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + return + } userId := c.GetInt("id") @@ -46,11 +50,11 @@ func Playground(c *gin.Context) { tempToken := &model.Token{ UserId: userId, - Name: fmt.Sprintf("playground-%s", group), - Group: group, + Name: fmt.Sprintf("playground-%s", relayInfo.UsingGroup), + Group: relayInfo.UsingGroup, } _ = middleware.SetupContextForToken(c, tempToken) - _, newAPIError = getChannel(c, group, modelName, 0) + _, newAPIError = getChannel(c, relayInfo, 0) if newAPIError != nil { return } diff --git a/controller/relay.go b/controller/relay.go index 50ad9dabb..2013b9c0f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -64,8 +64,8 @@ func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewA func Relay(c *gin.Context, relayFormat types.RelayFormat) { requestId := c.GetString(common.RequestIdKey) - group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) - originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) + //group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) + //originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) var ( newAPIError *types.NewAPIError @@ -158,7 +158,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { }() for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) + channel, err := getChannel(c, relayInfo, i) if err != nil { logger.LogError(c, err.Error()) newAPIError = err @@ -211,7 +211,7 @@ func addUsedChannel(c *gin.Context, channelId int) { c.Set("use_channel", useChannel) } -func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) { +func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) { if retryCount == 0 { autoBan := c.GetBool("auto_ban") autoBanInt := 1 @@ -225,14 +225,18 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m AutoBan: &autoBanInt, }, nil } - channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) + channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount) + + info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info) + if err != nil { - return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) + return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, info.OriginModelName, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } if channel == nil { - return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) + return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, info.OriginModelName), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } - newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) + + newAPIError := middleware.SetupContextForSelectedChannel(c, channel, info.OriginModelName) if newAPIError != nil { return nil, newAPIError } @@ -392,8 +396,6 @@ func RelayNotFound(c *gin.Context) { func RelayTask(c *gin.Context) { retryTimes := common.RetryTimes channelId := c.GetInt("channel_id") - group := c.GetString("group") - originalModel := c.GetString("original_model") c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) if err != nil { @@ -404,7 +406,7 @@ func RelayTask(c *gin.Context) { retryTimes = 0 } for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { - channel, newAPIError := getChannel(c, group, originalModel, i) + channel, newAPIError := getChannel(c, relayInfo, i) if newAPIError != nil { logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 1882eca89..8bc47bb52 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -81,6 +81,7 @@ type TokenCountMeta struct { type RelayInfo struct { TokenId int TokenKey string + TokenGroup string UserId int UsingGroup string // 使用的分组 UserGroup string // 用户所在分组 @@ -400,6 +401,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId), TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey), TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited), + TokenGroup: common.GetContextKeyString(c, constant.ContextKeyTokenGroup), isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), diff --git a/service/channel_select.go b/service/channel_select.go index b95aa025b..aea522d96 100644 --- a/service/channel_select.go +++ b/service/channel_select.go @@ -12,12 +12,12 @@ import ( ) // CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements. -func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) { +func CacheGetRandomSatisfiedChannel(c *gin.Context, tokenGroup string, modelName string, retry int) (*model.Channel, string, error) { var channel *model.Channel var err error - selectGroup := group + selectGroup := tokenGroup userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) - if group == "auto" { + if tokenGroup == "auto" { if len(setting.GetAutoGroups()) == 0 { return nil, selectGroup, errors.New("auto groups is not enabled") } @@ -49,9 +49,9 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName stri } } } else { - channel, err = model.GetRandomSatisfiedChannel(group, modelName, retry) + channel, err = model.GetRandomSatisfiedChannel(tokenGroup, modelName, retry) if err != nil { - return nil, group, err + return nil, tokenGroup, err } } return channel, selectGroup, nil