diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index 17aec18f0..d0bcc7111 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -2,9 +2,12 @@ package sora import ( "bytes" + "encoding/json" "fmt" "io" + "mime/multipart" "net/http" + "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" @@ -87,9 +90,96 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "get_request_body_failed") } + + // 检查是否需要模型重定向 + if !info.IsModelMapped { + // 如果不需要重定向,直接返回原始请求体 + return bytes.NewReader(cachedBody), nil + } + + contentType := c.Request.Header.Get("Content-Type") + + // 处理multipart/form-data请求 + if strings.Contains(contentType, "multipart/form-data") { + return buildRequestBodyWithMappedModel(cachedBody, contentType, info.UpstreamModelName) + } + // 处理JSON请求 + if strings.Contains(contentType, "application/json") { + var jsonData map[string]interface{} + if err := json.Unmarshal(cachedBody, &jsonData); err != nil { + return nil, errors.Wrap(err, "unmarshal_json_failed") + } + + // 替换model字段为映射后的模型名 + jsonData["model"] = info.UpstreamModelName + + // 重新编码为JSON + newBody, err := json.Marshal(jsonData) + if err != nil { + return nil, errors.Wrap(err, "marshal_json_failed") + } + + return bytes.NewReader(newBody), nil + } + return bytes.NewReader(cachedBody), nil } +func buildRequestBodyWithMappedModel(originalBody []byte, contentType, redirectedModel string) (io.Reader, error) { + newBuffer := &bytes.Buffer{} + writer := multipart.NewWriter(newBuffer) + + r := multipart.NewReader(bytes.NewReader(originalBody), strings.TrimPrefix(contentType, "multipart/form-data; boundary=")) + + for { + part, err := r.NextPart() + if err == io.EOF { + break + } + if err != nil { + return nil, errors.Wrap(err, "read_multipart_part_failed") + } + + fieldName := part.FormName() + + if fieldName == "model" { + // 修改 model 字段为映射后的模型名 + if err := writer.WriteField("model", redirectedModel); err != nil { + return nil, errors.Wrap(err, "write_model_field_failed") + } + } else { + // 对于其他字段,保留原始内容 + if part.FileName() != "" { + newPart, err := writer.CreateFormFile(fieldName, part.FileName()) + if err != nil { + return nil, errors.Wrap(err, "create_form_file_failed") + } + if _, err := io.Copy(newPart, part); err != nil { + return nil, errors.Wrap(err, "copy_file_content_failed") + } + } else { + content, err := io.ReadAll(part) + if err != nil { + return nil, errors.Wrap(err, "read_field_content_failed") + } + if err := writer.WriteField(fieldName, string(content)); err != nil { + return nil, errors.Wrap(err, "write_field_failed") + } + } + } + + if err := part.Close(); err != nil { + return nil, errors.Wrap(err, "close_part_failed") + } + } + + if err := writer.Close(); err != nil { + return nil, errors.Wrap(err, "close_multipart_writer_failed") + } + + return newBuffer, nil +} + // DoRequest delegates to common helper. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index b38baf13a..26dd2afc2 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -252,6 +252,11 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d } } + // 模型映射 + if info.IsModelMapped { + req.Model = info.UpstreamModelName + } + storeTaskRequest(c, info, action, req) return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index ca1b0bb1f..db543319a 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -17,6 +17,7 @@ import ( "github.com/QuantumNous/new-api/relay/channel" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" + "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/ratio_setting" @@ -38,6 +39,11 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. } info.InitChannelMeta(c) + + // 模型映射 + if err := helper.ModelMappedHelper(c, info, nil); err != nil { + return service.TaskErrorWrapper(err, "model_mapped_failed", http.StatusBadRequest) + } adaptor := GetTaskAdaptor(platform) if adaptor == nil { return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)