feat: channel affinity (#2669)

* feat: channel affinity

* feat: channel affinity -> model setting

* fix: channel affinity

* feat: channel affinity op

* feat: channel_type setting

* feat: clean

* feat: cache supports both memory and Redis.

* feat: Optimise ui/ux

* feat: Optimise ui/ux

* feat: Optimise codex usage ui/ux

* feat: Optimise ui/ux

* feat: Optimise ui/ux

* feat: Optimise ui/ux

* feat: If the affinitized channel fails and a retry succeeds on another channel, update the affinity to the successful channel
This commit is contained in:
Seefs
2026-01-26 19:57:41 +08:00
committed by GitHub
parent 7da04be52b
commit d7d3a2f763
24 changed files with 2542 additions and 125 deletions

487
service/channel_affinity.go Normal file
View File

@@ -0,0 +1,487 @@
package service
import (
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/pkg/cachex"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
"github.com/samber/hot"
"github.com/tidwall/gjson"
)
const (
ginKeyChannelAffinityCacheKey = "channel_affinity_cache_key"
ginKeyChannelAffinityTTLSeconds = "channel_affinity_ttl_seconds"
ginKeyChannelAffinityMeta = "channel_affinity_meta"
ginKeyChannelAffinityLogInfo = "channel_affinity_log_info"
channelAffinityCacheNamespace = "new-api:channel_affinity:v1"
)
var (
channelAffinityCacheOnce sync.Once
channelAffinityCache *cachex.HybridCache[int]
channelAffinityRegexCache sync.Map // map[string]*regexp.Regexp
)
type channelAffinityMeta struct {
CacheKey string
TTLSeconds int
RuleName string
KeySourceType string
KeySourceKey string
KeySourcePath string
KeyFingerprint string
UsingGroup string
ModelName string
RequestPath string
}
type ChannelAffinityCacheStats struct {
Enabled bool `json:"enabled"`
Total int `json:"total"`
Unknown int `json:"unknown"`
ByRuleName map[string]int `json:"by_rule_name"`
CacheCapacity int `json:"cache_capacity"`
CacheAlgo string `json:"cache_algo"`
}
func getChannelAffinityCache() *cachex.HybridCache[int] {
channelAffinityCacheOnce.Do(func() {
setting := operation_setting.GetChannelAffinitySetting()
capacity := setting.MaxEntries
if capacity <= 0 {
capacity = 100_000
}
defaultTTLSeconds := setting.DefaultTTLSeconds
if defaultTTLSeconds <= 0 {
defaultTTLSeconds = 3600
}
channelAffinityCache = cachex.NewHybridCache[int](cachex.HybridCacheConfig[int]{
Namespace: cachex.Namespace(channelAffinityCacheNamespace),
Redis: common.RDB,
RedisEnabled: func() bool {
return common.RedisEnabled && common.RDB != nil
},
RedisCodec: cachex.IntCodec{},
Memory: func() *hot.HotCache[string, int] {
return hot.NewHotCache[string, int](hot.LRU, capacity).
WithTTL(time.Duration(defaultTTLSeconds) * time.Second).
WithJanitor().
Build()
},
})
})
return channelAffinityCache
}
func GetChannelAffinityCacheStats() ChannelAffinityCacheStats {
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil {
return ChannelAffinityCacheStats{
Enabled: false,
Total: 0,
Unknown: 0,
ByRuleName: map[string]int{},
}
}
cache := getChannelAffinityCache()
mainCap, _ := cache.Capacity()
mainAlgo, _ := cache.Algorithm()
rules := setting.Rules
ruleByName := make(map[string]operation_setting.ChannelAffinityRule, len(rules))
for _, r := range rules {
name := strings.TrimSpace(r.Name)
if name == "" {
continue
}
if !r.IncludeRuleName {
continue
}
ruleByName[name] = r
}
byRuleName := make(map[string]int, len(ruleByName))
for name := range ruleByName {
byRuleName[name] = 0
}
keys, err := cache.Keys()
if err != nil {
common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err))
keys = nil
}
total := len(keys)
unknown := 0
for _, k := range keys {
prefix := channelAffinityCacheNamespace + ":"
if !strings.HasPrefix(k, prefix) {
unknown++
continue
}
rest := strings.TrimPrefix(k, prefix)
parts := strings.Split(rest, ":")
if len(parts) < 2 {
unknown++
continue
}
ruleName := parts[0]
rule, ok := ruleByName[ruleName]
if !ok {
unknown++
continue
}
if rule.IncludeUsingGroup {
if len(parts) < 3 {
unknown++
continue
}
}
byRuleName[ruleName]++
}
return ChannelAffinityCacheStats{
Enabled: setting.Enabled,
Total: total,
Unknown: unknown,
ByRuleName: byRuleName,
CacheCapacity: mainCap,
CacheAlgo: mainAlgo,
}
}
func ClearChannelAffinityCacheAll() int {
cache := getChannelAffinityCache()
keys, err := cache.Keys()
if err != nil {
common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err))
keys = nil
}
if len(keys) > 0 {
if _, err := cache.DeleteMany(keys); err != nil {
common.SysError(fmt.Sprintf("channel affinity cache delete many failed: err=%v", err))
}
}
return len(keys)
}
func ClearChannelAffinityCacheByRuleName(ruleName string) (int, error) {
ruleName = strings.TrimSpace(ruleName)
if ruleName == "" {
return 0, fmt.Errorf("rule_name 不能为空")
}
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil {
return 0, fmt.Errorf("channel_affinity_setting 未初始化")
}
var matchedRule *operation_setting.ChannelAffinityRule
for i := range setting.Rules {
r := &setting.Rules[i]
if strings.TrimSpace(r.Name) != ruleName {
continue
}
matchedRule = r
break
}
if matchedRule == nil {
return 0, fmt.Errorf("未知规则名称")
}
if !matchedRule.IncludeRuleName {
return 0, fmt.Errorf("该规则未启用 include_rule_name无法按规则清空缓存")
}
cache := getChannelAffinityCache()
deleted, err := cache.DeleteByPrefix(ruleName)
if err != nil {
return 0, err
}
return deleted, nil
}
func matchAnyRegexCached(patterns []string, s string) bool {
if len(patterns) == 0 || s == "" {
return false
}
for _, pattern := range patterns {
if pattern == "" {
continue
}
re, ok := channelAffinityRegexCache.Load(pattern)
if !ok {
compiled, err := regexp.Compile(pattern)
if err != nil {
continue
}
re = compiled
channelAffinityRegexCache.Store(pattern, re)
}
if re.(*regexp.Regexp).MatchString(s) {
return true
}
}
return false
}
func matchAnyIncludeFold(patterns []string, s string) bool {
if len(patterns) == 0 || s == "" {
return false
}
sLower := strings.ToLower(s)
for _, p := range patterns {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if strings.Contains(sLower, strings.ToLower(p)) {
return true
}
}
return false
}
func extractChannelAffinityValue(c *gin.Context, src operation_setting.ChannelAffinityKeySource) string {
switch src.Type {
case "context_int":
if src.Key == "" {
return ""
}
v := c.GetInt(src.Key)
if v <= 0 {
return ""
}
return strconv.Itoa(v)
case "context_string":
if src.Key == "" {
return ""
}
return strings.TrimSpace(c.GetString(src.Key))
case "gjson":
if src.Path == "" {
return ""
}
body, err := common.GetRequestBody(c)
if err != nil || len(body) == 0 {
return ""
}
res := gjson.GetBytes(body, src.Path)
if !res.Exists() {
return ""
}
switch res.Type {
case gjson.String, gjson.Number, gjson.True, gjson.False:
return strings.TrimSpace(res.String())
default:
return strings.TrimSpace(res.Raw)
}
default:
return ""
}
}
func buildChannelAffinityCacheKeySuffix(rule operation_setting.ChannelAffinityRule, usingGroup string, affinityValue string) string {
parts := make([]string, 0, 3)
if rule.IncludeRuleName && rule.Name != "" {
parts = append(parts, rule.Name)
}
if rule.IncludeUsingGroup && usingGroup != "" {
parts = append(parts, usingGroup)
}
parts = append(parts, affinityValue)
return strings.Join(parts, ":")
}
func setChannelAffinityContext(c *gin.Context, meta channelAffinityMeta) {
c.Set(ginKeyChannelAffinityCacheKey, meta.CacheKey)
c.Set(ginKeyChannelAffinityTTLSeconds, meta.TTLSeconds)
c.Set(ginKeyChannelAffinityMeta, meta)
}
func getChannelAffinityContext(c *gin.Context) (string, int, bool) {
keyAny, ok := c.Get(ginKeyChannelAffinityCacheKey)
if !ok {
return "", 0, false
}
key, ok := keyAny.(string)
if !ok || key == "" {
return "", 0, false
}
ttlAny, ok := c.Get(ginKeyChannelAffinityTTLSeconds)
if !ok {
return key, 0, true
}
ttlSeconds, _ := ttlAny.(int)
return key, ttlSeconds, true
}
func getChannelAffinityMeta(c *gin.Context) (channelAffinityMeta, bool) {
anyMeta, ok := c.Get(ginKeyChannelAffinityMeta)
if !ok {
return channelAffinityMeta{}, false
}
meta, ok := anyMeta.(channelAffinityMeta)
if !ok {
return channelAffinityMeta{}, false
}
return meta, true
}
func affinityFingerprint(s string) string {
if s == "" {
return ""
}
hex := common.Sha1([]byte(s))
if len(hex) >= 8 {
return hex[:8]
}
return hex
}
func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) {
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil || !setting.Enabled {
return 0, false
}
path := ""
if c != nil && c.Request != nil && c.Request.URL != nil {
path = c.Request.URL.Path
}
userAgent := ""
if c != nil && c.Request != nil {
userAgent = c.Request.UserAgent()
}
for _, rule := range setting.Rules {
if !matchAnyRegexCached(rule.ModelRegex, modelName) {
continue
}
if len(rule.PathRegex) > 0 && !matchAnyRegexCached(rule.PathRegex, path) {
continue
}
if len(rule.UserAgentInclude) > 0 && !matchAnyIncludeFold(rule.UserAgentInclude, userAgent) {
continue
}
var affinityValue string
var usedSource operation_setting.ChannelAffinityKeySource
for _, src := range rule.KeySources {
affinityValue = extractChannelAffinityValue(c, src)
if affinityValue != "" {
usedSource = src
break
}
}
if affinityValue == "" {
continue
}
if rule.ValueRegex != "" && !matchAnyRegexCached([]string{rule.ValueRegex}, affinityValue) {
continue
}
ttlSeconds := rule.TTLSeconds
if ttlSeconds <= 0 {
ttlSeconds = setting.DefaultTTLSeconds
}
cacheKeySuffix := buildChannelAffinityCacheKeySuffix(rule, usingGroup, affinityValue)
cacheKeyFull := channelAffinityCacheNamespace + ":" + cacheKeySuffix
setChannelAffinityContext(c, channelAffinityMeta{
CacheKey: cacheKeyFull,
TTLSeconds: ttlSeconds,
RuleName: rule.Name,
KeySourceType: strings.TrimSpace(usedSource.Type),
KeySourceKey: strings.TrimSpace(usedSource.Key),
KeySourcePath: strings.TrimSpace(usedSource.Path),
KeyFingerprint: affinityFingerprint(affinityValue),
UsingGroup: usingGroup,
ModelName: modelName,
RequestPath: path,
})
cache := getChannelAffinityCache()
channelID, found, err := cache.Get(cacheKeySuffix)
if err != nil {
common.SysError(fmt.Sprintf("channel affinity cache get failed: key=%s, err=%v", cacheKeyFull, err))
return 0, false
}
if found {
return channelID, true
}
return 0, false
}
return 0, false
}
func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
if c == nil || channelID <= 0 {
return
}
meta, ok := getChannelAffinityMeta(c)
if !ok {
return
}
info := map[string]interface{}{
"reason": meta.RuleName,
"rule_name": meta.RuleName,
"using_group": meta.UsingGroup,
"selected_group": selectedGroup,
"model": meta.ModelName,
"request_path": meta.RequestPath,
"channel_id": channelID,
"key_source": meta.KeySourceType,
"key_key": meta.KeySourceKey,
"key_path": meta.KeySourcePath,
"key_fp": meta.KeyFingerprint,
}
c.Set(ginKeyChannelAffinityLogInfo, info)
}
func AppendChannelAffinityAdminInfo(c *gin.Context, adminInfo map[string]interface{}) {
if c == nil || adminInfo == nil {
return
}
anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo)
if !ok || anyInfo == nil {
return
}
adminInfo["channel_affinity"] = anyInfo
}
func RecordChannelAffinity(c *gin.Context, channelID int) {
if channelID <= 0 {
return
}
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil || !setting.Enabled {
return
}
if setting.SwitchOnSuccess && c != nil {
if successChannelID := c.GetInt("channel_id"); successChannelID > 0 {
channelID = successChannelID
}
}
cacheKey, ttlSeconds, ok := getChannelAffinityContext(c)
if !ok {
return
}
if ttlSeconds <= 0 {
ttlSeconds = setting.DefaultTTLSeconds
}
if ttlSeconds <= 0 {
ttlSeconds = 3600
}
cache := getChannelAffinityCache()
if err := cache.SetWithTTL(cacheKey, channelID, time.Duration(ttlSeconds)*time.Second); err != nil {
common.SysError(fmt.Sprintf("channel affinity cache set failed: key=%s, err=%v", cacheKey, err))
}
}

