Files
new-api/relay/common/relay_utils.go

258 lines
6.7 KiB
Go

package common
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
)
type HasPrompt interface {
GetPrompt() string
}
type HasImage interface {
HasImage() bool
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case constant.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case constant.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
return &dto.TaskError{
Code: code,
Message: err.Error(),
StatusCode: statusCode,
LocalError: localError,
Error: err,
}
}
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
info.Action = action
c.Set("task_request", requestObj)
}
func validatePrompt(prompt string) *dto.TaskError {
if strings.TrimSpace(prompt) == "" {
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
}
return nil
}
func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
var req TaskSubmitReq
if _, err := c.MultipartForm(); err != nil {
return req, err
}
formData := c.Request.PostForm
req = TaskSubmitReq{
Prompt: formData.Get("prompt"),
Model: formData.Get("model"),
Mode: formData.Get("mode"),
Image: formData.Get("image"),
Size: formData.Get("size"),
Metadata: make(map[string]interface{}),
}
if durationStr := formData.Get("seconds"); durationStr != "" {
if duration, err := strconv.Atoi(durationStr); err == nil {
req.Duration = duration
}
}
if images := formData["images"]; len(images) > 0 {
req.Images = images
}
for key, values := range formData {
if len(values) > 0 && !isKnownTaskField(key) {
if intVal, err := strconv.Atoi(values[0]); err == nil {
req.Metadata[key] = intVal
} else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
req.Metadata[key] = floatVal
} else {
req.Metadata[key] = values[0]
}
}
}
return req, nil
}
func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
contentType := c.GetHeader("Content-Type")
var prompt string
var model string
var seconds int
var size string
var hasInputReference bool
if strings.HasPrefix(contentType, "multipart/form-data") {
form, err := common.ParseMultipartFormReusable(c)
if err != nil {
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
}
defer form.RemoveAll()
prompts, ok := form.Value["prompt"]
if !ok || len(prompts) == 0 {
return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
}
prompt = prompts[0]
if _, ok := form.Value["model"]; !ok {
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
}
model = form.Value["model"][0]
if _, ok := form.File["input_reference"]; ok {
hasInputReference = true
}
if ss, ok := form.Value["seconds"]; ok {
sInt := common.String2Int(ss[0])
if sInt > seconds {
seconds = common.String2Int(ss[0])
}
}
if sz, ok := form.Value["size"]; ok {
size = sz[0]
}
} else {
var req TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
}
prompt = req.Prompt
model = req.Model
seconds = req.Duration
if strings.TrimSpace(req.Model) == "" {
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
}
if req.HasImage() {
hasInputReference = true
}
}
if taskErr := validatePrompt(prompt); taskErr != nil {
return taskErr
}
action := constant.TaskActionTextGenerate
if hasInputReference {
action = constant.TaskActionGenerate
}
if strings.HasPrefix(model, "sora-2") {
if size == "" {
size = "720x1280"
}
if seconds <= 0 {
seconds = 4
}
if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) {
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
}
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
}
info.PriceData.OtherRatios = map[string]float64{
"seconds": float64(seconds),
"size": 1,
}
if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
info.PriceData.OtherRatios["size"] = 1.666667
}
}
info.Action = action
return nil
}
func isKnownTaskField(field string) bool {
knownFields := map[string]bool{
"prompt": true,
"model": true,
"mode": true,
"image": true,
"images": true,
"size": true,
"duration": true,
"input_reference": true, // Sora 特有字段
}
return knownFields[field]
}
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
var err error
contentType := c.GetHeader("Content-Type")
var req TaskSubmitReq
if strings.HasPrefix(contentType, "multipart/form-data") {
req, err = validateMultipartTaskRequest(c, info, action)
if err != nil {
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
}
} else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
return taskErr
}
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
// 兼容单图上传
req.Images = []string{req.Image}
}
if req.HasImage() {
action = constant.TaskActionGenerate
if info.ChannelType == constant.ChannelTypeVidu {
// vidu 增加 首尾帧生视频和参考图生视频
if len(req.Images) == 2 {
action = constant.TaskActionFirstTailGenerate
} else if len(req.Images) > 2 {
action = constant.TaskActionReferenceGenerate
}
}
}
storeTaskRequest(c, info, action, req)
return nil
}