Files
new-api/relay/channel/ali/image.go
2025-11-14 20:30:18 +08:00

348 lines
9.7 KiB
Go

package ali
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
var imageRequest AliImageRequest
imageRequest.Model = request.Model
imageRequest.ResponseFormat = request.ResponseFormat
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
if request.Extra != nil {
if val, ok := request.Extra["parameters"]; ok {
err := common.Unmarshal(val, &imageRequest.Parameters)
if err != nil {
return nil, fmt.Errorf("invalid parameters field: %w", err)
}
}
if val, ok := request.Extra["input"]; ok {
err := common.Unmarshal(val, &imageRequest.Input)
if err != nil {
return nil, fmt.Errorf("invalid input field: %w", err)
}
}
}
if imageRequest.Parameters == nil {
imageRequest.Parameters = AliImageParameters{
Size: strings.Replace(request.Size, "x", "*", -1),
N: int(request.N),
Watermark: request.Watermark,
}
}
if imageRequest.Input == nil {
imageRequest.Input = AliImageInput{
Prompt: request.Prompt,
}
}
return &imageRequest, nil
}
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
mf := c.Request.MultipartForm
if mf == nil {
if _, err := c.MultipartForm(); err != nil {
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
}
mf = c.Request.MultipartForm
}
var imageFiles []*multipart.FileHeader
var exists bool
// First check for standard "image" field
if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
// If not found, check for "image[]" field
if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
// If still not found, iterate through all fields to find any that start with "image["
foundArrayImages := false
for fieldName, files := range mf.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
foundArrayImages = true
imageFiles = append(imageFiles, files...)
}
}
// If no image fields found at all
if !foundArrayImages && (len(imageFiles) == 0) {
return nil, errors.New("image is required")
}
}
}
if len(imageFiles) == 0 {
return nil, errors.New("image is required")
}
//if len(imageFiles) > 1 {
// return nil, errors.New("only one image is supported for qwen edit")
//}
// 获取base64编码的图片
var imageBase64s []string
for _, file := range imageFiles {
image, err := file.Open()
if err != nil {
return nil, errors.New("failed to open image file")
}
// 读取文件内容
imageData, err := io.ReadAll(image)
if err != nil {
return nil, errors.New("failed to read image file")
}
// 获取MIME类型
mimeType := http.DetectContentType(imageData)
// 编码为base64
base64Data := base64.StdEncoding.EncodeToString(imageData)
// 构造data URL格式
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
imageBase64s = append(imageBase64s, dataURL)
image.Close()
}
return imageBase64s, nil
}
func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
var imageRequest AliImageRequest
imageRequest.Model = request.Model
imageRequest.ResponseFormat = request.ResponseFormat
imageBase64s, err := getImageBase64sFromForm(c, "image")
if err != nil {
return nil, fmt.Errorf("get image base64s from form failed: %w", err)
}
//dto.MediaContent{}
mediaContents := make([]AliMediaContent, len(imageBase64s))
for i, b64 := range imageBase64s {
mediaContents[i] = AliMediaContent{
Image: b64,
}
}
mediaContents = append(mediaContents, AliMediaContent{
Text: request.Prompt,
})
imageRequest.Input = AliImageInput{
Messages: []AliMessage{
{
Role: "user",
Content: mediaContents,
},
},
}
imageRequest.Parameters = AliImageParameters{
Watermark: request.Watermark,
}
return &imageRequest, nil
}
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
var aliResponse AliResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
common.SysLog("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response AliResponse
err = common.Unmarshal(responseBody, &response)
if err != nil {
common.SysLog("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
waitSeconds := 10
step := 0
maxStep := 20
var taskResponse AliResponse
var responseBody []byte
for {
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
step++
rsp, err, body := updateTask(info, taskID)
responseBody = body
if err != nil {
logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
time.Sleep(time.Duration(waitSeconds) * time.Second)
continue
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
imageResponse := dto.ImageResponse{
Created: info.StartTime.Unix(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
var mapResponse map[string]any
_ = common.Unmarshal(originBody, &mapResponse)
imageResponse.Extra = mapResponse
return &imageResponse
}
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
service.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
}
aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponse), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
}, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return nil, &dto.Usage{}
}
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
service.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Message != "" {
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
}
var fullTextResponse dto.ImageResponse
if len(aliResponse.Output.Choices) > 0 {
fullTextResponse = dto.ImageResponse{
Created: info.StartTime.Unix(),
Data: []dto.ImageData{
{
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
B64Json: "",
},
},
}
}
var mapResponse map[string]any
_ = common.Unmarshal(responseBody, &mapResponse)
fullTextResponse.Extra = mapResponse
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return nil, &dto.Usage{}
}