diff --git a/middleware/distributor.go b/middleware/distributor.go index 5a9deb23c..3c8529d96 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -181,6 +181,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) + } else if strings.Contains(c.Request.URL.Path, "/v1/videos/") && strings.HasSuffix(c.Request.URL.Path, "/remix") { + relayMode := relayconstant.RelayModeVideoSubmit + c.Set("relay_mode", relayMode) + shouldSelectChannel = false } else if strings.Contains(c.Request.URL.Path, "/v1/videos") { //curl https://api.openai.com/v1/videos \ // -H "Authorization: Bearer $OPENAI_API_KEY" \ diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index 17aec18f0..8486abf23 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" @@ -67,11 +68,30 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.apiKey = info.ApiKey } +func validateRemixRequest(c *gin.Context) *dto.TaskError { + var req struct { + Prompt string `json:"prompt"` + } + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + } + if strings.TrimSpace(req.Prompt) == "" { + return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest) + } + return nil +} + func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + if info.Action == "remix" { + return validateRemixRequest(c) + } return relaycommon.ValidateMultipartDirect(c, info) } func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { + if info.Action == "remix" { + return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil + } return fmt.Sprintf("%s/v1/videos", a.baseURL), nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index 61e2af523..ff4c73cfe 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -32,7 +32,67 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. if info.TaskRelayInfo == nil { info.TaskRelayInfo = &relaycommon.TaskRelayInfo{} } + path := c.Request.URL.Path + if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") { + info.Action = "remix" + } + + // 提取 remix 任务的 video_id + if info.Action == "remix" { + videoID := c.Param("video_id") + if strings.TrimSpace(videoID) == "" { + return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest) + } + info.OriginTaskID = videoID + } + platform := constant.TaskPlatform(c.GetString("platform")) + + // 获取原始任务信息 + if info.OriginTaskID != "" { + originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) + return + } + if !exist { + taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) + return + } + if info.OriginModelName == "" { + if originTask.Properties.OriginModelName != "" { + info.OriginModelName = originTask.Properties.OriginModelName + } else if originTask.Properties.UpstreamModelName != "" { + info.OriginModelName = originTask.Properties.UpstreamModelName + } else { + var taskData map[string]interface{} + _ = json.Unmarshal(originTask.Data, &taskData) + if m, ok := taskData["model"].(string); ok && m != "" { + info.OriginModelName = m + platform = originTask.Platform + } + } + } + if originTask.ChannelId != info.ChannelId { + channel, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + return + } + if channel.Status != common.ChannelStatusEnabled { + taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) + return + } + c.Set("base_url", channel.GetBaseURL()) + c.Set("channel_id", originTask.ChannelId) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + + info.ChannelBaseUrl = channel.GetBaseURL() + info.ChannelId = originTask.ChannelId + platform = originTask.Platform + } + + } if platform == "" { platform = GetTaskPlatform(c) } @@ -94,34 +154,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. return } - if info.OriginTaskID != "" { - originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) - return - } - if !exist { - taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) - return - } - if originTask.ChannelId != info.ChannelId { - channel, err := model.GetChannelById(originTask.ChannelId, true) - if err != nil { - taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) - return - } - if channel.Status != common.ChannelStatusEnabled { - return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest) - } - c.Set("base_url", channel.GetBaseURL()) - c.Set("channel_id", originTask.ChannelId) - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - - info.ChannelBaseUrl = channel.GetBaseURL() - info.ChannelId = originTask.ChannelId - } - } - // build body requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { diff --git a/router/video-router.go b/router/video-router.go index 87097cf86..d5fed1d78 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -14,6 +14,7 @@ func SetVideoRouter(router *gin.Engine) { videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy) videoV1Router.POST("/video/generations", controller.RelayTask) videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) + videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask) } // openai compatible API video routes // docs: https://platform.openai.com/docs/api-reference/videos/create