diff --git a/constant/task.go b/constant/task.go
index 21790145b..e174fd60e 100644
--- a/constant/task.go
+++ b/constant/task.go
@@ -11,8 +11,10 @@ const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
- TaskActionGenerate = "generate"
- TaskActionTextGenerate = "textGenerate"
+ TaskActionGenerate = "generate"
+ TaskActionTextGenerate = "textGenerate"
+ TaskActionFirstTailGenerate = "firstTailGenerate"
+ TaskActionReferenceGenerate = "referenceGenerate"
)
var SunoModel2Action = map[string]string{
diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go
index a1140d1e7..358aef583 100644
--- a/relay/channel/task/vidu/adaptor.go
+++ b/relay/channel/task/vidu/adaptor.go
@@ -80,8 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
- // Use the unified validation method for TaskSubmitReq with image-based action determination
- return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
+ return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@@ -112,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
switch info.Action {
case constant.TaskActionGenerate:
path = "/img2video"
+ case constant.TaskActionFirstTailGenerate:
+ path = "/start-end2video"
+ case constant.TaskActionReferenceGenerate:
+ path = "/reference2video"
default:
path = "/text2video"
}
@@ -187,14 +190,9 @@ func (a *TaskAdaptor) GetChannelName() string {
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
- var images []string
- if req.Image != "" {
- images = []string{req.Image}
- }
-
r := requestPayload{
Model: defaultString(req.Model, "viduq1"),
- Images: images,
+ Images: req.Images,
Prompt: req.Prompt,
Duration: defaultInt(req.Duration, 5),
Resolution: defaultString(req.Size, "1080p"),
diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go
index cf6d08dda..3a721b479 100644
--- a/relay/common/relay_utils.go
+++ b/relay/common/relay_utils.go
@@ -79,34 +79,18 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
req.Images = []string{req.Image}
}
+ if req.HasImage() {
+ action = constant.TaskActionGenerate
+ if info.ChannelType == constant.ChannelTypeVidu {
+ // vidu 增加 首尾帧生视频和参考图生视频
+ if len(req.Images) == 2 {
+ action = constant.TaskActionFirstTailGenerate
+ } else if len(req.Images) > 2 {
+ action = constant.TaskActionReferenceGenerate
+ }
+ }
+ }
+
storeTaskRequest(c, info, action, req)
return nil
}
-
-func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
- hasPrompt, ok := requestObj.(HasPrompt)
- if !ok {
- return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
- }
-
- if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
- return taskErr
- }
-
- action := constant.TaskActionTextGenerate
- if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
- action = constant.TaskActionGenerate
- }
-
- storeTaskRequest(c, info, action, requestObj)
- return nil
-}
-
-func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
- var req TaskSubmitReq
- if err := c.ShouldBindJSON(&req); err != nil {
- return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
- }
-
- return ValidateTaskRequestWithImage(c, info, req)
-}
diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx
index 766c17158..b63c7dd4f 100644
--- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx
+++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx
@@ -35,8 +35,9 @@ import {
Sparkles,
} from 'lucide-react';
import {
- TASK_ACTION_GENERATE,
- TASK_ACTION_TEXT_GENERATE,
+ TASK_ACTION_FIRST_TAIL_GENERATE,
+ TASK_ACTION_GENERATE, TASK_ACTION_REFERENCE_GENERATE,
+ TASK_ACTION_TEXT_GENERATE
} from '../../../constants/common.constant';
import { CHANNEL_OPTIONS } from '../../../constants/channel.constants';
@@ -111,6 +112,18 @@ const renderType = (type, t) => {
{t('文生视频')}
);
+ case TASK_ACTION_FIRST_TAIL_GENERATE:
+ return (
+ }>
+ {t('首尾生视频')}
+
+ );
+ case TASK_ACTION_REFERENCE_GENERATE:
+ return (
+ }>
+ {t('参照生视频')}
+
+ );
default:
return (
}>
@@ -343,7 +356,9 @@ export const getTaskLogsColumns = ({
// 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
const isVideoTask =
record.action === TASK_ACTION_GENERATE ||
- record.action === TASK_ACTION_TEXT_GENERATE;
+ record.action === TASK_ACTION_TEXT_GENERATE ||
+ record.action === TASK_ACTION_FIRST_TAIL_GENERATE ||
+ record.action === TASK_ACTION_REFERENCE_GENERATE;
const isSuccess = record.status === 'SUCCESS';
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
if (isSuccess && isVideoTask && isUrl) {
diff --git a/web/src/constants/common.constant.js b/web/src/constants/common.constant.js
index 277bb9a54..57fbbbde5 100644
--- a/web/src/constants/common.constant.js
+++ b/web/src/constants/common.constant.js
@@ -40,3 +40,5 @@ export const API_ENDPOINTS = [
export const TASK_ACTION_GENERATE = 'generate';
export const TASK_ACTION_TEXT_GENERATE = 'textGenerate';
+export const TASK_ACTION_FIRST_TAIL_GENERATE = 'firstTailGenerate';
+export const TASK_ACTION_REFERENCE_GENERATE = 'referenceGenerate';