From 9629c8a77166368580e898ca425fb07b75f67fa9 Mon Sep 17 00:00:00 2001 From: Seefs <40468931+seefs001@users.noreply.github.com> Date: Fri, 31 Oct 2025 15:29:17 +0800 Subject: [PATCH] fix veo3 (#2140) --- controller/video_proxy.go | 26 ++++- controller/video_proxy_gemini.go | 158 +++++++++++++++++++++++++++ model/task.go | 53 +++++++-- relay/channel/task/gemini/adaptor.go | 52 ++++++++- relay/common/relay_info.go | 1 + relay/relay_task.go | 2 +- 6 files changed, 278 insertions(+), 14 deletions(-) create mode 100644 controller/video_proxy_gemini.go diff --git a/controller/video_proxy.go b/controller/video_proxy.go index 2bfb0dc29..c9801b4b5 100644 --- a/controller/video_proxy.go +++ b/controller/video_proxy.go @@ -92,8 +92,30 @@ func VideoProxy(c *gin.Context) { } if channel.Type == constant.ChannelTypeGemini { - videoURL = fmt.Sprintf("%s&key=%s", c.Query("url"), channel.Key) - req.Header.Set("x-goog-api-key", channel.Key) + apiKey := task.PrivateData.Key + if apiKey == "" { + logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID)) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "API key not stored for task", + "type": "server_error", + }, + }) + return + } + + videoURL, err = getGeminiVideoURL(channel, task, apiKey) + if err != nil { + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error())) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "message": "Failed to resolve Gemini video URL", + "type": "server_error", + }, + }) + return + } + req.Header.Set("x-goog-api-key", apiKey) } else { // Default (Sora, etc.): Use original logic videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID) diff --git a/controller/video_proxy_gemini.go b/controller/video_proxy_gemini.go new file mode 100644 index 000000000..4e2e60e62 --- /dev/null +++ b/controller/video_proxy_gemini.go @@ -0,0 +1,158 @@ +package controller + +import ( + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/relay" +) + +func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string) (string, error) { + if channel == nil || task == nil { + return "", fmt.Errorf("invalid channel or task") + } + + if url := extractGeminiVideoURLFromTaskData(task); url != "" { + return ensureAPIKey(url, apiKey), nil + } + + baseURL := constant.ChannelBaseURLs[channel.Type] + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() + } + + adaptor := relay.GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channel.Type))) + if adaptor == nil { + return "", fmt.Errorf("gemini task adaptor not found") + } + + if apiKey == "" { + return "", fmt.Errorf("api key not available for task") + } + + resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{ + "task_id": task.TaskID, + "action": task.Action, + }) + if err != nil { + return "", fmt.Errorf("fetch task failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read task response failed: %w", err) + } + + taskInfo, parseErr := adaptor.ParseTaskResult(body) + if parseErr == nil && taskInfo != nil && taskInfo.RemoteUrl != "" { + return ensureAPIKey(taskInfo.RemoteUrl, apiKey), nil + } + + if url := extractGeminiVideoURLFromPayload(body); url != "" { + return ensureAPIKey(url, apiKey), nil + } + + if parseErr != nil { + return "", fmt.Errorf("parse task result failed: %w", parseErr) + } + + return "", fmt.Errorf("gemini video url not found") +} + +func extractGeminiVideoURLFromTaskData(task *model.Task) string { + if task == nil || len(task.Data) == 0 { + return "" + } + var payload map[string]any + if err := json.Unmarshal(task.Data, &payload); err != nil { + return "" + } + return extractGeminiVideoURLFromMap(payload) +} + +func extractGeminiVideoURLFromPayload(body []byte) string { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return "" + } + return extractGeminiVideoURLFromMap(payload) +} + +func extractGeminiVideoURLFromMap(payload map[string]any) string { + if payload == nil { + return "" + } + if uri, ok := payload["uri"].(string); ok && uri != "" { + return uri + } + if resp, ok := payload["response"].(map[string]any); ok { + if uri := extractGeminiVideoURLFromResponse(resp); uri != "" { + return uri + } + } + return "" +} + +func extractGeminiVideoURLFromResponse(resp map[string]any) string { + if resp == nil { + return "" + } + if gvr, ok := resp["generateVideoResponse"].(map[string]any); ok { + if uri := extractGeminiVideoURLFromGeneratedSamples(gvr); uri != "" { + return uri + } + } + if videos, ok := resp["videos"].([]any); ok { + for _, video := range videos { + if vm, ok := video.(map[string]any); ok { + if uri, ok := vm["uri"].(string); ok && uri != "" { + return uri + } + } + } + } + if uri, ok := resp["video"].(string); ok && uri != "" { + return uri + } + if uri, ok := resp["uri"].(string); ok && uri != "" { + return uri + } + return "" +} + +func extractGeminiVideoURLFromGeneratedSamples(gvr map[string]any) string { + if gvr == nil { + return "" + } + if samples, ok := gvr["generatedSamples"].([]any); ok { + for _, sample := range samples { + if sm, ok := sample.(map[string]any); ok { + if video, ok := sm["video"].(map[string]any); ok { + if uri, ok := video["uri"].(string); ok && uri != "" { + return uri + } + } + } + } + } + return "" +} + +func ensureAPIKey(uri, key string) string { + if key == "" || uri == "" { + return uri + } + if strings.Contains(uri, "key=") { + return uri + } + if strings.Contains(uri, "?") { + return fmt.Sprintf("%s&key=%s", uri, key) + } + return fmt.Sprintf("%s?key=%s", uri, key) +} diff --git a/model/task.go b/model/task.go index a8c3a7d4d..994dd25c7 100644 --- a/model/task.go +++ b/model/task.go @@ -57,8 +57,9 @@ type Task struct { FinishTime int64 `json:"finish_time" gorm:"index"` Progress string `json:"progress" gorm:"type:varchar(20);index"` Properties Properties `json:"properties" gorm:"type:json"` - - Data json.RawMessage `json:"data" gorm:"type:json"` + // 禁止返回给用户,内部可能包含key等隐私信息 + PrivateData TaskPrivateData `json:"-" gorm:"column:private_data;type:json"` + Data json.RawMessage `json:"data" gorm:"type:json"` } func (t *Task) SetData(data any) { @@ -77,13 +78,39 @@ type Properties struct { func (m *Properties) Scan(val interface{}) error { bytesValue, _ := val.([]byte) + if len(bytesValue) == 0 { + m.Input = "" + return nil + } return json.Unmarshal(bytesValue, m) } func (m Properties) Value() (driver.Value, error) { + if m.Input == "" { + return nil, nil + } return json.Marshal(m) } +type TaskPrivateData struct { + Key string `json:"key,omitempty"` +} + +func (p *TaskPrivateData) Scan(val interface{}) error { + bytesValue, _ := val.([]byte) + if len(bytesValue) == 0 { + return nil + } + return json.Unmarshal(bytesValue, p) +} + +func (p TaskPrivateData) Value() (driver.Value, error) { + if (p == TaskPrivateData{}) { + return nil, nil + } + return json.Marshal(p) +} + // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 type SyncTaskQueryParams struct { Platform constant.TaskPlatform @@ -98,14 +125,22 @@ type SyncTaskQueryParams struct { } func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task { + properties := Properties{} + privateData := TaskPrivateData{} + if relayInfo != nil && relayInfo.ChannelMeta != nil && relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini { + privateData.Key = relayInfo.ChannelMeta.ApiKey + } + t := &Task{ - UserId: relayInfo.UserId, - Group: relayInfo.UsingGroup, - SubmitTime: time.Now().Unix(), - Status: TaskStatusNotStart, - Progress: "0%", - ChannelId: relayInfo.ChannelId, - Platform: platform, + UserId: relayInfo.UserId, + Group: relayInfo.UsingGroup, + SubmitTime: time.Now().Unix(), + Status: TaskStatusNotStart, + Progress: "0%", + ChannelId: relayInfo.ChannelId, + Platform: platform, + Properties: properties, + PrivateData: privateData, } return t } diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index 092059c67..0fa9dda4b 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "net/http" + "regexp" "strings" "time" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" @@ -248,17 +250,45 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e ti.Status = model.TaskStatusSuccess ti.Progress = "100%" + taskID := encodeLocalTaskID(op.Name) + ti.TaskID = taskID + ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID) + // Extract URL from generateVideoResponse if available if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 { if uri := op.Response.GenerateVideoResponse.GeneratedSamples[0].Video.URI; uri != "" { - taskID := encodeLocalTaskID(op.Name) - ti.Url = fmt.Sprintf("%s/v1/videos/%s/content?url=%s", system_setting.ServerAddress, taskID, uri) + ti.RemoteUrl = uri } } return ti, nil } +func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { + upstreamName, err := decodeLocalTaskID(task.TaskID) + if err != nil { + upstreamName = "" + } + modelName := extractModelFromOperationName(upstreamName) + if strings.TrimSpace(modelName) == "" { + modelName = "veo-3.0-generate-001" + } + + video := dto.NewOpenAIVideo() + video.ID = task.TaskID + video.Model = modelName + video.Status = task.Status.ToVideoStatus() + video.SetProgressStr(task.Progress) + video.CreatedAt = task.CreatedAt + if task.FinishTime > 0 { + video.CompletedAt = task.FinishTime + } else if task.UpdatedAt > 0 { + video.CompletedAt = task.UpdatedAt + } + + return common.Marshal(video) +} + // ============================ // helpers // ============================ @@ -274,3 +304,21 @@ func decodeLocalTaskID(local string) (string, error) { } return string(b), nil } + +var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) + +func extractModelFromOperationName(name string) string { + if name == "" { + return "" + } + if m := modelRe.FindStringSubmatch(name); len(m) == 2 { + return m[1] + } + if idx := strings.Index(name, "models/"); idx >= 0 { + s := name[idx+len("models/"):] + if p := strings.Index(s, "/operations/"); p > 0 { + return s[:p] + } + } + return "" +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 8f59a9056..b67c4143d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -509,6 +509,7 @@ type TaskInfo struct { Status string `json:"status"` Reason string `json:"reason,omitempty"` Url string `json:"url,omitempty"` + RemoteUrl string `json:"remote_url,omitempty"` Progress string `json:"progress,omitempty"` CompletionTokens int `json:"completion_tokens,omitempty"` // 用于按倍率计费 TotalTokens int `json:"total_tokens,omitempty"` // 用于按倍率计费 diff --git a/relay/relay_task.go b/relay/relay_task.go index ca1b0bb1f..61e2af523 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -319,7 +319,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d if err2 != nil { return } - if channelModel.Type != constant.ChannelTypeVertexAi { + if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { return } baseURL := constant.ChannelBaseURLs[channelModel.Type]