diff --git a/controller/channel_affinity_cache.go b/controller/channel_affinity_cache.go new file mode 100644 index 000000000..bb5cab20a --- /dev/null +++ b/controller/channel_affinity_cache.go @@ -0,0 +1,60 @@ +package controller + +import ( + "net/http" + "strings" + + "github.com/QuantumNous/new-api/service" + "github.com/gin-gonic/gin" +) + +func GetChannelAffinityCacheStats(c *gin.Context) { + stats := service.GetChannelAffinityCacheStats() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": stats, + }) +} + +func ClearChannelAffinityCache(c *gin.Context) { + all := strings.TrimSpace(c.Query("all")) + ruleName := strings.TrimSpace(c.Query("rule_name")) + + if all == "true" { + deleted := service.ClearChannelAffinityCacheAll() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "deleted": deleted, + }, + }) + return + } + + if ruleName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "缺少参数:rule_name,或使用 all=true 清空全部", + }) + return + } + + deleted, err := service.ClearChannelAffinityCacheByRuleName(ruleName) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "deleted": deleted, + }, + }) +} diff --git a/controller/relay.go b/controller/relay.go index 4fba947f7..906a6969b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -362,6 +362,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t adminInfo["is_multi_key"] = true adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex) } + service.AppendChannelAffinityAdminInfo(c, adminInfo) other["admin_info"] = adminInfo model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, 0, false, userGroup, other) } diff --git a/go.mod b/go.mod index f4f133973..0ea30e998 100644 --- a/go.mod +++ b/go.mod @@ -55,16 +55,18 @@ require ( ) require ( + github.com/DmitriyVTitov/size v1.5.0 // indirect github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect + github.com/beorn7/perks v1.0.1 // indirect github.com/boombuler/barcode v1.1.0 // indirect github.com/bytedance/sonic v1.14.1 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -94,7 +96,7 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.17.8 // indirect + github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -103,10 +105,17 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pelletier/go-toml/v2 v2.2.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/prometheus/client_golang v1.22.0 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/samber/go-singleflightx v0.3.2 // indirect + github.com/samber/hot v0.11.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect @@ -120,7 +129,7 @@ require ( golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect - google.golang.org/protobuf v1.34.2 // indirect + google.golang.org/protobuf v1.36.5 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.66.10 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index 697a313d8..7e9f3bd70 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A= github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U= +github.com/DmitriyVTitov/size v1.5.0 h1:/PzqxYrOyOUX1BXj6J9OuVRVGe+66VL4D9FlUaW515g= +github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0= github.com/abema/go-mp4 v1.4.1 h1:YoS4VRqd+pAmddRPLFf8vMk74kuGl6ULSjzhsIqwr6M= github.com/abema/go-mp4 v1.4.1/go.mod h1:vPl9t5ZK7K0x68jh12/+ECWBCXoWuIDtNgPtU2f04ws= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= @@ -22,6 +24,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fv github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA= github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= @@ -40,6 +44,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= @@ -110,6 +116,7 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -165,6 +172,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -200,6 +209,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= @@ -218,13 +229,27 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/samber/go-singleflightx v0.3.2 h1:jXbUU0fvis8Fdv4HGONboX5WdEZcYLoBEcKiE+ITCyQ= +github.com/samber/go-singleflightx v0.3.2/go.mod h1:X2BR+oheHIYc73PvxRMlcASg6KYYTQyUYpdVU7t/ux4= +github.com/samber/hot v0.11.0 h1:JhV9hk8SmZIqB0To8OyCzPubvszkuoSXWx/7FCEGO+Q= +github.com/samber/hot v0.11.0/go.mod h1:NB9v5U4NfDx7jmlrP+zHuqCuLUsywgAtCH7XOAkOxAg= github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= @@ -332,6 +357,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= +google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/middleware/distributor.go b/middleware/distributor.go index 95fa64a30..054763c9e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -97,35 +97,64 @@ func Distribute() func(c *gin.Context) { common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup) } } - channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{ - Ctx: c, - ModelName: modelRequest.Model, - TokenGroup: usingGroup, - Retry: common.GetPointer(0), - }) - if err != nil { - showGroup := usingGroup - if usingGroup == "auto" { - showGroup = fmt.Sprintf("auto(%s)", selectGroup) + + if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found { + preferred, err := model.CacheGetChannel(preferredChannelID) + if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled { + if usingGroup == "auto" { + userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) + autoGroups := service.GetUserAutoGroup(userGroup) + for _, g := range autoGroups { + if model.IsChannelEnabledForGroupModel(g, modelRequest.Model, preferred.Id) { + selectGroup = g + common.SetContextKey(c, constant.ContextKeyAutoGroup, g) + channel = preferred + service.MarkChannelAffinityUsed(c, g, preferred.Id) + break + } + } + } else if model.IsChannelEnabledForGroupModel(usingGroup, modelRequest.Model, preferred.Id) { + channel = preferred + selectGroup = usingGroup + service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id) + } } - message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(distributor): %s", showGroup, modelRequest.Model, err.Error()) - // 如果错误,但是渠道不为空,说明是数据库一致性问题 - //if channel != nil { - // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) - // message = "数据库一致性已被破坏,请联系管理员" - //} - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, types.ErrorCodeModelNotFound) - return } + if channel == nil { - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", usingGroup, modelRequest.Model), types.ErrorCodeModelNotFound) - return + channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{ + Ctx: c, + ModelName: modelRequest.Model, + TokenGroup: usingGroup, + Retry: common.GetPointer(0), + }) + if err != nil { + showGroup := usingGroup + if usingGroup == "auto" { + showGroup = fmt.Sprintf("auto(%s)", selectGroup) + } + message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(distributor): %s", showGroup, modelRequest.Model, err.Error()) + // 如果错误,但是渠道不为空,说明是数据库一致性问题 + //if channel != nil { + // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + // message = "数据库一致性已被破坏,请联系管理员" + //} + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, types.ErrorCodeModelNotFound) + return + } + if channel == nil { + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", usingGroup, modelRequest.Model), types.ErrorCodeModelNotFound) + return + } } } } common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) SetupContextForSelectedChannel(c, channel, modelRequest.Model) c.Next() + if channel != nil && c.Writer != nil && c.Writer.Status() < http.StatusBadRequest { + service.RecordChannelAffinity(c, channel.Id) + } } } diff --git a/model/channel_satisfy.go b/model/channel_satisfy.go new file mode 100644 index 000000000..681f1e69b --- /dev/null +++ b/model/channel_satisfy.go @@ -0,0 +1,71 @@ +package model + +import ( + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/setting/ratio_setting" +) + +func IsChannelEnabledForGroupModel(group string, modelName string, channelID int) bool { + if group == "" || modelName == "" || channelID <= 0 { + return false + } + if !common.MemoryCacheEnabled { + return isChannelEnabledForGroupModelDB(group, modelName, channelID) + } + + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + + if group2model2channels == nil { + return false + } + + if isChannelIDInList(group2model2channels[group][modelName], channelID) { + return true + } + normalized := ratio_setting.FormatMatchingModelName(modelName) + if normalized != "" && normalized != modelName { + return isChannelIDInList(group2model2channels[group][normalized], channelID) + } + return false +} + +func IsChannelEnabledForAnyGroupModel(groups []string, modelName string, channelID int) bool { + if len(groups) == 0 { + return false + } + for _, g := range groups { + if IsChannelEnabledForGroupModel(g, modelName, channelID) { + return true + } + } + return false +} + +func isChannelEnabledForGroupModelDB(group string, modelName string, channelID int) bool { + var count int64 + err := DB.Model(&Ability{}). + Where(commonGroupCol+" = ? and model = ? and channel_id = ? and enabled = ?", group, modelName, channelID, true). + Count(&count).Error + if err == nil && count > 0 { + return true + } + normalized := ratio_setting.FormatMatchingModelName(modelName) + if normalized == "" || normalized == modelName { + return false + } + count = 0 + err = DB.Model(&Ability{}). + Where(commonGroupCol+" = ? and model = ? and channel_id = ? and enabled = ?", group, normalized, channelID, true). + Count(&count).Error + return err == nil && count > 0 +} + +func isChannelIDInList(list []int, channelID int) bool { + for _, id := range list { + if id == channelID { + return true + } + } + return false +} diff --git a/pkg/cachex/codec.go b/pkg/cachex/codec.go new file mode 100644 index 000000000..2e4957a84 --- /dev/null +++ b/pkg/cachex/codec.go @@ -0,0 +1,53 @@ +package cachex + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" +) + +type ValueCodec[V any] interface { + Encode(v V) (string, error) + Decode(s string) (V, error) +} + +type IntCodec struct{} + +func (c IntCodec) Encode(v int) (string, error) { + return strconv.Itoa(v), nil +} + +func (c IntCodec) Decode(s string) (int, error) { + s = strings.TrimSpace(s) + if s == "" { + return 0, fmt.Errorf("empty int value") + } + return strconv.Atoi(s) +} + +type StringCodec struct{} + +func (c StringCodec) Encode(v string) (string, error) { return v, nil } +func (c StringCodec) Decode(s string) (string, error) { return s, nil } + +type JSONCodec[V any] struct{} + +func (c JSONCodec[V]) Encode(v V) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(b), nil +} + +func (c JSONCodec[V]) Decode(s string) (V, error) { + var v V + if strings.TrimSpace(s) == "" { + return v, fmt.Errorf("empty json value") + } + if err := json.Unmarshal([]byte(s), &v); err != nil { + return v, err + } + return v, nil +} diff --git a/pkg/cachex/hybrid_cache.go b/pkg/cachex/hybrid_cache.go new file mode 100644 index 000000000..9df3cfe64 --- /dev/null +++ b/pkg/cachex/hybrid_cache.go @@ -0,0 +1,285 @@ +package cachex + +import ( + "context" + "errors" + "strings" + "sync" + "time" + + "github.com/go-redis/redis/v8" + "github.com/samber/hot" +) + +const ( + defaultRedisOpTimeout = 2 * time.Second + defaultRedisScanTimeout = 30 * time.Second + defaultRedisDelTimeout = 10 * time.Second +) + +type HybridCacheConfig[V any] struct { + Namespace Namespace + + // Redis is used when RedisEnabled returns true (or RedisEnabled is nil) and Redis is not nil. + Redis *redis.Client + RedisCodec ValueCodec[V] + RedisEnabled func() bool + + // Memory builds a hot cache used when Redis is disabled. Keys stored in memory are fully namespaced. + Memory func() *hot.HotCache[string, V] +} + +// HybridCache is a small helper that uses Redis when enabled, otherwise falls back to in-memory hot cache. +type HybridCache[V any] struct { + ns Namespace + + redis *redis.Client + redisCodec ValueCodec[V] + redisEnabled func() bool + + memOnce sync.Once + memInit func() *hot.HotCache[string, V] + mem *hot.HotCache[string, V] +} + +func NewHybridCache[V any](cfg HybridCacheConfig[V]) *HybridCache[V] { + return &HybridCache[V]{ + ns: cfg.Namespace, + redis: cfg.Redis, + redisCodec: cfg.RedisCodec, + redisEnabled: cfg.RedisEnabled, + memInit: cfg.Memory, + } +} + +func (c *HybridCache[V]) FullKey(key string) string { + return c.ns.FullKey(key) +} + +func (c *HybridCache[V]) redisOn() bool { + if c.redis == nil || c.redisCodec == nil { + return false + } + if c.redisEnabled == nil { + return true + } + return c.redisEnabled() +} + +func (c *HybridCache[V]) memCache() *hot.HotCache[string, V] { + c.memOnce.Do(func() { + if c.memInit == nil { + c.mem = hot.NewHotCache[string, V](hot.LRU, 1).Build() + return + } + c.mem = c.memInit() + }) + return c.mem +} + +func (c *HybridCache[V]) Get(key string) (value V, found bool, err error) { + full := c.ns.FullKey(key) + if full == "" { + var zero V + return zero, false, nil + } + + if c.redisOn() { + ctx, cancel := context.WithTimeout(context.Background(), defaultRedisOpTimeout) + defer cancel() + + raw, e := c.redis.Get(ctx, full).Result() + if e == nil { + v, decErr := c.redisCodec.Decode(raw) + if decErr != nil { + var zero V + return zero, false, decErr + } + return v, true, nil + } + if errors.Is(e, redis.Nil) { + var zero V + return zero, false, nil + } + var zero V + return zero, false, e + } + + return c.memCache().Get(full) +} + +func (c *HybridCache[V]) SetWithTTL(key string, v V, ttl time.Duration) error { + full := c.ns.FullKey(key) + if full == "" { + return nil + } + + if c.redisOn() { + raw, err := c.redisCodec.Encode(v) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultRedisOpTimeout) + defer cancel() + return c.redis.Set(ctx, full, raw, ttl).Err() + } + + c.memCache().SetWithTTL(full, v, ttl) + return nil +} + +// Keys returns keys with valid values. In Redis, it returns all matching keys. +func (c *HybridCache[V]) Keys() ([]string, error) { + if c.redisOn() { + return c.scanKeys(c.ns.MatchPattern()) + } + return c.memCache().Keys(), nil +} + +func (c *HybridCache[V]) scanKeys(match string) ([]string, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultRedisScanTimeout) + defer cancel() + + var cursor uint64 + keys := make([]string, 0, 1024) + for { + k, next, err := c.redis.Scan(ctx, cursor, match, 1000).Result() + if err != nil { + return keys, err + } + keys = append(keys, k...) + cursor = next + if cursor == 0 { + break + } + } + return keys, nil +} + +func (c *HybridCache[V]) Purge() error { + if c.redisOn() { + keys, err := c.scanKeys(c.ns.MatchPattern()) + if err != nil { + return err + } + if len(keys) == 0 { + return nil + } + _, err = c.DeleteMany(keys) + return err + } + + c.memCache().Purge() + return nil +} + +func (c *HybridCache[V]) DeleteByPrefix(prefix string) (int, error) { + fullPrefix := c.ns.FullKey(prefix) + if fullPrefix == "" { + return 0, nil + } + if !strings.HasSuffix(fullPrefix, ":") { + fullPrefix += ":" + } + + if c.redisOn() { + match := fullPrefix + "*" + keys, err := c.scanKeys(match) + if err != nil { + return 0, err + } + if len(keys) == 0 { + return 0, nil + } + + res, err := c.DeleteMany(keys) + if err != nil { + return 0, err + } + deleted := 0 + for _, ok := range res { + if ok { + deleted++ + } + } + return deleted, nil + } + + // In memory, we filter keys and bulk delete. + allKeys := c.memCache().Keys() + keys := make([]string, 0, 128) + for _, k := range allKeys { + if strings.HasPrefix(k, fullPrefix) { + keys = append(keys, k) + } + } + if len(keys) == 0 { + return 0, nil + } + res, _ := c.DeleteMany(keys) + deleted := 0 + for _, ok := range res { + if ok { + deleted++ + } + } + return deleted, nil +} + +// DeleteMany accepts either fully namespaced keys or raw keys and deletes them. +// It returns a map keyed by fully namespaced keys. +func (c *HybridCache[V]) DeleteMany(keys []string) (map[string]bool, error) { + res := make(map[string]bool, len(keys)) + if len(keys) == 0 { + return res, nil + } + + fullKeys := make([]string, 0, len(keys)) + for _, k := range keys { + k = c.ns.FullKey(k) + if k == "" { + continue + } + fullKeys = append(fullKeys, k) + } + if len(fullKeys) == 0 { + return res, nil + } + + if c.redisOn() { + ctx, cancel := context.WithTimeout(context.Background(), defaultRedisDelTimeout) + defer cancel() + + pipe := c.redis.Pipeline() + cmds := make([]*redis.IntCmd, 0, len(fullKeys)) + for _, k := range fullKeys { + // UNLINK is non-blocking vs DEL for large key batches. + cmds = append(cmds, pipe.Unlink(ctx, k)) + } + _, err := pipe.Exec(ctx) + if err != nil && !errors.Is(err, redis.Nil) { + return res, err + } + for i, cmd := range cmds { + deleted := cmd != nil && cmd.Err() == nil && cmd.Val() > 0 + res[fullKeys[i]] = deleted + } + return res, nil + } + + return c.memCache().DeleteMany(fullKeys), nil +} + +func (c *HybridCache[V]) Capacity() (mainCacheCapacity int, missingCacheCapacity int) { + if c.redisOn() { + return 0, 0 + } + return c.memCache().Capacity() +} + +func (c *HybridCache[V]) Algorithm() (mainCacheAlgorithm string, missingCacheAlgorithm string) { + if c.redisOn() { + return "redis", "" + } + return c.memCache().Algorithm() +} diff --git a/pkg/cachex/namespace.go b/pkg/cachex/namespace.go new file mode 100644 index 000000000..e6806bf2f --- /dev/null +++ b/pkg/cachex/namespace.go @@ -0,0 +1,38 @@ +package cachex + +import "strings" + +// Namespace isolates keys between different cache use-cases. (e.g. "channel_affinity:v1"). +type Namespace string + +func (n Namespace) prefix() string { + ns := strings.TrimSpace(string(n)) + ns = strings.TrimRight(ns, ":") + if ns == "" { + return "" + } + return ns + ":" +} + +func (n Namespace) FullKey(key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + p := n.prefix() + if p == "" { + return strings.TrimLeft(key, ":") + } + if strings.HasPrefix(key, p) { + return key + } + return p + strings.TrimLeft(key, ":") +} + +func (n Namespace) MatchPattern() string { + p := n.prefix() + if p == "" { + return "*" + } + return p + "*" +} diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 5792715b2..b2706730d 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -79,7 +79,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types if info.RelayMode == relayconstant.RelayModeChatCompletions && !passThroughGlobal && !info.ChannelSetting.PassThroughBodyEnabled && - shouldChatCompletionsViaResponses(info) { + service.ShouldChatCompletionsUseResponsesGlobal(info.ChannelId, info.ChannelType, info.OriginModelName) { applySystemPromptIfNeeded(c, info, request) usage, newApiErr := chatCompletionsViaResponses(c, info, adaptor, request) if newApiErr != nil { @@ -218,16 +218,6 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types return nil } -func shouldChatCompletionsViaResponses(info *relaycommon.RelayInfo) bool { - if info == nil { - return false - } - if info.RelayMode != relayconstant.RelayModeChatCompletions { - return false - } - return service.ShouldChatCompletionsUseResponsesGlobal(info.ChannelId, info.OriginModelName) -} - func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) { if usage == nil { usage = &dto.Usage{ diff --git a/router/api-router.go b/router/api-router.go index f3ae4d970..39e43eee9 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -123,6 +123,8 @@ func SetApiRouter(router *gin.Engine) { { optionRoute.GET("/", controller.GetOptions) optionRoute.PUT("/", controller.UpdateOption) + optionRoute.GET("/channel_affinity_cache", controller.GetChannelAffinityCacheStats) + optionRoute.DELETE("/channel_affinity_cache", controller.ClearChannelAffinityCache) optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio) optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } diff --git a/service/channel_affinity.go b/service/channel_affinity.go new file mode 100644 index 000000000..5aa50adb6 --- /dev/null +++ b/service/channel_affinity.go @@ -0,0 +1,487 @@ +package service + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/pkg/cachex" + "github.com/QuantumNous/new-api/setting/operation_setting" + "github.com/gin-gonic/gin" + "github.com/samber/hot" + "github.com/tidwall/gjson" +) + +const ( + ginKeyChannelAffinityCacheKey = "channel_affinity_cache_key" + ginKeyChannelAffinityTTLSeconds = "channel_affinity_ttl_seconds" + ginKeyChannelAffinityMeta = "channel_affinity_meta" + ginKeyChannelAffinityLogInfo = "channel_affinity_log_info" + + channelAffinityCacheNamespace = "new-api:channel_affinity:v1" +) + +var ( + channelAffinityCacheOnce sync.Once + channelAffinityCache *cachex.HybridCache[int] + + channelAffinityRegexCache sync.Map // map[string]*regexp.Regexp +) + +type channelAffinityMeta struct { + CacheKey string + TTLSeconds int + RuleName string + KeySourceType string + KeySourceKey string + KeySourcePath string + KeyFingerprint string + UsingGroup string + ModelName string + RequestPath string +} + +type ChannelAffinityCacheStats struct { + Enabled bool `json:"enabled"` + Total int `json:"total"` + Unknown int `json:"unknown"` + ByRuleName map[string]int `json:"by_rule_name"` + CacheCapacity int `json:"cache_capacity"` + CacheAlgo string `json:"cache_algo"` +} + +func getChannelAffinityCache() *cachex.HybridCache[int] { + channelAffinityCacheOnce.Do(func() { + setting := operation_setting.GetChannelAffinitySetting() + capacity := setting.MaxEntries + if capacity <= 0 { + capacity = 100_000 + } + defaultTTLSeconds := setting.DefaultTTLSeconds + if defaultTTLSeconds <= 0 { + defaultTTLSeconds = 3600 + } + + channelAffinityCache = cachex.NewHybridCache[int](cachex.HybridCacheConfig[int]{ + Namespace: cachex.Namespace(channelAffinityCacheNamespace), + Redis: common.RDB, + RedisEnabled: func() bool { + return common.RedisEnabled && common.RDB != nil + }, + RedisCodec: cachex.IntCodec{}, + Memory: func() *hot.HotCache[string, int] { + return hot.NewHotCache[string, int](hot.LRU, capacity). + WithTTL(time.Duration(defaultTTLSeconds) * time.Second). + WithJanitor(). + Build() + }, + }) + }) + return channelAffinityCache +} + +func GetChannelAffinityCacheStats() ChannelAffinityCacheStats { + setting := operation_setting.GetChannelAffinitySetting() + if setting == nil { + return ChannelAffinityCacheStats{ + Enabled: false, + Total: 0, + Unknown: 0, + ByRuleName: map[string]int{}, + } + } + + cache := getChannelAffinityCache() + mainCap, _ := cache.Capacity() + mainAlgo, _ := cache.Algorithm() + + rules := setting.Rules + ruleByName := make(map[string]operation_setting.ChannelAffinityRule, len(rules)) + for _, r := range rules { + name := strings.TrimSpace(r.Name) + if name == "" { + continue + } + if !r.IncludeRuleName { + continue + } + ruleByName[name] = r + } + + byRuleName := make(map[string]int, len(ruleByName)) + for name := range ruleByName { + byRuleName[name] = 0 + } + + keys, err := cache.Keys() + if err != nil { + common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err)) + keys = nil + } + total := len(keys) + unknown := 0 + for _, k := range keys { + prefix := channelAffinityCacheNamespace + ":" + if !strings.HasPrefix(k, prefix) { + unknown++ + continue + } + rest := strings.TrimPrefix(k, prefix) + parts := strings.Split(rest, ":") + if len(parts) < 2 { + unknown++ + continue + } + ruleName := parts[0] + rule, ok := ruleByName[ruleName] + if !ok { + unknown++ + continue + } + if rule.IncludeUsingGroup { + if len(parts) < 3 { + unknown++ + continue + } + } + byRuleName[ruleName]++ + } + + return ChannelAffinityCacheStats{ + Enabled: setting.Enabled, + Total: total, + Unknown: unknown, + ByRuleName: byRuleName, + CacheCapacity: mainCap, + CacheAlgo: mainAlgo, + } +} + +func ClearChannelAffinityCacheAll() int { + cache := getChannelAffinityCache() + keys, err := cache.Keys() + if err != nil { + common.SysError(fmt.Sprintf("channel affinity cache list keys failed: err=%v", err)) + keys = nil + } + if len(keys) > 0 { + if _, err := cache.DeleteMany(keys); err != nil { + common.SysError(fmt.Sprintf("channel affinity cache delete many failed: err=%v", err)) + } + } + return len(keys) +} + +func ClearChannelAffinityCacheByRuleName(ruleName string) (int, error) { + ruleName = strings.TrimSpace(ruleName) + if ruleName == "" { + return 0, fmt.Errorf("rule_name 不能为空") + } + + setting := operation_setting.GetChannelAffinitySetting() + if setting == nil { + return 0, fmt.Errorf("channel_affinity_setting 未初始化") + } + + var matchedRule *operation_setting.ChannelAffinityRule + for i := range setting.Rules { + r := &setting.Rules[i] + if strings.TrimSpace(r.Name) != ruleName { + continue + } + matchedRule = r + break + } + if matchedRule == nil { + return 0, fmt.Errorf("未知规则名称") + } + if !matchedRule.IncludeRuleName { + return 0, fmt.Errorf("该规则未启用 include_rule_name,无法按规则清空缓存") + } + + cache := getChannelAffinityCache() + deleted, err := cache.DeleteByPrefix(ruleName) + if err != nil { + return 0, err + } + return deleted, nil +} + +func matchAnyRegexCached(patterns []string, s string) bool { + if len(patterns) == 0 || s == "" { + return false + } + for _, pattern := range patterns { + if pattern == "" { + continue + } + re, ok := channelAffinityRegexCache.Load(pattern) + if !ok { + compiled, err := regexp.Compile(pattern) + if err != nil { + continue + } + re = compiled + channelAffinityRegexCache.Store(pattern, re) + } + if re.(*regexp.Regexp).MatchString(s) { + return true + } + } + return false +} + +func matchAnyIncludeFold(patterns []string, s string) bool { + if len(patterns) == 0 || s == "" { + return false + } + sLower := strings.ToLower(s) + for _, p := range patterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if strings.Contains(sLower, strings.ToLower(p)) { + return true + } + } + return false +} + +func extractChannelAffinityValue(c *gin.Context, src operation_setting.ChannelAffinityKeySource) string { + switch src.Type { + case "context_int": + if src.Key == "" { + return "" + } + v := c.GetInt(src.Key) + if v <= 0 { + return "" + } + return strconv.Itoa(v) + case "context_string": + if src.Key == "" { + return "" + } + return strings.TrimSpace(c.GetString(src.Key)) + case "gjson": + if src.Path == "" { + return "" + } + body, err := common.GetRequestBody(c) + if err != nil || len(body) == 0 { + return "" + } + res := gjson.GetBytes(body, src.Path) + if !res.Exists() { + return "" + } + switch res.Type { + case gjson.String, gjson.Number, gjson.True, gjson.False: + return strings.TrimSpace(res.String()) + default: + return strings.TrimSpace(res.Raw) + } + default: + return "" + } +} + +func buildChannelAffinityCacheKeySuffix(rule operation_setting.ChannelAffinityRule, usingGroup string, affinityValue string) string { + parts := make([]string, 0, 3) + if rule.IncludeRuleName && rule.Name != "" { + parts = append(parts, rule.Name) + } + if rule.IncludeUsingGroup && usingGroup != "" { + parts = append(parts, usingGroup) + } + parts = append(parts, affinityValue) + return strings.Join(parts, ":") +} + +func setChannelAffinityContext(c *gin.Context, meta channelAffinityMeta) { + c.Set(ginKeyChannelAffinityCacheKey, meta.CacheKey) + c.Set(ginKeyChannelAffinityTTLSeconds, meta.TTLSeconds) + c.Set(ginKeyChannelAffinityMeta, meta) +} + +func getChannelAffinityContext(c *gin.Context) (string, int, bool) { + keyAny, ok := c.Get(ginKeyChannelAffinityCacheKey) + if !ok { + return "", 0, false + } + key, ok := keyAny.(string) + if !ok || key == "" { + return "", 0, false + } + ttlAny, ok := c.Get(ginKeyChannelAffinityTTLSeconds) + if !ok { + return key, 0, true + } + ttlSeconds, _ := ttlAny.(int) + return key, ttlSeconds, true +} + +func getChannelAffinityMeta(c *gin.Context) (channelAffinityMeta, bool) { + anyMeta, ok := c.Get(ginKeyChannelAffinityMeta) + if !ok { + return channelAffinityMeta{}, false + } + meta, ok := anyMeta.(channelAffinityMeta) + if !ok { + return channelAffinityMeta{}, false + } + return meta, true +} + +func affinityFingerprint(s string) string { + if s == "" { + return "" + } + hex := common.Sha1([]byte(s)) + if len(hex) >= 8 { + return hex[:8] + } + return hex +} + +func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) { + setting := operation_setting.GetChannelAffinitySetting() + if setting == nil || !setting.Enabled { + return 0, false + } + path := "" + if c != nil && c.Request != nil && c.Request.URL != nil { + path = c.Request.URL.Path + } + userAgent := "" + if c != nil && c.Request != nil { + userAgent = c.Request.UserAgent() + } + + for _, rule := range setting.Rules { + if !matchAnyRegexCached(rule.ModelRegex, modelName) { + continue + } + if len(rule.PathRegex) > 0 && !matchAnyRegexCached(rule.PathRegex, path) { + continue + } + if len(rule.UserAgentInclude) > 0 && !matchAnyIncludeFold(rule.UserAgentInclude, userAgent) { + continue + } + var affinityValue string + var usedSource operation_setting.ChannelAffinityKeySource + for _, src := range rule.KeySources { + affinityValue = extractChannelAffinityValue(c, src) + if affinityValue != "" { + usedSource = src + break + } + } + if affinityValue == "" { + continue + } + if rule.ValueRegex != "" && !matchAnyRegexCached([]string{rule.ValueRegex}, affinityValue) { + continue + } + + ttlSeconds := rule.TTLSeconds + if ttlSeconds <= 0 { + ttlSeconds = setting.DefaultTTLSeconds + } + cacheKeySuffix := buildChannelAffinityCacheKeySuffix(rule, usingGroup, affinityValue) + cacheKeyFull := channelAffinityCacheNamespace + ":" + cacheKeySuffix + setChannelAffinityContext(c, channelAffinityMeta{ + CacheKey: cacheKeyFull, + TTLSeconds: ttlSeconds, + RuleName: rule.Name, + KeySourceType: strings.TrimSpace(usedSource.Type), + KeySourceKey: strings.TrimSpace(usedSource.Key), + KeySourcePath: strings.TrimSpace(usedSource.Path), + KeyFingerprint: affinityFingerprint(affinityValue), + UsingGroup: usingGroup, + ModelName: modelName, + RequestPath: path, + }) + + cache := getChannelAffinityCache() + channelID, found, err := cache.Get(cacheKeySuffix) + if err != nil { + common.SysError(fmt.Sprintf("channel affinity cache get failed: key=%s, err=%v", cacheKeyFull, err)) + return 0, false + } + if found { + return channelID, true + } + return 0, false + } + return 0, false +} + +func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) { + if c == nil || channelID <= 0 { + return + } + meta, ok := getChannelAffinityMeta(c) + if !ok { + return + } + info := map[string]interface{}{ + "reason": meta.RuleName, + "rule_name": meta.RuleName, + "using_group": meta.UsingGroup, + "selected_group": selectedGroup, + "model": meta.ModelName, + "request_path": meta.RequestPath, + "channel_id": channelID, + "key_source": meta.KeySourceType, + "key_key": meta.KeySourceKey, + "key_path": meta.KeySourcePath, + "key_fp": meta.KeyFingerprint, + } + c.Set(ginKeyChannelAffinityLogInfo, info) +} + +func AppendChannelAffinityAdminInfo(c *gin.Context, adminInfo map[string]interface{}) { + if c == nil || adminInfo == nil { + return + } + anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo) + if !ok || anyInfo == nil { + return + } + adminInfo["channel_affinity"] = anyInfo +} + +func RecordChannelAffinity(c *gin.Context, channelID int) { + if channelID <= 0 { + return + } + setting := operation_setting.GetChannelAffinitySetting() + if setting == nil || !setting.Enabled { + return + } + if setting.SwitchOnSuccess && c != nil { + if successChannelID := c.GetInt("channel_id"); successChannelID > 0 { + channelID = successChannelID + } + } + cacheKey, ttlSeconds, ok := getChannelAffinityContext(c) + if !ok { + return + } + if ttlSeconds <= 0 { + ttlSeconds = setting.DefaultTTLSeconds + } + if ttlSeconds <= 0 { + ttlSeconds = 3600 + } + cache := getChannelAffinityCache() + if err := cache.SetWithTTL(cacheKey, channelID, time.Duration(ttlSeconds)*time.Second); err != nil { + common.SysError(fmt.Sprintf("channel affinity cache set failed: key=%s, err=%v", cacheKey, err)) + } +} diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 71a6bd32a..71dd22f47 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -68,6 +68,8 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m adminInfo["local_count_tokens"] = isLocalCountTokens } + AppendChannelAffinityAdminInfo(ctx, adminInfo) + other["admin_info"] = adminInfo appendRequestPath(ctx, relayInfo, other) appendRequestConversionChain(relayInfo, other) diff --git a/service/openai_chat_responses_mode.go b/service/openai_chat_responses_mode.go index a655a38be..c66c33c9d 100644 --- a/service/openai_chat_responses_mode.go +++ b/service/openai_chat_responses_mode.go @@ -5,10 +5,10 @@ import ( "github.com/QuantumNous/new-api/setting/model_setting" ) -func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool { - return openaicompat.ShouldChatCompletionsUseResponsesPolicy(policy, channelID, model) +func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, channelType int, model string) bool { + return openaicompat.ShouldChatCompletionsUseResponsesPolicy(policy, channelID, channelType, model) } -func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool { - return openaicompat.ShouldChatCompletionsUseResponsesGlobal(channelID, model) +func ShouldChatCompletionsUseResponsesGlobal(channelID int, channelType int, model string) bool { + return openaicompat.ShouldChatCompletionsUseResponsesGlobal(channelID, channelType, model) } diff --git a/service/openaicompat/policy.go b/service/openaicompat/policy.go index 39b11ce5d..b600b0fdc 100644 --- a/service/openaicompat/policy.go +++ b/service/openaicompat/policy.go @@ -2,17 +2,18 @@ package openaicompat import "github.com/QuantumNous/new-api/setting/model_setting" -func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool { - if !policy.IsChannelEnabled(channelID) { +func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, channelType int, model string) bool { + if !policy.IsChannelEnabled(channelID, channelType) { return false } return matchAnyRegex(policy.ModelPatterns, model) } -func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool { +func ShouldChatCompletionsUseResponsesGlobal(channelID int, channelType int, model string) bool { return ShouldChatCompletionsUseResponsesPolicy( model_setting.GetGlobalSettings().ChatCompletionsToResponsesPolicy, channelID, + channelType, model, ) } diff --git a/setting/model_setting/global.go b/setting/model_setting/global.go index 580171171..d0c4d3128 100644 --- a/setting/model_setting/global.go +++ b/setting/model_setting/global.go @@ -11,20 +11,25 @@ type ChatCompletionsToResponsesPolicy struct { Enabled bool `json:"enabled"` AllChannels bool `json:"all_channels"` ChannelIDs []int `json:"channel_ids,omitempty"` + ChannelTypes []int `json:"channel_types,omitempty"` ModelPatterns []string `json:"model_patterns,omitempty"` } -func (p ChatCompletionsToResponsesPolicy) IsChannelEnabled(channelID int) bool { +func (p ChatCompletionsToResponsesPolicy) IsChannelEnabled(channelID int, channelType int) bool { if !p.Enabled { return false } if p.AllChannels { return true } - if channelID == 0 || len(p.ChannelIDs) == 0 { - return false + + if channelID > 0 && len(p.ChannelIDs) > 0 && slices.Contains(p.ChannelIDs, channelID) { + return true } - return slices.Contains(p.ChannelIDs, channelID) + if channelType > 0 && len(p.ChannelTypes) > 0 && slices.Contains(p.ChannelTypes, channelType) { + return true + } + return false } type GlobalSettings struct { diff --git a/setting/operation_setting/channel_affinity_setting.go b/setting/operation_setting/channel_affinity_setting.go new file mode 100644 index 000000000..f95ac6969 --- /dev/null +++ b/setting/operation_setting/channel_affinity_setting.go @@ -0,0 +1,47 @@ +package operation_setting + +import "github.com/QuantumNous/new-api/setting/config" + +type ChannelAffinityKeySource struct { + Type string `json:"type"` // context_int, context_string, gjson + Key string `json:"key,omitempty"` + Path string `json:"path,omitempty"` +} + +type ChannelAffinityRule struct { + Name string `json:"name"` + ModelRegex []string `json:"model_regex"` + PathRegex []string `json:"path_regex"` + UserAgentInclude []string `json:"user_agent_include,omitempty"` + KeySources []ChannelAffinityKeySource `json:"key_sources"` + + ValueRegex string `json:"value_regex"` + TTLSeconds int `json:"ttl_seconds"` + + IncludeUsingGroup bool `json:"include_using_group"` + IncludeRuleName bool `json:"include_rule_name"` +} + +type ChannelAffinitySetting struct { + Enabled bool `json:"enabled"` + SwitchOnSuccess bool `json:"switch_on_success"` + MaxEntries int `json:"max_entries"` + DefaultTTLSeconds int `json:"default_ttl_seconds"` + Rules []ChannelAffinityRule `json:"rules"` +} + +var channelAffinitySetting = ChannelAffinitySetting{ + Enabled: false, + SwitchOnSuccess: true, + MaxEntries: 100_000, + DefaultTTLSeconds: 3600, + Rules: []ChannelAffinityRule{}, +} + +func init() { + config.GlobalConfig.Register("channel_affinity_setting", &channelAffinitySetting) +} + +func GetChannelAffinitySetting() *ChannelAffinitySetting { + return &channelAffinitySetting +} diff --git a/web/src/components/settings/ModelSetting.jsx b/web/src/components/settings/ModelSetting.jsx index 1f2206691..69ece885a 100644 --- a/web/src/components/settings/ModelSetting.jsx +++ b/web/src/components/settings/ModelSetting.jsx @@ -25,6 +25,7 @@ import { useTranslation } from 'react-i18next'; import SettingGeminiModel from '../../pages/Setting/Model/SettingGeminiModel'; import SettingClaudeModel from '../../pages/Setting/Model/SettingClaudeModel'; import SettingGlobalModel from '../../pages/Setting/Model/SettingGlobalModel'; +import SettingsChannelAffinity from '../../pages/Setting/Operation/SettingsChannelAffinity'; const ModelSetting = () => { const { t } = useTranslation(); @@ -109,6 +110,10 @@ const ModelSetting = () => { + {/* Channel affinity */} + + + {/* Gemini */} diff --git a/web/src/components/table/channels/modals/CodexUsageModal.jsx b/web/src/components/table/channels/modals/CodexUsageModal.jsx index df5e2c98b..16ad07610 100644 --- a/web/src/components/table/channels/modals/CodexUsageModal.jsx +++ b/web/src/components/table/channels/modals/CodexUsageModal.jsx @@ -17,8 +17,9 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React from 'react'; -import { Modal, Button, Progress, Tag, Typography } from '@douyinfe/semi-ui'; +import React, { useCallback, useEffect, useRef, useState } from 'react'; +import { Modal, Button, Progress, Tag, Typography, Spin } from '@douyinfe/semi-ui'; +import { API, showError } from '../../../../helpers'; const { Text } = Typography; @@ -101,7 +102,7 @@ const RateLimitWindowCard = ({ t, title, windowData }) => { ); }; -export const openCodexUsageModal = ({ t, record, payload, onCopy }) => { +const CodexUsageView = ({ t, record, payload, onCopy, onRefresh }) => { const tt = typeof t === 'function' ? t : (v) => v; const data = payload?.data ?? null; const rateLimit = data?.rate_limit ?? {}; @@ -123,61 +124,159 @@ export const openCodexUsageModal = ({ t, record, payload, onCopy }) => { const rawText = typeof data === 'string' ? data : JSON.stringify(data ?? payload, null, 2); - Modal.info({ - title: ( -
- {tt('Codex 用量')} - {statusTag} + return ( +
+
+ + {tt('渠道:')} + {record?.name || '-'} ({tt('编号:')} + {record?.id || '-'}) + +
+ {statusTag} + +
- ), + +
+ + {tt('上游状态码:')} + {upstreamStatus ?? '-'} + +
+ +
+ + +
+ +
+
+
{tt('原始 JSON')}
+ +
+
+          {rawText}
+        
+
+
+ ); +}; + +const CodexUsageLoader = ({ t, record, initialPayload, onCopy }) => { + const tt = typeof t === 'function' ? t : (v) => v; + const [loading, setLoading] = useState(!initialPayload); + const [payload, setPayload] = useState(initialPayload ?? null); + const hasShownErrorRef = useRef(false); + const mountedRef = useRef(true); + const recordId = record?.id; + + const fetchUsage = useCallback(async () => { + if (!recordId) { + if (mountedRef.current) setPayload(null); + return; + } + + if (mountedRef.current) setLoading(true); + try { + const res = await API.get(`/api/channel/${recordId}/codex/usage`, { + skipErrorHandler: true, + }); + if (!mountedRef.current) return; + setPayload(res?.data ?? null); + if (!res?.data?.success && !hasShownErrorRef.current) { + hasShownErrorRef.current = true; + showError(tt('获取用量失败')); + } + } catch (error) { + if (!mountedRef.current) return; + if (!hasShownErrorRef.current) { + hasShownErrorRef.current = true; + showError(tt('获取用量失败')); + } + setPayload({ success: false, message: String(error) }); + } finally { + if (mountedRef.current) setLoading(false); + } + }, [recordId, tt]); + + useEffect(() => { + mountedRef.current = true; + return () => { + mountedRef.current = false; + }; + }, []); + + useEffect(() => { + if (initialPayload) return; + fetchUsage().catch(() => {}); + }, [fetchUsage, initialPayload]); + + if (loading) { + return ( +
+ +
+ ); + } + + if (!payload) { + return ( +
+ {tt('获取用量失败')} +
+ +
+
+ ); + } + + return ( + + ); +}; + +export const openCodexUsageModal = ({ t, record, payload, onCopy }) => { + const tt = typeof t === 'function' ? t : (v) => v; + + Modal.info({ + title: tt('Codex 用量'), centered: true, width: 900, style: { maxWidth: '95vw' }, content: ( -
-
- - {tt('渠道:')} - {record?.name || '-'} ({tt('编号:')} - {record?.id || '-'}) - - - {tt('上游状态码:')} - {upstreamStatus ?? '-'} - -
- -
- - -
- -
-
-
{tt('原始 JSON')}
- -
-
-            {rawText}
-          
-
-
+ ), footer: (
diff --git a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx index b3096c286..4c7fb1d84 100644 --- a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx +++ b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx @@ -40,7 +40,7 @@ import { renderClaudeModelPrice, renderModelPrice, } from '../../../helpers'; -import { IconHelpCircle } from '@douyinfe/semi-icons'; +import { IconHelpCircle, IconStarStroked } from '@douyinfe/semi-icons'; import { Route } from 'lucide-react'; const colors = [ @@ -498,6 +498,7 @@ export const getLogsColumns = ({ return <>; } let content = t('渠道') + `:${record.channel}`; + let affinity = null; if (record.other !== '') { let other = JSON.parse(record.other); if (other === null) { @@ -513,9 +514,55 @@ export const getLogsColumns = ({ let useChannelStr = useChannel.join('->'); content = t('渠道') + `:${useChannelStr}`; } + if (other.admin_info.channel_affinity) { + affinity = other.admin_info.channel_affinity; + } } } - return isAdminUser ?
{content}
: <>; + return isAdminUser ? ( + +
{content}
+ {affinity ? ( + + {t('渠道亲和性')} +
+ + {t('规则')}:{affinity.rule_name || '-'} + +
+
+ + {t('分组')}:{affinity.selected_group || '-'} + +
+
+ + {t('Key')}: + {(affinity.key_source || '-') + + ':' + + (affinity.key_path || affinity.key_key || '-') + + (affinity.key_fp ? `#${affinity.key_fp}` : '')} + +
+
+ } + > + + + + + {t('优选')} + + + + + ) : null} + + ) : ( + <> + ); }, }, { @@ -552,9 +599,13 @@ export const getLogsColumns = ({ other.cache_creation_tokens || 0, other.cache_creation_ratio || 1.0, other.cache_creation_tokens_5m || 0, - other.cache_creation_ratio_5m || other.cache_creation_ratio || 1.0, + other.cache_creation_ratio_5m || + other.cache_creation_ratio || + 1.0, other.cache_creation_tokens_1h || 0, - other.cache_creation_ratio_1h || other.cache_creation_ratio || 1.0, + other.cache_creation_ratio_1h || + other.cache_creation_ratio || + 1.0, false, 1.0, other?.is_system_prompt_overwritten, diff --git a/web/src/hooks/channels/useChannelsData.jsx b/web/src/hooks/channels/useChannelsData.jsx index 5e1feb162..1ce0785fb 100644 --- a/web/src/hooks/channels/useChannelsData.jsx +++ b/web/src/hooks/channels/useChannelsData.jsx @@ -747,28 +747,15 @@ export const useChannelsData = () => { const updateChannelBalance = async (record) => { if (record?.type === 57) { - try { - const res = await API.get(`/api/channel/${record.id}/codex/usage`, { - skipErrorHandler: true, - }); - if (!res?.data?.success) { - console.error('Codex usage fetch failed:', res?.data?.message); - showError(t('获取用量失败')); - } - openCodexUsageModal({ - t, - record, - payload: res?.data, - onCopy: async (text) => { - const ok = await copy(text); - if (ok) showSuccess(t('已复制')); - else showError(t('复制失败')); - }, - }); - } catch (error) { - console.error('Codex usage fetch error:', error); - showError(t('获取用量失败')); - } + openCodexUsageModal({ + t, + record, + onCopy: async (text) => { + const ok = await copy(text); + if (ok) showSuccess(t('已复制')); + else showError(t('复制失败')); + }, + }); return; } diff --git a/web/src/index.css b/web/src/index.css index dff5360b9..229095068 100644 --- a/web/src/index.css +++ b/web/src/index.css @@ -818,6 +818,34 @@ html.dark .with-pastel-balls::before { padding: 10px !important; } +/* ==================== 使用日志: channel affinity tag ==================== */ +.semi-tag.channel-affinity-tag { + border: 1px solid rgba(var(--semi-cyan-5), 0.35); + background-color: rgba(var(--semi-cyan-5), 0.15); + color: rgba(var(--semi-cyan-9), 1); + cursor: help; + transition: + background-color 120ms ease, + border-color 120ms ease, + box-shadow 120ms ease; +} + +.semi-tag.channel-affinity-tag:hover { + background-color: rgba(var(--semi-cyan-5), 0.22); + border-color: rgba(var(--semi-cyan-5), 0.6); + box-shadow: 0 0 0 2px rgba(var(--semi-cyan-5), 0.18); +} + +.semi-tag.channel-affinity-tag:active { + background-color: rgba(var(--semi-cyan-5), 0.28); +} + +.semi-tag.channel-affinity-tag .channel-affinity-tag-content { + display: inline-flex; + align-items: center; + gap: 0.25rem; +} + /* ==================== 自定义圆角样式 ==================== */ .semi-radio, .semi-tagInput, diff --git a/web/src/pages/Setting/Model/SettingGlobalModel.jsx b/web/src/pages/Setting/Model/SettingGlobalModel.jsx index 3d4cfd56e..9878875c7 100644 --- a/web/src/pages/Setting/Model/SettingGlobalModel.jsx +++ b/web/src/pages/Setting/Model/SettingGlobalModel.jsx @@ -49,6 +49,7 @@ const chatCompletionsToResponsesPolicyExample = JSON.stringify( enabled: true, all_channels: false, channel_ids: [1, 2], + channel_types: [1], model_patterns: ['^gpt-4o.*$', '^gpt-5.*$'], }, null, diff --git a/web/src/pages/Setting/Operation/SettingsChannelAffinity.jsx b/web/src/pages/Setting/Operation/SettingsChannelAffinity.jsx new file mode 100644 index 000000000..86c2bc321 --- /dev/null +++ b/web/src/pages/Setting/Operation/SettingsChannelAffinity.jsx @@ -0,0 +1,1139 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useEffect, useRef, useState } from 'react'; +import { + Banner, + Button, + Col, + Collapse, + Divider, + Form, + Input, + Modal, + Row, + Select, + Space, + Spin, + Table, + Tag, + Typography, +} from '@douyinfe/semi-ui'; +import { + IconClose, + IconDelete, + IconEdit, + IconPlus, + IconRefresh, +} from '@douyinfe/semi-icons'; +import { + API, + compareObjects, + showError, + showSuccess, + showWarning, + toBoolean, + verifyJSON, +} from '../../../helpers'; +import { useTranslation } from 'react-i18next'; + +const KEY_ENABLED = 'channel_affinity_setting.enabled'; +const KEY_SWITCH_ON_SUCCESS = 'channel_affinity_setting.switch_on_success'; +const KEY_MAX_ENTRIES = 'channel_affinity_setting.max_entries'; +const KEY_DEFAULT_TTL = 'channel_affinity_setting.default_ttl_seconds'; +const KEY_RULES = 'channel_affinity_setting.rules'; + +const KEY_SOURCE_TYPES = [ + { label: 'context_int', value: 'context_int' }, + { label: 'context_string', value: 'context_string' }, + { label: 'gjson', value: 'gjson' }, +]; + +const RULE_TEMPLATES = { + codex: { + name: 'codex优选', + model_regex: ['^gpt-.*$'], + path_regex: ['/v1/responses'], + key_sources: [{ type: 'gjson', path: 'prompt_cache_key' }], + value_regex: '', + ttl_seconds: 0, + include_using_group: true, + include_rule_name: true, + }, + claudeCode: { + name: 'claude-code优选', + model_regex: ['^claude-.*$'], + path_regex: ['/v1/messages'], + key_sources: [{ type: 'gjson', path: 'metadata.user_id' }], + value_regex: '', + ttl_seconds: 0, + include_using_group: true, + include_rule_name: true, + }, +}; + +const CONTEXT_KEY_PRESETS = [ + { key: 'id', label: 'id(用户 ID)' }, + { key: 'token_id', label: 'token_id' }, + { key: 'token_key', label: 'token_key' }, + { key: 'token_group', label: 'token_group' }, + { key: 'group', label: 'group(using_group)' }, + { key: 'username', label: 'username' }, + { key: 'user_group', label: 'user_group' }, + { key: 'user_email', label: 'user_email' }, + { key: 'specific_channel_id', label: 'specific_channel_id' }, +]; + +const RULES_JSON_PLACEHOLDER = `[ + { + "name": "prefer-by-conversation-id", + "model_regex": ["^gpt-.*$"], + "path_regex": ["/v1/chat/completions"], + "user_agent_include": ["curl", "PostmanRuntime"], + "key_sources": [ + { "type": "gjson", "path": "metadata.conversation_id" }, + { "type": "context_string", "key": "conversation_id" } + ], + "value_regex": "^[-0-9A-Za-z._:]{1,128}$", + "ttl_seconds": 600, + "include_using_group": true, + "include_rule_name": true + } +]`; + +const normalizeStringList = (text) => { + if (!text) return []; + return text + .split('\n') + .map((s) => s.trim()) + .filter((s) => s.length > 0); +}; + +const stringifyPretty = (v) => JSON.stringify(v, null, 2); +const stringifyCompact = (v) => JSON.stringify(v); + +const parseRulesJson = (jsonString) => { + try { + const parsed = JSON.parse(jsonString || '[]'); + if (!Array.isArray(parsed)) return []; + return parsed.map((rule, index) => ({ + id: index, + ...(rule || {}), + })); + } catch (e) { + return []; + } +}; + +const rulesToJson = (rules) => { + const payload = (rules || []).map((r) => { + const { id, ...rest } = r || {}; + return rest; + }); + return stringifyPretty(payload); +}; + +const normalizeKeySource = (src) => { + const type = (src?.type || '').trim(); + const key = (src?.key || '').trim(); + const path = (src?.path || '').trim(); + return { type, key, path }; +}; + +const makeUniqueName = (existingNames, baseName) => { + const base = (baseName || '').trim() || 'rule'; + if (!existingNames.has(base)) return base; + for (let i = 2; i < 1000; i++) { + const n = `${base}-${i}`; + if (!existingNames.has(n)) return n; + } + return `${base}-${Date.now()}`; +}; + +const tryParseRulesJsonArray = (jsonString) => { + const raw = jsonString || '[]'; + if (!verifyJSON(raw)) return { ok: false, message: 'Rules JSON is invalid' }; + try { + const parsed = JSON.parse(raw); + if (!Array.isArray(parsed)) + return { ok: false, message: 'Rules JSON must be an array' }; + return { ok: true, value: parsed }; + } catch (e) { + return { ok: false, message: 'Rules JSON is invalid' }; + } +}; + +export default function SettingsChannelAffinity(props) { + const { t } = useTranslation(); + const { Text } = Typography; + const [loading, setLoading] = useState(false); + + const [cacheLoading, setCacheLoading] = useState(false); + const [cacheStats, setCacheStats] = useState({ + enabled: false, + total: 0, + unknown: 0, + by_rule_name: {}, + cache_capacity: 0, + cache_algo: '', + }); + + const [inputs, setInputs] = useState({ + [KEY_ENABLED]: false, + [KEY_SWITCH_ON_SUCCESS]: true, + [KEY_MAX_ENTRIES]: 100000, + [KEY_DEFAULT_TTL]: 3600, + [KEY_RULES]: '[]', + }); + const refForm = useRef(); + const [inputsRow, setInputsRow] = useState(inputs); + const [editMode, setEditMode] = useState('visual'); + const prevEditModeRef = useRef(editMode); + + const [rules, setRules] = useState([]); + const [modalVisible, setModalVisible] = useState(false); + const [editingRule, setEditingRule] = useState(null); + const [isEdit, setIsEdit] = useState(false); + const modalFormRef = useRef(); + const [modalInitValues, setModalInitValues] = useState(null); + const [modalFormKey, setModalFormKey] = useState(0); + const [modalAdvancedActiveKey, setModalAdvancedActiveKey] = useState([]); + + const effectiveDefaultTTLSeconds = + Number(inputs?.[KEY_DEFAULT_TTL] || 0) > 0 + ? Number(inputs?.[KEY_DEFAULT_TTL] || 0) + : 3600; + + const buildModalFormValues = (rule) => { + const r = rule || {}; + return { + name: r.name || '', + model_regex_text: (r.model_regex || []).join('\n'), + path_regex_text: (r.path_regex || []).join('\n'), + user_agent_include_text: (r.user_agent_include || []).join('\n'), + value_regex: r.value_regex || '', + ttl_seconds: Number(r.ttl_seconds || 0), + include_using_group: r.include_using_group ?? true, + include_rule_name: r.include_rule_name ?? true, + }; + }; + + const refreshCacheStats = async () => { + try { + setCacheLoading(true); + const res = await API.get('/api/option/channel_affinity_cache', { + disableDuplicate: true, + }); + const { success, message, data } = res.data; + if (!success) return showError(t(message)); + setCacheStats(data || {}); + } catch (e) { + showError(t('刷新缓存统计失败')); + } finally { + setCacheLoading(false); + } + }; + + const confirmClearAllCache = () => { + Modal.confirm({ + title: t('确认清空全部渠道亲和性缓存'), + content: ( +
+ {t('将删除所有仍在内存中的渠道亲和性缓存条目。')} +
+ ), + onOk: async () => { + const res = await API.delete('/api/option/channel_affinity_cache', { + params: { all: true }, + }); + const { success, message } = res.data; + if (!success) { + showError(t(message)); + return; + } + showSuccess(t('已清空')); + await refreshCacheStats(); + }, + }); + }; + + const confirmClearRuleCache = (rule) => { + const name = (rule?.name || '').trim(); + if (!name) return; + if (!rule?.include_rule_name) { + showWarning( + t('该规则未启用“作用域:包含规则名称”,无法按规则清空缓存。'), + ); + return; + } + Modal.confirm({ + title: t('确认清空该规则缓存'), + content: ( +
+ {t('规则')}: {name} +
+ ), + onOk: async () => { + const res = await API.delete('/api/option/channel_affinity_cache', { + params: { rule_name: name }, + }); + const { success, message } = res.data; + if (!success) { + showError(t(message)); + return; + } + showSuccess(t('已清空')); + await refreshCacheStats(); + }, + }); + }; + + const setRulesJsonToForm = (jsonString) => { + if (!refForm.current) return; + // Use setValue instead of setValues. Semi Form's setValues assigns undefined + // to every registered field not included in the payload, which can wipe other inputs. + refForm.current.setValue(KEY_RULES, jsonString || '[]'); + }; + + const switchToJsonMode = () => { + // Ensure a stable source of truth when entering JSON mode. + // Semi Form may ignore setValues() for an unmounted field, so we seed state first. + const jsonString = rulesToJson(rules); + setInputs((prev) => ({ ...(prev || {}), [KEY_RULES]: jsonString })); + setEditMode('json'); + }; + + const switchToVisualMode = () => { + const validation = tryParseRulesJsonArray(inputs[KEY_RULES] || '[]'); + if (!validation.ok) { + showError(t(validation.message)); + return; + } + setEditMode('visual'); + }; + + const updateRulesState = (nextRules) => { + setRules(nextRules); + const jsonString = rulesToJson(nextRules); + setInputs((prev) => ({ ...prev, [KEY_RULES]: jsonString })); + if (refForm.current && editMode === 'json') { + refForm.current.setValue(KEY_RULES, jsonString); + } + }; + + const appendCodexAndClaudeCodeTemplates = () => { + const doAppend = () => { + const existingNames = new Set( + (rules || []) + .map((r) => (r?.name || '').trim()) + .filter((x) => x.length > 0), + ); + + const templates = [RULE_TEMPLATES.codex, RULE_TEMPLATES.claudeCode].map( + (tpl) => { + const name = makeUniqueName(existingNames, tpl.name); + existingNames.add(name); + return { ...tpl, name }; + }, + ); + + const next = [...(rules || []), ...templates].map((r, idx) => ({ + ...(r || {}), + id: idx, + })); + updateRulesState(next); + showSuccess(t('已填充模版')); + }; + + if ((rules || []).length === 0) { + doAppend(); + return; + } + + Modal.confirm({ + title: t('填充 Codex / Claude Code 模版'), + content: ( +
+ {t('将追加 2 条规则到现有规则列表。')} +
+ ), + onOk: doAppend, + }); + }; + + const ruleColumns = [ + { + title: t('名称'), + dataIndex: 'name', + render: (text) => {text || '-'}, + }, + { + title: t('模型正则'), + dataIndex: 'model_regex', + render: (list) => + (list || []).length > 0 + ? (list || []).slice(0, 3).map((v, idx) => ( + + {v} + + )) + : '-', + }, + { + title: t('路径正则'), + dataIndex: 'path_regex', + render: (list) => + (list || []).length > 0 + ? (list || []).slice(0, 2).map((v, idx) => ( + + {v} + + )) + : '-', + }, + { + title: t('User-Agent include'), + dataIndex: 'user_agent_include', + render: (list) => + (list || []).length > 0 + ? (list || []).slice(0, 2).map((v, idx) => ( + + {v} + + )) + : '-', + }, + { + title: t('Key 来源'), + dataIndex: 'key_sources', + render: (list) => { + const xs = list || []; + if (xs.length === 0) return '-'; + return xs.slice(0, 3).map((src, idx) => { + const s = normalizeKeySource(src); + const detail = s.type === 'gjson' ? s.path : s.key; + return ( + + {s.type}:{detail} + + ); + }); + }, + }, + { + title: t('TTL(秒)'), + dataIndex: 'ttl_seconds', + render: (v) => {Number(v || 0) || '-'}, + }, + { + title: t('缓存条目数'), + render: (_, record) => { + const name = (record?.name || '').trim(); + if (!name || !record?.include_rule_name) { + return N/A; + } + const n = Number(cacheStats?.by_rule_name?.[name] || 0); + return {n}; + }, + }, + { + title: t('作用域'), + render: (_, record) => { + const tags = []; + if (record?.include_using_group) tags.push('分组'); + if (record?.include_rule_name) tags.push('规则'); + if (tags.length === 0) return '-'; + return tags.map((x) => ( + + {x} + + )); + }, + }, + { + title: t('操作'), + render: (_, record) => ( + + + + + + + + + + + {editMode === 'visual' ? ( + + ) : ( + verifyJSON(value || '[]'), + }, + ]} + onChange={(value) => + setInputs({ ...inputs, [KEY_RULES]: value }) + } + /> + )} + + + + + { + setModalVisible(false); + setEditingRule(null); + setModalInitValues(null); + setModalAdvancedActiveKey([]); + }} + onOk={handleModalSave} + okText={t('保存')} + cancelText={t('取消')} + width={720} + > +
{ + modalFormRef.current = formAPI; + }} + > + + setEditingRule((prev) => ({ ...(prev || {}), name: value })) + } + /> + + +
+ + + + + + + + { + const keys = Array.isArray(activeKey) ? activeKey : [activeKey]; + setModalAdvancedActiveKey(keys.filter(Boolean)); + }} + > + + + + + {t( + '可选。匹配入口请求的 User-Agent;任意一行作为子串匹配(忽略大小写)即命中。', + )} +
+ {t( + 'NewAPI 默认不会将入口请求的 User-Agent 透传到上游渠道;该条件仅用于识别访问本站点的客户端。', + )} +
+ {t( + '为保证匹配准确,请确保客户端直连本站点(避免反向代理/网关改写 User-Agent)。', + )} + + } + placeholder={'curl\nPostmanRuntime\nMyApp/…'} + autosize={{ minRows: 3, maxRows: 8 }} + /> + + + + +
+ + + + + {t('该规则的缓存保留时长;0 表示使用默认 TTL:')} + {effectiveDefaultTTLSeconds} + {t(' 秒。')} + + } + /> + + + + + + + + {t( + '开启后,using_group 会参与 cache key(不同分组隔离)。', + )} + + + + + + {t('开启后,规则名称会参与 cache key(不同规则隔离)。')} + + + + + + + + + {t('Key 来源')} + + + + {t( + 'context_int/context_string 从请求上下文读取;gjson 从入口请求的 JSON body 按 gjson path 读取。', + )} + +
+ + {t('常用上下文 Key(用于 context_*)')}: + +
+ {(CONTEXT_KEY_PRESETS || []).map((x) => ( + + {x.label} + + ))} +
+
+ +
( + + updateKeySource( + idx, + isGjson ? { path: value } : { key: value }, + ) + } + /> + ); + }, + }, + { + title: t('操作'), + width: 90, + render: (_, __, idx) => ( +