Compare commits

...

7 Commits

Author SHA1 Message Date
creamlike1024
797c7acd13 fix: 修复multipart表单字段内容复制问题 2025-10-31 20:13:27 +08:00
creamlike1024
f15b85f745 fix(: 修复multipart请求边界设置和文件字段处理问题 2025-10-31 20:06:01 +08:00
creamlike1024
10a473993b refactor(relay): remove IsModelMapped properties 2025-10-31 19:53:46 +08:00
creamlike1024
ff11c92713 Merge branch 'main' into task-model-mapper 2025-10-31 19:49:05 +08:00
creamlike1024
347ad047f9 feat: 保存重定向信息到 task.Properties 2025-10-31 19:45:37 +08:00
creamlike1024
c651727bab fix(adaptor): 修复解析multipart请求时获取boundary的问题 2025-10-31 19:16:55 +08:00
creamlike1024
7fc25a57cf feat(relay): 添加视频模型映射功能支持 2025-10-31 18:58:03 +08:00
3 changed files with 117 additions and 0 deletions

View File

@@ -2,9 +2,13 @@ package sora
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
@@ -87,9 +91,107 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
return nil, errors.Wrap(err, "get_request_body_failed")
}
// 检查是否需要模型重定向
if !info.IsModelMapped {
// 如果不需要重定向,直接返回原始请求体
return bytes.NewReader(cachedBody), nil
}
contentType := c.Request.Header.Get("Content-Type")
// 处理multipart/form-data请求
if strings.Contains(contentType, "multipart/form-data") {
return buildRequestBodyWithMappedModel(cachedBody, contentType, info.UpstreamModelName)
}
// 处理JSON请求
if strings.Contains(contentType, "application/json") {
var jsonData map[string]interface{}
if err := json.Unmarshal(cachedBody, &jsonData); err != nil {
return nil, errors.Wrap(err, "unmarshal_json_failed")
}
// 替换model字段为映射后的模型名
jsonData["model"] = info.UpstreamModelName
// 重新编码为JSON
newBody, err := json.Marshal(jsonData)
if err != nil {
return nil, errors.Wrap(err, "marshal_json_failed")
}
return bytes.NewReader(newBody), nil
}
return bytes.NewReader(cachedBody), nil
}
func buildRequestBodyWithMappedModel(originalBody []byte, contentType, redirectedModel string) (io.Reader, error) {
newBuffer := &bytes.Buffer{}
writer := multipart.NewWriter(newBuffer)
_, params, err := mime.ParseMediaType(contentType)
if err != nil {
return nil, errors.Wrap(err, "parse_content_type_failed")
}
boundary, ok := params["boundary"]
if !ok {
return nil, errors.New("boundary_not_found_in_content_type")
}
if err := writer.SetBoundary(boundary); err != nil {
return nil, errors.Wrap(err, "set_boundary_failed")
}
r := multipart.NewReader(bytes.NewReader(originalBody), boundary)
for {
part, err := r.NextPart()
if err == io.EOF {
break
}
if err != nil {
return nil, errors.Wrap(err, "read_multipart_part_failed")
}
fieldName := part.FormName()
if fieldName == "model" {
// 修改 model 字段为映射后的模型名
if err := writer.WriteField("model", redirectedModel); err != nil {
return nil, errors.Wrap(err, "write_model_field_failed")
}
} else {
// 对于其他字段,保留原始内容
if part.FileName() != "" {
newPart, err := writer.CreatePart(part.Header)
if err != nil {
return nil, errors.Wrap(err, "create_form_file_failed")
}
if _, err := io.Copy(newPart, part); err != nil {
return nil, errors.Wrap(err, "copy_file_content_failed")
}
} else {
newPart, err := writer.CreatePart(part.Header)
if err != nil {
return nil, errors.Wrap(err, "create_form_field_failed")
}
if _, err := io.Copy(newPart, part); err != nil {
return nil, errors.Wrap(err, "copy_field_content_failed")
}
}
}
if err := part.Close(); err != nil {
return nil, errors.Wrap(err, "close_part_failed")
}
}
if err := writer.Close(); err != nil {
return nil, errors.Wrap(err, "close_multipart_writer_failed")
}
return newBuffer, nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)

View File

@@ -223,6 +223,11 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
}
}
// 模型映射
if info.IsModelMapped {
req.Model = info.UpstreamModelName
}
storeTaskRequest(c, info, action, req)
return nil
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/ratio_setting"
@@ -38,6 +39,11 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
}
info.InitChannelMeta(c)
// 模型映射
if err := helper.ModelMappedHelper(c, info, nil); err != nil {
return service.TaskErrorWrapper(err, "model_mapped_failed", http.StatusBadRequest)
}
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
@@ -208,6 +214,10 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
task.Quota = quota
task.Data = taskData
task.Action = info.Action
task.Properties = model.Properties{
UpstreamModelName: info.UpstreamModelName,
OriginModelName: info.OriginModelName,
}
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)