feat: add openai video remix endpoint

This commit is contained in:
creamlike1024
2025-12-05 11:39:01 +08:00
parent 0b9f6a58bc
commit d732cdd259
4 changed files with 85 additions and 28 deletions

View File

@@ -181,6 +181,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/videos/") && strings.HasSuffix(c.Request.URL.Path, "/remix") {
relayMode := relayconstant.RelayModeVideoSubmit
c.Set("relay_mode", relayMode)
shouldSelectChannel = false
} else if strings.Contains(c.Request.URL.Path, "/v1/videos") {
//curl https://api.openai.com/v1/videos \
// -H "Authorization: Bearer $OPENAI_API_KEY" \

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
@@ -67,11 +68,30 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
a.apiKey = info.ApiKey
}
func validateRemixRequest(c *gin.Context) *dto.TaskError {
var req struct {
Prompt string `json:"prompt"`
}
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
}
if strings.TrimSpace(req.Prompt) == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
}
return nil
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
if info.Action == "remix" {
return validateRemixRequest(c)
}
return relaycommon.ValidateMultipartDirect(c, info)
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.Action == "remix" {
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
}
return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
}

View File

@@ -32,7 +32,67 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
if info.TaskRelayInfo == nil {
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
}
path := c.Request.URL.Path
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
info.Action = "remix"
}
// 提取 remix 任务的 video_id
if info.Action == "remix" {
videoID := c.Param("video_id")
if strings.TrimSpace(videoID) == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
}
info.OriginTaskID = videoID
}
platform := constant.TaskPlatform(c.GetString("platform"))
// 获取原始任务信息
if info.OriginTaskID != "" {
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
if info.OriginModelName == "" {
if originTask.Properties.OriginModelName != "" {
info.OriginModelName = originTask.Properties.OriginModelName
} else if originTask.Properties.UpstreamModelName != "" {
info.OriginModelName = originTask.Properties.UpstreamModelName
} else {
var taskData map[string]interface{}
_ = json.Unmarshal(originTask.Data, &taskData)
if m, ok := taskData["model"].(string); ok && m != "" {
info.OriginModelName = m
platform = originTask.Platform
}
}
}
if originTask.ChannelId != info.ChannelId {
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
return
}
if channel.Status != common.ChannelStatusEnabled {
taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
return
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
info.ChannelBaseUrl = channel.GetBaseURL()
info.ChannelId = originTask.ChannelId
platform = originTask.Platform
}
}
if platform == "" {
platform = GetTaskPlatform(c)
}
@@ -94,34 +154,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
return
}
if info.OriginTaskID != "" {
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
if originTask.ChannelId != info.ChannelId {
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
return
}
if channel.Status != common.ChannelStatusEnabled {
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
info.ChannelBaseUrl = channel.GetBaseURL()
info.ChannelId = originTask.ChannelId
}
}
// build body
requestBody, err := adaptor.BuildRequestBody(c, info)
if err != nil {

View File

@@ -14,6 +14,7 @@ func SetVideoRouter(router *gin.Engine) {
videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
videoV1Router.POST("/video/generations", controller.RelayTask)
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
}
// openai compatible API video routes
// docs: https://platform.openai.com/docs/api-reference/videos/create