diff --git a/constant/context_key.go b/constant/context_key.go
index 4de704619..ecc5178ee 100644
--- a/constant/context_key.go
+++ b/constant/context_key.go
@@ -18,8 +18,10 @@ const (
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
+ ContextKeyTokenCrossGroupRetry ContextKey = "token_cross_group_retry"
/* channel related keys */
+ ContextKeyAutoGroupIndex ContextKey = "auto_group_index"
ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelName ContextKey = "channel_name"
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
diff --git a/controller/token.go b/controller/token.go
index 0f0ae7fdf..eca4ce002 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -248,6 +248,7 @@ func UpdateToken(c *gin.Context) {
cleanToken.ModelLimits = token.ModelLimits
cleanToken.AllowIps = token.AllowIps
cleanToken.Group = token.Group
+ cleanToken.CrossGroupRetry = token.CrossGroupRetry
}
err = cleanToken.Update()
if err != nil {
diff --git a/middleware/auth.go b/middleware/auth.go
index dc59df9af..b1fca4712 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -308,6 +308,7 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
c.Set("token_model_limit_enabled", false)
}
c.Set("token_group", token.Group)
+ c.Set("token_cross_group_retry", token.CrossGroupRetry)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1])
diff --git a/model/token.go b/model/token.go
index c1fe2a670..a6a307ac2 100644
--- a/model/token.go
+++ b/model/token.go
@@ -27,6 +27,7 @@ type Token struct {
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Group string `json:"group" gorm:"default:''"`
+ CrossGroupRetry bool `json:"cross_group_retry" gorm:"default:false"` // 跨分组重试,仅auto分组有效
DeletedAt gorm.DeletedAt `gorm:"index"`
}
@@ -185,7 +186,7 @@ func (token *Token) Update() (err error) {
}
}()
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
- "model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
+ "model_limits_enabled", "model_limits", "allow_ips", "group", "cross_group_retry").Updates(token).Error
return err
}
diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go
index 69731d4d2..18cada8e0 100644
--- a/relay/channel/openai/helper.go
+++ b/relay/channel/openai/helper.go
@@ -172,7 +172,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
shouldSendLastResp *bool) error {
var lastStreamResponse dto.ChatCompletionsStreamResponse
- if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
+ if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
return err
}
diff --git a/service/channel_select.go b/service/channel_select.go
index 53f7d2c2a..b95aa025b 100644
--- a/service/channel_select.go
+++ b/service/channel_select.go
@@ -11,6 +11,7 @@ import (
"github.com/gin-gonic/gin"
)
+// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) {
var channel *model.Channel
var err error
@@ -20,15 +21,30 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName stri
if len(setting.GetAutoGroups()) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled")
}
- for _, autoGroup := range GetUserAutoGroup(userGroup) {
- logger.LogDebug(c, "Auto selecting group:", autoGroup)
- channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, retry)
+ autoGroups := GetUserAutoGroup(userGroup)
+ // 如果 token 启用了跨分组重试,获取上次失败的 auto group 索引,从下一个开始尝试
+ startIndex := 0
+ crossGroupRetry := common.GetContextKeyBool(c, constant.ContextKeyTokenCrossGroupRetry)
+ if crossGroupRetry && retry > 0 {
+ logger.LogDebug(c, "Auto group retry cross group, retry: %d", retry)
+ if lastIndex, exists := common.GetContextKey(c, constant.ContextKeyAutoGroupIndex); exists {
+ if idx, ok := lastIndex.(int); ok {
+ startIndex = idx + 1
+ }
+ }
+ logger.LogDebug(c, "Auto group retry cross group, start index: %d", startIndex)
+ }
+ for i := startIndex; i < len(autoGroups); i++ {
+ autoGroup := autoGroups[i]
+ logger.LogDebug(c, "Auto selecting group: %s", autoGroup)
+ channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, 0)
if channel == nil {
continue
} else {
c.Set("auto_group", autoGroup)
+ common.SetContextKey(c, constant.ContextKeyAutoGroupIndex, i)
selectGroup = autoGroup
- logger.LogDebug(c, "Auto selected group:", autoGroup)
+ logger.LogDebug(c, "Auto selected group: %s", autoGroup)
break
}
}
diff --git a/service/token_counter.go b/service/token_counter.go
index ebf0e243d..c70c54a88 100644
--- a/service/token_counter.go
+++ b/service/token_counter.go
@@ -317,7 +317,7 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela
for i, file := range meta.Files {
switch file.FileType {
case types.FileTypeImage:
- if common.IsOpenAITextModel(info.OriginModelName) {
+ if common.IsOpenAITextModel(model) {
token, err := getImageToken(file, model, info.IsStream)
if err != nil {
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
diff --git a/web/src/components/table/tokens/TokensColumnDefs.jsx b/web/src/components/table/tokens/TokensColumnDefs.jsx
index 4e092f9cc..ce8eab807 100644
--- a/web/src/components/table/tokens/TokensColumnDefs.jsx
+++ b/web/src/components/table/tokens/TokensColumnDefs.jsx
@@ -88,7 +88,7 @@ const renderStatus = (text, record, t) => {
};
// Render group column
-const renderGroupColumn = (text, t) => {
+const renderGroupColumn = (text, record, t) => {
if (text === 'auto') {
return (