refactor(relay): update channel retrieval to use RelayInfo structure

This commit is contained in:
CaIon
2025-12-12 22:04:38 +08:00
parent 2ac6a5b02f
commit ce6fb95f96
4 changed files with 29 additions and 21 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types" "github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -31,8 +32,11 @@ func Playground(c *gin.Context) {
return return
} }
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, nil, nil)
modelName := c.GetString("original_model") if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
return
}
userId := c.GetInt("id") userId := c.GetInt("id")
@@ -46,11 +50,11 @@ func Playground(c *gin.Context) {
tempToken := &model.Token{ tempToken := &model.Token{
UserId: userId, UserId: userId,
Name: fmt.Sprintf("playground-%s", group), Name: fmt.Sprintf("playground-%s", relayInfo.UsingGroup),
Group: group, Group: relayInfo.UsingGroup,
} }
_ = middleware.SetupContextForToken(c, tempToken) _ = middleware.SetupContextForToken(c, tempToken)
_, newAPIError = getChannel(c, group, modelName, 0) _, newAPIError = getChannel(c, relayInfo, 0)
if newAPIError != nil { if newAPIError != nil {
return return
} }

View File

@@ -64,8 +64,8 @@ func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewA
func Relay(c *gin.Context, relayFormat types.RelayFormat) { func Relay(c *gin.Context, relayFormat types.RelayFormat) {
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) //group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) //originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
var ( var (
newAPIError *types.NewAPIError newAPIError *types.NewAPIError
@@ -158,7 +158,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
}() }()
for i := 0; i <= common.RetryTimes; i++ { for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i) channel, err := getChannel(c, relayInfo, i)
if err != nil { if err != nil {
logger.LogError(c, err.Error()) logger.LogError(c, err.Error())
newAPIError = err newAPIError = err
@@ -211,7 +211,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
c.Set("use_channel", useChannel) c.Set("use_channel", useChannel)
} }
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) { func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) {
if retryCount == 0 { if retryCount == 0 {
autoBan := c.GetBool("auto_ban") autoBan := c.GetBool("auto_ban")
autoBanInt := 1 autoBanInt := 1
@@ -225,14 +225,18 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt, AutoBan: &autoBanInt,
}, nil }, nil
} }
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount)
info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
if err != nil { if err != nil {
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, info.OriginModelName, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
} }
if channel == nil { if channel == nil {
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在retry", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在retry", selectGroup, info.OriginModelName), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
} }
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, info.OriginModelName)
if newAPIError != nil { if newAPIError != nil {
return nil, newAPIError return nil, newAPIError
} }
@@ -392,8 +396,6 @@ func RelayNotFound(c *gin.Context) {
func RelayTask(c *gin.Context) { func RelayTask(c *gin.Context) {
retryTimes := common.RetryTimes retryTimes := common.RetryTimes
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
group := c.GetString("group")
originalModel := c.GetString("original_model")
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
if err != nil { if err != nil {
@@ -404,7 +406,7 @@ func RelayTask(c *gin.Context) {
retryTimes = 0 retryTimes = 0
} }
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, newAPIError := getChannel(c, group, originalModel, i) channel, newAPIError := getChannel(c, relayInfo, i)
if newAPIError != nil { if newAPIError != nil {
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)

View File

@@ -81,6 +81,7 @@ type TokenCountMeta struct {
type RelayInfo struct { type RelayInfo struct {
TokenId int TokenId int
TokenKey string TokenKey string
TokenGroup string
UserId int UserId int
UsingGroup string // 使用的分组 UsingGroup string // 使用的分组
UserGroup string // 用户所在分组 UserGroup string // 用户所在分组
@@ -400,6 +401,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId), TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey), TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited), TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
TokenGroup: common.GetContextKeyString(c, constant.ContextKeyTokenGroup),
isFirstResponse: true, isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),

View File

@@ -12,12 +12,12 @@ import (
) )
// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements. // CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) { func CacheGetRandomSatisfiedChannel(c *gin.Context, tokenGroup string, modelName string, retry int) (*model.Channel, string, error) {
var channel *model.Channel var channel *model.Channel
var err error var err error
selectGroup := group selectGroup := tokenGroup
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
if group == "auto" { if tokenGroup == "auto" {
if len(setting.GetAutoGroups()) == 0 { if len(setting.GetAutoGroups()) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled") return nil, selectGroup, errors.New("auto groups is not enabled")
} }
@@ -49,9 +49,9 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName stri
} }
} }
} else { } else {
channel, err = model.GetRandomSatisfiedChannel(group, modelName, retry) channel, err = model.GetRandomSatisfiedChannel(tokenGroup, modelName, retry)
if err != nil { if err != nil {
return nil, group, err return nil, tokenGroup, err
} }
} }
return channel, selectGroup, nil return channel, selectGroup, nil