mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 04:40:59 +00:00
feat: channel affinity (#2669)
* feat: channel affinity * feat: channel affinity -> model setting * fix: channel affinity * feat: channel affinity op * feat: channel_type setting * feat: clean * feat: cache supports both memory and Redis. * feat: Optimise ui/ux * feat: Optimise ui/ux * feat: Optimise codex usage ui/ux * feat: Optimise ui/ux * feat: Optimise ui/ux * feat: Optimise ui/ux * feat: If the affinitized channel fails and a retry succeeds on another channel, update the affinity to the successful channel
This commit is contained in:
487
service/channel_affinity.go
Normal file
487
service/channel_affinity.go
Normal file
@@ -0,0 +1,487 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/pkg/cachex"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/hot"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
ginKeyChannelAffinityCacheKey = "channel_affinity_cache_key"
|
||||
ginKeyChannelAffinityTTLSeconds = "channel_affinity_ttl_seconds"
|
||||
ginKeyChannelAffinityMeta = "channel_affinity_meta"
|
||||
ginKeyChannelAffinityLogInfo = "channel_affinity_log_info"
|
||||
|
||||
channelAffinityCacheNamespace = "new-api:channel_affinity:v1"
|
||||
)
|
||||
|
||||
var (
|
||||
channelAffinityCacheOnce sync.Once
|
||||
channelAffinityCache *cachex.HybridCache[int]
|
||||
|
||||
channelAffinityRegexCache sync.Map // map[string]*regexp.Regexp
|
||||
)
|
||||
|
||||
type channelAffinityMeta struct {
|
||||
CacheKey string
|
||||
TTLSeconds int
|
||||
RuleName string
|
||||
KeySourceType string
|
||||
KeySourceKey string
|
||||
KeySourcePath string
|
||||
KeyFingerprint string
|
||||
UsingGroup string
|
||||
ModelName string
|
||||
RequestPath string
|
||||
}
|
||||
|
||||
type ChannelAffinityCacheStats struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Total int `json:"total"`
|
||||
Unknown int `json:"unknown"`
|
||||
ByRuleName map[string]int `json:"by_rule_name"`
|
||||
CacheCapacity int `json:"cache_capacity"`
|
||||
CacheAlgo string `json:"cache_algo"`
|
||||
}
|
||||
|
||||
func getChannelAffinityCache() *cachex.HybridCache[int] {
|
||||
channelAffinityCacheOnce.Do(func() {
|
||||
setting := operation_setting.GetChannelAffinitySetting()
|
||||
capacity := setting.MaxEntries
|
||||
if capacity <= 0 {
|
||||
capacity = 100_000
|
||||
}
|
||||
defaultTTLSeconds := setting.DefaultTTLSeconds
|
||||
if defaultTTLSeconds <= 0 {
|
||||
defaultTTLSeconds = 3600
|
||||
}
|
||||
|
||||
channelAffinityCache = cachex.NewHybridCache[int](cachex.HybridCacheConfig[int]{
|
||||
Namespace: cachex.Namespace(channelAffinityCacheNamespace),
|
||||
Redis: common.RDB,
|
||||
RedisEnabled: func() bool {
|
||||
return common.RedisEnabled && common.RDB != nil
|
||||
},
|
||||
RedisCodec: cachex.IntCodec{},
|
||||
Memory: func() *hot.HotCache[string, int] {
|
||||
return hot.NewHotCache[string, int](hot.LRU, capacity).
|
||||
WithTTL(time.Duration(defaultTTLSeconds) * time.Second).
|
||||
WithJanitor().
|
||||
Build()
|
||||
},
|
||||
})
|
||||
})
|
||||
return channelAffinityCache
|
||||
}
|
||||
|
||||
func GetChannelAffinityCacheStats() ChannelAffinityCacheStats {
|
||||
setting := operation_setting.GetChannelAffinitySetting()
|
||||
if setting == nil {
|
||||
return ChannelAffinityCacheStats{
|
||||
Enabled: false,
|
||||
Total: 0,
|
||||
Unknown: 0,
|
||||
ByRuleName: map[string]int{},
|
||||
}
|
||||
}
|
||||
|
||||
cache := getChannelAffinityCache()
|
||||
mainCap, _ := cache.Capacity()
|
||||
mainAlgo, _ := cache.Algorithm()
|
||||
|
||||
rules := setting.Rules
|
||||
ruleByName := make(map[string]operation_setting.ChannelAffinityRule, len(rules))
|
||||
for _, r := range rules {
|
||||
name := strings.TrimSpace(r.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if !r.IncludeRuleName {
|
||||
continue
|
||||
}
|
||||
ruleByName[name] = r
|
||||
}
|
||||
|
||||
byRuleName := make(map[string]int, len(ruleByName))
|
||||
for name := range ruleByName {
|
||||
byRuleName[name] = 0
|
||||
}
|
||||
|
||||
keys, err := cache.Keys()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err))
|
||||
keys = nil
|
||||
}
|
||||
total := len(keys)
|
||||
unknown := 0
|
||||
for _, k := range keys {
|
||||
prefix := channelAffinityCacheNamespace + ":"
|
||||
if !strings.HasPrefix(k, prefix) {
|
||||
unknown++
|
||||
continue
|
||||
}
|
||||
rest := strings.TrimPrefix(k, prefix)
|
||||
parts := strings.Split(rest, ":")
|
||||
if len(parts) < 2 {
|
||||
unknown++
|
||||
continue
|
||||
}
|
||||
ruleName := parts[0]
|
||||
rule, ok := ruleByName[ruleName]
|
||||
if !ok {
|
||||
unknown++
|
||||
continue
|
||||
}
|
||||
if rule.IncludeUsingGroup {
|
||||
if len(parts) < 3 {
|
||||
unknown++
|
||||
continue
|
||||
}
|
||||
}
|
||||
byRuleName[ruleName]++
|
||||
}
|
||||
|
||||
return ChannelAffinityCacheStats{
|
||||
Enabled: setting.Enabled,
|
||||
Total: total,
|
||||
Unknown: unknown,
|
||||
ByRuleName: byRuleName,
|
||||
CacheCapacity: mainCap,
|
||||
CacheAlgo: mainAlgo,
|
||||
}
|
||||
}
|
||||
|
||||
func ClearChannelAffinityCacheAll() int {
|
||||
cache := getChannelAffinityCache()
|
||||
keys, err := cache.Keys()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err))
|
||||
keys = nil
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if _, err := cache.DeleteMany(keys); err != nil {
|
||||
common.SysError(fmt.Sprintf("channel affinity cache delete many failed: err=%v", err))
|
||||
}
|
||||
}
|
||||
return len(keys)
|
||||
}
|
||||
|
||||
func ClearChannelAffinityCacheByRuleName(ruleName string) (int, error) {
|
||||
ruleName = strings.TrimSpace(ruleName)
|
||||
if ruleName == "" {
|
||||
return 0, fmt.Errorf("rule_name 不能为空")
|
||||
}
|
||||
|
||||
setting := operation_setting.GetChannelAffinitySetting()
|
||||
if setting == nil {
|
||||
return 0, fmt.Errorf("channel_affinity_setting 未初始化")
|
||||
}
|
||||
|
||||
var matchedRule *operation_setting.ChannelAffinityRule
|
||||
for i := range setting.Rules {
|
||||
r := &setting.Rules[i]
|
||||
if strings.TrimSpace(r.Name) != ruleName {
|
||||
continue
|
||||
}
|
||||
matchedRule = r
|
||||
break
|
||||
}
|
||||
if matchedRule == nil {
|
||||
return 0, fmt.Errorf("未知规则名称")
|
||||
}
|
||||
if !matchedRule.IncludeRuleName {
|
||||
return 0, fmt.Errorf("该规则未启用 include_rule_name,无法按规则清空缓存")
|
||||
}
|
||||
|
||||
cache := getChannelAffinityCache()
|
||||
deleted, err := cache.DeleteByPrefix(ruleName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func matchAnyRegexCached(patterns []string, s string) bool {
|
||||
if len(patterns) == 0 || s == "" {
|
||||
return false
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
if pattern == "" {
|
||||
continue
|
||||
}
|
||||
re, ok := channelAffinityRegexCache.Load(pattern)
|
||||
if !ok {
|
||||
compiled, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
re = compiled
|
||||
channelAffinityRegexCache.Store(pattern, re)
|
||||
}
|
||||
if re.(*regexp.Regexp).MatchString(s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchAnyIncludeFold(patterns []string, s string) bool {
|
||||
if len(patterns) == 0 || s == "" {
|
||||
return false
|
||||
}
|
||||
sLower := strings.ToLower(s)
|
||||
for _, p := range patterns {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(sLower, strings.ToLower(p)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractChannelAffinityValue(c *gin.Context, src operation_setting.ChannelAffinityKeySource) string {
|
||||
switch src.Type {
|
||||
case "context_int":
|
||||
if src.Key == "" {
|
||||
return ""
|
||||
}
|
||||
v := c.GetInt(src.Key)
|
||||
if v <= 0 {
|
||||
return ""
|
||||
}
|
||||
return strconv.Itoa(v)
|
||||
case "context_string":
|
||||
if src.Key == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(c.GetString(src.Key))
|
||||
case "gjson":
|
||||
if src.Path == "" {
|
||||
return ""
|
||||
}
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil || len(body) == 0 {
|
||||
return ""
|
||||
}
|
||||
res := gjson.GetBytes(body, src.Path)
|
||||
if !res.Exists() {
|
||||
return ""
|
||||
}
|
||||
switch res.Type {
|
||||
case gjson.String, gjson.Number, gjson.True, gjson.False:
|
||||
return strings.TrimSpace(res.String())
|
||||
default:
|
||||
return strings.TrimSpace(res.Raw)
|
||||
}
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func buildChannelAffinityCacheKeySuffix(rule operation_setting.ChannelAffinityRule, usingGroup string, affinityValue string) string {
|
||||
parts := make([]string, 0, 3)
|
||||
if rule.IncludeRuleName && rule.Name != "" {
|
||||
parts = append(parts, rule.Name)
|
||||
}
|
||||
if rule.IncludeUsingGroup && usingGroup != "" {
|
||||
parts = append(parts, usingGroup)
|
||||
}
|
||||
parts = append(parts, affinityValue)
|
||||
return strings.Join(parts, ":")
|
||||
}
|
||||
|
||||
func setChannelAffinityContext(c *gin.Context, meta channelAffinityMeta) {
|
||||
c.Set(ginKeyChannelAffinityCacheKey, meta.CacheKey)
|
||||
c.Set(ginKeyChannelAffinityTTLSeconds, meta.TTLSeconds)
|
||||
c.Set(ginKeyChannelAffinityMeta, meta)
|
||||
}
|
||||
|
||||
func getChannelAffinityContext(c *gin.Context) (string, int, bool) {
|
||||
keyAny, ok := c.Get(ginKeyChannelAffinityCacheKey)
|
||||
if !ok {
|
||||
return "", 0, false
|
||||
}
|
||||
key, ok := keyAny.(string)
|
||||
if !ok || key == "" {
|
||||
return "", 0, false
|
||||
}
|
||||
ttlAny, ok := c.Get(ginKeyChannelAffinityTTLSeconds)
|
||||
if !ok {
|
||||
return key, 0, true
|
||||
}
|
||||
ttlSeconds, _ := ttlAny.(int)
|
||||
return key, ttlSeconds, true
|
||||
}
|
||||
|
||||
func getChannelAffinityMeta(c *gin.Context) (channelAffinityMeta, bool) {
|
||||
anyMeta, ok := c.Get(ginKeyChannelAffinityMeta)
|
||||
if !ok {
|
||||
return channelAffinityMeta{}, false
|
||||
}
|
||||
meta, ok := anyMeta.(channelAffinityMeta)
|
||||
if !ok {
|
||||
return channelAffinityMeta{}, false
|
||||
}
|
||||
return meta, true
|
||||
}
|
||||
|
||||
func affinityFingerprint(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
hex := common.Sha1([]byte(s))
|
||||
if len(hex) >= 8 {
|
||||
return hex[:8]
|
||||
}
|
||||
return hex
|
||||
}
|
||||
|
||||
func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) {
|
||||
setting := operation_setting.GetChannelAffinitySetting()
|
||||
if setting == nil || !setting.Enabled {
|
||||
return 0, false
|
||||
}
|
||||
path := ""
|
||||
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||
path = c.Request.URL.Path
|
||||
}
|
||||
userAgent := ""
|
||||
if c != nil && c.Request != nil {
|
||||
userAgent = c.Request.UserAgent()
|
||||
}
|
||||
|
||||
for _, rule := range setting.Rules {
|
||||
if !matchAnyRegexCached(rule.ModelRegex, modelName) {
|
||||
continue
|
||||
}
|
||||
if len(rule.PathRegex) > 0 && !matchAnyRegexCached(rule.PathRegex, path) {
|
||||
continue
|
||||
}
|
||||
if len(rule.UserAgentInclude) > 0 && !matchAnyIncludeFold(rule.UserAgentInclude, userAgent) {
|
||||
continue
|
||||
}
|
||||
var affinityValue string
|
||||
var usedSource operation_setting.ChannelAffinityKeySource
|
||||
for _, src := range rule.KeySources {
|
||||
affinityValue = extractChannelAffinityValue(c, src)
|
||||
if affinityValue != "" {
|
||||
usedSource = src
|
||||
break
|
||||
}
|
||||
}
|
||||
if affinityValue == "" {
|
||||
continue
|
||||
}
|
||||
if rule.ValueRegex != "" && !matchAnyRegexCached([]string{rule.ValueRegex}, affinityValue) {
|
||||
continue
|
||||
}
|
||||
|
||||
ttlSeconds := rule.TTLSeconds
|
||||
if ttlSeconds <= 0 {
|
||||
ttlSeconds = setting.DefaultTTLSeconds
|
||||
}
|
||||
cacheKeySuffix := buildChannelAffinityCacheKeySuffix(rule, usingGroup, affinityValue)
|
||||
cacheKeyFull := channelAffinityCacheNamespace + ":" + cacheKeySuffix
|
||||
setChannelAffinityContext(c, channelAffinityMeta{
|
||||
CacheKey: cacheKeyFull,
|
||||
TTLSeconds: ttlSeconds,
|
||||
RuleName: rule.Name,
|
||||
KeySourceType: strings.TrimSpace(usedSource.Type),
|
||||
KeySourceKey: strings.TrimSpace(usedSource.Key),
|
||||
KeySourcePath: strings.TrimSpace(usedSource.Path),
|
||||
KeyFingerprint: affinityFingerprint(affinityValue),
|
||||
UsingGroup: usingGroup,
|
||||
ModelName: modelName,
|
||||
RequestPath: path,
|
||||
})
|
||||
|
||||
cache := getChannelAffinityCache()
|
||||
channelID, found, err := cache.Get(cacheKeySuffix)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("channel affinity cache get failed: key=%s, err=%v", cacheKeyFull, err))
|
||||
return 0, false
|
||||
}
|
||||
if found {
|
||||
return channelID, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
|
||||
if c == nil || channelID <= 0 {
|
||||
return
|
||||
}
|
||||
meta, ok := getChannelAffinityMeta(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
info := map[string]interface{}{
|
||||
"reason": meta.RuleName,
|
||||
"rule_name": meta.RuleName,
|
||||
"using_group": meta.UsingGroup,
|
||||
"selected_group": selectedGroup,
|
||||
"model": meta.ModelName,
|
||||
"request_path": meta.RequestPath,
|
||||
"channel_id": channelID,
|
||||
"key_source": meta.KeySourceType,
|
||||
"key_key": meta.KeySourceKey,
|
||||
"key_path": meta.KeySourcePath,
|
||||
"key_fp": meta.KeyFingerprint,
|
||||
}
|
||||
c.Set(ginKeyChannelAffinityLogInfo, info)
|
||||
}
|
||||
|
||||
func AppendChannelAffinityAdminInfo(c *gin.Context, adminInfo map[string]interface{}) {
|
||||
if c == nil || adminInfo == nil {
|
||||
return
|
||||
}
|
||||
anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo)
|
||||
if !ok || anyInfo == nil {
|
||||
return
|
||||
}
|
||||
adminInfo["channel_affinity"] = anyInfo
|
||||
}
|
||||
|
||||
func RecordChannelAffinity(c *gin.Context, channelID int) {
|
||||
if channelID <= 0 {
|
||||
return
|
||||
}
|
||||
setting := operation_setting.GetChannelAffinitySetting()
|
||||
if setting == nil || !setting.Enabled {
|
||||
return
|
||||
}
|
||||
if setting.SwitchOnSuccess && c != nil {
|
||||
if successChannelID := c.GetInt("channel_id"); successChannelID > 0 {
|
||||
channelID = successChannelID
|
||||
}
|
||||
}
|
||||
cacheKey, ttlSeconds, ok := getChannelAffinityContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if ttlSeconds <= 0 {
|
||||
ttlSeconds = setting.DefaultTTLSeconds
|
||||
}
|
||||
if ttlSeconds <= 0 {
|
||||
ttlSeconds = 3600
|
||||
}
|
||||
cache := getChannelAffinityCache()
|
||||
if err := cache.SetWithTTL(cacheKey, channelID, time.Duration(ttlSeconds)*time.Second); err != nil {
|
||||
common.SysError(fmt.Sprintf("channel affinity cache set failed: key=%s, err=%v", cacheKey, err))
|
||||
}
|
||||
}
|
||||
@@ -68,6 +68,8 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
adminInfo["local_count_tokens"] = isLocalCountTokens
|
||||
}
|
||||
|
||||
AppendChannelAffinityAdminInfo(ctx, adminInfo)
|
||||
|
||||
other["admin_info"] = adminInfo
|
||||
appendRequestPath(ctx, relayInfo, other)
|
||||
appendRequestConversionChain(relayInfo, other)
|
||||
|
||||
@@ -5,10 +5,10 @@ import (
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
)
|
||||
|
||||
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool {
|
||||
return openaicompat.ShouldChatCompletionsUseResponsesPolicy(policy, channelID, model)
|
||||
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, channelType int, model string) bool {
|
||||
return openaicompat.ShouldChatCompletionsUseResponsesPolicy(policy, channelID, channelType, model)
|
||||
}
|
||||
|
||||
func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool {
|
||||
return openaicompat.ShouldChatCompletionsUseResponsesGlobal(channelID, model)
|
||||
func ShouldChatCompletionsUseResponsesGlobal(channelID int, channelType int, model string) bool {
|
||||
return openaicompat.ShouldChatCompletionsUseResponsesGlobal(channelID, channelType, model)
|
||||
}
|
||||
|
||||
@@ -2,17 +2,18 @@ package openaicompat
|
||||
|
||||
import "github.com/QuantumNous/new-api/setting/model_setting"
|
||||
|
||||
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool {
|
||||
if !policy.IsChannelEnabled(channelID) {
|
||||
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, channelType int, model string) bool {
|
||||
if !policy.IsChannelEnabled(channelID, channelType) {
|
||||
return false
|
||||
}
|
||||
return matchAnyRegex(policy.ModelPatterns, model)
|
||||
}
|
||||
|
||||
func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool {
|
||||
func ShouldChatCompletionsUseResponsesGlobal(channelID int, channelType int, model string) bool {
|
||||
return ShouldChatCompletionsUseResponsesPolicy(
|
||||
model_setting.GetGlobalSettings().ChatCompletionsToResponsesPolicy,
|
||||
channelID,
|
||||
channelType,
|
||||
model,
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user