diff --git a/controller/video_proxy.go b/controller/video_proxy.go new file mode 100644 index 000000000..55ba707c0 --- /dev/null +++ b/controller/video_proxy.go @@ -0,0 +1,129 @@ +package controller + +import ( + "fmt" + "io" + "net/http" + "one-api/logger" + "one-api/model" + "time" + + "github.com/gin-gonic/gin" +) + +func VideoProxy(c *gin.Context) { + taskID := c.Param("task_id") + if taskID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "task_id is required", + "type": "invalid_request_error", + }, + }) + return + } + + task, exists, err := model.GetByOnlyTaskId(taskID) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to query task", + "type": "server_error", + }, + }) + return + } + if !exists || task == nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %s", taskID, err.Error())) + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "message": "Task not found", + "type": "invalid_request_error", + }, + }) + return + } + + if task.Status != model.TaskStatusSuccess { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status), + "type": "invalid_request_error", + }, + }) + return + } + + channel, err := model.CacheGetChannel(task.ChannelId) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel %d: %s", task.ChannelId, err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to retrieve channel information", + "type": "server_error", + }, + }) + return + } + baseURL := channel.GetBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + videoURL := fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID) + + client := &http.Client{ + Timeout: 60 * time.Second, + } + + req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, videoURL, nil) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request for %s: %s", videoURL, err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Failed to create proxy request", + "type": "server_error", + }, + }) + return + } + + req.Header.Set("Authorization", "Bearer "+channel.Key) + + resp, err := client.Do(req) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error())) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "message": "Failed to fetch video content", + "type": "server_error", + }, + }) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL)) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode), + "type": "server_error", + }, + }) + return + } + + for key, values := range resp.Header { + for _, value := range values { + c.Writer.Header().Add(key, value) + } + } + + c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error())) + } +} diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index fa46eafff..49fb8a852 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -12,6 +12,7 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" + "one-api/setting/system_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" @@ -166,7 +167,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e taskResult.Status = model.TaskStatusInProgress case "completed": taskResult.Status = model.TaskStatusSuccess - taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", a.baseURL, resTask.ID) + taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID) case "failed", "cancelled": taskResult.Status = model.TaskStatusFailure if resTask.Error != nil { diff --git a/router/video-router.go b/router/video-router.go index 2a5b22610..dd541fffa 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -9,6 +9,7 @@ import ( func SetVideoRouter(router *gin.Engine) { videoV1Router := router.Group("/v1") + videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy) videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { videoV1Router.POST("/video/generations", controller.RelayTask)