mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 09:08:38 +00:00
refactor(relay): update channel retrieval to use RelayInfo structure
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/constant"
|
"github.com/QuantumNous/new-api/constant"
|
||||||
"github.com/QuantumNous/new-api/middleware"
|
"github.com/QuantumNous/new-api/middleware"
|
||||||
"github.com/QuantumNous/new-api/model"
|
"github.com/QuantumNous/new-api/model"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/types"
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -31,8 +32,11 @@ func Playground(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, nil, nil)
|
||||||
modelName := c.GetString("original_model")
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
@@ -46,11 +50,11 @@ func Playground(c *gin.Context) {
|
|||||||
|
|
||||||
tempToken := &model.Token{
|
tempToken := &model.Token{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Name: fmt.Sprintf("playground-%s", group),
|
Name: fmt.Sprintf("playground-%s", relayInfo.UsingGroup),
|
||||||
Group: group,
|
Group: relayInfo.UsingGroup,
|
||||||
}
|
}
|
||||||
_ = middleware.SetupContextForToken(c, tempToken)
|
_ = middleware.SetupContextForToken(c, tempToken)
|
||||||
_, newAPIError = getChannel(c, group, modelName, 0)
|
_, newAPIError = getChannel(c, relayInfo, 0)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,8 +64,8 @@ func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewA
|
|||||||
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||||
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
//group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||||
originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
//originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
newAPIError *types.NewAPIError
|
newAPIError *types.NewAPIError
|
||||||
@@ -158,7 +158,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, relayInfo, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LogError(c, err.Error())
|
logger.LogError(c, err.Error())
|
||||||
newAPIError = err
|
newAPIError = err
|
||||||
@@ -211,7 +211,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
|
|||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||||||
if retryCount == 0 {
|
if retryCount == 0 {
|
||||||
autoBan := c.GetBool("auto_ban")
|
autoBan := c.GetBool("auto_ban")
|
||||||
autoBanInt := 1
|
autoBanInt := 1
|
||||||
@@ -225,14 +225,18 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
|||||||
AutoBan: &autoBanInt,
|
AutoBan: &autoBanInt,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount)
|
||||||
|
|
||||||
|
info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, info.OriginModelName, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
if channel == nil {
|
if channel == nil {
|
||||||
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, info.OriginModelName), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
||||||
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, info.OriginModelName)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
return nil, newAPIError
|
return nil, newAPIError
|
||||||
}
|
}
|
||||||
@@ -392,8 +396,6 @@ func RelayNotFound(c *gin.Context) {
|
|||||||
func RelayTask(c *gin.Context) {
|
func RelayTask(c *gin.Context) {
|
||||||
retryTimes := common.RetryTimes
|
retryTimes := common.RetryTimes
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
group := c.GetString("group")
|
|
||||||
originalModel := c.GetString("original_model")
|
|
||||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||||
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -404,7 +406,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, newAPIError := getChannel(c, group, originalModel, i)
|
channel, newAPIError := getChannel(c, relayInfo, i)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ type TokenCountMeta struct {
|
|||||||
type RelayInfo struct {
|
type RelayInfo struct {
|
||||||
TokenId int
|
TokenId int
|
||||||
TokenKey string
|
TokenKey string
|
||||||
|
TokenGroup string
|
||||||
UserId int
|
UserId int
|
||||||
UsingGroup string // 使用的分组
|
UsingGroup string // 使用的分组
|
||||||
UserGroup string // 用户所在分组
|
UserGroup string // 用户所在分组
|
||||||
@@ -400,6 +401,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
|||||||
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
|
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
|
||||||
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
|
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
|
||||||
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
|
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
|
||||||
|
TokenGroup: common.GetContextKeyString(c, constant.ContextKeyTokenGroup),
|
||||||
|
|
||||||
isFirstResponse: true,
|
isFirstResponse: true,
|
||||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||||
|
|||||||
@@ -12,12 +12,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
|
// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
|
||||||
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) {
|
func CacheGetRandomSatisfiedChannel(c *gin.Context, tokenGroup string, modelName string, retry int) (*model.Channel, string, error) {
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
var err error
|
var err error
|
||||||
selectGroup := group
|
selectGroup := tokenGroup
|
||||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||||
if group == "auto" {
|
if tokenGroup == "auto" {
|
||||||
if len(setting.GetAutoGroups()) == 0 {
|
if len(setting.GetAutoGroups()) == 0 {
|
||||||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||||||
}
|
}
|
||||||
@@ -49,9 +49,9 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
channel, err = model.GetRandomSatisfiedChannel(group, modelName, retry)
|
channel, err = model.GetRandomSatisfiedChannel(tokenGroup, modelName, retry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, group, err
|
return nil, tokenGroup, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return channel, selectGroup, nil
|
return channel, selectGroup, nil
|
||||||
|
|||||||
Reference in New Issue
Block a user