diff --git a/constant/context_key.go b/constant/context_key.go index 4de704619..ecc5178ee 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -18,8 +18,10 @@ const ( ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id" ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled" ContextKeyTokenModelLimit ContextKey = "token_model_limit" + ContextKeyTokenCrossGroupRetry ContextKey = "token_cross_group_retry" /* channel related keys */ + ContextKeyAutoGroupIndex ContextKey = "auto_group_index" ContextKeyChannelId ContextKey = "channel_id" ContextKeyChannelName ContextKey = "channel_name" ContextKeyChannelCreateTime ContextKey = "channel_create_time" diff --git a/controller/token.go b/controller/token.go index 04e31f8c1..832438e83 100644 --- a/controller/token.go +++ b/controller/token.go @@ -248,6 +248,7 @@ func UpdateToken(c *gin.Context) { cleanToken.ModelLimits = token.ModelLimits cleanToken.AllowIps = token.AllowIps cleanToken.Group = token.Group + cleanToken.CrossGroupRetry = token.CrossGroupRetry } err = cleanToken.Update() if err != nil { diff --git a/middleware/auth.go b/middleware/auth.go index dc59df9af..b1fca4712 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -308,6 +308,7 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e c.Set("token_model_limit_enabled", false) } c.Set("token_group", token.Group) + c.Set("token_cross_group_retry", token.CrossGroupRetry) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("specific_channel_id", parts[1]) diff --git a/model/token.go b/model/token.go index c1fe2a670..a6a307ac2 100644 --- a/model/token.go +++ b/model/token.go @@ -27,6 +27,7 @@ type Token struct { AllowIps *string `json:"allow_ips" gorm:"default:''"` UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota Group string `json:"group" gorm:"default:''"` + CrossGroupRetry bool `json:"cross_group_retry" gorm:"default:false"` // 跨分组重试,仅auto分组有效 DeletedAt gorm.DeletedAt `gorm:"index"` } @@ -185,7 +186,7 @@ func (token *Token) Update() (err error) { } }() err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", - "model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error + "model_limits_enabled", "model_limits", "allow_ips", "group", "cross_group_retry").Updates(token).Error return err } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 69731d4d2..18cada8e0 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -172,7 +172,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int shouldSendLastResp *bool) error { var lastStreamResponse dto.ChatCompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil { + if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil { return err } diff --git a/service/channel_select.go b/service/channel_select.go index 53f7d2c2a..348b89e55 100644 --- a/service/channel_select.go +++ b/service/channel_select.go @@ -11,6 +11,7 @@ import ( "github.com/gin-gonic/gin" ) +// CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements. func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) { var channel *model.Channel var err error @@ -20,15 +21,28 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName stri if len(setting.GetAutoGroups()) == 0 { return nil, selectGroup, errors.New("auto groups is not enabled") } - for _, autoGroup := range GetUserAutoGroup(userGroup) { - logger.LogDebug(c, "Auto selecting group:", autoGroup) - channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, retry) + autoGroups := GetUserAutoGroup(userGroup) + // 如果 token 启用了跨分组重试,获取上次失败的 auto group 索引,从下一个开始尝试 + startIndex := 0 + crossGroupRetry := common.GetContextKeyBool(c, constant.ContextKeyTokenCrossGroupRetry) + if crossGroupRetry && retry > 0 { + logger.LogDebug(c, "Auto group retry cross group, retry: %d", retry) + if lastIndex, exists := c.Get(string(constant.ContextKeyAutoGroupIndex)); exists { + startIndex = lastIndex.(int) + 1 + } + logger.LogDebug(c, "Auto group retry cross group, start index: %d", startIndex) + } + for i := startIndex; i < len(autoGroups); i++ { + autoGroup := autoGroups[i] + logger.LogDebug(c, "Auto selecting group: %s", autoGroup) + channel, _ = model.GetRandomSatisfiedChannel(autoGroup, modelName, 0) if channel == nil { continue } else { c.Set("auto_group", autoGroup) + c.Set(string(constant.ContextKeyAutoGroupIndex), i) selectGroup = autoGroup - logger.LogDebug(c, "Auto selected group:", autoGroup) + logger.LogDebug(c, "Auto selected group: %s", autoGroup) break } } diff --git a/service/token_counter.go b/service/token_counter.go index ebf0e243d..c70c54a88 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -317,7 +317,7 @@ func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *rela for i, file := range meta.Files { switch file.FileType { case types.FileTypeImage: - if common.IsOpenAITextModel(info.OriginModelName) { + if common.IsOpenAITextModel(model) { token, err := getImageToken(file, model, info.IsStream) if err != nil { return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err) diff --git a/web/src/components/table/tokens/modals/EditTokenModal.jsx b/web/src/components/table/tokens/modals/EditTokenModal.jsx index 59a3894af..c7db40d66 100644 --- a/web/src/components/table/tokens/modals/EditTokenModal.jsx +++ b/web/src/components/table/tokens/modals/EditTokenModal.jsx @@ -73,6 +73,7 @@ const EditTokenModal = (props) => { model_limits: [], allow_ips: '', group: '', + cross_group_retry: false, tokenCount: 1, }); @@ -377,6 +378,16 @@ const EditTokenModal = (props) => { /> )} + + + {