diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 955e592a2..2bc45c547 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -18,7 +18,6 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" - "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" @@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { // ValidateRequestAndSetAction parses body, validates fields and sets default action. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Accept only POST /v1/video/generations as "generate" action. - action := constant.TaskActionGenerate - info.Action = action - - req := relaycommon.TaskSubmitReq{} - if err := common.UnmarshalBodyReusable(c, &req); err != nil { - taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) - return - } - if strings.TrimSpace(req.Prompt) == "" { - taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) - return - } - - // Store into context for later usage - c.Set("task_request", req) - return nil + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) } // BuildRequestURL constructs the upstream URL. @@ -334,11 +318,11 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } // Handle one-of image_urls or binary_data_base64 - if req.Image != "" { - if strings.HasPrefix(req.Image, "http") { - r.ImageUrls = []string{req.Image} + if req.HasImage() { + if strings.HasPrefix(req.Images[0], "http") { + r.ImageUrls = req.Images } else { - r.BinaryDataBase64 = []string{req.Image} + r.BinaryDataBase64 = req.Images } } metadata := req.Metadata diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 3d6da253b..13f2af972 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -16,7 +16,6 @@ import ( "github.com/golang-jwt/jwt" "github.com/pkg/errors" - "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" @@ -28,16 +27,6 @@ import ( // Request / Response structures // ============================ -type SubmitReq struct { - Prompt string `json:"prompt"` - Model string `json:"model,omitempty"` - Mode string `json:"mode,omitempty"` - Image string `json:"image,omitempty"` - Size string `json:"size,omitempty"` - Duration int `json:"duration,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` -} - type TrajectoryPoint struct { X int `json:"x"` Y int `json:"y"` @@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { // ValidateRequestAndSetAction parses body, validates fields and sets default action. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - // Accept only POST /v1/video/generations as "generate" action. - action := constant.TaskActionGenerate - info.Action = action - - var req SubmitReq - if err := common.UnmarshalBodyReusable(c, &req); err != nil { - taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) - return - } - if strings.TrimSpace(req.Prompt) == "" { - taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) - return - } - - // Store into context for later usage - c.Set("task_request", req) - return nil + // Use the standard validation method for TaskSubmitReq + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) } // BuildRequestURL constructs the upstream URL. @@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if !exists { return nil, fmt.Errorf("request not found in context") } - req := v.(SubmitReq) + req := v.(relaycommon.TaskSubmitReq) body, err := a.convertToRequestPayload(&req) if err != nil { @@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { r := requestPayload{ Prompt: req.Prompt, Image: req.Image, diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index c82c1c0e8..a1140d1e7 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -23,16 +23,6 @@ import ( // Request / Response structures // ============================ -type SubmitReq struct { - Prompt string `json:"prompt"` - Model string `json:"model,omitempty"` - Mode string `json:"mode,omitempty"` - Image string `json:"image,omitempty"` - Size string `json:"size,omitempty"` - Duration int `json:"duration,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` -} - type requestPayload struct { Model string `json:"model"` Images []string `json:"images"` @@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { - var req SubmitReq - if err := c.ShouldBindJSON(&req); err != nil { - return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest) - } - - if req.Prompt == "" { - return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest) - } - - if req.Image != "" { - info.Action = constant.TaskActionGenerate - } else { - info.Action = constant.TaskActionTextGenerate - } - - c.Set("task_request", req) - return nil + // Use the unified validation method for TaskSubmitReq with image-based action determination + return relaycommon.ValidateTaskRequestWithImageBinding(c, info) } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { @@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) if !exists { return nil, fmt.Errorf("request not found in context") } - req := v.(SubmitReq) + req := v.(relaycommon.TaskSubmitReq) body, err := a.convertToRequestPayload(&req) if err != nil { @@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { var images []string if req.Image != "" { images = []string{req.Image} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index da572c070..99925dc5d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -481,11 +481,20 @@ type TaskSubmitReq struct { Model string `json:"model,omitempty"` Mode string `json:"mode,omitempty"` Image string `json:"image,omitempty"` + Images []string `json:"images,omitempty"` Size string `json:"size,omitempty"` Duration int `json:"duration,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` } +func (t TaskSubmitReq) GetPrompt() string { + return t.Prompt +} + +func (t TaskSubmitReq) HasImage() bool { + return len(t.Images) > 0 +} + type TaskInfo struct { Code int `json:"code"` TaskID string `json:"task_id"` diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 3d5efcb6d..cf6d08dda 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -2,12 +2,23 @@ package common import ( "fmt" + "net/http" + "one-api/common" "one-api/constant" + "one-api/dto" "strings" "github.com/gin-gonic/gin" ) +type HasPrompt interface { + GetPrompt() string +} + +type HasImage interface { + HasImage() bool +} + func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) @@ -30,3 +41,72 @@ func GetAPIVersion(c *gin.Context) string { } return apiVersion } + +func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError { + return &dto.TaskError{ + Code: code, + Message: err.Error(), + StatusCode: statusCode, + LocalError: localError, + Error: err, + } +} + +func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) { + info.Action = action + c.Set("task_request", requestObj) +} + +func validatePrompt(prompt string) *dto.TaskError { + if strings.TrimSpace(prompt) == "" { + return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true) + } + return nil +} + +func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError { + var req TaskSubmitReq + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + return createTaskError(err, "invalid_request", http.StatusBadRequest, true) + } + + if taskErr := validatePrompt(req.Prompt); taskErr != nil { + return taskErr + } + + if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" { + // 兼容单图上传 + req.Images = []string{req.Image} + } + + storeTaskRequest(c, info, action, req) + return nil +} + +func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError { + hasPrompt, ok := requestObj.(HasPrompt) + if !ok { + return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true) + } + + if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil { + return taskErr + } + + action := constant.TaskActionTextGenerate + if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() { + action = constant.TaskActionGenerate + } + + storeTaskRequest(c, info, action, requestObj) + return nil +} + +func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError { + var req TaskSubmitReq + if err := c.ShouldBindJSON(&req); err != nil { + return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false) + } + + return ValidateTaskRequestWithImage(c, info, req) +}