mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-01 08:13:45 +00:00
Compare commits
7 Commits
v0.9.15-pa
...
task-model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
797c7acd13 | ||
|
|
f15b85f745 | ||
|
|
10a473993b | ||
|
|
ff11c92713 | ||
|
|
347ad047f9 | ||
|
|
c651727bab | ||
|
|
7fc25a57cf |
@@ -98,9 +98,9 @@ func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, reque
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
|
||||
//if len(imageFiles) > 1 {
|
||||
// return nil, errors.New("only one image is supported for qwen edit")
|
||||
//}
|
||||
if len(imageFiles) > 1 {
|
||||
return nil, errors.New("only one image is supported for qwen edit")
|
||||
}
|
||||
|
||||
// 获取base64编码的图片
|
||||
var imageBase64s []string
|
||||
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -110,7 +108,6 @@ type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
aliReq *AliVideoRequest
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -121,16 +118,6 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// 阿里通义万相支持 JSON 格式,不使用 multipart
|
||||
var taskReq relaycommon.TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
|
||||
return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
|
||||
}
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
a.aliReq = aliReq
|
||||
logger.LogJson(c, "ali video request body", aliReq)
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
@@ -147,7 +134,13 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
bodyBytes, err := common.Marshal(a.aliReq)
|
||||
var taskReq relaycommon.TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal_task_request_failed")
|
||||
}
|
||||
aliReq := a.convertToAliRequest(taskReq)
|
||||
|
||||
bodyBytes, err := common.Marshal(aliReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal_ali_request_failed")
|
||||
}
|
||||
@@ -155,31 +148,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
return bytes.NewReader(bodyBytes), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
|
||||
otherRatios := map[string]map[string]float64{
|
||||
"wan2.5-i2v-preview": {
|
||||
"480P": 1,
|
||||
"720P": 2,
|
||||
"1080P": 1 / 0.3,
|
||||
},
|
||||
"wan2.2-i2v-plus": {
|
||||
"480P": 1,
|
||||
"1080P": 0.7 / 0.14,
|
||||
},
|
||||
"wan2.2-kf2v-flash": {
|
||||
"480P": 1,
|
||||
"720P": 2,
|
||||
"1080P": 4.8,
|
||||
},
|
||||
"wan2.2-i2v-flash": {
|
||||
"480P": 1,
|
||||
"720P": 2,
|
||||
},
|
||||
"wan2.2-s2v": {
|
||||
"480P": 1,
|
||||
"720P": 0.9 / 0.5,
|
||||
},
|
||||
}
|
||||
func (a *TaskAdaptor) convertToAliRequest(req relaycommon.TaskSubmitReq) *AliVideoRequest {
|
||||
aliReq := &AliVideoRequest{
|
||||
Model: req.Model,
|
||||
Input: AliVideoInput{
|
||||
@@ -216,13 +185,6 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
|
||||
// 处理时长
|
||||
if req.Duration > 0 {
|
||||
aliReq.Parameters.Duration = req.Duration
|
||||
} else if req.Seconds != "" {
|
||||
seconds, err := strconv.Atoi(req.Seconds)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert seconds to int failed")
|
||||
} else {
|
||||
aliReq.Parameters.Duration = seconds
|
||||
}
|
||||
} else {
|
||||
aliReq.Parameters.Duration = 5 // 默认5秒
|
||||
}
|
||||
@@ -230,32 +192,11 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
|
||||
// 从 metadata 中提取额外参数
|
||||
if req.Metadata != nil {
|
||||
if metadataBytes, err := common.Marshal(req.Metadata); err == nil {
|
||||
err = common.Unmarshal(metadataBytes, aliReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
} else {
|
||||
return nil, errors.Wrap(err, "marshal metadata failed")
|
||||
_ = common.Unmarshal(metadataBytes, aliReq)
|
||||
}
|
||||
}
|
||||
|
||||
if aliReq.Model != req.Model {
|
||||
return nil, errors.New("can't change model with metadata")
|
||||
}
|
||||
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
"seconds": float64(aliReq.Parameters.Duration),
|
||||
}
|
||||
|
||||
if otherRatio, ok := otherRatios[req.Model]; ok {
|
||||
if ratio, ok := otherRatio[aliReq.Parameters.Resolution]; ok {
|
||||
info.PriceData.OtherRatios[fmt.Sprintf("resolution-%s", aliReq.Parameters.Resolution)] = ratio
|
||||
}
|
||||
}
|
||||
|
||||
// println(fmt.Sprintf("other ratios: %v", info.PriceData.OtherRatios))
|
||||
|
||||
return aliReq, nil
|
||||
return aliReq
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper
|
||||
|
||||
@@ -2,9 +2,13 @@ package sora
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -87,9 +91,107 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get_request_body_failed")
|
||||
}
|
||||
|
||||
// 检查是否需要模型重定向
|
||||
if !info.IsModelMapped {
|
||||
// 如果不需要重定向,直接返回原始请求体
|
||||
return bytes.NewReader(cachedBody), nil
|
||||
}
|
||||
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
|
||||
// 处理multipart/form-data请求
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
return buildRequestBodyWithMappedModel(cachedBody, contentType, info.UpstreamModelName)
|
||||
}
|
||||
// 处理JSON请求
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
var jsonData map[string]interface{}
|
||||
if err := json.Unmarshal(cachedBody, &jsonData); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal_json_failed")
|
||||
}
|
||||
|
||||
// 替换model字段为映射后的模型名
|
||||
jsonData["model"] = info.UpstreamModelName
|
||||
|
||||
// 重新编码为JSON
|
||||
newBody, err := json.Marshal(jsonData)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal_json_failed")
|
||||
}
|
||||
|
||||
return bytes.NewReader(newBody), nil
|
||||
}
|
||||
|
||||
return bytes.NewReader(cachedBody), nil
|
||||
}
|
||||
|
||||
func buildRequestBodyWithMappedModel(originalBody []byte, contentType, redirectedModel string) (io.Reader, error) {
|
||||
newBuffer := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(newBuffer)
|
||||
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parse_content_type_failed")
|
||||
}
|
||||
boundary, ok := params["boundary"]
|
||||
if !ok {
|
||||
return nil, errors.New("boundary_not_found_in_content_type")
|
||||
}
|
||||
if err := writer.SetBoundary(boundary); err != nil {
|
||||
return nil, errors.Wrap(err, "set_boundary_failed")
|
||||
}
|
||||
r := multipart.NewReader(bytes.NewReader(originalBody), boundary)
|
||||
|
||||
for {
|
||||
part, err := r.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read_multipart_part_failed")
|
||||
}
|
||||
|
||||
fieldName := part.FormName()
|
||||
|
||||
if fieldName == "model" {
|
||||
// 修改 model 字段为映射后的模型名
|
||||
if err := writer.WriteField("model", redirectedModel); err != nil {
|
||||
return nil, errors.Wrap(err, "write_model_field_failed")
|
||||
}
|
||||
} else {
|
||||
// 对于其他字段,保留原始内容
|
||||
if part.FileName() != "" {
|
||||
newPart, err := writer.CreatePart(part.Header)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create_form_file_failed")
|
||||
}
|
||||
if _, err := io.Copy(newPart, part); err != nil {
|
||||
return nil, errors.Wrap(err, "copy_file_content_failed")
|
||||
}
|
||||
} else {
|
||||
newPart, err := writer.CreatePart(part.Header)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create_form_field_failed")
|
||||
}
|
||||
if _, err := io.Copy(newPart, part); err != nil {
|
||||
return nil, errors.Wrap(err, "copy_field_content_failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := part.Close(); err != nil {
|
||||
return nil, errors.Wrap(err, "close_part_failed")
|
||||
}
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, errors.Wrap(err, "close_multipart_writer_failed")
|
||||
}
|
||||
|
||||
return newBuffer, nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
|
||||
@@ -121,7 +121,6 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
|
||||
prompt = req.Prompt
|
||||
model = req.Model
|
||||
size = req.Size
|
||||
seconds, _ = strconv.Atoi(req.Seconds)
|
||||
if seconds == 0 {
|
||||
seconds = req.Duration
|
||||
@@ -224,6 +223,11 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
|
||||
}
|
||||
}
|
||||
|
||||
// 模型映射
|
||||
if info.IsModelMapped {
|
||||
req.Model = info.UpstreamModelName
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
|
||||
@@ -38,6 +39,11 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
}
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
// 模型映射
|
||||
if err := helper.ModelMappedHelper(c, info, nil); err != nil {
|
||||
return service.TaskErrorWrapper(err, "model_mapped_failed", http.StatusBadRequest)
|
||||
}
|
||||
adaptor := GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
@@ -208,6 +214,10 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
task.Quota = quota
|
||||
task.Data = taskData
|
||||
task.Action = info.Action
|
||||
task.Properties = model.Properties{
|
||||
UpstreamModelName: info.UpstreamModelName,
|
||||
OriginModelName: info.OriginModelName,
|
||||
}
|
||||
err = task.Insert()
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||
|
||||
Reference in New Issue
Block a user