refactor: use common taskSubmitReq

This commit is contained in:
feitianbubu
2025-09-12 21:52:32 +08:00
parent b712279b2a
commit 6ed775be8f
5 changed files with 92 additions and 76 deletions

View File

@@ -18,7 +18,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors" "github.com/pkg/errors"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
@@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action. // ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action. // Accept only POST /v1/video/generations as "generate" action.
action := constant.TaskActionGenerate return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
info.Action = action
req := relaycommon.TaskSubmitReq{}
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
} }
// BuildRequestURL constructs the upstream URL. // BuildRequestURL constructs the upstream URL.

View File

@@ -16,7 +16,6 @@ import (
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
@@ -28,16 +27,6 @@ import (
// Request / Response structures // Request / Response structures
// ============================ // ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type TrajectoryPoint struct { type TrajectoryPoint struct {
X int `json:"x"` X int `json:"x"`
Y int `json:"y"` Y int `json:"y"`
@@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action. // ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action. // Use the standard validation method for TaskSubmitReq
action := constant.TaskActionGenerate return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
info.Action = action
var req SubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
} }
// BuildRequestURL constructs the upstream URL. // BuildRequestURL constructs the upstream URL.
@@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if !exists { if !exists {
return nil, fmt.Errorf("request not found in context") return nil, fmt.Errorf("request not found in context")
} }
req := v.(SubmitReq) req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req) body, err := a.convertToRequestPayload(&req)
if err != nil { if err != nil {
@@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers // helpers
// ============================ // ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{ r := requestPayload{
Prompt: req.Prompt, Prompt: req.Prompt,
Image: req.Image, Image: req.Image,

View File

@@ -23,16 +23,6 @@ import (
// Request / Response structures // Request / Response structures
// ============================ // ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type requestPayload struct { type requestPayload struct {
Model string `json:"model"` Model string `json:"model"`
Images []string `json:"images"` Images []string `json:"images"`
@@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
var req SubmitReq // Use the unified validation method for TaskSubmitReq with image-based action determination
if err := c.ShouldBindJSON(&req); err != nil { return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
}
if req.Prompt == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
}
if req.Image != "" {
info.Action = constant.TaskActionGenerate
} else {
info.Action = constant.TaskActionTextGenerate
}
c.Set("task_request", req)
return nil
} }
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
if !exists { if !exists {
return nil, fmt.Errorf("request not found in context") return nil, fmt.Errorf("request not found in context")
} }
req := v.(SubmitReq) req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req) body, err := a.convertToRequestPayload(&req)
if err != nil { if err != nil {
@@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers // helpers
// ============================ // ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
var images []string var images []string
if req.Image != "" { if req.Image != "" {
images = []string{req.Image} images = []string{req.Image}

View File

@@ -486,6 +486,14 @@ type TaskSubmitReq struct {
Metadata map[string]interface{} `json:"metadata,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"`
} }
func (t TaskSubmitReq) GetPrompt() string {
return t.Prompt
}
func (t TaskSubmitReq) GetImage() string {
return t.Image
}
type TaskInfo struct { type TaskInfo struct {
Code int `json:"code"` Code int `json:"code"`
TaskID string `json:"task_id"` TaskID string `json:"task_id"`

View File

@@ -2,12 +2,23 @@ package common
import ( import (
"fmt" "fmt"
"net/http"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type HasPrompt interface {
GetPrompt() string
}
type HasImage interface {
GetImage() string
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
@@ -30,3 +41,67 @@ func GetAPIVersion(c *gin.Context) string {
} }
return apiVersion return apiVersion
} }
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
return &dto.TaskError{
Code: code,
Message: err.Error(),
StatusCode: statusCode,
LocalError: localError,
Error: err,
}
}
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
info.Action = action
c.Set("task_request", requestObj)
}
func validatePrompt(prompt string) *dto.TaskError {
if strings.TrimSpace(prompt) == "" {
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
}
return nil
}
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
var req TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
return taskErr
}
storeTaskRequest(c, info, action, req)
return nil
}
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
hasPrompt, ok := requestObj.(HasPrompt)
if !ok {
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
return taskErr
}
action := constant.TaskActionTextGenerate
if hasImage, ok := requestObj.(HasImage); ok && strings.TrimSpace(hasImage.GetImage()) != "" {
action = constant.TaskActionGenerate
}
storeTaskRequest(c, info, action, requestObj)
return nil
}
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
var req TaskSubmitReq
if err := c.ShouldBindJSON(&req); err != nil {
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
}
return ValidateTaskRequestWithImage(c, info, req)
}