diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index c87fcfceb..3e3ddc53b 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -38,21 +38,21 @@ jobs: - name: Build Backend (amd64) run: | go mod download - go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api + go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api - name: Build Backend (arm64) run: | sudo apt-get update DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu - CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64 + CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api-arm64 - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: files: | - one-api - one-api-arm64 + new-api + new-api-arm64 draft: true generate_release_notes: true env: diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 1bc786ac0..8eaf2d67a 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -39,12 +39,12 @@ jobs: - name: Build Backend run: | go mod download - go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos + go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o new-api-macos - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: - files: one-api-macos + files: new-api-macos draft: true generate_release_notes: true env: diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index de3d83d5e..30e864f34 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -41,12 +41,12 @@ jobs: - name: Build Backend run: | go mod download - go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe + go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o new-api.exe - name: Release uses: softprops/action-gh-release@v1 if: startsWith(github.ref, 'refs/tags/') with: - files: one-api.exe + files: new-api.exe draft: true generate_release_notes: true env: diff --git a/common/ip.go b/common/ip.go new file mode 100644 index 000000000..bfb64ee7f --- /dev/null +++ b/common/ip.go @@ -0,0 +1,22 @@ +package common + +import "net" + +func IsPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + private := []net.IPNet{ + {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, + {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, + {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, + } + + for _, privateNet := range private { + if privateNet.Contains(ip) { + return true + } + } + return false +} diff --git a/common/ssrf_protection.go b/common/ssrf_protection.go new file mode 100644 index 000000000..6f7d289f1 --- /dev/null +++ b/common/ssrf_protection.go @@ -0,0 +1,327 @@ +package common + +import ( + "fmt" + "net" + "net/url" + "strconv" + "strings" +) + +// SSRFProtection SSRF防护配置 +type SSRFProtection struct { + AllowPrivateIp bool + DomainFilterMode bool // true: 白名单, false: 黑名单 + DomainList []string // domain format, e.g. example.com, *.example.com + IpFilterMode bool // true: 白名单, false: 黑名单 + IpList []string // CIDR or single IP + AllowedPorts []int // 允许的端口范围 + ApplyIPFilterForDomain bool // 对域名启用IP过滤 +} + +// DefaultSSRFProtection 默认SSRF防护配置 +var DefaultSSRFProtection = &SSRFProtection{ + AllowPrivateIp: false, + DomainFilterMode: true, + DomainList: []string{}, + IpFilterMode: true, + IpList: []string{}, + AllowedPorts: []int{}, +} + +// isPrivateIP 检查IP是否为私有地址 +func isPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + // 检查私有网段 + private := []net.IPNet{ + {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8 + {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12 + {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16 + {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8 + {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地) + {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播) + {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留) + } + + for _, privateNet := range private { + if privateNet.Contains(ip) { + return true + } + } + + // 检查IPv6私有地址 + if ip.To4() == nil { + // IPv6 loopback + if ip.Equal(net.IPv6loopback) { + return true + } + // IPv6 link-local + if strings.HasPrefix(ip.String(), "fe80:") { + return true + } + // IPv6 unique local + if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") { + return true + } + } + + return false +} + +// parsePortRanges 解析端口范围配置 +// 支持格式: "80", "443", "8000-9000" +func parsePortRanges(portConfigs []string) ([]int, error) { + var ports []int + + for _, config := range portConfigs { + config = strings.TrimSpace(config) + if config == "" { + continue + } + + if strings.Contains(config, "-") { + // 处理端口范围 "8000-9000" + parts := strings.Split(config, "-") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid port range format: %s", config) + } + + startPort, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + return nil, fmt.Errorf("invalid start port in range %s: %v", config, err) + } + + endPort, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + return nil, fmt.Errorf("invalid end port in range %s: %v", config, err) + } + + if startPort > endPort { + return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config) + } + + if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 { + return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config) + } + + // 添加范围内的所有端口 + for port := startPort; port <= endPort; port++ { + ports = append(ports, port) + } + } else { + // 处理单个端口 "80" + port, err := strconv.Atoi(config) + if err != nil { + return nil, fmt.Errorf("invalid port number: %s", config) + } + + if port < 1 || port > 65535 { + return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port) + } + + ports = append(ports, port) + } + } + + return ports, nil +} + +// isAllowedPort 检查端口是否被允许 +func (p *SSRFProtection) isAllowedPort(port int) bool { + if len(p.AllowedPorts) == 0 { + return true // 如果没有配置端口限制,则允许所有端口 + } + + for _, allowedPort := range p.AllowedPorts { + if port == allowedPort { + return true + } + } + return false +} + +// isDomainWhitelisted 检查域名是否在白名单中 +func isDomainListed(domain string, list []string) bool { + if len(list) == 0 { + return false + } + + domain = strings.ToLower(domain) + for _, item := range list { + item = strings.ToLower(strings.TrimSpace(item)) + if item == "" { + continue + } + // 精确匹配 + if domain == item { + return true + } + // 通配符匹配 (*.example.com) + if strings.HasPrefix(item, "*.") { + suffix := strings.TrimPrefix(item, "*.") + if strings.HasSuffix(domain, "."+suffix) || domain == suffix { + return true + } + } + } + return false +} + +func (p *SSRFProtection) isDomainAllowed(domain string) bool { + listed := isDomainListed(domain, p.DomainList) + if p.DomainFilterMode { // 白名单 + return listed + } + // 黑名单 + return !listed +} + +// isIPWhitelisted 检查IP是否在白名单中 + +func isIPListed(ip net.IP, list []string) bool { + if len(list) == 0 { + return false + } + + for _, whitelistCIDR := range list { + _, network, err := net.ParseCIDR(whitelistCIDR) + if err != nil { + // 尝试作为单个IP处理 + if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil { + if ip.Equal(whitelistIP) { + return true + } + } + continue + } + + if network.Contains(ip) { + return true + } + } + return false +} + +// IsIPAccessAllowed 检查IP是否允许访问 +func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool { + // 私有IP限制 + if isPrivateIP(ip) && !p.AllowPrivateIp { + return false + } + + listed := isIPListed(ip, p.IpList) + if p.IpFilterMode { // 白名单 + return listed + } + // 黑名单 + return !listed +} + +// ValidateURL 验证URL是否安全 +func (p *SSRFProtection) ValidateURL(urlStr string) error { + // 解析URL + u, err := url.Parse(urlStr) + if err != nil { + return fmt.Errorf("invalid URL format: %v", err) + } + + // 只允许HTTP/HTTPS协议 + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme) + } + + // 解析主机和端口 + host, portStr, err := net.SplitHostPort(u.Host) + if err != nil { + // 没有端口,使用默认端口 + host = u.Hostname() + if u.Scheme == "https" { + portStr = "443" + } else { + portStr = "80" + } + } + + // 验证端口 + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("invalid port: %s", portStr) + } + + if !p.isAllowedPort(port) { + return fmt.Errorf("port %d is not allowed", port) + } + + // 如果 host 是 IP,则跳过域名检查 + if ip := net.ParseIP(host); ip != nil { + if !p.IsIPAccessAllowed(ip) { + if isPrivateIP(ip) { + return fmt.Errorf("private IP address not allowed: %s", ip.String()) + } + if p.IpFilterMode { + return fmt.Errorf("ip not in whitelist: %s", ip.String()) + } + return fmt.Errorf("ip in blacklist: %s", ip.String()) + } + return nil + } + + // 先进行域名过滤 + if !p.isDomainAllowed(host) { + if p.DomainFilterMode { + return fmt.Errorf("domain not in whitelist: %s", host) + } + return fmt.Errorf("domain in blacklist: %s", host) + } + + // 若未启用对域名应用IP过滤,则到此通过 + if !p.ApplyIPFilterForDomain { + return nil + } + + // 解析域名对应IP并检查 + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("DNS resolution failed for %s: %v", host, err) + } + for _, ip := range ips { + if !p.IsIPAccessAllowed(ip) { + if isPrivateIP(ip) && !p.AllowPrivateIp { + return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String()) + } + if p.IpFilterMode { + return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String()) + } + return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String()) + } + } + return nil +} + +// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL +func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error { + // 如果SSRF防护被禁用,直接返回成功 + if !enableSSRFProtection { + return nil + } + + // 解析端口范围配置 + allowedPortInts, err := parsePortRanges(allowedPorts) + if err != nil { + return fmt.Errorf("request reject - invalid port configuration: %v", err) + } + + protection := &SSRFProtection{ + AllowPrivateIp: allowPrivateIp, + DomainFilterMode: domainFilterMode, + DomainList: domainList, + IpFilterMode: ipFilterMode, + IpList: ipList, + AllowedPorts: allowedPortInts, + ApplyIPFilterForDomain: applyIPFilterForDomain, + } + return protection.ValidateURL(urlStr) +} diff --git a/common/sys_log.go b/common/sys_log.go index 478015f07..b29adc3e6 100644 --- a/common/sys_log.go +++ b/common/sys_log.go @@ -2,9 +2,10 @@ package common import ( "fmt" - "github.com/gin-gonic/gin" "os" "time" + + "github.com/gin-gonic/gin" ) func SysLog(s string) { @@ -22,3 +23,33 @@ func FatalLog(v ...any) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) os.Exit(1) } + +func LogStartupSuccess(startTime time.Time, port string) { + + duration := time.Since(startTime) + durationMs := duration.Milliseconds() + + // Get network IPs + networkIps := GetNetworkIps() + + // Print blank line for spacing + fmt.Fprintf(gin.DefaultWriter, "\n") + + // Print the main success message + fmt.Fprintf(gin.DefaultWriter, " \033[32m%s %s\033[0m ready in %d ms\n", SystemName, Version, durationMs) + fmt.Fprintf(gin.DefaultWriter, "\n") + + // Skip fancy startup message in container environments + if !IsRunningInContainer() { + // Print local URL + fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mLocal:\033[0m http://localhost:%s/\n", port) + } + + // Print network URLs + for _, ip := range networkIps { + fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port) + } + + // Print blank line for spacing + fmt.Fprintf(gin.DefaultWriter, "\n") +} diff --git a/common/utils.go b/common/utils.go index 883abfd1a..21f72ec6a 100644 --- a/common/utils.go +++ b/common/utils.go @@ -68,6 +68,78 @@ func GetIp() (ip string) { return } +func GetNetworkIps() []string { + var networkIps []string + ips, err := net.InterfaceAddrs() + if err != nil { + log.Println(err) + return networkIps + } + + for _, a := range ips { + if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + ip := ipNet.IP.String() + // Include common private network ranges + if strings.HasPrefix(ip, "10.") || + strings.HasPrefix(ip, "172.") || + strings.HasPrefix(ip, "192.168.") { + networkIps = append(networkIps, ip) + } + } + } + } + return networkIps +} + +// IsRunningInContainer detects if the application is running inside a container +func IsRunningInContainer() bool { + // Method 1: Check for .dockerenv file (Docker containers) + if _, err := os.Stat("/.dockerenv"); err == nil { + return true + } + + // Method 2: Check cgroup for container indicators + if data, err := os.ReadFile("/proc/1/cgroup"); err == nil { + content := string(data) + if strings.Contains(content, "docker") || + strings.Contains(content, "containerd") || + strings.Contains(content, "kubepods") || + strings.Contains(content, "/lxc/") { + return true + } + } + + // Method 3: Check environment variables commonly set by container runtimes + containerEnvVars := []string{ + "KUBERNETES_SERVICE_HOST", + "DOCKER_CONTAINER", + "container", + } + + for _, envVar := range containerEnvVars { + if os.Getenv(envVar) != "" { + return true + } + } + + // Method 4: Check if init process is not the traditional init + if data, err := os.ReadFile("/proc/1/comm"); err == nil { + comm := strings.TrimSpace(string(data)) + // In containers, process 1 is often not "init" or "systemd" + if comm != "init" && comm != "systemd" { + // Additional check: if it's a common container entrypoint + if strings.Contains(comm, "docker") || + strings.Contains(comm, "containerd") || + strings.Contains(comm, "runc") { + return true + } + } + } + + return false +} + var sizeKB = 1024 var sizeMB = sizeKB * 1024 var sizeGB = sizeMB * 1024 diff --git a/constant/task.go b/constant/task.go index 21790145b..e174fd60e 100644 --- a/constant/task.go +++ b/constant/task.go @@ -11,8 +11,10 @@ const ( SunoActionMusic = "MUSIC" SunoActionLyrics = "LYRICS" - TaskActionGenerate = "generate" - TaskActionTextGenerate = "textGenerate" + TaskActionGenerate = "generate" + TaskActionTextGenerate = "textGenerate" + TaskActionFirstTailGenerate = "firstTailGenerate" + TaskActionReferenceGenerate = "referenceGenerate" ) var SunoModel2Action = map[string]string{ diff --git a/controller/channel-test.go b/controller/channel-test.go index 5a668c488..9ea6eed75 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -90,6 +90,11 @@ func testChannel(channel *model.Channel, testModel string) testResult { requestPath = "/v1/embeddings" // 修改请求路径 } + // VolcEngine 图像生成模型 + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + } + c.Request = &http.Request{ Method: "POST", URL: &url.URL{Path: requestPath}, // 使用动态路径 @@ -109,6 +114,21 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } + // 重新检查模型类型并更新请求路径 + if strings.Contains(strings.ToLower(testModel), "embedding") || + strings.HasPrefix(testModel, "m3e") || + strings.Contains(testModel, "bge-") || + strings.Contains(testModel, "embed") || + channel.Type == constant.ChannelTypeMokaAI { + requestPath = "/v1/embeddings" + c.Request.URL.Path = requestPath + } + + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + c.Request.URL.Path = requestPath + } + cache, err := model.GetUserCache(1) if err != nil { return testResult{ @@ -140,6 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult { if c.Request.URL.Path == "/v1/embeddings" { relayFormat = types.RelayFormatEmbedding } + if c.Request.URL.Path == "/v1/images/generations" { + relayFormat = types.RelayFormatOpenAIImage + } info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) @@ -201,6 +224,22 @@ func testChannel(channel *model.Channel, testModel string) testResult { } // 调用专门用于 Embedding 的转换函数 convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest) + } else if info.RelayMode == relayconstant.RelayModeImagesGenerations { + // 创建一个 ImageRequest + prompt := "cat" + if request.Prompt != nil { + if promptStr, ok := request.Prompt.(string); ok && promptStr != "" { + prompt = promptStr + } + } + imageRequest := dto.ImageRequest{ + Prompt: prompt, + Model: request.Model, + N: uint(request.N), + Size: request.Size, + } + // 调用专门用于图像生成的转换函数 + convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest) } else { // 对其他所有请求类型(如 Chat),保持原有逻辑 convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request) diff --git a/controller/channel.go b/controller/channel.go index 403eb04cc..5d075f3c5 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -8,6 +8,7 @@ import ( "one-api/constant" "one-api/dto" "one-api/model" + "one-api/service" "strconv" "strings" @@ -188,6 +189,8 @@ func FetchUpstreamModels(c *gin.Context) { url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader case constant.ChannelTypeAli: url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) + case constant.ChannelTypeZhipu_v4: + url = fmt.Sprintf("%s/api/paas/v4/models", baseURL) default: url = fmt.Sprintf("%s/v1/models", baseURL) } @@ -501,9 +504,10 @@ func validateChannel(channel *model.Channel, isAdd bool) error { } type AddChannelRequest struct { - Mode string `json:"mode"` - MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` - Channel *model.Channel `json:"channel"` + Mode string `json:"mode"` + MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` + BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"` + Channel *model.Channel `json:"channel"` } func getVertexArrayKeys(keys string) ([]string, error) { @@ -616,6 +620,13 @@ func AddChannel(c *gin.Context) { } localChannel := addChannelRequest.Channel localChannel.Key = key + if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 { + keyPrefix := localChannel.Key + if len(localChannel.Key) > 8 { + keyPrefix = localChannel.Key[:8] + } + localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix) + } channels = append(channels, *localChannel) } err = model.BatchInsertChannels(channels) @@ -623,6 +634,7 @@ func AddChannel(c *gin.Context) { common.ApiError(c, err) return } + service.ResetProxyClientCache() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -884,6 +896,7 @@ func UpdateChannel(c *gin.Context) { return } model.InitChannelCache() + service.ResetProxyClientCache() channel.Key = "" clearChannelInfo(&channel.Channel) c.JSON(http.StatusOK, gin.H{ @@ -1093,8 +1106,8 @@ func CopyChannel(c *gin.Context) { // MultiKeyManageRequest represents the request for multi-key management operations type MultiKeyManageRequest struct { ChannelId int `json:"channel_id"` - Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status" - KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions + Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status" + KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions Page int `json:"page,omitempty"` // for get_key_status pagination PageSize int `json:"page_size,omitempty"` // for get_key_status pagination Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all @@ -1422,6 +1435,86 @@ func ManageMultiKeys(c *gin.Context) { }) return + case "delete_key": + if request.KeyIndex == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "未指定要删除的密钥索引", + }) + return + } + + keyIndex := *request.KeyIndex + if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "密钥索引超出范围", + }) + return + } + + keys := channel.GetKeys() + var remainingKeys []string + var newStatusList = make(map[int]int) + var newDisabledTime = make(map[int]int64) + var newDisabledReason = make(map[int]string) + + newIndex := 0 + for i, key := range keys { + // 跳过要删除的密钥 + if i == keyIndex { + continue + } + + remainingKeys = append(remainingKeys, key) + + // 保留其他密钥的状态信息,重新索引 + if channel.ChannelInfo.MultiKeyStatusList != nil { + if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 { + newStatusList[newIndex] = status + } + } + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists { + newDisabledTime[newIndex] = t + } + } + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists { + newDisabledReason[newIndex] = r + } + } + newIndex++ + } + + if len(remainingKeys) == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "不能删除最后一个密钥", + }) + return + } + + // Update channel with remaining keys + channel.Key = strings.Join(remainingKeys, "\n") + channel.ChannelInfo.MultiKeySize = len(remainingKeys) + channel.ChannelInfo.MultiKeyStatusList = newStatusList + channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime + channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "密钥已删除", + }) + return + case "delete_disabled_keys": keys := channel.GetKeys() var remainingKeys []string diff --git a/controller/option.go b/controller/option.go index e5f2b75b0..7d1c676f5 100644 --- a/controller/option.go +++ b/controller/option.go @@ -128,6 +128,33 @@ func UpdateOption(c *gin.Context) { }) return } + case "ImageRatio": + err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "图片倍率设置失败: " + err.Error(), + }) + return + } + case "AudioRatio": + err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "音频倍率设置失败: " + err.Error(), + }) + return + } + case "AudioCompletionRatio": + err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "音频补全倍率设置失败: " + err.Error(), + }) + return + } case "ModelRequestRateLimitGroup": err = setting.CheckModelRequestRateLimitGroup(option.Value.(string)) if err != nil { diff --git a/controller/setup.go b/controller/setup.go index 44a7b3a73..3ae255e94 100644 --- a/controller/setup.go +++ b/controller/setup.go @@ -53,7 +53,7 @@ func GetSetup(c *gin.Context) { func PostSetup(c *gin.Context) { // Check if setup is already completed if constant.Setup { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "系统已经初始化完成", }) @@ -66,7 +66,7 @@ func PostSetup(c *gin.Context) { var req SetupRequest err := c.ShouldBindJSON(&req) if err != nil { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "请求参数有误", }) @@ -77,7 +77,7 @@ func PostSetup(c *gin.Context) { if !rootExists { // Validate username length: max 12 characters to align with model.User validation if len(req.Username) > 12 { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "用户名长度不能超过12个字符", }) @@ -85,7 +85,7 @@ func PostSetup(c *gin.Context) { } // Validate password if req.Password != req.ConfirmPassword { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "两次输入的密码不一致", }) @@ -93,7 +93,7 @@ func PostSetup(c *gin.Context) { } if len(req.Password) < 8 { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "密码长度至少为8个字符", }) @@ -103,7 +103,7 @@ func PostSetup(c *gin.Context) { // Create root user hashedPassword, err := common.Password2Hash(req.Password) if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "系统错误: " + err.Error(), }) @@ -120,7 +120,7 @@ func PostSetup(c *gin.Context) { } err = model.DB.Create(&rootUser).Error if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "创建管理员账号失败: " + err.Error(), }) @@ -135,7 +135,7 @@ func PostSetup(c *gin.Context) { // Save operation modes to database for persistence err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled)) if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "保存自用模式设置失败: " + err.Error(), }) @@ -144,7 +144,7 @@ func PostSetup(c *gin.Context) { err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled)) if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "保存演示站点模式设置失败: " + err.Error(), }) @@ -160,7 +160,7 @@ func PostSetup(c *gin.Context) { } err = model.DB.Create(&setup).Error if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "系统初始化失败: " + err.Error(), }) diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index d462acb4b..9a568d857 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -217,7 +217,7 @@ func genStripeLink(referenceId string, customerId string, email string, amount i params := &stripe.CheckoutSessionParams{ ClientReferenceID: stripe.String(referenceId), - SuccessURL: stripe.String(system_setting.ServerAddress + "/log"), + SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"), CancelURL: stripe.String(system_setting.ServerAddress + "/topup"), LineItems: []*stripe.CheckoutSessionLineItemParams{ { @@ -225,7 +225,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i Quantity: stripe.Int64(amount), }, }, - Mode: stripe.String(string(stripe.CheckoutSessionModePayment)), + Mode: stripe.String(string(stripe.CheckoutSessionModePayment)), + AllowPromotionCodes: stripe.Bool(setting.StripePromotionCodesEnabled), } if "" == customerId { diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 8791f516e..d6d6e0848 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -19,4 +19,12 @@ const ( type ChannelOtherSettings struct { AzureResponsesVersion string `json:"azure_responses_version,omitempty"` VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" + OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"` +} + +func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool { + if s == nil || s.OpenRouterEnterprise == nil { + return false + } + return *s.OpenRouterEnterprise } diff --git a/dto/gemini.go b/dto/gemini.go index 5df67ba0b..80552aade 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -14,7 +14,30 @@ type GeminiChatRequest struct { SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"` GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"` Tools json.RawMessage `json:"tools,omitempty"` + ToolConfig *ToolConfig `json:"toolConfig,omitempty"` SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` + CachedContent string `json:"cachedContent,omitempty"` +} + +type ToolConfig struct { + FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"` + RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"` +} + +type FunctionCallingConfig struct { + Mode FunctionCallingConfigMode `json:"mode,omitempty"` + AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"` +} +type FunctionCallingConfigMode string + +type RetrievalConfig struct { + LatLng *LatLng `json:"latLng,omitempty"` + LanguageCode string `json:"languageCode,omitempty"` +} + +type LatLng struct { + Latitude *float64 `json:"latitude,omitempty"` + Longitude *float64 `json:"longitude,omitempty"` } func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { @@ -228,6 +251,7 @@ type GeminiChatTool struct { GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"` CodeExecution any `json:"codeExecution,omitempty"` FunctionDeclarations any `json:"functionDeclarations,omitempty"` + URLContext any `json:"urlContext,omitempty"` } type GeminiChatGenerationConfig struct { @@ -239,12 +263,20 @@ type GeminiChatGenerationConfig struct { StopSequences []string `json:"stopSequences,omitempty"` ResponseMimeType string `json:"responseMimeType,omitempty"` ResponseSchema any `json:"responseSchema,omitempty"` + ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"` + PresencePenalty *float32 `json:"presencePenalty,omitempty"` + FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"` + ResponseLogprobs bool `json:"responseLogprobs,omitempty"` + Logprobs *int32 `json:"logprobs,omitempty"` + MediaResolution MediaResolution `json:"mediaResolution,omitempty"` Seed int64 `json:"seed,omitempty"` ResponseModalities []string `json:"responseModalities,omitempty"` ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config } +type MediaResolution string + type GeminiChatCandidate struct { Content GeminiChatContent `json:"content"` FinishReason *string `json:"finishReason"` diff --git a/dto/openai_request.go b/dto/openai_request.go index cd05a63c9..191fa638f 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -772,11 +772,12 @@ type OpenAIResponsesRequest struct { Instructions json.RawMessage `json:"instructions,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"` - ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"` PreviousResponseID string `json:"previous_response_id,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"` ServiceTier string `json:"service_tier,omitempty"` - Store bool `json:"store,omitempty"` + Store json.RawMessage `json:"store,omitempty"` + PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"` Stream bool `json:"stream,omitempty"` Temperature float64 `json:"temperature,omitempty"` Text json.RawMessage `json:"text,omitempty"` diff --git a/dto/openai_response.go b/dto/openai_response.go index 966748cb5..6353c15ff 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -6,6 +6,10 @@ import ( "one-api/types" ) +const ( + ResponsesOutputTypeImageGenerationCall = "image_generation_call" +) + type SimpleResponse struct { Usage `json:"usage"` Error any `json:"error"` @@ -273,6 +277,42 @@ func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError { return GetOpenAIError(o.Error) } +func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool { + if len(o.Output) == 0 { + return false + } + for _, output := range o.Output { + if output.Type == ResponsesOutputTypeImageGenerationCall { + return true + } + } + return false +} + +func (o *OpenAIResponsesResponse) GetQuality() string { + if len(o.Output) == 0 { + return "" + } + for _, output := range o.Output { + if output.Type == ResponsesOutputTypeImageGenerationCall { + return output.Quality + } + } + return "" +} + +func (o *OpenAIResponsesResponse) GetSize() string { + if len(o.Output) == 0 { + return "" + } + for _, output := range o.Output { + if output.Type == ResponsesOutputTypeImageGenerationCall { + return output.Size + } + } + return "" +} + type IncompleteDetails struct { Reasoning string `json:"reasoning"` } @@ -283,6 +323,8 @@ type ResponsesOutput struct { Status string `json:"status"` Role string `json:"role"` Content []ResponsesOutputContent `json:"content"` + Quality string `json:"quality"` + Size string `json:"size"` } type ResponsesOutputContent struct { diff --git a/main.go b/main.go index 0caf53617..e12dddc5a 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "embed" "fmt" "log" @@ -16,6 +17,8 @@ import ( "one-api/setting/ratio_setting" "os" "strconv" + "strings" + "time" "github.com/bytedance/gopkg/util/gopool" "github.com/gin-contrib/sessions" @@ -33,6 +36,7 @@ var buildFS embed.FS var indexPage []byte func main() { + startTime := time.Now() err := InitResources() if err != nil { @@ -145,11 +149,31 @@ func main() { }) server.Use(sessions.Sessions("session", store)) + analyticsInjectBuilder := &strings.Builder{} + if os.Getenv("UMAMI_WEBSITE_ID") != "" { + umamiSiteID := os.Getenv("UMAMI_WEBSITE_ID") + umamiScriptURL := os.Getenv("UMAMI_SCRIPT_URL") + if umamiScriptURL == "" { + umamiScriptURL = "https://analytics.umami.is/script.js" + } + analyticsInjectBuilder.WriteString("") + } + analyticsInject := analyticsInjectBuilder.String() + indexPage = bytes.ReplaceAll(indexPage, []byte("\n"), []byte(analyticsInject)) + router.SetRouter(server, buildFS, indexPage) var port = os.Getenv("PORT") if port == "" { port = strconv.Itoa(*common.Port) } + + // Log startup success message + common.LogStartupSuccess(startTime, port) + err = server.Run(":" + port) if err != nil { common.FatalLog("failed to start HTTP server: " + err.Error()) @@ -204,4 +228,4 @@ func InitResources() error { return err } return nil -} \ No newline at end of file +} diff --git a/model/option.go b/model/option.go index fefee4e7d..9ace8fece 100644 --- a/model/option.go +++ b/model/option.go @@ -82,6 +82,7 @@ func InitOptionMap() { common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret common.OptionMap["StripePriceId"] = setting.StripePriceId common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64) + common.OptionMap["StripePromotionCodesEnabled"] = strconv.FormatBool(setting.StripePromotionCodesEnabled) common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["Chats"] = setting.Chats2JsonString() common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString() @@ -112,6 +113,9 @@ func InitOptionMap() { common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString() common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString() + common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString() + common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString() + common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink //common.OptionMap["ChatLink"] = common.ChatLink //common.OptionMap["ChatLink2"] = common.ChatLink2 @@ -327,6 +331,8 @@ func updateOptionMap(key string, value string) (err error) { setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64) case "StripeMinTopUp": setting.StripeMinTopUp, _ = strconv.Atoi(value) + case "StripePromotionCodesEnabled": + setting.StripePromotionCodesEnabled = value == "true" case "TopupGroupRatio": err = common.UpdateTopupGroupRatioByJSONString(value) case "GitHubClientId": @@ -397,6 +403,12 @@ func updateOptionMap(key string, value string) (err error) { err = ratio_setting.UpdateModelPriceByJSONString(value) case "CacheRatio": err = ratio_setting.UpdateCacheRatioByJSONString(value) + case "ImageRatio": + err = ratio_setting.UpdateImageRatioByJSONString(value) + case "AudioRatio": + err = ratio_setting.UpdateAudioRatioByJSONString(value) + case "AudioCompletionRatio": + err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value) case "TopUpLink": common.TopUpLink = value //case "ChatLink": diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index a065caff7..79a0f7060 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -265,6 +265,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http resp, err := client.Do(req) if err != nil { + logger.LogError(c, "do request failed: "+err.Error()) return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed")) } if resp == nil { diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 72d0f9890..5ac7ce998 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -21,6 +21,10 @@ var awsModelIDMap = map[string]string{ "nova-lite-v1:0": "amazon.nova-lite-v1:0", "nova-pro-v1:0": "amazon.nova-pro-v1:0", "nova-premier-v1:0": "amazon.nova-premier-v1:0", + "nova-canvas-v1:0": "amazon.nova-canvas-v1:0", + "nova-reel-v1:0": "amazon.nova-reel-v1:0", + "nova-reel-v1:1": "amazon.nova-reel-v1:1", + "nova-sonic-v1:0": "amazon.nova-sonic-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ @@ -82,10 +86,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{ "apac": true, }, "amazon.nova-premier-v1:0": { + "us": true, + }, + "amazon.nova-canvas-v1:0": { "us": true, "eu": true, "apac": true, - }} + }, + "amazon.nova-reel-v1:0": { + "us": true, + "eu": true, + "apac": true, + }, + "amazon.nova-reel-v1:1": { + "us": true, + }, + "amazon.nova-sonic-v1:0": { + "us": true, + "eu": true, + "apac": true, + }, +} var awsRegionCrossModelPrefixMap = map[string]string{ "us": "us", diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 17d732ab0..962f8794a 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -3,17 +3,17 @@ package deepseek import ( "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/claude" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/types" "strings" - - "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { - adaptor := openai.Adaptor{} + adaptor := claude.Adaptor{} return adaptor.ConvertClaudeRequest(c, info, req) } @@ -44,14 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { fimBaseUrl := info.ChannelBaseUrl - if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { - fimBaseUrl += "/beta" - } - switch info.RelayMode { - case constant.RelayModeCompletions: - return fmt.Sprintf("%s/completions", fimBaseUrl), nil + switch info.RelayFormat { + case types.RelayFormatClaude: + return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil default: - return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil + if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { + fimBaseUrl += "/beta" + } + switch info.RelayMode { + case constant.RelayModeCompletions: + return fmt.Sprintf("%s/completions", fimBaseUrl), nil + default: + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil + } } } @@ -87,12 +92,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) + switch info.RelayFormat { + case types.RelayFormatClaude: + if info.IsStream { + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + } else { + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + } + default: + adaptor := openai.Adaptor{} + return adaptor.DoResponse(c, resp, info) } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 4968f78fe..57542aa5a 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -215,8 +215,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { - if strings.HasSuffix(info.RequestURLPath, ":embedContent") || - strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") { + if strings.Contains(info.RequestURLPath, ":embedContent") || + strings.Contains(info.RequestURLPath, ":batchEmbedContents") { return NativeGeminiEmbeddingHandler(c, resp, info) } if info.IsStream { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index eb4afbae1..c8e9c7579 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -23,6 +23,7 @@ import ( "github.com/gin-gonic/gin" ) +// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob var geminiSupportedMimeTypes = map[string]bool{ "application/pdf": true, "audio/mpeg": true, @@ -30,6 +31,7 @@ var geminiSupportedMimeTypes = map[string]bool{ "audio/wav": true, "image/png": true, "image/jpeg": true, + "image/webp": true, "text/plain": true, "video/mov": true, "video/mpeg": true, @@ -243,6 +245,7 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools)) googleSearch := false codeExecution := false + urlContext := false for _, tool := range textRequest.Tools { if tool.Function.Name == "googleSearch" { googleSearch = true @@ -252,6 +255,10 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i codeExecution = true continue } + if tool.Function.Name == "urlContext" { + urlContext = true + continue + } if tool.Function.Parameters != nil { params, ok := tool.Function.Parameters.(map[string]interface{}) @@ -279,6 +286,11 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i GoogleSearch: make(map[string]string), }) } + if urlContext { + geminiTools = append(geminiTools, dto.GeminiChatTool{ + URLContext: make(map[string]string), + }) + } if len(functions) > 0 { geminiTools = append(geminiTools, dto.GeminiChatTool{ FunctionDeclarations: functions, diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go index e290c239d..f24976bb3 100644 --- a/relay/channel/moonshot/adaptor.go +++ b/relay/channel/moonshot/adaptor.go @@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { - adaptor := openai.Adaptor{} + adaptor := claude.Adaptor{} return adaptor.ConvertClaudeRequest(c, info, req) } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index d6b5b697e..bafe73b92 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -10,6 +10,7 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/types" + "strings" "github.com/gin-gonic/gin" ) @@ -17,10 +18,7 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { openaiAdaptor := openai.Adaptor{} @@ -31,32 +29,21 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{ IncludeUsage: true, } - return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest)) + // map to ollama chat request (Claude -> OpenAI -> Ollama chat) + return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest)) } -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { - //TODO implement me - return nil, errors.New("not implemented") -} +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - if info.RelayFormat == types.RelayFormatClaude { - return info.ChannelBaseUrl + "/v1/chat/completions", nil - } - switch info.RelayMode { - case relayconstant.RelayModeEmbeddings: - return info.ChannelBaseUrl + "/api/embed", nil - default: - return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil - } + if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil } + if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil } + return info.ChannelBaseUrl + "/api/chat", nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -66,10 +53,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel } func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") + if request == nil { return nil, errors.New("request is nil") } + // decide generate or chat + if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { + return openAIToGenerate(c, request) } - return requestOpenAI2Ollama(c, request) + return openAIChatToOllamaChat(c, request) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -80,10 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return requestOpenAI2Embeddings(request), nil } -func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { - // TODO implement me - return nil, errors.New("not implemented") -} +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) @@ -92,15 +78,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case relayconstant.RelayModeEmbeddings: - usage, err = ollamaEmbeddingHandler(c, info, resp) + return ollamaEmbeddingHandler(c, info, resp) default: if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) + return ollamaStreamHandler(c, info, resp) } + return ollamaChatHandler(c, info, resp) } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index 317c2a4a1..45e49ab43 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -2,48 +2,69 @@ package ollama import ( "encoding/json" - "one-api/dto" ) -type OllamaRequest struct { - Model string `json:"model,omitempty"` - Messages []dto.Message `json:"messages,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - Seed float64 `json:"seed,omitempty"` - Topp float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - Tools []dto.ToolCallRequest `json:"tools,omitempty"` - ResponseFormat any `json:"response_format,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - Suffix any `json:"suffix,omitempty"` - StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"` - Prompt any `json:"prompt,omitempty"` - Think json.RawMessage `json:"think,omitempty"` +type OllamaChatMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` + ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"` + ToolName string `json:"tool_name,omitempty"` + Thinking json.RawMessage `json:"thinking,omitempty"` } -type Options struct { - Seed int `json:"seed,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` +type OllamaToolFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters interface{} `json:"parameters,omitempty"` +} + +type OllamaTool struct { + Type string `json:"type"` + Function OllamaToolFunction `json:"function"` +} + +type OllamaToolCall struct { + Function struct { + Name string `json:"name"` + Arguments interface{} `json:"arguments"` + } `json:"function"` +} + +type OllamaChatRequest struct { + Model string `json:"model"` + Messages []OllamaChatMessage `json:"messages"` + Tools interface{} `json:"tools,omitempty"` + Format interface{} `json:"format,omitempty"` + Stream bool `json:"stream,omitempty"` + Options map[string]any `json:"options,omitempty"` + KeepAlive interface{} `json:"keep_alive,omitempty"` + Think json.RawMessage `json:"think,omitempty"` +} + +type OllamaGenerateRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + Suffix string `json:"suffix,omitempty"` + Images []string `json:"images,omitempty"` + Format interface{} `json:"format,omitempty"` + Stream bool `json:"stream,omitempty"` + Options map[string]any `json:"options,omitempty"` + KeepAlive interface{} `json:"keep_alive,omitempty"` + Think json.RawMessage `json:"think,omitempty"` } type OllamaEmbeddingRequest struct { - Model string `json:"model,omitempty"` - Input []string `json:"input"` - Options *Options `json:"options,omitempty"` + Model string `json:"model"` + Input interface{} `json:"input"` + Options map[string]any `json:"options,omitempty"` + Dimensions int `json:"dimensions,omitempty"` } type OllamaEmbeddingResponse struct { - Error string `json:"error,omitempty"` - Model string `json:"model"` - Embedding [][]float64 `json:"embeddings,omitempty"` + Error string `json:"error,omitempty"` + Model string `json:"model"` + Embeddings [][]float64 `json:"embeddings"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` } + diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 27c67b4ec..3b67f9525 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -1,6 +1,7 @@ package ollama import ( + "encoding/json" "fmt" "io" "net/http" @@ -14,121 +15,176 @@ import ( "github.com/gin-gonic/gin" ) -func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) { - messages := make([]dto.Message, 0, len(request.Messages)) - for _, message := range request.Messages { - if !message.IsStringContent() { - mediaMessages := message.ParseContent() - for j, mediaMessage := range mediaMessages { - if mediaMessage.Type == dto.ContentTypeImageURL { - imageUrl := mediaMessage.GetImageMedia() - // check if not base64 - if strings.HasPrefix(imageUrl.Url, "http") { - fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama") - if err != nil { - return nil, err +func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) { + chatReq := &OllamaChatRequest{ + Model: r.Model, + Stream: r.Stream, + Options: map[string]any{}, + Think: r.Think, + } + if r.ResponseFormat != nil { + if r.ResponseFormat.Type == "json" { + chatReq.Format = "json" + } else if r.ResponseFormat.Type == "json_schema" { + if len(r.ResponseFormat.JsonSchema) > 0 { + var schema any + _ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema) + chatReq.Format = schema + } + } + } + + // options mapping + if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature } + if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP } + if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK } + if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty } + if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty } + if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) } + if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) } + + if r.Stop != nil { + switch v := r.Stop.(type) { + case string: + chatReq.Options["stop"] = []string{v} + case []string: + chatReq.Options["stop"] = v + case []any: + arr := make([]string,0,len(v)) + for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } } + if len(arr)>0 { chatReq.Options["stop"] = arr } + } + } + + if len(r.Tools) > 0 { + tools := make([]OllamaTool,0,len(r.Tools)) + for _, t := range r.Tools { + tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}}) + } + chatReq.Tools = tools + } + + chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages)) + for _, m := range r.Messages { + var textBuilder strings.Builder + var images []string + if m.IsStringContent() { + textBuilder.WriteString(m.StringContent()) + } else { + parts := m.ParseContent() + for _, part := range parts { + if part.Type == dto.ContentTypeImageURL { + img := part.GetImageMedia() + if img != nil && img.Url != "" { + var base64Data string + if strings.HasPrefix(img.Url, "http") { + fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat") + if err != nil { return nil, err } + base64Data = fileData.Base64Data + } else if strings.HasPrefix(img.Url, "data:") { + if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] } + } else { + base64Data = img.Url } - imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data) + if base64Data != "" { images = append(images, base64Data) } } - mediaMessage.ImageUrl = imageUrl - mediaMessages[j] = mediaMessage + } else if part.Type == dto.ContentTypeText { + textBuilder.WriteString(part.Text) } } - message.SetMediaContent(mediaMessages) } - messages = append(messages, dto.Message{ - Role: message.Role, - Content: message.Content, - ToolCalls: message.ToolCalls, - ToolCallId: message.ToolCallId, - }) + cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()} + if len(images)>0 { cm.Images = images } + if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name } + if m.ToolCalls != nil && len(m.ToolCalls) > 0 { + parsed := m.ParseToolCalls() + if len(parsed) > 0 { + calls := make([]OllamaToolCall,0,len(parsed)) + for _, tc := range parsed { + var args interface{} + if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) } + if args==nil { args = map[string]any{} } + oc := OllamaToolCall{} + oc.Function.Name = tc.Function.Name + oc.Function.Arguments = args + calls = append(calls, oc) + } + cm.ToolCalls = calls + } + } + chatReq.Messages = append(chatReq.Messages, cm) } - str, ok := request.Stop.(string) - var Stop []string - if ok { - Stop = []string{str} - } else { - Stop, _ = request.Stop.([]string) - } - ollamaRequest := &OllamaRequest{ - Model: request.Model, - Messages: messages, - Stream: request.Stream, - Temperature: request.Temperature, - Seed: request.Seed, - Topp: request.TopP, - TopK: request.TopK, - Stop: Stop, - Tools: request.Tools, - MaxTokens: request.GetMaxTokens(), - ResponseFormat: request.ResponseFormat, - FrequencyPenalty: request.FrequencyPenalty, - PresencePenalty: request.PresencePenalty, - Prompt: request.Prompt, - StreamOptions: request.StreamOptions, - Suffix: request.Suffix, - } - ollamaRequest.Think = request.Think - return ollamaRequest, nil + return chatReq, nil } -func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest { - return &OllamaEmbeddingRequest{ - Model: request.Model, - Input: request.ParseInput(), - Options: &Options{ - Seed: int(request.Seed), - Temperature: request.Temperature, - TopP: request.TopP, - FrequencyPenalty: request.FrequencyPenalty, - PresencePenalty: request.PresencePenalty, - }, +// openAIToGenerate converts OpenAI completions request to Ollama generate +func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) { + gen := &OllamaGenerateRequest{ + Model: r.Model, + Stream: r.Stream, + Options: map[string]any{}, + Think: r.Think, } + // Prompt may be in r.Prompt (string or []any) + if r.Prompt != nil { + switch v := r.Prompt.(type) { + case string: + gen.Prompt = v + case []any: + var sb strings.Builder + for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } } + gen.Prompt = sb.String() + default: + gen.Prompt = fmt.Sprintf("%v", r.Prompt) + } + } + if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } } + if r.ResponseFormat != nil { + if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema } + } + if r.Temperature != nil { gen.Options["temperature"] = r.Temperature } + if r.TopP != 0 { gen.Options["top_p"] = r.TopP } + if r.TopK != 0 { gen.Options["top_k"] = r.TopK } + if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty } + if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty } + if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) } + if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) } + if r.Stop != nil { + switch v := r.Stop.(type) { + case string: gen.Options["stop"] = []string{v} + case []string: gen.Options["stop"] = v + case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr } + } + } + return gen, nil +} + +func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest { + opts := map[string]any{} + if r.Temperature != nil { opts["temperature"] = r.Temperature } + if r.TopP != 0 { opts["top_p"] = r.TopP } + if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty } + if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty } + if r.Seed != 0 { opts["seed"] = int(r.Seed) } + if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions } + input := r.ParseInput() + if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} } + return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions} } func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - var ollamaEmbeddingResponse OllamaEmbeddingResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } + var oResp OllamaEmbeddingResponse + body, err := io.ReadAll(resp.Body) + if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } service.CloseResponseBodyGracefully(resp) - err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - if ollamaEmbeddingResponse.Error != "" { - return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding) - data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) - data = append(data, dto.OpenAIEmbeddingResponseItem{ - Embedding: flattenedEmbeddings, - Object: "embedding", - }) - usage := &dto.Usage{ - TotalTokens: info.PromptTokens, - CompletionTokens: 0, - PromptTokens: info.PromptTokens, - } - embeddingResponse := &dto.OpenAIEmbeddingResponse{ - Object: "list", - Data: data, - Model: info.UpstreamModelName, - Usage: *usage, - } - doResponseBody, err := common.Marshal(embeddingResponse) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - service.IOCopyBytesGracefully(c, resp, doResponseBody) + if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings)) + for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) } + usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount} + embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage} + out, _ := common.Marshal(embResp) + service.IOCopyBytesGracefully(c, resp, out) return usage, nil } -func flattenEmbeddings(embeddings [][]float64) []float64 { - flattened := []float64{} - for _, row := range embeddings { - flattened = append(flattened, row...) - } - return flattened -} diff --git a/relay/channel/ollama/stream.go b/relay/channel/ollama/stream.go new file mode 100644 index 000000000..964f11d90 --- /dev/null +++ b/relay/channel/ollama/stream.go @@ -0,0 +1,210 @@ +package ollama + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/logger" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/types" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +type ollamaChatStreamChunk struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + // chat + Message *struct { + Role string `json:"role"` + Content string `json:"content"` + Thinking json.RawMessage `json:"thinking"` + ToolCalls []struct { + Function struct { + Name string `json:"name"` + Arguments interface{} `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + // generate + Response string `json:"response"` + Done bool `json:"done"` + DoneReason string `json:"done_reason"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` + PromptEvalDuration int64 `json:"prompt_eval_duration"` + EvalDuration int64 `json:"eval_duration"` +} + +func toUnix(ts string) int64 { + if ts == "" { return time.Now().Unix() } + // try time.RFC3339 or with nanoseconds + t, err := time.Parse(time.RFC3339Nano, ts) + if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() } + return t.Unix() +} + +func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) } + defer service.CloseResponseBodyGracefully(resp) + + helper.SetEventStreamHeaders(c) + scanner := bufio.NewScanner(resp.Body) + usage := &dto.Usage{} + var model = info.UpstreamModelName + var responseId = common.GetUUID() + var created = time.Now().Unix() + var toolCallIndex int + start := helper.GenerateStartEmptyResponse(responseId, created, model, nil) + if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) } + + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + if line == "" { continue } + var chunk ollamaChatStreamChunk + if err := json.Unmarshal([]byte(line), &chunk); err != nil { + logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line) + return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + if chunk.Model != "" { model = chunk.Model } + created = toUnix(chunk.CreatedAt) + + if !chunk.Done { + // delta content + var content string + if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response } + delta := dto.ChatCompletionsStreamResponse{ + Id: responseId, + Object: "chat.completion.chunk", + Created: created, + Model: model, + Choices: []dto.ChatCompletionsStreamResponseChoice{ { + Index: 0, + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" }, + } }, + } + if content != "" { delta.Choices[0].Delta.SetContentString(content) } + if chunk.Message != nil && len(chunk.Message.Thinking) > 0 { + raw := strings.TrimSpace(string(chunk.Message.Thinking)) + if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) } + } + // tool calls + if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 { + delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls)) + for _, tc := range chunk.Message.ToolCalls { + // arguments -> string + argBytes, _ := json.Marshal(tc.Function.Arguments) + toolId := fmt.Sprintf("call_%d", toolCallIndex) + tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}} + tr.SetIndex(toolCallIndex) + toolCallIndex++ + delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr) + } + } + if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) } + continue + } + // done frame + // finalize once and break loop + usage.PromptTokens = chunk.PromptEvalCount + usage.CompletionTokens = chunk.EvalCount + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + finishReason := chunk.DoneReason + if finishReason == "" { finishReason = "stop" } + // emit stop delta + if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil { + if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) } + } + // emit usage frame + if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil { + if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) } + } + // send [DONE] + helper.Done(c) + break + } + if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) } + return usage, nil +} + +// non-stream handler for chat/generate +func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + body, err := io.ReadAll(resp.Body) + if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } + service.CloseResponseBodyGracefully(resp) + raw := string(body) + if common.DebugEnabled { println("ollama non-stream raw resp:", raw) } + + lines := strings.Split(raw, "\n") + var ( + aggContent strings.Builder + reasoningBuilder strings.Builder + lastChunk ollamaChatStreamChunk + parsedAny bool + ) + for _, ln := range lines { + ln = strings.TrimSpace(ln) + if ln == "" { continue } + var ck ollamaChatStreamChunk + if err := json.Unmarshal([]byte(ln), &ck); err != nil { + if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + continue + } + parsedAny = true + lastChunk = ck + if ck.Message != nil && len(ck.Message.Thinking) > 0 { + raw := strings.TrimSpace(string(ck.Message.Thinking)) + if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } + } + if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) } + } + + if !parsedAny { + var single ollamaChatStreamChunk + if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + lastChunk = single + if single.Message != nil { + if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } } + aggContent.WriteString(single.Message.Content) + } else { aggContent.WriteString(single.Response) } + } + + model := lastChunk.Model + if model == "" { model = info.UpstreamModelName } + created := toUnix(lastChunk.CreatedAt) + usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount} + content := aggContent.String() + finishReason := lastChunk.DoneReason + if finishReason == "" { finishReason = "stop" } + + msg := dto.Message{Role: "assistant", Content: contentPtr(content)} + if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc } + full := dto.OpenAITextResponse{ + Id: common.GetUUID(), + Model: model, + Object: "chat.completion", + Created: created, + Choices: []dto.OpenAITextResponseChoice{ { + Index: 0, + Message: msg, + FinishReason: finishReason, + } }, + Usage: *usage, + } + out, _ := common.Marshal(full) + service.IOCopyBytesGracefully(c, resp, out) + return usage, nil +} + +func contentPtr(s string) *string { if s=="" { return nil }; return &s } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4b13a7df1..a88b68502 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -12,6 +12,7 @@ import ( "one-api/constant" "one-api/dto" "one-api/logger" + "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -185,10 +186,27 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if common.DebugEnabled { println("upstream response body:", string(responseBody)) } + // Unmarshal to simpleResponse + if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() { + // 尝试解析为 openrouter enterprise + var enterpriseResponse openrouter.OpenRouterEnterpriseResponse + err = common.Unmarshal(responseBody, &enterpriseResponse) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + if enterpriseResponse.Success { + responseBody = enterpriseResponse.Data + } else { + logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data)) + return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + } + err = common.Unmarshal(responseBody, &simpleResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) } diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index e188889e4..85938a771 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -33,6 +33,12 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) } + if responsesResponse.HasImageGenerationCall() { + c.Set("image_generation_call", true) + c.Set("image_generation_call_quality", responsesResponse.GetQuality()) + c.Set("image_generation_call_size", responsesResponse.GetSize()) + } + // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) @@ -80,18 +86,25 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp sendResponsesStreamData(c, streamResponse, data) switch streamResponse.Type { case "response.completed": - if streamResponse.Response != nil && streamResponse.Response.Usage != nil { - if streamResponse.Response.Usage.InputTokens != 0 { - usage.PromptTokens = streamResponse.Response.Usage.InputTokens + if streamResponse.Response != nil { + if streamResponse.Response.Usage != nil { + if streamResponse.Response.Usage.InputTokens != 0 { + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + } + if streamResponse.Response.Usage.OutputTokens != 0 { + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + } + if streamResponse.Response.Usage.TotalTokens != 0 { + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + } + if streamResponse.Response.Usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens + } } - if streamResponse.Response.Usage.OutputTokens != 0 { - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - } - if streamResponse.Response.Usage.TotalTokens != 0 { - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens - } - if streamResponse.Response.Usage.InputTokensDetails != nil { - usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens + if streamResponse.Response.HasImageGenerationCall() { + c.Set("image_generation_call", true) + c.Set("image_generation_call_quality", streamResponse.Response.GetQuality()) + c.Set("image_generation_call_size", streamResponse.Response.GetSize()) } } case "response.output_text.delta": diff --git a/relay/channel/openrouter/dto.go b/relay/channel/openrouter/dto.go index 607f495bf..a32499852 100644 --- a/relay/channel/openrouter/dto.go +++ b/relay/channel/openrouter/dto.go @@ -1,5 +1,7 @@ package openrouter +import "encoding/json" + type RequestReasoning struct { // One of the following (not both): Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style) @@ -7,3 +9,8 @@ type RequestReasoning struct { // Optional: Default is false. All models support this. Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response } + +type OpenRouterEnterpriseResponse struct { + Data json.RawMessage `json:"data"` + Success bool `json:"success"` +} diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 2bc45c547..a2545a273 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -36,6 +36,7 @@ type requestPayload struct { Prompt string `json:"prompt,omitempty"` Seed int64 `json:"seed"` AspectRatio string `json:"aspect_ratio"` + Frames int `json:"frames,omitempty"` } type responsePayload struct { @@ -93,6 +94,9 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { + if isNewAPIRelay(info.ApiKey) { + return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil + } return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil } @@ -100,7 +104,12 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") - return a.signRequest(req, a.accessKey, a.secretKey) + if isNewAPIRelay(info.ApiKey) { + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + } else { + return a.signRequest(req, a.accessKey, a.secretKey) + } + return nil } // BuildRequestBody converts request into Jimeng specific format. @@ -160,6 +169,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http } uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl) + if isNewAPIRelay(key) { + uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL) + } payload := map[string]string{ "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774 "task_id": taskID, @@ -177,17 +189,20 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") - keyParts := strings.Split(key, "|") - if len(keyParts) != 2 { - return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'") - } - accessKey := strings.TrimSpace(keyParts[0]) - secretKey := strings.TrimSpace(keyParts[1]) + if isNewAPIRelay(key) { + req.Header.Set("Authorization", "Bearer "+key) + } else { + keyParts := strings.Split(key, "|") + if len(keyParts) != 2 { + return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'") + } + accessKey := strings.TrimSpace(keyParts[0]) + secretKey := strings.TrimSpace(keyParts[1]) - if err := a.signRequest(req, accessKey, secretKey); err != nil { - return nil, errors.Wrap(err, "sign request failed") + if err := a.signRequest(req, accessKey, secretKey); err != nil { + return nil, errors.Wrap(err, "sign request failed") + } } - return service.GetHttpClient().Do(req) } @@ -311,10 +326,15 @@ func hmacSHA256(key []byte, data []byte) []byte { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { r := requestPayload{ - ReqKey: "jimeng_vgfm_i2v_l20", - Prompt: req.Prompt, - AspectRatio: "16:9", // Default aspect ratio - Seed: -1, // Default to random + ReqKey: req.Model, + Prompt: req.Prompt, + } + + switch req.Duration { + case 10: + r.Frames = 241 // 24*10+1 = 241 + default: + r.Frames = 121 // 24*5+1 = 121 } // Handle one-of image_urls or binary_data_base64 @@ -334,6 +354,22 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* if err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } + + // 即梦视频3.0 ReqKey转换 + // https://www.volcengine.com/docs/85621/1792707 + if strings.Contains(r.ReqKey, "jimeng_v30") { + if len(r.ImageUrls) > 1 { + // 多张图片:首尾帧生成 + r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1) + } else if len(r.ImageUrls) == 1 { + // 单张图片:图生视频 + r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1) + } else { + // 无图片:文生视频 + r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1) + } + } + return &r, nil } @@ -362,3 +398,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e taskResult.Url = resTask.Data.VideoUrl return &taskResult, nil } + +func isNewAPIRelay(apiKey string) bool { + return strings.HasPrefix(apiKey, "sk-") +} diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 13f2af972..fec3396ae 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -117,6 +117,11 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video") + + if isNewAPIRelay(info.ApiKey) { + return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil + } + return fmt.Sprintf("%s%s", a.baseURL, path), nil } @@ -199,6 +204,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http } path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video") url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID) + if isNewAPIRelay(key) { + url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID) + } req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -304,8 +312,13 @@ func (a *TaskAdaptor) createJWTToken() (string, error) { //} func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { - + if isNewAPIRelay(apiKey) { + return apiKey, nil // new api relay + } keyParts := strings.Split(apiKey, "|") + if len(keyParts) != 2 { + return "", errors.New("invalid api_key, required format is accessKey|secretKey") + } accessKey := strings.TrimSpace(keyParts[0]) if len(keyParts) == 1 { return accessKey, nil @@ -352,3 +365,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } return taskInfo, nil } + +func isNewAPIRelay(apiKey string) bool { + return strings.HasPrefix(apiKey, "sk-") +} diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index a1140d1e7..358aef583 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -80,8 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { - // Use the unified validation method for TaskSubmitReq with image-based action determination - return relaycommon.ValidateTaskRequestWithImageBinding(c, info) + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { @@ -112,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro switch info.Action { case constant.TaskActionGenerate: path = "/img2video" + case constant.TaskActionFirstTailGenerate: + path = "/start-end2video" + case constant.TaskActionReferenceGenerate: + path = "/reference2video" default: path = "/text2video" } @@ -187,14 +190,9 @@ func (a *TaskAdaptor) GetChannelName() string { // ============================ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { - var images []string - if req.Image != "" { - images = []string{req.Image} - } - r := requestPayload{ Model: defaultString(req.Model, "viduq1"), - Images: images, + Images: req.Images, Prompt: req.Prompt, Duration: defaultInt(req.Duration, 5), Resolution: defaultString(req.Size, "1080p"), diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 0af019da4..21d6e1705 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "net/textproto" + channelconstant "one-api/constant" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" @@ -41,6 +42,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { switch info.RelayMode { + case constant.RelayModeImagesGenerations: + return request, nil case constant.RelayModeImagesEdits: var requestBody bytes.Buffer @@ -186,20 +189,26 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + // 支持自定义域名,如果未设置则使用默认域名 + baseUrl := info.ChannelBaseUrl + if baseUrl == "" { + baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] + } + switch info.RelayMode { case constant.RelayModeChatCompletions: if strings.HasPrefix(info.UpstreamModelName, "bot") { - return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil } - return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil case constant.RelayModeRerank: - return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go index 30cc902e7..87a12b27c 100644 --- a/relay/channel/volcengine/constants.go +++ b/relay/channel/volcengine/constants.go @@ -8,6 +8,12 @@ var ModelList = []string{ "Doubao-lite-32k", "Doubao-lite-4k", "Doubao-embedding", + "doubao-seedream-4-0-250828", + "seedream-4-0-250828", + "doubao-seedance-1-0-pro-250528", + "seedance-1-0-pro-250528", + "doubao-seed-1-6-thinking-250715", + "seed-1-6-thinking-250715", } var ChannelName = "volcengine" diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 9d5c190fe..9503d5d39 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -207,10 +207,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap return nil, nil, err } - defer func() { - conn.Close() - }() - data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { @@ -220,6 +216,9 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap dataChan := make(chan XunfeiChatResponse) stopChan := make(chan bool) go func() { + defer func() { + conn.Close() + }() for { _, msg, err := conn.ReadMessage() if err != nil { diff --git a/relay/claude_handler.go b/relay/claude_handler.go index dbdc6ee1c..59d12abe4 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -69,6 +70,31 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ info.UpstreamModelName = request.Model } + if info.ChannelSetting.SystemPrompt != "" { + if request.System == nil { + request.SetStringSystem(info.ChannelSetting.SystemPrompt) + } else if info.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + if request.IsStringSystem() { + existing := strings.TrimSpace(request.GetStringSystem()) + if existing == "" { + request.SetStringSystem(info.ChannelSetting.SystemPrompt) + } else { + request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing) + } + } else { + systemContents := request.ParseSystem() + newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText} + newSystem.SetText(info.ChannelSetting.SystemPrompt) + if len(systemContents) == 0 { + request.System = []dto.ClaudeMediaMessage{newSystem} + } else { + request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...) + } + } + } + } + var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index cf6d08dda..3a721b479 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -79,34 +79,18 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d req.Images = []string{req.Image} } + if req.HasImage() { + action = constant.TaskActionGenerate + if info.ChannelType == constant.ChannelTypeVidu { + // vidu 增加 首尾帧生视频和参考图生视频 + if len(req.Images) == 2 { + action = constant.TaskActionFirstTailGenerate + } else if len(req.Images) > 2 { + action = constant.TaskActionReferenceGenerate + } + } + } + storeTaskRequest(c, info, action, req) return nil } - -func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError { - hasPrompt, ok := requestObj.(HasPrompt) - if !ok { - return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true) - } - - if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil { - return taskErr - } - - action := constant.TaskActionTextGenerate - if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() { - action = constant.TaskActionGenerate - } - - storeTaskRequest(c, info, action, requestObj) - return nil -} - -func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError { - var req TaskSubmitReq - if err := c.ShouldBindJSON(&req); err != nil { - return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false) - } - - return ValidateTaskRequestWithImage(c, info, req) -} diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 01ab1fff4..38b820f72 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -90,41 +90,43 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types if info.ChannelSetting.SystemPrompt != "" { // 如果有系统提示,则将其添加到请求中 - request := convertedRequest.(*dto.GeneralOpenAIRequest) - containSystemPrompt := false - for _, message := range request.Messages { - if message.Role == request.GetSystemRoleName() { - containSystemPrompt = true - break - } - } - if !containSystemPrompt { - // 如果没有系统提示,则添加系统提示 - systemMessage := dto.Message{ - Role: request.GetSystemRoleName(), - Content: info.ChannelSetting.SystemPrompt, - } - request.Messages = append([]dto.Message{systemMessage}, request.Messages...) - } else if info.ChannelSetting.SystemPromptOverride { - common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) - // 如果有系统提示,且允许覆盖,则拼接到前面 - for i, message := range request.Messages { + request, ok := convertedRequest.(*dto.GeneralOpenAIRequest) + if ok { + containSystemPrompt := false + for _, message := range request.Messages { if message.Role == request.GetSystemRoleName() { - if message.IsStringContent() { - request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) - } else { - contents := message.ParseContent() - contents = append([]dto.MediaContent{ - { - Type: dto.ContentTypeText, - Text: info.ChannelSetting.SystemPrompt, - }, - }, contents...) - request.Messages[i].Content = contents - } + containSystemPrompt = true break } } + if !containSystemPrompt { + // 如果没有系统提示,则添加系统提示 + systemMessage := dto.Message{ + Role: request.GetSystemRoleName(), + Content: info.ChannelSetting.SystemPrompt, + } + request.Messages = append([]dto.Message{systemMessage}, request.Messages...) + } else if info.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + // 如果有系统提示,且允许覆盖,则拼接到前面 + for i, message := range request.Messages { + if message.Role == request.GetSystemRoleName() { + if message.IsStringContent() { + request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) + } else { + contents := message.ParseContent() + contents = append([]dto.MediaContent{ + { + Type: dto.ContentTypeText, + Text: info.ChannelSetting.SystemPrompt, + }, + }, contents...) + request.Messages[i].Content = contents + } + break + } + } + } } } @@ -276,6 +278,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage fileSearchTool.CallCount, dFileSearchQuota.String()) } } + var dImageGenerationCallQuota decimal.Decimal + var imageGenerationCallPrice float64 + if ctx.GetBool("image_generation_call") { + imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) + dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()) + } var quotaCalculateDecimal decimal.Decimal @@ -331,6 +340,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) // 添加 audio input 独立计费 quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) + // 添加 image generation call 计费 + quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens @@ -429,6 +440,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage other["audio_input_token_count"] = audioTokens other["audio_input_price"] = audioInputPrice } + if !dImageGenerationCallQuota.IsZero() { + other["image_generation_call"] = true + other["image_generation_call_price"] = imageGenerationCallPrice + } model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 0252d6578..1410da606 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/logger" "one-api/relay/channel/gemini" @@ -94,6 +95,32 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ adaptor.Init(info) + if info.ChannelSetting.SystemPrompt != "" { + if request.SystemInstructions == nil { + request.SystemInstructions = &dto.GeminiChatContent{ + Parts: []dto.GeminiPart{ + {Text: info.ChannelSetting.SystemPrompt}, + }, + } + } else if len(request.SystemInstructions.Parts) == 0 { + request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}} + } else if info.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + merged := false + for i := range request.SystemInstructions.Parts { + if request.SystemInstructions.Parts[i].Text == "" { + continue + } + request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text + merged = true + break + } + if !merged { + request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...) + } + } + } + // Clean up empty system instruction if request.SystemInstructions != nil { hasContent := false diff --git a/relay/helper/price.go b/relay/helper/price.go index fdc5b66d8..c23c068b3 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -52,6 +52,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var cacheRatio float64 var imageRatio float64 var cacheCreationRatio float64 + var audioRatio float64 + var audioCompletionRatio float64 if !usePrice { preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota) if meta.MaxTokens != 0 { @@ -73,6 +75,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName) cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName) imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName) + audioRatio = ratio_setting.GetAudioRatio(info.OriginModelName) + audioCompletionRatio = ratio_setting.GetAudioCompletionRatio(info.OriginModelName) ratio := modelRatio * groupRatioInfo.GroupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { @@ -90,6 +94,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens UsePrice: usePrice, CacheRatio: cacheRatio, ImageRatio: imageRatio, + AudioRatio: audioRatio, + AudioCompletionRatio: audioCompletionRatio, CacheCreationRatio: cacheCreationRatio, ShouldPreConsumedQuota: preConsumedQuota, } diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 4d1c1f9bb..f4a290ec6 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -21,7 +21,11 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt case types.RelayFormatOpenAI: request, err = GetAndValidateTextRequest(c, relayMode) case types.RelayFormatGemini: - request, err = GetAndValidateGeminiRequest(c) + if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") { + request, err = GetAndValidateGeminiEmbeddingRequest(c) + } else { + request, err = GetAndValidateGeminiRequest(c) + } case types.RelayFormatClaude: request, err = GetAndValidateClaudeRequest(c) case types.RelayFormatOpenAIResponses: @@ -288,7 +292,6 @@ func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenA } func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { - request := &dto.GeminiChatRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { @@ -304,3 +307,12 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) return request, nil } + +func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) { + request := &dto.GeminiEmbeddingRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + return request, nil +} diff --git a/service/cf_worker.go b/service/download.go similarity index 59% rename from service/cf_worker.go rename to service/download.go index d60b6fad5..036c43af8 100644 --- a/service/cf_worker.go +++ b/service/download.go @@ -28,6 +28,12 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { return nil, fmt.Errorf("only support https url") } + // SSRF防护:验证请求URL + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return nil, fmt.Errorf("request reject: %v", err) + } + workerUrl := system_setting.WorkerUrl if !strings.HasSuffix(workerUrl, "/") { workerUrl += "/" @@ -51,7 +57,13 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, } return DoWorkerRequest(req) } else { - common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) + // SSRF防护:验证请求URL(非Worker模式) + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return nil, fmt.Errorf("request reject: %v", err) + } + + common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", "))) return http.Get(originUrl) } } diff --git a/service/http_client.go b/service/http_client.go index b191ddd78..d8fcfae01 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -7,12 +7,17 @@ import ( "net/http" "net/url" "one-api/common" + "sync" "time" "golang.org/x/net/proxy" ) -var httpClient *http.Client +var ( + httpClient *http.Client + proxyClientLock sync.Mutex + proxyClients = make(map[string]*http.Client) +) func InitHttpClient() { if common.RelayTimeout == 0 { @@ -28,12 +33,31 @@ func GetHttpClient() *http.Client { return httpClient } +// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化 +func ResetProxyClientCache() { + proxyClientLock.Lock() + defer proxyClientLock.Unlock() + for _, client := range proxyClients { + if transport, ok := client.Transport.(*http.Transport); ok && transport != nil { + transport.CloseIdleConnections() + } + } + proxyClients = make(map[string]*http.Client) +} + // NewProxyHttpClient 创建支持代理的 HTTP 客户端 func NewProxyHttpClient(proxyURL string) (*http.Client, error) { if proxyURL == "" { return http.DefaultClient, nil } + proxyClientLock.Lock() + if client, ok := proxyClients[proxyURL]; ok { + proxyClientLock.Unlock() + return client, nil + } + proxyClientLock.Unlock() + parsedURL, err := url.Parse(proxyURL) if err != nil { return nil, err @@ -41,11 +65,16 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { switch parsedURL.Scheme { case "http", "https": - return &http.Client{ + client := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(parsedURL), }, - }, nil + } + client.Timeout = time.Duration(common.RelayTimeout) * time.Second + proxyClientLock.Lock() + proxyClients[proxyURL] = client + proxyClientLock.Unlock() + return client, nil case "socks5", "socks5h": // 获取认证信息 @@ -67,13 +96,18 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { return nil, err } - return &http.Client{ + client := &http.Client{ Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, }, - }, nil + } + client.Timeout = time.Duration(common.RelayTimeout) * time.Second + proxyClientLock.Lock() + proxyClients[proxyURL] = client + proxyClientLock.Unlock() + return client, nil default: return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 3cfabc1a4..0cf53513b 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -19,7 +19,7 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) { gopool.Go(func() { relayInfoCopy := *relayInfo - err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false) + err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false) if err != nil { common.SysLog("error return pre-consumed quota: " + err.Error()) } diff --git a/service/user_notify.go b/service/user_notify.go index 972ca655c..fba12d9db 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -113,6 +113,12 @@ func sendBarkNotify(barkURL string, data dto.Notify) error { return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode) } } else { + // SSRF防护:验证Bark URL(非Worker模式) + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return fmt.Errorf("request reject: %v", err) + } + // 直接发送请求 req, err = http.NewRequest(http.MethodGet, finalURL, nil) if err != nil { diff --git a/service/webhook.go b/service/webhook.go index 9c6ec8102..c678b8634 100644 --- a/service/webhook.go +++ b/service/webhook.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "net/http" + "one-api/common" "one-api/dto" "one-api/setting/system_setting" "time" @@ -86,6 +87,12 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode) } } else { + // SSRF防护:验证Webhook URL(非Worker模式) + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return fmt.Errorf("request reject: %v", err) + } + req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes)) if err != nil { return fmt.Errorf("failed to create webhook request: %v", err) diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go index 549a1862e..5b89d6fec 100644 --- a/setting/operation_setting/tools.go +++ b/setting/operation_setting/tools.go @@ -10,6 +10,18 @@ const ( FileSearchPrice = 2.5 ) +const ( + GPTImage1Low1024x1024 = 0.011 + GPTImage1Low1024x1536 = 0.016 + GPTImage1Low1536x1024 = 0.016 + GPTImage1Medium1024x1024 = 0.042 + GPTImage1Medium1024x1536 = 0.063 + GPTImage1Medium1536x1024 = 0.063 + GPTImage1High1024x1024 = 0.167 + GPTImage1High1024x1536 = 0.25 + GPTImage1High1536x1024 = 0.25 +) + const ( // Gemini Audio Input Price Gemini25FlashPreviewInputAudioPrice = 1.00 @@ -65,3 +77,31 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 { } return 0 } + +func GetGPTImage1PriceOnceCall(quality string, size string) float64 { + prices := map[string]map[string]float64{ + "low": { + "1024x1024": GPTImage1Low1024x1024, + "1024x1536": GPTImage1Low1024x1536, + "1536x1024": GPTImage1Low1536x1024, + }, + "medium": { + "1024x1024": GPTImage1Medium1024x1024, + "1024x1536": GPTImage1Medium1024x1536, + "1536x1024": GPTImage1Medium1536x1024, + }, + "high": { + "1024x1024": GPTImage1High1024x1024, + "1024x1536": GPTImage1High1024x1536, + "1536x1024": GPTImage1High1536x1024, + }, + } + + if qualityMap, exists := prices[quality]; exists { + if price, exists := qualityMap[size]; exists { + return price + } + } + + return GPTImage1High1024x1024 +} diff --git a/setting/payment_stripe.go b/setting/payment_stripe.go index 80d877dfa..d97120c85 100644 --- a/setting/payment_stripe.go +++ b/setting/payment_stripe.go @@ -5,3 +5,4 @@ var StripeWebhookSecret = "" var StripePriceId = "" var StripeUnitPrice = 8.0 var StripeMinTopUp = 1 +var StripePromotionCodesEnabled = false diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 097e048df..738ab64d6 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -178,6 +178,7 @@ var defaultModelRatio = map[string]float64{ "gemini-2.5-flash-lite-preview-thinking-*": 0.05, "gemini-2.5-flash-lite-preview-06-17": 0.05, "gemini-2.5-flash": 0.15, + "gemini-embedding-001": 0.075, "text-embedding-004": 0.001, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens @@ -289,6 +290,18 @@ var defaultModelPrice = map[string]float64{ "mj_upload": 0.05, } +var defaultAudioRatio = map[string]float64{ + "gpt-4o-audio-preview": 16, + "gpt-4o-mini-audio-preview": 66.67, + "gpt-4o-realtime-preview": 8, + "gpt-4o-mini-realtime-preview": 16.67, +} + +var defaultAudioCompletionRatio = map[string]float64{ + "gpt-4o-realtime": 2, + "gpt-4o-mini-realtime": 2, +} + var ( modelPriceMap map[string]float64 = nil modelPriceMapMutex = sync.RWMutex{} @@ -337,6 +350,15 @@ func InitRatioSettings() { imageRatioMap = defaultImageRatio imageRatioMapMutex.Unlock() + // initialize audioRatioMap + audioRatioMapMutex.Lock() + audioRatioMap = defaultAudioRatio + audioRatioMapMutex.Unlock() + + // initialize audioCompletionRatioMap + audioCompletionRatioMapMutex.Lock() + audioCompletionRatioMap = defaultAudioCompletionRatio + audioCompletionRatioMapMutex.Unlock() } func GetModelPriceMap() map[string]float64 { @@ -428,6 +450,18 @@ func GetDefaultModelRatioMap() map[string]float64 { return defaultModelRatio } +func GetDefaultImageRatioMap() map[string]float64 { + return defaultImageRatio +} + +func GetDefaultAudioRatioMap() map[string]float64 { + return defaultAudioRatio +} + +func GetDefaultAudioCompletionRatioMap() map[string]float64 { + return defaultAudioCompletionRatio +} + func GetCompletionRatioMap() map[string]float64 { CompletionRatioMutex.RLock() defer CompletionRatioMutex.RUnlock() @@ -595,32 +629,22 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) { } func GetAudioRatio(name string) float64 { - if strings.Contains(name, "-realtime") { - if strings.HasSuffix(name, "gpt-4o-realtime-preview") { - return 8 - } else if strings.Contains(name, "gpt-4o-mini-realtime-preview") { - return 10 / 0.6 - } else { - return 20 - } - } - if strings.Contains(name, "-audio") { - if strings.HasPrefix(name, "gpt-4o-audio-preview") { - return 40 / 2.5 - } else if strings.HasPrefix(name, "gpt-4o-mini-audio-preview") { - return 10 / 0.15 - } else { - return 40 - } + audioRatioMapMutex.RLock() + defer audioRatioMapMutex.RUnlock() + name = FormatMatchingModelName(name) + if ratio, ok := audioRatioMap[name]; ok { + return ratio } return 20 } func GetAudioCompletionRatio(name string) float64 { - if strings.HasPrefix(name, "gpt-4o-realtime") { - return 2 - } else if strings.HasPrefix(name, "gpt-4o-mini-realtime") { - return 2 + audioCompletionRatioMapMutex.RLock() + defer audioCompletionRatioMapMutex.RUnlock() + name = FormatMatchingModelName(name) + if ratio, ok := audioCompletionRatioMap[name]; ok { + + return ratio } return 2 } @@ -641,6 +665,14 @@ var defaultImageRatio = map[string]float64{ } var imageRatioMap map[string]float64 var imageRatioMapMutex sync.RWMutex +var ( + audioRatioMap map[string]float64 = nil + audioRatioMapMutex = sync.RWMutex{} +) +var ( + audioCompletionRatioMap map[string]float64 = nil + audioCompletionRatioMapMutex = sync.RWMutex{} +) func ImageRatio2JSONString() string { imageRatioMapMutex.RLock() @@ -669,6 +701,71 @@ func GetImageRatio(name string) (float64, bool) { return ratio, true } +func AudioRatio2JSONString() string { + audioRatioMapMutex.RLock() + defer audioRatioMapMutex.RUnlock() + jsonBytes, err := common.Marshal(audioRatioMap) + if err != nil { + common.SysError("error marshalling audio ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateAudioRatioByJSONString(jsonStr string) error { + + tmp := make(map[string]float64) + if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil { + return err + } + audioRatioMapMutex.Lock() + audioRatioMap = tmp + audioRatioMapMutex.Unlock() + InvalidateExposedDataCache() + return nil +} + +func GetAudioRatioCopy() map[string]float64 { + audioRatioMapMutex.RLock() + defer audioRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(audioRatioMap)) + for k, v := range audioRatioMap { + copyMap[k] = v + } + return copyMap +} + +func AudioCompletionRatio2JSONString() string { + audioCompletionRatioMapMutex.RLock() + defer audioCompletionRatioMapMutex.RUnlock() + jsonBytes, err := common.Marshal(audioCompletionRatioMap) + if err != nil { + common.SysError("error marshalling audio completion ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateAudioCompletionRatioByJSONString(jsonStr string) error { + tmp := make(map[string]float64) + if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil { + return err + } + audioCompletionRatioMapMutex.Lock() + audioCompletionRatioMap = tmp + audioCompletionRatioMapMutex.Unlock() + InvalidateExposedDataCache() + return nil +} + +func GetAudioCompletionRatioCopy() map[string]float64 { + audioCompletionRatioMapMutex.RLock() + defer audioCompletionRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(audioCompletionRatioMap)) + for k, v := range audioCompletionRatioMap { + copyMap[k] = v + } + return copyMap +} + func GetModelRatioCopy() map[string]float64 { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() diff --git a/setting/system_setting/fetch_setting.go b/setting/system_setting/fetch_setting.go new file mode 100644 index 000000000..c41b930af --- /dev/null +++ b/setting/system_setting/fetch_setting.go @@ -0,0 +1,34 @@ +package system_setting + +import "one-api/setting/config" + +type FetchSetting struct { + EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护 + AllowPrivateIp bool `json:"allow_private_ip"` + DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式 + IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式,true: 白名单模式,false: 黑名单模式 + DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com + IpList []string `json:"ip_list"` // CIDR format + AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000 + ApplyIPFilterForDomain bool `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤(实验性) +} + +var defaultFetchSetting = FetchSetting{ + EnableSSRFProtection: true, // 默认开启SSRF防护 + AllowPrivateIp: false, + DomainFilterMode: false, + IpFilterMode: false, + DomainList: []string{}, + IpList: []string{}, + AllowedPorts: []string{"80", "443", "8080", "8443"}, + ApplyIPFilterForDomain: false, +} + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("fetch_setting", &defaultFetchSetting) +} + +func GetFetchSetting() *FetchSetting { + return &defaultFetchSetting +} diff --git a/types/error.go b/types/error.go index 883ee0641..a42e84385 100644 --- a/types/error.go +++ b/types/error.go @@ -122,6 +122,9 @@ func (e *NewAPIError) MaskSensitiveError() string { return string(e.errorCode) } errStr := e.Err.Error() + if e.errorCode == ErrorCodeCountTokenFailed { + return errStr + } return common.MaskSensitiveInfo(errStr) } @@ -153,8 +156,9 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError { Code: e.errorCode, } } - - result.Message = common.MaskSensitiveInfo(result.Message) + if e.errorCode != ErrorCodeCountTokenFailed { + result.Message = common.MaskSensitiveInfo(result.Message) + } return result } @@ -178,7 +182,9 @@ func (e *NewAPIError) ToClaudeError() ClaudeError { Type: string(e.errorType), } } - result.Message = common.MaskSensitiveInfo(result.Message) + if e.errorCode != ErrorCodeCountTokenFailed { + result.Message = common.MaskSensitiveInfo(result.Message) + } return result } diff --git a/types/price_data.go b/types/price_data.go index f6a92d7e3..ec7fcdfe9 100644 --- a/types/price_data.go +++ b/types/price_data.go @@ -15,6 +15,8 @@ type PriceData struct { CacheRatio float64 CacheCreationRatio float64 ImageRatio float64 + AudioRatio float64 + AudioCompletionRatio float64 UsePrice bool ShouldPreConsumedQuota int GroupRatioInfo GroupRatioInfo @@ -27,5 +29,5 @@ type PerCallPriceData struct { } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio) } diff --git a/web/index.html b/web/index.html index 09d87ae1a..df6b0e398 100644 --- a/web/index.html +++ b/web/index.html @@ -10,6 +10,7 @@ content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用" /> New API + diff --git a/web/jsconfig.json b/web/jsconfig.json new file mode 100644 index 000000000..ced4d0543 --- /dev/null +++ b/web/jsconfig.json @@ -0,0 +1,9 @@ +{ + "compilerOptions": { + "baseUrl": "./", + "paths": { + "@/*": ["src/*"] + } + }, + "include": ["src/**/*"] +} \ No newline at end of file diff --git a/web/src/components/common/markdown/MarkdownRenderer.jsx b/web/src/components/common/markdown/MarkdownRenderer.jsx index f1283a640..05419f8cc 100644 --- a/web/src/components/common/markdown/MarkdownRenderer.jsx +++ b/web/src/components/common/markdown/MarkdownRenderer.jsx @@ -181,8 +181,8 @@ export function PreCode(props) { e.preventDefault(); e.stopPropagation(); if (ref.current) { - const code = - ref.current.querySelector('code')?.innerText ?? ''; + const codeElement = ref.current.querySelector('code'); + const code = codeElement?.textContent ?? ''; copy(code).then((success) => { if (success) { Toast.success(t('代码已复制到剪贴板')); diff --git a/web/src/components/layout/headerbar/UserArea.jsx b/web/src/components/layout/headerbar/UserArea.jsx index 8ea70f47f..9fc011da1 100644 --- a/web/src/components/layout/headerbar/UserArea.jsx +++ b/web/src/components/layout/headerbar/UserArea.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React from 'react'; +import React, { useRef } from 'react'; import { Link } from 'react-router-dom'; import { Avatar, Button, Dropdown, Typography } from '@douyinfe/semi-ui'; import { ChevronDown } from 'lucide-react'; @@ -39,6 +39,7 @@ const UserArea = ({ navigate, t, }) => { + const dropdownRef = useRef(null); if (isLoading) { return ( - { - navigate('/console/personal'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('个人设置')} -
-
- { - navigate('/console/token'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('令牌管理')} -
-
- { - navigate('/console/topup'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('钱包管理')} -
-
- -
- - {t('退出')} -
-
- - } - > - - + + {userState.user.username[0].toUpperCase()} + + + + {userState.user.username} + + + + + + ); } else { const showRegisterButton = !isSelfUseMode; diff --git a/web/src/components/settings/PaymentSetting.jsx b/web/src/components/settings/PaymentSetting.jsx index faaa9561b..220c86642 100644 --- a/web/src/components/settings/PaymentSetting.jsx +++ b/web/src/components/settings/PaymentSetting.jsx @@ -45,6 +45,7 @@ const PaymentSetting = () => { StripePriceId: '', StripeUnitPrice: 8.0, StripeMinTopUp: 1, + StripePromotionCodesEnabled: false, }); let [loading, setLoading] = useState(false); diff --git a/web/src/components/settings/PersonalSetting.jsx b/web/src/components/settings/PersonalSetting.jsx index 3ba8dcfd3..15dfbd973 100644 --- a/web/src/components/settings/PersonalSetting.jsx +++ b/web/src/components/settings/PersonalSetting.jsx @@ -19,7 +19,14 @@ For commercial licensing, please contact support@quantumnous.com import React, { useContext, useEffect, useState } from 'react'; import { useNavigate } from 'react-router-dom'; -import { API, copy, showError, showInfo, showSuccess } from '../../helpers'; +import { + API, + copy, + showError, + showInfo, + showSuccess, + setStatusData, +} from '../../helpers'; import { UserContext } from '../../context/User'; import { Modal } from '@douyinfe/semi-ui'; import { useTranslation } from 'react-i18next'; @@ -71,18 +78,40 @@ const PersonalSetting = () => { }); useEffect(() => { - let status = localStorage.getItem('status'); - if (status) { - status = JSON.parse(status); - setStatus(status); - if (status.turnstile_check) { + let saved = localStorage.getItem('status'); + if (saved) { + const parsed = JSON.parse(saved); + setStatus(parsed); + if (parsed.turnstile_check) { setTurnstileEnabled(true); - setTurnstileSiteKey(status.turnstile_site_key); + setTurnstileSiteKey(parsed.turnstile_site_key); + } else { + setTurnstileEnabled(false); + setTurnstileSiteKey(''); } } - getUserData().then((res) => { - console.log(userState); - }); + // Always refresh status from server to avoid stale flags (e.g., admin just enabled OAuth) + (async () => { + try { + const res = await API.get('/api/status'); + const { success, data } = res.data; + if (success && data) { + setStatus(data); + setStatusData(data); + if (data.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(data.turnstile_site_key); + } else { + setTurnstileEnabled(false); + setTurnstileSiteKey(''); + } + } + } catch (e) { + // ignore and keep local status + } + })(); + + getUserData(); }, []); useEffect(() => { diff --git a/web/src/components/settings/RatioSetting.jsx b/web/src/components/settings/RatioSetting.jsx index 096722bba..f5d8ef99d 100644 --- a/web/src/components/settings/RatioSetting.jsx +++ b/web/src/components/settings/RatioSetting.jsx @@ -39,6 +39,9 @@ const RatioSetting = () => { CompletionRatio: '', GroupRatio: '', GroupGroupRatio: '', + ImageRatio: '', + AudioRatio: '', + AudioCompletionRatio: '', AutoGroups: '', DefaultUseAutoGroup: false, ExposeRatioEnabled: false, @@ -61,7 +64,10 @@ const RatioSetting = () => { item.key === 'UserUsableGroups' || item.key === 'CompletionRatio' || item.key === 'ModelPrice' || - item.key === 'CacheRatio' + item.key === 'CacheRatio' || + item.key === 'ImageRatio' || + item.key === 'AudioRatio' || + item.key === 'AudioCompletionRatio' ) { try { item.value = JSON.stringify(JSON.parse(item.value), null, 2); diff --git a/web/src/components/settings/SystemSetting.jsx b/web/src/components/settings/SystemSetting.jsx index 9c7eeaadc..f9a2c019d 100644 --- a/web/src/components/settings/SystemSetting.jsx +++ b/web/src/components/settings/SystemSetting.jsx @@ -29,6 +29,7 @@ import { TagInput, Spin, Card, + Radio, } from '@douyinfe/semi-ui'; const { Text } = Typography; import { @@ -44,6 +45,7 @@ import { useTranslation } from 'react-i18next'; const SystemSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ + PasswordLoginEnabled: '', PasswordRegisterEnabled: '', EmailVerificationEnabled: '', @@ -87,6 +89,15 @@ const SystemSetting = () => { LinuxDOClientSecret: '', LinuxDOMinimumTrustLevel: '', ServerAddress: '', + // SSRF防护配置 + 'fetch_setting.enable_ssrf_protection': true, + 'fetch_setting.allow_private_ip': '', + 'fetch_setting.domain_filter_mode': false, // true 白名单,false 黑名单 + 'fetch_setting.ip_filter_mode': false, // true 白名单,false 黑名单 + 'fetch_setting.domain_list': [], + 'fetch_setting.ip_list': [], + 'fetch_setting.allowed_ports': [], + 'fetch_setting.apply_ip_filter_for_domain': false, }); const [originInputs, setOriginInputs] = useState({}); @@ -98,6 +109,11 @@ const SystemSetting = () => { useState(false); const [linuxDOOAuthEnabled, setLinuxDOOAuthEnabled] = useState(false); const [emailToAdd, setEmailToAdd] = useState(''); + const [domainFilterMode, setDomainFilterMode] = useState(true); + const [ipFilterMode, setIpFilterMode] = useState(true); + const [domainList, setDomainList] = useState([]); + const [ipList, setIpList] = useState([]); + const [allowedPorts, setAllowedPorts] = useState([]); const getOptions = async () => { setLoading(true); @@ -113,6 +129,37 @@ const SystemSetting = () => { case 'EmailDomainWhitelist': setEmailDomainWhitelist(item.value ? item.value.split(',') : []); break; + case 'fetch_setting.allow_private_ip': + case 'fetch_setting.enable_ssrf_protection': + case 'fetch_setting.domain_filter_mode': + case 'fetch_setting.ip_filter_mode': + case 'fetch_setting.apply_ip_filter_for_domain': + item.value = toBoolean(item.value); + break; + case 'fetch_setting.domain_list': + try { + const domains = item.value ? JSON.parse(item.value) : []; + setDomainList(Array.isArray(domains) ? domains : []); + } catch (e) { + setDomainList([]); + } + break; + case 'fetch_setting.ip_list': + try { + const ips = item.value ? JSON.parse(item.value) : []; + setIpList(Array.isArray(ips) ? ips : []); + } catch (e) { + setIpList([]); + } + break; + case 'fetch_setting.allowed_ports': + try { + const ports = item.value ? JSON.parse(item.value) : []; + setAllowedPorts(Array.isArray(ports) ? ports : []); + } catch (e) { + setAllowedPorts(['80', '443', '8080', '8443']); + } + break; case 'PasswordLoginEnabled': case 'PasswordRegisterEnabled': case 'EmailVerificationEnabled': @@ -140,6 +187,13 @@ const SystemSetting = () => { }); setInputs(newInputs); setOriginInputs(newInputs); + // 同步模式布尔到本地状态 + if (typeof newInputs['fetch_setting.domain_filter_mode'] !== 'undefined') { + setDomainFilterMode(!!newInputs['fetch_setting.domain_filter_mode']); + } + if (typeof newInputs['fetch_setting.ip_filter_mode'] !== 'undefined') { + setIpFilterMode(!!newInputs['fetch_setting.ip_filter_mode']); + } if (formApiRef.current) { formApiRef.current.setValues(newInputs); } @@ -276,6 +330,46 @@ const SystemSetting = () => { } }; + const submitSSRF = async () => { + const options = []; + + // 处理域名过滤模式与列表 + options.push({ + key: 'fetch_setting.domain_filter_mode', + value: domainFilterMode, + }); + if (Array.isArray(domainList)) { + options.push({ + key: 'fetch_setting.domain_list', + value: JSON.stringify(domainList), + }); + } + + // 处理IP过滤模式与列表 + options.push({ + key: 'fetch_setting.ip_filter_mode', + value: ipFilterMode, + }); + if (Array.isArray(ipList)) { + options.push({ + key: 'fetch_setting.ip_list', + value: JSON.stringify(ipList), + }); + } + + // 处理端口配置 + if (Array.isArray(allowedPorts)) { + options.push({ + key: 'fetch_setting.allowed_ports', + value: JSON.stringify(allowedPorts), + }); + } + + if (options.length > 0) { + await updateOptions(options); + } + }; + const handleAddEmail = () => { if (emailToAdd && emailToAdd.trim() !== '') { const domain = emailToAdd.trim(); @@ -587,6 +681,179 @@ const SystemSetting = () => { + + + + {t('配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全')} + + + + + handleCheckboxChange('fetch_setting.enable_ssrf_protection', e) + } + > + {t('启用SSRF防护(推荐开启以保护服务器安全)')} + + + + + + + + handleCheckboxChange('fetch_setting.allow_private_ip', e) + } + > + {t('允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)')} + + + + + + + + handleCheckboxChange('fetch_setting.apply_ip_filter_for_domain', e) + } + style={{ marginBottom: 8 }} + > + {t('对域名启用 IP 过滤(实验性)')} + + + {t(domainFilterMode ? '域名白名单' : '域名黑名单')} + + + {t('支持通配符格式,如:example.com, *.api.example.com')} + + { + const selected = val && val.target ? val.target.value : val; + const isWhitelist = selected === 'whitelist'; + setDomainFilterMode(isWhitelist); + setInputs(prev => ({ + ...prev, + 'fetch_setting.domain_filter_mode': isWhitelist, + })); + }} + style={{ marginBottom: 8 }} + > + {t('白名单')} + {t('黑名单')} + + { + setDomainList(value); + // 触发Form的onChange事件 + setInputs(prev => ({ + ...prev, + 'fetch_setting.domain_list': value + })); + }} + placeholder={t('输入域名后回车,如:example.com')} + style={{ width: '100%' }} + /> + + + + + + + {t(ipFilterMode ? 'IP白名单' : 'IP黑名单')} + + + {t('支持CIDR格式,如:8.8.8.8, 192.168.1.0/24')} + + { + const selected = val && val.target ? val.target.value : val; + const isWhitelist = selected === 'whitelist'; + setIpFilterMode(isWhitelist); + setInputs(prev => ({ + ...prev, + 'fetch_setting.ip_filter_mode': isWhitelist, + })); + }} + style={{ marginBottom: 8 }} + > + {t('白名单')} + {t('黑名单')} + + { + setIpList(value); + // 触发Form的onChange事件 + setInputs(prev => ({ + ...prev, + 'fetch_setting.ip_list': value + })); + }} + placeholder={t('输入IP地址后回车,如:8.8.8.8')} + style={{ width: '100%' }} + /> + + + + + + {t('允许的端口')} + + {t('支持单个端口和端口范围,如:80, 443, 8000-8999')} + + { + setAllowedPorts(value); + // 触发Form的onChange事件 + setInputs(prev => ({ + ...prev, + 'fetch_setting.allowed_ports': value + })); + }} + placeholder={t('输入端口后回车,如:80 或 8000-8999')} + style={{ width: '100%' }} + /> + + {t('端口配置详细说明')} + + + + + + + + ); }; + const isBound = (accountId) => Boolean(accountId); + const [showTelegramBindModal, setShowTelegramBindModal] = React.useState(false); + return ( {/* 卡片头部 */} @@ -142,7 +146,7 @@ const AccountManagement = ({ size='small' onClick={() => setShowEmailBindModal(true)} > - {userState.user && userState.user.email !== '' + {isBound(userState.user?.email) ? t('修改绑定') : t('绑定')} @@ -165,9 +169,11 @@ const AccountManagement = ({ {t('微信')}
- {userState.user && userState.user.wechat_id !== '' - ? t('已绑定') - : t('未绑定')} + {!status.wechat_login + ? t('未启用') + : isBound(userState.user?.wechat_id) + ? t('已绑定') + : t('未绑定')}
@@ -179,7 +185,7 @@ const AccountManagement = ({ disabled={!status.wechat_login} onClick={() => setShowWeChatBindModal(true)} > - {userState.user && userState.user.wechat_id !== '' + {isBound(userState.user?.wechat_id) ? t('修改绑定') : status.wechat_login ? t('绑定') @@ -220,8 +226,7 @@ const AccountManagement = ({ onGitHubOAuthClicked(status.github_client_id) } disabled={ - (userState.user && userState.user.github_id !== '') || - !status.github_oauth + isBound(userState.user?.github_id) || !status.github_oauth } > {status.github_oauth ? t('绑定') : t('未启用')} @@ -264,8 +269,7 @@ const AccountManagement = ({ ) } disabled={ - (userState.user && userState.user.oidc_id !== '') || - !status.oidc_enabled + isBound(userState.user?.oidc_id) || !status.oidc_enabled } > {status.oidc_enabled ? t('绑定') : t('未启用')} @@ -298,26 +302,56 @@ const AccountManagement = ({
{status.telegram_oauth ? ( - userState.user.telegram_id !== '' ? ( - ) : ( -
- -
+ ) ) : ( - )}
+ setShowTelegramBindModal(false)} + footer={null} + > +
+ {t('点击下方按钮通过 Telegram 完成绑定')} +
+
+
+ +
+
+
{/* LinuxDO绑定 */} @@ -350,8 +384,7 @@ const AccountManagement = ({ onLinuxDOOAuthClicked(status.linuxdo_client_id) } disabled={ - (userState.user && userState.user.linux_do_id !== '') || - !status.linuxdo_oauth + isBound(userState.user?.linux_do_id) || !status.linuxdo_oauth } > {status.linuxdo_oauth ? t('绑定') : t('未启用')} diff --git a/web/src/components/settings/personal/cards/NotificationSettings.jsx b/web/src/components/settings/personal/cards/NotificationSettings.jsx index 0b097eaff..aad612d2c 100644 --- a/web/src/components/settings/personal/cards/NotificationSettings.jsx +++ b/web/src/components/settings/personal/cards/NotificationSettings.jsx @@ -44,6 +44,7 @@ import CodeViewer from '../../../playground/CodeViewer'; import { StatusContext } from '../../../../context/Status'; import { UserContext } from '../../../../context/User'; import { useUserPermissions } from '../../../../hooks/common/useUserPermissions'; +import { useSidebar } from '../../../../hooks/common/useSidebar'; const NotificationSettings = ({ t, @@ -97,6 +98,9 @@ const NotificationSettings = ({ isSidebarModuleAllowed, } = useUserPermissions(); + // 使用useSidebar钩子获取刷新方法 + const { refreshUserConfig } = useSidebar(); + // 左侧边栏设置处理函数 const handleSectionChange = (sectionKey) => { return (checked) => { @@ -132,6 +136,9 @@ const NotificationSettings = ({ }); if (res.data.success) { showSuccess(t('侧边栏设置保存成功')); + + // 刷新useSidebar钩子中的用户配置,实现实时更新 + await refreshUserConfig(); } else { showError(res.data.message); } @@ -334,7 +341,7 @@ const NotificationSettings = ({ loading={sidebarLoading} className='!rounded-lg' > - {t('保存边栏设置')} + {t('保存设置')} ) : ( diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index c0a216246..2eb480e7a 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -85,6 +85,26 @@ const REGION_EXAMPLE = { 'claude-3-5-sonnet-20240620': 'europe-west1', }; +// 支持并且已适配通过接口获取模型列表的渠道类型 +const MODEL_FETCHABLE_TYPES = new Set([ + 1, + 4, + 14, + 34, + 17, + 26, + 24, + 47, + 25, + 20, + 23, + 31, + 35, + 40, + 42, + 48, +]); + function type2secretPrompt(type) { // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') switch (type) { @@ -144,6 +164,8 @@ const EditChannelModal = (props) => { settings: '', // 仅 Vertex: 密钥格式(存入 settings.vertex_key_type) vertex_key_type: 'json', + // 企业账户设置 + is_enterprise_account: false, }; const [batch, setBatch] = useState(false); const [multiToSingle, setMultiToSingle] = useState(false); @@ -169,6 +191,7 @@ const EditChannelModal = (props) => { const [channelSearchValue, setChannelSearchValue] = useState(''); const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式 const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加) + const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户 // 2FA验证查看密钥相关状态 const [twoFAState, setTwoFAState] = useState({ @@ -215,7 +238,7 @@ const EditChannelModal = (props) => { pass_through_body_enabled: false, system_prompt: '', }); - const showApiConfigCard = inputs.type !== 45; // 控制是否显示 API 配置卡片(仅当渠道类型不是 豆包 时显示) + const showApiConfigCard = true; // 控制是否显示 API 配置卡片 const getInitValues = () => ({ ...originInputs }); // 处理渠道额外设置的更新 @@ -322,6 +345,10 @@ const EditChannelModal = (props) => { case 36: localModels = ['suno_music', 'suno_lyrics']; break; + case 45: + localModels = getChannelModels(value); + setInputs((prevInputs) => ({ ...prevInputs, base_url: 'https://ark.cn-beijing.volces.com' })); + break; default: localModels = getChannelModels(value); break; @@ -413,15 +440,27 @@ const EditChannelModal = (props) => { parsedSettings.azure_responses_version || ''; // 读取 Vertex 密钥格式 data.vertex_key_type = parsedSettings.vertex_key_type || 'json'; + // 读取企业账户设置 + data.is_enterprise_account = parsedSettings.openrouter_enterprise === true; } catch (error) { console.error('解析其他设置失败:', error); data.azure_responses_version = ''; data.region = ''; data.vertex_key_type = 'json'; + data.is_enterprise_account = false; } } else { // 兼容历史数据:老渠道没有 settings 时,默认按 json 展示 data.vertex_key_type = 'json'; + data.is_enterprise_account = false; + } + + if ( + data.type === 45 && + (!data.base_url || + (typeof data.base_url === 'string' && data.base_url.trim() === '')) + ) { + data.base_url = 'https://ark.cn-beijing.volces.com'; } setInputs(data); @@ -433,6 +472,8 @@ const EditChannelModal = (props) => { } else { setAutoBan(true); } + // 同步企业账户状态 + setIsEnterpriseAccount(data.is_enterprise_account || false); setBasicModels(getChannelModels(data.type)); // 同步更新channelSettings状态显示 setChannelSettings({ @@ -692,6 +733,8 @@ const EditChannelModal = (props) => { }); // 重置密钥模式状态 setKeyMode('append'); + // 重置企业账户状态 + setIsEnterpriseAccount(false); // 清空表单中的key_mode字段 if (formApiRef.current) { formApiRef.current.setValue('key_mode', undefined); @@ -802,7 +845,9 @@ const EditChannelModal = (props) => { delete localInputs.key; } } else { - localInputs.key = batch ? JSON.stringify(keys) : JSON.stringify(keys[0]); + localInputs.key = batch + ? JSON.stringify(keys) + : JSON.stringify(keys[0]); } } } @@ -822,6 +867,10 @@ const EditChannelModal = (props) => { showInfo(t('请至少选择一个模型!')); return; } + if (localInputs.type === 45 && (!localInputs.base_url || localInputs.base_url.trim() === '')) { + showInfo(t('请输入API地址!')); + return; + } if ( localInputs.model_mapping && localInputs.model_mapping !== '' && @@ -851,6 +900,21 @@ const EditChannelModal = (props) => { }; localInputs.setting = JSON.stringify(channelExtraSettings); + // 处理type === 20的企业账户设置 + if (localInputs.type === 20) { + let settings = {}; + if (localInputs.settings) { + try { + settings = JSON.parse(localInputs.settings); + } catch (error) { + console.error('解析settings失败:', error); + } + } + // 设置企业账户标识,无论是true还是false都要传到后端 + settings.openrouter_enterprise = localInputs.is_enterprise_account === true; + localInputs.settings = JSON.stringify(settings); + } + // 清理不需要发送到后端的字段 delete localInputs.force_format; delete localInputs.thinking_to_content; @@ -858,6 +922,7 @@ const EditChannelModal = (props) => { delete localInputs.pass_through_body_enabled; delete localInputs.system_prompt; delete localInputs.system_prompt_override; + delete localInputs.is_enterprise_account; // 顶层的 vertex_key_type 不应发送给后端 delete localInputs.vertex_key_type; @@ -899,6 +964,56 @@ const EditChannelModal = (props) => { } }; + // 密钥去重函数 + const deduplicateKeys = () => { + const currentKey = formApiRef.current?.getValue('key') || inputs.key || ''; + + if (!currentKey.trim()) { + showInfo(t('请先输入密钥')); + return; + } + + // 按行分割密钥 + const keyLines = currentKey.split('\n'); + const beforeCount = keyLines.length; + + // 使用哈希表去重,保持原有顺序 + const keySet = new Set(); + const deduplicatedKeys = []; + + keyLines.forEach((line) => { + const trimmedLine = line.trim(); + if (trimmedLine && !keySet.has(trimmedLine)) { + keySet.add(trimmedLine); + deduplicatedKeys.push(trimmedLine); + } + }); + + const afterCount = deduplicatedKeys.length; + const deduplicatedKeyText = deduplicatedKeys.join('\n'); + + // 更新表单和状态 + if (formApiRef.current) { + formApiRef.current.setValue('key', deduplicatedKeyText); + } + handleInputChange('key', deduplicatedKeyText); + + // 显示去重结果 + const message = t( + '去重完成:去重前 {{before}} 个密钥,去重后 {{after}} 个密钥', + { + before: beforeCount, + after: afterCount, + }, + ); + + if (beforeCount === afterCount) { + showInfo(t('未发现重复密钥')); + } else { + showSuccess(message); + } + }; + const addCustomModels = () => { if (customModel.trim() === '') return; const modelArray = customModel.split(',').map((model) => model.trim()); @@ -994,24 +1109,41 @@ const EditChannelModal = (props) => { )} {batch && ( - { - setMultiToSingle((prev) => !prev); - setInputs((prev) => { - const newInputs = { ...prev }; - if (!multiToSingle) { - newInputs.multi_key_mode = multiKeyMode; - } else { - delete newInputs.multi_key_mode; - } - return newInputs; - }); - }} - > - {t('密钥聚合模式')} - + <> + { + setMultiToSingle((prev) => { + const nextValue = !prev; + setInputs((prevInputs) => { + const newInputs = { ...prevInputs }; + if (nextValue) { + newInputs.multi_key_mode = multiKeyMode; + } else { + delete newInputs.multi_key_mode; + } + return newInputs; + }); + return nextValue; + }); + }} + > + {t('密钥聚合模式')} + + + {inputs.type !== 41 && ( + + )} + )} ) : null; @@ -1175,6 +1307,21 @@ const EditChannelModal = (props) => { onChange={(value) => handleInputChange('type', value)} /> + {inputs.type === 20 && ( + { + setIsEnterpriseAccount(value); + handleInputChange('is_enterprise_account', value); + }} + extraText={t('企业账户为特殊返回格式,需要特殊处理,如果非企业账户,请勿勾选')} + initValue={inputs.is_enterprise_account} + /> + )} + { value={inputs.vertex_key_type || 'json'} onChange={(value) => { // 更新设置中的 vertex_key_type - handleChannelOtherSettingsChange('vertex_key_type', value); + handleChannelOtherSettingsChange( + 'vertex_key_type', + value, + ); // 切换为 api_key 时,关闭批量与手动/文件切换,并清理已选文件 if (value === 'api_key') { setBatch(false); @@ -1218,7 +1368,8 @@ const EditChannelModal = (props) => { /> )} {batch ? ( - inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? ( + inputs.type === 41 && + (inputs.vertex_key_type || 'json') === 'json' ? ( { autoComplete='new-password' onChange={(value) => handleInputChange('key', value)} extraText={ -
+
{isEdit && isMultiKeyChannel && keyMode === 'append' && ( @@ -1282,7 +1433,8 @@ const EditChannelModal = (props) => { ) ) : ( <> - {inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? ( + {inputs.type === 41 && + (inputs.vertex_key_type || 'json') === 'json' ? ( <> {!batch && (
@@ -1789,6 +1941,30 @@ const EditChannelModal = (props) => { />
)} + + {inputs.type === 45 && ( +
+ + handleInputChange('base_url', value) + } + optionList={[ + { + value: 'https://ark.cn-beijing.volces.com', + label: 'https://ark.cn-beijing.volces.com' + }, + { + value: 'https://ark.ap-southeast.bytepluses.com', + label: 'https://ark.ap-southeast.bytepluses.com' + } + ]} + defaultValue='https://ark.cn-beijing.volces.com' + /> +
+ )} )} @@ -1872,13 +2048,15 @@ const EditChannelModal = (props) => { > {t('填入所有模型')} - + {MODEL_FETCHABLE_TYPES.has(inputs.type) && ( + + )} )} + handleDeleteKey(record.index)} + okType={'danger'} + position={'topRight'} + > + + ), }, diff --git a/web/src/components/table/mj-logs/MjLogsFilters.jsx b/web/src/components/table/mj-logs/MjLogsFilters.jsx index 44c6bcfcd..6db96e791 100644 --- a/web/src/components/table/mj-logs/MjLogsFilters.jsx +++ b/web/src/components/table/mj-logs/MjLogsFilters.jsx @@ -21,6 +21,8 @@ import React from 'react'; import { Button, Form } from '@douyinfe/semi-ui'; import { IconSearch } from '@douyinfe/semi-icons'; +import { DATE_RANGE_PRESETS } from '../../../constants/console.constants'; + const MjLogsFilters = ({ formInitValues, setFormApi, @@ -54,6 +56,11 @@ const MjLogsFilters = ({ showClear pure size='small' + presets={DATE_RANGE_PRESETS.map(preset => ({ + text: t(preset.text), + start: preset.start(), + end: preset.end() + }))} />
diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index 766c17158..b63c7dd4f 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -35,8 +35,9 @@ import { Sparkles, } from 'lucide-react'; import { - TASK_ACTION_GENERATE, - TASK_ACTION_TEXT_GENERATE, + TASK_ACTION_FIRST_TAIL_GENERATE, + TASK_ACTION_GENERATE, TASK_ACTION_REFERENCE_GENERATE, + TASK_ACTION_TEXT_GENERATE } from '../../../constants/common.constant'; import { CHANNEL_OPTIONS } from '../../../constants/channel.constants'; @@ -111,6 +112,18 @@ const renderType = (type, t) => { {t('文生视频')} ); + case TASK_ACTION_FIRST_TAIL_GENERATE: + return ( + }> + {t('首尾生视频')} + + ); + case TASK_ACTION_REFERENCE_GENERATE: + return ( + }> + {t('参照生视频')} + + ); default: return ( }> @@ -343,7 +356,9 @@ export const getTaskLogsColumns = ({ // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 const isVideoTask = record.action === TASK_ACTION_GENERATE || - record.action === TASK_ACTION_TEXT_GENERATE; + record.action === TASK_ACTION_TEXT_GENERATE || + record.action === TASK_ACTION_FIRST_TAIL_GENERATE || + record.action === TASK_ACTION_REFERENCE_GENERATE; const isSuccess = record.status === 'SUCCESS'; const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); if (isSuccess && isVideoTask && isUrl) { diff --git a/web/src/components/table/task-logs/TaskLogsFilters.jsx b/web/src/components/table/task-logs/TaskLogsFilters.jsx index d5e081ab7..e27cea867 100644 --- a/web/src/components/table/task-logs/TaskLogsFilters.jsx +++ b/web/src/components/table/task-logs/TaskLogsFilters.jsx @@ -21,6 +21,8 @@ import React from 'react'; import { Button, Form } from '@douyinfe/semi-ui'; import { IconSearch } from '@douyinfe/semi-icons'; +import { DATE_RANGE_PRESETS } from '../../../constants/console.constants'; + const TaskLogsFilters = ({ formInitValues, setFormApi, @@ -54,6 +56,11 @@ const TaskLogsFilters = ({ showClear pure size='small' + presets={DATE_RANGE_PRESETS.map(preset => ({ + text: t(preset.text), + start: preset.start(), + end: preset.end() + }))} />
diff --git a/web/src/components/table/task-logs/modals/ContentModal.jsx b/web/src/components/table/task-logs/modals/ContentModal.jsx index 3d747b77d..3bfba37b1 100644 --- a/web/src/components/table/task-logs/modals/ContentModal.jsx +++ b/web/src/components/table/task-logs/modals/ContentModal.jsx @@ -17,8 +17,11 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React from 'react'; -import { Modal } from '@douyinfe/semi-ui'; +import React, { useState, useEffect } from 'react'; +import { Modal, Button, Typography, Spin } from '@douyinfe/semi-ui'; +import { IconExternalOpen, IconCopy } from '@douyinfe/semi-icons'; + +const { Text } = Typography; const ContentModal = ({ isModalOpen, @@ -26,17 +29,120 @@ const ContentModal = ({ modalContent, isVideo, }) => { + const [videoError, setVideoError] = useState(false); + const [isLoading, setIsLoading] = useState(false); + + useEffect(() => { + if (isModalOpen && isVideo) { + setVideoError(false); + setIsLoading(true); + } + }, [isModalOpen, isVideo]); + + const handleVideoError = () => { + setVideoError(true); + setIsLoading(false); + }; + + const handleVideoLoaded = () => { + setIsLoading(false); + }; + + const handleCopyUrl = () => { + navigator.clipboard.writeText(modalContent); + }; + + const handleOpenInNewTab = () => { + window.open(modalContent, '_blank'); + }; + + const renderVideoContent = () => { + if (videoError) { + return ( +
+ + 视频无法在当前浏览器中播放,这可能是由于: + + + • 视频服务商的跨域限制 + + + • 需要特定的请求头或认证 + + + • 防盗链保护机制 + + +
+ + +
+ +
+ + {modalContent} + +
+
+ ); + } + + return ( +
+ {isLoading && ( +
+ +
+ )} +
+ ); + }; + return ( setIsModalOpen(false)} onCancel={() => setIsModalOpen(false)} closable={null} - bodyStyle={{ height: '400px', overflow: 'auto' }} + bodyStyle={{ + height: isVideo ? '450px' : '400px', + overflow: 'auto', + padding: isVideo && videoError ? '0' : '24px' + }} width={800} > {isVideo ? ( -