diff --git a/.env.example b/.env.example index f4b9d02ee..0a64758dd 100644 --- a/.env.example +++ b/.env.example @@ -85,3 +85,8 @@ LINUX_DO_USER_ENDPOINT=https://connect.linux.do/api/user # 节点类型 # 如果是主节点则为master # NODE_TYPE=master + +# 可信任重定向域名列表(逗号分隔,支持子域名匹配) +# 用于验证支付成功/取消回调URL的域名安全性 +# 示例: example.com,myapp.io 将允许 example.com, sub.example.com, myapp.io 等 +# TRUSTED_REDIRECT_DOMAINS=example.com,myapp.io diff --git a/common/init.go b/common/init.go index 9501ce3be..cf62f408e 100644 --- a/common/init.go +++ b/common/init.go @@ -159,4 +159,17 @@ func initConstantEnv() { } constant.TaskPricePatches = taskPricePatches } + + // Initialize trusted redirect domains for URL validation + trustedDomainsStr := GetEnvOrDefaultString("TRUSTED_REDIRECT_DOMAINS", "") + var trustedDomains []string + domains := strings.Split(trustedDomainsStr, ",") + for _, domain := range domains { + trimmedDomain := strings.TrimSpace(domain) + if trimmedDomain != "" { + // Normalize domain to lowercase + trustedDomains = append(trustedDomains, strings.ToLower(trimmedDomain)) + } + } + constant.TrustedRedirectDomains = trustedDomains } diff --git a/common/url_validator.go b/common/url_validator.go new file mode 100644 index 000000000..151f643f1 --- /dev/null +++ b/common/url_validator.go @@ -0,0 +1,39 @@ +package common + +import ( + "fmt" + "net/url" + "strings" + + "github.com/QuantumNous/new-api/constant" +) + +// ValidateRedirectURL validates that a redirect URL is safe to use. +// It checks that: +// - The URL is properly formatted +// - The scheme is either http or https +// - The domain is in the trusted domains list (exact match or subdomain) +// +// Returns nil if the URL is valid and trusted, otherwise returns an error +// describing why the validation failed. +func ValidateRedirectURL(rawURL string) error { + // Parse the URL + parsedURL, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL format: %s", err.Error()) + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return fmt.Errorf("invalid URL scheme: only http and https are allowed") + } + + domain := strings.ToLower(parsedURL.Hostname()) + + for _, trustedDomain := range constant.TrustedRedirectDomains { + if domain == trustedDomain || strings.HasSuffix(domain, "."+trustedDomain) { + return nil + } + } + + return fmt.Errorf("domain %s is not in the trusted domains list", domain) +} diff --git a/common/url_validator_test.go b/common/url_validator_test.go new file mode 100644 index 000000000..b87b6787e --- /dev/null +++ b/common/url_validator_test.go @@ -0,0 +1,134 @@ +package common + +import ( + "testing" + + "github.com/QuantumNous/new-api/constant" +) + +func TestValidateRedirectURL(t *testing.T) { + // Save original trusted domains and restore after test + originalDomains := constant.TrustedRedirectDomains + defer func() { + constant.TrustedRedirectDomains = originalDomains + }() + + tests := []struct { + name string + url string + trustedDomains []string + wantErr bool + errContains string + }{ + // Valid cases + { + name: "exact domain match with https", + url: "https://example.com/success", + trustedDomains: []string{"example.com"}, + wantErr: false, + }, + { + name: "exact domain match with http", + url: "http://example.com/callback", + trustedDomains: []string{"example.com"}, + wantErr: false, + }, + { + name: "subdomain match", + url: "https://sub.example.com/success", + trustedDomains: []string{"example.com"}, + wantErr: false, + }, + { + name: "case insensitive domain", + url: "https://EXAMPLE.COM/success", + trustedDomains: []string{"example.com"}, + wantErr: false, + }, + + // Invalid cases - untrusted domain + { + name: "untrusted domain", + url: "https://evil.com/phishing", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "not in the trusted domains list", + }, + { + name: "suffix attack - fakeexample.com", + url: "https://fakeexample.com/success", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "not in the trusted domains list", + }, + { + name: "empty trusted domains list", + url: "https://example.com/success", + trustedDomains: []string{}, + wantErr: true, + errContains: "not in the trusted domains list", + }, + + // Invalid cases - scheme + { + name: "javascript scheme", + url: "javascript:alert('xss')", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "invalid URL scheme", + }, + { + name: "data scheme", + url: "data:text/html,", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "invalid URL scheme", + }, + + // Edge cases + { + name: "empty URL", + url: "", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "invalid URL scheme", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up trusted domains for this test case + constant.TrustedRedirectDomains = tt.trustedDomains + + err := ValidateRedirectURL(tt.url) + + if tt.wantErr { + if err == nil { + t.Errorf("ValidateRedirectURL(%q) expected error containing %q, got nil", tt.url, tt.errContains) + return + } + if tt.errContains != "" && !contains(err.Error(), tt.errContains) { + t.Errorf("ValidateRedirectURL(%q) error = %q, want error containing %q", tt.url, err.Error(), tt.errContains) + } + } else { + if err != nil { + t.Errorf("ValidateRedirectURL(%q) unexpected error: %v", tt.url, err) + } + } + }) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/constant/env.go b/constant/env.go index c561c207d..873427a7c 100644 --- a/constant/env.go +++ b/constant/env.go @@ -20,3 +20,7 @@ var TaskQueryLimit int // temporary variable for sora patch, will be removed in future var TaskPricePatches []string + +// TrustedRedirectDomains is a list of trusted domains for redirect URL validation. +// Domains support subdomain matching (e.g., "example.com" matches "sub.example.com"). +var TrustedRedirectDomains []string diff --git a/controller/channel_affinity_cache.go b/controller/channel_affinity_cache.go index bb5cab20a..a72b04b8b 100644 --- a/controller/channel_affinity_cache.go +++ b/controller/channel_affinity_cache.go @@ -58,3 +58,31 @@ func ClearChannelAffinityCache(c *gin.Context) { }, }) } + +func GetChannelAffinityUsageCacheStats(c *gin.Context) { + ruleName := strings.TrimSpace(c.Query("rule_name")) + usingGroup := strings.TrimSpace(c.Query("using_group")) + keyFp := strings.TrimSpace(c.Query("key_fp")) + + if ruleName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "missing param: rule_name", + }) + return + } + if keyFp == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "missing param: key_fp", + }) + return + } + + stats := service.GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": stats, + }) +} diff --git a/controller/performance_unix.go b/controller/performance_unix.go index 3421b6acf..b6ff62d2a 100644 --- a/controller/performance_unix.go +++ b/controller/performance_unix.go @@ -24,10 +24,11 @@ func getDiskSpaceInfo() DiskSpaceInfo { return info } - // 计算磁盘空间 - info.Total = stat.Blocks * uint64(stat.Bsize) - info.Free = stat.Bavail * uint64(stat.Bsize) - info.Used = info.Total - stat.Bfree*uint64(stat.Bsize) + // 计算磁盘空间 (显式转换以兼容 FreeBSD,其字段类型为 int64) + bsize := uint64(stat.Bsize) + info.Total = uint64(stat.Blocks) * bsize + info.Free = uint64(stat.Bavail) * bsize + info.Used = info.Total - uint64(stat.Bfree)*bsize if info.Total > 0 { info.UsedPercent = float64(info.Used) / float64(info.Total) * 100 diff --git a/controller/relay.go b/controller/relay.go index 3a929d8d7..78d21e54f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -311,6 +311,9 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b if openaiErr == nil { return false } + if service.ShouldSkipRetryAfterChannelAffinityFailure(c) { + return false + } if types.IsChannelError(openaiErr) { return true } @@ -514,6 +517,9 @@ func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, if taskErr == nil { return false } + if service.ShouldSkipRetryAfterChannelAffinityFailure(c) { + return false + } if retryTimes <= 0 { return false } diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index 01479e8a8..e1718cc5e 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -29,9 +29,18 @@ const ( var stripeAdaptor = &StripeAdaptor{} +// StripePayRequest represents a payment request for Stripe checkout. type StripePayRequest struct { - Amount int64 `json:"amount"` + // Amount is the quantity of units to purchase. + Amount int64 `json:"amount"` + // PaymentMethod specifies the payment method (e.g., "stripe"). PaymentMethod string `json:"payment_method"` + // SuccessURL is the optional custom URL to redirect after successful payment. + // If empty, defaults to the server's console log page. + SuccessURL string `json:"success_url,omitempty"` + // CancelURL is the optional custom URL to redirect when payment is canceled. + // If empty, defaults to the server's console topup page. + CancelURL string `json:"cancel_url,omitempty"` } type StripeAdaptor struct { @@ -70,6 +79,16 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) { return } + if req.SuccessURL != "" && common.ValidateRedirectURL(req.SuccessURL) != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": "支付成功重定向URL不在可信任域名列表中", "data": ""}) + return + } + + if req.CancelURL != "" && common.ValidateRedirectURL(req.CancelURL) != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": "支付取消重定向URL不在可信任域名列表中", "data": ""}) + return + } + id := c.GetInt("id") user, _ := model.GetUserById(id, false) chargedMoney := GetChargedAmount(float64(req.Amount), *user) @@ -77,7 +96,7 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) { reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4)) referenceId := "ref_" + common.Sha1([]byte(reference)) - payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount) + payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount, req.SuccessURL, req.CancelURL) if err != nil { log.Println("获取Stripe Checkout支付链接失败", err) c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) @@ -237,17 +256,37 @@ func sessionExpired(event stripe.Event) { log.Println("充值订单已过期", referenceId) } -func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) { +// genStripeLink generates a Stripe Checkout session URL for payment. +// It creates a new checkout session with the specified parameters and returns the payment URL. +// +// Parameters: +// - referenceId: unique reference identifier for the transaction +// - customerId: existing Stripe customer ID (empty string if new customer) +// - email: customer email address for new customer creation +// - amount: quantity of units to purchase +// - successURL: custom URL to redirect after successful payment (empty for default) +// - cancelURL: custom URL to redirect when payment is canceled (empty for default) +// +// Returns the checkout session URL or an error if the session creation fails. +func genStripeLink(referenceId string, customerId string, email string, amount int64, successURL string, cancelURL string) (string, error) { if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") { return "", fmt.Errorf("无效的Stripe API密钥") } stripe.Key = setting.StripeApiSecret + // Use custom URLs if provided, otherwise use defaults + if successURL == "" { + successURL = system_setting.ServerAddress + "/console/log" + } + if cancelURL == "" { + cancelURL = system_setting.ServerAddress + "/console/topup" + } + params := &stripe.CheckoutSessionParams{ ClientReferenceID: stripe.String(referenceId), - SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"), - CancelURL: stripe.String(system_setting.ServerAddress + "/console/topup"), + SuccessURL: stripe.String(successURL), + CancelURL: stripe.String(cancelURL), LineItems: []*stripe.CheckoutSessionLineItemParams{ { Price: stripe.String(setting.StripePriceId), diff --git a/dto/gemini.go b/dto/gemini.go index 17881c521..b330f8b1b 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -449,11 +449,12 @@ type GeminiChatResponse struct { } type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - ThoughtsTokenCount int `json:"thoughtsTokenCount"` - PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + ThoughtsTokenCount int `json:"thoughtsTokenCount"` + CachedContentTokenCount int `json:"cachedContentTokenCount"` + PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` } type GeminiPromptTokensDetails struct { diff --git a/main.go b/main.go index 0964530e1..47f966da2 100644 --- a/main.go +++ b/main.go @@ -19,7 +19,7 @@ import ( "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/router" "github.com/QuantumNous/new-api/service" - _ "github.com/QuantumNous/new-api/setting/performance_setting" // 注册性能设置 + _ "github.com/QuantumNous/new-api/setting/performance_setting" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/bytedance/gopkg/util/gopool" diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index cd9d06db2..39485b16f 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -49,6 +49,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount + usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { if detail.Modality == "AUDIO" { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index da114b64f..8edbe6d8a 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1251,6 +1251,7 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount + usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { if detail.Modality == "AUDIO" { usage.PromptTokensDetails.AudioTokens = detail.TokenCount @@ -1395,6 +1396,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, } usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount + usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { if detail.Modality == "AUDIO" { usage.PromptTokensDetails.AudioTokens = detail.TokenCount @@ -1447,6 +1449,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount + usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 6051c7e8b..6ebecb3c0 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -30,6 +30,7 @@ type ContentItem struct { Text string `json:"text,omitempty"` // for text type ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type Video *VideoReference `json:"video,omitempty"` // for video (sample) type + Role string `json:"role,omitempty"` // reference_image / first_frame / last_frame } type ImageURL struct { diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index b2706730d..74abfe5b2 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -219,6 +219,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types } func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) { + originUsage := usage if usage == nil { usage = &dto.Usage{ PromptTokens: relayInfo.GetEstimatePromptTokens(), @@ -228,6 +229,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage extraContent = append(extraContent, "上游无计费信息") } + if originUsage != nil { + service.ObserveChannelAffinityUsageCacheFromContext(ctx, usage) + } + adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason) useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() diff --git a/relay/relay_task.go b/relay/relay_task.go index 61588f93b..ebbd1f65d 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -144,7 +144,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. if !success { defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName] if !ok { - modelPrice = 0.1 + modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit } else { modelPrice = defaultPrice } diff --git a/router/api-router.go b/router/api-router.go index 5587ac7ad..50c817f49 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -251,6 +251,7 @@ func SetApiRouter(router *gin.Engine) { logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs) logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) + logRoute.GET("/channel_affinity_usage_cache", middleware.AdminAuth(), controller.GetChannelAffinityUsageCacheStats) logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) diff --git a/service/channel_affinity.go b/service/channel_affinity.go index 5aa50adb6..a94eb29e8 100644 --- a/service/channel_affinity.go +++ b/service/channel_affinity.go @@ -2,6 +2,7 @@ package service import ( "fmt" + "hash/fnv" "regexp" "strconv" "strings" @@ -9,6 +10,7 @@ import ( "time" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/pkg/cachex" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/gin-gonic/gin" @@ -21,14 +23,19 @@ const ( ginKeyChannelAffinityTTLSeconds = "channel_affinity_ttl_seconds" ginKeyChannelAffinityMeta = "channel_affinity_meta" ginKeyChannelAffinityLogInfo = "channel_affinity_log_info" + ginKeyChannelAffinitySkipRetry = "channel_affinity_skip_retry_on_failure" - channelAffinityCacheNamespace = "new-api:channel_affinity:v1" + channelAffinityCacheNamespace = "new-api:channel_affinity:v1" + channelAffinityUsageCacheStatsNamespace = "new-api:channel_affinity_usage_cache_stats:v1" ) var ( channelAffinityCacheOnce sync.Once channelAffinityCache *cachex.HybridCache[int] + channelAffinityUsageCacheStatsOnce sync.Once + channelAffinityUsageCacheStatsCache *cachex.HybridCache[ChannelAffinityUsageCacheCounters] + channelAffinityRegexCache sync.Map // map[string]*regexp.Regexp ) @@ -36,15 +43,24 @@ type channelAffinityMeta struct { CacheKey string TTLSeconds int RuleName string + SkipRetry bool KeySourceType string KeySourceKey string KeySourcePath string + KeyHint string KeyFingerprint string UsingGroup string ModelName string RequestPath string } +type ChannelAffinityStatsContext struct { + RuleName string + UsingGroup string + KeyFingerprint string + TTLSeconds int64 +} + type ChannelAffinityCacheStats struct { Enabled bool `json:"enabled"` Total int `json:"total"` @@ -338,6 +354,32 @@ func getChannelAffinityMeta(c *gin.Context) (channelAffinityMeta, bool) { return meta, true } +func GetChannelAffinityStatsContext(c *gin.Context) (ChannelAffinityStatsContext, bool) { + if c == nil { + return ChannelAffinityStatsContext{}, false + } + meta, ok := getChannelAffinityMeta(c) + if !ok { + return ChannelAffinityStatsContext{}, false + } + ruleName := strings.TrimSpace(meta.RuleName) + keyFp := strings.TrimSpace(meta.KeyFingerprint) + usingGroup := strings.TrimSpace(meta.UsingGroup) + if ruleName == "" || keyFp == "" { + return ChannelAffinityStatsContext{}, false + } + ttlSeconds := int64(meta.TTLSeconds) + if ttlSeconds <= 0 { + return ChannelAffinityStatsContext{}, false + } + return ChannelAffinityStatsContext{ + RuleName: ruleName, + UsingGroup: usingGroup, + KeyFingerprint: keyFp, + TTLSeconds: ttlSeconds, + }, true +} + func affinityFingerprint(s string) string { if s == "" { return "" @@ -349,6 +391,19 @@ func affinityFingerprint(s string) string { return hex } +func buildChannelAffinityKeyHint(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + s = strings.ReplaceAll(s, "\n", " ") + s = strings.ReplaceAll(s, "\r", " ") + if len(s) <= 12 { + return s + } + return s[:4] + "..." + s[len(s)-4:] +} + func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) { setting := operation_setting.GetChannelAffinitySetting() if setting == nil || !setting.Enabled { @@ -399,9 +454,11 @@ func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup CacheKey: cacheKeyFull, TTLSeconds: ttlSeconds, RuleName: rule.Name, + SkipRetry: rule.SkipRetryOnFailure, KeySourceType: strings.TrimSpace(usedSource.Type), KeySourceKey: strings.TrimSpace(usedSource.Key), KeySourcePath: strings.TrimSpace(usedSource.Path), + KeyHint: buildChannelAffinityKeyHint(affinityValue), KeyFingerprint: affinityFingerprint(affinityValue), UsingGroup: usingGroup, ModelName: modelName, @@ -422,6 +479,21 @@ func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup return 0, false } +func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool { + if c == nil { + return false + } + v, ok := c.Get(ginKeyChannelAffinitySkipRetry) + if !ok { + return false + } + b, ok := v.(bool) + if !ok { + return false + } + return b +} + func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) { if c == nil || channelID <= 0 { return @@ -430,6 +502,7 @@ func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int if !ok { return } + c.Set(ginKeyChannelAffinitySkipRetry, meta.SkipRetry) info := map[string]interface{}{ "reason": meta.RuleName, "rule_name": meta.RuleName, @@ -441,6 +514,7 @@ func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int "key_source": meta.KeySourceType, "key_key": meta.KeySourceKey, "key_path": meta.KeySourcePath, + "key_hint": meta.KeyHint, "key_fp": meta.KeyFingerprint, } c.Set(ginKeyChannelAffinityLogInfo, info) @@ -485,3 +559,225 @@ func RecordChannelAffinity(c *gin.Context, channelID int) { common.SysError(fmt.Sprintf("channel affinity cache set failed: key=%s, err=%v", cacheKey, err)) } } + +type ChannelAffinityUsageCacheStats struct { + RuleName string `json:"rule_name"` + UsingGroup string `json:"using_group"` + KeyFingerprint string `json:"key_fp"` + + Hit int64 `json:"hit"` + Total int64 `json:"total"` + WindowSeconds int64 `json:"window_seconds"` + + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + CachedTokens int64 `json:"cached_tokens"` + PromptCacheHitTokens int64 `json:"prompt_cache_hit_tokens"` + LastSeenAt int64 `json:"last_seen_at"` +} + +type ChannelAffinityUsageCacheCounters struct { + Hit int64 `json:"hit"` + Total int64 `json:"total"` + WindowSeconds int64 `json:"window_seconds"` + + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` + CachedTokens int64 `json:"cached_tokens"` + PromptCacheHitTokens int64 `json:"prompt_cache_hit_tokens"` + LastSeenAt int64 `json:"last_seen_at"` +} + +var channelAffinityUsageCacheStatsLocks [64]sync.Mutex + +func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage) { + statsCtx, ok := GetChannelAffinityStatsContext(c) + if !ok { + return + } + observeChannelAffinityUsageCache(statsCtx, usage) +} + +func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) ChannelAffinityUsageCacheStats { + ruleName = strings.TrimSpace(ruleName) + usingGroup = strings.TrimSpace(usingGroup) + keyFp = strings.TrimSpace(keyFp) + + entryKey := channelAffinityUsageCacheEntryKey(ruleName, usingGroup, keyFp) + if entryKey == "" { + return ChannelAffinityUsageCacheStats{ + RuleName: ruleName, + UsingGroup: usingGroup, + KeyFingerprint: keyFp, + } + } + + cache := getChannelAffinityUsageCacheStatsCache() + v, found, err := cache.Get(entryKey) + if err != nil || !found { + return ChannelAffinityUsageCacheStats{ + RuleName: ruleName, + UsingGroup: usingGroup, + KeyFingerprint: keyFp, + } + } + return ChannelAffinityUsageCacheStats{ + RuleName: ruleName, + UsingGroup: usingGroup, + KeyFingerprint: keyFp, + Hit: v.Hit, + Total: v.Total, + WindowSeconds: v.WindowSeconds, + PromptTokens: v.PromptTokens, + CompletionTokens: v.CompletionTokens, + TotalTokens: v.TotalTokens, + CachedTokens: v.CachedTokens, + PromptCacheHitTokens: v.PromptCacheHitTokens, + LastSeenAt: v.LastSeenAt, + } +} + +func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage) { + entryKey := channelAffinityUsageCacheEntryKey(statsCtx.RuleName, statsCtx.UsingGroup, statsCtx.KeyFingerprint) + if entryKey == "" { + return + } + + windowSeconds := statsCtx.TTLSeconds + if windowSeconds <= 0 { + return + } + + cache := getChannelAffinityUsageCacheStatsCache() + ttl := time.Duration(windowSeconds) * time.Second + + lock := channelAffinityUsageCacheStatsLock(entryKey) + lock.Lock() + defer lock.Unlock() + + prev, found, err := cache.Get(entryKey) + if err != nil { + return + } + next := prev + if !found { + next = ChannelAffinityUsageCacheCounters{} + } + next.Total++ + hit, cachedTokens, promptCacheHitTokens := usageCacheSignals(usage) + if hit { + next.Hit++ + } + next.WindowSeconds = windowSeconds + next.LastSeenAt = time.Now().Unix() + next.CachedTokens += cachedTokens + next.PromptCacheHitTokens += promptCacheHitTokens + next.PromptTokens += int64(usagePromptTokens(usage)) + next.CompletionTokens += int64(usageCompletionTokens(usage)) + next.TotalTokens += int64(usageTotalTokens(usage)) + _ = cache.SetWithTTL(entryKey, next, ttl) +} + +func channelAffinityUsageCacheEntryKey(ruleName, usingGroup, keyFp string) string { + ruleName = strings.TrimSpace(ruleName) + usingGroup = strings.TrimSpace(usingGroup) + keyFp = strings.TrimSpace(keyFp) + if ruleName == "" || keyFp == "" { + return "" + } + return ruleName + "\n" + usingGroup + "\n" + keyFp +} + +func usageCacheSignals(usage *dto.Usage) (hit bool, cachedTokens int64, promptCacheHitTokens int64) { + if usage == nil { + return false, 0, 0 + } + + cached := int64(0) + if usage.PromptTokensDetails.CachedTokens > 0 { + cached = int64(usage.PromptTokensDetails.CachedTokens) + } else if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { + cached = int64(usage.InputTokensDetails.CachedTokens) + } + pcht := int64(0) + if usage.PromptCacheHitTokens > 0 { + pcht = int64(usage.PromptCacheHitTokens) + } + return cached > 0 || pcht > 0, cached, pcht +} + +func usagePromptTokens(usage *dto.Usage) int { + if usage == nil { + return 0 + } + if usage.PromptTokens > 0 { + return usage.PromptTokens + } + return usage.InputTokens +} + +func usageCompletionTokens(usage *dto.Usage) int { + if usage == nil { + return 0 + } + if usage.CompletionTokens > 0 { + return usage.CompletionTokens + } + return usage.OutputTokens +} + +func usageTotalTokens(usage *dto.Usage) int { + if usage == nil { + return 0 + } + if usage.TotalTokens > 0 { + return usage.TotalTokens + } + pt := usagePromptTokens(usage) + ct := usageCompletionTokens(usage) + if pt > 0 || ct > 0 { + return pt + ct + } + return 0 +} + +func getChannelAffinityUsageCacheStatsCache() *cachex.HybridCache[ChannelAffinityUsageCacheCounters] { + channelAffinityUsageCacheStatsOnce.Do(func() { + setting := operation_setting.GetChannelAffinitySetting() + capacity := 100_000 + defaultTTLSeconds := 3600 + if setting != nil { + if setting.MaxEntries > 0 { + capacity = setting.MaxEntries + } + if setting.DefaultTTLSeconds > 0 { + defaultTTLSeconds = setting.DefaultTTLSeconds + } + } + + channelAffinityUsageCacheStatsCache = cachex.NewHybridCache[ChannelAffinityUsageCacheCounters](cachex.HybridCacheConfig[ChannelAffinityUsageCacheCounters]{ + Namespace: cachex.Namespace(channelAffinityUsageCacheStatsNamespace), + Redis: common.RDB, + RedisEnabled: func() bool { + return common.RedisEnabled && common.RDB != nil + }, + RedisCodec: cachex.JSONCodec[ChannelAffinityUsageCacheCounters]{}, + Memory: func() *hot.HotCache[string, ChannelAffinityUsageCacheCounters] { + return hot.NewHotCache[string, ChannelAffinityUsageCacheCounters](hot.LRU, capacity). + WithTTL(time.Duration(defaultTTLSeconds) * time.Second). + WithJanitor(). + Build() + }, + }) + }) + return channelAffinityUsageCacheStatsCache +} + +func channelAffinityUsageCacheStatsLock(key string) *sync.Mutex { + h := fnv.New32a() + _, _ = h.Write([]byte(key)) + idx := h.Sum32() % uint32(len(channelAffinityUsageCacheStatsLocks)) + return &channelAffinityUsageCacheStatsLocks[idx] +} diff --git a/setting/operation_setting/channel_affinity_setting.go b/setting/operation_setting/channel_affinity_setting.go index f95ac6969..7173f7b78 100644 --- a/setting/operation_setting/channel_affinity_setting.go +++ b/setting/operation_setting/channel_affinity_setting.go @@ -18,6 +18,8 @@ type ChannelAffinityRule struct { ValueRegex string `json:"value_regex"` TTLSeconds int `json:"ttl_seconds"` + SkipRetryOnFailure bool `json:"skip_retry_on_failure,omitempty"` + IncludeUsingGroup bool `json:"include_using_group"` IncludeRuleName bool `json:"include_rule_name"` } diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index cf54cb313..665c2f593 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -8,6 +8,8 @@ import ( ) var defaultCacheRatio = map[string]float64{ + "gemini-3-flash-preview": 0.25, + "gemini-3-pro-preview": 0.25, "gpt-4": 0.5, "o1": 0.5, "o1-2024-12-17": 0.5, diff --git a/web/src/components/playground/CodeViewer.jsx b/web/src/components/playground/CodeViewer.jsx index 9d8ae453a..ce21d43cc 100644 --- a/web/src/components/playground/CodeViewer.jsx +++ b/web/src/components/playground/CodeViewer.jsx @@ -106,6 +106,21 @@ const highlightJson = (str) => { ); }; +const linkRegex = /(https?:\/\/[^\s<"'\]),;}]+)/g; + +const linkifyHtml = (html) => { + const parts = html.split(/(<[^>]+>)/g); + return parts + .map((part) => { + if (part.startsWith('<')) return part; + return part.replace( + linkRegex, + (url) => `${url}`, + ); + }) + .join(''); +}; + const isJsonLike = (content, language) => { if (language === 'json') return true; const trimmed = content.trim(); @@ -179,6 +194,10 @@ const CodeViewer = ({ content, title, language = 'json' }) => { return displayContent; }, [displayContent, language, contentMetrics.isVeryLarge, isExpanded]); + const renderedContent = useMemo(() => { + return linkifyHtml(highlightedContent); + }, [highlightedContent]); + const handleCopy = useCallback(async () => { try { const textToCopy = @@ -276,6 +295,8 @@ const CodeViewer = ({ content, title, language = 'json' }) => { style={{ ...codeThemeStyles.content, paddingTop: contentPadding, + whiteSpace: 'pre-wrap', + wordBreak: 'break-word', }} className='model-settings-scroll' > @@ -303,7 +324,7 @@ const CodeViewer = ({ content, title, language = 'json' }) => { {t('正在处理大内容...')} ) : ( -
+ )} diff --git a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx index 96f2214a6..941f47004 100644 --- a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx +++ b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx @@ -20,6 +20,7 @@ For commercial licensing, please contact support@quantumnous.com import React from 'react'; import { Avatar, + Button, Space, Tag, Tooltip, @@ -71,6 +72,34 @@ function formatRatio(ratio) { return String(ratio); } +function buildChannelAffinityTooltip(affinity, t) { + if (!affinity) { + return null; + } + + const keySource = affinity.key_source || '-'; + const keyPath = affinity.key_path || affinity.key_key || '-'; + const keyHint = affinity.key_hint || ''; + const keyFp = affinity.key_fp ? `#${affinity.key_fp}` : ''; + const keyText = `${keySource}:${keyPath}${keyFp}`; + + const lines = [ + t('渠道亲和性'), + `${t('规则')}:${affinity.rule_name || '-'}`, + `${t('分组')}:${affinity.selected_group || '-'}`, + `${t('Key')}:${keyText}`, + ...(keyHint ? [`${t('Key 摘要')}:${keyHint}`] : []), + ]; + + return ( +