feat: doubao-seedream support image edit

This commit is contained in:
feitianbubu
2025-10-23 21:18:11 +08:00
parent fcf0f952b1
commit 3ac9ff6028
3 changed files with 170 additions and 106 deletions

View File

@@ -2,9 +2,11 @@ package common
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/url"
"strings"
"time"
@@ -40,6 +42,10 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = Unmarshal(requestBody, &v)
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
err = parseFormData(requestBody, &v)
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
err = parseMultipartFormData(c, requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
@@ -138,3 +144,57 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return form, nil
}
func parseFormData(data []byte, v any) error {
values, err := url.ParseQuery(string(data))
if err != nil {
return err
}
formMap := make(map[string]any)
for key, vals := range values {
if len(vals) == 1 {
formMap[key] = vals[0]
} else {
formMap[key] = vals
}
}
jsonData, err := json.Marshal(formMap)
if err != nil {
return err
}
return json.Unmarshal(jsonData, v)
}
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
contentType := c.Request.Header.Get("Content-Type")
boundary := ""
if idx := strings.Index(contentType, "boundary="); idx != -1 {
boundary = contentType[idx+9:]
}
if boundary == "" {
return json.Unmarshal(data, v) // Fallback to JSON
}
reader := multipart.NewReader(bytes.NewReader(data), boundary)
form, err := reader.ReadForm(32 << 20) // 32 MB max memory
if err != nil {
return err
}
defer form.RemoveAll()
formMap := make(map[string]any)
for key, vals := range form.Value {
if len(vals) == 1 {
formMap[key] = vals[0]
} else {
formMap[key] = vals
}
}
jsonData, err := json.Marshal(formMap)
if err != nil {
return err
}
return json.Unmarshal(jsonData, v)
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
@@ -245,7 +246,8 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
contentType := c.Request.Header.Get("Content-Type")
if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) {
modelRequest.Model = c.PostForm("model")
}
}

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"path/filepath"
"strings"
@@ -104,106 +102,107 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
return request, nil
case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
formData := c.Request.PostForm
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
return nil, errors.New("failed to parse multipart form")
}
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
var imageFiles []*multipart.FileHeader
var exists bool
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
foundArrayImages := false
for fieldName, files := range c.Request.MultipartForm.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
foundArrayImages = true
for _, file := range files {
imageFiles = append(imageFiles, file)
}
}
}
if !foundArrayImages && (len(imageFiles) == 0) {
return nil, errors.New("image is required")
}
}
}
for i, fileHeader := range imageFiles {
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
}
defer file.Close()
fieldName := "image"
if len(imageFiles) > 1 {
fieldName = "image[]"
}
mimeType := detectImageMimeType(fileHeader.Filename)
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
h.Set("Content-Type", mimeType)
part, err := writer.CreatePart(h)
if err != nil {
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
}
if _, err := io.Copy(part, file); err != nil {
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
}
}
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
maskFile, err := maskFiles[0].Open()
if err != nil {
return nil, errors.New("failed to open mask file")
}
defer maskFile.Close()
mimeType := detectImageMimeType(maskFiles[0].Filename)
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
h.Set("Content-Type", mimeType)
maskPart, err := writer.CreatePart(h)
if err != nil {
return nil, errors.New("create form file failed for mask")
}
if _, err := io.Copy(maskPart, maskFile); err != nil {
return nil, errors.New("copy mask file failed")
}
}
} else {
return nil, errors.New("no multipart form data found")
}
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return bytes.NewReader(requestBody.Bytes()), nil
// 根据官方文档,并没有发现豆包生图支持表单请求:https://www.volcengine.com/docs/82379/1824121
//case constant.RelayModeImagesEdits:
//
// var requestBody bytes.Buffer
// writer := multipart.NewWriter(&requestBody)
//
// writer.WriteField("model", request.Model)
//
// formData := c.Request.PostForm
// for key, values := range formData {
// if key == "model" {
// continue
// }
// for _, value := range values {
// writer.WriteField(key, value)
// }
// }
//
// if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
// return nil, errors.New("failed to parse multipart form")
// }
//
// if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// var imageFiles []*multipart.FileHeader
// var exists bool
//
// if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// foundArrayImages := false
// for fieldName, files := range c.Request.MultipartForm.File {
// if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
// foundArrayImages = true
// for _, file := range files {
// imageFiles = append(imageFiles, file)
// }
// }
// }
//
// if !foundArrayImages && (len(imageFiles) == 0) {
// return nil, errors.New("image is required")
// }
// }
// }
//
// for i, fileHeader := range imageFiles {
// file, err := fileHeader.Open()
// if err != nil {
// return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
// }
// defer file.Close()
//
// fieldName := "image"
// if len(imageFiles) > 1 {
// fieldName = "image[]"
// }
//
// mimeType := detectImageMimeType(fileHeader.Filename)
//
// h := make(textproto.MIMEHeader)
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
// h.Set("Content-Type", mimeType)
//
// part, err := writer.CreatePart(h)
// if err != nil {
// return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
// }
//
// if _, err := io.Copy(part, file); err != nil {
// return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
// }
// }
//
// if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
// maskFile, err := maskFiles[0].Open()
// if err != nil {
// return nil, errors.New("failed to open mask file")
// }
// defer maskFile.Close()
//
// mimeType := detectImageMimeType(maskFiles[0].Filename)
//
// h := make(textproto.MIMEHeader)
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
// h.Set("Content-Type", mimeType)
//
// maskPart, err := writer.CreatePart(h)
// if err != nil {
// return nil, errors.New("create form file failed for mask")
// }
//
// if _, err := io.Copy(maskPart, maskFile); err != nil {
// return nil, errors.New("copy mask file failed")
// }
// }
// } else {
// return nil, errors.New("no multipart form data found")
// }
//
// writer.Close()
// c.Request.Header.Set("Content-Type", writer.FormDataContentType())
// return bytes.NewReader(requestBody.Bytes()), nil
default:
return request, nil
@@ -251,10 +250,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
case constant.RelayModeImagesGenerations:
//豆包的图生图也走generations接口: https://www.volcengine.com/docs/82379/1824121
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
//case constant.RelayModeImagesEdits:
// return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
case constant.RelayModeAudioSpeech:
@@ -278,6 +278,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
}
req.Set("Content-Type", "application/json")
return nil
} else if info.RelayMode == constant.RelayModeImagesEdits {
req.Set("Content-Type", gin.MIMEJSON)
}
req.Set("Authorization", "Bearer "+info.ApiKey)