mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-18 02:17:28 +00:00
feat: support qwen-image-edit
This commit is contained in:
117
dto/dalle.go
117
dto/dalle.go
@@ -1,117 +0,0 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
// 用匿名字段接住额外的字段
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
}
|
||||
|
||||
func (r *ImageRequest) UnmarshalJSON(data []byte) error {
|
||||
// 先解析成 map[string]interface{}
|
||||
var rawMap map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &rawMap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 用 struct tag 获取所有已定义字段名
|
||||
knownFields := GetJSONFieldNames(reflect.TypeOf(*r))
|
||||
|
||||
// 再正常解析已定义字段
|
||||
type Alias ImageRequest
|
||||
var known Alias
|
||||
if err := json.Unmarshal(data, &known); err != nil {
|
||||
return err
|
||||
}
|
||||
*r = ImageRequest(known)
|
||||
|
||||
// 提取多余字段
|
||||
r.Extra = make(map[string]json.RawMessage)
|
||||
for k, v := range rawMap {
|
||||
if _, ok := knownFields[k]; !ok {
|
||||
r.Extra[k] = v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r ImageRequest) MarshalJSON() ([]byte, error) {
|
||||
// 将已定义字段转为 map
|
||||
type Alias ImageRequest
|
||||
alias := Alias(r)
|
||||
base, err := json.Marshal(alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var baseMap map[string]json.RawMessage
|
||||
if err := json.Unmarshal(base, &baseMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 合并 ExtraFields
|
||||
for k, v := range r.Extra {
|
||||
baseMap[k] = v
|
||||
}
|
||||
|
||||
return json.Marshal(baseMap)
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Data []ImageData `json:"data"`
|
||||
Created int64 `json:"created"`
|
||||
}
|
||||
type ImageData struct {
|
||||
Url string `json:"url"`
|
||||
B64Json string `json:"b64_json"`
|
||||
RevisedPrompt string `json:"revised_prompt"`
|
||||
}
|
||||
|
||||
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||
fields := make(map[string]struct{})
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
// 跳过匿名字段(例如 ExtraFields)
|
||||
if field.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == "-" || tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 取逗号前字段名(排除 omitempty 等)
|
||||
name := tag
|
||||
if commaIdx := indexComma(tag); commaIdx != -1 {
|
||||
name = tag[:commaIdx]
|
||||
}
|
||||
fields[name] = struct{}{}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func indexComma(s string) int {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == ',' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
@@ -2,7 +2,9 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -29,6 +31,68 @@ type ImageRequest struct {
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
}
|
||||
|
||||
func (i *ImageRequest) UnmarshalJSON(data []byte) error {
|
||||
// 先解析成 map[string]interface{}
|
||||
var rawMap map[string]json.RawMessage
|
||||
if err := common.Unmarshal(data, &rawMap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 用 struct tag 获取所有已定义字段名
|
||||
knownFields := GetJSONFieldNames(reflect.TypeOf(*i))
|
||||
|
||||
// 再正常解析已定义字段
|
||||
type Alias ImageRequest
|
||||
var known Alias
|
||||
if err := common.Unmarshal(data, &known); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = ImageRequest(known)
|
||||
|
||||
// 提取多余字段
|
||||
i.Extra = make(map[string]json.RawMessage)
|
||||
for k, v := range rawMap {
|
||||
if _, ok := knownFields[k]; !ok {
|
||||
i.Extra[k] = v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||
fields := make(map[string]struct{})
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
// 跳过匿名字段(例如 ExtraFields)
|
||||
if field.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == "-" || tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 取逗号前字段名(排除 omitempty 等)
|
||||
name := tag
|
||||
if commaIdx := indexComma(tag); commaIdx != -1 {
|
||||
name = tag[:commaIdx]
|
||||
}
|
||||
fields[name] = struct{}{}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func indexComma(s string) int {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == ',' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
var sizeRatio = 1.0
|
||||
var qualityRatio = 1.0
|
||||
|
||||
@@ -67,8 +67,8 @@ type GeneralOpenAIRequest struct {
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
// 用匿名参数接收额外参数,例如ollama的think参数在此接收
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
// ollama Params
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
|
||||
@@ -185,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
modelRequest.Model = modelName
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
@@ -208,7 +208,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||
modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
||||
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
modelRequest.Model = c.PostForm("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
relayMode := relayconstant.RelayModeAudioSpeech
|
||||
|
||||
@@ -3,7 +3,6 @@ package ali
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
@@ -14,6 +13,8 @@ import (
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -44,6 +45,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesEdits:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
case constant.RelayModeCompletions:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
|
||||
default:
|
||||
@@ -66,6 +69,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
req.Set("Content-Type", "application/json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -93,11 +99,30 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
||||
// 如果用户使用表单,则需要解析表单数据
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
aliRequest, err := oaiFormEdit2AliImageEdit(c, info, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image edit form request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
}
|
||||
}
|
||||
return aliRequest, nil
|
||||
return nil, fmt.Errorf("unsupported image relay mode: %d", info.RelayMode)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -134,6 +159,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
case constant.RelayModeImagesEdits:
|
||||
err, usage = aliImageEditHandler(c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
|
||||
@@ -3,10 +3,15 @@ package ali
|
||||
import "one-api/dto"
|
||||
|
||||
type AliMessage struct {
|
||||
Content string `json:"content"`
|
||||
Content any `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type AliMediaContent struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
//History []AliMessage `json:"history,omitempty"`
|
||||
@@ -70,13 +75,14 @@ type TaskResult struct {
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
Choices []map[string]any `json:"choices,omitempty"`
|
||||
}
|
||||
|
||||
type AliResponse struct {
|
||||
@@ -101,8 +107,9 @@ type AliImageParameters struct {
|
||||
}
|
||||
|
||||
type AliImageInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
Messages []AliMessage `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
type AliRerankParameters struct {
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -21,7 +24,7 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
|
||||
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
||||
if request.Extra != nil {
|
||||
if val, ok := request.Extra["parameters"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Parameters)
|
||||
@@ -54,6 +57,100 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
if _, err := c.MultipartForm(); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||
}
|
||||
mf = c.Request.MultipartForm
|
||||
}
|
||||
|
||||
var imageFiles []*multipart.FileHeader
|
||||
var exists bool
|
||||
|
||||
// First check for standard "image" field
|
||||
if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
|
||||
// If not found, check for "image[]" field
|
||||
if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||
// If still not found, iterate through all fields to find any that start with "image["
|
||||
foundArrayImages := false
|
||||
for fieldName, files := range mf.File {
|
||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||
foundArrayImages = true
|
||||
imageFiles = append(imageFiles, files...)
|
||||
}
|
||||
}
|
||||
|
||||
// If no image fields found at all
|
||||
if !foundArrayImages && (len(imageFiles) == 0) {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(imageFiles) == 0 {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
|
||||
if len(imageFiles) > 1 {
|
||||
return nil, errors.New("only one image is supported for qwen edit")
|
||||
}
|
||||
|
||||
// 获取base64编码的图片
|
||||
var imageBase64s []string
|
||||
for _, file := range imageFiles {
|
||||
image, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to open image file")
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
imageData, err := io.ReadAll(image)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read image file")
|
||||
}
|
||||
|
||||
// 获取MIME类型
|
||||
mimeType := http.DetectContentType(imageData)
|
||||
|
||||
// 编码为base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(imageData)
|
||||
|
||||
// 构造data URL格式
|
||||
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
|
||||
imageBase64s = append(imageBase64s, dataURL)
|
||||
image.Close()
|
||||
}
|
||||
|
||||
//dto.MediaContent{}
|
||||
mediaContents := make([]AliMediaContent, len(imageBase64s))
|
||||
for i, b64 := range imageBase64s {
|
||||
mediaContents[i] = AliMediaContent{
|
||||
Image: b64,
|
||||
}
|
||||
}
|
||||
mediaContents = append(mediaContents, AliMediaContent{
|
||||
Text: request.Prompt,
|
||||
})
|
||||
imageRequest.Input = AliImageInput{
|
||||
Messages: []AliMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: mediaContents,
|
||||
},
|
||||
},
|
||||
}
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
||||
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
|
||||
|
||||
@@ -196,8 +293,47 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
c.Writer.Write(jsonResponse)
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Message != "" {
|
||||
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
||||
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
var fullTextResponse dto.ImageResponse
|
||||
if len(aliResponse.Output.Choices) > 0 {
|
||||
fullTextResponse = dto.ImageResponse{
|
||||
Created: info.StartTime.Unix(),
|
||||
Data: []dto.ImageData{
|
||||
{
|
||||
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
||||
B64Json: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(responseBody, &mapResponse)
|
||||
fullTextResponse.Extra = mapResponse
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
@@ -68,9 +68,7 @@ func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*O
|
||||
StreamOptions: request.StreamOptions,
|
||||
Suffix: request.Suffix,
|
||||
}
|
||||
if think, ok := request.Extra["think"]; ok {
|
||||
ollamaRequest.Think = think
|
||||
}
|
||||
ollamaRequest.Think = request.Think
|
||||
return ollamaRequest, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -132,30 +132,34 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
||||
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesEdits:
|
||||
_, err := c.MultipartForm()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||
}
|
||||
formData := c.Request.PostForm
|
||||
imageRequest.Prompt = formData.Get("prompt")
|
||||
imageRequest.Model = formData.Get("model")
|
||||
imageRequest.N = uint(common.String2Int(formData.Get("n")))
|
||||
imageRequest.Quality = formData.Get("quality")
|
||||
imageRequest.Size = formData.Get("size")
|
||||
|
||||
if imageRequest.Model == "gpt-image-1" {
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "standard"
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
_, err := c.MultipartForm()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||
}
|
||||
}
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
formData := c.Request.PostForm
|
||||
imageRequest.Prompt = formData.Get("prompt")
|
||||
imageRequest.Model = formData.Get("model")
|
||||
imageRequest.N = uint(common.String2Int(formData.Get("n")))
|
||||
imageRequest.Quality = formData.Get("quality")
|
||||
imageRequest.Size = formData.Get("size")
|
||||
|
||||
watermark := formData.Has("watermark")
|
||||
if watermark {
|
||||
imageRequest.Watermark = &watermark
|
||||
if imageRequest.Model == "gpt-image-1" {
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "standard"
|
||||
}
|
||||
}
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
|
||||
watermark := formData.Has("watermark")
|
||||
if watermark {
|
||||
imageRequest.Watermark = &watermark
|
||||
}
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
@@ -163,7 +167,8 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
||||
}
|
||||
|
||||
if imageRequest.Model == "" {
|
||||
imageRequest.Model = "dall-e-3"
|
||||
//imageRequest.Model = "dall-e-3"
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
|
||||
if strings.Contains(imageRequest.Size, "×") {
|
||||
@@ -194,9 +199,9 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Prompt == "" {
|
||||
return nil, errors.New("prompt is required")
|
||||
}
|
||||
//if imageRequest.Prompt == "" {
|
||||
// return nil, errors.New("prompt is required")
|
||||
//}
|
||||
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
|
||||
@@ -2,14 +2,13 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
@@ -56,10 +55,12 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
|
||||
switch convertedRequest.(type) {
|
||||
case *bytes.Buffer:
|
||||
requestBody = convertedRequest.(io.Reader)
|
||||
} else {
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
default:
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -73,7 +74,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("image request body: %s", string(jsonData)))
|
||||
logger.LogDebug(c, fmt.Sprintf("image request body: %s", string(jsonData)))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user