Files
new-api/service/channel_affinity.go
Seefs d7d3a2f763 feat: channel affinity (#2669)
* feat: channel affinity

* feat: channel affinity -> model setting

* fix: channel affinity

* feat: channel affinity op

* feat: channel_type setting

* feat: clean

* feat: cache supports both memory and Redis.

* feat: Optimise ui/ux

* feat: Optimise ui/ux

* feat: Optimise codex usage ui/ux

* feat: Optimise ui/ux

* feat: Optimise ui/ux

* feat: Optimise ui/ux

* feat: If the affinitized channel fails and a retry succeeds on another channel, update the affinity to the successful channel
2026-01-26 19:57:41 +08:00

488 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/pkg/cachex"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
"github.com/samber/hot"
"github.com/tidwall/gjson"
)
const (
ginKeyChannelAffinityCacheKey = "channel_affinity_cache_key"
ginKeyChannelAffinityTTLSeconds = "channel_affinity_ttl_seconds"
ginKeyChannelAffinityMeta = "channel_affinity_meta"
ginKeyChannelAffinityLogInfo = "channel_affinity_log_info"
channelAffinityCacheNamespace = "new-api:channel_affinity:v1"
)
var (
channelAffinityCacheOnce sync.Once
channelAffinityCache *cachex.HybridCache[int]
channelAffinityRegexCache sync.Map // map[string]*regexp.Regexp
)
type channelAffinityMeta struct {
CacheKey string
TTLSeconds int
RuleName string
KeySourceType string
KeySourceKey string
KeySourcePath string
KeyFingerprint string
UsingGroup string
ModelName string
RequestPath string
}
type ChannelAffinityCacheStats struct {
Enabled bool `json:"enabled"`
Total int `json:"total"`
Unknown int `json:"unknown"`
ByRuleName map[string]int `json:"by_rule_name"`
CacheCapacity int `json:"cache_capacity"`
CacheAlgo string `json:"cache_algo"`
}
func getChannelAffinityCache() *cachex.HybridCache[int] {
channelAffinityCacheOnce.Do(func() {
setting := operation_setting.GetChannelAffinitySetting()
capacity := setting.MaxEntries
if capacity <= 0 {
capacity = 100_000
}
defaultTTLSeconds := setting.DefaultTTLSeconds
if defaultTTLSeconds <= 0 {
defaultTTLSeconds = 3600
}
channelAffinityCache = cachex.NewHybridCache[int](cachex.HybridCacheConfig[int]{
Namespace: cachex.Namespace(channelAffinityCacheNamespace),
Redis: common.RDB,
RedisEnabled: func() bool {
return common.RedisEnabled && common.RDB != nil
},
RedisCodec: cachex.IntCodec{},
Memory: func() *hot.HotCache[string, int] {
return hot.NewHotCache[string, int](hot.LRU, capacity).
WithTTL(time.Duration(defaultTTLSeconds) * time.Second).
WithJanitor().
Build()
},
})
})
return channelAffinityCache
}
func GetChannelAffinityCacheStats() ChannelAffinityCacheStats {
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil {
return ChannelAffinityCacheStats{
Enabled: false,
Total: 0,
Unknown: 0,
ByRuleName: map[string]int{},
}
}
cache := getChannelAffinityCache()
mainCap, _ := cache.Capacity()
mainAlgo, _ := cache.Algorithm()
rules := setting.Rules
ruleByName := make(map[string]operation_setting.ChannelAffinityRule, len(rules))
for _, r := range rules {
name := strings.TrimSpace(r.Name)
if name == "" {
continue
}
if !r.IncludeRuleName {
continue
}
ruleByName[name] = r
}
byRuleName := make(map[string]int, len(ruleByName))
for name := range ruleByName {
byRuleName[name] = 0
}
keys, err := cache.Keys()
if err != nil {
common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err))
keys = nil
}
total := len(keys)
unknown := 0
for _, k := range keys {
prefix := channelAffinityCacheNamespace + ":"
if !strings.HasPrefix(k, prefix) {
unknown++
continue
}
rest := strings.TrimPrefix(k, prefix)
parts := strings.Split(rest, ":")
if len(parts) < 2 {
unknown++
continue
}
ruleName := parts[0]
rule, ok := ruleByName[ruleName]
if !ok {
unknown++
continue
}
if rule.IncludeUsingGroup {
if len(parts) < 3 {
unknown++
continue
}
}
byRuleName[ruleName]++
}
return ChannelAffinityCacheStats{
Enabled: setting.Enabled,
Total: total,
Unknown: unknown,
ByRuleName: byRuleName,
CacheCapacity: mainCap,
CacheAlgo: mainAlgo,
}
}
func ClearChannelAffinityCacheAll() int {
cache := getChannelAffinityCache()
keys, err := cache.Keys()
if err != nil {
common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err))
keys = nil
}
if len(keys) > 0 {
if _, err := cache.DeleteMany(keys); err != nil {
common.SysError(fmt.Sprintf("channel affinity cache delete many failed: err=%v", err))
}
}
return len(keys)
}
func ClearChannelAffinityCacheByRuleName(ruleName string) (int, error) {
ruleName = strings.TrimSpace(ruleName)
if ruleName == "" {
return 0, fmt.Errorf("rule_name 不能为空")
}
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil {
return 0, fmt.Errorf("channel_affinity_setting 未初始化")
}
var matchedRule *operation_setting.ChannelAffinityRule
for i := range setting.Rules {
r := &setting.Rules[i]
if strings.TrimSpace(r.Name) != ruleName {
continue
}
matchedRule = r
break
}
if matchedRule == nil {
return 0, fmt.Errorf("未知规则名称")
}
if !matchedRule.IncludeRuleName {
return 0, fmt.Errorf("该规则未启用 include_rule_name无法按规则清空缓存")
}
cache := getChannelAffinityCache()
deleted, err := cache.DeleteByPrefix(ruleName)
if err != nil {
return 0, err
}
return deleted, nil
}
func matchAnyRegexCached(patterns []string, s string) bool {
if len(patterns) == 0 || s == "" {
return false
}
for _, pattern := range patterns {
if pattern == "" {
continue
}
re, ok := channelAffinityRegexCache.Load(pattern)
if !ok {
compiled, err := regexp.Compile(pattern)
if err != nil {
continue
}
re = compiled
channelAffinityRegexCache.Store(pattern, re)
}
if re.(*regexp.Regexp).MatchString(s) {
return true
}
}
return false
}
func matchAnyIncludeFold(patterns []string, s string) bool {
if len(patterns) == 0 || s == "" {
return false
}
sLower := strings.ToLower(s)
for _, p := range patterns {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if strings.Contains(sLower, strings.ToLower(p)) {
return true
}
}
return false
}
func extractChannelAffinityValue(c *gin.Context, src operation_setting.ChannelAffinityKeySource) string {
switch src.Type {
case "context_int":
if src.Key == "" {
return ""
}
v := c.GetInt(src.Key)
if v <= 0 {
return ""
}
return strconv.Itoa(v)
case "context_string":
if src.Key == "" {
return ""
}
return strings.TrimSpace(c.GetString(src.Key))
case "gjson":
if src.Path == "" {
return ""
}
body, err := common.GetRequestBody(c)
if err != nil || len(body) == 0 {
return ""
}
res := gjson.GetBytes(body, src.Path)
if !res.Exists() {
return ""
}
switch res.Type {
case gjson.String, gjson.Number, gjson.True, gjson.False:
return strings.TrimSpace(res.String())
default:
return strings.TrimSpace(res.Raw)
}
default:
return ""
}
}
func buildChannelAffinityCacheKeySuffix(rule operation_setting.ChannelAffinityRule, usingGroup string, affinityValue string) string {
parts := make([]string, 0, 3)
if rule.IncludeRuleName && rule.Name != "" {
parts = append(parts, rule.Name)
}
if rule.IncludeUsingGroup && usingGroup != "" {
parts = append(parts, usingGroup)
}
parts = append(parts, affinityValue)
return strings.Join(parts, ":")
}
func setChannelAffinityContext(c *gin.Context, meta channelAffinityMeta) {
c.Set(ginKeyChannelAffinityCacheKey, meta.CacheKey)
c.Set(ginKeyChannelAffinityTTLSeconds, meta.TTLSeconds)
c.Set(ginKeyChannelAffinityMeta, meta)
}
func getChannelAffinityContext(c *gin.Context) (string, int, bool) {
keyAny, ok := c.Get(ginKeyChannelAffinityCacheKey)
if !ok {
return "", 0, false
}
key, ok := keyAny.(string)
if !ok || key == "" {
return "", 0, false
}
ttlAny, ok := c.Get(ginKeyChannelAffinityTTLSeconds)
if !ok {
return key, 0, true
}
ttlSeconds, _ := ttlAny.(int)
return key, ttlSeconds, true
}
func getChannelAffinityMeta(c *gin.Context) (channelAffinityMeta, bool) {
anyMeta, ok := c.Get(ginKeyChannelAffinityMeta)
if !ok {
return channelAffinityMeta{}, false
}
meta, ok := anyMeta.(channelAffinityMeta)
if !ok {
return channelAffinityMeta{}, false
}
return meta, true
}
func affinityFingerprint(s string) string {
if s == "" {
return ""
}
hex := common.Sha1([]byte(s))
if len(hex) >= 8 {
return hex[:8]
}
return hex
}
func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) {
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil || !setting.Enabled {
return 0, false
}
path := ""
if c != nil && c.Request != nil && c.Request.URL != nil {
path = c.Request.URL.Path
}
userAgent := ""
if c != nil && c.Request != nil {
userAgent = c.Request.UserAgent()
}
for _, rule := range setting.Rules {
if !matchAnyRegexCached(rule.ModelRegex, modelName) {
continue
}
if len(rule.PathRegex) > 0 && !matchAnyRegexCached(rule.PathRegex, path) {
continue
}
if len(rule.UserAgentInclude) > 0 && !matchAnyIncludeFold(rule.UserAgentInclude, userAgent) {
continue
}
var affinityValue string
var usedSource operation_setting.ChannelAffinityKeySource
for _, src := range rule.KeySources {
affinityValue = extractChannelAffinityValue(c, src)
if affinityValue != "" {
usedSource = src
break
}
}
if affinityValue == "" {
continue
}
if rule.ValueRegex != "" && !matchAnyRegexCached([]string{rule.ValueRegex}, affinityValue) {
continue
}
ttlSeconds := rule.TTLSeconds
if ttlSeconds <= 0 {
ttlSeconds = setting.DefaultTTLSeconds
}
cacheKeySuffix := buildChannelAffinityCacheKeySuffix(rule, usingGroup, affinityValue)
cacheKeyFull := channelAffinityCacheNamespace + ":" + cacheKeySuffix
setChannelAffinityContext(c, channelAffinityMeta{
CacheKey: cacheKeyFull,
TTLSeconds: ttlSeconds,
RuleName: rule.Name,
KeySourceType: strings.TrimSpace(usedSource.Type),
KeySourceKey: strings.TrimSpace(usedSource.Key),
KeySourcePath: strings.TrimSpace(usedSource.Path),
KeyFingerprint: affinityFingerprint(affinityValue),
UsingGroup: usingGroup,
ModelName: modelName,
RequestPath: path,
})
cache := getChannelAffinityCache()
channelID, found, err := cache.Get(cacheKeySuffix)
if err != nil {
common.SysError(fmt.Sprintf("channel affinity cache get failed: key=%s, err=%v", cacheKeyFull, err))
return 0, false
}
if found {
return channelID, true
}
return 0, false
}
return 0, false
}
func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
if c == nil || channelID <= 0 {
return
}
meta, ok := getChannelAffinityMeta(c)
if !ok {
return
}
info := map[string]interface{}{
"reason": meta.RuleName,
"rule_name": meta.RuleName,
"using_group": meta.UsingGroup,
"selected_group": selectedGroup,
"model": meta.ModelName,
"request_path": meta.RequestPath,
"channel_id": channelID,
"key_source": meta.KeySourceType,
"key_key": meta.KeySourceKey,
"key_path": meta.KeySourcePath,
"key_fp": meta.KeyFingerprint,
}
c.Set(ginKeyChannelAffinityLogInfo, info)
}
func AppendChannelAffinityAdminInfo(c *gin.Context, adminInfo map[string]interface{}) {
if c == nil || adminInfo == nil {
return
}
anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo)
if !ok || anyInfo == nil {
return
}
adminInfo["channel_affinity"] = anyInfo
}
func RecordChannelAffinity(c *gin.Context, channelID int) {
if channelID <= 0 {
return
}
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil || !setting.Enabled {
return
}
if setting.SwitchOnSuccess && c != nil {
if successChannelID := c.GetInt("channel_id"); successChannelID > 0 {
channelID = successChannelID
}
}
cacheKey, ttlSeconds, ok := getChannelAffinityContext(c)
if !ok {
return
}
if ttlSeconds <= 0 {
ttlSeconds = setting.DefaultTTLSeconds
}
if ttlSeconds <= 0 {
ttlSeconds = 3600
}
cache := getChannelAffinityCache()
if err := cache.SetWithTTL(cacheKey, channelID, time.Duration(ttlSeconds)*time.Second); err != nil {
common.SysError(fmt.Sprintf("channel affinity cache set failed: key=%s, err=%v", cacheKey, err))
}
}