feat(token): add cross-group retry option for token processing

This commit is contained in:
CaIon
2025-12-12 17:59:21 +08:00
parent 0b9f6a58bc
commit 01b4039e96
8 changed files with 38 additions and 8 deletions

View File

@@ -18,8 +18,10 @@ const (
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id" ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled" ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
ContextKeyTokenModelLimit ContextKey = "token_model_limit" ContextKeyTokenModelLimit ContextKey = "token_model_limit"
ContextKeyTokenCrossGroupRetry ContextKey = "token_cross_group_retry"
/* channel related keys */ /* channel related keys */
ContextKeyAutoGroupIndex ContextKey = "auto_group_index"
ContextKeyChannelId ContextKey = "channel_id" ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelName ContextKey = "channel_name" ContextKeyChannelName ContextKey = "channel_name"
ContextKeyChannelCreateTime ContextKey = "channel_create_time" ContextKeyChannelCreateTime ContextKey = "channel_create_time"

View File

@@ -248,6 +248,7 @@ func UpdateToken(c *gin.Context) {
cleanToken.ModelLimits = token.ModelLimits cleanToken.ModelLimits = token.ModelLimits
cleanToken.AllowIps = token.AllowIps cleanToken.AllowIps = token.AllowIps
cleanToken.Group = token.Group cleanToken.Group = token.Group
cleanToken.CrossGroupRetry = token.CrossGroupRetry
} }
err = cleanToken.Update() err = cleanToken.Update()
if err != nil { if err != nil {

View File

@@ -308,6 +308,7 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
c.Set("token_model_limit_enabled", false) c.Set("token_model_limit_enabled", false)
} }
c.Set("token_group", token.Group) c.Set("token_group", token.Group)
c.Set("token_cross_group_retry", token.CrossGroupRetry)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1]) c.Set("specific_channel_id", parts[1])

View File

@@ -27,6 +27,7 @@ type Token struct {
AllowIps *string `json:"allow_ips" gorm:"default:''"` AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Group string `json:"group" gorm:"default:''"` Group string `json:"group" gorm:"default:''"`
CrossGroupRetry bool `json:"cross_group_retry" gorm:"default:false"` // 跨分组重试仅auto分组有效
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
@@ -185,7 +186,7 @@ func (token *Token) Update() (err error) {
} }
}() }()
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error "model_limits_enabled", "model_limits", "allow_ips", "group", "cross_group_retry").Updates(token).Error
return err return err
} }

View File

@@ -172,7 +172,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
shouldSendLastResp *bool) error { shouldSendLastResp *bool) error {
var lastStreamResponse dto.ChatCompletionsStreamResponse var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil { if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
return err return err
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) { func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) {
var channel *model.Channel var channel *model.Channel
var err error var err error
@@ -20,15 +21,28 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName stri
if len(setting.GetAutoGroups()) == 0 { if len(setting.GetAutoGroups()) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled") return nil, selectGroup, errors.New("auto groups is not enabled")
} }
for _, autoGroup := range GetUserAutoGroup(userGroup) { autoGroups := GetUserAutoGroup(userGroup)
logger.LogDebug(c, "Auto selecting group:", autoGroup) // 如果 token 启用了跨分组重试,获取上次失败的 auto group 索引,从下一个开始尝试
channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, retry) startIndex := 0
crossGroupRetry := common.GetContextKeyBool(c, constant.ContextKeyTokenCrossGroupRetry)
if crossGroupRetry && retry > 0 {
logger.LogDebug(c, "Auto group retry cross group, retry: %d", retry)
if lastIndex, exists := c.Get(string(constant.ContextKeyAutoGroupIndex)); exists {
startIndex = lastIndex.(int) + 1
}
logger.LogDebug(c, "Auto group retry cross group, start index: %d", startIndex)
}
for i := startIndex; i < len(autoGroups); i++ {
autoGroup := autoGroups[i]
logger.LogDebug(c, "Auto selecting group: %s", autoGroup)
channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, 0)
if channel == nil { if channel == nil {
continue continue
} else { } else {
c.Set("auto_group", autoGroup) c.Set("auto_group", autoGroup)
c.Set(string(constant.ContextKeyAutoGroupIndex), i)
selectGroup = autoGroup selectGroup = autoGroup
logger.LogDebug(c, "Auto selected group:", autoGroup) logger.LogDebug(c, "Auto selected group: %s", autoGroup)
break break
} }
} }

View File

@@ -317,7 +317,7 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
for i, file := range meta.Files { for i, file := range meta.Files {
switch file.FileType { switch file.FileType {
case types.FileTypeImage: case types.FileTypeImage:
if common.IsOpenAITextModel(info.OriginModelName) { if common.IsOpenAITextModel(model) {
token, err := getImageToken(file, model, info.IsStream) token, err := getImageToken(file, model, info.IsStream)
if err != nil { if err != nil {
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err) return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)

View File

@@ -73,6 +73,7 @@ const EditTokenModal = (props) => {
model_limits: [], model_limits: [],
allow_ips: '', allow_ips: '',
group: '', group: '',
cross_group_retry: false,
tokenCount: 1, tokenCount: 1,
}); });
@@ -377,6 +378,16 @@ const EditTokenModal = (props) => {
/> />
)} )}
</Col> </Col>
<Col span={24} style={{ display: values.group === 'auto' ? 'block' : 'none' }}>
<Form.Switch
field='cross_group_retry'
label={t('跨分组重试')}
size='default'
extraText={t(
'开启后,当前分组渠道失败时会按顺序尝试下一个分组的渠道',
)}
/>
</Col>
<Col xs={24} sm={24} md={24} lg={10} xl={10}> <Col xs={24} sm={24} md={24} lg={10} xl={10}>
<Form.DatePicker <Form.DatePicker
field='expired_time' field='expired_time'
@@ -499,7 +510,7 @@ const EditTokenModal = (props) => {
<Form.Switch <Form.Switch
field='unlimited_quota' field='unlimited_quota'
label={t('无限额度')} label={t('无限额度')}
size='large' size='default'
extraText={t( extraText={t(
'令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制', '令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制',
)} )}