From 81e29aaa3db696a180077f3960d04a23ecde0157 Mon Sep 17 00:00:00 2001 From: Sh1n3zZ Date: Tue, 26 Aug 2025 08:29:26 +0800 Subject: [PATCH] feat: vertex veo (#1450) --- common/database.go | 2 +- controller/setup.go | 2 +- controller/task_video.go | 42 ++- main.go | 2 +- middleware/distributor.go | 2 +- relay/channel/task/vertex/adaptor.go | 344 ++++++++++++++++++++++++ relay/channel/vertex/adaptor.go | 1 + relay/channel/vertex/relay-vertex.go | 5 +- relay/channel/vertex/service_account.go | 47 +++- relay/relay_adaptor.go | 6 +- relay/relay_task.go | 96 ++++++- 11 files changed, 534 insertions(+), 15 deletions(-) create mode 100644 relay/channel/task/vertex/adaptor.go diff --git a/common/database.go b/common/database.go index 71dbd94d5..38a54d5e6 100644 --- a/common/database.go +++ b/common/database.go @@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries var UsingMySQL = false var UsingClickHouse = false -var SQLitePath = "one-api.db?_busy_timeout=30000" +var SQLitePath = "one-api.db?_busy_timeout=30000" \ No newline at end of file diff --git a/controller/setup.go b/controller/setup.go index 8943a1a02..44a7b3a73 100644 --- a/controller/setup.go +++ b/controller/setup.go @@ -178,4 +178,4 @@ func boolToString(b bool) string { return "true" } return "false" -} +} \ No newline at end of file diff --git a/controller/task_video.go b/controller/task_video.go index ffb6728ba..73d5c39b1 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -94,7 +94,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) } else { - task.Data = responseBody + task.Data = redactVideoResponseBody(responseBody) } now := time.Now().Unix() @@ -113,11 +113,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.StartTime = now } case model.TaskStatusSuccess: - task.Progress = "100%" + task.Progress = "100%" if task.FinishTime == 0 { task.FinishTime = now } - task.FailReason = taskResult.Url + if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") { + task.FailReason = taskResult.Url + } case model.TaskStatusFailure: task.Status = model.TaskStatusFailure task.Progress = "100%" @@ -146,3 +148,37 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha return nil } + +func redactVideoResponseBody(body []byte) []byte { + var m map[string]any + if err := json.Unmarshal(body, &m); err != nil { + return body + } + resp, _ := m["response"].(map[string]any) + if resp != nil { + delete(resp, "bytesBase64Encoded") + if v, ok := resp["video"].(string); ok { + resp["video"] = truncateBase64(v) + } + if vs, ok := resp["videos"].([]any); ok { + for i := range vs { + if vm, ok := vs[i].(map[string]any); ok { + delete(vm, "bytesBase64Encoded") + } + } + } + } + b, err := json.Marshal(m) + if err != nil { + return body + } + return b +} + +func truncateBase64(s string) string { + const maxKeep = 256 + if len(s) <= maxKeep { + return s + } + return s[:maxKeep] + "..." +} diff --git a/main.go b/main.go index 2dfddaccf..91311b867 100644 --- a/main.go +++ b/main.go @@ -208,4 +208,4 @@ func InitResources() error { return err } return nil -} +} \ No newline at end of file diff --git a/middleware/distributor.go b/middleware/distributor.go index 1e6df872d..7fefeda49 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -166,9 +166,9 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { - err = common.UnmarshalBodyReusable(c, &modelRequest) relayMode := relayconstant.RelayModeUnknown if c.Request.Method == http.MethodPost { + err = common.UnmarshalBodyReusable(c, &modelRequest) relayMode = relayconstant.RelayModeVideoSubmit } else if c.Request.Method == http.MethodGet { relayMode = relayconstant.RelayModeVideoFetchByID diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go new file mode 100644 index 000000000..d2ab826d0 --- /dev/null +++ b/relay/channel/task/vertex/adaptor.go @@ -0,0 +1,344 @@ +package vertex + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "strings" + + "github.com/gin-gonic/gin" + + "one-api/common" + "one-api/constant" + "one-api/dto" + "one-api/relay/channel" + vertexcore "one-api/relay/channel/vertex" + relaycommon "one-api/relay/common" + "one-api/service" +) + +type requestPayload struct { + Instances []map[string]any `json:"instances"` + Parameters map[string]any `json:"parameters,omitempty"` +} + +type submitResponse struct { + Name string `json:"name"` +} + +type operationVideo struct { + MimeType string `json:"mimeType"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` + Encoding string `json:"encoding"` +} + +type operationResponse struct { + Name string `json:"name"` + Done bool `json:"done"` + Response struct { + Type string `json:"@type"` + RaiMediaFilteredCount int `json:"raiMediaFilteredCount"` + Videos []operationVideo `json:"videos"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` + Encoding string `json:"encoding"` + Video string `json:"video"` + } `json:"response"` + Error struct { + Message string `json:"message"` + } `json:"error"` +} + +type TaskAdaptor struct{} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {} + +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + info.Action = constant.TaskActionTextGenerate + + req := relaycommon.TaskSubmitReq{} + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + } + if strings.TrimSpace(req.Prompt) == "" { + return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) + } + c.Set("task_request", req) + return nil +} + +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + adc := &vertexcore.Credentials{} + if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + return "", fmt.Errorf("failed to decode credentials: %w", err) + } + modelName := info.OriginModelName + if v, ok := getRequestModelFromContext(info); ok { + modelName = v + } + if modelName == "" { + modelName = "veo-3.0-generate-001" + } + + region := vertexcore.GetModelRegion(info.ApiVersion, modelName) + if strings.TrimSpace(region) == "" { + region = "global" + } + if region == "global" { + return fmt.Sprintf( + "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning", + adc.ProjectID, + modelName, + ), nil + } + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning", + region, + adc.ProjectID, + region, + modelName, + ), nil +} + +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + adc := &vertexcore.Credentials{} + if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + return fmt.Errorf("failed to decode credentials: %w", err) + } + + token, err := vertexcore.AcquireAccessToken(*adc, info.ChannelSetting.Proxy) + if err != nil { + return fmt.Errorf("failed to acquire access token: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("x-goog-user-project", adc.ProjectID) + return nil +} + +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) { + v, ok := c.Get("task_request") + if !ok { + return nil, fmt.Errorf("request not found in context") + } + req := v.(relaycommon.TaskSubmitReq) + + body := requestPayload{ + Instances: []map[string]any{{"prompt": req.Prompt}}, + Parameters: map[string]any{}, + } + if req.Metadata != nil { + if v, ok := req.Metadata["storageUri"]; ok { + body.Parameters["storageUri"] = v + } + if v, ok := req.Metadata["sampleCount"]; ok { + body.Parameters["sampleCount"] = v + } + } + if _, ok := body.Parameters["sampleCount"]; !ok { + body.Parameters["sampleCount"] = 1 + } + + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + _ = resp.Body.Close() + + var s submitResponse + if err := json.Unmarshal(responseBody, &s); err != nil { + return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) + } + if strings.TrimSpace(s.Name) == "" { + return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) + } + localID := encodeLocalTaskID(s.Name) + c.JSON(http.StatusOK, gin.H{"task_id": localID}) + return localID, responseBody, nil +} + +func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} } +func (a *TaskAdaptor) GetChannelName() string { return "vertex" } + +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + upstreamName, err := decodeLocalTaskID(taskID) + if err != nil { + return nil, fmt.Errorf("decode task_id failed: %w", err) + } + region := extractRegionFromOperationName(upstreamName) + if region == "" { + region = "us-central1" + } + project := extractProjectFromOperationName(upstreamName) + model := extractModelFromOperationName(upstreamName) + if project == "" || model == "" { + return nil, fmt.Errorf("cannot extract project/model from operation name") + } + var url string + if region == "global" { + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, model) + } else { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, model) + } + payload := map[string]string{"operationName": upstreamName} + data, err := json.Marshal(payload) + if err != nil { + return nil, err + } + adc := &vertexcore.Credentials{} + if err := json.Unmarshal([]byte(key), adc); err != nil { + return nil, fmt.Errorf("failed to decode credentials: %w", err) + } + token, err := vertexcore.AcquireAccessToken(*adc, "") + if err != nil { + return nil, fmt.Errorf("failed to acquire access token: %w", err) + } + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("x-goog-user-project", adc.ProjectID) + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + var op operationResponse + if err := json.Unmarshal(respBody, &op); err != nil { + return nil, fmt.Errorf("unmarshal operation response failed: %w", err) + } + ti := &relaycommon.TaskInfo{} + if op.Error.Message != "" { + ti.Status = "FAILURE" + ti.Reason = op.Error.Message + ti.Progress = "100%" + return ti, nil + } + if !op.Done { + ti.Status = "IN_PROGRESS" + ti.Progress = "50%" + return ti, nil + } + ti.Status = "SUCCESS" + ti.Progress = "100%" + if len(op.Response.Videos) > 0 { + v0 := op.Response.Videos[0] + if v0.BytesBase64Encoded != "" { + mime := strings.TrimSpace(v0.MimeType) + if mime == "" { + enc := strings.TrimSpace(v0.Encoding) + if enc == "" { + enc = "mp4" + } + if strings.Contains(enc, "/") { + mime = enc + } else { + mime = "video/" + enc + } + } + ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded + return ti, nil + } + } + if op.Response.BytesBase64Encoded != "" { + enc := strings.TrimSpace(op.Response.Encoding) + if enc == "" { + enc = "mp4" + } + mime := enc + if !strings.Contains(enc, "/") { + mime = "video/" + enc + } + ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded + return ti, nil + } + if op.Response.Video != "" { // some variants use `video` as base64 + enc := strings.TrimSpace(op.Response.Encoding) + if enc == "" { + enc = "mp4" + } + mime := enc + if !strings.Contains(enc, "/") { + mime = "video/" + enc + } + ti.Url = "data:" + mime + ";base64," + op.Response.Video + return ti, nil + } + return ti, nil +} + +func getRequestModelFromContext(info *relaycommon.TaskRelayInfo) (string, bool) { + return info.OriginModelName, info.OriginModelName != "" +} + +func encodeLocalTaskID(name string) string { + return base64.RawURLEncoding.EncodeToString([]byte(name)) +} + +func decodeLocalTaskID(local string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(local) + if err != nil { + return "", err + } + return string(b), nil +} + +var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`) + +func extractRegionFromOperationName(name string) string { + m := regionRe.FindStringSubmatch(name) + if len(m) == 2 { + return m[1] + } + return "" +} + +var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) + +func extractModelFromOperationName(name string) string { + m := modelRe.FindStringSubmatch(name) + if len(m) == 2 { + return m[1] + } + idx := strings.Index(name, "models/") + if idx >= 0 { + s := name[idx+len("models/"):] + if p := strings.Index(s, "/operations/"); p > 0 { + return s[:p] + } + } + return "" +} + +var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`) + +func extractProjectFromOperationName(name string) string { + m := projectRe.FindStringSubmatch(name) + if len(m) == 2 { + return m[1] + } + return "" +} diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 0b6b26743..d15592bf8 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -174,6 +174,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return err } req.Set("Authorization", "Bearer "+accessToken) + req.Set("x-goog-user-project", a.AccountCredentials.ProjectID) return nil } diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go index 5ed876654..f0b84906a 100644 --- a/relay/channel/vertex/relay-vertex.go +++ b/relay/channel/vertex/relay-vertex.go @@ -12,7 +12,10 @@ func GetModelRegion(other string, localModelName string) string { if m[localModelName] != nil { return m[localModelName].(string) } else { - return m["default"].(string) + if v, ok := m["default"]; ok { + return v.(string) + } + return "global" } } return other diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go index 9a4650d98..f90d5454d 100644 --- a/relay/channel/vertex/service_account.go +++ b/relay/channel/vertex/service_account.go @@ -6,14 +6,15 @@ import ( "encoding/json" "encoding/pem" "errors" - "github.com/bytedance/gopkg/cache/asynccache" - "github.com/golang-jwt/jwt" "net/http" "net/url" relaycommon "one-api/relay/common" "one-api/service" "strings" + "github.com/bytedance/gopkg/cache/asynccache" + "github.com/golang-jwt/jwt" + "fmt" "time" ) @@ -137,3 +138,45 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s return "", fmt.Errorf("failed to get access token: %v", result) } + +func AcquireAccessToken(creds Credentials, proxy string) (string, error) { + signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey) + if err != nil { + return "", fmt.Errorf("failed to create signed JWT: %w", err) + } + return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy) +} + +func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) { + authURL := "https://www.googleapis.com/oauth2/v4/token" + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") + data.Set("assertion", signedJWT) + + var client *http.Client + var err error + if proxy != "" { + client, err = service.NewProxyHttpClient(proxy) + if err != nil { + return "", fmt.Errorf("new proxy http client failed: %w", err) + } + } else { + client = service.GetHttpClient() + } + + resp, err := client.PostForm(authURL, data) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + + if accessToken, ok := result["access_token"].(string); ok { + return accessToken, nil + } + return "", fmt.Errorf("failed to get access token: %v", result) +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 1ee85986c..0c271210b 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -1,7 +1,6 @@ package relay import ( - "github.com/gin-gonic/gin" "one-api/constant" "one-api/relay/channel" "one-api/relay/channel/ali" @@ -28,6 +27,7 @@ import ( taskjimeng "one-api/relay/channel/task/jimeng" "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" + taskvertex "one-api/relay/channel/task/vertex" taskVidu "one-api/relay/channel/task/vidu" "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" @@ -37,6 +37,8 @@ import ( "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" "strconv" + + "github.com/gin-gonic/gin" ) func GetAdaptor(apiType int) channel.Adaptor { @@ -126,6 +128,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor { return &kling.TaskAdaptor{} case constant.ChannelTypeJimeng: return &taskjimeng.TaskAdaptor{} + case constant.ChannelTypeVertexAi: + return &taskvertex.TaskAdaptor{} case constant.ChannelTypeVidu: return &taskVidu.TaskAdaptor{} } diff --git a/relay/relay_task.go b/relay/relay_task.go index 95b8083b3..6faec176d 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -15,6 +15,8 @@ import ( relayconstant "one-api/relay/constant" "one-api/service" "one-api/setting/ratio_setting" + "strconv" + "strings" "github.com/gin-gonic/gin" ) @@ -32,6 +34,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if err != nil { return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError) } + relayInfo.InitChannelMeta(c) adaptor := GetTaskAdaptor(platform) if adaptor == nil { @@ -197,6 +200,9 @@ func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { if taskErr != nil { return taskErr } + if len(respBody) == 0 { + respBody = []byte("{\"code\":\"success\",\"data\":null}") + } c.Writer.Header().Set("Content-Type", "application/json") _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody)) @@ -276,10 +282,92 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ - Code: "success", - Data: TaskModel2Dto(originTask), - }) + func() { + channelModel, err2 := model.GetChannelById(originTask.ChannelId, true) + if err2 != nil { + return + } + if channelModel.Type != constant.ChannelTypeVertexAi { + return + } + baseURL := constant.ChannelBaseURLs[channelModel.Type] + if channelModel.GetBaseURL() != "" { + baseURL = channelModel.GetBaseURL() + } + adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) + if adaptor == nil { + return + } + resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ + "task_id": originTask.TaskID, + "action": originTask.Action, + }) + if err2 != nil || resp == nil { + return + } + defer resp.Body.Close() + body, err2 := io.ReadAll(resp.Body) + if err2 != nil { + return + } + ti, err2 := adaptor.ParseTaskResult(body) + if err2 == nil && ti != nil { + if ti.Status != "" { + originTask.Status = model.TaskStatus(ti.Status) + } + if ti.Progress != "" { + originTask.Progress = ti.Progress + } + if ti.Url != "" { + originTask.FailReason = ti.Url + } + _ = originTask.Update() + var raw map[string]any + _ = json.Unmarshal(body, &raw) + format := "mp4" + if respObj, ok := raw["response"].(map[string]any); ok { + if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 { + if v0, ok := vids[0].(map[string]any); ok { + if mt, ok := v0["mimeType"].(string); ok && mt != "" { + if strings.Contains(mt, "mp4") { + format = "mp4" + } else { + format = mt + } + } + } + } + } + status := "processing" + switch originTask.Status { + case model.TaskStatusSuccess: + status = "succeeded" + case model.TaskStatusFailure: + status = "failed" + case model.TaskStatusQueued, model.TaskStatusSubmitted: + status = "queued" + } + out := map[string]any{ + "error": nil, + "format": format, + "metadata": nil, + "status": status, + "task_id": originTask.TaskID, + "url": originTask.FailReason, + } + respBody, _ = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: out, + }) + } + }() + + if len(respBody) == 0 { + respBody, err = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: TaskModel2Dto(originTask), + }) + } return }