Merge branch 'feitianbubu-pr/add-jimeng-video-images'

This commit is contained in:
creamlike1024
2025-09-13 09:57:01 +08:00
5 changed files with 102 additions and 80 deletions

View File

@@ -18,7 +18,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
@@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
action := constant.TaskActionGenerate
info.Action = action
req := relaycommon.TaskSubmitReq{}
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
// BuildRequestURL constructs the upstream URL.
@@ -334,11 +318,11 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
// Handle one-of image_urls or binary_data_base64
if req.Image != "" {
if strings.HasPrefix(req.Image, "http") {
r.ImageUrls = []string{req.Image}
if req.HasImage() {
if strings.HasPrefix(req.Images[0], "http") {
r.ImageUrls = req.Images
} else {
r.BinaryDataBase64 = []string{req.Image}
r.BinaryDataBase64 = req.Images
}
}
metadata := req.Metadata

View File

@@ -16,7 +16,6 @@ import (
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
@@ -28,16 +27,6 @@ import (
// Request / Response structures
// ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type TrajectoryPoint struct {
X int `json:"x"`
Y int `json:"y"`
@@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
action := constant.TaskActionGenerate
info.Action = action
var req SubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
// Use the standard validation method for TaskSubmitReq
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
// BuildRequestURL constructs the upstream URL.
@@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(SubmitReq)
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
@@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,

View File

@@ -23,16 +23,6 @@ import (
// Request / Response structures
// ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type requestPayload struct {
Model string `json:"model"`
Images []string `json:"images"`
@@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
var req SubmitReq
if err := c.ShouldBindJSON(&req); err != nil {
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
}
if req.Prompt == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
}
if req.Image != "" {
info.Action = constant.TaskActionGenerate
} else {
info.Action = constant.TaskActionTextGenerate
}
c.Set("task_request", req)
return nil
// Use the unified validation method for TaskSubmitReq with image-based action determination
return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(SubmitReq)
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
@@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
var images []string
if req.Image != "" {
images = []string{req.Image}

View File

@@ -481,11 +481,20 @@ type TaskSubmitReq struct {
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Images []string `json:"images,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
func (t TaskSubmitReq) GetPrompt() string {
return t.Prompt
}
func (t TaskSubmitReq) HasImage() bool {
return len(t.Images) > 0
}
type TaskInfo struct {
Code int `json:"code"`
TaskID string `json:"task_id"`

View File

@@ -2,12 +2,23 @@ package common
import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"strings"
"github.com/gin-gonic/gin"
)
type HasPrompt interface {
GetPrompt() string
}
type HasImage interface {
HasImage() bool
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
@@ -30,3 +41,72 @@ func GetAPIVersion(c *gin.Context) string {
}
return apiVersion
}
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
return &dto.TaskError{
Code: code,
Message: err.Error(),
StatusCode: statusCode,
LocalError: localError,
Error: err,
}
}
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
info.Action = action
c.Set("task_request", requestObj)
}
func validatePrompt(prompt string) *dto.TaskError {
if strings.TrimSpace(prompt) == "" {
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
}
return nil
}
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
var req TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
return taskErr
}
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
// 兼容单图上传
req.Images = []string{req.Image}
}
storeTaskRequest(c, info, action, req)
return nil
}
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
hasPrompt, ok := requestObj.(HasPrompt)
if !ok {
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
return taskErr
}
action := constant.TaskActionTextGenerate
if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
action = constant.TaskActionGenerate
}
storeTaskRequest(c, info, action, requestObj)
return nil
}
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
var req TaskSubmitReq
if err := c.ShouldBindJSON(&req); err != nil {
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
}
return ValidateTaskRequestWithImage(c, info, req)
}