From 3ac9ff602886e1b809a79787faa36e0d3f2ffc27 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Thu, 23 Oct 2025 21:18:11 +0800 Subject: [PATCH] feat: doubao-seedream support image edit --- common/gin.go | 60 ++++++++ middleware/distributor.go | 4 +- relay/channel/volcengine/adaptor.go | 212 ++++++++++++++-------------- 3 files changed, 170 insertions(+), 106 deletions(-) diff --git a/common/gin.go b/common/gin.go index 4bc9f1ba7..cc83a5f98 100644 --- a/common/gin.go +++ b/common/gin.go @@ -2,9 +2,11 @@ package common import ( "bytes" + "encoding/json" "io" "mime/multipart" "net/http" + "net/url" "strings" "time" @@ -40,6 +42,10 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = Unmarshal(requestBody, &v) + } else if strings.Contains(contentType, gin.MIMEPOSTForm) { + err = parseFormData(requestBody, &v) + } else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) { + err = parseMultipartFormData(c, requestBody, &v) } else { // skip for now // TODO: someday non json request have variant model, we will need to implementation this @@ -138,3 +144,57 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return form, nil } + +func parseFormData(data []byte, v any) error { + values, err := url.ParseQuery(string(data)) + if err != nil { + return err + } + formMap := make(map[string]any) + for key, vals := range values { + if len(vals) == 1 { + formMap[key] = vals[0] + } else { + formMap[key] = vals + } + } + jsonData, err := json.Marshal(formMap) + if err != nil { + return err + } + + return json.Unmarshal(jsonData, v) +} + +func parseMultipartFormData(c *gin.Context, data []byte, v any) error { + contentType := c.Request.Header.Get("Content-Type") + boundary := "" + if idx := strings.Index(contentType, "boundary="); idx != -1 { + boundary = contentType[idx+9:] + } + + if boundary == "" { + return json.Unmarshal(data, v) // Fallback to JSON + } + + reader := multipart.NewReader(bytes.NewReader(data), boundary) + form, err := reader.ReadForm(32 << 20) // 32 MB max memory + if err != nil { + return err + } + defer form.RemoveAll() + formMap := make(map[string]any) + for key, vals := range form.Value { + if len(vals) == 1 { + formMap[key] = vals[0] + } else { + formMap[key] = vals + } + } + jsonData, err := json.Marshal(formMap) + if err != nil { + return err + } + + return json.Unmarshal(jsonData, v) +} diff --git a/middleware/distributor.go b/middleware/distributor.go index 2ff79e6ca..3f7103990 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "slices" "strconv" "strings" "time" @@ -245,7 +246,8 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { //modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") - if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { + contentType := c.Request.Header.Get("Content-Type") + if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) { modelRequest.Model = c.PostForm("model") } } diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index f46328e37..a377b1dde 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -6,9 +6,7 @@ import ( "errors" "fmt" "io" - "mime/multipart" "net/http" - "net/textproto" "path/filepath" "strings" @@ -104,106 +102,107 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf switch info.RelayMode { case constant.RelayModeImagesGenerations: return request, nil - case constant.RelayModeImagesEdits: - - var requestBody bytes.Buffer - writer := multipart.NewWriter(&requestBody) - - writer.WriteField("model", request.Model) - - formData := c.Request.PostForm - for key, values := range formData { - if key == "model" { - continue - } - for _, value := range values { - writer.WriteField(key, value) - } - } - - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { - return nil, errors.New("failed to parse multipart form") - } - - if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { - var imageFiles []*multipart.FileHeader - var exists bool - - if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 { - if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 { - foundArrayImages := false - for fieldName, files := range c.Request.MultipartForm.File { - if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { - foundArrayImages = true - for _, file := range files { - imageFiles = append(imageFiles, file) - } - } - } - - if !foundArrayImages && (len(imageFiles) == 0) { - return nil, errors.New("image is required") - } - } - } - - for i, fileHeader := range imageFiles { - file, err := fileHeader.Open() - if err != nil { - return nil, fmt.Errorf("failed to open image file %d: %w", i, err) - } - defer file.Close() - - fieldName := "image" - if len(imageFiles) > 1 { - fieldName = "image[]" - } - - mimeType := detectImageMimeType(fileHeader.Filename) - - h := make(textproto.MIMEHeader) - h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename)) - h.Set("Content-Type", mimeType) - - part, err := writer.CreatePart(h) - if err != nil { - return nil, fmt.Errorf("create form part failed for image %d: %w", i, err) - } - - if _, err := io.Copy(part, file); err != nil { - return nil, fmt.Errorf("copy file failed for image %d: %w", i, err) - } - } - - if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 { - maskFile, err := maskFiles[0].Open() - if err != nil { - return nil, errors.New("failed to open mask file") - } - defer maskFile.Close() - - mimeType := detectImageMimeType(maskFiles[0].Filename) - - h := make(textproto.MIMEHeader) - h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename)) - h.Set("Content-Type", mimeType) - - maskPart, err := writer.CreatePart(h) - if err != nil { - return nil, errors.New("create form file failed for mask") - } - - if _, err := io.Copy(maskPart, maskFile); err != nil { - return nil, errors.New("copy mask file failed") - } - } - } else { - return nil, errors.New("no multipart form data found") - } - - writer.Close() - c.Request.Header.Set("Content-Type", writer.FormDataContentType()) - return bytes.NewReader(requestBody.Bytes()), nil + // 根据官方文档,并没有发现豆包生图支持表单请求:https://www.volcengine.com/docs/82379/1824121 + //case constant.RelayModeImagesEdits: + // + // var requestBody bytes.Buffer + // writer := multipart.NewWriter(&requestBody) + // + // writer.WriteField("model", request.Model) + // + // formData := c.Request.PostForm + // for key, values := range formData { + // if key == "model" { + // continue + // } + // for _, value := range values { + // writer.WriteField(key, value) + // } + // } + // + // if err := c.Request.ParseMultipartForm(32 << 20); err != nil { + // return nil, errors.New("failed to parse multipart form") + // } + // + // if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { + // var imageFiles []*multipart.FileHeader + // var exists bool + // + // if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 { + // if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 { + // foundArrayImages := false + // for fieldName, files := range c.Request.MultipartForm.File { + // if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { + // foundArrayImages = true + // for _, file := range files { + // imageFiles = append(imageFiles, file) + // } + // } + // } + // + // if !foundArrayImages && (len(imageFiles) == 0) { + // return nil, errors.New("image is required") + // } + // } + // } + // + // for i, fileHeader := range imageFiles { + // file, err := fileHeader.Open() + // if err != nil { + // return nil, fmt.Errorf("failed to open image file %d: %w", i, err) + // } + // defer file.Close() + // + // fieldName := "image" + // if len(imageFiles) > 1 { + // fieldName = "image[]" + // } + // + // mimeType := detectImageMimeType(fileHeader.Filename) + // + // h := make(textproto.MIMEHeader) + // h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename)) + // h.Set("Content-Type", mimeType) + // + // part, err := writer.CreatePart(h) + // if err != nil { + // return nil, fmt.Errorf("create form part failed for image %d: %w", i, err) + // } + // + // if _, err := io.Copy(part, file); err != nil { + // return nil, fmt.Errorf("copy file failed for image %d: %w", i, err) + // } + // } + // + // if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 { + // maskFile, err := maskFiles[0].Open() + // if err != nil { + // return nil, errors.New("failed to open mask file") + // } + // defer maskFile.Close() + // + // mimeType := detectImageMimeType(maskFiles[0].Filename) + // + // h := make(textproto.MIMEHeader) + // h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename)) + // h.Set("Content-Type", mimeType) + // + // maskPart, err := writer.CreatePart(h) + // if err != nil { + // return nil, errors.New("create form file failed for mask") + // } + // + // if _, err := io.Copy(maskPart, maskFile); err != nil { + // return nil, errors.New("copy mask file failed") + // } + // } + // } else { + // return nil, errors.New("no multipart form data found") + // } + // + // writer.Close() + // c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + // return bytes.NewReader(requestBody.Bytes()), nil default: return request, nil @@ -251,10 +250,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil case constant.RelayModeEmbeddings: return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil - case constant.RelayModeImagesGenerations: + //豆包的图生图也走generations接口: https://www.volcengine.com/docs/82379/1824121 + case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil - case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil + //case constant.RelayModeImagesEdits: + // return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil case constant.RelayModeRerank: return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil case constant.RelayModeAudioSpeech: @@ -278,6 +278,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel } req.Set("Content-Type", "application/json") return nil + } else if info.RelayMode == constant.RelayModeImagesEdits { + req.Set("Content-Type", gin.MIMEJSON) } req.Set("Authorization", "Bearer "+info.ApiKey)