diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index c87fcfceb..3e3ddc53b 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -38,21 +38,21 @@ jobs: - name: Build Backend (amd64) run: | go mod download - go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api + go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api - name: Build Backend (arm64) run: | sudo apt-get update DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu - CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64 + CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api-arm64 - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: files: | - one-api - one-api-arm64 + new-api + new-api-arm64 draft: true generate_release_notes: true env: diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 1bc786ac0..8eaf2d67a 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -39,12 +39,12 @@ jobs: - name: Build Backend run: | go mod download - go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos + go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o new-api-macos - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: - files: one-api-macos + files: new-api-macos draft: true generate_release_notes: true env: diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index de3d83d5e..30e864f34 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -41,12 +41,12 @@ jobs: - name: Build Backend run: | go mod download - go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe + go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o new-api.exe - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: - files: one-api.exe + files: new-api.exe draft: true generate_release_notes: true env: diff --git a/common/ip.go b/common/ip.go new file mode 100644 index 000000000..bfb64ee7f --- /dev/null +++ b/common/ip.go @@ -0,0 +1,22 @@ +package common + +import "net" + +func IsPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + private := []net.IPNet{ + {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, + {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, + {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, + } + + for _, privateNet := range private { + if privateNet.Contains(ip) { + return true + } + } + return false +} diff --git a/common/ssrf_protection.go b/common/ssrf_protection.go new file mode 100644 index 000000000..6f7d289f1 --- /dev/null +++ b/common/ssrf_protection.go @@ -0,0 +1,327 @@ +package common + +import ( + "fmt" + "net" + "net/url" + "strconv" + "strings" +) + +// SSRFProtection SSRF防护配置 +type SSRFProtection struct { + AllowPrivateIp bool + DomainFilterMode bool // true: 白名单, false: 黑名单 + DomainList []string // domain format, e.g. example.com, *.example.com + IpFilterMode bool // true: 白名单, false: 黑名单 + IpList []string // CIDR or single IP + AllowedPorts []int // 允许的端口范围 + ApplyIPFilterForDomain bool // 对域名启用IP过滤 +} + +// DefaultSSRFProtection 默认SSRF防护配置 +var DefaultSSRFProtection = &SSRFProtection{ + AllowPrivateIp: false, + DomainFilterMode: true, + DomainList: []string{}, + IpFilterMode: true, + IpList: []string{}, + AllowedPorts: []int{}, +} + +// isPrivateIP 检查IP是否为私有地址 +func isPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + // 检查私有网段 + private := []net.IPNet{ + {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8 + {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12 + {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16 + {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8 + {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地) + {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播) + {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留) + } + + for _, privateNet := range private { + if privateNet.Contains(ip) { + return true + } + } + + // 检查IPv6私有地址 + if ip.To4() == nil { + // IPv6 loopback + if ip.Equal(net.IPv6loopback) { + return true + } + // IPv6 link-local + if strings.HasPrefix(ip.String(), "fe80:") { + return true + } + // IPv6 unique local + if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") { + return true + } + } + + return false +} + +// parsePortRanges 解析端口范围配置 +// 支持格式: "80", "443", "8000-9000" +func parsePortRanges(portConfigs []string) ([]int, error) { + var ports []int + + for _, config := range portConfigs { + config = strings.TrimSpace(config) + if config == "" { + continue + } + + if strings.Contains(config, "-") { + // 处理端口范围 "8000-9000" + parts := strings.Split(config, "-") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid port range format: %s", config) + } + + startPort, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + return nil, fmt.Errorf("invalid start port in range %s: %v", config, err) + } + + endPort, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + return nil, fmt.Errorf("invalid end port in range %s: %v", config, err) + } + + if startPort > endPort { + return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config) + } + + if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 { + return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config) + } + + // 添加范围内的所有端口 + for port := startPort; port <= endPort; port++ { + ports = append(ports, port) + } + } else { + // 处理单个端口 "80" + port, err := strconv.Atoi(config) + if err != nil { + return nil, fmt.Errorf("invalid port number: %s", config) + } + + if port < 1 || port > 65535 { + return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port) + } + + ports = append(ports, port) + } + } + + return ports, nil +} + +// isAllowedPort 检查端口是否被允许 +func (p *SSRFProtection) isAllowedPort(port int) bool { + if len(p.AllowedPorts) == 0 { + return true // 如果没有配置端口限制,则允许所有端口 + } + + for _, allowedPort := range p.AllowedPorts { + if port == allowedPort { + return true + } + } + return false +} + +// isDomainWhitelisted 检查域名是否在白名单中 +func isDomainListed(domain string, list []string) bool { + if len(list) == 0 { + return false + } + + domain = strings.ToLower(domain) + for _, item := range list { + item = strings.ToLower(strings.TrimSpace(item)) + if item == "" { + continue + } + // 精确匹配 + if domain == item { + return true + } + // 通配符匹配 (*.example.com) + if strings.HasPrefix(item, "*.") { + suffix := strings.TrimPrefix(item, "*.") + if strings.HasSuffix(domain, "."+suffix) || domain == suffix { + return true + } + } + } + return false +} + +func (p *SSRFProtection) isDomainAllowed(domain string) bool { + listed := isDomainListed(domain, p.DomainList) + if p.DomainFilterMode { // 白名单 + return listed + } + // 黑名单 + return !listed +} + +// isIPWhitelisted 检查IP是否在白名单中 + +func isIPListed(ip net.IP, list []string) bool { + if len(list) == 0 { + return false + } + + for _, whitelistCIDR := range list { + _, network, err := net.ParseCIDR(whitelistCIDR) + if err != nil { + // 尝试作为单个IP处理 + if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil { + if ip.Equal(whitelistIP) { + return true + } + } + continue + } + + if network.Contains(ip) { + return true + } + } + return false +} + +// IsIPAccessAllowed 检查IP是否允许访问 +func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool { + // 私有IP限制 + if isPrivateIP(ip) && !p.AllowPrivateIp { + return false + } + + listed := isIPListed(ip, p.IpList) + if p.IpFilterMode { // 白名单 + return listed + } + // 黑名单 + return !listed +} + +// ValidateURL 验证URL是否安全 +func (p *SSRFProtection) ValidateURL(urlStr string) error { + // 解析URL + u, err := url.Parse(urlStr) + if err != nil { + return fmt.Errorf("invalid URL format: %v", err) + } + + // 只允许HTTP/HTTPS协议 + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme) + } + + // 解析主机和端口 + host, portStr, err := net.SplitHostPort(u.Host) + if err != nil { + // 没有端口,使用默认端口 + host = u.Hostname() + if u.Scheme == "https" { + portStr = "443" + } else { + portStr = "80" + } + } + + // 验证端口 + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("invalid port: %s", portStr) + } + + if !p.isAllowedPort(port) { + return fmt.Errorf("port %d is not allowed", port) + } + + // 如果 host 是 IP,则跳过域名检查 + if ip := net.ParseIP(host); ip != nil { + if !p.IsIPAccessAllowed(ip) { + if isPrivateIP(ip) { + return fmt.Errorf("private IP address not allowed: %s", ip.String()) + } + if p.IpFilterMode { + return fmt.Errorf("ip not in whitelist: %s", ip.String()) + } + return fmt.Errorf("ip in blacklist: %s", ip.String()) + } + return nil + } + + // 先进行域名过滤 + if !p.isDomainAllowed(host) { + if p.DomainFilterMode { + return fmt.Errorf("domain not in whitelist: %s", host) + } + return fmt.Errorf("domain in blacklist: %s", host) + } + + // 若未启用对域名应用IP过滤,则到此通过 + if !p.ApplyIPFilterForDomain { + return nil + } + + // 解析域名对应IP并检查 + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("DNS resolution failed for %s: %v", host, err) + } + for _, ip := range ips { + if !p.IsIPAccessAllowed(ip) { + if isPrivateIP(ip) && !p.AllowPrivateIp { + return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String()) + } + if p.IpFilterMode { + return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String()) + } + return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String()) + } + } + return nil +} + +// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL +func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error { + // 如果SSRF防护被禁用,直接返回成功 + if !enableSSRFProtection { + return nil + } + + // 解析端口范围配置 + allowedPortInts, err := parsePortRanges(allowedPorts) + if err != nil { + return fmt.Errorf("request reject - invalid port configuration: %v", err) + } + + protection := &SSRFProtection{ + AllowPrivateIp: allowPrivateIp, + DomainFilterMode: domainFilterMode, + DomainList: domainList, + IpFilterMode: ipFilterMode, + IpList: ipList, + AllowedPorts: allowedPortInts, + ApplyIPFilterForDomain: applyIPFilterForDomain, + } + return protection.ValidateURL(urlStr) +} diff --git a/constant/task.go b/constant/task.go index 21790145b..e174fd60e 100644 --- a/constant/task.go +++ b/constant/task.go @@ -11,8 +11,10 @@ const ( SunoActionMusic = "MUSIC" SunoActionLyrics = "LYRICS" - TaskActionGenerate = "generate" - TaskActionTextGenerate = "textGenerate" + TaskActionGenerate = "generate" + TaskActionTextGenerate = "textGenerate" + TaskActionFirstTailGenerate = "firstTailGenerate" + TaskActionReferenceGenerate = "referenceGenerate" ) var SunoModel2Action = map[string]string{ diff --git a/controller/channel-test.go b/controller/channel-test.go index 5a668c488..9ea6eed75 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -90,6 +90,11 @@ func testChannel(channel *model.Channel, testModel string) testResult { requestPath = "/v1/embeddings" // 修改请求路径 } + // VolcEngine 图像生成模型 + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + } + c.Request = &http.Request{ Method: "POST", URL: &url.URL{Path: requestPath}, // 使用动态路径 @@ -109,6 +114,21 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } + // 重新检查模型类型并更新请求路径 + if strings.Contains(strings.ToLower(testModel), "embedding") || + strings.HasPrefix(testModel, "m3e") || + strings.Contains(testModel, "bge-") || + strings.Contains(testModel, "embed") || + channel.Type == constant.ChannelTypeMokaAI { + requestPath = "/v1/embeddings" + c.Request.URL.Path = requestPath + } + + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + c.Request.URL.Path = requestPath + } + cache, err := model.GetUserCache(1) if err != nil { return testResult{ @@ -140,6 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult { if c.Request.URL.Path == "/v1/embeddings" { relayFormat = types.RelayFormatEmbedding } + if c.Request.URL.Path == "/v1/images/generations" { + relayFormat = types.RelayFormatOpenAIImage + } info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) @@ -201,6 +224,22 @@ func testChannel(channel *model.Channel, testModel string) testResult { } // 调用专门用于 Embedding 的转换函数 convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest) + } else if info.RelayMode == relayconstant.RelayModeImagesGenerations { + // 创建一个 ImageRequest + prompt := "cat" + if request.Prompt != nil { + if promptStr, ok := request.Prompt.(string); ok && promptStr != "" { + prompt = promptStr + } + } + imageRequest := dto.ImageRequest{ + Prompt: prompt, + Model: request.Model, + N: uint(request.N), + Size: request.Size, + } + // 调用专门用于图像生成的转换函数 + convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest) } else { // 对其他所有请求类型(如 Chat),保持原有逻辑 convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request) diff --git a/controller/channel.go b/controller/channel.go index 17154ab0f..480d5b4f3 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -188,6 +188,8 @@ func FetchUpstreamModels(c *gin.Context) { url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader case constant.ChannelTypeAli: url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) + case constant.ChannelTypeZhipu_v4: + url = fmt.Sprintf("%s/api/paas/v4/models", baseURL) default: url = fmt.Sprintf("%s/v1/models", baseURL) } @@ -1101,8 +1103,8 @@ func CopyChannel(c *gin.Context) { // MultiKeyManageRequest represents the request for multi-key management operations type MultiKeyManageRequest struct { ChannelId int `json:"channel_id"` - Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status" - KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions + Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status" + KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions Page int `json:"page,omitempty"` // for get_key_status pagination PageSize int `json:"page_size,omitempty"` // for get_key_status pagination Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all @@ -1430,6 +1432,86 @@ func ManageMultiKeys(c *gin.Context) { }) return + case "delete_key": + if request.KeyIndex == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "未指定要删除的密钥索引", + }) + return + } + + keyIndex := *request.KeyIndex + if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "密钥索引超出范围", + }) + return + } + + keys := channel.GetKeys() + var remainingKeys []string + var newStatusList = make(map[int]int) + var newDisabledTime = make(map[int]int64) + var newDisabledReason = make(map[int]string) + + newIndex := 0 + for i, key := range keys { + // 跳过要删除的密钥 + if i == keyIndex { + continue + } + + remainingKeys = append(remainingKeys, key) + + // 保留其他密钥的状态信息,重新索引 + if channel.ChannelInfo.MultiKeyStatusList != nil { + if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 { + newStatusList[newIndex] = status + } + } + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists { + newDisabledTime[newIndex] = t + } + } + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists { + newDisabledReason[newIndex] = r + } + } + newIndex++ + } + + if len(remainingKeys) == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "不能删除最后一个密钥", + }) + return + } + + // Update channel with remaining keys + channel.Key = strings.Join(remainingKeys, "\n") + channel.ChannelInfo.MultiKeySize = len(remainingKeys) + channel.ChannelInfo.MultiKeyStatusList = newStatusList + channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime + channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "密钥已删除", + }) + return + case "delete_disabled_keys": keys := channel.GetKeys() var remainingKeys []string diff --git a/controller/option.go b/controller/option.go index 3e59c68e0..7d1c676f5 100644 --- a/controller/option.go +++ b/controller/option.go @@ -129,7 +129,7 @@ func UpdateOption(c *gin.Context) { return } case "ImageRatio": - err = ratio_setting.UpdateImageRatioByJSONString(option.Value) + err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -138,7 +138,7 @@ func UpdateOption(c *gin.Context) { return } case "AudioRatio": - err = ratio_setting.UpdateAudioRatioByJSONString(option.Value) + err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -147,7 +147,7 @@ func UpdateOption(c *gin.Context) { return } case "AudioCompletionRatio": - err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value) + err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index d462acb4b..9a568d857 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -217,7 +217,7 @@ func genStripeLink(referenceId string, customerId string, email string, amount i params := &stripe.CheckoutSessionParams{ ClientReferenceID: stripe.String(referenceId), - SuccessURL: stripe.String(system_setting.ServerAddress + "/log"), + SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"), CancelURL: stripe.String(system_setting.ServerAddress + "/topup"), LineItems: []*stripe.CheckoutSessionLineItemParams{ { @@ -225,7 +225,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i Quantity: stripe.Int64(amount), }, }, - Mode: stripe.String(string(stripe.CheckoutSessionModePayment)), + Mode: stripe.String(string(stripe.CheckoutSessionModePayment)), + AllowPromotionCodes: stripe.Bool(setting.StripePromotionCodesEnabled), } if "" == customerId { diff --git a/dto/gemini.go b/dto/gemini.go index 5df67ba0b..bc05c6aab 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -14,7 +14,30 @@ type GeminiChatRequest struct { SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"` GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"` Tools json.RawMessage `json:"tools,omitempty"` + ToolConfig *ToolConfig `json:"toolConfig,omitempty"` SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` + CachedContent string `json:"cachedContent,omitempty"` +} + +type ToolConfig struct { + FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"` + RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"` +} + +type FunctionCallingConfig struct { + Mode FunctionCallingConfigMode `json:"mode,omitempty"` + AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"` +} +type FunctionCallingConfigMode string + +type RetrievalConfig struct { + LatLng *LatLng `json:"latLng,omitempty"` + LanguageCode string `json:"languageCode,omitempty"` +} + +type LatLng struct { + Latitude *float64 `json:"latitude,omitempty"` + Longitude *float64 `json:"longitude,omitempty"` } func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { @@ -239,12 +262,20 @@ type GeminiChatGenerationConfig struct { StopSequences []string `json:"stopSequences,omitempty"` ResponseMimeType string `json:"responseMimeType,omitempty"` ResponseSchema any `json:"responseSchema,omitempty"` + ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"` + PresencePenalty *float32 `json:"presencePenalty,omitempty"` + FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"` + ResponseLogprobs bool `json:"responseLogprobs,omitempty"` + Logprobs *int32 `json:"logprobs,omitempty"` + MediaResolution MediaResolution `json:"mediaResolution,omitempty"` Seed int64 `json:"seed,omitempty"` ResponseModalities []string `json:"responseModalities,omitempty"` ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config } +type MediaResolution string + type GeminiChatCandidate struct { Content GeminiChatContent `json:"content"` FinishReason *string `json:"finishReason"` diff --git a/dto/openai_request.go b/dto/openai_request.go index cd05a63c9..191fa638f 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -772,11 +772,12 @@ type OpenAIResponsesRequest struct { Instructions json.RawMessage `json:"instructions,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"` - ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"` PreviousResponseID string `json:"previous_response_id,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"` ServiceTier string `json:"service_tier,omitempty"` - Store bool `json:"store,omitempty"` + Store json.RawMessage `json:"store,omitempty"` + PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"` Stream bool `json:"stream,omitempty"` Temperature float64 `json:"temperature,omitempty"` Text json.RawMessage `json:"text,omitempty"` diff --git a/dto/openai_response.go b/dto/openai_response.go index 966748cb5..6353c15ff 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -6,6 +6,10 @@ import ( "one-api/types" ) +const ( + ResponsesOutputTypeImageGenerationCall = "image_generation_call" +) + type SimpleResponse struct { Usage `json:"usage"` Error any `json:"error"` @@ -273,6 +277,42 @@ func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError { return GetOpenAIError(o.Error) } +func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool { + if len(o.Output) == 0 { + return false + } + for _, output := range o.Output { + if output.Type == ResponsesOutputTypeImageGenerationCall { + return true + } + } + return false +} + +func (o *OpenAIResponsesResponse) GetQuality() string { + if len(o.Output) == 0 { + return "" + } + for _, output := range o.Output { + if output.Type == ResponsesOutputTypeImageGenerationCall { + return output.Quality + } + } + return "" +} + +func (o *OpenAIResponsesResponse) GetSize() string { + if len(o.Output) == 0 { + return "" + } + for _, output := range o.Output { + if output.Type == ResponsesOutputTypeImageGenerationCall { + return output.Size + } + } + return "" +} + type IncompleteDetails struct { Reasoning string `json:"reasoning"` } @@ -283,6 +323,8 @@ type ResponsesOutput struct { Status string `json:"status"` Role string `json:"role"` Content []ResponsesOutputContent `json:"content"` + Quality string `json:"quality"` + Size string `json:"size"` } type ResponsesOutputContent struct { diff --git a/model/option.go b/model/option.go index ceecff658..9ace8fece 100644 --- a/model/option.go +++ b/model/option.go @@ -82,6 +82,7 @@ func InitOptionMap() { common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret common.OptionMap["StripePriceId"] = setting.StripePriceId common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64) + common.OptionMap["StripePromotionCodesEnabled"] = strconv.FormatBool(setting.StripePromotionCodesEnabled) common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["Chats"] = setting.Chats2JsonString() common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString() @@ -330,6 +331,8 @@ func updateOptionMap(key string, value string) (err error) { setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64) case "StripeMinTopUp": setting.StripeMinTopUp, _ = strconv.Atoi(value) + case "StripePromotionCodesEnabled": + setting.StripePromotionCodesEnabled = value == "true" case "TopupGroupRatio": err = common.UpdateTopupGroupRatioByJSONString(value) case "GitHubClientId": diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 72d0f9890..5ac7ce998 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -21,6 +21,10 @@ var awsModelIDMap = map[string]string{ "nova-lite-v1:0": "amazon.nova-lite-v1:0", "nova-pro-v1:0": "amazon.nova-pro-v1:0", "nova-premier-v1:0": "amazon.nova-premier-v1:0", + "nova-canvas-v1:0": "amazon.nova-canvas-v1:0", + "nova-reel-v1:0": "amazon.nova-reel-v1:0", + "nova-reel-v1:1": "amazon.nova-reel-v1:1", + "nova-sonic-v1:0": "amazon.nova-sonic-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ @@ -82,10 +86,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{ "apac": true, }, "amazon.nova-premier-v1:0": { + "us": true, + }, + "amazon.nova-canvas-v1:0": { "us": true, "eu": true, "apac": true, - }} + }, + "amazon.nova-reel-v1:0": { + "us": true, + "eu": true, + "apac": true, + }, + "amazon.nova-reel-v1:1": { + "us": true, + }, + "amazon.nova-sonic-v1:0": { + "us": true, + "eu": true, + "apac": true, + }, +} var awsRegionCrossModelPrefixMap = map[string]string{ "us": "us", diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 17d732ab0..962f8794a 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -3,17 +3,17 @@ package deepseek import ( "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/claude" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/types" "strings" - - "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { - adaptor := openai.Adaptor{} + adaptor := claude.Adaptor{} return adaptor.ConvertClaudeRequest(c, info, req) } @@ -44,14 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { fimBaseUrl := info.ChannelBaseUrl - if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { - fimBaseUrl += "/beta" - } - switch info.RelayMode { - case constant.RelayModeCompletions: - return fmt.Sprintf("%s/completions", fimBaseUrl), nil + switch info.RelayFormat { + case types.RelayFormatClaude: + return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil default: - return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil + if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { + fimBaseUrl += "/beta" + } + switch info.RelayMode { + case constant.RelayModeCompletions: + return fmt.Sprintf("%s/completions", fimBaseUrl), nil + default: + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil + } } } @@ -87,12 +92,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) + switch info.RelayFormat { + case types.RelayFormatClaude: + if info.IsStream { + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + } else { + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + } + default: + adaptor := openai.Adaptor{} + return adaptor.DoResponse(c, resp, info) } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 4968f78fe..57542aa5a 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -215,8 +215,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { - if strings.HasSuffix(info.RequestURLPath, ":embedContent") || - strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") { + if strings.Contains(info.RequestURLPath, ":embedContent") || + strings.Contains(info.RequestURLPath, ":batchEmbedContents") { return NativeGeminiEmbeddingHandler(c, resp, info) } if info.IsStream { diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go index e290c239d..f24976bb3 100644 --- a/relay/channel/moonshot/adaptor.go +++ b/relay/channel/moonshot/adaptor.go @@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { - adaptor := openai.Adaptor{} + adaptor := claude.Adaptor{} return adaptor.ConvertClaudeRequest(c, info, req) } diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index e188889e4..85938a771 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -33,6 +33,12 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) } + if responsesResponse.HasImageGenerationCall() { + c.Set("image_generation_call", true) + c.Set("image_generation_call_quality", responsesResponse.GetQuality()) + c.Set("image_generation_call_size", responsesResponse.GetSize()) + } + // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) @@ -80,18 +86,25 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp sendResponsesStreamData(c, streamResponse, data) switch streamResponse.Type { case "response.completed": - if streamResponse.Response != nil && streamResponse.Response.Usage != nil { - if streamResponse.Response.Usage.InputTokens != 0 { - usage.PromptTokens = streamResponse.Response.Usage.InputTokens + if streamResponse.Response != nil { + if streamResponse.Response.Usage != nil { + if streamResponse.Response.Usage.InputTokens != 0 { + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + } + if streamResponse.Response.Usage.OutputTokens != 0 { + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + } + if streamResponse.Response.Usage.TotalTokens != 0 { + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + } + if streamResponse.Response.Usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens + } } - if streamResponse.Response.Usage.OutputTokens != 0 { - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - } - if streamResponse.Response.Usage.TotalTokens != 0 { - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens - } - if streamResponse.Response.Usage.InputTokensDetails != nil { - usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens + if streamResponse.Response.HasImageGenerationCall() { + c.Set("image_generation_call", true) + c.Set("image_generation_call_quality", streamResponse.Response.GetQuality()) + c.Set("image_generation_call_size", streamResponse.Response.GetSize()) } } case "response.output_text.delta": diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 95f3cb269..a2545a273 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -36,6 +36,7 @@ type requestPayload struct { Prompt string `json:"prompt,omitempty"` Seed int64 `json:"seed"` AspectRatio string `json:"aspect_ratio"` + Frames int `json:"frames,omitempty"` } type responsePayload struct { @@ -325,10 +326,15 @@ func hmacSHA256(key []byte, data []byte) []byte { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { r := requestPayload{ - ReqKey: "jimeng_vgfm_i2v_l20", - Prompt: req.Prompt, - AspectRatio: "16:9", // Default aspect ratio - Seed: -1, // Default to random + ReqKey: req.Model, + Prompt: req.Prompt, + } + + switch req.Duration { + case 10: + r.Frames = 241 // 24*10+1 = 241 + default: + r.Frames = 121 // 24*5+1 = 121 } // Handle one-of image_urls or binary_data_base64 @@ -348,6 +354,22 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* if err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } + + // 即梦视频3.0 ReqKey转换 + // https://www.volcengine.com/docs/85621/1792707 + if strings.Contains(r.ReqKey, "jimeng_v30") { + if len(r.ImageUrls) > 1 { + // 多张图片:首尾帧生成 + r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1) + } else if len(r.ImageUrls) == 1 { + // 单张图片:图生视频 + r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1) + } else { + // 无图片:文生视频 + r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1) + } + } + return &r, nil } diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index a1140d1e7..358aef583 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -80,8 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { - // Use the unified validation method for TaskSubmitReq with image-based action determination - return relaycommon.ValidateTaskRequestWithImageBinding(c, info) + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { @@ -112,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro switch info.Action { case constant.TaskActionGenerate: path = "/img2video" + case constant.TaskActionFirstTailGenerate: + path = "/start-end2video" + case constant.TaskActionReferenceGenerate: + path = "/reference2video" default: path = "/text2video" } @@ -187,14 +190,9 @@ func (a *TaskAdaptor) GetChannelName() string { // ============================ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { - var images []string - if req.Image != "" { - images = []string{req.Image} - } - r := requestPayload{ Model: defaultString(req.Model, "viduq1"), - Images: images, + Images: req.Images, Prompt: req.Prompt, Duration: defaultInt(req.Duration, 5), Resolution: defaultString(req.Size, "1080p"), diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 0af019da4..eb88412af 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -41,6 +41,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { switch info.RelayMode { + case constant.RelayModeImagesGenerations: + return request, nil case constant.RelayModeImagesEdits: var requestBody bytes.Buffer diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go index 30cc902e7..fca10e7c1 100644 --- a/relay/channel/volcengine/constants.go +++ b/relay/channel/volcengine/constants.go @@ -8,6 +8,7 @@ var ModelList = []string{ "Doubao-lite-32k", "Doubao-lite-4k", "Doubao-embedding", + "doubao-seedream-4-0-250828", } var ChannelName = "volcengine" diff --git a/relay/claude_handler.go b/relay/claude_handler.go index dbdc6ee1c..59d12abe4 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -69,6 +70,31 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ info.UpstreamModelName = request.Model } + if info.ChannelSetting.SystemPrompt != "" { + if request.System == nil { + request.SetStringSystem(info.ChannelSetting.SystemPrompt) + } else if info.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + if request.IsStringSystem() { + existing := strings.TrimSpace(request.GetStringSystem()) + if existing == "" { + request.SetStringSystem(info.ChannelSetting.SystemPrompt) + } else { + request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing) + } + } else { + systemContents := request.ParseSystem() + newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText} + newSystem.SetText(info.ChannelSetting.SystemPrompt) + if len(systemContents) == 0 { + request.System = []dto.ClaudeMediaMessage{newSystem} + } else { + request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...) + } + } + } + } + var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index cf6d08dda..3a721b479 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -79,34 +79,18 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d req.Images = []string{req.Image} } + if req.HasImage() { + action = constant.TaskActionGenerate + if info.ChannelType == constant.ChannelTypeVidu { + // vidu 增加 首尾帧生视频和参考图生视频 + if len(req.Images) == 2 { + action = constant.TaskActionFirstTailGenerate + } else if len(req.Images) > 2 { + action = constant.TaskActionReferenceGenerate + } + } + } + storeTaskRequest(c, info, action, req) return nil } - -func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError { - hasPrompt, ok := requestObj.(HasPrompt) - if !ok { - return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true) - } - - if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil { - return taskErr - } - - action := constant.TaskActionTextGenerate - if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() { - action = constant.TaskActionGenerate - } - - storeTaskRequest(c, info, action, requestObj) - return nil -} - -func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError { - var req TaskSubmitReq - if err := c.ShouldBindJSON(&req); err != nil { - return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false) - } - - return ValidateTaskRequestWithImage(c, info, req) -} diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index c2d6b6fa1..38b820f72 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -278,6 +278,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage fileSearchTool.CallCount, dFileSearchQuota.String()) } } + var dImageGenerationCallQuota decimal.Decimal + var imageGenerationCallPrice float64 + if ctx.GetBool("image_generation_call") { + imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) + dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()) + } var quotaCalculateDecimal decimal.Decimal @@ -333,6 +340,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) // 添加 audio input 独立计费 quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) + // 添加 image generation call 计费 + quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens @@ -431,6 +440,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage other["audio_input_token_count"] = audioTokens other["audio_input_price"] = audioInputPrice } + if !dImageGenerationCallQuota.IsZero() { + other["image_generation_call"] = true + other["image_generation_call_price"] = imageGenerationCallPrice + } model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 0252d6578..1410da606 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/logger" "one-api/relay/channel/gemini" @@ -94,6 +95,32 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ adaptor.Init(info) + if info.ChannelSetting.SystemPrompt != "" { + if request.SystemInstructions == nil { + request.SystemInstructions = &dto.GeminiChatContent{ + Parts: []dto.GeminiPart{ + {Text: info.ChannelSetting.SystemPrompt}, + }, + } + } else if len(request.SystemInstructions.Parts) == 0 { + request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}} + } else if info.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + merged := false + for i := range request.SystemInstructions.Parts { + if request.SystemInstructions.Parts[i].Text == "" { + continue + } + request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text + merged = true + break + } + if !merged { + request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...) + } + } + } + // Clean up empty system instruction if request.SystemInstructions != nil { hasContent := false diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 4d1c1f9bb..f4a290ec6 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -21,7 +21,11 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt case types.RelayFormatOpenAI: request, err = GetAndValidateTextRequest(c, relayMode) case types.RelayFormatGemini: - request, err = GetAndValidateGeminiRequest(c) + if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") { + request, err = GetAndValidateGeminiEmbeddingRequest(c) + } else { + request, err = GetAndValidateGeminiRequest(c) + } case types.RelayFormatClaude: request, err = GetAndValidateClaudeRequest(c) case types.RelayFormatOpenAIResponses: @@ -288,7 +292,6 @@ func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenA } func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { - request := &dto.GeminiChatRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { @@ -304,3 +307,12 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) return request, nil } + +func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) { + request := &dto.GeminiEmbeddingRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + return request, nil +} diff --git a/service/cf_worker.go b/service/cf_worker.go deleted file mode 100644 index 4a7b43760..000000000 --- a/service/cf_worker.go +++ /dev/null @@ -1,57 +0,0 @@ -package service - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "one-api/common" - "one-api/setting" - "strings" -) - -// WorkerRequest Worker请求的数据结构 -type WorkerRequest struct { - URL string `json:"url"` - Key string `json:"key"` - Method string `json:"method,omitempty"` - Headers map[string]string `json:"headers,omitempty"` - Body json.RawMessage `json:"body,omitempty"` -} - -// DoWorkerRequest 通过Worker发送请求 -func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { - if !setting.EnableWorker() { - return nil, fmt.Errorf("worker not enabled") - } - if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") { - return nil, fmt.Errorf("only support https url") - } - - workerUrl := setting.WorkerUrl - if !strings.HasSuffix(workerUrl, "/") { - workerUrl += "/" - } - - // 序列化worker请求数据 - workerPayload, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("failed to marshal worker payload: %v", err) - } - - return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload)) -} - -func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) { - if setting.EnableWorker() { - common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) - req := &WorkerRequest{ - URL: originUrl, - Key: setting.WorkerValidKey, - } - return DoWorkerRequest(req) - } else { - common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) - return http.Get(originUrl) - } -} diff --git a/service/download.go b/service/download.go new file mode 100644 index 000000000..036c43af8 --- /dev/null +++ b/service/download.go @@ -0,0 +1,69 @@ +package service + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "one-api/common" + "one-api/setting/system_setting" + "strings" +) + +// WorkerRequest Worker请求的数据结构 +type WorkerRequest struct { + URL string `json:"url"` + Key string `json:"key"` + Method string `json:"method,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + Body json.RawMessage `json:"body,omitempty"` +} + +// DoWorkerRequest 通过Worker发送请求 +func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { + if !system_setting.EnableWorker() { + return nil, fmt.Errorf("worker not enabled") + } + if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") { + return nil, fmt.Errorf("only support https url") + } + + // SSRF防护:验证请求URL + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return nil, fmt.Errorf("request reject: %v", err) + } + + workerUrl := system_setting.WorkerUrl + if !strings.HasSuffix(workerUrl, "/") { + workerUrl += "/" + } + + // 序列化worker请求数据 + workerPayload, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal worker payload: %v", err) + } + + return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload)) +} + +func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) { + if system_setting.EnableWorker() { + common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) + req := &WorkerRequest{ + URL: originUrl, + Key: system_setting.WorkerValidKey, + } + return DoWorkerRequest(req) + } else { + // SSRF防护:验证请求URL(非Worker模式) + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return nil, fmt.Errorf("request reject: %v", err) + } + + common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", "))) + return http.Get(originUrl) + } +} diff --git a/service/user_notify.go b/service/user_notify.go index c4a3ea91f..fba12d9db 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -7,7 +7,7 @@ import ( "one-api/common" "one-api/dto" "one-api/model" - "one-api/setting" + "one-api/setting/system_setting" "strings" ) @@ -91,11 +91,11 @@ func sendBarkNotify(barkURL string, data dto.Notify) error { var resp *http.Response var err error - if setting.EnableWorker() { + if system_setting.EnableWorker() { // 使用worker发送请求 workerReq := &WorkerRequest{ URL: finalURL, - Key: setting.WorkerValidKey, + Key: system_setting.WorkerValidKey, Method: http.MethodGet, Headers: map[string]string{ "User-Agent": "OneAPI-Bark-Notify/1.0", @@ -113,6 +113,12 @@ func sendBarkNotify(barkURL string, data dto.Notify) error { return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode) } } else { + // SSRF防护:验证Bark URL(非Worker模式) + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return fmt.Errorf("request reject: %v", err) + } + // 直接发送请求 req, err = http.NewRequest(http.MethodGet, finalURL, nil) if err != nil { diff --git a/service/webhook.go b/service/webhook.go index 8faccda30..c678b8634 100644 --- a/service/webhook.go +++ b/service/webhook.go @@ -8,8 +8,9 @@ import ( "encoding/json" "fmt" "net/http" + "one-api/common" "one-api/dto" - "one-api/setting" + "one-api/setting/system_setting" "time" ) @@ -56,11 +57,11 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error var req *http.Request var resp *http.Response - if setting.EnableWorker() { + if system_setting.EnableWorker() { // 构建worker请求数据 workerReq := &WorkerRequest{ URL: webhookURL, - Key: setting.WorkerValidKey, + Key: system_setting.WorkerValidKey, Method: http.MethodPost, Headers: map[string]string{ "Content-Type": "application/json", @@ -86,6 +87,12 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode) } } else { + // SSRF防护:验证Webhook URL(非Worker模式) + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return fmt.Errorf("request reject: %v", err) + } + req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes)) if err != nil { return fmt.Errorf("failed to create webhook request: %v", err) diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go index 549a1862e..5b89d6fec 100644 --- a/setting/operation_setting/tools.go +++ b/setting/operation_setting/tools.go @@ -10,6 +10,18 @@ const ( FileSearchPrice = 2.5 ) +const ( + GPTImage1Low1024x1024 = 0.011 + GPTImage1Low1024x1536 = 0.016 + GPTImage1Low1536x1024 = 0.016 + GPTImage1Medium1024x1024 = 0.042 + GPTImage1Medium1024x1536 = 0.063 + GPTImage1Medium1536x1024 = 0.063 + GPTImage1High1024x1024 = 0.167 + GPTImage1High1024x1536 = 0.25 + GPTImage1High1536x1024 = 0.25 +) + const ( // Gemini Audio Input Price Gemini25FlashPreviewInputAudioPrice = 1.00 @@ -65,3 +77,31 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 { } return 0 } + +func GetGPTImage1PriceOnceCall(quality string, size string) float64 { + prices := map[string]map[string]float64{ + "low": { + "1024x1024": GPTImage1Low1024x1024, + "1024x1536": GPTImage1Low1024x1536, + "1536x1024": GPTImage1Low1536x1024, + }, + "medium": { + "1024x1024": GPTImage1Medium1024x1024, + "1024x1536": GPTImage1Medium1024x1536, + "1536x1024": GPTImage1Medium1536x1024, + }, + "high": { + "1024x1024": GPTImage1High1024x1024, + "1024x1536": GPTImage1High1024x1536, + "1536x1024": GPTImage1High1536x1024, + }, + } + + if qualityMap, exists := prices[quality]; exists { + if price, exists := qualityMap[size]; exists { + return price + } + } + + return GPTImage1High1024x1024 +} diff --git a/setting/payment_stripe.go b/setting/payment_stripe.go index 80d877dfa..d97120c85 100644 --- a/setting/payment_stripe.go +++ b/setting/payment_stripe.go @@ -5,3 +5,4 @@ var StripeWebhookSecret = "" var StripePriceId = "" var StripeUnitPrice = 8.0 var StripeMinTopUp = 1 +var StripePromotionCodesEnabled = false diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 5b47c875f..362c6fa1a 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -178,6 +178,7 @@ var defaultModelRatio = map[string]float64{ "gemini-2.5-flash-lite-preview-thinking-*": 0.05, "gemini-2.5-flash-lite-preview-06-17": 0.05, "gemini-2.5-flash": 0.15, + "gemini-embedding-001": 0.075, "text-embedding-004": 0.001, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens diff --git a/setting/system_setting/fetch_setting.go b/setting/system_setting/fetch_setting.go new file mode 100644 index 000000000..c41b930af --- /dev/null +++ b/setting/system_setting/fetch_setting.go @@ -0,0 +1,34 @@ +package system_setting + +import "one-api/setting/config" + +type FetchSetting struct { + EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护 + AllowPrivateIp bool `json:"allow_private_ip"` + DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式 + IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式,true: 白名单模式,false: 黑名单模式 + DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com + IpList []string `json:"ip_list"` // CIDR format + AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000 + ApplyIPFilterForDomain bool `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤(实验性) +} + +var defaultFetchSetting = FetchSetting{ + EnableSSRFProtection: true, // 默认开启SSRF防护 + AllowPrivateIp: false, + DomainFilterMode: false, + IpFilterMode: false, + DomainList: []string{}, + IpList: []string{}, + AllowedPorts: []string{"80", "443", "8080", "8443"}, + ApplyIPFilterForDomain: false, +} + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("fetch_setting", &defaultFetchSetting) +} + +func GetFetchSetting() *FetchSetting { + return &defaultFetchSetting +} diff --git a/types/error.go b/types/error.go index 883ee0641..a42e84385 100644 --- a/types/error.go +++ b/types/error.go @@ -122,6 +122,9 @@ func (e *NewAPIError) MaskSensitiveError() string { return string(e.errorCode) } errStr := e.Err.Error() + if e.errorCode == ErrorCodeCountTokenFailed { + return errStr + } return common.MaskSensitiveInfo(errStr) } @@ -153,8 +156,9 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError { Code: e.errorCode, } } - - result.Message = common.MaskSensitiveInfo(result.Message) + if e.errorCode != ErrorCodeCountTokenFailed { + result.Message = common.MaskSensitiveInfo(result.Message) + } return result } @@ -178,7 +182,9 @@ func (e *NewAPIError) ToClaudeError() ClaudeError { Type: string(e.errorType), } } - result.Message = common.MaskSensitiveInfo(result.Message) + if e.errorCode != ErrorCodeCountTokenFailed { + result.Message = common.MaskSensitiveInfo(result.Message) + } return result } diff --git a/web/jsconfig.json b/web/jsconfig.json new file mode 100644 index 000000000..ced4d0543 --- /dev/null +++ b/web/jsconfig.json @@ -0,0 +1,9 @@ +{ + "compilerOptions": { + "baseUrl": "./", + "paths": { + "@/*": ["src/*"] + } + }, + "include": ["src/**/*"] +} \ No newline at end of file diff --git a/web/src/components/common/markdown/MarkdownRenderer.jsx b/web/src/components/common/markdown/MarkdownRenderer.jsx index f1283a640..05419f8cc 100644 --- a/web/src/components/common/markdown/MarkdownRenderer.jsx +++ b/web/src/components/common/markdown/MarkdownRenderer.jsx @@ -181,8 +181,8 @@ export function PreCode(props) { e.preventDefault(); e.stopPropagation(); if (ref.current) { - const code = - ref.current.querySelector('code')?.innerText ?? ''; + const codeElement = ref.current.querySelector('code'); + const code = codeElement?.textContent ?? ''; copy(code).then((success) => { if (success) { Toast.success(t('代码已复制到剪贴板')); diff --git a/web/src/components/layout/headerbar/UserArea.jsx b/web/src/components/layout/headerbar/UserArea.jsx index 8ea70f47f..9fc011da1 100644 --- a/web/src/components/layout/headerbar/UserArea.jsx +++ b/web/src/components/layout/headerbar/UserArea.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React from 'react'; +import React, { useRef } from 'react'; import { Link } from 'react-router-dom'; import { Avatar, Button, Dropdown, Typography } from '@douyinfe/semi-ui'; import { ChevronDown } from 'lucide-react'; @@ -39,6 +39,7 @@ const UserArea = ({ navigate, t, }) => { + const dropdownRef = useRef(null); if (isLoading) { return ( - { - navigate('/console/personal'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('个人设置')} -
-
- { - navigate('/console/token'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('令牌管理')} -
-
- { - navigate('/console/topup'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('钱包管理')} -
-
- -
- - {t('退出')} -
-
- - } - > - - + + {userState.user.username[0].toUpperCase()} + + + + {userState.user.username} + + + + + + ); } else { const showRegisterButton = !isSelfUseMode; diff --git a/web/src/components/settings/PaymentSetting.jsx b/web/src/components/settings/PaymentSetting.jsx index faaa9561b..220c86642 100644 --- a/web/src/components/settings/PaymentSetting.jsx +++ b/web/src/components/settings/PaymentSetting.jsx @@ -45,6 +45,7 @@ const PaymentSetting = () => { StripePriceId: '', StripeUnitPrice: 8.0, StripeMinTopUp: 1, + StripePromotionCodesEnabled: false, }); let [loading, setLoading] = useState(false); diff --git a/web/src/components/settings/SystemSetting.jsx b/web/src/components/settings/SystemSetting.jsx index 9c7eeaadc..f9a2c019d 100644 --- a/web/src/components/settings/SystemSetting.jsx +++ b/web/src/components/settings/SystemSetting.jsx @@ -29,6 +29,7 @@ import { TagInput, Spin, Card, + Radio, } from '@douyinfe/semi-ui'; const { Text } = Typography; import { @@ -44,6 +45,7 @@ import { useTranslation } from 'react-i18next'; const SystemSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ + PasswordLoginEnabled: '', PasswordRegisterEnabled: '', EmailVerificationEnabled: '', @@ -87,6 +89,15 @@ const SystemSetting = () => { LinuxDOClientSecret: '', LinuxDOMinimumTrustLevel: '', ServerAddress: '', + // SSRF防护配置 + 'fetch_setting.enable_ssrf_protection': true, + 'fetch_setting.allow_private_ip': '', + 'fetch_setting.domain_filter_mode': false, // true 白名单,false 黑名单 + 'fetch_setting.ip_filter_mode': false, // true 白名单,false 黑名单 + 'fetch_setting.domain_list': [], + 'fetch_setting.ip_list': [], + 'fetch_setting.allowed_ports': [], + 'fetch_setting.apply_ip_filter_for_domain': false, }); const [originInputs, setOriginInputs] = useState({}); @@ -98,6 +109,11 @@ const SystemSetting = () => { useState(false); const [linuxDOOAuthEnabled, setLinuxDOOAuthEnabled] = useState(false); const [emailToAdd, setEmailToAdd] = useState(''); + const [domainFilterMode, setDomainFilterMode] = useState(true); + const [ipFilterMode, setIpFilterMode] = useState(true); + const [domainList, setDomainList] = useState([]); + const [ipList, setIpList] = useState([]); + const [allowedPorts, setAllowedPorts] = useState([]); const getOptions = async () => { setLoading(true); @@ -113,6 +129,37 @@ const SystemSetting = () => { case 'EmailDomainWhitelist': setEmailDomainWhitelist(item.value ? item.value.split(',') : []); break; + case 'fetch_setting.allow_private_ip': + case 'fetch_setting.enable_ssrf_protection': + case 'fetch_setting.domain_filter_mode': + case 'fetch_setting.ip_filter_mode': + case 'fetch_setting.apply_ip_filter_for_domain': + item.value = toBoolean(item.value); + break; + case 'fetch_setting.domain_list': + try { + const domains = item.value ? JSON.parse(item.value) : []; + setDomainList(Array.isArray(domains) ? domains : []); + } catch (e) { + setDomainList([]); + } + break; + case 'fetch_setting.ip_list': + try { + const ips = item.value ? JSON.parse(item.value) : []; + setIpList(Array.isArray(ips) ? ips : []); + } catch (e) { + setIpList([]); + } + break; + case 'fetch_setting.allowed_ports': + try { + const ports = item.value ? JSON.parse(item.value) : []; + setAllowedPorts(Array.isArray(ports) ? ports : []); + } catch (e) { + setAllowedPorts(['80', '443', '8080', '8443']); + } + break; case 'PasswordLoginEnabled': case 'PasswordRegisterEnabled': case 'EmailVerificationEnabled': @@ -140,6 +187,13 @@ const SystemSetting = () => { }); setInputs(newInputs); setOriginInputs(newInputs); + // 同步模式布尔到本地状态 + if (typeof newInputs['fetch_setting.domain_filter_mode'] !== 'undefined') { + setDomainFilterMode(!!newInputs['fetch_setting.domain_filter_mode']); + } + if (typeof newInputs['fetch_setting.ip_filter_mode'] !== 'undefined') { + setIpFilterMode(!!newInputs['fetch_setting.ip_filter_mode']); + } if (formApiRef.current) { formApiRef.current.setValues(newInputs); } @@ -276,6 +330,46 @@ const SystemSetting = () => { } }; + const submitSSRF = async () => { + const options = []; + + // 处理域名过滤模式与列表 + options.push({ + key: 'fetch_setting.domain_filter_mode', + value: domainFilterMode, + }); + if (Array.isArray(domainList)) { + options.push({ + key: 'fetch_setting.domain_list', + value: JSON.stringify(domainList), + }); + } + + // 处理IP过滤模式与列表 + options.push({ + key: 'fetch_setting.ip_filter_mode', + value: ipFilterMode, + }); + if (Array.isArray(ipList)) { + options.push({ + key: 'fetch_setting.ip_list', + value: JSON.stringify(ipList), + }); + } + + // 处理端口配置 + if (Array.isArray(allowedPorts)) { + options.push({ + key: 'fetch_setting.allowed_ports', + value: JSON.stringify(allowedPorts), + }); + } + + if (options.length > 0) { + await updateOptions(options); + } + }; + const handleAddEmail = () => { if (emailToAdd && emailToAdd.trim() !== '') { const domain = emailToAdd.trim(); @@ -587,6 +681,179 @@ const SystemSetting = () => { + + + + {t('配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全')} + + + + + handleCheckboxChange('fetch_setting.enable_ssrf_protection', e) + } + > + {t('启用SSRF防护(推荐开启以保护服务器安全)')} + + + + + + + + handleCheckboxChange('fetch_setting.allow_private_ip', e) + } + > + {t('允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)')} + + + + + + + + handleCheckboxChange('fetch_setting.apply_ip_filter_for_domain', e) + } + style={{ marginBottom: 8 }} + > + {t('对域名启用 IP 过滤(实验性)')} + + + {t(domainFilterMode ? '域名白名单' : '域名黑名单')} + + + {t('支持通配符格式,如:example.com, *.api.example.com')} + + { + const selected = val && val.target ? val.target.value : val; + const isWhitelist = selected === 'whitelist'; + setDomainFilterMode(isWhitelist); + setInputs(prev => ({ + ...prev, + 'fetch_setting.domain_filter_mode': isWhitelist, + })); + }} + style={{ marginBottom: 8 }} + > + {t('白名单')} + {t('黑名单')} + + { + setDomainList(value); + // 触发Form的onChange事件 + setInputs(prev => ({ + ...prev, + 'fetch_setting.domain_list': value + })); + }} + placeholder={t('输入域名后回车,如:example.com')} + style={{ width: '100%' }} + /> + + + + + + + {t(ipFilterMode ? 'IP白名单' : 'IP黑名单')} + + + {t('支持CIDR格式,如:8.8.8.8, 192.168.1.0/24')} + + { + const selected = val && val.target ? val.target.value : val; + const isWhitelist = selected === 'whitelist'; + setIpFilterMode(isWhitelist); + setInputs(prev => ({ + ...prev, + 'fetch_setting.ip_filter_mode': isWhitelist, + })); + }} + style={{ marginBottom: 8 }} + > + {t('白名单')} + {t('黑名单')} + + { + setIpList(value); + // 触发Form的onChange事件 + setInputs(prev => ({ + ...prev, + 'fetch_setting.ip_list': value + })); + }} + placeholder={t('输入IP地址后回车,如:8.8.8.8')} + style={{ width: '100%' }} + /> + + + + + + {t('允许的端口')} + + {t('支持单个端口和端口范围,如:80, 443, 8000-8999')} + + { + setAllowedPorts(value); + // 触发Form的onChange事件 + setInputs(prev => ({ + ...prev, + 'fetch_setting.allowed_ports': value + })); + }} + placeholder={t('输入端口后回车,如:80 或 8000-8999')} + style={{ width: '100%' }} + /> + + {t('端口配置详细说明')} + + + + + + + + { return (checked) => { @@ -132,6 +136,9 @@ const NotificationSettings = ({ }); if (res.data.success) { showSuccess(t('侧边栏设置保存成功')); + + // 刷新useSidebar钩子中的用户配置,实现实时更新 + await refreshUserConfig(); } else { showError(res.data.message); } @@ -334,7 +341,7 @@ const NotificationSettings = ({ loading={sidebarLoading} className='!rounded-lg' > - {t('保存边栏设置')} + {t('保存设置')} ) : ( diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index c0a216246..967bf88a2 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -85,6 +85,26 @@ const REGION_EXAMPLE = { 'claude-3-5-sonnet-20240620': 'europe-west1', }; +// 支持并且已适配通过接口获取模型列表的渠道类型 +const MODEL_FETCHABLE_TYPES = new Set([ + 1, + 4, + 14, + 34, + 17, + 26, + 24, + 47, + 25, + 20, + 23, + 31, + 35, + 40, + 42, + 48, +]); + function type2secretPrompt(type) { // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') switch (type) { @@ -1872,13 +1892,15 @@ const EditChannelModal = (props) => { > {t('填入所有模型')} - + {MODEL_FETCHABLE_TYPES.has(inputs.type) && ( + + )} )} + handleDeleteKey(record.index)} + okType={'danger'} + position={'topRight'} + > + + ), }, diff --git a/web/src/components/table/mj-logs/MjLogsFilters.jsx b/web/src/components/table/mj-logs/MjLogsFilters.jsx index 44c6bcfcd..6db96e791 100644 --- a/web/src/components/table/mj-logs/MjLogsFilters.jsx +++ b/web/src/components/table/mj-logs/MjLogsFilters.jsx @@ -21,6 +21,8 @@ import React from 'react'; import { Button, Form } from '@douyinfe/semi-ui'; import { IconSearch } from '@douyinfe/semi-icons'; +import { DATE_RANGE_PRESETS } from '../../../constants/console.constants'; + const MjLogsFilters = ({ formInitValues, setFormApi, @@ -54,6 +56,11 @@ const MjLogsFilters = ({ showClear pure size='small' + presets={DATE_RANGE_PRESETS.map(preset => ({ + text: t(preset.text), + start: preset.start(), + end: preset.end() + }))} /> diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index 766c17158..b63c7dd4f 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -35,8 +35,9 @@ import { Sparkles, } from 'lucide-react'; import { - TASK_ACTION_GENERATE, - TASK_ACTION_TEXT_GENERATE, + TASK_ACTION_FIRST_TAIL_GENERATE, + TASK_ACTION_GENERATE, TASK_ACTION_REFERENCE_GENERATE, + TASK_ACTION_TEXT_GENERATE } from '../../../constants/common.constant'; import { CHANNEL_OPTIONS } from '../../../constants/channel.constants'; @@ -111,6 +112,18 @@ const renderType = (type, t) => { {t('文生视频')} ); + case TASK_ACTION_FIRST_TAIL_GENERATE: + return ( + }> + {t('首尾生视频')} + + ); + case TASK_ACTION_REFERENCE_GENERATE: + return ( + }> + {t('参照生视频')} + + ); default: return ( }> @@ -343,7 +356,9 @@ export const getTaskLogsColumns = ({ // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 const isVideoTask = record.action === TASK_ACTION_GENERATE || - record.action === TASK_ACTION_TEXT_GENERATE; + record.action === TASK_ACTION_TEXT_GENERATE || + record.action === TASK_ACTION_FIRST_TAIL_GENERATE || + record.action === TASK_ACTION_REFERENCE_GENERATE; const isSuccess = record.status === 'SUCCESS'; const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); if (isSuccess && isVideoTask && isUrl) { diff --git a/web/src/components/table/task-logs/TaskLogsFilters.jsx b/web/src/components/table/task-logs/TaskLogsFilters.jsx index d5e081ab7..e27cea867 100644 --- a/web/src/components/table/task-logs/TaskLogsFilters.jsx +++ b/web/src/components/table/task-logs/TaskLogsFilters.jsx @@ -21,6 +21,8 @@ import React from 'react'; import { Button, Form } from '@douyinfe/semi-ui'; import { IconSearch } from '@douyinfe/semi-icons'; +import { DATE_RANGE_PRESETS } from '../../../constants/console.constants'; + const TaskLogsFilters = ({ formInitValues, setFormApi, @@ -54,6 +56,11 @@ const TaskLogsFilters = ({ showClear pure size='small' + presets={DATE_RANGE_PRESETS.map(preset => ({ + text: t(preset.text), + start: preset.start(), + end: preset.end() + }))} /> diff --git a/web/src/components/table/usage-logs/UsageLogsFilters.jsx b/web/src/components/table/usage-logs/UsageLogsFilters.jsx index f76ec823e..58e5a4692 100644 --- a/web/src/components/table/usage-logs/UsageLogsFilters.jsx +++ b/web/src/components/table/usage-logs/UsageLogsFilters.jsx @@ -21,6 +21,8 @@ import React from 'react'; import { Button, Form } from '@douyinfe/semi-ui'; import { IconSearch } from '@douyinfe/semi-icons'; +import { DATE_RANGE_PRESETS } from '../../../constants/console.constants'; + const LogsFilters = ({ formInitValues, setFormApi, @@ -55,6 +57,11 @@ const LogsFilters = ({ showClear pure size='small' + presets={DATE_RANGE_PRESETS.map(preset => ({ + text: t(preset.text), + start: preset.start(), + end: preset.end() + }))} /> diff --git a/web/src/constants/common.constant.js b/web/src/constants/common.constant.js index 277bb9a54..57fbbbde5 100644 --- a/web/src/constants/common.constant.js +++ b/web/src/constants/common.constant.js @@ -40,3 +40,5 @@ export const API_ENDPOINTS = [ export const TASK_ACTION_GENERATE = 'generate'; export const TASK_ACTION_TEXT_GENERATE = 'textGenerate'; +export const TASK_ACTION_FIRST_TAIL_GENERATE = 'firstTailGenerate'; +export const TASK_ACTION_REFERENCE_GENERATE = 'referenceGenerate'; diff --git a/web/src/constants/console.constants.js b/web/src/constants/console.constants.js new file mode 100644 index 000000000..23ee1e17f --- /dev/null +++ b/web/src/constants/console.constants.js @@ -0,0 +1,49 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import dayjs from 'dayjs'; + +// ========== 日期预设常量 ========== +export const DATE_RANGE_PRESETS = [ + { + text: '今天', + start: () => dayjs().startOf('day').toDate(), + end: () => dayjs().endOf('day').toDate() + }, + { + text: '近 7 天', + start: () => dayjs().subtract(6, 'day').startOf('day').toDate(), + end: () => dayjs().endOf('day').toDate() + }, + { + text: '本周', + start: () => dayjs().startOf('week').toDate(), + end: () => dayjs().endOf('week').toDate() + }, + { + text: '近 30 天', + start: () => dayjs().subtract(29, 'day').startOf('day').toDate(), + end: () => dayjs().endOf('day').toDate() + }, + { + text: '本月', + start: () => dayjs().startOf('month').toDate(), + end: () => dayjs().endOf('month').toDate() + }, +]; diff --git a/web/src/helpers/api.js b/web/src/helpers/api.js index b7092fe77..bc389b2e8 100644 --- a/web/src/helpers/api.js +++ b/web/src/helpers/api.js @@ -118,7 +118,6 @@ export const buildApiPayload = ( model: inputs.model, group: inputs.group, messages: processedMessages, - group: inputs.group, stream: inputs.stream, }; @@ -132,13 +131,15 @@ export const buildApiPayload = ( seed: 'seed', }; + Object.entries(parameterMappings).forEach(([key, param]) => { - if ( - parameterEnabled[key] && - inputs[param] !== undefined && - inputs[param] !== null - ) { - payload[param] = inputs[param]; + const enabled = parameterEnabled[key]; + const value = inputs[param]; + const hasValue = value !== undefined && value !== null; + + + if (enabled && hasValue) { + payload[param] = value; } }); diff --git a/web/src/helpers/render.jsx b/web/src/helpers/render.jsx index 65332701b..c19e2849d 100644 --- a/web/src/helpers/render.jsx +++ b/web/src/helpers/render.jsx @@ -1027,6 +1027,8 @@ export function renderModelPrice( audioInputSeperatePrice = false, audioInputTokens = 0, audioInputPrice = 0, + imageGenerationCall = false, + imageGenerationCallPrice = 0, ) { const { ratio: effectiveGroupRatio, label: ratioLabel } = getEffectiveRatio( groupRatio, @@ -1069,7 +1071,8 @@ export function renderModelPrice( (audioInputTokens / 1000000) * audioInputPrice * groupRatio + (completionTokens / 1000000) * completionRatioPrice * groupRatio + (webSearchCallCount / 1000) * webSearchPrice * groupRatio + - (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio; + (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio + + (imageGenerationCallPrice * groupRatio); return ( <> @@ -1131,7 +1134,13 @@ export function renderModelPrice( })}

)} -

+ {imageGenerationCall && imageGenerationCallPrice > 0 && ( +

+ {i18next.t('图片生成调用:${{price}} / 1次', { + price: imageGenerationCallPrice, + })} +

+ )}

{(() => { // 构建输入部分描述 @@ -1211,6 +1220,16 @@ export function renderModelPrice( }, ) : '', + imageGenerationCall && imageGenerationCallPrice > 0 + ? i18next.t( + ' + 图片生成调用 ${{price}} / 1次 * {{ratioType}} {{ratio}}', + { + price: imageGenerationCallPrice, + ratio: groupRatio, + ratioType: ratioLabel, + }, + ) + : '', ].join(''); return i18next.t( diff --git a/web/src/helpers/utils.jsx b/web/src/helpers/utils.jsx index e446ea69d..bcd13230e 100644 --- a/web/src/helpers/utils.jsx +++ b/web/src/helpers/utils.jsx @@ -75,13 +75,17 @@ export async function copy(text) { await navigator.clipboard.writeText(text); } catch (e) { try { - // 构建input 执行 复制命令 - var _input = window.document.createElement('input'); - _input.value = text; - window.document.body.appendChild(_input); - _input.select(); - window.document.execCommand('Copy'); - window.document.body.removeChild(_input); + // 构建 textarea 执行复制命令,保留多行文本格式 + const textarea = window.document.createElement('textarea'); + textarea.value = text; + textarea.setAttribute('readonly', ''); + textarea.style.position = 'fixed'; + textarea.style.left = '-9999px'; + textarea.style.top = '-9999px'; + window.document.body.appendChild(textarea); + textarea.select(); + window.document.execCommand('copy'); + window.document.body.removeChild(textarea); } catch (e) { okay = false; console.error(e); diff --git a/web/src/hooks/common/useSidebar.js b/web/src/hooks/common/useSidebar.js index 5dce44f9e..13d76fd86 100644 --- a/web/src/hooks/common/useSidebar.js +++ b/web/src/hooks/common/useSidebar.js @@ -21,6 +21,10 @@ import { useState, useEffect, useMemo, useContext } from 'react'; import { StatusContext } from '../../context/Status'; import { API } from '../../helpers'; +// 创建一个全局事件系统来同步所有useSidebar实例 +const sidebarEventTarget = new EventTarget(); +const SIDEBAR_REFRESH_EVENT = 'sidebar-refresh'; + export const useSidebar = () => { const [statusState] = useContext(StatusContext); const [userConfig, setUserConfig] = useState(null); @@ -124,9 +128,12 @@ export const useSidebar = () => { // 刷新用户配置的方法(供外部调用) const refreshUserConfig = async () => { - if (Object.keys(adminConfig).length > 0) { + if (Object.keys(adminConfig).length > 0) { await loadUserConfig(); } + + // 触发全局刷新事件,通知所有useSidebar实例更新 + sidebarEventTarget.dispatchEvent(new CustomEvent(SIDEBAR_REFRESH_EVENT)); }; // 加载用户配置 @@ -137,6 +144,21 @@ export const useSidebar = () => { } }, [adminConfig]); + // 监听全局刷新事件 + useEffect(() => { + const handleRefresh = () => { + if (Object.keys(adminConfig).length > 0) { + loadUserConfig(); + } + }; + + sidebarEventTarget.addEventListener(SIDEBAR_REFRESH_EVENT, handleRefresh); + + return () => { + sidebarEventTarget.removeEventListener(SIDEBAR_REFRESH_EVENT, handleRefresh); + }; + }, [adminConfig]); + // 计算最终的显示配置 const finalConfig = useMemo(() => { const result = {}; diff --git a/web/src/hooks/dashboard/useDashboardStats.jsx b/web/src/hooks/dashboard/useDashboardStats.jsx index aa9677a50..dbf3b67e7 100644 --- a/web/src/hooks/dashboard/useDashboardStats.jsx +++ b/web/src/hooks/dashboard/useDashboardStats.jsx @@ -102,7 +102,7 @@ export const useDashboardStats = ( }, { title: t('统计Tokens'), - value: isNaN(consumeTokens) ? 0 : consumeTokens, + value: isNaN(consumeTokens) ? 0 : consumeTokens.toLocaleString(), icon: , avatarColor: 'pink', trendData: trendData.tokens, diff --git a/web/src/hooks/usage-logs/useUsageLogsData.jsx b/web/src/hooks/usage-logs/useUsageLogsData.jsx index 81f3f539a..d434e7333 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.jsx +++ b/web/src/hooks/usage-logs/useUsageLogsData.jsx @@ -447,6 +447,8 @@ export const useLogsData = () => { other?.audio_input_seperate_price || false, other?.audio_input_token_count || 0, other?.audio_input_price || 0, + other?.image_generation_call || false, + other?.image_generation_call_price || 0, ); } expandDataLocal.push({ diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index c86fb0e7f..e935c10cc 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -837,6 +837,7 @@ "确定要充值 $": "Confirm to top up $", "微信/支付宝 实付金额:": "WeChat/Alipay actual payment amount:", "Stripe 实付金额:": "Stripe actual payment amount:", + "允许在 Stripe 支付中输入促销码": "Allow entering promotion codes during Stripe checkout", "支付中...": "Paying", "支付宝": "Alipay", "收益统计": "Income statistics", @@ -1889,6 +1890,10 @@ "确定要删除所有已自动禁用的密钥吗?": "Are you sure you want to delete all automatically disabled keys?", "此操作不可撤销,将永久删除已自动禁用的密钥": "This operation cannot be undone, and all automatically disabled keys will be permanently deleted.", "删除自动禁用密钥": "Delete auto disabled keys", + "确定要删除此密钥吗?": "Are you sure you want to delete this key?", + "此操作不可撤销,将永久删除该密钥": "This operation cannot be undone, and the key will be permanently deleted.", + "密钥已删除": "Key has been deleted", + "删除密钥失败": "Failed to delete key", "图标": "Icon", "模型图标": "Model icon", "请输入图标名称": "Please enter the icon name", @@ -2094,5 +2099,36 @@ "原价": "Original price", "优惠": "Discount", "折": "% off", - "节省": "Save" + "节省": "Save", + "今天": "Today", + "近 7 天": "Last 7 Days", + "本周": "This Week", + "本月": "This Month", + "近 30 天": "Last 30 Days", + "代理设置": "Proxy Settings", + "更新Worker设置": "Update Worker Settings", + "SSRF防护设置": "SSRF Protection Settings", + "配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全": "Configure Server-Side Request Forgery (SSRF) protection to secure internal network resources", + "SSRF防护详细说明": "SSRF protection prevents malicious users from using your server to access internal network resources. Configure whitelists for trusted domains/IPs and restrict allowed ports. Applies to file downloads, webhooks, and notifications.", + "启用SSRF防护(推荐开启以保护服务器安全)": "Enable SSRF Protection (Recommended for server security)", + "SSRF防护开关详细说明": "Master switch controls whether SSRF protection is enabled. When disabled, all SSRF checks are bypassed, allowing access to any URL. ⚠️ Only disable this feature in completely trusted environments.", + "允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)": "Allow access to private IP addresses (127.0.0.1, 192.168.x.x and other internal addresses)", + "私有IP访问详细说明": "⚠️ Security Warning: Enabling this allows access to internal network resources (localhost, private networks). Only enable if you need to access internal services and understand the security implications.", + "域名白名单": "Domain Whitelist", + "支持通配符格式,如:example.com, *.api.example.com": "Supports wildcard format, e.g.: example.com, *.api.example.com", + "域名白名单详细说明": "Whitelisted domains bypass all SSRF checks and are allowed direct access. Supports exact domains (example.com) or wildcards (*.api.example.com) for subdomains. When whitelist is empty, all domains go through SSRF validation.", + "输入域名后回车,如:example.com": "Enter domain and press Enter, e.g.: example.com", + "支持CIDR格式,如:8.8.8.8, 192.168.1.0/24": "Supports CIDR format, e.g.: 8.8.8.8, 192.168.1.0/24", + "IP白名单详细说明": "Controls which IP addresses are allowed access. Use single IPs (8.8.8.8) or CIDR notation (192.168.1.0/24). Empty whitelist allows all IPs (subject to private IP settings), non-empty whitelist only allows listed IPs.", + "输入IP地址后回车,如:8.8.8.8": "Enter IP address and press Enter, e.g.: 8.8.8.8", + "允许的端口": "Allowed Ports", + "支持单个端口和端口范围,如:80, 443, 8000-8999": "Supports single ports and port ranges, e.g.: 80, 443, 8000-8999", + "端口配置详细说明": "Restrict external requests to specific ports. Use single ports (80, 443) or ranges (8000-8999). Empty list allows all ports. Default includes common web ports.", + "输入端口后回车,如:80 或 8000-8999": "Enter port and press Enter, e.g.: 80 or 8000-8999", + "更新SSRF防护设置": "Update SSRF Protection Settings", + "对域名启用 IP 过滤(实验性)": "Enable IP filtering for domains (experimental)", + "域名IP过滤详细说明": "⚠️ This is an experimental option. A domain may resolve to multiple IPv4/IPv6 addresses. If enabled, ensure the IP filter list covers these addresses, otherwise access may fail.", + "域名黑名单": "Domain Blacklist", + "白名单": "Whitelist", + "黑名单": "Blacklist" } diff --git a/web/src/i18n/locales/zh.json b/web/src/i18n/locales/zh.json index 5c7904fc5..4b6b1e680 100644 --- a/web/src/i18n/locales/zh.json +++ b/web/src/i18n/locales/zh.json @@ -9,5 +9,29 @@ "语言": "语言", "展开侧边栏": "展开侧边栏", "关闭侧边栏": "关闭侧边栏", - "注销成功!": "注销成功!" + "注销成功!": "注销成功!", + "代理设置": "代理设置", + "更新Worker设置": "更新Worker设置", + "SSRF防护设置": "SSRF防护设置", + "配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全": "配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全", + "SSRF防护详细说明": "SSRF防护可防止恶意用户利用您的服务器访问内网资源。您可以配置受信任域名/IP的白名单,并限制允许的端口。适用于文件下载、Webhook回调和通知功能。", + "启用SSRF防护(推荐开启以保护服务器安全)": "启用SSRF防护(推荐开启以保护服务器安全)", + "SSRF防护开关详细说明": "总开关控制是否启用SSRF防护功能。关闭后将跳过所有SSRF检查,允许访问任意URL。⚠️ 仅在完全信任环境中关闭此功能。", + "允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)": "允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)", + "私有IP访问详细说明": "⚠️ 安全警告:启用此选项将允许访问内网资源(本地主机、私有网络)。仅在需要访问内部服务且了解安全风险的情况下启用。", + "域名白名单": "域名白名单", + "支持通配符格式,如:example.com, *.api.example.com": "支持通配符格式,如:example.com, *.api.example.com", + "域名白名单详细说明": "白名单中的域名将绕过所有SSRF检查,直接允许访问。支持精确域名(example.com)或通配符(*.api.example.com)匹配子域名。白名单为空时,所有域名都需要通过SSRF检查。", + "输入域名后回车,如:example.com": "输入域名后回车,如:example.com", + "IP白名单": "IP白名单", + "支持CIDR格式,如:8.8.8.8, 192.168.1.0/24": "支持CIDR格式,如:8.8.8.8, 192.168.1.0/24", + "IP白名单详细说明": "控制允许访问的IP地址。支持单个IP(8.8.8.8)或CIDR网段(192.168.1.0/24)。空白名单允许所有IP(但仍受私有IP设置限制),非空白名单仅允许列表中的IP访问。", + "输入IP地址后回车,如:8.8.8.8": "输入IP地址后回车,如:8.8.8.8", + "允许的端口": "允许的端口", + "支持单个端口和端口范围,如:80, 443, 8000-8999": "支持单个端口和端口范围,如:80, 443, 8000-8999", + "端口配置详细说明": "限制外部请求只能访问指定端口。支持单个端口(80, 443)或端口范围(8000-8999)。空列表允许所有端口。默认包含常用Web端口。", + "输入端口后回车,如:80 或 8000-8999": "输入端口后回车,如:80 或 8000-8999", + "更新SSRF防护设置": "更新SSRF防护设置", + "域名IP过滤详细说明": "⚠️此功能为实验性选项,域名可能解析到多个 IPv4/IPv6 地址,若开启,请确保 IP 过滤列表覆盖这些地址,否则可能导致访问失败。", + "允许在 Stripe 支付中输入促销码": "允许在 Stripe 支付中输入促销码" } diff --git a/web/src/pages/Setting/Payment/SettingsPaymentGatewayStripe.jsx b/web/src/pages/Setting/Payment/SettingsPaymentGatewayStripe.jsx index 2f4ea210e..e4ddea110 100644 --- a/web/src/pages/Setting/Payment/SettingsPaymentGatewayStripe.jsx +++ b/web/src/pages/Setting/Payment/SettingsPaymentGatewayStripe.jsx @@ -45,6 +45,7 @@ export default function SettingsPaymentGateway(props) { StripePriceId: '', StripeUnitPrice: 8.0, StripeMinTopUp: 1, + StripePromotionCodesEnabled: false, }); const [originInputs, setOriginInputs] = useState({}); const formApiRef = useRef(null); @@ -63,6 +64,10 @@ export default function SettingsPaymentGateway(props) { props.options.StripeMinTopUp !== undefined ? parseFloat(props.options.StripeMinTopUp) : 1, + StripePromotionCodesEnabled: + props.options.StripePromotionCodesEnabled !== undefined + ? props.options.StripePromotionCodesEnabled + : false, }; setInputs(currentInputs); setOriginInputs({ ...currentInputs }); @@ -114,6 +119,16 @@ export default function SettingsPaymentGateway(props) { value: inputs.StripeMinTopUp.toString(), }); } + if ( + originInputs['StripePromotionCodesEnabled'] !== + inputs.StripePromotionCodesEnabled && + inputs.StripePromotionCodesEnabled !== undefined + ) { + options.push({ + key: 'StripePromotionCodesEnabled', + value: inputs.StripePromotionCodesEnabled ? 'true' : 'false', + }); + } // 发送请求 const requestQueue = options.map((opt) => @@ -225,6 +240,15 @@ export default function SettingsPaymentGateway(props) { placeholder={t('例如:2,就是最低充值2$')} /> + + + diff --git a/web/src/pages/Setting/Ratio/ModelRatioSettings.jsx b/web/src/pages/Setting/Ratio/ModelRatioSettings.jsx index b40951261..ed982edcf 100644 --- a/web/src/pages/Setting/Ratio/ModelRatioSettings.jsx +++ b/web/src/pages/Setting/Ratio/ModelRatioSettings.jsx @@ -225,8 +225,8 @@ export default function ModelRatioSettings(props) {