View File

@@ -68,6 +68,8 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
adminInfo["local_count_tokens"] = isLocalCountTokens
}
AppendChannelAffinityAdminInfo(ctx, adminInfo)
other["admin_info"] = adminInfo
appendRequestPath(ctx, relayInfo, other)
appendRequestConversionChain(relayInfo, other)

View File

@@ -5,10 +5,10 @@ import (
"github.com/QuantumNous/new-api/setting/model_setting"
)
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool {
return openaicompat.ShouldChatCompletionsUseResponsesPolicy(policy, channelID, model)
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, channelType int, model string) bool {
return openaicompat.ShouldChatCompletionsUseResponsesPolicy(policy, channelID, channelType, model)
}
func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool {
return openaicompat.ShouldChatCompletionsUseResponsesGlobal(channelID, model)
func ShouldChatCompletionsUseResponsesGlobal(channelID int, channelType int, model string) bool {
return openaicompat.ShouldChatCompletionsUseResponsesGlobal(channelID, channelType, model)
}

View File

@@ -2,17 +2,18 @@ package openaicompat
import "github.com/QuantumNous/new-api/setting/model_setting"
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool {
if !policy.IsChannelEnabled(channelID) {
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, channelType int, model string) bool {
if !policy.IsChannelEnabled(channelID, channelType) {
return false
}
return matchAnyRegex(policy.ModelPatterns, model)
}
func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool {
func ShouldChatCompletionsUseResponsesGlobal(channelID int, channelType int, model string) bool {
return ShouldChatCompletionsUseResponsesPolicy(
model_setting.GetGlobalSettings().ChatCompletionsToResponsesPolicy,
channelID,
channelType,
model,
)
}