mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 07:17:26 +00:00
feat: support qwen-image-edit
This commit is contained in:
117
dto/dalle.go
117
dto/dalle.go
@@ -1,117 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ImageRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt" binding:"required"`
|
|
||||||
N int `json:"n,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Quality string `json:"quality,omitempty"`
|
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
|
||||||
Style string `json:"style,omitempty"`
|
|
||||||
User string `json:"user,omitempty"`
|
|
||||||
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
|
||||||
Background string `json:"background,omitempty"`
|
|
||||||
Moderation string `json:"moderation,omitempty"`
|
|
||||||
OutputFormat string `json:"output_format,omitempty"`
|
|
||||||
// 用匿名字段接住额外的字段
|
|
||||||
Extra map[string]json.RawMessage `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ImageRequest) UnmarshalJSON(data []byte) error {
|
|
||||||
// 先解析成 map[string]interface{}
|
|
||||||
var rawMap map[string]json.RawMessage
|
|
||||||
if err := json.Unmarshal(data, &rawMap); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 用 struct tag 获取所有已定义字段名
|
|
||||||
knownFields := GetJSONFieldNames(reflect.TypeOf(*r))
|
|
||||||
|
|
||||||
// 再正常解析已定义字段
|
|
||||||
type Alias ImageRequest
|
|
||||||
var known Alias
|
|
||||||
if err := json.Unmarshal(data, &known); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*r = ImageRequest(known)
|
|
||||||
|
|
||||||
// 提取多余字段
|
|
||||||
r.Extra = make(map[string]json.RawMessage)
|
|
||||||
for k, v := range rawMap {
|
|
||||||
if _, ok := knownFields[k]; !ok {
|
|
||||||
r.Extra[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r ImageRequest) MarshalJSON() ([]byte, error) {
|
|
||||||
// 将已定义字段转为 map
|
|
||||||
type Alias ImageRequest
|
|
||||||
alias := Alias(r)
|
|
||||||
base, err := json.Marshal(alias)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var baseMap map[string]json.RawMessage
|
|
||||||
if err := json.Unmarshal(base, &baseMap); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 合并 ExtraFields
|
|
||||||
for k, v := range r.Extra {
|
|
||||||
baseMap[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.Marshal(baseMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageResponse struct {
|
|
||||||
Data []ImageData `json:"data"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
}
|
|
||||||
type ImageData struct {
|
|
||||||
Url string `json:"url"`
|
|
||||||
B64Json string `json:"b64_json"`
|
|
||||||
RevisedPrompt string `json:"revised_prompt"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
|
||||||
fields := make(map[string]struct{})
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
field := t.Field(i)
|
|
||||||
|
|
||||||
// 跳过匿名字段(例如 ExtraFields)
|
|
||||||
if field.Anonymous {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
tag := field.Tag.Get("json")
|
|
||||||
if tag == "-" || tag == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 取逗号前字段名(排除 omitempty 等)
|
|
||||||
name := tag
|
|
||||||
if commaIdx := indexComma(tag); commaIdx != -1 {
|
|
||||||
name = tag[:commaIdx]
|
|
||||||
}
|
|
||||||
fields[name] = struct{}{}
|
|
||||||
}
|
|
||||||
return fields
|
|
||||||
}
|
|
||||||
|
|
||||||
func indexComma(s string) int {
|
|
||||||
for i := 0; i < len(s); i++ {
|
|
||||||
if s[i] == ',' {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
@@ -2,7 +2,9 @@ package dto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"one-api/common"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -29,6 +31,68 @@ type ImageRequest struct {
|
|||||||
Extra map[string]json.RawMessage `json:"-"`
|
Extra map[string]json.RawMessage `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) UnmarshalJSON(data []byte) error {
|
||||||
|
// 先解析成 map[string]interface{}
|
||||||
|
var rawMap map[string]json.RawMessage
|
||||||
|
if err := common.Unmarshal(data, &rawMap); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用 struct tag 获取所有已定义字段名
|
||||||
|
knownFields := GetJSONFieldNames(reflect.TypeOf(*i))
|
||||||
|
|
||||||
|
// 再正常解析已定义字段
|
||||||
|
type Alias ImageRequest
|
||||||
|
var known Alias
|
||||||
|
if err := common.Unmarshal(data, &known); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*i = ImageRequest(known)
|
||||||
|
|
||||||
|
// 提取多余字段
|
||||||
|
i.Extra = make(map[string]json.RawMessage)
|
||||||
|
for k, v := range rawMap {
|
||||||
|
if _, ok := knownFields[k]; !ok {
|
||||||
|
i.Extra[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||||
|
fields := make(map[string]struct{})
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
|
||||||
|
// 跳过匿名字段(例如 ExtraFields)
|
||||||
|
if field.Anonymous {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tag := field.Tag.Get("json")
|
||||||
|
if tag == "-" || tag == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 取逗号前字段名(排除 omitempty 等)
|
||||||
|
name := tag
|
||||||
|
if commaIdx := indexComma(tag); commaIdx != -1 {
|
||||||
|
name = tag[:commaIdx]
|
||||||
|
}
|
||||||
|
fields[name] = struct{}{}
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexComma(s string) int {
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if s[i] == ',' {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
var sizeRatio = 1.0
|
var sizeRatio = 1.0
|
||||||
var qualityRatio = 1.0
|
var qualityRatio = 1.0
|
||||||
|
|||||||
@@ -67,8 +67,8 @@ type GeneralOpenAIRequest struct {
|
|||||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||||
// Ali Qwen Params
|
// Ali Qwen Params
|
||||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||||
// 用匿名参数接收额外参数,例如ollama的think参数在此接收
|
// ollama Params
|
||||||
Extra map[string]json.RawMessage `json:"-"`
|
Think json.RawMessage `json:"think,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
modelRequest.Model = modelName
|
modelRequest.Model = modelName
|
||||||
}
|
}
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -208,7 +208,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||||
modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
||||||
|
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
|
modelRequest.Model = c.PostForm("model")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||||
relayMode := relayconstant.RelayModeAudioSpeech
|
relayMode := relayconstant.RelayModeAudioSpeech
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package ali
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
@@ -14,6 +13,8 @@ import (
|
|||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
@@ -44,6 +45,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||||
|
case constant.RelayModeImagesEdits:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||||
case constant.RelayModeCompletions:
|
case constant.RelayModeCompletions:
|
||||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
|
||||||
default:
|
default:
|
||||||
@@ -66,6 +69,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||||
req.Set("X-DashScope-Async", "enable")
|
req.Set("X-DashScope-Async", "enable")
|
||||||
}
|
}
|
||||||
|
if info.RelayMode == constant.RelayModeImagesEdits {
|
||||||
|
req.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,11 +99,30 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
aliRequest, err := oaiImage2Ali(request)
|
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||||
if err != nil {
|
aliRequest, err := oaiImage2Ali(request)
|
||||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||||
|
}
|
||||||
|
return aliRequest, nil
|
||||||
|
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||||
|
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
||||||
|
// 如果用户使用表单,则需要解析表单数据
|
||||||
|
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
|
aliRequest, err := oaiFormEdit2AliImageEdit(c, info, request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert image edit form request failed: %w", err)
|
||||||
|
}
|
||||||
|
return aliRequest, nil
|
||||||
|
} else {
|
||||||
|
aliRequest, err := oaiImage2Ali(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||||
|
}
|
||||||
|
return aliRequest, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return aliRequest, nil
|
return nil, fmt.Errorf("unsupported image relay mode: %d", info.RelayMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
@@ -134,6 +159,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
err, usage = aliImageHandler(c, resp, info)
|
err, usage = aliImageHandler(c, resp, info)
|
||||||
|
case constant.RelayModeImagesEdits:
|
||||||
|
err, usage = aliImageEditHandler(c, resp, info)
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
err, usage = RerankHandler(c, resp, info)
|
err, usage = RerankHandler(c, resp, info)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -3,10 +3,15 @@ package ali
|
|||||||
import "one-api/dto"
|
import "one-api/dto"
|
||||||
|
|
||||||
type AliMessage struct {
|
type AliMessage struct {
|
||||||
Content string `json:"content"`
|
Content any `json:"content"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AliMediaContent struct {
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type AliInput struct {
|
type AliInput struct {
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
//History []AliMessage `json:"history,omitempty"`
|
//History []AliMessage `json:"history,omitempty"`
|
||||||
@@ -70,13 +75,14 @@ type TaskResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AliOutput struct {
|
type AliOutput struct {
|
||||||
TaskId string `json:"task_id,omitempty"`
|
TaskId string `json:"task_id,omitempty"`
|
||||||
TaskStatus string `json:"task_status,omitempty"`
|
TaskStatus string `json:"task_status,omitempty"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
FinishReason string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
Message string `json:"message,omitempty"`
|
Message string `json:"message,omitempty"`
|
||||||
Code string `json:"code,omitempty"`
|
Code string `json:"code,omitempty"`
|
||||||
Results []TaskResult `json:"results,omitempty"`
|
Results []TaskResult `json:"results,omitempty"`
|
||||||
|
Choices []map[string]any `json:"choices,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliResponse struct {
|
type AliResponse struct {
|
||||||
@@ -101,8 +107,9 @@ type AliImageParameters struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AliImageInput struct {
|
type AliImageInput struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||||
|
Messages []AliMessage `json:"messages,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliRerankParameters struct {
|
type AliRerankParameters struct {
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
package ali
|
package ali
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
@@ -21,7 +24,7 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
|||||||
var imageRequest AliImageRequest
|
var imageRequest AliImageRequest
|
||||||
imageRequest.Model = request.Model
|
imageRequest.Model = request.Model
|
||||||
imageRequest.ResponseFormat = request.ResponseFormat
|
imageRequest.ResponseFormat = request.ResponseFormat
|
||||||
|
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
||||||
if request.Extra != nil {
|
if request.Extra != nil {
|
||||||
if val, ok := request.Extra["parameters"]; ok {
|
if val, ok := request.Extra["parameters"]; ok {
|
||||||
err := common.Unmarshal(val, &imageRequest.Parameters)
|
err := common.Unmarshal(val, &imageRequest.Parameters)
|
||||||
@@ -54,6 +57,100 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
|||||||
return &imageRequest, nil
|
return &imageRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
|
||||||
|
var imageRequest AliImageRequest
|
||||||
|
imageRequest.Model = request.Model
|
||||||
|
imageRequest.ResponseFormat = request.ResponseFormat
|
||||||
|
|
||||||
|
mf := c.Request.MultipartForm
|
||||||
|
if mf == nil {
|
||||||
|
if _, err := c.MultipartForm(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||||
|
}
|
||||||
|
mf = c.Request.MultipartForm
|
||||||
|
}
|
||||||
|
|
||||||
|
var imageFiles []*multipart.FileHeader
|
||||||
|
var exists bool
|
||||||
|
|
||||||
|
// First check for standard "image" field
|
||||||
|
if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
|
||||||
|
// If not found, check for "image[]" field
|
||||||
|
if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||||
|
// If still not found, iterate through all fields to find any that start with "image["
|
||||||
|
foundArrayImages := false
|
||||||
|
for fieldName, files := range mf.File {
|
||||||
|
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||||
|
foundArrayImages = true
|
||||||
|
imageFiles = append(imageFiles, files...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no image fields found at all
|
||||||
|
if !foundArrayImages && (len(imageFiles) == 0) {
|
||||||
|
return nil, errors.New("image is required")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(imageFiles) == 0 {
|
||||||
|
return nil, errors.New("image is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(imageFiles) > 1 {
|
||||||
|
return nil, errors.New("only one image is supported for qwen edit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取base64编码的图片
|
||||||
|
var imageBase64s []string
|
||||||
|
for _, file := range imageFiles {
|
||||||
|
image, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("failed to open image file")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取文件内容
|
||||||
|
imageData, err := io.ReadAll(image)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("failed to read image file")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取MIME类型
|
||||||
|
mimeType := http.DetectContentType(imageData)
|
||||||
|
|
||||||
|
// 编码为base64
|
||||||
|
base64Data := base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
|
||||||
|
// 构造data URL格式
|
||||||
|
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
|
||||||
|
imageBase64s = append(imageBase64s, dataURL)
|
||||||
|
image.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
//dto.MediaContent{}
|
||||||
|
mediaContents := make([]AliMediaContent, len(imageBase64s))
|
||||||
|
for i, b64 := range imageBase64s {
|
||||||
|
mediaContents[i] = AliMediaContent{
|
||||||
|
Image: b64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mediaContents = append(mediaContents, AliMediaContent{
|
||||||
|
Text: request.Prompt,
|
||||||
|
})
|
||||||
|
imageRequest.Input = AliImageInput{
|
||||||
|
Messages: []AliMessage{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: mediaContents,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
imageRequest.Parameters = AliImageParameters{
|
||||||
|
Watermark: request.Watermark,
|
||||||
|
}
|
||||||
|
return &imageRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
||||||
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
|
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
|
||||||
|
|
||||||
@@ -196,8 +293,47 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
return nil, &dto.Usage{}
|
||||||
c.Writer.Write(jsonResponse)
|
}
|
||||||
|
|
||||||
|
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||||
|
var aliResponse AliResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
service.CloseResponseBodyGracefully(resp)
|
||||||
|
err = common.Unmarshal(responseBody, &aliResponse)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliResponse.Message != "" {
|
||||||
|
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
||||||
|
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
||||||
|
}
|
||||||
|
var fullTextResponse dto.ImageResponse
|
||||||
|
if len(aliResponse.Output.Choices) > 0 {
|
||||||
|
fullTextResponse = dto.ImageResponse{
|
||||||
|
Created: info.StartTime.Unix(),
|
||||||
|
Data: []dto.ImageData{
|
||||||
|
{
|
||||||
|
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
||||||
|
B64Json: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var mapResponse map[string]any
|
||||||
|
_ = common.Unmarshal(responseBody, &mapResponse)
|
||||||
|
fullTextResponse.Extra = mapResponse
|
||||||
|
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
|
}
|
||||||
|
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
return nil, &dto.Usage{}
|
return nil, &dto.Usage{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,9 +68,7 @@ func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*O
|
|||||||
StreamOptions: request.StreamOptions,
|
StreamOptions: request.StreamOptions,
|
||||||
Suffix: request.Suffix,
|
Suffix: request.Suffix,
|
||||||
}
|
}
|
||||||
if think, ok := request.Extra["think"]; ok {
|
ollamaRequest.Think = request.Think
|
||||||
ollamaRequest.Think = think
|
|
||||||
}
|
|
||||||
return ollamaRequest, nil
|
return ollamaRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -132,30 +132,34 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
|||||||
|
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesEdits:
|
||||||
_, err := c.MultipartForm()
|
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
if err != nil {
|
_, err := c.MultipartForm()
|
||||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
if err != nil {
|
||||||
}
|
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||||
formData := c.Request.PostForm
|
|
||||||
imageRequest.Prompt = formData.Get("prompt")
|
|
||||||
imageRequest.Model = formData.Get("model")
|
|
||||||
imageRequest.N = uint(common.String2Int(formData.Get("n")))
|
|
||||||
imageRequest.Quality = formData.Get("quality")
|
|
||||||
imageRequest.Size = formData.Get("size")
|
|
||||||
|
|
||||||
if imageRequest.Model == "gpt-image-1" {
|
|
||||||
if imageRequest.Quality == "" {
|
|
||||||
imageRequest.Quality = "standard"
|
|
||||||
}
|
}
|
||||||
}
|
formData := c.Request.PostForm
|
||||||
if imageRequest.N == 0 {
|
imageRequest.Prompt = formData.Get("prompt")
|
||||||
imageRequest.N = 1
|
imageRequest.Model = formData.Get("model")
|
||||||
}
|
imageRequest.N = uint(common.String2Int(formData.Get("n")))
|
||||||
|
imageRequest.Quality = formData.Get("quality")
|
||||||
|
imageRequest.Size = formData.Get("size")
|
||||||
|
|
||||||
watermark := formData.Has("watermark")
|
if imageRequest.Model == "gpt-image-1" {
|
||||||
if watermark {
|
if imageRequest.Quality == "" {
|
||||||
imageRequest.Watermark = &watermark
|
imageRequest.Quality = "standard"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if imageRequest.N == 0 {
|
||||||
|
imageRequest.N = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
watermark := formData.Has("watermark")
|
||||||
|
if watermark {
|
||||||
|
imageRequest.Watermark = &watermark
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
fallthrough
|
||||||
default:
|
default:
|
||||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -163,7 +167,8 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
|||||||
}
|
}
|
||||||
|
|
||||||
if imageRequest.Model == "" {
|
if imageRequest.Model == "" {
|
||||||
imageRequest.Model = "dall-e-3"
|
//imageRequest.Model = "dall-e-3"
|
||||||
|
return nil, errors.New("model is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(imageRequest.Size, "×") {
|
if strings.Contains(imageRequest.Size, "×") {
|
||||||
@@ -194,9 +199,9 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if imageRequest.Prompt == "" {
|
//if imageRequest.Prompt == "" {
|
||||||
return nil, errors.New("prompt is required")
|
// return nil, errors.New("prompt is required")
|
||||||
}
|
//}
|
||||||
|
|
||||||
if imageRequest.N == 0 {
|
if imageRequest.N == 0 {
|
||||||
imageRequest.N = 1
|
imageRequest.N = 1
|
||||||
|
|||||||
@@ -2,14 +2,13 @@ package relay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
@@ -56,10 +55,12 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
if info.RelayMode == relayconstant.RelayModeImagesEdits {
|
|
||||||
|
switch convertedRequest.(type) {
|
||||||
|
case *bytes.Buffer:
|
||||||
requestBody = convertedRequest.(io.Reader)
|
requestBody = convertedRequest.(io.Reader)
|
||||||
} else {
|
default:
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := common.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
@@ -73,7 +74,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
}
|
}
|
||||||
|
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println(fmt.Sprintf("image request body: %s", string(jsonData)))
|
logger.LogDebug(c, fmt.Sprintf("image request body: %s", string(jsonData)))
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user