From 81e29aaa3db696a180077f3960d04a23ecde0157 Mon Sep 17 00:00:00 2001 From: Sh1n3zZ Date: Tue, 26 Aug 2025 08:29:26 +0800 Subject: [PATCH 01/32] feat: vertex veo (#1450) --- common/database.go | 2 +- controller/setup.go | 2 +- controller/task_video.go | 42 ++- main.go | 2 +- middleware/distributor.go | 2 +- relay/channel/task/vertex/adaptor.go | 344 ++++++++++++++++++++++++ relay/channel/vertex/adaptor.go | 1 + relay/channel/vertex/relay-vertex.go | 5 +- relay/channel/vertex/service_account.go | 47 +++- relay/relay_adaptor.go | 6 +- relay/relay_task.go | 96 ++++++- 11 files changed, 534 insertions(+), 15 deletions(-) create mode 100644 relay/channel/task/vertex/adaptor.go diff --git a/common/database.go b/common/database.go index 71dbd94d5..38a54d5e6 100644 --- a/common/database.go +++ b/common/database.go @@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries var UsingMySQL = false var UsingClickHouse = false -var SQLitePath = "one-api.db?_busy_timeout=30000" +var SQLitePath = "one-api.db?_busy_timeout=30000" \ No newline at end of file diff --git a/controller/setup.go b/controller/setup.go index 8943a1a02..44a7b3a73 100644 --- a/controller/setup.go +++ b/controller/setup.go @@ -178,4 +178,4 @@ func boolToString(b bool) string { return "true" } return "false" -} +} \ No newline at end of file diff --git a/controller/task_video.go b/controller/task_video.go index ffb6728ba..73d5c39b1 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -94,7 +94,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) } else { - task.Data = responseBody + task.Data = redactVideoResponseBody(responseBody) } now := time.Now().Unix() @@ -113,11 +113,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.StartTime = now } case model.TaskStatusSuccess: - task.Progress = "100%" + task.Progress = "100%" if task.FinishTime == 0 { task.FinishTime = now } - task.FailReason = taskResult.Url + if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") { + task.FailReason = taskResult.Url + } case model.TaskStatusFailure: task.Status = model.TaskStatusFailure task.Progress = "100%" @@ -146,3 +148,37 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha return nil } + +func redactVideoResponseBody(body []byte) []byte { + var m map[string]any + if err := json.Unmarshal(body, &m); err != nil { + return body + } + resp, _ := m["response"].(map[string]any) + if resp != nil { + delete(resp, "bytesBase64Encoded") + if v, ok := resp["video"].(string); ok { + resp["video"] = truncateBase64(v) + } + if vs, ok := resp["videos"].([]any); ok { + for i := range vs { + if vm, ok := vs[i].(map[string]any); ok { + delete(vm, "bytesBase64Encoded") + } + } + } + } + b, err := json.Marshal(m) + if err != nil { + return body + } + return b +} + +func truncateBase64(s string) string { + const maxKeep = 256 + if len(s) <= maxKeep { + return s + } + return s[:maxKeep] + "..." +} diff --git a/main.go b/main.go index 2dfddaccf..91311b867 100644 --- a/main.go +++ b/main.go @@ -208,4 +208,4 @@ func InitResources() error { return err } return nil -} +} \ No newline at end of file diff --git a/middleware/distributor.go b/middleware/distributor.go index 1e6df872d..7fefeda49 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -166,9 +166,9 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { - err = common.UnmarshalBodyReusable(c, &modelRequest) relayMode := relayconstant.RelayModeUnknown if c.Request.Method == http.MethodPost { + err = common.UnmarshalBodyReusable(c, &modelRequest) relayMode = relayconstant.RelayModeVideoSubmit } else if c.Request.Method == http.MethodGet { relayMode = relayconstant.RelayModeVideoFetchByID diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go new file mode 100644 index 000000000..d2ab826d0 --- /dev/null +++ b/relay/channel/task/vertex/adaptor.go @@ -0,0 +1,344 @@ +package vertex + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "strings" + + "github.com/gin-gonic/gin" + + "one-api/common" + "one-api/constant" + "one-api/dto" + "one-api/relay/channel" + vertexcore "one-api/relay/channel/vertex" + relaycommon "one-api/relay/common" + "one-api/service" +) + +type requestPayload struct { + Instances []map[string]any `json:"instances"` + Parameters map[string]any `json:"parameters,omitempty"` +} + +type submitResponse struct { + Name string `json:"name"` +} + +type operationVideo struct { + MimeType string `json:"mimeType"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` + Encoding string `json:"encoding"` +} + +type operationResponse struct { + Name string `json:"name"` + Done bool `json:"done"` + Response struct { + Type string `json:"@type"` + RaiMediaFilteredCount int `json:"raiMediaFilteredCount"` + Videos []operationVideo `json:"videos"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` + Encoding string `json:"encoding"` + Video string `json:"video"` + } `json:"response"` + Error struct { + Message string `json:"message"` + } `json:"error"` +} + +type TaskAdaptor struct{} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {} + +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + info.Action = constant.TaskActionTextGenerate + + req := relaycommon.TaskSubmitReq{} + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + } + if strings.TrimSpace(req.Prompt) == "" { + return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) + } + c.Set("task_request", req) + return nil +} + +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + adc := &vertexcore.Credentials{} + if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + return "", fmt.Errorf("failed to decode credentials: %w", err) + } + modelName := info.OriginModelName + if v, ok := getRequestModelFromContext(info); ok { + modelName = v + } + if modelName == "" { + modelName = "veo-3.0-generate-001" + } + + region := vertexcore.GetModelRegion(info.ApiVersion, modelName) + if strings.TrimSpace(region) == "" { + region = "global" + } + if region == "global" { + return fmt.Sprintf( + "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning", + adc.ProjectID, + modelName, + ), nil + } + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning", + region, + adc.ProjectID, + region, + modelName, + ), nil +} + +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + adc := &vertexcore.Credentials{} + if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + return fmt.Errorf("failed to decode credentials: %w", err) + } + + token, err := vertexcore.AcquireAccessToken(*adc, info.ChannelSetting.Proxy) + if err != nil { + return fmt.Errorf("failed to acquire access token: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("x-goog-user-project", adc.ProjectID) + return nil +} + +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) { + v, ok := c.Get("task_request") + if !ok { + return nil, fmt.Errorf("request not found in context") + } + req := v.(relaycommon.TaskSubmitReq) + + body := requestPayload{ + Instances: []map[string]any{{"prompt": req.Prompt}}, + Parameters: map[string]any{}, + } + if req.Metadata != nil { + if v, ok := req.Metadata["storageUri"]; ok { + body.Parameters["storageUri"] = v + } + if v, ok := req.Metadata["sampleCount"]; ok { + body.Parameters["sampleCount"] = v + } + } + if _, ok := body.Parameters["sampleCount"]; !ok { + body.Parameters["sampleCount"] = 1 + } + + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + _ = resp.Body.Close() + + var s submitResponse + if err := json.Unmarshal(responseBody, &s); err != nil { + return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) + } + if strings.TrimSpace(s.Name) == "" { + return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) + } + localID := encodeLocalTaskID(s.Name) + c.JSON(http.StatusOK, gin.H{"task_id": localID}) + return localID, responseBody, nil +} + +func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} } +func (a *TaskAdaptor) GetChannelName() string { return "vertex" } + +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + upstreamName, err := decodeLocalTaskID(taskID) + if err != nil { + return nil, fmt.Errorf("decode task_id failed: %w", err) + } + region := extractRegionFromOperationName(upstreamName) + if region == "" { + region = "us-central1" + } + project := extractProjectFromOperationName(upstreamName) + model := extractModelFromOperationName(upstreamName) + if project == "" || model == "" { + return nil, fmt.Errorf("cannot extract project/model from operation name") + } + var url string + if region == "global" { + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, model) + } else { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, model) + } + payload := map[string]string{"operationName": upstreamName} + data, err := json.Marshal(payload) + if err != nil { + return nil, err + } + adc := &vertexcore.Credentials{} + if err := json.Unmarshal([]byte(key), adc); err != nil { + return nil, fmt.Errorf("failed to decode credentials: %w", err) + } + token, err := vertexcore.AcquireAccessToken(*adc, "") + if err != nil { + return nil, fmt.Errorf("failed to acquire access token: %w", err) + } + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("x-goog-user-project", adc.ProjectID) + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + var op operationResponse + if err := json.Unmarshal(respBody, &op); err != nil { + return nil, fmt.Errorf("unmarshal operation response failed: %w", err) + } + ti := &relaycommon.TaskInfo{} + if op.Error.Message != "" { + ti.Status = "FAILURE" + ti.Reason = op.Error.Message + ti.Progress = "100%" + return ti, nil + } + if !op.Done { + ti.Status = "IN_PROGRESS" + ti.Progress = "50%" + return ti, nil + } + ti.Status = "SUCCESS" + ti.Progress = "100%" + if len(op.Response.Videos) > 0 { + v0 := op.Response.Videos[0] + if v0.BytesBase64Encoded != "" { + mime := strings.TrimSpace(v0.MimeType) + if mime == "" { + enc := strings.TrimSpace(v0.Encoding) + if enc == "" { + enc = "mp4" + } + if strings.Contains(enc, "/") { + mime = enc + } else { + mime = "video/" + enc + } + } + ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded + return ti, nil + } + } + if op.Response.BytesBase64Encoded != "" { + enc := strings.TrimSpace(op.Response.Encoding) + if enc == "" { + enc = "mp4" + } + mime := enc + if !strings.Contains(enc, "/") { + mime = "video/" + enc + } + ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded + return ti, nil + } + if op.Response.Video != "" { // some variants use `video` as base64 + enc := strings.TrimSpace(op.Response.Encoding) + if enc == "" { + enc = "mp4" + } + mime := enc + if !strings.Contains(enc, "/") { + mime = "video/" + enc + } + ti.Url = "data:" + mime + ";base64," + op.Response.Video + return ti, nil + } + return ti, nil +} + +func getRequestModelFromContext(info *relaycommon.TaskRelayInfo) (string, bool) { + return info.OriginModelName, info.OriginModelName != "" +} + +func encodeLocalTaskID(name string) string { + return base64.RawURLEncoding.EncodeToString([]byte(name)) +} + +func decodeLocalTaskID(local string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(local) + if err != nil { + return "", err + } + return string(b), nil +} + +var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`) + +func extractRegionFromOperationName(name string) string { + m := regionRe.FindStringSubmatch(name) + if len(m) == 2 { + return m[1] + } + return "" +} + +var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) + +func extractModelFromOperationName(name string) string { + m := modelRe.FindStringSubmatch(name) + if len(m) == 2 { + return m[1] + } + idx := strings.Index(name, "models/") + if idx >= 0 { + s := name[idx+len("models/"):] + if p := strings.Index(s, "/operations/"); p > 0 { + return s[:p] + } + } + return "" +} + +var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`) + +func extractProjectFromOperationName(name string) string { + m := projectRe.FindStringSubmatch(name) + if len(m) == 2 { + return m[1] + } + return "" +} diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 0b6b26743..d15592bf8 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -174,6 +174,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return err } req.Set("Authorization", "Bearer "+accessToken) + req.Set("x-goog-user-project", a.AccountCredentials.ProjectID) return nil } diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go index 5ed876654..f0b84906a 100644 --- a/relay/channel/vertex/relay-vertex.go +++ b/relay/channel/vertex/relay-vertex.go @@ -12,7 +12,10 @@ func GetModelRegion(other string, localModelName string) string { if m[localModelName] != nil { return m[localModelName].(string) } else { - return m["default"].(string) + if v, ok := m["default"]; ok { + return v.(string) + } + return "global" } } return other diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go index 9a4650d98..f90d5454d 100644 --- a/relay/channel/vertex/service_account.go +++ b/relay/channel/vertex/service_account.go @@ -6,14 +6,15 @@ import ( "encoding/json" "encoding/pem" "errors" - "github.com/bytedance/gopkg/cache/asynccache" - "github.com/golang-jwt/jwt" "net/http" "net/url" relaycommon "one-api/relay/common" "one-api/service" "strings" + "github.com/bytedance/gopkg/cache/asynccache" + "github.com/golang-jwt/jwt" + "fmt" "time" ) @@ -137,3 +138,45 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s return "", fmt.Errorf("failed to get access token: %v", result) } + +func AcquireAccessToken(creds Credentials, proxy string) (string, error) { + signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey) + if err != nil { + return "", fmt.Errorf("failed to create signed JWT: %w", err) + } + return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy) +} + +func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) { + authURL := "https://www.googleapis.com/oauth2/v4/token" + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") + data.Set("assertion", signedJWT) + + var client *http.Client + var err error + if proxy != "" { + client, err = service.NewProxyHttpClient(proxy) + if err != nil { + return "", fmt.Errorf("new proxy http client failed: %w", err) + } + } else { + client = service.GetHttpClient() + } + + resp, err := client.PostForm(authURL, data) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + + if accessToken, ok := result["access_token"].(string); ok { + return accessToken, nil + } + return "", fmt.Errorf("failed to get access token: %v", result) +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 1ee85986c..0c271210b 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -1,7 +1,6 @@ package relay import ( - "github.com/gin-gonic/gin" "one-api/constant" "one-api/relay/channel" "one-api/relay/channel/ali" @@ -28,6 +27,7 @@ import ( taskjimeng "one-api/relay/channel/task/jimeng" "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" + taskvertex "one-api/relay/channel/task/vertex" taskVidu "one-api/relay/channel/task/vidu" "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" @@ -37,6 +37,8 @@ import ( "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" "strconv" + + "github.com/gin-gonic/gin" ) func GetAdaptor(apiType int) channel.Adaptor { @@ -126,6 +128,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor { return &kling.TaskAdaptor{} case constant.ChannelTypeJimeng: return &taskjimeng.TaskAdaptor{} + case constant.ChannelTypeVertexAi: + return &taskvertex.TaskAdaptor{} case constant.ChannelTypeVidu: return &taskVidu.TaskAdaptor{} } diff --git a/relay/relay_task.go b/relay/relay_task.go index 95b8083b3..6faec176d 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -15,6 +15,8 @@ import ( relayconstant "one-api/relay/constant" "one-api/service" "one-api/setting/ratio_setting" + "strconv" + "strings" "github.com/gin-gonic/gin" ) @@ -32,6 +34,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if err != nil { return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError) } + relayInfo.InitChannelMeta(c) adaptor := GetTaskAdaptor(platform) if adaptor == nil { @@ -197,6 +200,9 @@ func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { if taskErr != nil { return taskErr } + if len(respBody) == 0 { + respBody = []byte("{\"code\":\"success\",\"data\":null}") + } c.Writer.Header().Set("Content-Type", "application/json") _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody)) @@ -276,10 +282,92 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ - Code: "success", - Data: TaskModel2Dto(originTask), - }) + func() { + channelModel, err2 := model.GetChannelById(originTask.ChannelId, true) + if err2 != nil { + return + } + if channelModel.Type != constant.ChannelTypeVertexAi { + return + } + baseURL := constant.ChannelBaseURLs[channelModel.Type] + if channelModel.GetBaseURL() != "" { + baseURL = channelModel.GetBaseURL() + } + adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) + if adaptor == nil { + return + } + resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ + "task_id": originTask.TaskID, + "action": originTask.Action, + }) + if err2 != nil || resp == nil { + return + } + defer resp.Body.Close() + body, err2 := io.ReadAll(resp.Body) + if err2 != nil { + return + } + ti, err2 := adaptor.ParseTaskResult(body) + if err2 == nil && ti != nil { + if ti.Status != "" { + originTask.Status = model.TaskStatus(ti.Status) + } + if ti.Progress != "" { + originTask.Progress = ti.Progress + } + if ti.Url != "" { + originTask.FailReason = ti.Url + } + _ = originTask.Update() + var raw map[string]any + _ = json.Unmarshal(body, &raw) + format := "mp4" + if respObj, ok := raw["response"].(map[string]any); ok { + if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 { + if v0, ok := vids[0].(map[string]any); ok { + if mt, ok := v0["mimeType"].(string); ok && mt != "" { + if strings.Contains(mt, "mp4") { + format = "mp4" + } else { + format = mt + } + } + } + } + } + status := "processing" + switch originTask.Status { + case model.TaskStatusSuccess: + status = "succeeded" + case model.TaskStatusFailure: + status = "failed" + case model.TaskStatusQueued, model.TaskStatusSubmitted: + status = "queued" + } + out := map[string]any{ + "error": nil, + "format": format, + "metadata": nil, + "status": status, + "task_id": originTask.TaskID, + "url": originTask.FailReason, + } + respBody, _ = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: out, + }) + } + }() + + if len(respBody) == 0 { + respBody, err = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: TaskModel2Dto(originTask), + }) + } return } From af94e11c7da4895c55acbf737b1868d03fdb7729 Mon Sep 17 00:00:00 2001 From: yunayj Date: Fri, 29 Aug 2025 19:06:01 +0800 Subject: [PATCH 02/32] =?UTF-8?q?=E4=BF=AE=E6=94=B9claude=20system?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E4=B8=BA=E6=95=B0=E7=BB=84=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=EF=BC=8C=E6=8F=90=E5=8D=87API=E5=85=BC=E5=AE=B9=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/claude/relay-claude.go | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 0c445bb9a..7550a97c8 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -274,19 +274,28 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe claudeMessages := make([]dto.ClaudeMessage, 0) isFirstMessage := true + // 初始化system消息数组,用于累积多个system消息 + var systemMessages []dto.ClaudeMediaMessage + for _, message := range formatMessages { if message.Role == "system" { + // 根据Claude API规范,system字段使用数组格式更有通用性 if message.IsStringContent() { - claudeRequest.System = message.StringContent() + systemMessages = append(systemMessages, dto.ClaudeMediaMessage{ + Type: "text", + Text: common.GetPointer[string](message.StringContent()), + }) } else { - contents := message.ParseContent() - content := "" - for _, ctx := range contents { + // 支持复合内容的system消息(虽然不常见,但需要考虑完整性) + for _, ctx := range message.ParseContent() { if ctx.Type == "text" { - content += ctx.Text + systemMessages = append(systemMessages, dto.ClaudeMediaMessage{ + Type: "text", + Text: common.GetPointer[string](ctx.Text), + }) } + // 未来可以在这里扩展对图片等其他类型的支持 } - claudeRequest.System = content } } else { if isFirstMessage { @@ -392,6 +401,12 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe claudeMessages = append(claudeMessages, claudeMessage) } } + + // 设置累积的system消息 + if len(systemMessages) > 0 { + claudeRequest.System = systemMessages + } + claudeRequest.Prompt = "" claudeRequest.Messages = claudeMessages return &claudeRequest, nil From 3064ff093a5a7705cdfd59555cbbaf20f0b3dc20 Mon Sep 17 00:00:00 2001 From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:45:00 +0800 Subject: [PATCH 03/32] Add request format conversion functionality Updated the features list to include request format conversion functionality and adjusted the order of items. --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 45b048340..48218cd70 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,11 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do - 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`) 16. 🔄 思考转内容功能 17. 🔄 针对用户的模型限流功能 -18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费: +18. 🔄 请求格式转换功能,支持以下三种格式转换: + 1. OpenAI Chat Completions => Claude Messages + 2. Clade Messages => OpenAI Chat Completions (可用于Claude Code调用第三方模型) + 3. OpenAI Chat Completions => Gemini Chat +20. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费: 1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项 2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费 3. 支持的渠道: From b29efbde5263c777d111c02b858ac0e2e6516d67 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Sun, 7 Sep 2025 23:03:19 +0800 Subject: [PATCH 04/32] feat(relay-claude): mapping stop reason and send text delta on block start type - convert claude stop reason "max_tokens" to openai "length" - send content_block_start content text delta --- relay/channel/claude/relay-claude.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 0c445bb9a..3c5524fa9 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -32,7 +32,7 @@ func stopReasonClaude2OpenAI(reason string) string { case "end_turn": return "stop" case "max_tokens": - return "max_tokens" + return "length" case "tool_use": return "tool_calls" default: @@ -426,7 +426,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse choice.Delta.Role = "assistant" } else if claudeResponse.Type == "content_block_start" { if claudeResponse.ContentBlock != nil { - //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) + // 如果是文本块,尽可能发送首段文本(若存在) + if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil { + choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text) + } if claudeResponse.ContentBlock.Type == "tool_use" { tools = append(tools, dto.ToolCallResponse{ Index: common.GetPointer(fcIdx), From c40a4f5444d123f1fb0eb5d109c852980871cd67 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 9 Sep 2025 23:18:07 +0800 Subject: [PATCH 05/32] fix: claude header was not set correctly --- relay/channel/aws/relay-aws.go | 7 ++++++- relay/channel/claude/relay-claude.go | 12 ++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 5822e363a..26e234fa3 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -130,7 +130,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* Usage: &dto.Usage{}, } - handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage) + // 复制上游 Content-Type 到客户端响应头 + if awsResp.ContentType != nil && *awsResp.ContentType != "" { + c.Writer.Header().Set("Content-Type", *awsResp.ContentType) + } + + handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage) if handlerErr != nil { return handlerErr, nil } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 511db2c6b..682256416 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -276,7 +276,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe isFirstMessage := true // 初始化system消息数组,用于累积多个system消息 var systemMessages []dto.ClaudeMediaMessage - + for _, message := range formatMessages { if message.Role == "system" { // 根据Claude API规范,system字段使用数组格式更有通用性 @@ -401,12 +401,12 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe claudeMessages = append(claudeMessages, claudeMessage) } } - + // 设置累积的system消息 if len(systemMessages) > 0 { claudeRequest.System = systemMessages } - + claudeRequest.Prompt = "" claudeRequest.Messages = claudeMessages return &claudeRequest, nil @@ -716,7 +716,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. return claudeInfo.Usage, nil } -func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError { +func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError { var claudeResponse dto.ClaudeResponse err := common.Unmarshal(data, &claudeResponse) if err != nil { @@ -754,7 +754,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests) } - service.IOCopyBytesGracefully(c, nil, responseData) + service.IOCopyBytesGracefully(c, httpResp, responseData) return nil } @@ -775,7 +775,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if common.DebugEnabled { println("responseBody: ", string(responseBody)) } - handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode) + handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode) if handleErr != nil { return nil, handleErr } From 041782c49e0289b9d2e64a318e81e4f75754dabf Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 9 Sep 2025 23:23:53 +0800 Subject: [PATCH 06/32] chore: remove PR branching strategy workflow file --- .github/workflows/pr-target-branch-check.yml | 21 -------------------- 1 file changed, 21 deletions(-) delete mode 100644 .github/workflows/pr-target-branch-check.yml diff --git a/.github/workflows/pr-target-branch-check.yml b/.github/workflows/pr-target-branch-check.yml deleted file mode 100644 index e7bd4c817..000000000 --- a/.github/workflows/pr-target-branch-check.yml +++ /dev/null @@ -1,21 +0,0 @@ -name: Check PR Branching Strategy -on: - pull_request: - types: [opened, synchronize, reopened, edited] - -jobs: - check-branching-strategy: - runs-on: ubuntu-latest - steps: - - name: Enforce branching strategy - run: | - if [[ "${{ github.base_ref }}" == "main" ]]; then - if [[ "${{ github.head_ref }}" != "alpha" ]]; then - echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch." - exit 1 - fi - elif [[ "${{ github.base_ref }}" != "alpha" ]]; then - echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch." - exit 1 - fi - echo "Branching strategy check passed." \ No newline at end of file From 3f9698bb470a8a6b6499c79a5f98c9ba3cfafab4 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Wed, 10 Sep 2025 15:29:07 +0800 Subject: [PATCH 07/32] =?UTF-8?q?feat:=20dalle=20=E8=87=AA=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=E5=AD=97=E6=AE=B5=E9=80=8F=E4=BC=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/openai_image.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/dto/openai_image.go b/dto/openai_image.go index 9e838688e..bc888dc71 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -59,6 +59,29 @@ func (i *ImageRequest) UnmarshalJSON(data []byte) error { return nil } +// 序列化时需要重新把字段平铺 +func (r ImageRequest) MarshalJSON() ([]byte, error) { + // 将已定义字段转为 map + type Alias ImageRequest + alias := Alias(r) + base, err := json.Marshal(alias) + if err != nil { + return nil, err + } + + var baseMap map[string]json.RawMessage + if err := json.Unmarshal(base, &baseMap); err != nil { + return nil, err + } + + // 合并 ExtraFields + for k, v := range r.Extra { + baseMap[k] = v + } + + return json.Marshal(baseMap) +} + func GetJSONFieldNames(t reflect.Type) map[string]struct{} { fields := make(map[string]struct{}) for i := 0; i < t.NumField(); i++ { From fcdfd027cd0140c98861cdc8e05050846344a75e Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Wed, 10 Sep 2025 15:30:23 +0800 Subject: [PATCH 08/32] =?UTF-8?q?fix:=20openai=20=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=20claude=20=E6=B2=A1=E8=AE=A1=E8=B4=B9=20cre?= =?UTF-8?q?ate=20cache=20token?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 2 +- relay/audio_handler.go | 2 +- relay/claude_handler.go | 2 +- relay/compatible_handler.go | 20 ++++++++++++++++++-- relay/embedding_handler.go | 2 +- relay/gemini_handler.go | 4 ++-- relay/image_handler.go | 2 +- relay/rerank_handler.go | 2 +- relay/responses_handler.go | 2 +- service/error.go | 6 ++++-- 10 files changed, 31 insertions(+), 13 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 5fc6d749c..5a668c488 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -235,7 +235,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - err := service.RelayErrorHandler(httpResp, true) + err := service.RelayErrorHandler(c.Request.Context(), httpResp, true) return testResult{ context: c, localErr: err, diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 711cc7a9b..1357e3816 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 59c052f62..dbdc6ee1c 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -111,7 +111,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index a3c6ace6e..8f27fd60b 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -158,7 +158,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newApiErr := service.RelayErrorHandler(httpResp, false) + newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr @@ -195,6 +195,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage imageTokens := usage.PromptTokensDetails.ImageTokens audioTokens := usage.PromptTokensDetails.AudioTokens completionTokens := usage.CompletionTokens + cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens + modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") @@ -204,6 +206,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage modelRatio := relayInfo.PriceData.ModelRatio groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio modelPrice := relayInfo.PriceData.ModelPrice + cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -211,12 +214,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage dImageTokens := decimal.NewFromInt(int64(imageTokens)) dAudioTokens := decimal.NewFromInt(int64(audioTokens)) dCompletionTokens := decimal.NewFromInt(int64(completionTokens)) + dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens)) dCompletionRatio := decimal.NewFromFloat(completionRatio) dCacheRatio := decimal.NewFromFloat(cacheRatio) dImageRatio := decimal.NewFromFloat(imageRatio) dModelRatio := decimal.NewFromFloat(modelRatio) dGroupRatio := decimal.NewFromFloat(groupRatio) dModelPrice := decimal.NewFromFloat(modelPrice) + dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) ratio := dModelRatio.Mul(dGroupRatio) @@ -284,6 +289,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage baseTokens = baseTokens.Sub(dCacheTokens) cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio) } + var dCachedCreationTokensWithRatio decimal.Decimal + if !dCachedCreationTokens.IsZero() { + baseTokens = baseTokens.Sub(dCachedCreationTokens) + dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio) + } // 减去 image tokens var imageTokensWithRatio decimal.Decimal @@ -302,7 +312,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()) } } - promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio) + promptQuota := baseTokens.Add(cachedTokensWithRatio). + Add(imageTokensWithRatio). + Add(dCachedCreationTokensWithRatio) completionQuota := dCompletionTokens.Mul(dCompletionRatio) @@ -395,6 +407,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage other["image_ratio"] = imageRatio other["image_output"] = imageTokens } + if cachedCreationTokens != 0 { + other["cache_creation_tokens"] = cachedCreationTokens + other["cache_creation_ratio"] = cachedCreationRatio + } if !dWebSearchQuota.IsZero() { if relayInfo.ResponsesUsageInfo != nil { if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 26dcf9719..3d8962bb4 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 460fd2f58..0252d6578 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -152,7 +152,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError @@ -249,7 +249,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } diff --git a/relay/image_handler.go b/relay/image_handler.go index 14a7103c3..e2789ae5e 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index fa3c7bbb4..46d2e25f6 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/responses_handler.go b/relay/responses_handler.go index f5f624c92..d1c5d2158 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/service/error.go b/service/error.go index ef5cbbde6..5c3bddd6e 100644 --- a/service/error.go +++ b/service/error.go @@ -1,12 +1,14 @@ package service import ( + "context" "errors" "fmt" "io" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/types" "strconv" "strings" @@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude return claudeErr } -func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { +func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode) responseBody, err := io.ReadAll(resp.Body) @@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) } else { if common.DebugEnabled { - println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) + logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) } newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) } From 27a0a447d0cf12c3b527f3797f4140dacd6498bc Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Wed, 10 Sep 2025 15:31:35 +0800 Subject: [PATCH 09/32] =?UTF-8?q?fix:=20err=20=E5=A6=82=E6=9E=9C=E6=98=AF?= =?UTF-8?q?=20newApiErr=20=E5=88=99=E4=BF=9D=E7=95=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/api_request.go | 3 +-- types/error.go | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index a50d5bdb5..a065caff7 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -264,9 +264,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } resp, err := client.Do(req) - if err != nil { - return nil, err + return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed")) } if resp == nil { return nil, errors.New("resp is nil") diff --git a/types/error.go b/types/error.go index f653e9a28..883ee0641 100644 --- a/types/error.go +++ b/types/error.go @@ -185,6 +185,14 @@ func (e *NewAPIError) ToClaudeError() ClaudeError { type NewAPIErrorOptions func(*NewAPIError) func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { + var newErr *NewAPIError + // 保留深层传递的 new err + if errors.As(err, &newErr) { + for _, op := range ops { + op(newErr) + } + return newErr + } e := &NewAPIError{ Err: err, RelayError: nil, @@ -199,8 +207,21 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI } func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { - if errorCode == ErrorCodeDoRequestFailed { - err = errors.New("upstream error: do request failed") + var newErr *NewAPIError + // 保留深层传递的 new err + if errors.As(err, &newErr) { + if newErr.RelayError == nil { + openaiError := OpenAIError{ + Message: newErr.Error(), + Type: string(errorCode), + Code: errorCode, + } + newErr.RelayError = openaiError + } + for _, op := range ops { + op(newErr) + } + return newErr } openaiError := OpenAIError{ Message: err.Error(), @@ -305,6 +326,15 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions { } } +func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions { + return func(e *NewAPIError) { + if common.DebugEnabled { + fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err) + } + e.Err = errors.New(replaceStr) + } +} + func IsRecordErrorLog(e *NewAPIError) bool { if e == nil { return false From cda73a2ec5be50c8b6723b8a84440845a4b30f45 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Wed, 10 Sep 2025 19:53:32 +0800 Subject: [PATCH 10/32] =?UTF-8?q?fix:=20dalle=20log=20=E6=98=BE=E7=A4=BA?= =?UTF-8?q?=E5=BC=A0=E6=95=B0=20N?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/relay.go | 13 ++++++------- relay/image_handler.go | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index d3d93192e..07c3aeaac 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -277,14 +277,13 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) - - gopool.Go(func() { - // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 - // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously - if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { + // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 + // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously + if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { + gopool.Go(func() { service.DisableChannel(channelError, err.Error()) - } - }) + }) + } if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) { // 保存错误日志到mysql中 diff --git a/relay/image_handler.go b/relay/image_handler.go index e2789ae5e..9c873d47f 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -120,7 +120,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type var logContent string if len(request.Size) > 0 { - logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality) + logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N) } postConsumeQuota(c, info, usage.(*dto.Usage), logContent) From 47aaa695b2c90b0a169a3010b0e91ab4d5fe9640 Mon Sep 17 00:00:00 2001 From: huanghejian Date: Wed, 10 Sep 2025 20:30:00 +0800 Subject: [PATCH 11/32] feat: support amazon nova --- relay/channel/aws/adaptor.go | 10 +++++ relay/channel/aws/constants.go | 11 +++++ relay/channel/aws/dto.go | 53 +++++++++++++++++++++++ relay/channel/aws/relay-aws.go | 78 ++++++++++++++++++++++++++++++++++ 4 files changed, 152 insertions(+) diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 1526a7f75..9d5e5891e 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + // 检查是否为Nova模型 + if isNovaModel(request.Model) { + novaReq := convertToNovaRequest(request) + c.Set("request_model", request.Model) + c.Set("converted_request", novaReq) + c.Set("is_nova_model", true) + return novaReq, nil + } + // 原有的Claude模型处理逻辑 var claudeReq *dto.ClaudeRequest var err error claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request) @@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } c.Set("request_model", claudeReq.Model) c.Set("converted_request", claudeReq) + c.Set("is_nova_model", false) return claudeReq, err } diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 3f8800b1e..8ed8f0318 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -1,5 +1,7 @@ package aws +import "strings" + var awsModelIDMap = map[string]string{ "claude-instant-1.2": "anthropic.claude-instant-v1", "claude-2.0": "anthropic.claude-v2", @@ -14,6 +16,10 @@ var awsModelIDMap = map[string]string{ "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0", "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0", "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0", + // Nova models + "amazon.nova-micro-v1:0": "us.amazon.nova-micro-v1:0", + "amazon.nova-lite-v1:0": "us.amazon.nova-lite-v1:0", + "amazon.nova-pro-v1:0": "us.amazon.nova-pro-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ @@ -67,3 +73,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{ } var ChannelName = "aws" + +// 判断是否为Nova模型 +func isNovaModel(modelId string) bool { + return strings.HasPrefix(modelId, "amazon.nova-") +} diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 0188c30a9..25851ff6f 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -34,3 +34,56 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest { Thinking: req.Thinking, } } + +// Nova模型使用messages-v1格式 +type NovaMessage struct { + Role string `json:"role"` + Content []NovaContent `json:"content"` +} + +type NovaContent struct { + Text string `json:"text"` +} + +type NovaRequest struct { + SchemaVersion string `json:"schemaVersion"` + Messages []NovaMessage `json:"messages"` + InferenceConfig NovaInferenceConfig `json:"inferenceConfig,omitempty"` +} + +type NovaInferenceConfig struct { + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` +} + +// 转换OpenAI请求为Nova格式 +func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest { + novaMessages := make([]NovaMessage, len(req.Messages)) + for i, msg := range req.Messages { + novaMessages[i] = NovaMessage{ + Role: msg.Role, + Content: []NovaContent{{Text: msg.StringContent()}}, + } + } + + novaReq := &NovaRequest{ + SchemaVersion: "messages-v1", + Messages: novaMessages, + } + + // 设置推理配置 + if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 { + if req.MaxTokens != 0 { + novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens) + } + if req.Temperature != nil && *req.Temperature != 0 { + novaReq.InferenceConfig.Temperature = *req.Temperature + } + if req.TopP != 0 { + novaReq.InferenceConfig.TopP = req.TopP + } + } + + return novaReq +} diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 26e234fa3..3df6b33dd 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -1,6 +1,7 @@ package aws import ( + "encoding/json" "fmt" "net/http" "one-api/common" @@ -93,7 +94,13 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* } awsModelId := awsModelID(c.GetString("request_model")) + // 检查是否为Nova模型 + isNova, _ := c.Get("is_nova_model") + if isNova == true { + return handleNovaRequest(c, awsCli, info, awsModelId) + } + // 原有的Claude处理逻辑 awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) if canCrossRegion { @@ -209,3 +216,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage) return nil, claudeInfo.Usage } + +// Nova模型处理函数 +func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) { + novaReq_, ok := c.Get("converted_request") + if !ok { + return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil + } + novaReq := novaReq_.(*NovaRequest) + + // 使用InvokeModel API,但使用Nova格式的请求体 + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + reqBody, err := json.Marshal(novaReq) + if err != nil { + return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil + } + awsReq.Body = reqBody + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil + } + + // 解析Nova响应 + var novaResp struct { + Output struct { + Message struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + } `json:"message"` + } `json:"output"` + Usage struct { + InputTokens int `json:"inputTokens"` + OutputTokens int `json:"outputTokens"` + TotalTokens int `json:"totalTokens"` + } `json:"usage"` + } + + if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil { + return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil + } + + // 构造OpenAI格式响应 + response := dto.OpenAITextResponse{ + Id: helper.GetResponseID(c), + Object: "chat.completion", + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + Choices: []dto.OpenAITextResponseChoice{{ + Index: 0, + Message: dto.Message{ + Role: "assistant", + Content: novaResp.Output.Message.Content[0].Text, + }, + FinishReason: "stop", + }}, + Usage: dto.Usage{ + PromptTokens: novaResp.Usage.InputTokens, + CompletionTokens: novaResp.Usage.OutputTokens, + TotalTokens: novaResp.Usage.TotalTokens, + }, + } + + c.JSON(http.StatusOK, response) + return nil, &response.Usage +} From 684caa36731ea63ab19a630a29debfbb26d435ec Mon Sep 17 00:00:00 2001 From: huanghejian Date: Thu, 11 Sep 2025 10:01:54 +0800 Subject: [PATCH 12/32] feat: amazon.nova-premier-v1:0 --- relay/channel/aws/constants.go | 7 ++++--- relay/channel/aws/dto.go | 16 +++++++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 8ed8f0318..7f18d57a1 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -17,9 +17,10 @@ var awsModelIDMap = map[string]string{ "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0", "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0", // Nova models - "amazon.nova-micro-v1:0": "us.amazon.nova-micro-v1:0", - "amazon.nova-lite-v1:0": "us.amazon.nova-lite-v1:0", - "amazon.nova-pro-v1:0": "us.amazon.nova-pro-v1:0", + "amazon.nova-micro-v1:0": "us.amazon.nova-micro-v1:0", + "amazon.nova-lite-v1:0": "us.amazon.nova-lite-v1:0", + "amazon.nova-pro-v1:0": "us.amazon.nova-pro-v1:0", + "amazon.nova-premier-v1:0": "us.amazon.nova-premier-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 25851ff6f..cef16c11f 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -35,7 +35,7 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest { } } -// Nova模型使用messages-v1格式 +// NovaMessage Nova模型使用messages-v1格式 type NovaMessage struct { Role string `json:"role"` Content []NovaContent `json:"content"` @@ -46,15 +46,17 @@ type NovaContent struct { } type NovaRequest struct { - SchemaVersion string `json:"schemaVersion"` - Messages []NovaMessage `json:"messages"` - InferenceConfig NovaInferenceConfig `json:"inferenceConfig,omitempty"` + SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0" + Messages []NovaMessage `json:"messages"` // 对话消息列表 + InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选 } type NovaInferenceConfig struct { - MaxTokens int `json:"maxTokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` + MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数 + Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1) + TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1) + TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128) + StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列 } // 转换OpenAI请求为Nova格式 From e3bc40f11b8bd3c57ca3435ba09af0b5b65a1c56 Mon Sep 17 00:00:00 2001 From: huanghejian Date: Thu, 11 Sep 2025 12:17:16 +0800 Subject: [PATCH 13/32] pref: support amazon nova --- relay/channel/aws/constants.go | 32 ++++++++++++++++++++++++++------ relay/channel/aws/dto.go | 1 + relay/channel/aws/relay-aws.go | 6 ++++++ 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 7f18d57a1..72d0f9890 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -17,10 +17,10 @@ var awsModelIDMap = map[string]string{ "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0", "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0", // Nova models - "amazon.nova-micro-v1:0": "us.amazon.nova-micro-v1:0", - "amazon.nova-lite-v1:0": "us.amazon.nova-lite-v1:0", - "amazon.nova-pro-v1:0": "us.amazon.nova-pro-v1:0", - "amazon.nova-premier-v1:0": "us.amazon.nova-premier-v1:0", + "nova-micro-v1:0": "amazon.nova-micro-v1:0", + "nova-lite-v1:0": "amazon.nova-lite-v1:0", + "nova-pro-v1:0": "amazon.nova-pro-v1:0", + "nova-premier-v1:0": "amazon.nova-premier-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ @@ -65,7 +65,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{ "anthropic.claude-opus-4-1-20250805-v1:0": { "us": true, }, -} + // Nova models - all support three major regions + "amazon.nova-micro-v1:0": { + "us": true, + "eu": true, + "apac": true, + }, + "amazon.nova-lite-v1:0": { + "us": true, + "eu": true, + "apac": true, + }, + "amazon.nova-pro-v1:0": { + "us": true, + "eu": true, + "apac": true, + }, + "amazon.nova-premier-v1:0": { + "us": true, + "eu": true, + "apac": true, + }} var awsRegionCrossModelPrefixMap = map[string]string{ "us": "us", @@ -77,5 +97,5 @@ var ChannelName = "aws" // 判断是否为Nova模型 func isNovaModel(modelId string) bool { - return strings.HasPrefix(modelId, "amazon.nova-") + return strings.HasPrefix(modelId, "nova-") } diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index cef16c11f..53daef288 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -76,6 +76,7 @@ func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest { // 设置推理配置 if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 { + novaReq.InferenceConfig = &NovaInferenceConfig{} if req.MaxTokens != 0 { novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens) } diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 3df6b33dd..eef26855a 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -97,6 +97,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* // 检查是否为Nova模型 isNova, _ := c.Get("is_nova_model") if isNova == true { + // Nova模型也支持跨区域 + awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) + canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) + if canCrossRegion { + awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix) + } return handleNovaRequest(c, awsCli, info, awsModelId) } From db6a788e0d4798c62922714a8e33d3f4780f095e Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Thu, 11 Sep 2025 12:28:57 +0800 Subject: [PATCH 14/32] =?UTF-8?q?fix:=20=E4=BC=98=E5=8C=96=20ImageRequest?= =?UTF-8?q?=20=E7=9A=84=20JSON=20=E5=BA=8F=E5=88=97=E5=8C=96=EF=BC=8C?= =?UTF-8?q?=E9=81=BF=E5=85=8D=E8=A6=86=E7=9B=96=E5=90=88=E5=B9=B6=20ExtraF?= =?UTF-8?q?ields?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/openai_image.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dto/openai_image.go b/dto/openai_image.go index bc888dc71..5aece25f2 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -64,19 +64,21 @@ func (r ImageRequest) MarshalJSON() ([]byte, error) { // 将已定义字段转为 map type Alias ImageRequest alias := Alias(r) - base, err := json.Marshal(alias) + base, err := common.Marshal(alias) if err != nil { return nil, err } var baseMap map[string]json.RawMessage - if err := json.Unmarshal(base, &baseMap); err != nil { + if err := common.Unmarshal(base, &baseMap); err != nil { return nil, err } // 合并 ExtraFields for k, v := range r.Extra { - baseMap[k] = v + if _, exists := baseMap[k]; !exists { + baseMap[k] = v + } } return json.Marshal(baseMap) From 70c27bc662fd4edb6487261538208bc0a2e802a9 Mon Sep 17 00:00:00 2001 From: huanghejian Date: Thu, 11 Sep 2025 12:31:43 +0800 Subject: [PATCH 15/32] feat: improve nova config --- relay/channel/aws/dto.go | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 53daef288..9c9fe946f 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -75,7 +75,7 @@ func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest { } // 设置推理配置 - if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 { + if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil { novaReq.InferenceConfig = &NovaInferenceConfig{} if req.MaxTokens != 0 { novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens) @@ -86,7 +86,40 @@ func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest { if req.TopP != 0 { novaReq.InferenceConfig.TopP = req.TopP } + if req.TopK != 0 { + novaReq.InferenceConfig.TopK = req.TopK + } + if req.Stop != nil { + if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 { + novaReq.InferenceConfig.StopSequences = stopSequences + } + } } return novaReq } + +// parseStopSequences 解析停止序列,支持字符串或字符串数组 +func parseStopSequences(stop any) []string { + if stop == nil { + return nil + } + + switch v := stop.(type) { + case string: + if v != "" { + return []string{v} + } + case []string: + return v + case []interface{}: + var sequences []string + for _, item := range v { + if str, ok := item.(string); ok && str != "" { + sequences = append(sequences, str) + } + } + return sequences + } + return nil +} From b25ac0bfb69ba6a5f1bd3f352567c7c8ad9a8f9e Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Thu, 11 Sep 2025 16:04:32 +0800 Subject: [PATCH 16/32] =?UTF-8?q?fix:=20=E9=A2=84=E6=89=A3=E9=A2=9D?= =?UTF-8?q?=E5=BA=A6=E4=BD=BF=E7=94=A8=20relay=20info=20=E4=BC=A0=E9=80=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/relay.go | 6 +++--- service/pre_consume_quota.go | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 07c3aeaac..23d725153 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { // common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta) - preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if newAPIError != nil { return } defer func() { // Only return quota if downstream failed and quota was actually pre-consumed - if newAPIError != nil && preConsumedQuota != 0 { - service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) + if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 { + service.ReturnPreConsumedQuota(c, relayInfo) } }() diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 86b04e526..3cfabc1a4 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -13,13 +13,13 @@ import ( "github.com/gin-gonic/gin" ) -func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { - if preConsumedQuota != 0 { - logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota))) +func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) { + if relayInfo.FinalPreConsumedQuota != 0 { + logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota))) gopool.Go(func() { relayInfoCopy := *relayInfo - err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) + err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false) if err != nil { common.SysLog("error return pre-consumed quota: " + err.Error()) } @@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr // PreConsumeQuota checks if the user has enough quota to pre-consume. // It returns the pre-consumed quota if successful, or an error if not. -func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) { +func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError { userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) + return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota <= 0 { - return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } if userQuota-preConsumedQuota < 0 { - return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } trustQuota := common.GetTrustQuota() @@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if preConsumedQuota > 0 { err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { - return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { - return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) } logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) } relayInfo.FinalPreConsumedQuota = preConsumedQuota - return preConsumedQuota, nil + return nil } From 93adcd57d7d851d90ee051e1daf8db7ea6b52655 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 11 Sep 2025 21:02:12 +0800 Subject: [PATCH 17/32] fix(responses): allow pass-through body for specific channel settings. (close #1762) --- relay/responses_handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/responses_handler.go b/relay/responses_handler.go index d1c5d2158..0c57a303f 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -41,7 +41,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * } adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry()) From b6c547ae982e83e34a1182578d68e3a8a9e86cf6 Mon Sep 17 00:00:00 2001 From: Zhaokun Zhang Date: Thu, 11 Sep 2025 21:34:49 +0800 Subject: [PATCH 18/32] =?UTF-8?q?fix:=20UI=20=E6=9C=AA=E5=AF=B9=E9=BD=90?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/pages/Setting/Operation/SettingsGeneral.jsx | 2 +- web/src/pages/Setting/Operation/SettingsHeaderNavModules.jsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/web/src/pages/Setting/Operation/SettingsGeneral.jsx b/web/src/pages/Setting/Operation/SettingsGeneral.jsx index c94c0dd5a..37b3dd984 100644 --- a/web/src/pages/Setting/Operation/SettingsGeneral.jsx +++ b/web/src/pages/Setting/Operation/SettingsGeneral.jsx @@ -194,7 +194,7 @@ export default function GeneralSettings(props) { /> - + From e68eed3d400785401c74f5bb3db21fd8b2f27b6a Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 12 Sep 2025 14:06:09 +0800 Subject: [PATCH 19/32] feat(channel): add support for Vertex AI key type configuration in settings --- controller/channel.go | 7 +- dto/channel_settings.go | 10 +- model/channel.go | 3 +- relay/channel/vertex/adaptor.go | 116 ++++++++------- .../channels/modals/EditChannelModal.jsx | 133 +++++++++++------- 5 files changed, 166 insertions(+), 103 deletions(-) diff --git a/controller/channel.go b/controller/channel.go index 70be91d42..403eb04cc 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/constant" + "one-api/dto" "one-api/model" "strconv" "strings" @@ -560,7 +561,7 @@ func AddChannel(c *gin.Context) { case "multi_to_single": addChannelRequest.Channel.ChannelInfo.IsMultiKey = true addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode - if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { + if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey { array, err := getVertexArrayKeys(addChannelRequest.Channel.Key) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -585,7 +586,7 @@ func AddChannel(c *gin.Context) { } keys = []string{addChannelRequest.Channel.Key} case "batch": - if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { + if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey { // multi json keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key) if err != nil { @@ -840,7 +841,7 @@ func UpdateChannel(c *gin.Context) { } // 处理 Vertex AI 的特殊情况 - if channel.Type == constant.ChannelTypeVertexAi { + if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey { // 尝试解析新密钥为JSON数组 if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") { array, err := getVertexArrayKeys(channel.Key) diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 2c58795cb..8791f516e 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -9,6 +9,14 @@ type ChannelSettings struct { SystemPromptOverride bool `json:"system_prompt_override,omitempty"` } +type VertexKeyType string + +const ( + VertexKeyTypeJSON VertexKeyType = "json" + VertexKeyTypeAPIKey VertexKeyType = "api_key" +) + type ChannelOtherSettings struct { - AzureResponsesVersion string `json:"azure_responses_version,omitempty"` + AzureResponsesVersion string `json:"azure_responses_version,omitempty"` + VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" } diff --git a/model/channel.go b/model/channel.go index a61b3eccf..534e2f3f2 100644 --- a/model/channel.go +++ b/model/channel.go @@ -42,7 +42,6 @@ type Channel struct { Priority *int64 `json:"priority" gorm:"bigint;default:0"` AutoBan *int `json:"auto_ban" gorm:"default:1"` OtherInfo string `json:"other_info"` - OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置 Tag *string `json:"tag" gorm:"index"` Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 ParamOverride *string `json:"param_override" gorm:"type:text"` @@ -51,6 +50,8 @@ type Channel struct { // add after v0.8.5 ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` + OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings + // cache info Keys []string `json:"-" gorm:"-"` } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 0b6b26743..b6a78b7aa 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "one-api/common" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/claude" @@ -80,16 +81,64 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } } -func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - adc := &Credentials{} - if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { - return "", fmt.Errorf("failed to decode credentials file: %w", err) - } +func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) { region := GetModelRegion(info.ApiVersion, info.OriginModelName) - a.AccountCredentials = *adc + if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey { + adc := &Credentials{} + if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil { + return "", fmt.Errorf("failed to decode credentials file: %w", err) + } + a.AccountCredentials = *adc + + if a.RequestMode == RequestModeLlama { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", + region, + adc.ProjectID, + region, + ), nil + } + + if region == "global" { + return fmt.Sprintf( + "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", + adc.ProjectID, + modelName, + suffix, + ), nil + } else { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + region, + adc.ProjectID, + region, + modelName, + suffix, + ), nil + } + } else { + if region == "global" { + return fmt.Sprintf( + "https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s", + modelName, + suffix, + info.ApiKey, + ), nil + } else { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s", + region, + modelName, + suffix, + info.ApiKey, + ), nil + } + } +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { suffix := "" if a.RequestMode == RequestModeGemini { - if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { // 新增逻辑:处理 -thinking- 格式 if strings.Contains(info.UpstreamModelName, "-thinking-") { @@ -112,23 +161,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { suffix = "predict" } - if region == "global" { - return fmt.Sprintf( - "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", - adc.ProjectID, - info.UpstreamModelName, - suffix, - ), nil - } else { - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", - region, - adc.ProjectID, - region, - info.UpstreamModelName, - suffix, - ), nil - } + return a.getRequestUrl(info, info.UpstreamModelName, suffix) } else if a.RequestMode == RequestModeClaude { if info.IsStream { suffix = "streamRawPredict?alt=sse" @@ -139,41 +172,22 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if v, ok := claudeModelMap[info.UpstreamModelName]; ok { model = v } - if region == "global" { - return fmt.Sprintf( - "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s", - adc.ProjectID, - model, - suffix, - ), nil - } else { - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", - region, - adc.ProjectID, - region, - model, - suffix, - ), nil - } + return a.getRequestUrl(info, model, suffix) } else if a.RequestMode == RequestModeLlama { - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", - region, - adc.ProjectID, - region, - ), nil + return a.getRequestUrl(info, "", "") } return "", errors.New("unsupported request mode") } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - accessToken, err := getAccessToken(a, info) - if err != nil { - return err + if info.ChannelOtherSettings.VertexKeyType == "json" { + accessToken, err := getAccessToken(a, info) + if err != nil { + return err + } + req.Set("Authorization", "Bearer "+accessToken) } - req.Set("Authorization", "Bearer "+accessToken) return nil } diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 7a86fa114..c0a216246 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -142,6 +142,8 @@ const EditChannelModal = (props) => { system_prompt: '', system_prompt_override: false, settings: '', + // 仅 Vertex: 密钥格式(存入 settings.vertex_key_type) + vertex_key_type: 'json', }; const [batch, setBatch] = useState(false); const [multiToSingle, setMultiToSingle] = useState(false); @@ -409,11 +411,17 @@ const EditChannelModal = (props) => { const parsedSettings = JSON.parse(data.settings); data.azure_responses_version = parsedSettings.azure_responses_version || ''; + // 读取 Vertex 密钥格式 + data.vertex_key_type = parsedSettings.vertex_key_type || 'json'; } catch (error) { console.error('解析其他设置失败:', error); data.azure_responses_version = ''; data.region = ''; + data.vertex_key_type = 'json'; } + } else { + // 兼容历史数据:老渠道没有 settings 时,默认按 json 展示 + data.vertex_key_type = 'json'; } setInputs(data); @@ -745,59 +753,56 @@ const EditChannelModal = (props) => { let localInputs = { ...formValues }; if (localInputs.type === 41) { - if (useManualInput) { - // 手动输入模式 - if (localInputs.key && localInputs.key.trim() !== '') { - try { - // 验证 JSON 格式 - const parsedKey = JSON.parse(localInputs.key); - // 确保是有效的密钥格式 - localInputs.key = JSON.stringify(parsedKey); - } catch (err) { - showError(t('密钥格式无效,请输入有效的 JSON 格式密钥')); - return; - } - } else if (!isEdit) { + const keyType = localInputs.vertex_key_type || 'json'; + if (keyType === 'api_key') { + // 直接作为普通字符串密钥处理 + if (!isEdit && (!localInputs.key || localInputs.key.trim() === '')) { showInfo(t('请输入密钥!')); return; } } else { - // 文件上传模式 - let keys = vertexKeys; - - // 若当前未选择文件,尝试从已上传文件列表解析(异步读取) - if (keys.length === 0 && vertexFileList.length > 0) { - try { - const parsed = await Promise.all( - vertexFileList.map(async (item) => { - const fileObj = item.fileInstance; - if (!fileObj) return null; - const txt = await fileObj.text(); - return JSON.parse(txt); - }), - ); - keys = parsed.filter(Boolean); - } catch (err) { - showError(t('解析密钥文件失败: {{msg}}', { msg: err.message })); + // JSON 服务账号密钥 + if (useManualInput) { + if (localInputs.key && localInputs.key.trim() !== '') { + try { + const parsedKey = JSON.parse(localInputs.key); + localInputs.key = JSON.stringify(parsedKey); + } catch (err) { + showError(t('密钥格式无效,请输入有效的 JSON 格式密钥')); + return; + } + } else if (!isEdit) { + showInfo(t('请输入密钥!')); return; } - } - - // 创建模式必须上传密钥;编辑模式可选 - if (keys.length === 0) { - if (!isEdit) { - showInfo(t('请上传密钥文件!')); - return; - } else { - // 编辑模式且未上传新密钥,不修改 key - delete localInputs.key; - } } else { - // 有新密钥,则覆盖 - if (batch) { - localInputs.key = JSON.stringify(keys); + // 文件上传模式 + let keys = vertexKeys; + if (keys.length === 0 && vertexFileList.length > 0) { + try { + const parsed = await Promise.all( + vertexFileList.map(async (item) => { + const fileObj = item.fileInstance; + if (!fileObj) return null; + const txt = await fileObj.text(); + return JSON.parse(txt); + }), + ); + keys = parsed.filter(Boolean); + } catch (err) { + showError(t('解析密钥文件失败: {{msg}}', { msg: err.message })); + return; + } + } + if (keys.length === 0) { + if (!isEdit) { + showInfo(t('请上传密钥文件!')); + return; + } else { + delete localInputs.key; + } } else { - localInputs.key = JSON.stringify(keys[0]); + localInputs.key = batch ? JSON.stringify(keys) : JSON.stringify(keys[0]); } } } @@ -853,6 +858,8 @@ const EditChannelModal = (props) => { delete localInputs.pass_through_body_enabled; delete localInputs.system_prompt; delete localInputs.system_prompt_override; + // 顶层的 vertex_key_type 不应发送给后端 + delete localInputs.vertex_key_type; let res; localInputs.auto_ban = localInputs.auto_ban ? 1 : 0; @@ -1178,8 +1185,40 @@ const EditChannelModal = (props) => { autoComplete='new-password' /> + {inputs.type === 41 && ( + { + // 更新设置中的 vertex_key_type + handleChannelOtherSettingsChange('vertex_key_type', value); + // 切换为 api_key 时,关闭批量与手动/文件切换,并清理已选文件 + if (value === 'api_key') { + setBatch(false); + setUseManualInput(false); + setVertexKeys([]); + setVertexFileList([]); + if (formApiRef.current) { + formApiRef.current.setValue('vertex_files', []); + } + } + }} + extraText={ + inputs.vertex_key_type === 'api_key' + ? t('API Key 模式下不支持批量创建') + : t('JSON 模式支持手动输入或上传服务账号 JSON') + } + /> + )} {batch ? ( - inputs.type === 41 ? ( + inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? ( { ) ) : ( <> - {inputs.type === 41 ? ( + {inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? ( <> {!batch && (
From d8410d2f11fdce79376531b1d752552efd17283f Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 12 Sep 2025 19:11:17 +0800 Subject: [PATCH 20/32] feat(payment): add payment settings configuration and update payment methods handling --- controller/channel-billing.go | 4 +- controller/channel.go | 7 +- controller/misc.go | 12 +- controller/topup.go | 63 +++++- controller/topup_stripe.go | 11 +- dto/channel_settings.go | 10 +- model/channel.go | 3 +- model/option.go | 24 +-- relay/channel/vertex/adaptor.go | 116 ++++++----- router/api-router.go | 1 + service/epay.go | 5 +- setting/operation_setting/payment_setting.go | 23 +++ .../payment_setting_old.go} | 21 +- .../components/settings/PaymentSetting.jsx | 26 +++ .../channels/modals/EditChannelModal.jsx | 133 ++++++++----- web/src/components/topup/RechargeCard.jsx | 182 +++++++++++------- web/src/components/topup/index.jsx | 173 +++++++++++------ .../topup/modals/PaymentConfirmModal.jsx | 39 +++- web/src/helpers/data.js | 1 - .../Payment/SettingsPaymentGateway.jsx | 76 ++++++++ 20 files changed, 655 insertions(+), 275 deletions(-) create mode 100644 setting/operation_setting/payment_setting.go rename setting/{payment.go => operation_setting/payment_setting_old.go} (57%) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 18acf2319..1082b9e73 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -10,7 +10,7 @@ import ( "one-api/constant" "one-api/model" "one-api/service" - "one-api/setting" + "one-api/setting/operation_setting" "one-api/types" "strconv" "time" @@ -342,7 +342,7 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) { return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode) } availableBalanceCny := response.Data.AvailableBalance - availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64() + availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64() channel.UpdateBalance(availableBalanceUsd) return availableBalanceUsd, nil } diff --git a/controller/channel.go b/controller/channel.go index 70be91d42..403eb04cc 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/constant" + "one-api/dto" "one-api/model" "strconv" "strings" @@ -560,7 +561,7 @@ func AddChannel(c *gin.Context) { case "multi_to_single": addChannelRequest.Channel.ChannelInfo.IsMultiKey = true addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode - if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { + if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey { array, err := getVertexArrayKeys(addChannelRequest.Channel.Key) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -585,7 +586,7 @@ func AddChannel(c *gin.Context) { } keys = []string{addChannelRequest.Channel.Key} case "batch": - if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { + if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey { // multi json keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key) if err != nil { @@ -840,7 +841,7 @@ func UpdateChannel(c *gin.Context) { } // 处理 Vertex AI 的特殊情况 - if channel.Type == constant.ChannelTypeVertexAi { + if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey { // 尝试解析新密钥为JSON数组 if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") { array, err := getVertexArrayKeys(channel.Key) diff --git a/controller/misc.go b/controller/misc.go index 897dad254..085829302 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -59,10 +59,6 @@ func GetStatus(c *gin.Context) { "wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_login": common.WeChatAuthEnabled, "server_address": setting.ServerAddress, - "price": setting.Price, - "stripe_unit_price": setting.StripeUnitPrice, - "min_topup": setting.MinTopUp, - "stripe_min_topup": setting.StripeMinTopUp, "turnstile_check": common.TurnstileCheckEnabled, "turnstile_site_key": common.TurnstileSiteKey, "top_up_link": common.TopUpLink, @@ -75,15 +71,15 @@ func GetStatus(c *gin.Context) { "enable_data_export": common.DataExportEnabled, "data_export_default_time": common.DataExportDefaultTime, "default_collapse_sidebar": common.DefaultCollapseSidebar, - "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", - "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "", "mj_notify_enabled": setting.MjNotifyEnabled, "chats": setting.Chats, "demo_site_enabled": operation_setting.DemoSiteEnabled, "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, "default_use_auto_group": setting.DefaultUseAutoGroup, - "pay_methods": setting.PayMethods, - "usd_exchange_rate": setting.USDExchangeRate, + + "usd_exchange_rate": operation_setting.USDExchangeRate, + "price": operation_setting.Price, + "stripe_unit_price": setting.StripeUnitPrice, // 面板启用开关 "api_info_enabled": cs.ApiInfoEnabled, diff --git a/controller/topup.go b/controller/topup.go index 3f3c86231..93f3e58e0 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -9,6 +9,7 @@ import ( "one-api/model" "one-api/service" "one-api/setting" + "one-api/setting/operation_setting" "strconv" "sync" "time" @@ -19,6 +20,44 @@ import ( "github.com/shopspring/decimal" ) +func GetTopUpInfo(c *gin.Context) { + // 获取支付方式 + payMethods := operation_setting.PayMethods + + // 如果启用了 Stripe 支付,添加到支付方法列表 + if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" { + // 检查是否已经包含 Stripe + hasStripe := false + for _, method := range payMethods { + if method["type"] == "stripe" { + hasStripe = true + break + } + } + + if !hasStripe { + stripeMethod := map[string]string{ + "name": "Stripe", + "type": "stripe", + "color": "rgba(var(--semi-purple-5), 1)", + "min_topup": strconv.Itoa(setting.StripeMinTopUp), + } + payMethods = append(payMethods, stripeMethod) + } + } + + data := gin.H{ + "enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "", + "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "", + "pay_methods": payMethods, + "min_topup": operation_setting.MinTopUp, + "stripe_min_topup": setting.StripeMinTopUp, + "amount_options": operation_setting.GetPaymentSetting().AmountOptions, + "discount": operation_setting.GetPaymentSetting().AmountDiscount, + } + common.ApiSuccess(c, data) +} + type EpayRequest struct { Amount int64 `json:"amount"` PaymentMethod string `json:"payment_method"` @@ -31,13 +70,13 @@ type AmountRequest struct { } func GetEpayClient() *epay.Client { - if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" { + if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" { return nil } withUrl, err := epay.NewClient(&epay.Config{ - PartnerID: setting.EpayId, - Key: setting.EpayKey, - }, setting.PayAddress) + PartnerID: operation_setting.EpayId, + Key: operation_setting.EpayKey, + }, operation_setting.PayAddress) if err != nil { return nil } @@ -58,15 +97,23 @@ func getPayMoney(amount int64, group string) float64 { } dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio) - dPrice := decimal.NewFromFloat(setting.Price) + dPrice := decimal.NewFromFloat(operation_setting.Price) + // apply optional preset discount by the original request amount (if configured), default 1.0 + discount := 1.0 + if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok { + if ds > 0 { + discount = ds + } + } + dDiscount := decimal.NewFromFloat(discount) - payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio) + payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio).Mul(dDiscount) return payMoney.InexactFloat64() } func getMinTopup() int64 { - minTopup := setting.MinTopUp + minTopup := operation_setting.MinTopUp if !common.DisplayInCurrencyEnabled { dMinTopup := decimal.NewFromInt(int64(minTopup)) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) @@ -99,7 +146,7 @@ func RequestEpay(c *gin.Context) { return } - if !setting.ContainsPayMethod(req.PaymentMethod) { + if !operation_setting.ContainsPayMethod(req.PaymentMethod) { c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"}) return } diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index eb3208092..bf0d7bf36 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -8,6 +8,7 @@ import ( "one-api/common" "one-api/model" "one-api/setting" + "one-api/setting/operation_setting" "strconv" "strings" "time" @@ -254,6 +255,7 @@ func GetChargedAmount(count float64, user model.User) float64 { } func getStripePayMoney(amount float64, group string) float64 { + originalAmount := amount if !common.DisplayInCurrencyEnabled { amount = amount / common.QuotaPerUnit } @@ -262,7 +264,14 @@ func getStripePayMoney(amount float64, group string) float64 { if topupGroupRatio == 0 { topupGroupRatio = 1 } - payMoney := amount * setting.StripeUnitPrice * topupGroupRatio + // apply optional preset discount by the original request amount (if configured), default 1.0 + discount := 1.0 + if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok { + if ds > 0 { + discount = ds + } + } + payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount return payMoney } diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 2c58795cb..8791f516e 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -9,6 +9,14 @@ type ChannelSettings struct { SystemPromptOverride bool `json:"system_prompt_override,omitempty"` } +type VertexKeyType string + +const ( + VertexKeyTypeJSON VertexKeyType = "json" + VertexKeyTypeAPIKey VertexKeyType = "api_key" +) + type ChannelOtherSettings struct { - AzureResponsesVersion string `json:"azure_responses_version,omitempty"` + AzureResponsesVersion string `json:"azure_responses_version,omitempty"` + VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" } diff --git a/model/channel.go b/model/channel.go index a61b3eccf..534e2f3f2 100644 --- a/model/channel.go +++ b/model/channel.go @@ -42,7 +42,6 @@ type Channel struct { Priority *int64 `json:"priority" gorm:"bigint;default:0"` AutoBan *int `json:"auto_ban" gorm:"default:1"` OtherInfo string `json:"other_info"` - OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置 Tag *string `json:"tag" gorm:"index"` Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 ParamOverride *string `json:"param_override" gorm:"type:text"` @@ -51,6 +50,8 @@ type Channel struct { // add after v0.8.5 ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` + OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings + // cache info Keys []string `json:"-" gorm:"-"` } diff --git a/model/option.go b/model/option.go index 2121710ce..73fe92ad1 100644 --- a/model/option.go +++ b/model/option.go @@ -73,9 +73,9 @@ func InitOptionMap() { common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["EpayId"] = "" common.OptionMap["EpayKey"] = "" - common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64) - common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64) - common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) + common.OptionMap["Price"] = strconv.FormatFloat(operation_setting.Price, 'f', -1, 64) + common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(operation_setting.USDExchangeRate, 'f', -1, 64) + common.OptionMap["MinTopUp"] = strconv.Itoa(operation_setting.MinTopUp) common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp) common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret @@ -85,7 +85,7 @@ func InitOptionMap() { common.OptionMap["Chats"] = setting.Chats2JsonString() common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString() common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup) - common.OptionMap["PayMethods"] = setting.PayMethods2JsonString() + common.OptionMap["PayMethods"] = operation_setting.PayMethods2JsonString() common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["TelegramBotToken"] = "" @@ -299,23 +299,23 @@ func updateOptionMap(key string, value string) (err error) { case "WorkerValidKey": setting.WorkerValidKey = value case "PayAddress": - setting.PayAddress = value + operation_setting.PayAddress = value case "Chats": err = setting.UpdateChatsByJsonString(value) case "AutoGroups": err = setting.UpdateAutoGroupsByJsonString(value) case "CustomCallbackAddress": - setting.CustomCallbackAddress = value + operation_setting.CustomCallbackAddress = value case "EpayId": - setting.EpayId = value + operation_setting.EpayId = value case "EpayKey": - setting.EpayKey = value + operation_setting.EpayKey = value case "Price": - setting.Price, _ = strconv.ParseFloat(value, 64) + operation_setting.Price, _ = strconv.ParseFloat(value, 64) case "USDExchangeRate": - setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64) + operation_setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64) case "MinTopUp": - setting.MinTopUp, _ = strconv.Atoi(value) + operation_setting.MinTopUp, _ = strconv.Atoi(value) case "StripeApiSecret": setting.StripeApiSecret = value case "StripeWebhookSecret": @@ -413,7 +413,7 @@ func updateOptionMap(key string, value string) (err error) { case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) case "PayMethods": - err = setting.UpdatePayMethodsByJsonString(value) + err = operation_setting.UpdatePayMethodsByJsonString(value) } return err } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 0b6b26743..b6a78b7aa 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "one-api/common" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/claude" @@ -80,16 +81,64 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } } -func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - adc := &Credentials{} - if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { - return "", fmt.Errorf("failed to decode credentials file: %w", err) - } +func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) { region := GetModelRegion(info.ApiVersion, info.OriginModelName) - a.AccountCredentials = *adc + if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey { + adc := &Credentials{} + if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil { + return "", fmt.Errorf("failed to decode credentials file: %w", err) + } + a.AccountCredentials = *adc + + if a.RequestMode == RequestModeLlama { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", + region, + adc.ProjectID, + region, + ), nil + } + + if region == "global" { + return fmt.Sprintf( + "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", + adc.ProjectID, + modelName, + suffix, + ), nil + } else { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + region, + adc.ProjectID, + region, + modelName, + suffix, + ), nil + } + } else { + if region == "global" { + return fmt.Sprintf( + "https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s", + modelName, + suffix, + info.ApiKey, + ), nil + } else { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s", + region, + modelName, + suffix, + info.ApiKey, + ), nil + } + } +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { suffix := "" if a.RequestMode == RequestModeGemini { - if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { // 新增逻辑:处理 -thinking- 格式 if strings.Contains(info.UpstreamModelName, "-thinking-") { @@ -112,23 +161,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { suffix = "predict" } - if region == "global" { - return fmt.Sprintf( - "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", - adc.ProjectID, - info.UpstreamModelName, - suffix, - ), nil - } else { - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", - region, - adc.ProjectID, - region, - info.UpstreamModelName, - suffix, - ), nil - } + return a.getRequestUrl(info, info.UpstreamModelName, suffix) } else if a.RequestMode == RequestModeClaude { if info.IsStream { suffix = "streamRawPredict?alt=sse" @@ -139,41 +172,22 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if v, ok := claudeModelMap[info.UpstreamModelName]; ok { model = v } - if region == "global" { - return fmt.Sprintf( - "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s", - adc.ProjectID, - model, - suffix, - ), nil - } else { - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", - region, - adc.ProjectID, - region, - model, - suffix, - ), nil - } + return a.getRequestUrl(info, model, suffix) } else if a.RequestMode == RequestModeLlama { - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", - region, - adc.ProjectID, - region, - ), nil + return a.getRequestUrl(info, "", "") } return "", errors.New("unsupported request mode") } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - accessToken, err := getAccessToken(a, info) - if err != nil { - return err + if info.ChannelOtherSettings.VertexKeyType == "json" { + accessToken, err := getAccessToken(a, info) + if err != nil { + return err + } + req.Set("Authorization", "Bearer "+accessToken) } - req.Set("Authorization", "Bearer "+accessToken) return nil } diff --git a/router/api-router.go b/router/api-router.go index 773857385..e16d06628 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -60,6 +60,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.DELETE("/self", controller.DeleteSelf) selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) + selfRoute.GET("/topup/info", controller.GetTopUpInfo) selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp) selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay) selfRoute.POST("/amount", controller.RequestAmount) diff --git a/service/epay.go b/service/epay.go index a8259d21d..a1ff484e6 100644 --- a/service/epay.go +++ b/service/epay.go @@ -2,11 +2,12 @@ package service import ( "one-api/setting" + "one-api/setting/operation_setting" ) func GetCallbackAddress() string { - if setting.CustomCallbackAddress == "" { + if operation_setting.CustomCallbackAddress == "" { return setting.ServerAddress } - return setting.CustomCallbackAddress + return operation_setting.CustomCallbackAddress } diff --git a/setting/operation_setting/payment_setting.go b/setting/operation_setting/payment_setting.go new file mode 100644 index 000000000..c8df039cf --- /dev/null +++ b/setting/operation_setting/payment_setting.go @@ -0,0 +1,23 @@ +package operation_setting + +import "one-api/setting/config" + +type PaymentSetting struct { + AmountOptions []int `json:"amount_options"` + AmountDiscount map[int]float64 `json:"amount_discount"` // 充值金额对应的折扣,例如 100 元 0.9 表示 100 元充值享受 9 折优惠 +} + +// 默认配置 +var paymentSetting = PaymentSetting{ + AmountOptions: []int{10, 20, 50, 100, 200, 500}, + AmountDiscount: map[int]float64{}, +} + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("payment_setting", &paymentSetting) +} + +func GetPaymentSetting() *PaymentSetting { + return &paymentSetting +} diff --git a/setting/payment.go b/setting/operation_setting/payment_setting_old.go similarity index 57% rename from setting/payment.go rename to setting/operation_setting/payment_setting_old.go index 7fc5ad3fd..a6313179e 100644 --- a/setting/payment.go +++ b/setting/operation_setting/payment_setting_old.go @@ -1,6 +1,13 @@ -package setting +/** +此文件为旧版支付设置文件,如需增加新的参数、变量等,请在 payment_setting.go 中添加 +This file is the old version of the payment settings file. If you need to add new parameters, variables, etc., please add them in payment_setting.go +*/ -import "encoding/json" +package operation_setting + +import ( + "one-api/common" +) var PayAddress = "" var CustomCallbackAddress = "" @@ -21,15 +28,21 @@ var PayMethods = []map[string]string{ "color": "rgba(var(--semi-green-5), 1)", "type": "wxpay", }, + { + "name": "自定义1", + "color": "black", + "type": "custom1", + "min_topup": "50", + }, } func UpdatePayMethodsByJsonString(jsonString string) error { PayMethods = make([]map[string]string, 0) - return json.Unmarshal([]byte(jsonString), &PayMethods) + return common.Unmarshal([]byte(jsonString), &PayMethods) } func PayMethods2JsonString() string { - jsonBytes, err := json.Marshal(PayMethods) + jsonBytes, err := common.Marshal(PayMethods) if err != nil { return "[]" } diff --git a/web/src/components/settings/PaymentSetting.jsx b/web/src/components/settings/PaymentSetting.jsx index a632760aa..faaa9561b 100644 --- a/web/src/components/settings/PaymentSetting.jsx +++ b/web/src/components/settings/PaymentSetting.jsx @@ -37,6 +37,8 @@ const PaymentSetting = () => { TopupGroupRatio: '', CustomCallbackAddress: '', PayMethods: '', + AmountOptions: '', + AmountDiscount: '', StripeApiSecret: '', StripeWebhookSecret: '', @@ -66,6 +68,30 @@ const PaymentSetting = () => { newInputs[item.key] = item.value; } break; + case 'payment_setting.amount_options': + try { + newInputs['AmountOptions'] = JSON.stringify( + JSON.parse(item.value), + null, + 2, + ); + } catch (error) { + console.error('解析AmountOptions出错:', error); + newInputs['AmountOptions'] = item.value; + } + break; + case 'payment_setting.amount_discount': + try { + newInputs['AmountDiscount'] = JSON.stringify( + JSON.parse(item.value), + null, + 2, + ); + } catch (error) { + console.error('解析AmountDiscount出错:', error); + newInputs['AmountDiscount'] = item.value; + } + break; case 'Price': case 'MinTopUp': case 'StripeUnitPrice': diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 7a86fa114..c0a216246 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -142,6 +142,8 @@ const EditChannelModal = (props) => { system_prompt: '', system_prompt_override: false, settings: '', + // 仅 Vertex: 密钥格式(存入 settings.vertex_key_type) + vertex_key_type: 'json', }; const [batch, setBatch] = useState(false); const [multiToSingle, setMultiToSingle] = useState(false); @@ -409,11 +411,17 @@ const EditChannelModal = (props) => { const parsedSettings = JSON.parse(data.settings); data.azure_responses_version = parsedSettings.azure_responses_version || ''; + // 读取 Vertex 密钥格式 + data.vertex_key_type = parsedSettings.vertex_key_type || 'json'; } catch (error) { console.error('解析其他设置失败:', error); data.azure_responses_version = ''; data.region = ''; + data.vertex_key_type = 'json'; } + } else { + // 兼容历史数据:老渠道没有 settings 时,默认按 json 展示 + data.vertex_key_type = 'json'; } setInputs(data); @@ -745,59 +753,56 @@ const EditChannelModal = (props) => { let localInputs = { ...formValues }; if (localInputs.type === 41) { - if (useManualInput) { - // 手动输入模式 - if (localInputs.key && localInputs.key.trim() !== '') { - try { - // 验证 JSON 格式 - const parsedKey = JSON.parse(localInputs.key); - // 确保是有效的密钥格式 - localInputs.key = JSON.stringify(parsedKey); - } catch (err) { - showError(t('密钥格式无效,请输入有效的 JSON 格式密钥')); - return; - } - } else if (!isEdit) { + const keyType = localInputs.vertex_key_type || 'json'; + if (keyType === 'api_key') { + // 直接作为普通字符串密钥处理 + if (!isEdit && (!localInputs.key || localInputs.key.trim() === '')) { showInfo(t('请输入密钥!')); return; } } else { - // 文件上传模式 - let keys = vertexKeys; - - // 若当前未选择文件,尝试从已上传文件列表解析(异步读取) - if (keys.length === 0 && vertexFileList.length > 0) { - try { - const parsed = await Promise.all( - vertexFileList.map(async (item) => { - const fileObj = item.fileInstance; - if (!fileObj) return null; - const txt = await fileObj.text(); - return JSON.parse(txt); - }), - ); - keys = parsed.filter(Boolean); - } catch (err) { - showError(t('解析密钥文件失败: {{msg}}', { msg: err.message })); + // JSON 服务账号密钥 + if (useManualInput) { + if (localInputs.key && localInputs.key.trim() !== '') { + try { + const parsedKey = JSON.parse(localInputs.key); + localInputs.key = JSON.stringify(parsedKey); + } catch (err) { + showError(t('密钥格式无效,请输入有效的 JSON 格式密钥')); + return; + } + } else if (!isEdit) { + showInfo(t('请输入密钥!')); return; } - } - - // 创建模式必须上传密钥;编辑模式可选 - if (keys.length === 0) { - if (!isEdit) { - showInfo(t('请上传密钥文件!')); - return; - } else { - // 编辑模式且未上传新密钥,不修改 key - delete localInputs.key; - } } else { - // 有新密钥,则覆盖 - if (batch) { - localInputs.key = JSON.stringify(keys); + // 文件上传模式 + let keys = vertexKeys; + if (keys.length === 0 && vertexFileList.length > 0) { + try { + const parsed = await Promise.all( + vertexFileList.map(async (item) => { + const fileObj = item.fileInstance; + if (!fileObj) return null; + const txt = await fileObj.text(); + return JSON.parse(txt); + }), + ); + keys = parsed.filter(Boolean); + } catch (err) { + showError(t('解析密钥文件失败: {{msg}}', { msg: err.message })); + return; + } + } + if (keys.length === 0) { + if (!isEdit) { + showInfo(t('请上传密钥文件!')); + return; + } else { + delete localInputs.key; + } } else { - localInputs.key = JSON.stringify(keys[0]); + localInputs.key = batch ? JSON.stringify(keys) : JSON.stringify(keys[0]); } } } @@ -853,6 +858,8 @@ const EditChannelModal = (props) => { delete localInputs.pass_through_body_enabled; delete localInputs.system_prompt; delete localInputs.system_prompt_override; + // 顶层的 vertex_key_type 不应发送给后端 + delete localInputs.vertex_key_type; let res; localInputs.auto_ban = localInputs.auto_ban ? 1 : 0; @@ -1178,8 +1185,40 @@ const EditChannelModal = (props) => { autoComplete='new-password' /> + {inputs.type === 41 && ( + { + // 更新设置中的 vertex_key_type + handleChannelOtherSettingsChange('vertex_key_type', value); + // 切换为 api_key 时,关闭批量与手动/文件切换,并清理已选文件 + if (value === 'api_key') { + setBatch(false); + setUseManualInput(false); + setVertexKeys([]); + setVertexFileList([]); + if (formApiRef.current) { + formApiRef.current.setValue('vertex_files', []); + } + } + }} + extraText={ + inputs.vertex_key_type === 'api_key' + ? t('API Key 模式下不支持批量创建') + : t('JSON 模式支持手动输入或上传服务账号 JSON') + } + /> + )} {batch ? ( - inputs.type === 41 ? ( + inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? ( { ) ) : ( <> - {inputs.type === 41 ? ( + {inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? ( <> {!batch && (
diff --git a/web/src/components/topup/RechargeCard.jsx b/web/src/components/topup/RechargeCard.jsx index 7fb06b0ca..f23381f40 100644 --- a/web/src/components/topup/RechargeCard.jsx +++ b/web/src/components/topup/RechargeCard.jsx @@ -21,6 +21,7 @@ import React, { useRef } from 'react'; import { Avatar, Typography, + Tag, Card, Button, Banner, @@ -29,7 +30,7 @@ import { Space, Row, Col, - Spin, + Spin, Tooltip } from '@douyinfe/semi-ui'; import { SiAlipay, SiWechat, SiStripe } from 'react-icons/si'; import { CreditCard, Coins, Wallet, BarChart2, TrendingUp } from 'lucide-react'; @@ -68,6 +69,7 @@ const RechargeCard = ({ userState, renderQuota, statusLoading, + topupInfo, }) => { const onlineFormApiRef = useRef(null); const redeemFormApiRef = useRef(null); @@ -261,44 +263,58 @@ const RechargeCard = ({ - - {payMethods.map((payMethod) => ( - - ))} - + {payMethods && payMethods.length > 0 ? ( + + {payMethods.map((payMethod) => { + const minTopupVal = Number(payMethod.min_topup) || 0; + const isStripe = payMethod.type === 'stripe'; + const disabled = + (!enableOnlineTopUp && !isStripe) || + (!enableStripeTopUp && isStripe) || + minTopupVal > Number(topUpCount || 0); + + const buttonEl = ( + + ); + + return disabled && minTopupVal > Number(topUpCount || 0) ? ( + + {buttonEl} + + ) : ( + {buttonEl} + ); + })} + + ) : ( +
+ {t('暂无可用的支付方式,请联系管理员配置')} +
+ )}
@@ -306,41 +322,59 @@ const RechargeCard = ({ {(enableOnlineTopUp || enableStripeTopUp) && ( - - {presetAmounts.map((preset, index) => ( - - ))} - +
+ {presetAmounts.map((preset, index) => { + const discount = preset.discount || topupInfo?.discount?.[preset.value] || 1.0; + const originalPrice = preset.value * priceRatio; + const discountedPrice = originalPrice * discount; + const hasDiscount = discount < 1.0; + const actualPay = discountedPrice; + const save = originalPrice - discountedPrice; + + return ( + { + selectPresetAmount(preset); + onlineFormApiRef.current?.setValue( + 'topUpCount', + preset.value, + ); + }} + > +
+ + {formatLargeNumber(preset.value)} {t('美元额度')} + {hasDiscount && ( + + {t('折').includes('off') ? + ((1 - discount) * 100).toFixed(1) : + (discount * 10).toFixed(1)}{t('折')} + + )} + +
+ {t('实付')} {actualPay.toFixed(2)}, + {hasDiscount ? `${t('节省')} ${save.toFixed(2)}` : `${t('节省')} 0.00`} +
+
+
+ ); + })} +
)}
diff --git a/web/src/components/topup/index.jsx b/web/src/components/topup/index.jsx index a09244488..929a47e39 100644 --- a/web/src/components/topup/index.jsx +++ b/web/src/components/topup/index.jsx @@ -80,6 +80,12 @@ const TopUp = () => { // 预设充值额度选项 const [presetAmounts, setPresetAmounts] = useState([]); const [selectedPreset, setSelectedPreset] = useState(null); + + // 充值配置信息 + const [topupInfo, setTopupInfo] = useState({ + amount_options: [], + discount: {} + }); const topUp = async () => { if (redemptionCode === '') { @@ -248,6 +254,99 @@ const TopUp = () => { } }; + // 获取充值配置信息 + const getTopupInfo = async () => { + try { + const res = await API.get('/api/user/topup/info'); + const { message, data, success } = res.data; + if (success) { + setTopupInfo({ + amount_options: data.amount_options || [], + discount: data.discount || {} + }); + + // 处理支付方式 + let payMethods = data.pay_methods || []; + try { + if (typeof payMethods === 'string') { + payMethods = JSON.parse(payMethods); + } + if (payMethods && payMethods.length > 0) { + // 检查name和type是否为空 + payMethods = payMethods.filter((method) => { + return method.name && method.type; + }); + // 如果没有color,则设置默认颜色 + payMethods = payMethods.map((method) => { + // 规范化最小充值数 + const normalizedMinTopup = Number(method.min_topup); + method.min_topup = Number.isFinite(normalizedMinTopup) ? normalizedMinTopup : 0; + + // Stripe 的最小充值从后端字段回填 + if (method.type === 'stripe' && (!method.min_topup || method.min_topup <= 0)) { + const stripeMin = Number(data.stripe_min_topup); + if (Number.isFinite(stripeMin)) { + method.min_topup = stripeMin; + } + } + + if (!method.color) { + if (method.type === 'alipay') { + method.color = 'rgba(var(--semi-blue-5), 1)'; + } else if (method.type === 'wxpay') { + method.color = 'rgba(var(--semi-green-5), 1)'; + } else if (method.type === 'stripe') { + method.color = 'rgba(var(--semi-purple-5), 1)'; + } else { + method.color = 'rgba(var(--semi-primary-5), 1)'; + } + } + return method; + }); + } else { + payMethods = []; + } + + // 如果启用了 Stripe 支付,添加到支付方法列表 + // 这个逻辑现在由后端处理,如果 Stripe 启用,后端会在 pay_methods 中包含它 + + setPayMethods(payMethods); + const enableStripeTopUp = data.enable_stripe_topup || false; + const enableOnlineTopUp = data.enable_online_topup || false; + const minTopUpValue = enableOnlineTopUp? data.min_topup : enableStripeTopUp? data.stripe_min_topup : 1; + setEnableOnlineTopUp(enableOnlineTopUp); + setEnableStripeTopUp(enableStripeTopUp); + setMinTopUp(minTopUpValue); + setTopUpCount(minTopUpValue); + + // 如果没有自定义充值数量选项,根据最小充值金额生成预设充值额度选项 + if (topupInfo.amount_options.length === 0) { + setPresetAmounts(generatePresetAmounts(minTopUpValue)); + } + + // 初始化显示实付金额 + getAmount(minTopUpValue); + } catch (e) { + console.log('解析支付方式失败:', e); + setPayMethods([]); + } + + // 如果有自定义充值数量选项,使用它们替换默认的预设选项 + if (data.amount_options && data.amount_options.length > 0) { + const customPresets = data.amount_options.map(amount => ({ + value: amount, + discount: data.discount[amount] || 1.0 + })); + setPresetAmounts(customPresets); + } + } else { + console.error('获取充值配置失败:', data); + } + } catch (error) { + console.error('获取充值配置异常:', error); + } + }; + // 获取邀请链接 const getAffLink = async () => { const res = await API.get('/api/user/aff'); @@ -290,52 +389,7 @@ const TopUp = () => { getUserQuota().then(); } setTransferAmount(getQuotaPerUnit()); - - let payMethods = localStorage.getItem('pay_methods'); - try { - payMethods = JSON.parse(payMethods); - if (payMethods && payMethods.length > 0) { - // 检查name和type是否为空 - payMethods = payMethods.filter((method) => { - return method.name && method.type; - }); - // 如果没有color,则设置默认颜色 - payMethods = payMethods.map((method) => { - if (!method.color) { - if (method.type === 'alipay') { - method.color = 'rgba(var(--semi-blue-5), 1)'; - } else if (method.type === 'wxpay') { - method.color = 'rgba(var(--semi-green-5), 1)'; - } else if (method.type === 'stripe') { - method.color = 'rgba(var(--semi-purple-5), 1)'; - } else { - method.color = 'rgba(var(--semi-primary-5), 1)'; - } - } - return method; - }); - } else { - payMethods = []; - } - - // 如果启用了 Stripe 支付,添加到支付方法列表 - if (statusState?.status?.enable_stripe_topup) { - const hasStripe = payMethods.some((method) => method.type === 'stripe'); - if (!hasStripe) { - payMethods.push({ - name: 'Stripe', - type: 'stripe', - color: 'rgba(var(--semi-purple-5), 1)', - }); - } - } - - setPayMethods(payMethods); - } catch (e) { - console.log(e); - showError(t('支付方式配置错误, 请联系管理员')); - } - }, [statusState?.status?.enable_stripe_topup]); + }, []); useEffect(() => { if (affFetchedRef.current) return; @@ -343,20 +397,18 @@ const TopUp = () => { getAffLink().then(); }, []); + // 在 statusState 可用时获取充值信息 + useEffect(() => { + getTopupInfo().then(); + }, []); + useEffect(() => { if (statusState?.status) { - const minTopUpValue = statusState.status.min_topup || 1; - setMinTopUp(minTopUpValue); - setTopUpCount(minTopUpValue); + // const minTopUpValue = statusState.status.min_topup || 1; + // setMinTopUp(minTopUpValue); + // setTopUpCount(minTopUpValue); setTopUpLink(statusState.status.top_up_link || ''); - setEnableOnlineTopUp(statusState.status.enable_online_topup || false); setPriceRatio(statusState.status.price || 1); - setEnableStripeTopUp(statusState.status.enable_stripe_topup || false); - - // 根据最小充值金额生成预设充值额度选项 - setPresetAmounts(generatePresetAmounts(minTopUpValue)); - // 初始化显示实付金额 - getAmount(minTopUpValue); setStatusLoading(false); } @@ -431,7 +483,11 @@ const TopUp = () => { const selectPresetAmount = (preset) => { setTopUpCount(preset.value); setSelectedPreset(preset.value); - setAmount(preset.value * priceRatio); + + // 计算实际支付金额,考虑折扣 + const discount = preset.discount || topupInfo.discount[preset.value] || 1.0; + const discountedAmount = preset.value * priceRatio * discount; + setAmount(discountedAmount); }; // 格式化大数字显示 @@ -475,6 +531,8 @@ const TopUp = () => { renderAmount={renderAmount} payWay={payWay} payMethods={payMethods} + amountNumber={amount} + discountRate={topupInfo?.discount?.[topUpCount] || 1.0} /> {/* 用户信息头部 */} @@ -512,6 +570,7 @@ const TopUp = () => { userState={userState} renderQuota={renderQuota} statusLoading={statusLoading} + topupInfo={topupInfo} />
diff --git a/web/src/components/topup/modals/PaymentConfirmModal.jsx b/web/src/components/topup/modals/PaymentConfirmModal.jsx index 76ea5eb22..1bffbfed1 100644 --- a/web/src/components/topup/modals/PaymentConfirmModal.jsx +++ b/web/src/components/topup/modals/PaymentConfirmModal.jsx @@ -36,7 +36,13 @@ const PaymentConfirmModal = ({ renderAmount, payWay, payMethods, + // 新增:用于显示折扣明细 + amountNumber, + discountRate, }) => { + const hasDiscount = discountRate && discountRate > 0 && discountRate < 1 && amountNumber > 0; + const originalAmount = hasDiscount ? (amountNumber / discountRate) : 0; + const discountAmount = hasDiscount ? (originalAmount - amountNumber) : 0; return ( ) : ( - - {renderAmount()} - +
+ + {renderAmount()} + + {hasDiscount && ( + + {Math.round(discountRate * 100)}% + + )} +
)} + {hasDiscount && !amountLoading && ( + <> +
+ + {t('原价')}: + + + {`${originalAmount.toFixed(2)} ${t('元')}`} + +
+
+ + {t('优惠')}: + + + {`- ${discountAmount.toFixed(2)} ${t('元')}`} + +
+ + )}
{t('支付方式')}: diff --git a/web/src/helpers/data.js b/web/src/helpers/data.js index 62353327c..b894a953c 100644 --- a/web/src/helpers/data.js +++ b/web/src/helpers/data.js @@ -28,7 +28,6 @@ export function setStatusData(data) { localStorage.setItem('enable_task', data.enable_task); localStorage.setItem('enable_data_export', data.enable_data_export); localStorage.setItem('chats', JSON.stringify(data.chats)); - localStorage.setItem('pay_methods', JSON.stringify(data.pay_methods)); localStorage.setItem( 'data_export_default_time', data.data_export_default_time, diff --git a/web/src/pages/Setting/Payment/SettingsPaymentGateway.jsx b/web/src/pages/Setting/Payment/SettingsPaymentGateway.jsx index ce8958dca..d681b6a27 100644 --- a/web/src/pages/Setting/Payment/SettingsPaymentGateway.jsx +++ b/web/src/pages/Setting/Payment/SettingsPaymentGateway.jsx @@ -41,6 +41,8 @@ export default function SettingsPaymentGateway(props) { TopupGroupRatio: '', CustomCallbackAddress: '', PayMethods: '', + AmountOptions: '', + AmountDiscount: '', }); const [originInputs, setOriginInputs] = useState({}); const formApiRef = useRef(null); @@ -62,7 +64,30 @@ export default function SettingsPaymentGateway(props) { TopupGroupRatio: props.options.TopupGroupRatio || '', CustomCallbackAddress: props.options.CustomCallbackAddress || '', PayMethods: props.options.PayMethods || '', + AmountOptions: props.options.AmountOptions || '', + AmountDiscount: props.options.AmountDiscount || '', }; + + // 美化 JSON 展示 + try { + if (currentInputs.AmountOptions) { + currentInputs.AmountOptions = JSON.stringify( + JSON.parse(currentInputs.AmountOptions), + null, + 2, + ); + } + } catch {} + try { + if (currentInputs.AmountDiscount) { + currentInputs.AmountDiscount = JSON.stringify( + JSON.parse(currentInputs.AmountDiscount), + null, + 2, + ); + } + } catch {} + setInputs(currentInputs); setOriginInputs({ ...currentInputs }); formApiRef.current.setValues(currentInputs); @@ -93,6 +118,20 @@ export default function SettingsPaymentGateway(props) { } } + if (originInputs['AmountOptions'] !== inputs.AmountOptions && inputs.AmountOptions.trim() !== '') { + if (!verifyJSON(inputs.AmountOptions)) { + showError(t('自定义充值数量选项不是合法的 JSON 数组')); + return; + } + } + + if (originInputs['AmountDiscount'] !== inputs.AmountDiscount && inputs.AmountDiscount.trim() !== '') { + if (!verifyJSON(inputs.AmountDiscount)) { + showError(t('充值金额折扣配置不是合法的 JSON 对象')); + return; + } + } + setLoading(true); try { const options = [ @@ -123,6 +162,12 @@ export default function SettingsPaymentGateway(props) { if (originInputs['PayMethods'] !== inputs.PayMethods) { options.push({ key: 'PayMethods', value: inputs.PayMethods }); } + if (originInputs['AmountOptions'] !== inputs.AmountOptions) { + options.push({ key: 'payment_setting.amount_options', value: inputs.AmountOptions }); + } + if (originInputs['AmountDiscount'] !== inputs.AmountDiscount) { + options.push({ key: 'payment_setting.amount_discount', value: inputs.AmountDiscount }); + } // 发送请求 const requestQueue = options.map((opt) => @@ -228,6 +273,37 @@ export default function SettingsPaymentGateway(props) { placeholder={t('为一个 JSON 文本')} autosize /> + + + + + + + + + + + + + From 1bffe3081dde8b6c9c35d4bced59bb23f3b7d396 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 12 Sep 2025 21:14:10 +0800 Subject: [PATCH 21/32] =?UTF-8?q?feat(settings):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E5=8D=95=E4=BD=8D=E7=BE=8E=E5=85=83=E9=A2=9D=E5=BA=A6=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=E9=A1=B9=EF=BC=8C=E4=B8=BA=E5=90=8E=E7=BB=AD=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=BD=9C=E5=87=86=E5=A4=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Setting/Operation/SettingsGeneral.jsx | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/web/src/pages/Setting/Operation/SettingsGeneral.jsx b/web/src/pages/Setting/Operation/SettingsGeneral.jsx index 37b3dd984..5af750ec3 100644 --- a/web/src/pages/Setting/Operation/SettingsGeneral.jsx +++ b/web/src/pages/Setting/Operation/SettingsGeneral.jsx @@ -130,17 +130,19 @@ export default function GeneralSettings(props) { showClear /> - - setShowQuotaWarning(true)} - /> - + {inputs.QuotaPerUnit !== '500000' && inputs.QuotaPerUnit !== 500000 && ( + + setShowQuotaWarning(true)} + /> + + )} Date: Fri, 12 Sep 2025 21:53:21 +0800 Subject: [PATCH 22/32] feat(i18n): update TOTP verification message with configuration details --- web/src/components/common/modals/TwoFactorAuthModal.jsx | 2 +- web/src/i18n/locales/en.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/web/src/components/common/modals/TwoFactorAuthModal.jsx b/web/src/components/common/modals/TwoFactorAuthModal.jsx index b0fc28e2a..2a9a8b25b 100644 --- a/web/src/components/common/modals/TwoFactorAuthModal.jsx +++ b/web/src/components/common/modals/TwoFactorAuthModal.jsx @@ -135,7 +135,7 @@ const TwoFactorAuthModal = ({ autoFocus /> - {t('支持6位TOTP验证码或8位备用码')} + {t('支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。')}
diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index f47839f2e..73dfbebe7 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1993,7 +1993,7 @@ "安全验证": "Security verification", "验证": "Verify", "为了保护账户安全,请验证您的两步验证码。": "To protect account security, please verify your two-factor authentication code.", - "支持6位TOTP验证码或8位备用码": "Supports 6-digit TOTP verification code or 8-digit backup code", + "支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。": "Supports 6-digit TOTP verification code or 8-digit backup code, can be configured or viewed in `Personal Settings - Security Settings - Two-Factor Authentication Settings`.", "获取密钥失败": "Failed to get key", "查看密钥": "View key", "查看渠道密钥": "View channel key", From 6ed775be8f55787f0af6eb98b60634c73be2d94d Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Fri, 12 Sep 2025 21:52:32 +0800 Subject: [PATCH 23/32] refactor: use common taskSubmitReq --- relay/channel/task/jimeng/adaptor.go | 18 +------ relay/channel/task/kling/adaptor.go | 34 ++----------- relay/channel/task/vidu/adaptor.go | 33 ++---------- relay/common/relay_info.go | 8 +++ relay/common/relay_utils.go | 75 ++++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 76 deletions(-) diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 955e592a2..f838bdb16 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -18,7 +18,6 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" - "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" @@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { // ValidateRequestAndSetAction parses body, validates fields and sets default action. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Accept only POST /v1/video/generations as "generate" action. - action := constant.TaskActionGenerate - info.Action = action - - req := relaycommon.TaskSubmitReq{} - if err := common.UnmarshalBodyReusable(c, &req); err != nil { - taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) - return - } - if strings.TrimSpace(req.Prompt) == "" { - taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) - return - } - - // Store into context for later usage - c.Set("task_request", req) - return nil + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) } // BuildRequestURL constructs the upstream URL. diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 3d6da253b..13f2af972 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -16,7 +16,6 @@ import ( "github.com/golang-jwt/jwt" "github.com/pkg/errors" - "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" @@ -28,16 +27,6 @@ import ( // Request / Response structures // ============================ -type SubmitReq struct { - Prompt string `json:"prompt"` - Model string `json:"model,omitempty"` - Mode string `json:"mode,omitempty"` - Image string `json:"image,omitempty"` - Size string `json:"size,omitempty"` - Duration int `json:"duration,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` -} - type TrajectoryPoint struct { X int `json:"x"` Y int `json:"y"` @@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { // ValidateRequestAndSetAction parses body, validates fields and sets default action. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - // Accept only POST /v1/video/generations as "generate" action. - action := constant.TaskActionGenerate - info.Action = action - - var req SubmitReq - if err := common.UnmarshalBodyReusable(c, &req); err != nil { - taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) - return - } - if strings.TrimSpace(req.Prompt) == "" { - taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) - return - } - - // Store into context for later usage - c.Set("task_request", req) - return nil + // Use the standard validation method for TaskSubmitReq + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) } // BuildRequestURL constructs the upstream URL. @@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if !exists { return nil, fmt.Errorf("request not found in context") } - req := v.(SubmitReq) + req := v.(relaycommon.TaskSubmitReq) body, err := a.convertToRequestPayload(&req) if err != nil { @@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { r := requestPayload{ Prompt: req.Prompt, Image: req.Image, diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index c82c1c0e8..a1140d1e7 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -23,16 +23,6 @@ import ( // Request / Response structures // ============================ -type SubmitReq struct { - Prompt string `json:"prompt"` - Model string `json:"model,omitempty"` - Mode string `json:"mode,omitempty"` - Image string `json:"image,omitempty"` - Size string `json:"size,omitempty"` - Duration int `json:"duration,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` -} - type requestPayload struct { Model string `json:"model"` Images []string `json:"images"` @@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { - var req SubmitReq - if err := c.ShouldBindJSON(&req); err != nil { - return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest) - } - - if req.Prompt == "" { - return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest) - } - - if req.Image != "" { - info.Action = constant.TaskActionGenerate - } else { - info.Action = constant.TaskActionTextGenerate - } - - c.Set("task_request", req) - return nil + // Use the unified validation method for TaskSubmitReq with image-based action determination + return relaycommon.ValidateTaskRequestWithImageBinding(c, info) } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { @@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) if !exists { return nil, fmt.Errorf("request not found in context") } - req := v.(SubmitReq) + req := v.(relaycommon.TaskSubmitReq) body, err := a.convertToRequestPayload(&req) if err != nil { @@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { var images []string if req.Image != "" { images = []string{req.Image} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index da572c070..eb292de23 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -486,6 +486,14 @@ type TaskSubmitReq struct { Metadata map[string]interface{} `json:"metadata,omitempty"` } +func (t TaskSubmitReq) GetPrompt() string { + return t.Prompt +} + +func (t TaskSubmitReq) GetImage() string { + return t.Image +} + type TaskInfo struct { Code int `json:"code"` TaskID string `json:"task_id"` diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 3d5efcb6d..108395613 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -2,12 +2,23 @@ package common import ( "fmt" + "net/http" + "one-api/common" "one-api/constant" + "one-api/dto" "strings" "github.com/gin-gonic/gin" ) +type HasPrompt interface { + GetPrompt() string +} + +type HasImage interface { + GetImage() string +} + func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) @@ -30,3 +41,67 @@ func GetAPIVersion(c *gin.Context) string { } return apiVersion } + +func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError { + return &dto.TaskError{ + Code: code, + Message: err.Error(), + StatusCode: statusCode, + LocalError: localError, + Error: err, + } +} + +func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) { + info.Action = action + c.Set("task_request", requestObj) +} + +func validatePrompt(prompt string) *dto.TaskError { + if strings.TrimSpace(prompt) == "" { + return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true) + } + return nil +} + +func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError { + var req TaskSubmitReq + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + return createTaskError(err, "invalid_request", http.StatusBadRequest, true) + } + + if taskErr := validatePrompt(req.Prompt); taskErr != nil { + return taskErr + } + + storeTaskRequest(c, info, action, req) + return nil +} + +func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError { + hasPrompt, ok := requestObj.(HasPrompt) + if !ok { + return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true) + } + + if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil { + return taskErr + } + + action := constant.TaskActionTextGenerate + if hasImage, ok := requestObj.(HasImage); ok && strings.TrimSpace(hasImage.GetImage()) != "" { + action = constant.TaskActionGenerate + } + + storeTaskRequest(c, info, action, requestObj) + return nil +} + +func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError { + var req TaskSubmitReq + if err := c.ShouldBindJSON(&req); err != nil { + return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false) + } + + return ValidateTaskRequestWithImage(c, info, req) +} From f14b06ec3a88023b3f4ef17f90e6e815bd4a75d2 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Fri, 12 Sep 2025 22:19:45 +0800 Subject: [PATCH 24/32] feat: jimeng video add images --- relay/channel/task/jimeng/adaptor.go | 8 ++++---- relay/common/relay_info.go | 5 +++-- relay/common/relay_utils.go | 9 +++++++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index f838bdb16..2bc45c547 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -318,11 +318,11 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } // Handle one-of image_urls or binary_data_base64 - if req.Image != "" { - if strings.HasPrefix(req.Image, "http") { - r.ImageUrls = []string{req.Image} + if req.HasImage() { + if strings.HasPrefix(req.Images[0], "http") { + r.ImageUrls = req.Images } else { - r.BinaryDataBase64 = []string{req.Image} + r.BinaryDataBase64 = req.Images } } metadata := req.Metadata diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index eb292de23..99925dc5d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -481,6 +481,7 @@ type TaskSubmitReq struct { Model string `json:"model,omitempty"` Mode string `json:"mode,omitempty"` Image string `json:"image,omitempty"` + Images []string `json:"images,omitempty"` Size string `json:"size,omitempty"` Duration int `json:"duration,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` @@ -490,8 +491,8 @@ func (t TaskSubmitReq) GetPrompt() string { return t.Prompt } -func (t TaskSubmitReq) GetImage() string { - return t.Image +func (t TaskSubmitReq) HasImage() bool { + return len(t.Images) > 0 } type TaskInfo struct { diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 108395613..cf6d08dda 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -16,7 +16,7 @@ type HasPrompt interface { } type HasImage interface { - GetImage() string + HasImage() bool } func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { @@ -74,6 +74,11 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d return taskErr } + if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" { + // 兼容单图上传 + req.Images = []string{req.Image} + } + storeTaskRequest(c, info, action, req) return nil } @@ -89,7 +94,7 @@ func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj in } action := constant.TaskActionTextGenerate - if hasImage, ok := requestObj.(HasImage); ok && strings.TrimSpace(hasImage.GetImage()) != "" { + if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() { action = constant.TaskActionGenerate } From 6451158680ac671e65f7691f1197b0f9f51c4637 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 13 Sep 2025 12:53:28 +0800 Subject: [PATCH 25/32] =?UTF-8?q?Revert=20"feat:=20gemini-2.5-flash-image-?= =?UTF-8?q?preview=20=E6=96=87=E6=9C=AC=E5=92=8C=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E8=BE=93=E5=87=BA=E8=AE=A1=E8=B4=B9"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit e732c5842675d2aeeb3faa2af633341fb9d9c1ac. --- dto/gemini.go | 16 ++++---- relay/channel/gemini/relay-gemini-native.go | 36 ------------------ relay/compatible_handler.go | 15 -------- service/token_counter.go | 2 +- setting/model_setting/gemini.go | 1 - setting/operation_setting/tools.go | 11 ------ setting/ratio_setting/model_ratio.go | 10 ++--- web/src/helpers/render.jsx | 38 ++++--------------- web/src/hooks/usage-logs/useUsageLogsData.jsx | 2 - 9 files changed, 20 insertions(+), 111 deletions(-) diff --git a/dto/gemini.go b/dto/gemini.go index cd5d74cdd..5df67ba0b 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -2,12 +2,11 @@ package dto import ( "encoding/json" + "github.com/gin-gonic/gin" "one-api/common" "one-api/logger" "one-api/types" "strings" - - "github.com/gin-gonic/gin" ) type GeminiChatRequest struct { @@ -269,15 +268,14 @@ type GeminiChatResponse struct { } type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - ThoughtsTokenCount int `json:"thoughtsTokenCount"` - PromptTokensDetails []GeminiModalityTokenCount `json:"promptTokensDetails"` - CandidatesTokensDetails []GeminiModalityTokenCount `json:"candidatesTokensDetails"` + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + ThoughtsTokenCount int `json:"thoughtsTokenCount"` + PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` } -type GeminiModalityTokenCount struct { +type GeminiPromptTokensDetails struct { Modality string `json:"modality"` TokenCount int `json:"tokenCount"` } diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 564b86908..974a22f50 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -46,32 +46,6 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount - if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") { - imageOutputCounts := 0 - for _, candidate := range geminiResponse.Candidates { - for _, part := range candidate.Content.Parts { - if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") { - imageOutputCounts++ - } - } - } - if imageOutputCounts != 0 { - usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290 - usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290 - c.Set("gemini_image_tokens", imageOutputCounts*1290) - } - } - - // if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") { - // for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails { - // if detail.Modality == "IMAGE" { - // usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount - // usage.TotalTokens = usage.TotalTokens - detail.TokenCount - // c.Set("gemini_image_tokens", detail.TokenCount) - // } - // } - // } - for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { if detail.Modality == "AUDIO" { usage.PromptTokensDetails.AudioTokens = detail.TokenCount @@ -162,16 +136,6 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn usage.PromptTokensDetails.TextTokens = detail.TokenCount } } - - if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") { - for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails { - if detail.Modality == "IMAGE" { - usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount - usage.TotalTokens = usage.TotalTokens - detail.TokenCount - c.Set("gemini_image_tokens", detail.TokenCount) - } - } - } } // 直接发送 GeminiChatResponse 响应 diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 8f27fd60b..01ab1fff4 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -326,22 +326,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage } else { quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) } - var dGeminiImageOutputQuota decimal.Decimal - var imageOutputPrice float64 - if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") { - imageOutputPrice = operation_setting.GetGeminiImageOutputPricePerMillionTokens(modelName) - if imageOutputPrice > 0 { - dImageOutputTokens := decimal.NewFromInt(int64(ctx.GetInt("gemini_image_tokens"))) - dGeminiImageOutputQuota = decimal.NewFromFloat(imageOutputPrice).Div(decimal.NewFromInt(1000000)).Mul(dImageOutputTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit) - } - } // 添加 responses tools call 调用的配额 quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) // 添加 audio input 独立计费 quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) - // 添加 Gemini image output 计费 - quotaCalculateDecimal = quotaCalculateDecimal.Add(dGeminiImageOutputQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens @@ -440,10 +429,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage other["audio_input_token_count"] = audioTokens other["audio_input_price"] = audioInputPrice } - if !dGeminiImageOutputQuota.IsZero() { - other["image_output_token_count"] = ctx.GetInt("gemini_image_tokens") - other["image_output_price"] = imageOutputPrice - } model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, diff --git a/service/token_counter.go b/service/token_counter.go index da56523fe..be5c2e80c 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -336,7 +336,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco for i, file := range meta.Files { switch file.FileType { case types.FileTypeImage: - if info.RelayFormat == types.RelayFormatGemini && !strings.HasPrefix(model, "gemini-2.5-flash-image-preview") { + if info.RelayFormat == types.RelayFormatGemini { tkm += 256 } else { token, err := getImageToken(file, model, info.IsStream) diff --git a/setting/model_setting/gemini.go b/setting/model_setting/gemini.go index 5412155f1..f132fec88 100644 --- a/setting/model_setting/gemini.go +++ b/setting/model_setting/gemini.go @@ -26,7 +26,6 @@ var defaultGeminiSettings = GeminiSettings{ SupportedImagineModels: []string{ "gemini-2.0-flash-exp-image-generation", "gemini-2.0-flash-exp", - "gemini-2.5-flash-image-preview", }, ThinkingAdapterEnabled: false, ThinkingAdapterBudgetTokensPercentage: 0.6, diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go index b87265ee1..549a1862e 100644 --- a/setting/operation_setting/tools.go +++ b/setting/operation_setting/tools.go @@ -24,10 +24,6 @@ const ( ClaudeWebSearchPrice = 10.00 ) -const ( - Gemini25FlashImagePreviewImageOutputPrice = 30.00 -) - func GetClaudeWebSearchPricePerThousand() float64 { return ClaudeWebSearchPrice } @@ -69,10 +65,3 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 { } return 0 } - -func GetGeminiImageOutputPricePerMillionTokens(modelName string) float64 { - if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") { - return Gemini25FlashImagePreviewImageOutputPrice - } - return 0 -} diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 1a1b0afa8..f06cd71ef 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -178,7 +178,6 @@ var defaultModelRatio = map[string]float64{ "gemini-2.5-flash-lite-preview-thinking-*": 0.05, "gemini-2.5-flash-lite-preview-06-17": 0.05, "gemini-2.5-flash": 0.15, - "gemini-2.5-flash-image-preview": 0.15, // $0.30(text/image) / 1M tokens "text-embedding-004": 0.001, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens @@ -294,11 +293,10 @@ var ( ) var defaultCompletionRatio = map[string]float64{ - "gpt-4-gizmo-*": 2, - "gpt-4o-gizmo-*": 3, - "gpt-4-all": 2, - "gpt-image-1": 8, - "gemini-2.5-flash-image-preview": 8.3333333333, + "gpt-4-gizmo-*": 2, + "gpt-4o-gizmo-*": 3, + "gpt-4-all": 2, + "gpt-image-1": 8, } // InitRatioSettings initializes all model related settings maps diff --git a/web/src/helpers/render.jsx b/web/src/helpers/render.jsx index 3d9d8d710..65332701b 100644 --- a/web/src/helpers/render.jsx +++ b/web/src/helpers/render.jsx @@ -1017,7 +1017,7 @@ export function renderModelPrice( cacheRatio = 1.0, image = false, imageRatio = 1.0, - imageInputTokens = 0, + imageOutputTokens = 0, webSearch = false, webSearchCallCount = 0, webSearchPrice = 0, @@ -1027,8 +1027,6 @@ export function renderModelPrice( audioInputSeperatePrice = false, audioInputTokens = 0, audioInputPrice = 0, - imageOutputTokens = 0, - imageOutputPrice = 0, ) { const { ratio: effectiveGroupRatio, label: ratioLabel } = getEffectiveRatio( groupRatio, @@ -1059,9 +1057,9 @@ export function renderModelPrice( let effectiveInputTokens = inputTokens - cacheTokens + cacheTokens * cacheRatio; // Handle image tokens if present - if (image && imageInputTokens > 0) { + if (image && imageOutputTokens > 0) { effectiveInputTokens = - inputTokens - imageInputTokens + imageInputTokens * imageRatio; + inputTokens - imageOutputTokens + imageOutputTokens * imageRatio; } if (audioInputTokens > 0) { effectiveInputTokens -= audioInputTokens; @@ -1071,8 +1069,7 @@ export function renderModelPrice( (audioInputTokens / 1000000) * audioInputPrice * groupRatio + (completionTokens / 1000000) * completionRatioPrice * groupRatio + (webSearchCallCount / 1000) * webSearchPrice * groupRatio + - (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio + - (imageOutputTokens / 1000000) * imageOutputPrice * groupRatio; + (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio; return ( <> @@ -1107,7 +1104,7 @@ export function renderModelPrice( )}

)} - {image && imageInputTokens > 0 && ( + {image && imageOutputTokens > 0 && (

{i18next.t( '图片输入价格:${{price}} * {{ratio}} = ${{total}} / 1M tokens (图片倍率: {{imageRatio}})', @@ -1134,26 +1131,17 @@ export function renderModelPrice( })}

)} - {imageOutputPrice > 0 && imageOutputTokens > 0 && ( -

- {i18next.t('图片输出价格:${{price}} * 分组倍率{{ratio}} = ${{total}} / 1M tokens', { - price: imageOutputPrice, - ratio: groupRatio, - total: imageOutputPrice * groupRatio, - })} -

- )}

{(() => { // 构建输入部分描述 let inputDesc = ''; - if (image && imageInputTokens > 0) { + if (image && imageOutputTokens > 0) { inputDesc = i18next.t( '(输入 {{nonImageInput}} tokens + 图片输入 {{imageInput}} tokens * {{imageRatio}} / 1M tokens * ${{price}}', { - nonImageInput: inputTokens - imageInputTokens, - imageInput: imageInputTokens, + nonImageInput: inputTokens - imageOutputTokens, + imageInput: imageOutputTokens, imageRatio: imageRatio, price: inputRatioPrice, }, @@ -1223,16 +1211,6 @@ export function renderModelPrice( }, ) : '', - imageOutputPrice > 0 && imageOutputTokens > 0 - ? i18next.t( - ' + 图片输出 {{tokenCounts}} tokens * ${{price}} / 1M tokens * 分组倍率{{ratio}}', - { - tokenCounts: imageOutputTokens, - price: imageOutputPrice, - ratio: groupRatio, - }, - ) - : '', ].join(''); return i18next.t( diff --git a/web/src/hooks/usage-logs/useUsageLogsData.jsx b/web/src/hooks/usage-logs/useUsageLogsData.jsx index 3584f1d9b..81f3f539a 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.jsx +++ b/web/src/hooks/usage-logs/useUsageLogsData.jsx @@ -447,8 +447,6 @@ export const useLogsData = () => { other?.audio_input_seperate_price || false, other?.audio_input_token_count || 0, other?.audio_input_price || 0, - other?.image_output_token_count || 0, - other?.image_output_price || 0, ); } expandDataLocal.push({ From c1d7ecdeec73ad5eaaad0626ee0262930ce67142 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 13 Sep 2025 12:53:41 +0800 Subject: [PATCH 26/32] fix(adaptor): correct VertexKeyType condition in SetupRequestHeader --- relay/channel/vertex/adaptor.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index b6a78b7aa..7e2fdcad3 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -160,7 +160,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if strings.HasPrefix(info.UpstreamModelName, "imagen") { suffix = "predict" } - return a.getRequestUrl(info, info.UpstreamModelName, suffix) } else if a.RequestMode == RequestModeClaude { if info.IsStream { @@ -181,7 +180,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - if info.ChannelOtherSettings.VertexKeyType == "json" { + if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey { accessToken, err := getAccessToken(a, info) if err != nil { return err From 28ed42130c9e6397580be3172a12ebd5dc2da096 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 13 Sep 2025 15:24:48 +0800 Subject: [PATCH 27/32] fix: update references from setting to system_setting for ServerAddress --- controller/midjourney.go | 5 +++-- controller/misc.go | 4 ++-- controller/oidc.go | 3 +-- controller/topup.go | 3 ++- controller/topup_stripe.go | 5 +++-- model/option.go | 15 ++++++++------- relay/mjproxy_handler.go | 3 ++- service/epay.go | 4 ++-- service/quota.go | 4 ++-- .../system_setting_old.go} | 2 +- 10 files changed, 26 insertions(+), 22 deletions(-) rename setting/{system_setting.go => system_setting/system_setting_old.go} (89%) diff --git a/controller/midjourney.go b/controller/midjourney.go index a67d39c23..3a7304419 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -13,6 +13,7 @@ import ( "one-api/model" "one-api/service" "one-api/setting" + "one-api/setting/system_setting" "time" "github.com/gin-gonic/gin" @@ -259,7 +260,7 @@ func GetAllMidjourney(c *gin.Context) { if setting.MjForwardUrlEnabled { for i, midjourney := range items { - midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId + midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId items[i] = midjourney } } @@ -284,7 +285,7 @@ func GetUserMidjourney(c *gin.Context) { if setting.MjForwardUrlEnabled { for i, midjourney := range items { - midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId + midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId items[i] = midjourney } } diff --git a/controller/misc.go b/controller/misc.go index 085829302..875142ffb 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -58,7 +58,7 @@ func GetStatus(c *gin.Context) { "footer_html": common.Footer, "wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_login": common.WeChatAuthEnabled, - "server_address": setting.ServerAddress, + "server_address": system_setting.ServerAddress, "turnstile_check": common.TurnstileCheckEnabled, "turnstile_site_key": common.TurnstileSiteKey, "top_up_link": common.TopUpLink, @@ -249,7 +249,7 @@ func SendPasswordResetEmail(c *gin.Context) { } code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) - link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code) + link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code) subject := fmt.Sprintf("%s密码重置", common.SystemName) content := fmt.Sprintf("

您好,你正在进行%s密码重置。

"+ "

点击 此处 进行密码重置。

"+ diff --git a/controller/oidc.go b/controller/oidc.go index f3def0e34..8e254d38f 100644 --- a/controller/oidc.go +++ b/controller/oidc.go @@ -8,7 +8,6 @@ import ( "net/url" "one-api/common" "one-api/model" - "one-api/setting" "one-api/setting/system_setting" "strconv" "strings" @@ -45,7 +44,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret) values.Set("code", code) values.Set("grant_type", "authorization_code") - values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress)) + values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress)) formData := values.Encode() req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData)) if err != nil { diff --git a/controller/topup.go b/controller/topup.go index 93f3e58e0..243e67940 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -10,6 +10,7 @@ import ( "one-api/service" "one-api/setting" "one-api/setting/operation_setting" + "one-api/setting/system_setting" "strconv" "sync" "time" @@ -152,7 +153,7 @@ func RequestEpay(c *gin.Context) { } callBackAddress := service.GetCallbackAddress() - returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log") + returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo) diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index bf0d7bf36..d462acb4b 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -9,6 +9,7 @@ import ( "one-api/model" "one-api/setting" "one-api/setting/operation_setting" + "one-api/setting/system_setting" "strconv" "strings" "time" @@ -216,8 +217,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i params := &stripe.CheckoutSessionParams{ ClientReferenceID: stripe.String(referenceId), - SuccessURL: stripe.String(setting.ServerAddress + "/log"), - CancelURL: stripe.String(setting.ServerAddress + "/topup"), + SuccessURL: stripe.String(system_setting.ServerAddress + "/log"), + CancelURL: stripe.String(system_setting.ServerAddress + "/topup"), LineItems: []*stripe.CheckoutSessionLineItemParams{ { Price: stripe.String(setting.StripePriceId), diff --git a/model/option.go b/model/option.go index 73fe92ad1..fefee4e7d 100644 --- a/model/option.go +++ b/model/option.go @@ -6,6 +6,7 @@ import ( "one-api/setting/config" "one-api/setting/operation_setting" "one-api/setting/ratio_setting" + "one-api/setting/system_setting" "strconv" "strings" "time" @@ -66,9 +67,9 @@ func InitOptionMap() { common.OptionMap["SystemName"] = common.SystemName common.OptionMap["Logo"] = common.Logo common.OptionMap["ServerAddress"] = "" - common.OptionMap["WorkerUrl"] = setting.WorkerUrl - common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey - common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled) + common.OptionMap["WorkerUrl"] = system_setting.WorkerUrl + common.OptionMap["WorkerValidKey"] = system_setting.WorkerValidKey + common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(system_setting.WorkerAllowHttpImageRequestEnabled) common.OptionMap["PayAddress"] = "" common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["EpayId"] = "" @@ -271,7 +272,7 @@ func updateOptionMap(key string, value string) (err error) { case "SMTPSSLEnabled": common.SMTPSSLEnabled = boolValue case "WorkerAllowHttpImageRequestEnabled": - setting.WorkerAllowHttpImageRequestEnabled = boolValue + system_setting.WorkerAllowHttpImageRequestEnabled = boolValue case "DefaultUseAutoGroup": setting.DefaultUseAutoGroup = boolValue case "ExposeRatioEnabled": @@ -293,11 +294,11 @@ func updateOptionMap(key string, value string) (err error) { case "SMTPToken": common.SMTPToken = value case "ServerAddress": - setting.ServerAddress = value + system_setting.ServerAddress = value case "WorkerUrl": - setting.WorkerUrl = value + system_setting.WorkerUrl = value case "WorkerValidKey": - setting.WorkerValidKey = value + system_setting.WorkerValidKey = value case "PayAddress": operation_setting.PayAddress = value case "Chats": diff --git a/relay/mjproxy_handler.go b/relay/mjproxy_handler.go index 7c52cb6be..ec8dfc6b2 100644 --- a/relay/mjproxy_handler.go +++ b/relay/mjproxy_handler.go @@ -16,6 +16,7 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting" + "one-api/setting/system_setting" "strconv" "strings" "time" @@ -131,7 +132,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.ImageUrl = "" if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled { - midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId + midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId if originTask.Status != "SUCCESS" { midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) } diff --git a/service/epay.go b/service/epay.go index a1ff484e6..48b84dd58 100644 --- a/service/epay.go +++ b/service/epay.go @@ -1,13 +1,13 @@ package service import ( - "one-api/setting" "one-api/setting/operation_setting" + "one-api/setting/system_setting" ) func GetCallbackAddress() string { if operation_setting.CustomCallbackAddress == "" { - return setting.ServerAddress + return system_setting.ServerAddress } return operation_setting.CustomCallbackAddress } diff --git a/service/quota.go b/service/quota.go index e078a1ad1..12017e11e 100644 --- a/service/quota.go +++ b/service/quota.go @@ -11,8 +11,8 @@ import ( "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - "one-api/setting" "one-api/setting/ratio_setting" + "one-api/setting/system_setting" "one-api/types" "strings" "time" @@ -534,7 +534,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon } if quotaTooLow { prompt := "您的额度即将用尽" - topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) + topUpLink := fmt.Sprintf("%s/topup", system_setting.ServerAddress) // 根据通知方式生成不同的内容格式 var content string diff --git a/setting/system_setting.go b/setting/system_setting/system_setting_old.go similarity index 89% rename from setting/system_setting.go rename to setting/system_setting/system_setting_old.go index c37a61235..4e0f1a502 100644 --- a/setting/system_setting.go +++ b/setting/system_setting/system_setting_old.go @@ -1,4 +1,4 @@ -package setting +package system_setting var ServerAddress = "http://localhost:3000" var WorkerUrl = "" From da6f24a3d48c286e4509a4f0befcb263133ec41b Mon Sep 17 00:00:00 2001 From: Seefs Date: Sat, 13 Sep 2025 16:26:14 +0800 Subject: [PATCH 28/32] fix veo3 adapter --- relay/channel/task/vertex/adaptor.go | 85 ++++++++++++++++------------ 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index d2ab826d0..4a236b2f0 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -7,12 +7,12 @@ import ( "fmt" "io" "net/http" + "one-api/model" "regexp" "strings" "github.com/gin-gonic/gin" - "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" @@ -21,6 +21,10 @@ import ( "one-api/service" ) +// ============================ +// Request / Response structures +// ============================ + type requestPayload struct { Instances []map[string]any `json:"instances"` Parameters map[string]any `json:"parameters,omitempty"` @@ -52,33 +56,35 @@ type operationResponse struct { } `json:"error"` } -type TaskAdaptor struct{} +// ============================ +// Adaptor implementation +// ============================ -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {} - -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { - info.Action = constant.TaskActionTextGenerate - - req := relaycommon.TaskSubmitReq{} - if err := common.UnmarshalBodyReusable(c, &req); err != nil { - return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) - } - if strings.TrimSpace(req.Prompt) == "" { - return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) - } - c.Set("task_request", req) - return nil +type TaskAdaptor struct { + ChannelType int + apiKey string + baseURL string } -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.ChannelBaseUrl + a.apiKey = info.ApiKey +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + // Use the standard validation method for TaskSubmitReq + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate) +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } modelName := info.OriginModelName - if v, ok := getRequestModelFromContext(info); ok { - modelName = v - } if modelName == "" { modelName = "veo-3.0-generate-001" } @@ -103,16 +109,17 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, ), nil } -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { return fmt.Errorf("failed to decode credentials: %w", err) } - token, err := vertexcore.AcquireAccessToken(*adc, info.ChannelSetting.Proxy) + token, err := vertexcore.AcquireAccessToken(*adc, "") if err != nil { return fmt.Errorf("failed to acquire access token: %w", err) } @@ -121,7 +128,8 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info return nil } -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) { +// BuildRequestBody converts request into Vertex specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, ok := c.Get("task_request") if !ok { return nil, fmt.Errorf("request not found in context") @@ -151,11 +159,13 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayI return bytes.NewReader(data), nil } -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -177,6 +187,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} } func (a *TaskAdaptor) GetChannelName() string { return "vertex" } +// FetchTask fetch task status func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { @@ -191,15 +202,15 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http region = "us-central1" } project := extractProjectFromOperationName(upstreamName) - model := extractModelFromOperationName(upstreamName) - if project == "" || model == "" { + modelName := extractModelFromOperationName(upstreamName) + if project == "" || modelName == "" { return nil, fmt.Errorf("cannot extract project/model from operation name") } var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, model) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, model) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) } payload := map[string]string{"operationName": upstreamName} data, err := json.Marshal(payload) @@ -232,17 +243,17 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } ti := &relaycommon.TaskInfo{} if op.Error.Message != "" { - ti.Status = "FAILURE" + ti.Status = model.TaskStatusFailure ti.Reason = op.Error.Message ti.Progress = "100%" return ti, nil } if !op.Done { - ti.Status = "IN_PROGRESS" + ti.Status = model.TaskStatusInProgress ti.Progress = "50%" return ti, nil } - ti.Status = "SUCCESS" + ti.Status = model.TaskStatusSuccess ti.Progress = "100%" if len(op.Response.Videos) > 0 { v0 := op.Response.Videos[0] @@ -290,9 +301,9 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e return ti, nil } -func getRequestModelFromContext(info *relaycommon.TaskRelayInfo) (string, bool) { - return info.OriginModelName, info.OriginModelName != "" -} +// ============================ +// helpers +// ============================ func encodeLocalTaskID(name string) string { return base64.RawURLEncoding.EncodeToString([]byte(name)) From 8563eafc57e1886a5c413ff3b977ca366cecc496 Mon Sep 17 00:00:00 2001 From: Seefs Date: Sun, 14 Sep 2025 12:59:44 +0800 Subject: [PATCH 29/32] fix: settings --- service/cf_worker.go | 12 ++++++------ service/user_notify.go | 6 +++--- service/webhook.go | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/service/cf_worker.go b/service/cf_worker.go index 4a7b43760..d60b6fad5 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -6,7 +6,7 @@ import ( "fmt" "net/http" "one-api/common" - "one-api/setting" + "one-api/setting/system_setting" "strings" ) @@ -21,14 +21,14 @@ type WorkerRequest struct { // DoWorkerRequest 通过Worker发送请求 func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { - if !setting.EnableWorker() { + if !system_setting.EnableWorker() { return nil, fmt.Errorf("worker not enabled") } - if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") { + if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") { return nil, fmt.Errorf("only support https url") } - workerUrl := setting.WorkerUrl + workerUrl := system_setting.WorkerUrl if !strings.HasSuffix(workerUrl, "/") { workerUrl += "/" } @@ -43,11 +43,11 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { } func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) { - if setting.EnableWorker() { + if system_setting.EnableWorker() { common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) req := &WorkerRequest{ URL: originUrl, - Key: setting.WorkerValidKey, + Key: system_setting.WorkerValidKey, } return DoWorkerRequest(req) } else { diff --git a/service/user_notify.go b/service/user_notify.go index c4a3ea91f..972ca655c 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -7,7 +7,7 @@ import ( "one-api/common" "one-api/dto" "one-api/model" - "one-api/setting" + "one-api/setting/system_setting" "strings" ) @@ -91,11 +91,11 @@ func sendBarkNotify(barkURL string, data dto.Notify) error { var resp *http.Response var err error - if setting.EnableWorker() { + if system_setting.EnableWorker() { // 使用worker发送请求 workerReq := &WorkerRequest{ URL: finalURL, - Key: setting.WorkerValidKey, + Key: system_setting.WorkerValidKey, Method: http.MethodGet, Headers: map[string]string{ "User-Agent": "OneAPI-Bark-Notify/1.0", diff --git a/service/webhook.go b/service/webhook.go index 8faccda30..9c6ec8102 100644 --- a/service/webhook.go +++ b/service/webhook.go @@ -9,7 +9,7 @@ import ( "fmt" "net/http" "one-api/dto" - "one-api/setting" + "one-api/setting/system_setting" "time" ) @@ -56,11 +56,11 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error var req *http.Request var resp *http.Response - if setting.EnableWorker() { + if system_setting.EnableWorker() { // 构建worker请求数据 workerReq := &WorkerRequest{ URL: webhookURL, - Key: setting.WorkerValidKey, + Key: system_setting.WorkerValidKey, Method: http.MethodPost, Headers: map[string]string{ "Content-Type": "application/json", From 33bf267ce82b82d5a43eeadff9d7b74424ddc2e0 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Mon, 15 Sep 2025 14:31:55 +0800 Subject: [PATCH 30/32] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=8D=B3?= =?UTF-8?q?=E6=A2=A6=E8=A7=86=E9=A2=913.0,=E6=96=B0=E5=A2=9E10s(frames=3D2?= =?UTF-8?q?41)=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/task/jimeng/adaptor.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 2bc45c547..e870a6590 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -36,6 +36,7 @@ type requestPayload struct { Prompt string `json:"prompt,omitempty"` Seed int64 `json:"seed"` AspectRatio string `json:"aspect_ratio"` + Frames int `json:"frames,omitempty"` } type responsePayload struct { @@ -311,10 +312,15 @@ func hmacSHA256(key []byte, data []byte) []byte { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { r := requestPayload{ - ReqKey: "jimeng_vgfm_i2v_l20", - Prompt: req.Prompt, - AspectRatio: "16:9", // Default aspect ratio - Seed: -1, // Default to random + ReqKey: req.Model, + Prompt: req.Prompt, + } + + switch req.Duration { + case 10: + r.Frames = 241 // 24*10+1 = 241 + default: + r.Frames = 121 // 24*5+1 = 121 } // Handle one-of image_urls or binary_data_base64 From f3e220b196028d29ddc2947daa7b3b8da21267a0 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Mon, 15 Sep 2025 15:53:41 +0800 Subject: [PATCH 31/32] feat: jimeng video 3.0 req_key convert --- relay/channel/task/jimeng/adaptor.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index e870a6590..b954d7b88 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -340,6 +340,22 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* if err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } + + // 即梦视频3.0 ReqKey转换 + // https://www.volcengine.com/docs/85621/1792707 + if strings.Contains(r.ReqKey, "jimeng_v30") { + if len(r.ImageUrls) > 1 { + // 多张图片:首尾帧生成 + r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1) + } else if len(r.ImageUrls) == 1 { + // 单张图片:图生视频 + r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1) + } else { + // 无图片:文生视频 + r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1) + } + } + return &r, nil } From f236785ed5594c3229ba5ab56d915424436a5281 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Mon, 15 Sep 2025 16:22:37 +0800 Subject: [PATCH 32/32] =?UTF-8?q?fix:=20stripe=E6=94=AF=E4=BB=98=E6=88=90?= =?UTF-8?q?=E5=8A=9F=E6=9C=AA=E6=AD=A3=E7=A1=AE=E8=B7=B3=E8=BD=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/topup_stripe.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index d462acb4b..ccde91dbe 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -217,7 +217,7 @@ func genStripeLink(referenceId string, customerId string, email string, amount i params := &stripe.CheckoutSessionParams{ ClientReferenceID: stripe.String(referenceId), - SuccessURL: stripe.String(system_setting.ServerAddress + "/log"), + SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"), CancelURL: stripe.String(system_setting.ServerAddress + "/topup"), LineItems: []*stripe.CheckoutSessionLineItemParams{ {