Merge branch 'feitianbubu-pr/add-jimeng-video-images'

This commit is contained in:
creamlike1024
2025-09-13 09:57:01 +08:00
5 changed files with 102 additions and 80 deletions

View File

@@ -18,7 +18,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors" "github.com/pkg/errors"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
@@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action. // ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action. // Accept only POST /v1/video/generations as "generate" action.
action := constant.TaskActionGenerate return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
info.Action = action
req := relaycommon.TaskSubmitReq{}
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
} }
// BuildRequestURL constructs the upstream URL. // BuildRequestURL constructs the upstream URL.
@@ -334,11 +318,11 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
} }
// Handle one-of image_urls or binary_data_base64 // Handle one-of image_urls or binary_data_base64
if req.Image != "" { if req.HasImage() {
if strings.HasPrefix(req.Image, "http") { if strings.HasPrefix(req.Images[0], "http") {
r.ImageUrls = []string{req.Image} r.ImageUrls = req.Images
} else { } else {
r.BinaryDataBase64 = []string{req.Image} r.BinaryDataBase64 = req.Images
} }
} }
metadata := req.Metadata metadata := req.Metadata

View File

@@ -16,7 +16,6 @@ import (
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
@@ -28,16 +27,6 @@ import (
// Request / Response structures // Request / Response structures
// ============================ // ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type TrajectoryPoint struct { type TrajectoryPoint struct {
X int `json:"x"` X int `json:"x"`
Y int `json:"y"` Y int `json:"y"`
@@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action. // ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action. // Use the standard validation method for TaskSubmitReq
action := constant.TaskActionGenerate return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
info.Action = action
var req SubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
} }
// BuildRequestURL constructs the upstream URL. // BuildRequestURL constructs the upstream URL.
@@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if !exists { if !exists {
return nil, fmt.Errorf("request not found in context") return nil, fmt.Errorf("request not found in context")
} }
req := v.(SubmitReq) req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req) body, err := a.convertToRequestPayload(&req)
if err != nil { if err != nil {
@@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers // helpers
// ============================ // ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{ r := requestPayload{
Prompt: req.Prompt, Prompt: req.Prompt,
Image: req.Image, Image: req.Image,

View File

@@ -23,16 +23,6 @@ import (
// Request / Response structures // Request / Response structures
// ============================ // ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type requestPayload struct { type requestPayload struct {
Model string `json:"model"` Model string `json:"model"`
Images []string `json:"images"` Images []string `json:"images"`
@@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
var req SubmitReq // Use the unified validation method for TaskSubmitReq with image-based action determination
if err := c.ShouldBindJSON(&req); err != nil { return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
}
if req.Prompt == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
}
if req.Image != "" {
info.Action = constant.TaskActionGenerate
} else {
info.Action = constant.TaskActionTextGenerate
}
c.Set("task_request", req)
return nil
} }
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
if !exists { if !exists {
return nil, fmt.Errorf("request not found in context") return nil, fmt.Errorf("request not found in context")
} }
req := v.(SubmitReq) req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req) body, err := a.convertToRequestPayload(&req)
if err != nil { if err != nil {
@@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers // helpers
// ============================ // ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
var images []string var images []string
if req.Image != "" { if req.Image != "" {
images = []string{req.Image} images = []string{req.Image}

View File

@@ -481,11 +481,20 @@ type TaskSubmitReq struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"` Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"` Image string `json:"image,omitempty"`
Images []string `json:"images,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"` Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"`
} }
func (t TaskSubmitReq) GetPrompt() string {
return t.Prompt
}
func (t TaskSubmitReq) HasImage() bool {
return len(t.Images) > 0
}
type TaskInfo struct { type TaskInfo struct {
Code int `json:"code"` Code int `json:"code"`
TaskID string `json:"task_id"` TaskID string `json:"task_id"`

View File

@@ -2,12 +2,23 @@ package common
import ( import (
"fmt" "fmt"
"net/http"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type HasPrompt interface {
GetPrompt() string
}
type HasImage interface {
HasImage() bool
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
@@ -30,3 +41,72 @@ func GetAPIVersion(c *gin.Context) string {
} }
return apiVersion return apiVersion
} }
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
return &dto.TaskError{
Code: code,
Message: err.Error(),
StatusCode: statusCode,
LocalError: localError,
Error: err,
}
}
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
info.Action = action
c.Set("task_request", requestObj)
}
func validatePrompt(prompt string) *dto.TaskError {
if strings.TrimSpace(prompt) == "" {
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
}
return nil
}
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
var req TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
return taskErr
}
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
// 兼容单图上传
req.Images = []string{req.Image}
}
storeTaskRequest(c, info, action, req)
return nil
}
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
hasPrompt, ok := requestObj.(HasPrompt)
if !ok {
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
return taskErr
}
action := constant.TaskActionTextGenerate
if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
action = constant.TaskActionGenerate
}
storeTaskRequest(c, info, action, requestObj)
return nil
}
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
var req TaskSubmitReq
if err := c.ShouldBindJSON(&req); err != nil {
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
}
return ValidateTaskRequestWithImage(c, info, req)
}