diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 064445489..adce01822 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -47,7 +47,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { case constant.RelayModeImagesGenerations: fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl) case constant.RelayModeImagesEdits: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl) + if isWanModel(info.OriginModelName) { + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl) + } else { + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl) + } case constant.RelayModeCompletions: fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl) default: @@ -71,6 +75,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel req.Set("X-DashScope-Async", "enable") } if info.RelayMode == constant.RelayModeImagesEdits { + if isWanModel(info.OriginModelName) { + req.Set("X-DashScope-Async", "enable") + } req.Set("Content-Type", "application/json") } return nil @@ -107,6 +114,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } return aliRequest, nil } else if info.RelayMode == constant.RelayModeImagesEdits { + if isWanModel(info.OriginModelName) { + return oaiFormEdit2WanxImageEdit(c, info, request) + } // ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416 // 如果用户使用表单,则需要解析表单数据 if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { @@ -161,7 +171,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case constant.RelayModeImagesGenerations: err, usage = aliImageHandler(c, resp, info) case constant.RelayModeImagesEdits: - err, usage = aliImageEditHandler(c, resp, info) + if isWanModel(info.OriginModelName) { + err, usage = aliImageHandler(c, resp, info) + } else { + err, usage = aliImageEditHandler(c, resp, info) + } case constant.RelayModeRerank: err, usage = RerankHandler(c, resp, info) default: diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go index b8072b601..26f14a6c0 100644 --- a/relay/channel/ali/dto.go +++ b/relay/channel/ali/dto.go @@ -112,6 +112,19 @@ type AliImageInput struct { Messages []AliMessage `json:"messages,omitempty"` } +type WanImageInput struct { + Prompt string `json:"prompt"` // 必需:文本提示词,描述生成图像中期望包含的元素和视觉特点 + Images []string `json:"images"` // 必需:图像URL数组,长度不超过2,支持HTTP/HTTPS URL或Base64编码 + NegativePrompt string `json:"negative_prompt,omitempty"` // 可选:反向提示词,描述不希望在画面中看到的内容 +} + +type WanImageParameters struct { + N int `json:"n,omitempty"` // 生成图片数量,取值范围1-4,默认4 + Watermark *bool `json:"watermark,omitempty"` // 是否添加水印标识,默认false + Seed int `json:"seed,omitempty"` // 随机数种子,取值范围[0, 2147483647] + Strength float64 `json:"strength,omitempty"` // 修改幅度 0.0-1.0,默认0.5(部分模型支持) +} + type AliRerankParameters struct { TopN *int `json:"top_n,omitempty"` ReturnDocuments *bool `json:"return_documents,omitempty"` diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index bda6cbe3e..0e3fe1ea0 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -58,11 +58,7 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) { return &imageRequest, nil } -func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) { - var imageRequest AliImageRequest - imageRequest.Model = request.Model - imageRequest.ResponseFormat = request.ResponseFormat - +func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) { mf := c.Request.MultipartForm if mf == nil { if _, err := c.MultipartForm(); err != nil { @@ -127,7 +123,18 @@ func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, reque imageBase64s = append(imageBase64s, dataURL) image.Close() } + return imageBase64s, nil +} +func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) { + var imageRequest AliImageRequest + imageRequest.Model = request.Model + imageRequest.ResponseFormat = request.ResponseFormat + + imageBase64s, err := getImageBase64sFromForm(c, "image") + if err != nil { + return nil, fmt.Errorf("get image base64s from form failed: %w", err) + } //dto.MediaContent{} mediaContents := make([]AliMediaContent, len(imageBase64s)) for i, b64 := range imageBase64s { diff --git a/relay/channel/ali/image_wan.go b/relay/channel/ali/image_wan.go new file mode 100644 index 000000000..4bd1a2701 --- /dev/null +++ b/relay/channel/ali/image_wan.go @@ -0,0 +1,39 @@ +package ali + +import ( + "fmt" + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" + relaycommon "github.com/QuantumNous/new-api/relay/common" + + "github.com/gin-gonic/gin" +) + +func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) { + var err error + var imageRequest AliImageRequest + imageRequest.Model = request.Model + imageRequest.ResponseFormat = request.ResponseFormat + wanInput := WanImageInput{ + Prompt: request.Prompt, + } + + if err := common.UnmarshalBodyReusable(c, &wanInput); err != nil { + return nil, err + } + if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil { + return nil, fmt.Errorf("get image base64s from form failed: %w", err) + } + wanParams := WanImageParameters{ + N: int(request.N), + } + imageRequest.Input = wanInput + imageRequest.Parameters = wanParams + return &imageRequest, nil +} + +func isWanModel(modelName string) bool { + return strings.Contains(modelName, "wan") +}