diff --git a/common/gin.go b/common/gin.go index 2cb358444..91a3a441a 100644 --- a/common/gin.go +++ b/common/gin.go @@ -3,6 +3,7 @@ package common import ( "bytes" "io" + "mime/multipart" "net/http" "one-api/constant" "strings" @@ -113,3 +114,26 @@ func ApiSuccess(c *gin.Context, data any) { "data": data, }) } + +func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { + requestBody, err := GetRequestBody(c) + if err != nil { + return nil, err + } + + contentType := c.Request.Header.Get("Content-Type") + boundary := "" + if idx := strings.Index(contentType, "boundary="); idx != -1 { + boundary = contentType[idx+9:] + } + + reader := multipart.NewReader(bytes.NewReader(requestBody), boundary) + form, err := reader.ReadForm(32 << 20) // 32 MB max memory + if err != nil { + return nil, err + } + + // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return form, nil +} diff --git a/constant/channel.go b/constant/channel.go index 1b5c2724b..426477e13 100644 --- a/constant/channel.go +++ b/constant/channel.go @@ -52,6 +52,7 @@ const ( ChannelTypeVidu = 52 ChannelTypeSubmodel = 53 ChannelTypeDoubaoVideo = 54 + ChannelTypeSora = 55 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -112,6 +113,7 @@ var ChannelBaseURLs = []string{ "https://api.vidu.cn", //52 "https://llm.submodel.ai", //53 "https://ark.cn-beijing.volces.com", //54 + "https://api.openai.com", //55 } var ChannelTypeNames = map[int]string{ @@ -166,6 +168,7 @@ var ChannelTypeNames = map[int]string{ ChannelTypeVidu: "Vidu", ChannelTypeSubmodel: "Submodel", ChannelTypeDoubaoVideo: "DoubaoVideo", + ChannelTypeSora: "Sora", } func GetChannelTypeName(channelType int) string { diff --git a/controller/task_video.go b/controller/task_video.go index ded011fe9..9bbf7a902 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -47,6 +47,11 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha if adaptor == nil { return fmt.Errorf("video adaptor not found") } + info := &relaycommon.RelayInfo{} + info.ChannelMeta = &relaycommon.ChannelMeta{ + ChannelBaseUrl: cacheGetChannel.GetBaseURL(), + } + adaptor.Init(info) for _, taskId := range taskIds { if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) diff --git a/controller/video_proxy.go b/controller/video_proxy.go new file mode 100644 index 000000000..55ba707c0 --- /dev/null +++ b/controller/video_proxy.go @@ -0,0 +1,129 @@ +package controller + +import ( + "fmt" + "io" + "net/http" + "one-api/logger" + "one-api/model" + "time" + + "github.com/gin-gonic/gin" +) + +func VideoProxy(c *gin.Context) { + taskID := c.Param("task_id") + if taskID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "task_id is required", + "type": "invalid_request_error", + }, + }) + return + } + + task, exists, err := model.GetByOnlyTaskId(taskID) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to query task", + "type": "server_error", + }, + }) + return + } + if !exists || task == nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %s", taskID, err.Error())) + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "message": "Task not found", + "type": "invalid_request_error", + }, + }) + return + } + + if task.Status != model.TaskStatusSuccess { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status), + "type": "invalid_request_error", + }, + }) + return + } + + channel, err := model.CacheGetChannel(task.ChannelId) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel %d: %s", task.ChannelId, err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to retrieve channel information", + "type": "server_error", + }, + }) + return + } + baseURL := channel.GetBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + videoURL := fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID) + + client := &http.Client{ + Timeout: 60 * time.Second, + } + + req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, videoURL, nil) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request for %s: %s", videoURL, err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to create proxy request", + "type": "server_error", + }, + }) + return + } + + req.Header.Set("Authorization", "Bearer "+channel.Key) + + resp, err := client.Do(req) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error())) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "message": "Failed to fetch video content", + "type": "server_error", + }, + }) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL)) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode), + "type": "server_error", + }, + }) + return + } + + for key, values := range resp.Header { + for _, value := range values { + c.Writer.Header().Add(key, value) + } + } + + c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error())) + } +} diff --git a/dto/openai_response.go b/dto/openai_response.go index 6353c15ff..7a3ddc68e 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -233,6 +233,16 @@ type Usage struct { Cost any `json:"cost,omitempty"` } +type OpenAIVideoResponse struct { + Id string `json:"id" example:"file-abc123"` + Object string `json:"object" example:"file"` + Bytes int64 `json:"bytes" example:"120000"` + CreatedAt int64 `json:"created_at" example:"1677610602"` + ExpiresAt int64 `json:"expires_at" example:"1677614202"` + Filename string `json:"filename" example:"mydata.jsonl"` + Purpose string `json:"purpose" example:"fine-tune"` +} + type InputTokenDetails struct { CachedTokens int `json:"cached_tokens"` CachedCreationTokens int `json:"-"` diff --git a/middleware/distributor.go b/middleware/distributor.go index 3d929df49..bbd3ba26a 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -165,6 +165,18 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) + } else if strings.Contains(c.Request.URL.Path, "/v1/videos") { + //curl https://api.openai.com/v1/videos \ + // -H "Authorization: Bearer $OPENAI_API_KEY" \ + // -F "model=sora-2" \ + // -F "prompt=A calico cat playing a piano on stage" + // -F input_reference="@image.jpg" + relayMode := relayconstant.RelayModeUnknown + if c.Request.Method == http.MethodPost { + relayMode = relayconstant.RelayModeVideoSubmit + modelRequest.Model = c.PostForm("model") + } + c.Set("relay_mode", relayMode) } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { relayMode := relayconstant.RelayModeUnknown if c.Request.Method == http.MethodPost { diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go new file mode 100644 index 000000000..49fb8a852 --- /dev/null +++ b/relay/channel/task/sora/adaptor.go @@ -0,0 +1,185 @@ +package sora + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/model" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" + "one-api/setting/system_setting" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +// ============================ +// Request / Response structures +// ============================ + +type ContentItem struct { + Type string `json:"type"` // "text" or "image_url" + Text string `json:"text,omitempty"` // for text type + ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type +} + +type ImageURL struct { + URL string `json:"url"` +} + +type responsePayload struct { + ID string `json:"id"` // task_id +} + +type responseTask struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + Status string `json:"status"` + Progress int `json:"progress"` + CreatedAt int64 `json:"created_at"` + CompletedAt int64 `json:"completed_at,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + Seconds string `json:"seconds,omitempty"` + Size string `json:"size,omitempty"` + RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"` + Error *struct { + Message string `json:"message"` + Code string `json:"code"` + } `json:"error,omitempty"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + apiKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.ChannelBaseUrl + a.apiKey = info.ApiKey +} + +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + return relaycommon.ValidateMultipartDirect(c, info) +} + +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/videos", a.baseURL), nil +} + +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + req.Header.Set("Authorization", "Bearer "+a.apiKey) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + return nil +} + +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { + cachedBody, err := common.GetRequestBody(c) + if err != nil { + return nil, errors.Wrap(err, "get_request_body_failed") + } + return bytes.NewReader(cachedBody), nil +} + +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + _ = resp.Body.Close() + + // Parse Sora response + var dResp responsePayload + if err := json.Unmarshal(responseBody, &dResp); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + if dResp.ID == "" { + taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID}) + return dResp.ID, responseBody, nil +} + +// FetchTask fetch task status +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + + uri := fmt.Sprintf("%s/v1/videos/%s", baseUrl, taskID) + + req, err := http.NewRequest(http.MethodGet, uri, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+key) + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return ModelList +} + +func (a *TaskAdaptor) GetChannelName() string { + return ChannelName +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + resTask := responseTask{} + if err := json.Unmarshal(respBody, &resTask); err != nil { + return nil, errors.Wrap(err, "unmarshal task result failed") + } + + taskResult := relaycommon.TaskInfo{ + Code: 0, + } + + switch resTask.Status { + case "queued", "pending": + taskResult.Status = model.TaskStatusQueued + case "processing", "in_progress": + taskResult.Status = model.TaskStatusInProgress + case "completed": + taskResult.Status = model.TaskStatusSuccess + taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID) + case "failed", "cancelled": + taskResult.Status = model.TaskStatusFailure + if resTask.Error != nil { + taskResult.Reason = resTask.Error.Message + } else { + taskResult.Reason = "task failed" + } + default: + } + if resTask.Progress > 0 && resTask.Progress < 100 { + taskResult.Progress = fmt.Sprintf("%d%%", resTask.Progress) + } + + return &taskResult, nil +} diff --git a/relay/channel/task/sora/constants.go b/relay/channel/task/sora/constants.go new file mode 100644 index 000000000..e2f6536ea --- /dev/null +++ b/relay/channel/task/sora/constants.go @@ -0,0 +1,8 @@ +package sora + +var ModelList = []string{ + "sora-2", + "sora-2-pro", +} + +var ChannelName = "sora" diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 3a721b479..f18c43741 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -6,6 +6,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "strconv" "strings" "github.com/gin-gonic/gin" @@ -52,7 +53,7 @@ func createTaskError(err error, code string, statusCode int, localError bool) *d } } -func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) { +func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) { info.Action = action c.Set("task_request", requestObj) } @@ -64,9 +65,97 @@ func validatePrompt(prompt string) *dto.TaskError { return nil } -func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError { +func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) { var req TaskSubmitReq - if err := common.UnmarshalBodyReusable(c, &req); err != nil { + if _, err := c.MultipartForm(); err != nil { + return req, err + } + + formData := c.Request.PostForm + req = TaskSubmitReq{ + Prompt: formData.Get("prompt"), + Model: formData.Get("model"), + Mode: formData.Get("mode"), + Image: formData.Get("image"), + Size: formData.Get("size"), + Metadata: make(map[string]interface{}), + } + + if durationStr := formData.Get("seconds"); durationStr != "" { + if duration, err := strconv.Atoi(durationStr); err == nil { + req.Duration = duration + } + } + + if images := formData["images"]; len(images) > 0 { + req.Images = images + } + + for key, values := range formData { + if len(values) > 0 && !isKnownTaskField(key) { + if intVal, err := strconv.Atoi(values[0]); err == nil { + req.Metadata[key] = intVal + } else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil { + req.Metadata[key] = floatVal + } else { + req.Metadata[key] = values[0] + } + } + } + return req, nil +} + +func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { + form, err := common.ParseMultipartFormReusable(c) + if err != nil { + return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true) + } + defer form.RemoveAll() + + prompts, ok := form.Value["prompt"] + if !ok || len(prompts) == 0 { + return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true) + } + if taskErr := validatePrompt(prompts[0]); taskErr != nil { + return taskErr + } + + if _, ok := form.Value["model"]; !ok { + return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) + } + action := constant.TaskActionTextGenerate + if _, ok := form.File["input_reference"]; ok { + action = constant.TaskActionGenerate + } + info.Action = action + + return nil +} + +func isKnownTaskField(field string) bool { + knownFields := map[string]bool{ + "prompt": true, + "model": true, + "mode": true, + "image": true, + "images": true, + "size": true, + "duration": true, + "input_reference": true, // Sora 特有字段 + } + return knownFields[field] +} + +func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError { + var err error + contentType := c.GetHeader("Content-Type") + var req TaskSubmitReq + if strings.HasPrefix(contentType, "multipart/form-data") { + req, err = validateMultipartTaskRequest(c, info, action) + if err != nil { + return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true) + } + } else if err := common.UnmarshalBodyReusable(c, &req); err != nil { return createTaskError(err, "invalid_request", http.StatusBadRequest, true) } diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index c8fd51a11..2017c9a6f 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -29,6 +29,7 @@ import ( taskdoubao "one-api/relay/channel/task/doubao" taskjimeng "one-api/relay/channel/task/jimeng" "one-api/relay/channel/task/kling" + tasksora "one-api/relay/channel/task/sora" "one-api/relay/channel/task/suno" taskvertex "one-api/relay/channel/task/vertex" taskVidu "one-api/relay/channel/task/vidu" @@ -137,6 +138,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor { return &taskVidu.TaskAdaptor{} case constant.ChannelTypeDoubaoVideo: return &taskdoubao.TaskAdaptor{} + case constant.ChannelTypeSora: + return &tasksora.TaskAdaptor{} } } return nil diff --git a/router/video-router.go b/router/video-router.go index bcc05eae9..dd541fffa 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -9,11 +9,17 @@ import ( func SetVideoRouter(router *gin.Engine) { videoV1Router := router.Group("/v1") + videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy) videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { videoV1Router.POST("/video/generations", controller.RelayTask) videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) } + // openai compatible API video routes + // docs: https://platform.openai.com/docs/api-reference/videos/create + { + videoV1Router.POST("/videos", controller.RelayTask) + } klingV1Router := router.Group("/kling/v1") klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute()) diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index ad6999365..dc55f0342 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -169,6 +169,11 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: '豆包视频', }, + { + value: 55, + color: 'green', + label: 'Sora', + }, ]; export const MODEL_TABLE_PAGE_SIZE = 10;