mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 09:08:38 +00:00
feat(relay): 添加视频模型映射功能支持
This commit is contained in:
@@ -2,9 +2,12 @@ package sora
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
@@ -87,9 +90,96 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "get_request_body_failed")
|
return nil, errors.Wrap(err, "get_request_body_failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否需要模型重定向
|
||||||
|
if !info.IsModelMapped {
|
||||||
|
// 如果不需要重定向,直接返回原始请求体
|
||||||
|
return bytes.NewReader(cachedBody), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
// 处理multipart/form-data请求
|
||||||
|
if strings.Contains(contentType, "multipart/form-data") {
|
||||||
|
return buildRequestBodyWithMappedModel(cachedBody, contentType, info.UpstreamModelName)
|
||||||
|
}
|
||||||
|
// 处理JSON请求
|
||||||
|
if strings.Contains(contentType, "application/json") {
|
||||||
|
var jsonData map[string]interface{}
|
||||||
|
if err := json.Unmarshal(cachedBody, &jsonData); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "unmarshal_json_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 替换model字段为映射后的模型名
|
||||||
|
jsonData["model"] = info.UpstreamModelName
|
||||||
|
|
||||||
|
// 重新编码为JSON
|
||||||
|
newBody, err := json.Marshal(jsonData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "marshal_json_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.NewReader(newBody), nil
|
||||||
|
}
|
||||||
|
|
||||||
return bytes.NewReader(cachedBody), nil
|
return bytes.NewReader(cachedBody), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildRequestBodyWithMappedModel(originalBody []byte, contentType, redirectedModel string) (io.Reader, error) {
|
||||||
|
newBuffer := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(newBuffer)
|
||||||
|
|
||||||
|
r := multipart.NewReader(bytes.NewReader(originalBody), strings.TrimPrefix(contentType, "multipart/form-data; boundary="))
|
||||||
|
|
||||||
|
for {
|
||||||
|
part, err := r.NextPart()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "read_multipart_part_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldName := part.FormName()
|
||||||
|
|
||||||
|
if fieldName == "model" {
|
||||||
|
// 修改 model 字段为映射后的模型名
|
||||||
|
if err := writer.WriteField("model", redirectedModel); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "write_model_field_failed")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 对于其他字段,保留原始内容
|
||||||
|
if part.FileName() != "" {
|
||||||
|
newPart, err := writer.CreateFormFile(fieldName, part.FileName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "create_form_file_failed")
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(newPart, part); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "copy_file_content_failed")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content, err := io.ReadAll(part)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "read_field_content_failed")
|
||||||
|
}
|
||||||
|
if err := writer.WriteField(fieldName, string(content)); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "write_field_failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := part.Close(); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "close_part_failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "close_multipart_writer_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return newBuffer, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DoRequest delegates to common helper.
|
// DoRequest delegates to common helper.
|
||||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||||
|
|||||||
@@ -252,6 +252,11 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 模型映射
|
||||||
|
if info.IsModelMapped {
|
||||||
|
req.Model = info.UpstreamModelName
|
||||||
|
}
|
||||||
|
|
||||||
storeTaskRequest(c, info, action, req)
|
storeTaskRequest(c, info, action, req)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/relay/channel"
|
"github.com/QuantumNous/new-api/relay/channel"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||||
|
"github.com/QuantumNous/new-api/relay/helper"
|
||||||
"github.com/QuantumNous/new-api/service"
|
"github.com/QuantumNous/new-api/service"
|
||||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||||
|
|
||||||
@@ -38,6 +39,11 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
|||||||
}
|
}
|
||||||
|
|
||||||
info.InitChannelMeta(c)
|
info.InitChannelMeta(c)
|
||||||
|
|
||||||
|
// 模型映射
|
||||||
|
if err := helper.ModelMappedHelper(c, info, nil); err != nil {
|
||||||
|
return service.TaskErrorWrapper(err, "model_mapped_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
adaptor := GetTaskAdaptor(platform)
|
adaptor := GetTaskAdaptor(platform)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||||
|
|||||||
Reference in New Issue
Block a user