From 956244c742842f8f096dd8e47a97404107a8f777 Mon Sep 17 00:00:00 2001 From: huanghejian Date: Tue, 16 Sep 2025 14:30:12 +0800 Subject: [PATCH] fix: VolcEngine doubao-seedream-4-0-250828 --- controller/channel-test.go | 39 +++++++++++++++++++++++++++ relay/channel/volcengine/adaptor.go | 2 ++ relay/channel/volcengine/constants.go | 1 + 3 files changed, 42 insertions(+) diff --git a/controller/channel-test.go b/controller/channel-test.go index 5a668c488..9ea6eed75 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -90,6 +90,11 @@ func testChannel(channel *model.Channel, testModel string) testResult { requestPath = "/v1/embeddings" // 修改请求路径 } + // VolcEngine 图像生成模型 + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + } + c.Request = &http.Request{ Method: "POST", URL: &url.URL{Path: requestPath}, // 使用动态路径 @@ -109,6 +114,21 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } + // 重新检查模型类型并更新请求路径 + if strings.Contains(strings.ToLower(testModel), "embedding") || + strings.HasPrefix(testModel, "m3e") || + strings.Contains(testModel, "bge-") || + strings.Contains(testModel, "embed") || + channel.Type == constant.ChannelTypeMokaAI { + requestPath = "/v1/embeddings" + c.Request.URL.Path = requestPath + } + + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + c.Request.URL.Path = requestPath + } + cache, err := model.GetUserCache(1) if err != nil { return testResult{ @@ -140,6 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult { if c.Request.URL.Path == "/v1/embeddings" { relayFormat = types.RelayFormatEmbedding } + if c.Request.URL.Path == "/v1/images/generations" { + relayFormat = types.RelayFormatOpenAIImage + } info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) @@ -201,6 +224,22 @@ func testChannel(channel *model.Channel, testModel string) testResult { } // 调用专门用于 Embedding 的转换函数 convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest) + } else if info.RelayMode == relayconstant.RelayModeImagesGenerations { + // 创建一个 ImageRequest + prompt := "cat" + if request.Prompt != nil { + if promptStr, ok := request.Prompt.(string); ok && promptStr != "" { + prompt = promptStr + } + } + imageRequest := dto.ImageRequest{ + Prompt: prompt, + Model: request.Model, + N: uint(request.N), + Size: request.Size, + } + // 调用专门用于图像生成的转换函数 + convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest) } else { // 对其他所有请求类型(如 Chat),保持原有逻辑 convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request) diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 0af019da4..eb88412af 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -41,6 +41,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { switch info.RelayMode { + case constant.RelayModeImagesGenerations: + return request, nil case constant.RelayModeImagesEdits: var requestBody bytes.Buffer diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go index 30cc902e7..fca10e7c1 100644 --- a/relay/channel/volcengine/constants.go +++ b/relay/channel/volcengine/constants.go @@ -8,6 +8,7 @@ var ModelList = []string{ "Doubao-lite-32k", "Doubao-lite-4k", "Doubao-embedding", + "doubao-seedream-4-0-250828", } var ChannelName = "volcengine"