mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:05:21 +00:00
feat: add sora video submit task
This commit is contained in:
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
@@ -113,3 +114,26 @@ func ApiSuccess(c *gin.Context, data any) {
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
|
||||
requestBody, err := GetRequestBody(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
boundary := ""
|
||||
if idx := strings.Index(contentType, "boundary="); idx != -1 {
|
||||
boundary = contentType[idx+9:]
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(requestBody), boundary)
|
||||
form, err := reader.ReadForm(32 << 20) // 32 MB max memory
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset request body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return form, nil
|
||||
}
|
||||
|
||||
@@ -52,6 +52,7 @@ const (
|
||||
ChannelTypeVidu = 52
|
||||
ChannelTypeSubmodel = 53
|
||||
ChannelTypeDoubaoVideo = 54
|
||||
ChannelTypeSora = 55
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -112,6 +113,7 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.vidu.cn", //52
|
||||
"https://llm.submodel.ai", //53
|
||||
"https://ark.cn-beijing.volces.com", //54
|
||||
"https://api.openai.com", //55
|
||||
}
|
||||
|
||||
var ChannelTypeNames = map[int]string{
|
||||
@@ -166,6 +168,7 @@ var ChannelTypeNames = map[int]string{
|
||||
ChannelTypeVidu: "Vidu",
|
||||
ChannelTypeSubmodel: "Submodel",
|
||||
ChannelTypeDoubaoVideo: "DoubaoVideo",
|
||||
ChannelTypeSora: "Sora",
|
||||
}
|
||||
|
||||
func GetChannelTypeName(channelType int) string {
|
||||
|
||||
@@ -233,6 +233,16 @@ type Usage struct {
|
||||
Cost any `json:"cost,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIVideoResponse struct {
|
||||
Id string `json:"id" example:"file-abc123"`
|
||||
Object string `json:"object" example:"file"`
|
||||
Bytes int64 `json:"bytes" example:"120000"`
|
||||
CreatedAt int64 `json:"created_at" example:"1677610602"`
|
||||
ExpiresAt int64 `json:"expires_at" example:"1677614202"`
|
||||
Filename string `json:"filename" example:"mydata.jsonl"`
|
||||
Purpose string `json:"purpose" example:"fine-tune"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
CachedCreationTokens int `json:"-"`
|
||||
|
||||
@@ -165,6 +165,18 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/videos") {
|
||||
//curl https://api.openai.com/v1/videos \
|
||||
// -H "Authorization: Bearer $OPENAI_API_KEY" \
|
||||
// -F "model=sora-2" \
|
||||
// -F "prompt=A calico cat playing a piano on stage"
|
||||
// -F input_reference="@image.jpg"
|
||||
relayMode := relayconstant.RelayModeUnknown
|
||||
if c.Request.Method == http.MethodPost {
|
||||
relayMode = relayconstant.RelayModeVideoSubmit
|
||||
modelRequest.Model = c.PostForm("model")
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
relayMode := relayconstant.RelayModeUnknown
|
||||
if c.Request.Method == http.MethodPost {
|
||||
|
||||
192
relay/channel/task/sora/adaptor.go
Normal file
192
relay/channel/task/sora/adaptor.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package sora
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type ContentItem struct {
|
||||
Type string `json:"type"` // "text" or "image_url"
|
||||
Text string `json:"text,omitempty"` // for text type
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
|
||||
}
|
||||
|
||||
type ImageURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
ID string `json:"id"` // task_id
|
||||
}
|
||||
|
||||
type responseTask struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Status string `json:"status"`
|
||||
Content struct {
|
||||
VideoURL string `json:"video_url"`
|
||||
} `json:"content"`
|
||||
Seed int `json:"seed"`
|
||||
Resolution string `json:"resolution"`
|
||||
Duration int `json:"duration"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Authorization", "Bearer "+a.apiKey)
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
cachedBody, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get_request_body_failed")
|
||||
}
|
||||
return bytes.NewReader(cachedBody), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// Parse Sora response
|
||||
var dResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &dResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if dResp.ID == "" {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID})
|
||||
return dResp.ID, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s/v1/videos/generations/%s", baseUrl, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, uri, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := responseTask{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
|
||||
taskResult := relaycommon.TaskInfo{
|
||||
Code: 0,
|
||||
}
|
||||
|
||||
// Map Sora status to internal status
|
||||
switch resTask.Status {
|
||||
case "pending", "queued":
|
||||
taskResult.Status = model.TaskStatusQueued
|
||||
taskResult.Progress = "10%"
|
||||
case "processing", "running":
|
||||
taskResult.Status = model.TaskStatusInProgress
|
||||
taskResult.Progress = "50%"
|
||||
case "succeeded", "completed":
|
||||
taskResult.Status = model.TaskStatusSuccess
|
||||
taskResult.Progress = "100%"
|
||||
taskResult.Url = resTask.Content.VideoURL
|
||||
// Parse usage information for billing
|
||||
taskResult.CompletionTokens = resTask.Usage.CompletionTokens
|
||||
taskResult.TotalTokens = resTask.Usage.TotalTokens
|
||||
case "failed", "cancelled":
|
||||
taskResult.Status = model.TaskStatusFailure
|
||||
taskResult.Progress = "100%"
|
||||
taskResult.Reason = "task failed"
|
||||
default:
|
||||
// Unknown status, treat as processing
|
||||
taskResult.Status = model.TaskStatusInProgress
|
||||
taskResult.Progress = "30%"
|
||||
}
|
||||
|
||||
return &taskResult, nil
|
||||
}
|
||||
8
relay/channel/task/sora/constants.go
Normal file
8
relay/channel/task/sora/constants.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package sora
|
||||
|
||||
var ModelList = []string{
|
||||
"sora-2",
|
||||
"sora-2-pro",
|
||||
}
|
||||
|
||||
var ChannelName = "sora"
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -52,7 +53,7 @@ func createTaskError(err error, code string, statusCode int, localError bool) *d
|
||||
}
|
||||
}
|
||||
|
||||
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
|
||||
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
|
||||
info.Action = action
|
||||
c.Set("task_request", requestObj)
|
||||
}
|
||||
@@ -64,9 +65,97 @@ func validatePrompt(prompt string) *dto.TaskError {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
|
||||
func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
|
||||
var req TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
if _, err := c.MultipartForm(); err != nil {
|
||||
return req, err
|
||||
}
|
||||
|
||||
formData := c.Request.PostForm
|
||||
req = TaskSubmitReq{
|
||||
Prompt: formData.Get("prompt"),
|
||||
Model: formData.Get("model"),
|
||||
Mode: formData.Get("mode"),
|
||||
Image: formData.Get("image"),
|
||||
Size: formData.Get("size"),
|
||||
Metadata: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
if durationStr := formData.Get("seconds"); durationStr != "" {
|
||||
if duration, err := strconv.Atoi(durationStr); err == nil {
|
||||
req.Duration = duration
|
||||
}
|
||||
}
|
||||
|
||||
if images := formData["images"]; len(images) > 0 {
|
||||
req.Images = images
|
||||
}
|
||||
|
||||
for key, values := range formData {
|
||||
if len(values) > 0 && !isKnownTaskField(key) {
|
||||
if intVal, err := strconv.Atoi(values[0]); err == nil {
|
||||
req.Metadata[key] = intVal
|
||||
} else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
|
||||
req.Metadata[key] = floatVal
|
||||
} else {
|
||||
req.Metadata[key] = values[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
form, err := common.ParseMultipartFormReusable(c)
|
||||
if err != nil {
|
||||
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
|
||||
}
|
||||
defer form.RemoveAll()
|
||||
|
||||
prompts, ok := form.Value["prompt"]
|
||||
if !ok || len(prompts) == 0 {
|
||||
return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
|
||||
}
|
||||
if taskErr := validatePrompt(prompts[0]); taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
|
||||
if _, ok := form.Value["model"]; !ok {
|
||||
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
|
||||
}
|
||||
action := constant.TaskActionTextGenerate
|
||||
if _, ok := form.File["input_reference"]; ok {
|
||||
action = constant.TaskActionGenerate
|
||||
}
|
||||
info.Action = action
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isKnownTaskField(field string) bool {
|
||||
knownFields := map[string]bool{
|
||||
"prompt": true,
|
||||
"model": true,
|
||||
"mode": true,
|
||||
"image": true,
|
||||
"images": true,
|
||||
"size": true,
|
||||
"duration": true,
|
||||
"input_reference": true, // Sora 特有字段
|
||||
}
|
||||
return knownFields[field]
|
||||
}
|
||||
|
||||
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
|
||||
var err error
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
var req TaskSubmitReq
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
req, err = validateMultipartTaskRequest(c, info, action)
|
||||
if err != nil {
|
||||
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
|
||||
}
|
||||
} else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
taskdoubao "one-api/relay/channel/task/doubao"
|
||||
taskjimeng "one-api/relay/channel/task/jimeng"
|
||||
"one-api/relay/channel/task/kling"
|
||||
tasksora "one-api/relay/channel/task/sora"
|
||||
"one-api/relay/channel/task/suno"
|
||||
taskvertex "one-api/relay/channel/task/vertex"
|
||||
taskVidu "one-api/relay/channel/task/vidu"
|
||||
@@ -137,6 +138,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
case constant.ChannelTypeDoubaoVideo:
|
||||
return &taskdoubao.TaskAdaptor{}
|
||||
case constant.ChannelTypeSora:
|
||||
return &tasksora.TaskAdaptor{}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -14,6 +14,11 @@ func SetVideoRouter(router *gin.Engine) {
|
||||
videoV1Router.POST("/video/generations", controller.RelayTask)
|
||||
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
|
||||
}
|
||||
// openai compatible API video routes
|
||||
// docs: https://platform.openai.com/docs/api-reference/videos/create
|
||||
{
|
||||
videoV1Router.POST("/videos", controller.RelayTask)
|
||||
}
|
||||
|
||||
klingV1Router := router.Group("/kling/v1")
|
||||
klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
|
||||
|
||||
@@ -169,6 +169,11 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'blue',
|
||||
label: '豆包视频',
|
||||
},
|
||||
{
|
||||
value: 55,
|
||||
color: 'green',
|
||||
label: 'Sora',
|
||||
},
|
||||
];
|
||||
|
||||
export const MODEL_TABLE_PAGE_SIZE = 10;
|
||||
|
||||
Reference in New Issue
Block a user