mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 09:58:38 +00:00
refactor: use common taskSubmitReq
This commit is contained in:
@@ -18,7 +18,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
@@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||||
// Accept only POST /v1/video/generations as "generate" action.
|
// Accept only POST /v1/video/generations as "generate" action.
|
||||||
action := constant.TaskActionGenerate
|
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||||
info.Action = action
|
|
||||||
|
|
||||||
req := relaycommon.TaskSubmitReq{}
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
|
||||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.Prompt) == "" {
|
|
||||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store into context for later usage
|
|
||||||
c.Set("task_request", req)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildRequestURL constructs the upstream URL.
|
// BuildRequestURL constructs the upstream URL.
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
@@ -28,16 +27,6 @@ import (
|
|||||||
// Request / Response structures
|
// Request / Response structures
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type SubmitReq struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Mode string `json:"mode,omitempty"`
|
|
||||||
Image string `json:"image,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Duration int `json:"duration,omitempty"`
|
|
||||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TrajectoryPoint struct {
|
type TrajectoryPoint struct {
|
||||||
X int `json:"x"`
|
X int `json:"x"`
|
||||||
Y int `json:"y"`
|
Y int `json:"y"`
|
||||||
@@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
|
|
||||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||||
// Accept only POST /v1/video/generations as "generate" action.
|
// Use the standard validation method for TaskSubmitReq
|
||||||
action := constant.TaskActionGenerate
|
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||||
info.Action = action
|
|
||||||
|
|
||||||
var req SubmitReq
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
|
||||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.Prompt) == "" {
|
|
||||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store into context for later usage
|
|
||||||
c.Set("task_request", req)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildRequestURL constructs the upstream URL.
|
// BuildRequestURL constructs the upstream URL.
|
||||||
@@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if !exists {
|
if !exists {
|
||||||
return nil, fmt.Errorf("request not found in context")
|
return nil, fmt.Errorf("request not found in context")
|
||||||
}
|
}
|
||||||
req := v.(SubmitReq)
|
req := v.(relaycommon.TaskSubmitReq)
|
||||||
|
|
||||||
body, err := a.convertToRequestPayload(&req)
|
body, err := a.convertToRequestPayload(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
|||||||
// helpers
|
// helpers
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||||
r := requestPayload{
|
r := requestPayload{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
Image: req.Image,
|
Image: req.Image,
|
||||||
|
|||||||
@@ -23,16 +23,6 @@ import (
|
|||||||
// Request / Response structures
|
// Request / Response structures
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type SubmitReq struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Mode string `json:"mode,omitempty"`
|
|
||||||
Image string `json:"image,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Duration int `json:"duration,omitempty"`
|
|
||||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type requestPayload struct {
|
type requestPayload struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Images []string `json:"images"`
|
Images []string `json:"images"`
|
||||||
@@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||||
var req SubmitReq
|
// Use the unified validation method for TaskSubmitReq with image-based action determination
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
|
||||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Prompt == "" {
|
|
||||||
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Image != "" {
|
|
||||||
info.Action = constant.TaskActionGenerate
|
|
||||||
} else {
|
|
||||||
info.Action = constant.TaskActionTextGenerate
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Set("task_request", req)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||||
@@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
|
|||||||
if !exists {
|
if !exists {
|
||||||
return nil, fmt.Errorf("request not found in context")
|
return nil, fmt.Errorf("request not found in context")
|
||||||
}
|
}
|
||||||
req := v.(SubmitReq)
|
req := v.(relaycommon.TaskSubmitReq)
|
||||||
|
|
||||||
body, err := a.convertToRequestPayload(&req)
|
body, err := a.convertToRequestPayload(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
|||||||
// helpers
|
// helpers
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||||
var images []string
|
var images []string
|
||||||
if req.Image != "" {
|
if req.Image != "" {
|
||||||
images = []string{req.Image}
|
images = []string{req.Image}
|
||||||
|
|||||||
@@ -486,6 +486,14 @@ type TaskSubmitReq struct {
|
|||||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t TaskSubmitReq) GetPrompt() string {
|
||||||
|
return t.Prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TaskSubmitReq) GetImage() string {
|
||||||
|
return t.Image
|
||||||
|
}
|
||||||
|
|
||||||
type TaskInfo struct {
|
type TaskInfo struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
TaskID string `json:"task_id"`
|
TaskID string `json:"task_id"`
|
||||||
|
|||||||
@@ -2,12 +2,23 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type HasPrompt interface {
|
||||||
|
GetPrompt() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type HasImage interface {
|
||||||
|
GetImage() string
|
||||||
|
}
|
||||||
|
|
||||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
|
||||||
@@ -30,3 +41,67 @@ func GetAPIVersion(c *gin.Context) string {
|
|||||||
}
|
}
|
||||||
return apiVersion
|
return apiVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
|
||||||
|
return &dto.TaskError{
|
||||||
|
Code: code,
|
||||||
|
Message: err.Error(),
|
||||||
|
StatusCode: statusCode,
|
||||||
|
LocalError: localError,
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
|
||||||
|
info.Action = action
|
||||||
|
c.Set("task_request", requestObj)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePrompt(prompt string) *dto.TaskError {
|
||||||
|
if strings.TrimSpace(prompt) == "" {
|
||||||
|
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
|
||||||
|
var req TaskSubmitReq
|
||||||
|
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||||
|
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
|
||||||
|
return taskErr
|
||||||
|
}
|
||||||
|
|
||||||
|
storeTaskRequest(c, info, action, req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
|
||||||
|
hasPrompt, ok := requestObj.(HasPrompt)
|
||||||
|
if !ok {
|
||||||
|
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
|
||||||
|
return taskErr
|
||||||
|
}
|
||||||
|
|
||||||
|
action := constant.TaskActionTextGenerate
|
||||||
|
if hasImage, ok := requestObj.(HasImage); ok && strings.TrimSpace(hasImage.GetImage()) != "" {
|
||||||
|
action = constant.TaskActionGenerate
|
||||||
|
}
|
||||||
|
|
||||||
|
storeTaskRequest(c, info, action, requestObj)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||||
|
var req TaskSubmitReq
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValidateTaskRequestWithImage(c, info, req)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user