diff --git a/constant/task.go b/constant/task.go index 21790145b..e174fd60e 100644 --- a/constant/task.go +++ b/constant/task.go @@ -11,8 +11,10 @@ const ( SunoActionMusic = "MUSIC" SunoActionLyrics = "LYRICS" - TaskActionGenerate = "generate" - TaskActionTextGenerate = "textGenerate" + TaskActionGenerate = "generate" + TaskActionTextGenerate = "textGenerate" + TaskActionFirstTailGenerate = "firstTailGenerate" + TaskActionReferenceGenerate = "referenceGenerate" ) var SunoModel2Action = map[string]string{ diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index a1140d1e7..358aef583 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -80,8 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { - // Use the unified validation method for TaskSubmitReq with image-based action determination - return relaycommon.ValidateTaskRequestWithImageBinding(c, info) + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { @@ -112,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro switch info.Action { case constant.TaskActionGenerate: path = "/img2video" + case constant.TaskActionFirstTailGenerate: + path = "/start-end2video" + case constant.TaskActionReferenceGenerate: + path = "/reference2video" default: path = "/text2video" } @@ -187,14 +190,9 @@ func (a *TaskAdaptor) GetChannelName() string { // ============================ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { - var images []string - if req.Image != "" { - images = []string{req.Image} - } - r := requestPayload{ Model: defaultString(req.Model, "viduq1"), - Images: images, + Images: req.Images, Prompt: req.Prompt, Duration: defaultInt(req.Duration, 5), Resolution: defaultString(req.Size, "1080p"), diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index cf6d08dda..3a721b479 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -79,34 +79,18 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d req.Images = []string{req.Image} } + if req.HasImage() { + action = constant.TaskActionGenerate + if info.ChannelType == constant.ChannelTypeVidu { + // vidu 增加 首尾帧生视频和参考图生视频 + if len(req.Images) == 2 { + action = constant.TaskActionFirstTailGenerate + } else if len(req.Images) > 2 { + action = constant.TaskActionReferenceGenerate + } + } + } + storeTaskRequest(c, info, action, req) return nil } - -func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError { - hasPrompt, ok := requestObj.(HasPrompt) - if !ok { - return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true) - } - - if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil { - return taskErr - } - - action := constant.TaskActionTextGenerate - if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() { - action = constant.TaskActionGenerate - } - - storeTaskRequest(c, info, action, requestObj) - return nil -} - -func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError { - var req TaskSubmitReq - if err := c.ShouldBindJSON(&req); err != nil { - return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false) - } - - return ValidateTaskRequestWithImage(c, info, req) -} diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index 766c17158..b63c7dd4f 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -35,8 +35,9 @@ import { Sparkles, } from 'lucide-react'; import { - TASK_ACTION_GENERATE, - TASK_ACTION_TEXT_GENERATE, + TASK_ACTION_FIRST_TAIL_GENERATE, + TASK_ACTION_GENERATE, TASK_ACTION_REFERENCE_GENERATE, + TASK_ACTION_TEXT_GENERATE } from '../../../constants/common.constant'; import { CHANNEL_OPTIONS } from '../../../constants/channel.constants'; @@ -111,6 +112,18 @@ const renderType = (type, t) => { {t('文生视频')} ); + case TASK_ACTION_FIRST_TAIL_GENERATE: + return ( + }> + {t('首尾生视频')} + + ); + case TASK_ACTION_REFERENCE_GENERATE: + return ( + }> + {t('参照生视频')} + + ); default: return ( }> @@ -343,7 +356,9 @@ export const getTaskLogsColumns = ({ // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 const isVideoTask = record.action === TASK_ACTION_GENERATE || - record.action === TASK_ACTION_TEXT_GENERATE; + record.action === TASK_ACTION_TEXT_GENERATE || + record.action === TASK_ACTION_FIRST_TAIL_GENERATE || + record.action === TASK_ACTION_REFERENCE_GENERATE; const isSuccess = record.status === 'SUCCESS'; const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); if (isSuccess && isVideoTask && isUrl) { diff --git a/web/src/constants/common.constant.js b/web/src/constants/common.constant.js index 277bb9a54..57fbbbde5 100644 --- a/web/src/constants/common.constant.js +++ b/web/src/constants/common.constant.js @@ -40,3 +40,5 @@ export const API_ENDPOINTS = [ export const TASK_ACTION_GENERATE = 'generate'; export const TASK_ACTION_TEXT_GENERATE = 'textGenerate'; +export const TASK_ACTION_FIRST_TAIL_GENERATE = 'firstTailGenerate'; +export const TASK_ACTION_REFERENCE_GENERATE = 'referenceGenerate';