From a62d96c1f10dccea0feb27d7aa07208bf14252ca Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Mon, 10 Nov 2025 15:13:41 +0800 Subject: [PATCH 1/2] feat: vidu specify reference2video via metadata action --- relay/channel/task/vidu/adaptor.go | 33 ++++++++++++++++++++++-------- relay/common/relay_utils.go | 23 ++++++++++----------- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 7ccac2ff2..06257c6a4 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -82,7 +82,29 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { - return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) + if err := relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate); err != nil { + return err + } + req, err := relaycommon.GetTaskRequest(c) + if err != nil { + return service.TaskErrorWrapper(err, "get_task_request_failed", http.StatusBadRequest) + } + action := constant.TaskActionTextGenerate + if meatAction, ok := req.Metadata["action"]; ok { + action, _ = meatAction.(string) + } else if req.HasImage() { + action = constant.TaskActionGenerate + if info.ChannelType == constant.ChannelTypeVidu { + // vidu 增加 首尾帧生视频和参考图生视频 + if len(req.Images) == 2 { + action = constant.TaskActionFirstTailGenerate + } else if len(req.Images) > 2 { + action = constant.TaskActionReferenceGenerate + } + } + } + info.Action = action + return nil } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { @@ -97,10 +119,6 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) return nil, err } - if len(body.Images) == 0 { - c.Set("action", constant.TaskActionTextGenerate) - } - data, err := json.Marshal(body) if err != nil { return nil, err @@ -131,9 +149,6 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info } func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - if action := c.GetString("action"); action != "" { - info.Action = action - } return channel.DoTaskApiRequest(a, c, info, requestBody) } @@ -185,7 +200,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http } func (a *TaskAdaptor) GetModelList() []string { - return []string{"viduq1", "vidu2.0", "vidu1.5"} + return []string{"viduq2", "viduq1", "vidu2.0", "vidu1.5"} } func (a *TaskAdaptor) GetChannelName() string { diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 05c42dfd8..b662f9053 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -59,6 +59,17 @@ func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj info.Action = action c.Set("task_request", requestObj) } +func GetTaskRequest(c *gin.Context) (TaskSubmitReq, error) { + v, exists := c.Get("task_request") + if !exists { + return TaskSubmitReq{}, fmt.Errorf("request not found in context") + } + req, ok := v.(TaskSubmitReq) + if !ok { + return TaskSubmitReq{}, fmt.Errorf("invalid task request type") + } + return req, nil +} func validatePrompt(prompt string) *dto.TaskError { if strings.TrimSpace(prompt) == "" { @@ -212,18 +223,6 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d req.Images = []string{req.Image} } - if req.HasImage() { - action = constant.TaskActionGenerate - if info.ChannelType == constant.ChannelTypeVidu { - // vidu 增加 首尾帧生视频和参考图生视频 - if len(req.Images) == 2 { - action = constant.TaskActionFirstTailGenerate - } else if len(req.Images) > 2 { - action = constant.TaskActionReferenceGenerate - } - } - } - storeTaskRequest(c, info, action, req) return nil } From 1a8d89c410ad29d707d83a148eba6f65f54323b7 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Mon, 10 Nov 2025 16:34:47 +0800 Subject: [PATCH 2/2] feat: vidu reference2video only viduq2 --- relay/channel/task/vidu/adaptor.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 06257c6a4..6b62f1f01 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" "github.com/QuantumNous/new-api/common" @@ -107,7 +108,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return nil } -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") @@ -119,6 +120,13 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) return nil, err } + if info.Action == constant.TaskActionReferenceGenerate { + if strings.Contains(body.Model, "viduq2") { + // 参考图生视频只能用 viduq2 模型, 不能带有pro或turbo后缀 https://platform.vidu.cn/docs/reference-to-video + body.Model = "viduq2" + } + } + data, err := json.Marshal(body) if err != nil { return nil, err