diff --git a/.dockerignore b/.dockerignore index e4e8e72ea..0670cd7d1 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,4 +4,5 @@ .vscode .gitignore Makefile -docs \ No newline at end of file +docs +.eslintcache \ No newline at end of file diff --git a/.env.example b/.env.example index ea2464270..c7851385b 100644 --- a/.env.example +++ b/.env.example @@ -47,7 +47,7 @@ # 所有请求超时时间,单位秒,默认为0,表示不限制 # RELAY_TIMEOUT=0 # 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值 -# STREAMING_TIMEOUT=120 +# STREAMING_TIMEOUT=300 # Gemini 识别图片 最大图片数量 # GEMINI_VISION_MAX_IMAGE_NUM=16 @@ -56,8 +56,6 @@ # SESSION_SECRET=random_string # 其他配置 -# 渠道测试频率(单位:秒) -# CHANNEL_TEST_FREQUENCY=10 # 生成默认token # GENERATE_DEFAULT_TOKEN=false # Cohere 安全设置 diff --git a/.gitignore b/.gitignore index 6a23f89e1..1382829fd 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ web/dist .env one-api .DS_Store -tiktoken_cache \ No newline at end of file +tiktoken_cache +.eslintcache \ No newline at end of file diff --git a/LICENSE b/LICENSE index 261eeb9e9..71284f6d0 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,103 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ +# **New API 许可协议 (Licensing)** - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +本项目采用**基于使用场景的双重许可 (Usage-Based Dual Licensing)** 模式。 - 1. Definitions. +**核心原则:** - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. +- **默认许可:** 本项目默认在 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)** 下提供。任何用户在遵守 AGPLv3 条款和下述附加限制的前提下,均可免费使用。 +- **商业许可:** 在特定商业场景下,或当您希望获得 AGPLv3 之外的权利时,**必须**获取**商业许可证 (Commercial License)**。 - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. +--- - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. +## **1. 开源许可证 (Open Source License): AGPLv3 - 适用于基础使用** - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. +- 在遵守 **AGPLv3** 条款的前提下,您可以自由地使用、修改和分发 New API。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。 +- **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 New API 并通过网络提供服务 (SaaS),或者分发了修改后的版本,您必须以 AGPLv3 许可证向所有用户提供相应的**完整源代码**。 +- **附加限制 (重要):** 在仅使用 AGPLv3 开源许可证的情况下,您**必须**完整保留项目代码中原有的品牌标识、LOGO 及版权声明信息。**禁止以任何形式修改、移除或遮盖**这些信息。如需移除,必须获取商业许可证。 +- 使用前请务必仔细阅读并理解 AGPLv3 的所有条款及上述附加限制。 - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. +## **2. 商业许可证 (Commercial License) - 适用于高级场景及闭源需求** - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. +在以下任一情况下,您**必须**联系我们获取并签署一份商业许可证,才能合法使用 New API: - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). +- **场景一:移除品牌和版权信息** + 您希望在您的产品或服务中移除 New API 的 LOGO、UI界面中的版权声明或其他品牌标识。 - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. +- **场景二:规避 AGPLv3 开源义务** + 您基于 New API 进行了修改,并希望: + - 通过网络提供服务(SaaS),但**不希望**向您的服务用户公开您修改后的源代码。 + - 分发一个集成了 New API 的软件产品,但**不希望**以 AGPLv3 许可证发布您的产品或公开源代码。 - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." +- **场景三:企业政策与集成需求** + - 您所在公司的政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件。 + - 您需要进行 OEM 集成,将 New API 作为您闭源商业产品的一部分进行再分发。 - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. +- **场景四:需要商业支持与保障** + 您需要 AGPLv3 未提供的商业保障,如官方技术支持等。 - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. +**获取商业许可:** +请通过电子邮件 **support@quantumnous.com** 联系 New API 团队洽谈商业授权事宜。 - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. +## **3. 贡献 (Contributions)** - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: +- 我们欢迎社区对 New API 的贡献。所有向本项目提交的贡献(例如通过 Pull Request)都将被视为在 **AGPLv3** 许可证下提供。 +- 通过向本项目提交贡献,即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。 +- 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 New API 版本中。 - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and +## **4. 其他条款 (Other Terms)** - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and +- 关于商业许可证的具体条款、条件和价格,以双方签署的正式商业许可协议为准。 +- 项目维护者保留根据需要更新本许可政策的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。 - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and +--- - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. +# **New API Licensing** - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. +This project uses a **Usage-Based Dual Licensing** model. - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. +**Core Principles:** - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. +- **Default License:** This project is available by default under the **GNU Affero General Public License v3.0 (AGPLv3)**. Any user may use it free of charge, provided they comply with both the AGPLv3 terms and the additional restrictions listed below. +- **Commercial License:** For specific commercial scenarios, or if you require rights beyond those granted by AGPLv3, you **must** obtain a **Commercial License**. - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. +--- - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. +## **1. Open Source License: AGPLv3 – For Basic Usage** - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. +- Under the terms of the **AGPLv3**, you are free to use, modify, and distribute New API. The complete AGPLv3 license text can be viewed at [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html). +- **Core Obligation:** A key AGPLv3 requirement is that if you modify New API and provide it as a network service (SaaS), or distribute a modified version, you must make the **complete corresponding source code** available to all users under the AGPLv3 license. +- **Additional Restriction (Important):** When using only the AGPLv3 open-source license, you **must** retain all original branding, logos, and copyright statements within the project’s code. **You are strictly prohibited from modifying, removing, or concealing** any such information. If you wish to remove this, you must obtain a Commercial License. +- Please read and ensure that you fully understand all AGPLv3 terms and the above additional restriction before use. - END OF TERMS AND CONDITIONS +## **2. Commercial License – For Advanced Scenarios & Closed Source Needs** - APPENDIX: How to apply the Apache License to your work. +You **must** contact us to obtain and sign a Commercial License in any of the following scenarios in order to legally use New API: - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. +- **Scenario 1: Removal of Branding and Copyright** + You wish to remove the New API logo, copyright statement, or other branding elements from your product or service. - Copyright [yyyy] [name of copyright owner] +- **Scenario 2: Avoidance of AGPLv3 Open Source Obligations** + You have modified New API and wish to: + - Offer it as a network service (SaaS) **without** disclosing your modifications' source code to your users. + - Distribute a software product integrated with New API **without** releasing your product under AGPLv3 or open-sourcing the code. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +- **Scenario 3: Enterprise Policy & Integration Needs** + - Your organization’s policies, client contracts, or project requirements prohibit the use of AGPLv3-licensed software. + - You require OEM integration and need to redistribute New API as part of your closed-source commercial product. - http://www.apache.org/licenses/LICENSE-2.0 +- **Scenario 4: Commercial Support and Assurances** + You require commercial assurances not provided by AGPLv3, such as official technical support. - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +**Obtaining a Commercial License:** +Please contact the New API team via email at **support@quantumnous.com** to discuss commercial licensing. + +## **3. Contributions** + +- We welcome community contributions to New API. All contributions (e.g., via Pull Request) are deemed to be provided under the **AGPLv3** license. +- By submitting a contribution, you agree that your code is licensed to this project and all downstream users under the AGPLv3 license (regardless of whether those users ultimately operate under AGPLv3 or a Commercial License). +- You also acknowledge and agree that your contribution may be included in New API releases distributed under a Commercial License. + +## **4. Other Terms** + +- The specific terms, conditions, and pricing of the Commercial License are governed by the formal commercial license agreement executed by both parties. +- Project maintainers reserve the right to update this licensing policy as needed. Updates will be communicated via official project channels (e.g., repository, official website). diff --git a/README.md b/README.md index 48218cd70..d68b3e135 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do 1. OpenAI Chat Completions => Claude Messages 2. Clade Messages => OpenAI Chat Completions (可用于Claude Code调用第三方模型) 3. OpenAI Chat Completions => Gemini Chat -20. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费: +19. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费: 1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项 2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费 3. 支持的渠道: diff --git a/common/api_type.go b/common/api_type.go index f045866ac..5ac46c863 100644 --- a/common/api_type.go +++ b/common/api_type.go @@ -65,6 +65,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = constant.APITypeCoze case constant.ChannelTypeJimeng: apiType = constant.APITypeJimeng + case constant.ChannelTypeMoonshot: + apiType = constant.APITypeMoonshot } if apiType == -1 { return constant.APITypeOpenAI, false diff --git a/common/constants.go b/common/constants.go index 305224115..e6d59d101 100644 --- a/common/constants.go +++ b/common/constants.go @@ -83,6 +83,7 @@ var GitHubClientId = "" var GitHubClientSecret = "" var LinuxDOClientId = "" var LinuxDOClientSecret = "" +var LinuxDOMinimumTrustLevel = 0 var WeChatServerAddress = "" var WeChatServerToken = "" diff --git a/common/copy.go b/common/copy.go new file mode 100644 index 000000000..3edb2fa25 --- /dev/null +++ b/common/copy.go @@ -0,0 +1,19 @@ +package common + +import ( + "fmt" + + "github.com/jinzhu/copier" +) + +func DeepCopy[T any](src *T) (*T, error) { + if src == nil { + return nil, fmt.Errorf("copy source cannot be nil") + } + var dst T + err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true}) + if err != nil { + return nil, err + } + return &dst, nil +} diff --git a/common/custom-event.go b/common/custom-event.go index d8f9ec9fb..256db5469 100644 --- a/common/custom-event.go +++ b/common/custom-event.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "sync" ) type stringWriter interface { @@ -52,6 +53,8 @@ type CustomEvent struct { Id string Retry uint Data interface{} + + Mutex sync.Mutex } func encode(writer io.Writer, event CustomEvent) error { @@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error { } func (r CustomEvent) WriteContentType(w http.ResponseWriter) { + r.Mutex.Lock() + defer r.Mutex.Unlock() header := w.Header() header["Content-Type"] = contentType diff --git a/common/database.go b/common/database.go index 9cbaf46a7..71dbd94d5 100644 --- a/common/database.go +++ b/common/database.go @@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries var UsingMySQL = false var UsingClickHouse = false -var SQLitePath = "one-api.db?_busy_timeout=5000" +var SQLitePath = "one-api.db?_busy_timeout=30000" diff --git a/common/endpoint_defaults.go b/common/endpoint_defaults.go new file mode 100644 index 000000000..ffc263507 --- /dev/null +++ b/common/endpoint_defaults.go @@ -0,0 +1,32 @@ +package common + +import "one-api/constant" + +// EndpointInfo 描述单个端点的默认请求信息 +// path: 上游路径 +// method: HTTP 请求方式,例如 POST/GET +// 目前均为 POST,后续可扩展 +// +// json 标签用于直接序列化到 API 输出 +// 例如:{"path":"/v1/chat/completions","method":"POST"} + +type EndpointInfo struct { + Path string `json:"path"` + Method string `json:"method"` +} + +// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method +var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{ + constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"}, + constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"}, + constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"}, + constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"}, + constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"}, + constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"}, +} + +// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在 +func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) { + info, ok := defaultEndpointInfoMap[et] + return info, ok +} diff --git a/common/gin.go b/common/gin.go index 8c67bb4d5..2cb358444 100644 --- a/common/gin.go +++ b/common/gin.go @@ -2,12 +2,13 @@ package common import ( "bytes" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/constant" "strings" "time" + + "github.com/gin-gonic/gin" ) const KeyRequestBody = "key_request_body" @@ -31,6 +32,9 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return err } + //if DebugEnabled { + // println("UnmarshalBodyReusable request body:", string(requestBody)) + //} contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = Unmarshal(requestBody, &v) diff --git a/common/init.go b/common/init.go index d70a09dd1..c4626f9ae 100644 --- a/common/init.go +++ b/common/init.go @@ -101,7 +101,7 @@ func InitEnv() { } func initConstantEnv() { - constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120) + constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300) constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true) constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) // ForceStreamOption 覆盖请求参数,强制返回usage信息 diff --git a/common/json.go b/common/json.go index 69aa952e9..13e23a460 100644 --- a/common/json.go +++ b/common/json.go @@ -20,3 +20,25 @@ func DecodeJson(reader *bytes.Reader, v any) error { func Marshal(v any) ([]byte, error) { return json.Marshal(v) } + +func GetJsonType(data json.RawMessage) string { + data = bytes.TrimSpace(data) + if len(data) == 0 { + return "unknown" + } + firstChar := bytes.TrimSpace(data)[0] + switch firstChar { + case '{': + return "object" + case '[': + return "array" + case '"': + return "string" + case 't', 'f': + return "boolean" + case 'n': + return "null" + default: + return "number" + } +} diff --git a/common/page_info.go b/common/page_info.go index 5e4535e3d..2378a5d81 100644 --- a/common/page_info.go +++ b/common/page_info.go @@ -41,7 +41,7 @@ func (p *PageInfo) SetItems(items any) { func GetPageQuery(c *gin.Context) *PageInfo { pageInfo := &PageInfo{} // 手动获取并处理每个参数 - if page, err := strconv.Atoi(c.Query("page")); err == nil { + if page, err := strconv.Atoi(c.Query("p")); err == nil { pageInfo.Page = page } if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil { diff --git a/common/quota.go b/common/quota.go new file mode 100644 index 000000000..dfd65d273 --- /dev/null +++ b/common/quota.go @@ -0,0 +1,5 @@ +package common + +func GetTrustQuota() int { + return int(10 * QuotaPerUnit) +} diff --git a/common/str.go b/common/str.go index 88b58c720..6debce28b 100644 --- a/common/str.go +++ b/common/str.go @@ -4,7 +4,10 @@ import ( "encoding/base64" "encoding/json" "math/rand" + "net/url" + "regexp" "strconv" + "strings" "unsafe" ) @@ -95,3 +98,140 @@ func GetJsonString(data any) string { b, _ := json.Marshal(data) return string(b) } + +// MaskEmail masks a user email to prevent PII leakage in logs +// Returns "***masked***" if email is empty, otherwise shows only the domain part +func MaskEmail(email string) string { + if email == "" { + return "***masked***" + } + + // Find the @ symbol + atIndex := strings.Index(email, "@") + if atIndex == -1 { + // No @ symbol found, return masked + return "***masked***" + } + + // Return only the domain part with @ symbol + return "***@" + email[atIndex+1:] +} + +// maskHostTail returns the tail parts of a domain/host that should be preserved. +// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD. +func maskHostTail(parts []string) []string { + if len(parts) < 2 { + return parts + } + lastPart := parts[len(parts)-1] + secondLastPart := parts[len(parts)-2] + if len(lastPart) == 2 && len(secondLastPart) <= 3 { + // Likely country code TLD like co.uk, com.cn + return []string{secondLastPart, lastPart} + } + return []string{lastPart} +} + +// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail. +// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk +func maskHostForURL(host string) string { + parts := strings.Split(host, ".") + if len(parts) < 2 { + return "***" + } + tail := maskHostTail(parts) + return "***." + strings.Join(tail, ".") +} + +// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***. +// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk +func maskHostForPlainDomain(domain string) string { + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return domain + } + tail := maskHostTail(parts) + numStars := len(parts) - len(tail) + if numStars < 1 { + numStars = 1 + } + stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".") + return stars + "." + strings.Join(tail, ".") +} + +// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string +// Example: +// http://example.com -> http://***.com +// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=*** +// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/*** +// 192.168.1.1 -> ***.***.***.*** +// openai.com -> ***.com +// www.openai.com -> ***.***.com +// api.openai.com -> ***.***.com +func MaskSensitiveInfo(str string) string { + // Mask URLs + urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`) + str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string { + u, err := url.Parse(urlStr) + if err != nil { + return urlStr + } + + host := u.Host + if host == "" { + return urlStr + } + + // Mask host with unified logic + maskedHost := maskHostForURL(host) + + result := u.Scheme + "://" + maskedHost + + // Mask path + if u.Path != "" && u.Path != "/" { + pathParts := strings.Split(strings.Trim(u.Path, "/"), "/") + maskedPathParts := make([]string, len(pathParts)) + for i := range pathParts { + if pathParts[i] != "" { + maskedPathParts[i] = "***" + } + } + if len(maskedPathParts) > 0 { + result += "/" + strings.Join(maskedPathParts, "/") + } + } else if u.Path == "/" { + result += "/" + } + + // Mask query parameters + if u.RawQuery != "" { + values, err := url.ParseQuery(u.RawQuery) + if err != nil { + // If can't parse query, just mask the whole query string + result += "?***" + } else { + maskedParams := make([]string, 0, len(values)) + for key := range values { + maskedParams = append(maskedParams, key+"=***") + } + if len(maskedParams) > 0 { + result += "?" + strings.Join(maskedParams, "&") + } + } + } + + return result + }) + + // Mask domain names without protocol (like openai.com, www.openai.com) + domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`) + str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string { + return maskHostForPlainDomain(domain) + }) + + // Mask IP addresses + ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`) + str = ipPattern.ReplaceAllString(str, "***.***.***.***") + + return str +} diff --git a/common/sys_log.go b/common/sys_log.go new file mode 100644 index 000000000..478015f07 --- /dev/null +++ b/common/sys_log.go @@ -0,0 +1,24 @@ +package common + +import ( + "fmt" + "github.com/gin-gonic/gin" + "os" + "time" +) + +func SysLog(s string) { + t := time.Now() + _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) +} + +func SysError(s string) { + t := time.Now() + _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) +} + +func FatalLog(v ...any) { + t := time.Now() + _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) + os.Exit(1) +} diff --git a/common/totp.go b/common/totp.go new file mode 100644 index 000000000..400f9d05c --- /dev/null +++ b/common/totp.go @@ -0,0 +1,150 @@ +package common + +import ( + "crypto/rand" + "fmt" + "os" + "strconv" + "strings" + + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" +) + +const ( + // 备用码配置 + BackupCodeLength = 8 // 备用码长度 + BackupCodeCount = 4 // 生成备用码数量 + + // 限制配置 + MaxFailAttempts = 5 // 最大失败尝试次数 + LockoutDuration = 300 // 锁定时间(秒) +) + +// GenerateTOTPSecret 生成TOTP密钥和配置 +func GenerateTOTPSecret(accountName string) (*otp.Key, error) { + issuer := Get2FAIssuer() + return totp.Generate(totp.GenerateOpts{ + Issuer: issuer, + AccountName: accountName, + Period: 30, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) +} + +// ValidateTOTPCode 验证TOTP验证码 +func ValidateTOTPCode(secret, code string) bool { + // 清理验证码格式 + cleanCode := strings.ReplaceAll(code, " ", "") + if len(cleanCode) != 6 { + return false + } + + // 验证验证码 + return totp.Validate(cleanCode, secret) +} + +// GenerateBackupCodes 生成备用恢复码 +func GenerateBackupCodes() ([]string, error) { + codes := make([]string, BackupCodeCount) + + for i := 0; i < BackupCodeCount; i++ { + code, err := generateRandomBackupCode() + if err != nil { + return nil, err + } + codes[i] = code + } + + return codes, nil +} + +// generateRandomBackupCode 生成单个备用码 +func generateRandomBackupCode() (string, error) { + const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + code := make([]byte, BackupCodeLength) + + for i := range code { + randomBytes := make([]byte, 1) + _, err := rand.Read(randomBytes) + if err != nil { + return "", err + } + code[i] = charset[int(randomBytes[0])%len(charset)] + } + + // 格式化为 XXXX-XXXX 格式 + return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil +} + +// ValidateBackupCode 验证备用码格式 +func ValidateBackupCode(code string) bool { + // 移除所有分隔符并转为大写 + cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", "")) + if len(cleanCode) != BackupCodeLength { + return false + } + + // 检查字符是否合法 + for _, char := range cleanCode { + if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) { + return false + } + } + + return true +} + +// NormalizeBackupCode 标准化备用码格式 +func NormalizeBackupCode(code string) string { + cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", "")) + if len(cleanCode) == BackupCodeLength { + return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:]) + } + return code +} + +// HashBackupCode 对备用码进行哈希 +func HashBackupCode(code string) (string, error) { + normalizedCode := NormalizeBackupCode(code) + return Password2Hash(normalizedCode) +} + +// Get2FAIssuer 获取2FA发行者名称 +func Get2FAIssuer() string { + return SystemName +} + +// getEnvOrDefault 获取环境变量或默认值 +func getEnvOrDefault(key, defaultValue string) string { + if value, exists := os.LookupEnv(key); exists { + return value + } + return defaultValue +} + +// ValidateNumericCode 验证数字验证码格式 +func ValidateNumericCode(code string) (string, error) { + // 移除空格 + code = strings.ReplaceAll(code, " ", "") + + if len(code) != 6 { + return "", fmt.Errorf("验证码必须是6位数字") + } + + // 检查是否为纯数字 + if _, err := strconv.Atoi(code); err != nil { + return "", fmt.Errorf("验证码只能包含数字") + } + + return code, nil +} + +// GenerateQRCodeData 生成二维码数据 +func GenerateQRCodeData(secret, username string) string { + issuer := Get2FAIssuer() + accountName := fmt.Sprintf("%s (%s)", username, issuer) + return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30", + issuer, accountName, secret, issuer) +} diff --git a/common/utils.go b/common/utils.go index 17aecd950..883abfd1a 100644 --- a/common/utils.go +++ b/common/utils.go @@ -123,8 +123,16 @@ func Interface2String(inter interface{}) string { return fmt.Sprintf("%d", inter.(int)) case float64: return fmt.Sprintf("%f", inter.(float64)) + case bool: + if inter.(bool) { + return "true" + } else { + return "false" + } + case nil: + return "" } - return "Not Implemented" + return fmt.Sprintf("%v", inter) } func UnescapeHTML(x string) interface{} { @@ -257,32 +265,32 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64 if err != nil { return 0, errors.Wrap(err, "failed to get audio duration") } - durationStr := string(bytes.TrimSpace(output)) - if durationStr == "N/A" { - // Create a temporary output file name - tmpFp, err := os.CreateTemp("", "audio-*"+ext) - if err != nil { - return 0, errors.Wrap(err, "failed to create temporary file") - } - tmpName := tmpFp.Name() - // Close immediately so ffmpeg can open the file on Windows. - _ = tmpFp.Close() - defer os.Remove(tmpName) + durationStr := string(bytes.TrimSpace(output)) + if durationStr == "N/A" { + // Create a temporary output file name + tmpFp, err := os.CreateTemp("", "audio-*"+ext) + if err != nil { + return 0, errors.Wrap(err, "failed to create temporary file") + } + tmpName := tmpFp.Name() + // Close immediately so ffmpeg can open the file on Windows. + _ = tmpFp.Close() + defer os.Remove(tmpName) - // ffmpeg -y -i filename -vcodec copy -acodec copy - ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName) - if err := ffmpegCmd.Run(); err != nil { - return 0, errors.Wrap(err, "failed to run ffmpeg") - } + // ffmpeg -y -i filename -vcodec copy -acodec copy + ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName) + if err := ffmpegCmd.Run(); err != nil { + return 0, errors.Wrap(err, "failed to run ffmpeg") + } - // Recalculate the duration of the new file - c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName) - output, err := c.Output() - if err != nil { - return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg") - } - durationStr = string(bytes.TrimSpace(output)) - } + // Recalculate the duration of the new file + c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName) + output, err := c.Output() + if err != nil { + return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg") + } + durationStr = string(bytes.TrimSpace(output)) + } return strconv.ParseFloat(durationStr, 64) } diff --git a/constant/api_type.go b/constant/api_type.go index 6ba5f2574..f62d91d53 100644 --- a/constant/api_type.go +++ b/constant/api_type.go @@ -31,5 +31,6 @@ const ( APITypeXai APITypeCoze APITypeJimeng - APITypeDummy // this one is only for count, do not add any channel after this + APITypeMoonshot // this one is only for count, do not add any channel after this + APITypeDummy // this one is only for count, do not add any channel after this ) diff --git a/constant/channel.go b/constant/channel.go index 224121e70..2e1cc5b07 100644 --- a/constant/channel.go +++ b/constant/channel.go @@ -49,6 +49,7 @@ const ( ChannelTypeCoze = 49 ChannelTypeKling = 50 ChannelTypeJimeng = 51 + ChannelTypeVidu = 52 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -106,4 +107,5 @@ var ChannelBaseURLs = []string{ "https://api.coze.cn", //49 "https://api.klingai.com", //50 "https://visual.volcengineapi.com", //51 + "https://api.vidu.cn", //52 } diff --git a/constant/context_key.go b/constant/context_key.go index 4eaf3d007..f7640272c 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -3,6 +3,9 @@ package constant type ContextKey string const ( + ContextKeyTokenCountMeta ContextKey = "token_count_meta" + ContextKeyPromptTokens ContextKey = "prompt_tokens" + ContextKeyOriginalModel ContextKey = "original_model" ContextKeyRequestStartTime ContextKey = "request_start_time" @@ -11,7 +14,6 @@ const ( ContextKeyTokenKey ContextKey = "token_key" ContextKeyTokenId ContextKey = "token_id" ContextKeyTokenGroup ContextKey = "token_group" - ContextKeyTokenAllowIps ContextKey = "allow_ips" ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id" ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled" ContextKeyTokenModelLimit ContextKey = "token_model_limit" @@ -23,7 +25,9 @@ const ( ContextKeyChannelBaseUrl ContextKey = "base_url" ContextKeyChannelType ContextKey = "channel_type" ContextKeyChannelSetting ContextKey = "channel_setting" + ContextKeyChannelOtherSetting ContextKey = "channel_other_setting" ContextKeyChannelParamOverride ContextKey = "param_override" + ContextKeyChannelHeaderOverride ContextKey = "header_override" ContextKeyChannelOrganization ContextKey = "channel_organization" ContextKeyChannelAutoBan ContextKey = "auto_ban" ContextKeyChannelModelMapping ContextKey = "model_mapping" @@ -41,4 +45,6 @@ const ( ContextKeyUserGroup ContextKey = "user_group" ContextKeyUsingGroup ContextKey = "group" ContextKeyUserName ContextKey = "username" + + ContextKeySystemPromptOverride ContextKey = "system_prompt_override" ) diff --git a/constant/task.go b/constant/task.go index e7af39a6e..21790145b 100644 --- a/constant/task.go +++ b/constant/task.go @@ -5,8 +5,6 @@ type TaskPlatform string const ( TaskPlatformSuno TaskPlatform = "suno" TaskPlatformMidjourney = "mj" - TaskPlatformKling TaskPlatform = "kling" - TaskPlatformJimeng TaskPlatform = "jimeng" ) const ( diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 5152e0608..18acf2319 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -135,7 +135,11 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He for k := range headers { req.Header.Add(k, headers.Get(k)) } - res, err := service.GetHttpClient().Do(req) + client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy) + if err != nil { + return nil, err + } + res, err := client.Do(req) if err != nil { return nil, err } diff --git a/controller/channel-test.go b/controller/channel-test.go index 8c4a26ae6..5fc6d749c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,7 @@ import ( relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/setting/operation_setting" "one-api/types" "strconv" "strings" @@ -69,6 +70,12 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: nil, } } + if channel.Type == constant.ChannelTypeVidu { + return testResult{ + localErr: errors.New("vidu channel test is not supported"), + newAPIError: nil, + } + } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -126,10 +133,27 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: newAPIError, } } + request := buildTestRequest(testModel) - info := relaycommon.GenRelayInfo(c) + // Determine relay format based on request path + relayFormat := types.RelayFormatOpenAI + if c.Request.URL.Path == "/v1/embeddings" { + relayFormat = types.RelayFormatEmbedding + } - err = helper.ModelMappedHelper(c, info, nil) + info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) + + if err != nil { + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed), + } + } + + info.InitChannelMeta(c) + + err = helper.ModelMappedHelper(c, info, request) if err != nil { return testResult{ context: c, @@ -137,7 +161,9 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError), } } + testModel = info.UpstreamModelName + request.Model = testModel apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) @@ -149,13 +175,12 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } - request := buildTestRequest(testModel) - // 创建一个用于日志的 info 副本,移除 ApiKey - logInfo := *info - logInfo.ApiKey = "" - common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) + //// 创建一个用于日志的 info 副本,移除 ApiKey + //logInfo := info + //logInfo.ApiKey = "" + common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString())) - priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) + priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta()) if err != nil { return testResult{ context: c, @@ -203,7 +228,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: err, - newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed), + newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError), } } var httpResp *http.Response @@ -214,7 +239,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: err, - newAPIError: types.NewError(err, types.ErrorCodeBadResponse), + newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError), } } } @@ -230,7 +255,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: errors.New("usage is nil"), - newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody), + newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError), } } usage := usageA.(*dto.Usage) @@ -240,7 +265,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: err, - newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed), + newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), } } info.PromptTokens = usage.PromptTokens @@ -269,7 +294,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { Quota: quota, Content: "模型测试", UseTimeSeconds: int(consumedTime), - IsStream: false, + IsStream: info.IsStream, Group: info.UsingGroup, Other: other, }) @@ -326,8 +351,11 @@ func TestChannel(c *gin.Context) { } channel, err := model.CacheGetChannel(channelId) if err != nil { - common.ApiError(c, err) - return + channel, err = model.GetChannelById(channelId, true) + if err != nil { + common.ApiError(c, err) + return + } } //defer func() { // if channel.ChannelInfo.IsMultiKey { @@ -411,14 +439,14 @@ func testAllChannels(notify bool) error { if common.AutomaticDisableChannelEnabled && !shouldBanChannel { if milliseconds > disableThreshold { err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded) + newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout) shouldBanChannel = true } } // disable channel if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { - go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) + processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) } // enable channel @@ -450,15 +478,26 @@ func TestAllChannels(c *gin.Context) { return } -func AutomaticallyTestChannels(frequency int) { - if frequency <= 0 { - common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") - return - } - for { - time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("testing all channels") - _ = testAllChannels(false) - common.SysLog("channel test finished") - } +var autoTestChannelsOnce sync.Once + +func AutomaticallyTestChannels() { + autoTestChannelsOnce.Do(func() { + for { + if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled { + time.Sleep(10 * time.Minute) + continue + } + frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes + common.SysLog(fmt.Sprintf("automatically test channels with interval %d minutes", frequency)) + for { + time.Sleep(time.Duration(frequency) * time.Minute) + common.SysLog("automatically testing all channels") + _ = testAllChannels(false) + common.SysLog("automatically channel test finished") + if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled { + break + } + } + } + }) } diff --git a/controller/channel.go b/controller/channel.go index d3bfa202a..70be91d42 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -52,6 +52,13 @@ func parseStatusFilter(statusParam string) int { } } +func clearChannelInfo(channel *model.Channel) { + if channel.ChannelInfo.IsMultiKey { + channel.ChannelInfo.MultiKeyDisabledReason = nil + channel.ChannelInfo.MultiKeyDisabledTime = nil + } +} + func GetAllChannels(c *gin.Context) { pageInfo := common.GetPageQuery(c) channelData := make([]*model.Channel, 0) @@ -126,6 +133,10 @@ func GetAllChannels(c *gin.Context) { } } + for _, datum := range channelData { + clearChannelInfo(datum) + } + countQuery := model.DB.Model(&model.Channel{}) if statusFilter == common.ChannelStatusEnabled { countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled) @@ -168,14 +179,26 @@ func FetchUpstreamModels(c *gin.Context) { if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } - url := fmt.Sprintf("%s/v1/models", baseURL) + + var url string switch channel.Type { case constant.ChannelTypeGemini: - url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) + // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY + url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader case constant.ChannelTypeAli: url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) + default: + url = fmt.Sprintf("%s/v1/models", baseURL) + } + + // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader + var body []byte + key := strings.Split(channel.Key, "\n")[0] + if channel.Type == constant.ChannelTypeGemini { + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it + } else { + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) } - body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { common.ApiError(c, err) return @@ -319,6 +342,10 @@ func SearchChannels(c *gin.Context) { pagedData := channelData[startIdx:endIdx] + for _, datum := range pagedData { + clearChannelInfo(datum) + } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -342,6 +369,9 @@ func GetChannel(c *gin.Context) { common.ApiError(c, err) return } + if channel != nil { + clearChannelInfo(channel) + } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -350,6 +380,85 @@ func GetChannel(c *gin.Context) { return } +// GetChannelKey 验证2FA后获取渠道密钥 +func GetChannelKey(c *gin.Context) { + type GetChannelKeyRequest struct { + Code string `json:"code" binding:"required"` + } + + var req GetChannelKeyRequest + if err := c.ShouldBindJSON(&req); err != nil { + common.ApiError(c, fmt.Errorf("参数错误: %v", err)) + return + } + + userId := c.GetInt("id") + channelId, err := strconv.Atoi(c.Param("id")) + if err != nil { + common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err)) + return + } + + // 获取2FA记录并验证 + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, fmt.Errorf("获取2FA信息失败: %v", err)) + return + } + + if twoFA == nil || !twoFA.IsEnabled { + common.ApiError(c, fmt.Errorf("用户未启用2FA,无法查看密钥")) + return + } + + // 统一的2FA验证逻辑 + if !validateTwoFactorAuth(twoFA, req.Code) { + common.ApiError(c, fmt.Errorf("验证码或备用码错误,请重试")) + return + } + + // 获取渠道信息(包含密钥) + channel, err := model.GetChannelById(channelId, true) + if err != nil { + common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err)) + return + } + + if channel == nil { + common.ApiError(c, fmt.Errorf("渠道不存在")) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId)) + + // 统一的成功响应格式 + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "验证成功", + "data": map[string]interface{}{ + "key": channel.Key, + }, + }) +} + +// validateTwoFactorAuth 统一的2FA验证函数 +func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool { + // 尝试验证TOTP + if cleanCode, err := common.ValidateNumericCode(code); err == nil { + if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid { + return true + } + } + + // 尝试验证备用码 + if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid { + return true + } + + return false +} + // validateChannel 通用的渠道校验函数 func validateChannel(channel *model.Channel, isAdd bool) error { // 校验 channel settings @@ -669,6 +778,7 @@ func DeleteChannelBatch(c *gin.Context) { type PatchChannel struct { model.Channel MultiKeyMode *string `json:"multi_key_mode"` + KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加 } func UpdateChannel(c *gin.Context) { @@ -688,7 +798,7 @@ func UpdateChannel(c *gin.Context) { return } // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request. - originChannel, err := model.GetChannelById(channel.Id, false) + originChannel, err := model.GetChannelById(channel.Id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -704,6 +814,69 @@ func UpdateChannel(c *gin.Context) { if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" { channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode) } + + // 处理多key模式下的密钥追加/覆盖逻辑 + if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey { + switch *channel.KeyMode { + case "append": + // 追加模式:将新密钥添加到现有密钥列表 + if originChannel.Key != "" { + var newKeys []string + var existingKeys []string + + // 解析现有密钥 + if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") { + // JSON数组格式 + var arr []json.RawMessage + if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil { + existingKeys = make([]string, len(arr)) + for i, v := range arr { + existingKeys[i] = string(v) + } + } + } else { + // 换行分隔格式 + existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n") + } + + // 处理 Vertex AI 的特殊情况 + if channel.Type == constant.ChannelTypeVertexAi { + // 尝试解析新密钥为JSON数组 + if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") { + array, err := getVertexArrayKeys(channel.Key) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "追加密钥解析失败: " + err.Error(), + }) + return + } + newKeys = array + } else { + // 单个JSON密钥 + newKeys = []string{channel.Key} + } + // 合并密钥 + allKeys := append(existingKeys, newKeys...) + channel.Key = strings.Join(allKeys, "\n") + } else { + // 普通渠道的处理 + inputKeys := strings.Split(channel.Key, "\n") + for _, key := range inputKeys { + key = strings.TrimSpace(key) + if key != "" { + newKeys = append(newKeys, key) + } + } + // 合并密钥 + allKeys := append(existingKeys, newKeys...) + channel.Key = strings.Join(allKeys, "\n") + } + } + case "replace": + // 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理) + } + } err = channel.Update() if err != nil { common.ApiError(c, err) @@ -711,6 +884,7 @@ func UpdateChannel(c *gin.Context) { } model.InitChannelCache() channel.Key = "" + clearChannelInfo(&channel.Channel) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -914,3 +1088,413 @@ func CopyChannel(c *gin.Context) { // success c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}}) } + +// MultiKeyManageRequest represents the request for multi-key management operations +type MultiKeyManageRequest struct { + ChannelId int `json:"channel_id"` + Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status" + KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions + Page int `json:"page,omitempty"` // for get_key_status pagination + PageSize int `json:"page_size,omitempty"` // for get_key_status pagination + Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all +} + +// MultiKeyStatusResponse represents the response for key status query +type MultiKeyStatusResponse struct { + Keys []KeyStatus `json:"keys"` + Total int `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalPages int `json:"total_pages"` + // Statistics + EnabledCount int `json:"enabled_count"` + ManualDisabledCount int `json:"manual_disabled_count"` + AutoDisabledCount int `json:"auto_disabled_count"` +} + +type KeyStatus struct { + Index int `json:"index"` + Status int `json:"status"` // 1: enabled, 2: disabled + DisabledTime int64 `json:"disabled_time,omitempty"` + Reason string `json:"reason,omitempty"` + KeyPreview string `json:"key_preview"` // first 10 chars of key for identification +} + +// ManageMultiKeys handles multi-key management operations +func ManageMultiKeys(c *gin.Context) { + request := MultiKeyManageRequest{} + err := c.ShouldBindJSON(&request) + if err != nil { + common.ApiError(c, err) + return + } + + channel, err := model.GetChannelById(request.ChannelId, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "渠道不存在", + }) + return + } + + if !channel.ChannelInfo.IsMultiKey { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该渠道不是多密钥模式", + }) + return + } + + lock := model.GetChannelPollingLock(channel.Id) + lock.Lock() + defer lock.Unlock() + + switch request.Action { + case "get_key_status": + keys := channel.GetKeys() + + // Default pagination parameters + page := request.Page + pageSize := request.PageSize + if page <= 0 { + page = 1 + } + if pageSize <= 0 { + pageSize = 50 // Default page size + } + + // Statistics for all keys (unchanged by filtering) + var enabledCount, manualDisabledCount, autoDisabledCount int + + // Build all key status data first + var allKeyStatusList []KeyStatus + for i, key := range keys { + status := 1 // default enabled + var disabledTime int64 + var reason string + + if channel.ChannelInfo.MultiKeyStatusList != nil { + if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists { + status = s + } + } + + // Count for statistics (all keys) + switch status { + case 1: + enabledCount++ + case 2: + manualDisabledCount++ + case 3: + autoDisabledCount++ + } + + if status != 1 { + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i] + } + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + reason = channel.ChannelInfo.MultiKeyDisabledReason[i] + } + } + + // Create key preview (first 10 chars) + keyPreview := key + if len(key) > 10 { + keyPreview = key[:10] + "..." + } + + allKeyStatusList = append(allKeyStatusList, KeyStatus{ + Index: i, + Status: status, + DisabledTime: disabledTime, + Reason: reason, + KeyPreview: keyPreview, + }) + } + + // Apply status filter if specified + var filteredKeyStatusList []KeyStatus + if request.Status != nil { + for _, keyStatus := range allKeyStatusList { + if keyStatus.Status == *request.Status { + filteredKeyStatusList = append(filteredKeyStatusList, keyStatus) + } + } + } else { + filteredKeyStatusList = allKeyStatusList + } + + // Calculate pagination based on filtered results + filteredTotal := len(filteredKeyStatusList) + totalPages := (filteredTotal + pageSize - 1) / pageSize + if totalPages == 0 { + totalPages = 1 + } + if page > totalPages { + page = totalPages + } + + // Calculate range for current page + start := (page - 1) * pageSize + end := start + pageSize + if end > filteredTotal { + end = filteredTotal + } + + // Get the page data + var pageKeyStatusList []KeyStatus + if start < filteredTotal { + pageKeyStatusList = filteredKeyStatusList[start:end] + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": MultiKeyStatusResponse{ + Keys: pageKeyStatusList, + Total: filteredTotal, // Total of filtered results + Page: page, + PageSize: pageSize, + TotalPages: totalPages, + EnabledCount: enabledCount, // Overall statistics + ManualDisabledCount: manualDisabledCount, // Overall statistics + AutoDisabledCount: autoDisabledCount, // Overall statistics + }, + }) + return + + case "disable_key": + if request.KeyIndex == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "未指定要禁用的密钥索引", + }) + return + } + + keyIndex := *request.KeyIndex + if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "密钥索引超出范围", + }) + return + } + + if channel.ChannelInfo.MultiKeyStatusList == nil { + channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) + } + if channel.ChannelInfo.MultiKeyDisabledTime == nil { + channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) + } + if channel.ChannelInfo.MultiKeyDisabledReason == nil { + channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) + } + + channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "密钥已禁用", + }) + return + + case "enable_key": + if request.KeyIndex == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "未指定要启用的密钥索引", + }) + return + } + + keyIndex := *request.KeyIndex + if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "密钥索引超出范围", + }) + return + } + + // 从状态列表中删除该密钥的记录,使其回到默认启用状态 + if channel.ChannelInfo.MultiKeyStatusList != nil { + delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex) + } + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex) + } + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex) + } + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "密钥已启用", + }) + return + + case "enable_all_keys": + // 清空所有禁用状态,使所有密钥回到默认启用状态 + var enabledCount int + if channel.ChannelInfo.MultiKeyStatusList != nil { + enabledCount = len(channel.ChannelInfo.MultiKeyStatusList) + } + + channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) + channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) + channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": fmt.Sprintf("已启用 %d 个密钥", enabledCount), + }) + return + + case "disable_all_keys": + // 禁用所有启用的密钥 + if channel.ChannelInfo.MultiKeyStatusList == nil { + channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) + } + if channel.ChannelInfo.MultiKeyDisabledTime == nil { + channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) + } + if channel.ChannelInfo.MultiKeyDisabledReason == nil { + channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) + } + + var disabledCount int + for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ { + status := 1 // default enabled + if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists { + status = s + } + + // 只禁用当前启用的密钥 + if status == 1 { + channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled + disabledCount++ + } + } + + if disabledCount == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "没有可禁用的密钥", + }) + return + } + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount), + }) + return + + case "delete_disabled_keys": + keys := channel.GetKeys() + var remainingKeys []string + var deletedCount int + var newStatusList = make(map[int]int) + var newDisabledTime = make(map[int]int64) + var newDisabledReason = make(map[int]string) + + newIndex := 0 + for i, key := range keys { + status := 1 // default enabled + if channel.ChannelInfo.MultiKeyStatusList != nil { + if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists { + status = s + } + } + + // 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥 + if status == 3 { + deletedCount++ + } else { + remainingKeys = append(remainingKeys, key) + // 保留非自动禁用密钥的状态信息,重新索引 + if status != 1 { + newStatusList[newIndex] = status + if channel.ChannelInfo.MultiKeyDisabledTime != nil { + if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists { + newDisabledTime[newIndex] = t + } + } + if channel.ChannelInfo.MultiKeyDisabledReason != nil { + if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists { + newDisabledReason[newIndex] = r + } + } + } + newIndex++ + } + } + + if deletedCount == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "没有需要删除的自动禁用密钥", + }) + return + } + + // Update channel with remaining keys + channel.Key = strings.Join(remainingKeys, "\n") + channel.ChannelInfo.MultiKeySize = len(remainingKeys) + channel.ChannelInfo.MultiKeyStatusList = newStatusList + channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime + channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason + + err = channel.Update() + if err != nil { + common.ApiError(c, err) + return + } + + model.InitChannelCache() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount), + "data": deletedCount, + }) + return + + default: + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "不支持的操作", + }) + return + } +} diff --git a/controller/console_migrate.go b/controller/console_migrate.go index d25f199b8..f0812c3d6 100644 --- a/controller/console_migrate.go +++ b/controller/console_migrate.go @@ -3,101 +3,102 @@ package controller import ( - "encoding/json" - "net/http" - "one-api/common" - "one-api/model" - "github.com/gin-gonic/gin" + "encoding/json" + "net/http" + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" ) // MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.* func MigrateConsoleSetting(c *gin.Context) { - // 读取全部 option - opts, err := model.AllOption() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) - return - } - // 建立 map - valMap := map[string]string{} - for _, o := range opts { - valMap[o.Key] = o.Value - } + // 读取全部 option + opts, err := model.AllOption() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) + return + } + // 建立 map + valMap := map[string]string{} + for _, o := range opts { + valMap[o.Key] = o.Value + } - // 处理 APIInfo - if v := valMap["ApiInfo"]; v != "" { - var arr []map[string]interface{} - if err := json.Unmarshal([]byte(v), &arr); err == nil { - if len(arr) > 50 { - arr = arr[:50] - } - bytes, _ := json.Marshal(arr) - model.UpdateOption("console_setting.api_info", string(bytes)) - } - model.UpdateOption("ApiInfo", "") - } - // Announcements 直接搬 - if v := valMap["Announcements"]; v != "" { - model.UpdateOption("console_setting.announcements", v) - model.UpdateOption("Announcements", "") - } - // FAQ 转换 - if v := valMap["FAQ"]; v != "" { - var arr []map[string]interface{} - if err := json.Unmarshal([]byte(v), &arr); err == nil { - out := []map[string]interface{}{} - for _, item := range arr { - q, _ := item["question"].(string) - if q == "" { - q, _ = item["title"].(string) - } - a, _ := item["answer"].(string) - if a == "" { - a, _ = item["content"].(string) - } - if q != "" && a != "" { - out = append(out, map[string]interface{}{"question": q, "answer": a}) - } - } - if len(out) > 50 { - out = out[:50] - } - bytes, _ := json.Marshal(out) - model.UpdateOption("console_setting.faq", string(bytes)) - } - model.UpdateOption("FAQ", "") - } - // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups) - url := valMap["UptimeKumaUrl"] - slug := valMap["UptimeKumaSlug"] - if url != "" && slug != "" { - // 仅当同时存在 URL 与 Slug 时才进行迁移 - groups := []map[string]interface{}{ - { - "id": 1, - "categoryName": "old", - "url": url, - "slug": slug, - "description": "", - }, - } - bytes, _ := json.Marshal(groups) - model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes)) - } - // 清空旧键内容 - if url != "" { - model.UpdateOption("UptimeKumaUrl", "") - } - if slug != "" { - model.UpdateOption("UptimeKumaSlug", "") - } + // 处理 APIInfo + if v := valMap["ApiInfo"]; v != "" { + var arr []map[string]interface{} + if err := json.Unmarshal([]byte(v), &arr); err == nil { + if len(arr) > 50 { + arr = arr[:50] + } + bytes, _ := json.Marshal(arr) + model.UpdateOption("console_setting.api_info", string(bytes)) + } + model.UpdateOption("ApiInfo", "") + } + // Announcements 直接搬 + if v := valMap["Announcements"]; v != "" { + model.UpdateOption("console_setting.announcements", v) + model.UpdateOption("Announcements", "") + } + // FAQ 转换 + if v := valMap["FAQ"]; v != "" { + var arr []map[string]interface{} + if err := json.Unmarshal([]byte(v), &arr); err == nil { + out := []map[string]interface{}{} + for _, item := range arr { + q, _ := item["question"].(string) + if q == "" { + q, _ = item["title"].(string) + } + a, _ := item["answer"].(string) + if a == "" { + a, _ = item["content"].(string) + } + if q != "" && a != "" { + out = append(out, map[string]interface{}{"question": q, "answer": a}) + } + } + if len(out) > 50 { + out = out[:50] + } + bytes, _ := json.Marshal(out) + model.UpdateOption("console_setting.faq", string(bytes)) + } + model.UpdateOption("FAQ", "") + } + // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups) + url := valMap["UptimeKumaUrl"] + slug := valMap["UptimeKumaSlug"] + if url != "" && slug != "" { + // 仅当同时存在 URL 与 Slug 时才进行迁移 + groups := []map[string]interface{}{ + { + "id": 1, + "categoryName": "old", + "url": url, + "slug": slug, + "description": "", + }, + } + bytes, _ := json.Marshal(groups) + model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes)) + } + // 清空旧键内容 + if url != "" { + model.UpdateOption("UptimeKumaUrl", "") + } + if slug != "" { + model.UpdateOption("UptimeKumaSlug", "") + } - // 删除旧键记录 - oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"} - model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{}) + // 删除旧键记录 + oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"} + model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{}) - // 重新加载 OptionMap - model.InitOptionMap() - common.SysLog("console setting migrated") - c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) -} \ No newline at end of file + // 重新加载 OptionMap + model.InitOptionMap() + common.SysLog("console setting migrated") + c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) +} diff --git a/controller/linuxdo.go b/controller/linuxdo.go index 65380b65a..9fa156157 100644 --- a/controller/linuxdo.go +++ b/controller/linuxdo.go @@ -220,21 +220,29 @@ func LinuxdoOAuth(c *gin.Context) { } } else { if common.RegisterEnabled { - user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1) - user.DisplayName = linuxdoUser.Name - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel { + user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1) + user.DisplayName = linuxdoUser.Name + user.Role = common.RoleCommonUser + user.Status = common.UserStatusEnabled - affCode := session.Get("aff") - inviterId := 0 - if affCode != nil { - inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) - } + affCode := session.Get("aff") + inviterId := 0 + if affCode != nil { + inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) + } - if err := user.Insert(inviterId); err != nil { + if err := user.Insert(inviterId); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": err.Error(), + "message": "Linux DO 信任等级未达到管理员设置的最低信任等级", }) return } diff --git a/controller/midjourney.go b/controller/midjourney.go index 02ad708fb..a67d39c23 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -9,6 +9,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -28,7 +29,7 @@ func UpdateMidjourneyTaskBulk() { continue } - common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) + logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) taskChannelM := make(map[int][]string) taskM := make(map[string]*model.Midjourney) nullTaskIds := make([]int, 0) @@ -47,9 +48,9 @@ func UpdateMidjourneyTaskBulk() { "progress": "100%", }) if err != nil { - common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) } else { - common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) + logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { @@ -57,20 +58,20 @@ func UpdateMidjourneyTaskBulk() { } for channelId, taskIds := range taskChannelM { - common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { continue } midjourneyChannel, err := model.CacheGetChannel(channelId) if err != nil { - common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) + logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) err := model.MjBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { - common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) + logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) } continue } @@ -81,7 +82,7 @@ func UpdateMidjourneyTaskBulk() { }) req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) continue } // 设置超时时间 @@ -93,22 +94,22 @@ func UpdateMidjourneyTaskBulk() { req.Header.Set("mj-api-secret", midjourneyChannel.Key) resp, err := service.GetHttpClient().Do(req) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue } if resp.StatusCode != http.StatusOK { - common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) continue } responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } var responseItems []dto.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) continue } resp.Body.Close() @@ -145,9 +146,25 @@ func UpdateMidjourneyTaskBulk() { buttonStr, _ := json.Marshal(responseItem.Buttons) task.Buttons = string(buttonStr) } + // 映射 VideoUrl + task.VideoUrl = responseItem.VideoUrl + + // 映射 VideoUrls - 将数组序列化为 JSON 字符串 + if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 { + videoUrlsStr, err := json.Marshal(responseItem.VideoUrls) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err)) + task.VideoUrls = "[]" // 失败时设置为空数组 + } else { + task.VideoUrls = string(videoUrlsStr) + } + } else { + task.VideoUrls = "" // 空值时清空字段 + } + shouldReturnQuota := false if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { - common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) + logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" if task.Quota != 0 { shouldReturnQuota = true @@ -155,14 +172,14 @@ func UpdateMidjourneyTaskBulk() { } err = task.Update() if err != nil { - common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) + logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) } else { if shouldReturnQuota { err = model.IncreaseUserQuota(task.UserId, task.Quota, false) if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota)) + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } @@ -208,6 +225,20 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) if oldTask.Progress != "100%" && newTask.FailReason != "" { return true } + // 检查 VideoUrl 是否需要更新 + if oldTask.VideoUrl != newTask.VideoUrl { + return true + } + // 检查 VideoUrls 是否需要更新 + if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 { + newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls) + if oldTask.VideoUrls != string(newVideoUrlsStr) { + return true + } + } else if oldTask.VideoUrls != "" { + // 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空) + return true + } return false } diff --git a/controller/misc.go b/controller/misc.go index a3ed9be9a..897dad254 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -39,48 +39,51 @@ func TestStatus(c *gin.Context) { func GetStatus(c *gin.Context) { cs := console_setting.GetConsoleSetting() + common.OptionMapRWMutex.RLock() + defer common.OptionMapRWMutex.RUnlock() data := gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": common.EmailVerificationEnabled, - "github_oauth": common.GitHubOAuthEnabled, - "github_client_id": common.GitHubClientId, - "linuxdo_oauth": common.LinuxDOOAuthEnabled, - "linuxdo_client_id": common.LinuxDOClientId, - "telegram_oauth": common.TelegramOAuthEnabled, - "telegram_bot_name": common.TelegramBotName, - "system_name": common.SystemName, - "logo": common.Logo, - "footer_html": common.Footer, - "wechat_qrcode": common.WeChatAccountQRCodeImageURL, - "wechat_login": common.WeChatAuthEnabled, - "server_address": setting.ServerAddress, - "price": setting.Price, - "stripe_unit_price": setting.StripeUnitPrice, - "min_topup": setting.MinTopUp, - "stripe_min_topup": setting.StripeMinTopUp, - "turnstile_check": common.TurnstileCheckEnabled, - "turnstile_site_key": common.TurnstileSiteKey, - "top_up_link": common.TopUpLink, - "docs_link": operation_setting.GetGeneralSetting().DocsLink, - "quota_per_unit": common.QuotaPerUnit, - "display_in_currency": common.DisplayInCurrencyEnabled, - "enable_batch_update": common.BatchUpdateEnabled, - "enable_drawing": common.DrawingEnabled, - "enable_task": common.TaskEnabled, - "enable_data_export": common.DataExportEnabled, - "data_export_default_time": common.DataExportDefaultTime, - "default_collapse_sidebar": common.DefaultCollapseSidebar, - "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", - "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "", - "mj_notify_enabled": setting.MjNotifyEnabled, - "chats": setting.Chats, - "demo_site_enabled": operation_setting.DemoSiteEnabled, - "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, - "default_use_auto_group": setting.DefaultUseAutoGroup, - "pay_methods": setting.PayMethods, - "usd_exchange_rate": setting.USDExchangeRate, + "version": common.Version, + "start_time": common.StartTime, + "email_verification": common.EmailVerificationEnabled, + "github_oauth": common.GitHubOAuthEnabled, + "github_client_id": common.GitHubClientId, + "linuxdo_oauth": common.LinuxDOOAuthEnabled, + "linuxdo_client_id": common.LinuxDOClientId, + "linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel, + "telegram_oauth": common.TelegramOAuthEnabled, + "telegram_bot_name": common.TelegramBotName, + "system_name": common.SystemName, + "logo": common.Logo, + "footer_html": common.Footer, + "wechat_qrcode": common.WeChatAccountQRCodeImageURL, + "wechat_login": common.WeChatAuthEnabled, + "server_address": setting.ServerAddress, + "price": setting.Price, + "stripe_unit_price": setting.StripeUnitPrice, + "min_topup": setting.MinTopUp, + "stripe_min_topup": setting.StripeMinTopUp, + "turnstile_check": common.TurnstileCheckEnabled, + "turnstile_site_key": common.TurnstileSiteKey, + "top_up_link": common.TopUpLink, + "docs_link": operation_setting.GetGeneralSetting().DocsLink, + "quota_per_unit": common.QuotaPerUnit, + "display_in_currency": common.DisplayInCurrencyEnabled, + "enable_batch_update": common.BatchUpdateEnabled, + "enable_drawing": common.DrawingEnabled, + "enable_task": common.TaskEnabled, + "enable_data_export": common.DataExportEnabled, + "data_export_default_time": common.DataExportDefaultTime, + "default_collapse_sidebar": common.DefaultCollapseSidebar, + "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", + "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "", + "mj_notify_enabled": setting.MjNotifyEnabled, + "chats": setting.Chats, + "demo_site_enabled": operation_setting.DemoSiteEnabled, + "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, + "default_use_auto_group": setting.DefaultUseAutoGroup, + "pay_methods": setting.PayMethods, + "usd_exchange_rate": setting.USDExchangeRate, // 面板启用开关 "api_info_enabled": cs.ApiInfoEnabled, @@ -88,6 +91,10 @@ func GetStatus(c *gin.Context) { "announcements_enabled": cs.AnnouncementsEnabled, "faq_enabled": cs.FAQEnabled, + // 模块管理配置 + "HeaderNavModules": common.OptionMap["HeaderNavModules"], + "SidebarModulesAdmin": common.OptionMap["SidebarModulesAdmin"], + "oidc_enabled": system_setting.GetOIDCSettings().Enabled, "oidc_client_id": system_setting.GetOIDCSettings().ClientId, "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint, diff --git a/controller/missing_models.go b/controller/missing_models.go new file mode 100644 index 000000000..425f9b25f --- /dev/null +++ b/controller/missing_models.go @@ -0,0 +1,27 @@ +package controller + +import ( + "net/http" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +// GetMissingModels returns the list of model names that are referenced by channels +// but do not have corresponding records in the models meta table. +// This helps administrators quickly discover models that need configuration. +func GetMissingModels(c *gin.Context) { + missing, err := model.GetMissingModels() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": missing, + }) +} diff --git a/controller/model.go b/controller/model.go index 31a66b297..f0571b995 100644 --- a/controller/model.go +++ b/controller/model.go @@ -16,6 +16,7 @@ import ( "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" "one-api/setting" + "time" ) // https://platform.openai.com/docs/api-reference/models/list @@ -92,7 +93,9 @@ func init() { if !success || apiType == constant.APITypeAIProxyLibrary { continue } - meta := &relaycommon.RelayInfo{ChannelType: i} + meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{ + ChannelType: i, + }} adaptor := relay.GetAdaptor(apiType) adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() @@ -102,7 +105,7 @@ func init() { }) } -func ListModels(c *gin.Context) { +func ListModels(c *gin.Context, modelType int) { userOpenAiModels := make([]dto.OpenAIModels, 0) modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) @@ -171,10 +174,42 @@ func ListModels(c *gin.Context) { } } } - c.JSON(200, gin.H{ - "success": true, - "data": userOpenAiModels, - }) + switch modelType { + case constant.ChannelTypeAnthropic: + useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels)) + for i, model := range userOpenAiModels { + useranthropicModels[i] = dto.AnthropicModel{ + ID: model.Id, + CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339), + DisplayName: model.Id, + Type: "model", + } + } + c.JSON(200, gin.H{ + "data": useranthropicModels, + "first_id": useranthropicModels[0].ID, + "has_more": false, + "last_id": useranthropicModels[len(useranthropicModels)-1].ID, + }) + case constant.ChannelTypeGemini: + userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels)) + for i, model := range userOpenAiModels { + userGeminiModels[i] = dto.GeminiModel{ + Name: model.Id, + DisplayName: model.Id, + } + } + c.JSON(200, gin.H{ + "models": userGeminiModels, + "nextPageToken": nil, + }) + default: + c.JSON(200, gin.H{ + "success": true, + "data": userOpenAiModels, + "object": "list", + }) + } } func ChannelListModels(c *gin.Context) { @@ -198,10 +233,20 @@ func EnabledListModels(c *gin.Context) { }) } -func RetrieveModel(c *gin.Context) { +func RetrieveModel(c *gin.Context, modelType int) { modelId := c.Param("model") if aiModel, ok := openAIModelsMap[modelId]; ok { - c.JSON(200, aiModel) + switch modelType { + case constant.ChannelTypeAnthropic: + c.JSON(200, dto.AnthropicModel{ + ID: aiModel.Id, + CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339), + DisplayName: aiModel.Id, + Type: "model", + }) + default: + c.JSON(200, aiModel) + } } else { openAIError := dto.OpenAIError{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), diff --git a/controller/model_meta.go b/controller/model_meta.go new file mode 100644 index 000000000..31ea64f35 --- /dev/null +++ b/controller/model_meta.go @@ -0,0 +1,330 @@ +package controller + +import ( + "encoding/json" + "sort" + "strconv" + "strings" + + "one-api/common" + "one-api/constant" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +// GetAllModelsMeta 获取模型列表(分页) +func GetAllModelsMeta(c *gin.Context) { + + pageInfo := common.GetPageQuery(c) + modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + // 批量填充附加字段,提升列表接口性能 + enrichModels(modelsMeta) + var total int64 + model.DB.Model(&model.Model{}).Count(&total) + + // 统计供应商计数(全部数据,不受分页影响) + vendorCounts, _ := model.GetVendorModelCounts() + + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(modelsMeta) + common.ApiSuccess(c, gin.H{ + "items": modelsMeta, + "total": total, + "page": pageInfo.GetPage(), + "page_size": pageInfo.GetPageSize(), + "vendor_counts": vendorCounts, + }) +} + +// SearchModelsMeta 搜索模型列表 +func SearchModelsMeta(c *gin.Context) { + + keyword := c.Query("keyword") + vendor := c.Query("vendor") + pageInfo := common.GetPageQuery(c) + + modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + // 批量填充附加字段,提升列表接口性能 + enrichModels(modelsMeta) + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(modelsMeta) + common.ApiSuccess(c, pageInfo) +} + +// GetModelMeta 根据 ID 获取单条模型信息 +func GetModelMeta(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + var m model.Model + if err := model.DB.First(&m, id).Error; err != nil { + common.ApiError(c, err) + return + } + enrichModels([]*model.Model{&m}) + common.ApiSuccess(c, &m) +} + +// CreateModelMeta 新建模型 +func CreateModelMeta(c *gin.Context) { + var m model.Model + if err := c.ShouldBindJSON(&m); err != nil { + common.ApiError(c, err) + return + } + if m.ModelName == "" { + common.ApiErrorMsg(c, "模型名称不能为空") + return + } + // 名称冲突检查 + if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "模型名称已存在") + return + } + + if err := m.Insert(); err != nil { + common.ApiError(c, err) + return + } + model.RefreshPricing() + common.ApiSuccess(c, &m) +} + +// UpdateModelMeta 更新模型 +func UpdateModelMeta(c *gin.Context) { + statusOnly := c.Query("status_only") == "true" + + var m model.Model + if err := c.ShouldBindJSON(&m); err != nil { + common.ApiError(c, err) + return + } + if m.Id == 0 { + common.ApiErrorMsg(c, "缺少模型 ID") + return + } + + if statusOnly { + // 只更新状态,防止误清空其他字段 + if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { + common.ApiError(c, err) + return + } + } else { + // 名称冲突检查 + if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "模型名称已存在") + return + } + + if err := m.Update(); err != nil { + common.ApiError(c, err) + return + } + } + model.RefreshPricing() + common.ApiSuccess(c, &m) +} + +// DeleteModelMeta 删除模型 +func DeleteModelMeta(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { + common.ApiError(c, err) + return + } + model.RefreshPricing() + common.ApiSuccess(c, nil) +} + +// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询 +func enrichModels(models []*model.Model) { + if len(models) == 0 { + return + } + + // 1) 拆分精确与规则匹配 + exactNames := make([]string, 0) + exactIdx := make(map[string][]int) // modelName -> indices in models + ruleIndices := make([]int, 0) + for i, m := range models { + if m == nil { + continue + } + if m.NameRule == model.NameRuleExact { + exactNames = append(exactNames, m.ModelName) + exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i) + } else { + ruleIndices = append(ruleIndices, i) + } + } + + // 2) 批量查询精确模型的绑定渠道 + channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames) + + // 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存 + for name, indices := range exactIdx { + chs := channelsByModel[name] + for _, idx := range indices { + mm := models[idx] + if mm.Endpoints == "" { + eps := model.GetModelSupportEndpointTypes(mm.ModelName) + if b, err := json.Marshal(eps); err == nil { + mm.Endpoints = string(b) + } + } + mm.BoundChannels = chs + mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName) + mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName) + } + } + + if len(ruleIndices) == 0 { + return + } + + // 4) 一次性读取定价缓存,内存匹配所有规则模型 + pricings := model.GetPricing() + + // 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合 + matchedNamesByIdx := make(map[int][]string) + endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{}) + groupSetByIdx := make(map[int]map[string]struct{}) + quotaSetByIdx := make(map[int]map[int]struct{}) + + for _, p := range pricings { + for _, idx := range ruleIndices { + mm := models[idx] + var matched bool + switch mm.NameRule { + case model.NameRulePrefix: + matched = strings.HasPrefix(p.ModelName, mm.ModelName) + case model.NameRuleSuffix: + matched = strings.HasSuffix(p.ModelName, mm.ModelName) + case model.NameRuleContains: + matched = strings.Contains(p.ModelName, mm.ModelName) + } + if !matched { + continue + } + matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName) + + es := endpointSetByIdx[idx] + if es == nil { + es = make(map[constant.EndpointType]struct{}) + endpointSetByIdx[idx] = es + } + for _, et := range p.SupportedEndpointTypes { + es[et] = struct{}{} + } + + gs := groupSetByIdx[idx] + if gs == nil { + gs = make(map[string]struct{}) + groupSetByIdx[idx] = gs + } + for _, g := range p.EnableGroup { + gs[g] = struct{}{} + } + + qs := quotaSetByIdx[idx] + if qs == nil { + qs = make(map[int]struct{}) + quotaSetByIdx[idx] = qs + } + qs[p.QuotaType] = struct{}{} + } + } + + // 5) 汇总所有匹配到的模型名称,批量查询一次渠道 + allMatchedSet := make(map[string]struct{}) + for _, names := range matchedNamesByIdx { + for _, n := range names { + allMatchedSet[n] = struct{}{} + } + } + allMatched := make([]string, 0, len(allMatchedSet)) + for n := range allMatchedSet { + allMatched = append(allMatched, n) + } + matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched) + + // 6) 回填每个规则模型的并集信息 + for _, idx := range ruleIndices { + mm := models[idx] + + // 端点并集 -> 序列化 + if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" { + eps := make([]constant.EndpointType, 0, len(es)) + for et := range es { + eps = append(eps, et) + } + if b, err := json.Marshal(eps); err == nil { + mm.Endpoints = string(b) + } + } + + // 分组并集 + if gs, ok := groupSetByIdx[idx]; ok { + groups := make([]string, 0, len(gs)) + for g := range gs { + groups = append(groups, g) + } + mm.EnableGroups = groups + } + + // 配额类型集合(保持去重并排序) + if qs, ok := quotaSetByIdx[idx]; ok { + arr := make([]int, 0, len(qs)) + for k := range qs { + arr = append(arr, k) + } + sort.Ints(arr) + mm.QuotaTypes = arr + } + + // 渠道并集 + names := matchedNamesByIdx[idx] + channelSet := make(map[string]model.BoundChannel) + for _, n := range names { + for _, ch := range matchedChannelsByModel[n] { + key := ch.Name + "_" + strconv.Itoa(ch.Type) + channelSet[key] = ch + } + } + if len(channelSet) > 0 { + chs := make([]model.BoundChannel, 0, len(channelSet)) + for _, ch := range channelSet { + chs = append(chs, ch) + } + mm.BoundChannels = chs + } + + // 匹配信息 + mm.MatchedModels = names + mm.MatchedCount = len(names) + } +} diff --git a/controller/model_sync.go b/controller/model_sync.go new file mode 100644 index 000000000..74034b51a --- /dev/null +++ b/controller/model_sync.go @@ -0,0 +1,604 @@ +package controller + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "strings" + "sync" + "time" + + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +// 上游地址 +const ( + upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json" + upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json" +) + +func normalizeLocale(locale string) (string, bool) { + l := strings.ToLower(strings.TrimSpace(locale)) + switch l { + case "en", "zh", "ja": + return l, true + default: + return "", false + } +} + +func getUpstreamBase() string { + return common.GetEnvOrDefaultString("SYNC_UPSTREAM_BASE", "https://basellm.github.io/llm-metadata") +} + +func getUpstreamURLs(locale string) (modelsURL, vendorsURL string) { + base := strings.TrimRight(getUpstreamBase(), "/") + if l, ok := normalizeLocale(locale); ok && l != "" { + return fmt.Sprintf("%s/api/i18n/%s/newapi/models.json", base, l), + fmt.Sprintf("%s/api/i18n/%s/newapi/vendors.json", base, l) + } + return fmt.Sprintf("%s/api/newapi/models.json", base), fmt.Sprintf("%s/api/newapi/vendors.json", base) +} + +type upstreamEnvelope[T any] struct { + Success bool `json:"success"` + Message string `json:"message"` + Data []T `json:"data"` +} + +type upstreamModel struct { + Description string `json:"description"` + Endpoints json.RawMessage `json:"endpoints"` + Icon string `json:"icon"` + ModelName string `json:"model_name"` + NameRule int `json:"name_rule"` + Status int `json:"status"` + Tags string `json:"tags"` + VendorName string `json:"vendor_name"` +} + +type upstreamVendor struct { + Description string `json:"description"` + Icon string `json:"icon"` + Name string `json:"name"` + Status int `json:"status"` +} + +var ( + etagCache = make(map[string]string) + bodyCache = make(map[string][]byte) + cacheMutex sync.RWMutex +) + +type overwriteField struct { + ModelName string `json:"model_name"` + Fields []string `json:"fields"` +} + +type syncRequest struct { + Overwrite []overwriteField `json:"overwrite"` + Locale string `json:"locale"` +} + +func newHTTPClient() *http.Client { + timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 10) + dialer := &net.Dialer{Timeout: time.Duration(timeoutSec) * time.Second} + transport := &http.Transport{ + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: time.Duration(timeoutSec) * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ResponseHeaderTimeout: time.Duration(timeoutSec) * time.Second, + } + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + if strings.HasSuffix(host, "github.io") { + if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil { + return conn, nil + } + return dialer.DialContext(ctx, "tcp6", addr) + } + return dialer.DialContext(ctx, network, addr) + } + return &http.Client{Transport: transport} +} + +var httpClient = newHTTPClient() + +func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error { + var lastErr error + attempts := common.GetEnvOrDefault("SYNC_HTTP_RETRY", 3) + if attempts < 1 { + attempts = 1 + } + baseDelay := 200 * time.Millisecond + maxMB := common.GetEnvOrDefault("SYNC_HTTP_MAX_MB", 10) + maxBytes := int64(maxMB) << 20 + for attempt := 0; attempt < attempts; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + // ETag conditional request + cacheMutex.RLock() + if et := etagCache[url]; et != "" { + req.Header.Set("If-None-Match", et) + } + cacheMutex.RUnlock() + + resp, err := httpClient.Do(req) + if err != nil { + lastErr = err + // backoff with jitter + sleep := baseDelay * time.Duration(1< id + vendorIDCache := make(map[string]int) + + for _, name := range missing { + up, ok := modelByName[name] + if !ok { + skipped = append(skipped, name) + continue + } + + // 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时) + var existing model.Model + if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil { + if existing.SyncOfficial == 0 { + skipped = append(skipped, name) + continue + } + } + + // 确保 vendor 存在 + vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors) + + // 创建模型 + mi := &model.Model{ + ModelName: name, + Description: up.Description, + Icon: up.Icon, + Tags: up.Tags, + VendorID: vendorID, + Status: chooseStatus(up.Status, 1), + NameRule: up.NameRule, + } + if err := mi.Insert(); err == nil { + createdModels++ + createdList = append(createdList, name) + } else { + skipped = append(skipped, name) + } + } + + // 4) 处理可选覆盖(更新本地已有模型的差异字段) + if len(req.Overwrite) > 0 { + // vendorIDCache 已用于创建阶段,可复用 + for _, ow := range req.Overwrite { + up, ok := modelByName[ow.ModelName] + if !ok { + continue + } + var local model.Model + if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil { + continue + } + + // 跳过被禁用官方同步的模型 + if local.SyncOfficial == 0 { + continue + } + + // 映射 vendor + newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors) + + // 应用字段覆盖(事务) + _ = model.DB.Transaction(func(tx *gorm.DB) error { + needUpdate := false + if containsField(ow.Fields, "description") { + local.Description = up.Description + needUpdate = true + } + if containsField(ow.Fields, "icon") { + local.Icon = up.Icon + needUpdate = true + } + if containsField(ow.Fields, "tags") { + local.Tags = up.Tags + needUpdate = true + } + if containsField(ow.Fields, "vendor") { + local.VendorID = newVendorID + needUpdate = true + } + if containsField(ow.Fields, "name_rule") { + local.NameRule = up.NameRule + needUpdate = true + } + if containsField(ow.Fields, "status") { + local.Status = chooseStatus(up.Status, local.Status) + needUpdate = true + } + if !needUpdate { + return nil + } + if err := tx.Save(&local).Error; err != nil { + return err + } + updatedModels++ + updatedList = append(updatedList, ow.ModelName) + return nil + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "created_models": createdModels, + "created_vendors": createdVendors, + "updated_models": updatedModels, + "skipped_models": skipped, + "created_list": createdList, + "updated_list": updatedList, + "source": gin.H{ + "locale": req.Locale, + "models_url": modelsURL, + "vendors_url": vendorsURL, + }, + }, + }) +} + +func containsField(fields []string, key string) bool { + key = strings.ToLower(strings.TrimSpace(key)) + for _, f := range fields { + if strings.ToLower(strings.TrimSpace(f)) == key { + return true + } + } + return false +} + +func coalesce(a, b string) string { + if strings.TrimSpace(a) != "" { + return a + } + return b +} + +func chooseStatus(primary, fallback int) int { + if primary == 0 && fallback != 0 { + return fallback + } + if primary != 0 { + return primary + } + return 1 +} + +// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择) +func SyncUpstreamPreview(c *gin.Context) { + // 1) 拉取上游数据 + timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15) + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second) + defer cancel() + + locale := c.Query("locale") + modelsURL, vendorsURL := getUpstreamURLs(locale) + + var vendorsEnv upstreamEnvelope[upstreamVendor] + var modelsEnv upstreamEnvelope[upstreamModel] + var fetchErr error + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _ = fetchJSON(ctx, vendorsURL, &vendorsEnv) + }() + go func() { + defer wg.Done() + if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil { + fetchErr = err + } + }() + wg.Wait() + if fetchErr != nil { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}}) + return + } + + vendorByName := make(map[string]upstreamVendor) + for _, v := range vendorsEnv.Data { + if v.Name != "" { + vendorByName[v.Name] = v + } + } + modelByName := make(map[string]upstreamModel) + upstreamNames := make([]string, 0, len(modelsEnv.Data)) + for _, m := range modelsEnv.Data { + if m.ModelName != "" { + modelByName[m.ModelName] = m + upstreamNames = append(upstreamNames, m.ModelName) + } + } + + // 2) 本地已有模型 + var locals []model.Model + if len(upstreamNames) > 0 { + _ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error + } + + // 本地 vendor 名称映射 + vendorIdSet := make(map[int]struct{}) + for _, m := range locals { + if m.VendorID != 0 { + vendorIdSet[m.VendorID] = struct{}{} + } + } + vendorIDs := make([]int, 0, len(vendorIdSet)) + for id := range vendorIdSet { + vendorIDs = append(vendorIDs, id) + } + idToVendorName := make(map[int]string) + if len(vendorIDs) > 0 { + var dbVendors []model.Vendor + _ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error + for _, v := range dbVendors { + idToVendorName[v.Id] = v.Name + } + } + + // 3) 缺失且上游存在的模型 + missingList, _ := model.GetMissingModels() + var missing []string + for _, name := range missingList { + if _, ok := modelByName[name]; ok { + missing = append(missing, name) + } + } + + // 4) 计算冲突字段 + type conflictField struct { + Field string `json:"field"` + Local interface{} `json:"local"` + Upstream interface{} `json:"upstream"` + } + type conflictItem struct { + ModelName string `json:"model_name"` + Fields []conflictField `json:"fields"` + } + + var conflicts []conflictItem + for _, local := range locals { + up, ok := modelByName[local.ModelName] + if !ok { + continue + } + fields := make([]conflictField, 0, 6) + if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) { + fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description}) + } + if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) { + fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon}) + } + if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) { + fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags}) + } + // vendor 对比使用名称 + localVendor := idToVendorName[local.VendorID] + if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) { + fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName}) + } + if local.NameRule != up.NameRule { + fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule}) + } + if local.Status != chooseStatus(up.Status, local.Status) { + fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status}) + } + if len(fields) > 0 { + conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields}) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "missing": missing, + "conflicts": conflicts, + "source": gin.H{ + "locale": locale, + "models_url": modelsURL, + "vendors_url": vendorsURL, + }, + }, + }) +} diff --git a/controller/oidc.go b/controller/oidc.go index df8ea1c40..f3def0e34 100644 --- a/controller/oidc.go +++ b/controller/oidc.go @@ -69,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } if oidcResponse.AccessToken == "" { - common.SysError("OIDC 获取 Token 失败,请检查设置!") + common.SysLog("OIDC 获取 Token 失败,请检查设置!") return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") } @@ -85,7 +85,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } defer res2.Body.Close() if res2.StatusCode != http.StatusOK { - common.SysError("OIDC 获取用户信息失败!请检查设置!") + common.SysLog("OIDC 获取用户信息失败!请检查设置!") return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") } @@ -95,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { return nil, err } if oidcUser.OpenID == "" || oidcUser.Email == "" { - common.SysError("OIDC 获取用户信息为空!请检查设置!") + common.SysLog("OIDC 获取用户信息为空!请检查设置!") return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") } return &oidcUser, nil diff --git a/controller/option.go b/controller/option.go index decdb0d40..e5f2b75b0 100644 --- a/controller/option.go +++ b/controller/option.go @@ -2,6 +2,7 @@ package controller import ( "encoding/json" + "fmt" "net/http" "one-api/common" "one-api/model" @@ -35,8 +36,13 @@ func GetOptions(c *gin.Context) { return } +type OptionUpdateRequest struct { + Key string `json:"key"` + Value any `json:"value"` +} + func UpdateOption(c *gin.Context) { - var option model.Option + var option OptionUpdateRequest err := json.NewDecoder(c.Request.Body).Decode(&option) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ @@ -45,6 +51,16 @@ func UpdateOption(c *gin.Context) { }) return } + switch option.Value.(type) { + case bool: + option.Value = common.Interface2String(option.Value.(bool)) + case float64: + option.Value = common.Interface2String(option.Value.(float64)) + case int: + option.Value = common.Interface2String(option.Value.(int)) + default: + option.Value = fmt.Sprintf("%v", option.Value) + } switch option.Key { case "GitHubOAuthEnabled": if option.Value == "true" && common.GitHubClientId == "" { @@ -104,7 +120,7 @@ func UpdateOption(c *gin.Context) { return } case "GroupRatio": - err = ratio_setting.CheckGroupRatio(option.Value) + err = ratio_setting.CheckGroupRatio(option.Value.(string)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -113,7 +129,7 @@ func UpdateOption(c *gin.Context) { return } case "ModelRequestRateLimitGroup": - err = setting.CheckModelRequestRateLimitGroup(option.Value) + err = setting.CheckModelRequestRateLimitGroup(option.Value.(string)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -122,7 +138,7 @@ func UpdateOption(c *gin.Context) { return } case "console_setting.api_info": - err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo") + err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -131,7 +147,7 @@ func UpdateOption(c *gin.Context) { return } case "console_setting.announcements": - err = console_setting.ValidateConsoleSettings(option.Value, "Announcements") + err = console_setting.ValidateConsoleSettings(option.Value.(string), "Announcements") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -140,7 +156,7 @@ func UpdateOption(c *gin.Context) { return } case "console_setting.faq": - err = console_setting.ValidateConsoleSettings(option.Value, "FAQ") + err = console_setting.ValidateConsoleSettings(option.Value.(string), "FAQ") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -149,7 +165,7 @@ func UpdateOption(c *gin.Context) { return } case "console_setting.uptime_kuma_groups": - err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups") + err = console_setting.ValidateConsoleSettings(option.Value.(string), "UptimeKumaGroups") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -158,7 +174,7 @@ func UpdateOption(c *gin.Context) { return } } - err = model.UpdateOption(option.Key, option.Value) + err = model.UpdateOption(option.Key, option.Value.(string)) if err != nil { common.ApiError(c, err) return diff --git a/controller/playground.go b/controller/playground.go index 0073cf060..8a1cb2b67 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -5,10 +5,8 @@ import ( "fmt" "one-api/common" "one-api/constant" - "one-api/dto" "one-api/middleware" "one-api/model" - "one-api/setting" "one-api/types" "time" @@ -28,41 +26,19 @@ func Playground(c *gin.Context) { useAccessToken := c.GetBool("use_access_token") if useAccessToken { - newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied) + newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) return } - playgroundRequest := &dto.PlayGroundRequest{} - err := common.UnmarshalBodyReusable(c, playgroundRequest) - if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) - return - } - - if playgroundRequest.Model == "" { - newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest) - return - } - c.Set("original_model", playgroundRequest.Model) - group := playgroundRequest.Group - userGroup := c.GetString("group") - - if group == "" { - group = userGroup - } else { - if !setting.GroupInUserUsableGroups(group) && group != userGroup { - newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied) - return - } - c.Set("group", group) - } + group := c.GetString("group") + modelName := c.GetString("original_model") userId := c.GetInt("id") // Write user context to ensure acceptUnsetRatio is available userCache, err := model.GetUserCache(userId) if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeQueryDataError) + newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) return } userCache.WriteContext(c) @@ -73,12 +49,12 @@ func Playground(c *gin.Context) { Group: group, } _ = middleware.SetupContextForToken(c, tempToken) - _, newAPIError = getChannel(c, group, playgroundRequest.Model, 0) + _, newAPIError = getChannel(c, group, modelName, 0) if newAPIError != nil { return } //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) - Relay(c) + Relay(c, types.RelayFormatOpenAI) } diff --git a/controller/prefill_group.go b/controller/prefill_group.go new file mode 100644 index 000000000..d912d6098 --- /dev/null +++ b/controller/prefill_group.go @@ -0,0 +1,90 @@ +package controller + +import ( + "strconv" + + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤 +func GetPrefillGroups(c *gin.Context) { + groupType := c.Query("type") + groups, err := model.GetAllPrefillGroups(groupType) + if err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, groups) +} + +// CreatePrefillGroup 创建新的预填组 +func CreatePrefillGroup(c *gin.Context) { + var g model.PrefillGroup + if err := c.ShouldBindJSON(&g); err != nil { + common.ApiError(c, err) + return + } + if g.Name == "" || g.Type == "" { + common.ApiErrorMsg(c, "组名称和类型不能为空") + return + } + // 创建前检查名称 + if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "组名称已存在") + return + } + + if err := g.Insert(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &g) +} + +// UpdatePrefillGroup 更新预填组 +func UpdatePrefillGroup(c *gin.Context) { + var g model.PrefillGroup + if err := c.ShouldBindJSON(&g); err != nil { + common.ApiError(c, err) + return + } + if g.Id == 0 { + common.ApiErrorMsg(c, "缺少组 ID") + return + } + // 名称冲突检查 + if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "组名称已存在") + return + } + + if err := g.Update(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &g) +} + +// DeletePrefillGroup 删除预填组 +func DeletePrefillGroup(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + if err := model.DeletePrefillGroupByID(id); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, nil) +} diff --git a/controller/pricing.go b/controller/pricing.go index f27336b72..4b7cc86d5 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -39,10 +39,13 @@ func GetPricing(c *gin.Context) { } c.JSON(200, gin.H{ - "success": true, - "data": pricing, - "group_ratio": groupRatio, - "usable_group": usableGroup, + "success": true, + "data": pricing, + "vendors": model.GetVendors(), + "group_ratio": groupRatio, + "usable_group": usableGroup, + "supported_endpoint": model.GetSupportedEndpointMap(), + "auto_groups": setting.AutoGroups, }) } diff --git a/controller/ratio_config.go b/controller/ratio_config.go index 6ddc3d9ef..0cb4aa73b 100644 --- a/controller/ratio_config.go +++ b/controller/ratio_config.go @@ -1,24 +1,24 @@ package controller import ( - "net/http" - "one-api/setting/ratio_setting" + "net/http" + "one-api/setting/ratio_setting" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) func GetRatioConfig(c *gin.Context) { - if !ratio_setting.IsExposeRatioEnabled() { - c.JSON(http.StatusForbidden, gin.H{ - "success": false, - "message": "倍率配置接口未启用", - }) - return - } + if !ratio_setting.IsExposeRatioEnabled() { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "倍率配置接口未启用", + }) + return + } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": ratio_setting.GetExposedData(), - }) -} \ No newline at end of file + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": ratio_setting.GetExposedData(), + }) +} diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index 0453870d0..7a481c476 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -1,474 +1,539 @@ package controller import ( - "context" - "encoding/json" - "fmt" - "net/http" - "strings" - "sync" - "time" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "one-api/logger" + "strings" + "sync" + "time" - "one-api/common" - "one-api/dto" - "one-api/model" - "one-api/setting/ratio_setting" + "one-api/dto" + "one-api/model" + "one-api/setting/ratio_setting" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) const ( - defaultTimeoutSeconds = 10 - defaultEndpoint = "/api/ratio_config" - maxConcurrentFetches = 8 + defaultTimeoutSeconds = 10 + defaultEndpoint = "/api/ratio_config" + maxConcurrentFetches = 8 + maxRatioConfigBytes = 10 << 20 // 10MB + floatEpsilon = 1e-9 ) +func nearlyEqual(a, b float64) bool { + if a > b { + return a-b < floatEpsilon + } + return b-a < floatEpsilon +} + +func valuesEqual(a, b interface{}) bool { + af, aok := a.(float64) + bf, bok := b.(float64) + if aok && bok { + return nearlyEqual(af, bf) + } + return a == b +} + var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} type upstreamResult struct { - Name string `json:"name"` - Data map[string]any `json:"data,omitempty"` - Err string `json:"err,omitempty"` + Name string `json:"name"` + Data map[string]any `json:"data,omitempty"` + Err string `json:"err,omitempty"` } func FetchUpstreamRatios(c *gin.Context) { - var req dto.UpstreamRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) - return - } + var req dto.UpstreamRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } - if req.Timeout <= 0 { - req.Timeout = defaultTimeoutSeconds - } + if req.Timeout <= 0 { + req.Timeout = defaultTimeoutSeconds + } - var upstreams []dto.UpstreamDTO + var upstreams []dto.UpstreamDTO - if len(req.Upstreams) > 0 { - for _, u := range req.Upstreams { - if strings.HasPrefix(u.BaseURL, "http") { - if u.Endpoint == "" { - u.Endpoint = defaultEndpoint - } - u.BaseURL = strings.TrimRight(u.BaseURL, "/") - upstreams = append(upstreams, u) - } - } - } else if len(req.ChannelIDs) > 0 { - intIds := make([]int, 0, len(req.ChannelIDs)) - for _, id64 := range req.ChannelIDs { - intIds = append(intIds, int(id64)) - } - dbChannels, err := model.GetChannelsByIds(intIds) - if err != nil { - common.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) - c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) - return - } - for _, ch := range dbChannels { - if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { - upstreams = append(upstreams, dto.UpstreamDTO{ - ID: ch.Id, - Name: ch.Name, - BaseURL: strings.TrimRight(base, "/"), - Endpoint: "", - }) - } - } - } + if len(req.Upstreams) > 0 { + for _, u := range req.Upstreams { + if strings.HasPrefix(u.BaseURL, "http") { + if u.Endpoint == "" { + u.Endpoint = defaultEndpoint + } + u.BaseURL = strings.TrimRight(u.BaseURL, "/") + upstreams = append(upstreams, u) + } + } + } else if len(req.ChannelIDs) > 0 { + intIds := make([]int, 0, len(req.ChannelIDs)) + for _, id64 := range req.ChannelIDs { + intIds = append(intIds, int(id64)) + } + dbChannels, err := model.GetChannelsByIds(intIds) + if err != nil { + logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) + return + } + for _, ch := range dbChannels { + if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { + upstreams = append(upstreams, dto.UpstreamDTO{ + ID: ch.Id, + Name: ch.Name, + BaseURL: strings.TrimRight(base, "/"), + Endpoint: "", + }) + } + } + } - if len(upstreams) == 0 { - c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) - return - } + if len(upstreams) == 0 { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) + return + } - var wg sync.WaitGroup - ch := make(chan upstreamResult, len(upstreams)) + var wg sync.WaitGroup + ch := make(chan upstreamResult, len(upstreams)) - sem := make(chan struct{}, maxConcurrentFetches) + sem := make(chan struct{}, maxConcurrentFetches) - client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} + dialer := &net.Dialer{Timeout: 10 * time.Second} + transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second} + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + // 对 github.io 优先尝试 IPv4,失败则回退 IPv6 + if strings.HasSuffix(host, "github.io") { + if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil { + return conn, nil + } + return dialer.DialContext(ctx, "tcp6", addr) + } + return dialer.DialContext(ctx, network, addr) + } + client := &http.Client{Transport: transport} - for _, chn := range upstreams { - wg.Add(1) - go func(chItem dto.UpstreamDTO) { - defer wg.Done() + for _, chn := range upstreams { + wg.Add(1) + go func(chItem dto.UpstreamDTO) { + defer wg.Done() - sem <- struct{}{} - defer func() { <-sem }() + sem <- struct{}{} + defer func() { <-sem }() - endpoint := chItem.Endpoint - if endpoint == "" { - endpoint = defaultEndpoint - } else if !strings.HasPrefix(endpoint, "/") { - endpoint = "/" + endpoint - } - fullURL := chItem.BaseURL + endpoint + endpoint := chItem.Endpoint + var fullURL string + if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { + fullURL = endpoint + } else { + if endpoint == "" { + endpoint = defaultEndpoint + } else if !strings.HasPrefix(endpoint, "/") { + endpoint = "/" + endpoint + } + fullURL = chItem.BaseURL + endpoint + } - uniqueName := chItem.Name - if chItem.ID != 0 { - uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) - } + uniqueName := chItem.Name + if chItem.ID != 0 { + uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) + } - ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) + defer cancel() - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) - if err != nil { - common.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } - resp, err := client.Do(httpReq) - if err != nil { - common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} - return - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) - ch <- upstreamResult{Name: uniqueName, Err: resp.Status} - return - } - // 兼容两种上游接口格式: - // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price - // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 - var body struct { - Success bool `json:"success"` - Data json.RawMessage `json:"data"` - Message string `json:"message"` - } + // 简单重试:最多 3 次,指数退避 + var resp *http.Response + var lastErr error + for attempt := 0; attempt < 3; attempt++ { + resp, lastErr = client.Do(httpReq) + if lastErr == nil { + break + } + time.Sleep(time.Duration(200*(1< data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price + // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 + var body struct { + Success bool `json:"success"` + Data json.RawMessage `json:"data"` + Message string `json:"message"` + } - if !body.Success { - ch <- upstreamResult{Name: uniqueName, Err: body.Message} - return - } + if err := json.NewDecoder(limited).Decode(&body); err != nil { + logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } - // 尝试按 type1 解析 - var type1Data map[string]any - if err := json.Unmarshal(body.Data, &type1Data); err == nil { - // 如果包含至少一个 ratioTypes 字段,则认为是 type1 - isType1 := false - for _, rt := range ratioTypes { - if _, ok := type1Data[rt]; ok { - isType1 = true - break - } - } - if isType1 { - ch <- upstreamResult{Name: uniqueName, Data: type1Data} - return - } - } + if !body.Success { + ch <- upstreamResult{Name: uniqueName, Err: body.Message} + return + } - // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析 - var pricingItems []struct { - ModelName string `json:"model_name"` - QuotaType int `json:"quota_type"` - ModelRatio float64 `json:"model_ratio"` - ModelPrice float64 `json:"model_price"` - CompletionRatio float64 `json:"completion_ratio"` - } - if err := json.Unmarshal(body.Data, &pricingItems); err != nil { - common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} - return - } + // 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容) - modelRatioMap := make(map[string]float64) - completionRatioMap := make(map[string]float64) - modelPriceMap := make(map[string]float64) + // 尝试按 type1 解析 + var type1Data map[string]any + if err := json.Unmarshal(body.Data, &type1Data); err == nil { + // 如果包含至少一个 ratioTypes 字段,则认为是 type1 + isType1 := false + for _, rt := range ratioTypes { + if _, ok := type1Data[rt]; ok { + isType1 = true + break + } + } + if isType1 { + ch <- upstreamResult{Name: uniqueName, Data: type1Data} + return + } + } - for _, item := range pricingItems { - if item.QuotaType == 1 { - modelPriceMap[item.ModelName] = item.ModelPrice - } else { - modelRatioMap[item.ModelName] = item.ModelRatio - // completionRatio 可能为 0,此时也直接赋值,保持与上游一致 - completionRatioMap[item.ModelName] = item.CompletionRatio - } - } + // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析 + var pricingItems []struct { + ModelName string `json:"model_name"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + CompletionRatio float64 `json:"completion_ratio"` + } + if err := json.Unmarshal(body.Data, &pricingItems); err != nil { + logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} + return + } - converted := make(map[string]any) + modelRatioMap := make(map[string]float64) + completionRatioMap := make(map[string]float64) + modelPriceMap := make(map[string]float64) - if len(modelRatioMap) > 0 { - ratioAny := make(map[string]any, len(modelRatioMap)) - for k, v := range modelRatioMap { - ratioAny[k] = v - } - converted["model_ratio"] = ratioAny - } + for _, item := range pricingItems { + if item.QuotaType == 1 { + modelPriceMap[item.ModelName] = item.ModelPrice + } else { + modelRatioMap[item.ModelName] = item.ModelRatio + // completionRatio 可能为 0,此时也直接赋值,保持与上游一致 + completionRatioMap[item.ModelName] = item.CompletionRatio + } + } - if len(completionRatioMap) > 0 { - compAny := make(map[string]any, len(completionRatioMap)) - for k, v := range completionRatioMap { - compAny[k] = v - } - converted["completion_ratio"] = compAny - } + converted := make(map[string]any) - if len(modelPriceMap) > 0 { - priceAny := make(map[string]any, len(modelPriceMap)) - for k, v := range modelPriceMap { - priceAny[k] = v - } - converted["model_price"] = priceAny - } + if len(modelRatioMap) > 0 { + ratioAny := make(map[string]any, len(modelRatioMap)) + for k, v := range modelRatioMap { + ratioAny[k] = v + } + converted["model_ratio"] = ratioAny + } - ch <- upstreamResult{Name: uniqueName, Data: converted} - }(chn) - } + if len(completionRatioMap) > 0 { + compAny := make(map[string]any, len(completionRatioMap)) + for k, v := range completionRatioMap { + compAny[k] = v + } + converted["completion_ratio"] = compAny + } - wg.Wait() - close(ch) + if len(modelPriceMap) > 0 { + priceAny := make(map[string]any, len(modelPriceMap)) + for k, v := range modelPriceMap { + priceAny[k] = v + } + converted["model_price"] = priceAny + } - localData := ratio_setting.GetExposedData() + ch <- upstreamResult{Name: uniqueName, Data: converted} + }(chn) + } - var testResults []dto.TestResult - var successfulChannels []struct { - name string - data map[string]any - } + wg.Wait() + close(ch) - for r := range ch { - if r.Err != "" { - testResults = append(testResults, dto.TestResult{ - Name: r.Name, - Status: "error", - Error: r.Err, - }) - } else { - testResults = append(testResults, dto.TestResult{ - Name: r.Name, - Status: "success", - }) - successfulChannels = append(successfulChannels, struct { - name string - data map[string]any - }{name: r.Name, data: r.Data}) - } - } + localData := ratio_setting.GetExposedData() - differences := buildDifferences(localData, successfulChannels) + var testResults []dto.TestResult + var successfulChannels []struct { + name string + data map[string]any + } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": gin.H{ - "differences": differences, - "test_results": testResults, - }, - }) + for r := range ch { + if r.Err != "" { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "error", + Error: r.Err, + }) + } else { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "success", + }) + successfulChannels = append(successfulChannels, struct { + name string + data map[string]any + }{name: r.Name, data: r.Data}) + } + } + + differences := buildDifferences(localData, successfulChannels) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "differences": differences, + "test_results": testResults, + }, + }) } func buildDifferences(localData map[string]any, successfulChannels []struct { - name string - data map[string]any + name string + data map[string]any }) map[string]map[string]dto.DifferenceItem { - differences := make(map[string]map[string]dto.DifferenceItem) + differences := make(map[string]map[string]dto.DifferenceItem) - allModels := make(map[string]struct{}) - - for _, ratioType := range ratioTypes { - if localRatioAny, ok := localData[ratioType]; ok { - if localRatio, ok := localRatioAny.(map[string]float64); ok { - for modelName := range localRatio { - allModels[modelName] = struct{}{} - } - } - } - } - - for _, channel := range successfulChannels { - for _, ratioType := range ratioTypes { - if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { - for modelName := range upstreamRatio { - allModels[modelName] = struct{}{} - } - } - } - } + allModels := make(map[string]struct{}) - confidenceMap := make(map[string]map[string]bool) - - // 预处理阶段:检查pricing接口的可信度 - for _, channel := range successfulChannels { - confidenceMap[channel.name] = make(map[string]bool) - - modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) - completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) - - if hasModelRatio && hasCompletionRatio { - // 遍历所有模型,检查是否满足不可信条件 - for modelName := range allModels { - // 默认为可信 - confidenceMap[channel.name][modelName] = true - - // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1 - if modelRatioVal, ok := modelRatios[modelName]; ok { - if completionRatioVal, ok := completionRatios[modelName]; ok { - // 转换为float64进行比较 - if modelRatioFloat, ok := modelRatioVal.(float64); ok { - if completionRatioFloat, ok := completionRatioVal.(float64); ok { - if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { - confidenceMap[channel.name][modelName] = false - } - } - } - } - } - } - } else { - // 如果不是从pricing接口获取的数据,则全部标记为可信 - for modelName := range allModels { - confidenceMap[channel.name][modelName] = true - } - } - } + for _, ratioType := range ratioTypes { + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + for modelName := range localRatio { + allModels[modelName] = struct{}{} + } + } + } + } - for modelName := range allModels { - for _, ratioType := range ratioTypes { - var localValue interface{} = nil - if localRatioAny, ok := localData[ratioType]; ok { - if localRatio, ok := localRatioAny.(map[string]float64); ok { - if val, exists := localRatio[modelName]; exists { - localValue = val - } - } - } + for _, channel := range successfulChannels { + for _, ratioType := range ratioTypes { + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + for modelName := range upstreamRatio { + allModels[modelName] = struct{}{} + } + } + } + } - upstreamValues := make(map[string]interface{}) - confidenceValues := make(map[string]bool) - hasUpstreamValue := false - hasDifference := false + confidenceMap := make(map[string]map[string]bool) - for _, channel := range successfulChannels { - var upstreamValue interface{} = nil - - if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { - if val, exists := upstreamRatio[modelName]; exists { - upstreamValue = val - hasUpstreamValue = true - - if localValue != nil && localValue != val { - hasDifference = true - } else if localValue == val { - upstreamValue = "same" - } - } - } - if upstreamValue == nil && localValue == nil { - upstreamValue = "same" - } - - if localValue == nil && upstreamValue != nil && upstreamValue != "same" { - hasDifference = true - } - - upstreamValues[channel.name] = upstreamValue - - confidenceValues[channel.name] = confidenceMap[channel.name][modelName] - } + // 预处理阶段:检查pricing接口的可信度 + for _, channel := range successfulChannels { + confidenceMap[channel.name] = make(map[string]bool) - shouldInclude := false - - if localValue != nil { - if hasDifference { - shouldInclude = true - } - } else { - if hasUpstreamValue { - shouldInclude = true - } - } + modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) + completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) - if shouldInclude { - if differences[modelName] == nil { - differences[modelName] = make(map[string]dto.DifferenceItem) - } - differences[modelName][ratioType] = dto.DifferenceItem{ - Current: localValue, - Upstreams: upstreamValues, - Confidence: confidenceValues, - } - } - } - } + if hasModelRatio && hasCompletionRatio { + // 遍历所有模型,检查是否满足不可信条件 + for modelName := range allModels { + // 默认为可信 + confidenceMap[channel.name][modelName] = true - channelHasDiff := make(map[string]bool) - for _, ratioMap := range differences { - for _, item := range ratioMap { - for chName, val := range item.Upstreams { - if val != nil && val != "same" { - channelHasDiff[chName] = true - } - } - } - } + // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1 + if modelRatioVal, ok := modelRatios[modelName]; ok { + if completionRatioVal, ok := completionRatios[modelName]; ok { + // 转换为float64进行比较 + if modelRatioFloat, ok := modelRatioVal.(float64); ok { + if completionRatioFloat, ok := completionRatioVal.(float64); ok { + if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { + confidenceMap[channel.name][modelName] = false + } + } + } + } + } + } + } else { + // 如果不是从pricing接口获取的数据,则全部标记为可信 + for modelName := range allModels { + confidenceMap[channel.name][modelName] = true + } + } + } - for modelName, ratioMap := range differences { - for ratioType, item := range ratioMap { - for chName := range item.Upstreams { - if !channelHasDiff[chName] { - delete(item.Upstreams, chName) - delete(item.Confidence, chName) - } - } + for modelName := range allModels { + for _, ratioType := range ratioTypes { + var localValue interface{} = nil + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + if val, exists := localRatio[modelName]; exists { + localValue = val + } + } + } - allSame := true - for _, v := range item.Upstreams { - if v != "same" { - allSame = false - break - } - } - if len(item.Upstreams) == 0 || allSame { - delete(ratioMap, ratioType) - } else { - differences[modelName][ratioType] = item - } - } + upstreamValues := make(map[string]interface{}) + confidenceValues := make(map[string]bool) + hasUpstreamValue := false + hasDifference := false - if len(ratioMap) == 0 { - delete(differences, modelName) - } - } + for _, channel := range successfulChannels { + var upstreamValue interface{} = nil - return differences + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + if val, exists := upstreamRatio[modelName]; exists { + upstreamValue = val + hasUpstreamValue = true + + if localValue != nil && !valuesEqual(localValue, val) { + hasDifference = true + } else if valuesEqual(localValue, val) { + upstreamValue = "same" + } + } + } + if upstreamValue == nil && localValue == nil { + upstreamValue = "same" + } + + if localValue == nil && upstreamValue != nil && upstreamValue != "same" { + hasDifference = true + } + + upstreamValues[channel.name] = upstreamValue + + confidenceValues[channel.name] = confidenceMap[channel.name][modelName] + } + + shouldInclude := false + + if localValue != nil { + if hasDifference { + shouldInclude = true + } + } else { + if hasUpstreamValue { + shouldInclude = true + } + } + + if shouldInclude { + if differences[modelName] == nil { + differences[modelName] = make(map[string]dto.DifferenceItem) + } + differences[modelName][ratioType] = dto.DifferenceItem{ + Current: localValue, + Upstreams: upstreamValues, + Confidence: confidenceValues, + } + } + } + } + + channelHasDiff := make(map[string]bool) + for _, ratioMap := range differences { + for _, item := range ratioMap { + for chName, val := range item.Upstreams { + if val != nil && val != "same" { + channelHasDiff[chName] = true + } + } + } + } + + for modelName, ratioMap := range differences { + for ratioType, item := range ratioMap { + for chName := range item.Upstreams { + if !channelHasDiff[chName] { + delete(item.Upstreams, chName) + delete(item.Confidence, chName) + } + } + + allSame := true + for _, v := range item.Upstreams { + if v != "same" { + allSame = false + break + } + } + if len(item.Upstreams) == 0 || allSame { + delete(ratioMap, ratioType) + } else { + differences[modelName][ratioType] = item + } + } + + if len(ratioMap) == 0 { + delete(differences, modelName) + } + } + + return differences } func GetSyncableChannels(c *gin.Context) { - channels, err := model.GetAllChannels(0, 0, true, false) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } + channels, err := model.GetAllChannels(0, 0, true, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } - var syncableChannels []dto.SyncableChannel - for _, channel := range channels { - if channel.GetBaseURL() != "" { - syncableChannels = append(syncableChannels, dto.SyncableChannel{ - ID: channel.Id, - Name: channel.Name, - BaseURL: channel.GetBaseURL(), - Status: channel.Status, - }) - } - } + var syncableChannels []dto.SyncableChannel + for _, channel := range channels { + if channel.GetBaseURL() != "" { + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: channel.Id, + Name: channel.Name, + BaseURL: channel.GetBaseURL(), + Status: channel.Status, + }) + } + } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": syncableChannels, - }) -} \ No newline at end of file + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: -100, + Name: "官方倍率预设", + BaseURL: "https://basellm.github.io", + Status: 1, + }) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": syncableChannels, + }) +} diff --git a/controller/redemption.go b/controller/redemption.go index 83ec19ad6..1e305e3d8 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -6,6 +6,7 @@ import ( "one-api/common" "one-api/model" "strconv" + "unicode/utf8" "github.com/gin-gonic/gin" ) @@ -63,7 +64,7 @@ func AddRedemption(c *gin.Context) { common.ApiError(c, err) return } - if len(redemption.Name) == 0 || len(redemption.Name) > 20 { + if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "兑换码名称长度必须在1-20之间", diff --git a/controller/relay.go b/controller/relay.go index b224b42c1..d3d93192e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,115 +2,193 @@ package controller import ( "bytes" - "errors" "fmt" "io" "log" "net/http" "one-api/common" "one-api/constant" - constant2 "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/relay" + relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/setting" "one-api/types" "strings" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) -func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { +func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError { var err *types.NewAPIError - switch relayMode { + switch info.RelayMode { case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: - err = relay.ImageHelper(c) + err = relay.ImageHelper(c, info) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: - err = relay.AudioHelper(c) + err = relay.AudioHelper(c, info) case relayconstant.RelayModeRerank: - err = relay.RerankHelper(c, relayMode) + err = relay.RerankHelper(c, info) case relayconstant.RelayModeEmbeddings: - err = relay.EmbeddingHelper(c) + err = relay.EmbeddingHelper(c, info) case relayconstant.RelayModeResponses: - err = relay.ResponsesHelper(c) - case relayconstant.RelayModeGemini: - err = relay.GeminiHelper(c) + err = relay.ResponsesHelper(c, info) default: - err = relay.TextHelper(c) + err = relay.TextHelper(c, info) } - - if constant2.ErrorLogEnabled && err != nil { - // 保存错误日志到mysql中 - userId := c.GetInt("id") - tokenName := c.GetString("token_name") - modelName := c.GetString("original_model") - tokenId := c.GetInt("token_id") - userGroup := c.GetString("group") - channelId := c.GetInt("channel_id") - other := make(map[string]interface{}) - other["error_type"] = err.ErrorType - other["error_code"] = err.GetErrorCode() - other["status_code"] = err.StatusCode - other["channel_id"] = channelId - other["channel_name"] = c.GetString("channel_name") - other["channel_type"] = c.GetInt("channel_type") - - model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other) - } - return err } -func Relay(c *gin.Context) { - relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) +func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError { + var err *types.NewAPIError + if strings.Contains(c.Request.URL.Path, "embed") { + err = relay.GeminiEmbeddingHandler(c, info) + } else { + err = relay.GeminiHelper(c, info) + } + return err +} + +func Relay(c *gin.Context, relayFormat types.RelayFormat) { + requestId := c.GetString(common.RequestIdKey) - group := c.GetString("group") - originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError + group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) + originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) + + var ( + newAPIError *types.NewAPIError + ws *websocket.Conn + ) + + if relayFormat == types.RelayFormatOpenAIRealtime { + var err error + ws, err = upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) + return + } + defer ws.Close() + } + + defer func() { + if newAPIError != nil { + newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) + switch relayFormat { + case types.RelayFormatOpenAIRealtime: + helper.WssError(c, ws, newAPIError.ToOpenAIError()) + case types.RelayFormatClaude: + c.JSON(newAPIError.StatusCode, gin.H{ + "type": "error", + "error": newAPIError.ToClaudeError(), + }) + default: + c.JSON(newAPIError.StatusCode, gin.H{ + "error": newAPIError.ToOpenAIError(), + }) + } + } + }() + + request, err := helper.GetAndValidateRequest(c, relayFormat) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) + return + } + + relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed) + return + } + + meta := request.GetTokenCountMeta() + + if setting.ShouldCheckPromptSensitive() { + contains, words := service.CheckSensitiveText(meta.CombineText) + if contains { + logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) + newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return + } + } + + tokens, err := service.CountRequestToken(c, meta, relayInfo) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed) + return + } + + relayInfo.SetPromptTokens(tokens) + + priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta) + if err != nil { + newAPIError = types.NewError(err, types.ErrorCodeModelPriceError) + return + } + + // common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta) + + preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if newAPIError != nil { + return + } + + defer func() { + // Only return quota if downstream failed and quota was actually pre-consumed + if newAPIError != nil && preConsumedQuota != 0 { + service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) + } + }() for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) newAPIError = err break } - newAPIError = relayRequest(c, relayMode, channel) + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - if newAPIError == nil { - return // 成功处理请求,直接返回 + switch relayFormat { + case types.RelayFormatOpenAIRealtime: + newAPIError = relay.WssHelper(c, relayInfo) + case types.RelayFormatClaude: + newAPIError = relay.ClaudeHelper(c, relayInfo) + case types.RelayFormatGemini: + newAPIError = geminiRelayHandler(c, relayInfo) + default: + newAPIError = relayHandler(c, relayInfo) } - go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) + if newAPIError == nil { + return + } + + processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) if !shouldRetry(c, newAPIError, common.RetryTimes-i) { break } } + useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - //if newAPIError.StatusCode == http.StatusTooManyRequests { - // common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error())) - // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - //} - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - c.JSON(newAPIError.StatusCode, gin.H{ - "error": newAPIError.ToOpenAIError(), - }) + logger.LogInfo(c, retryLogStr) } } @@ -121,122 +199,6 @@ var upgrader = websocket.Upgrader{ }, } -func WssRelay(c *gin.Context) { - // 将 HTTP 连接升级为 WebSocket 连接 - - ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) - defer ws.Close() - - if err != nil { - helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError()) - return - } - - relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) - requestId := c.GetString(common.RequestIdKey) - group := c.GetString("group") - //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 - originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError - - for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, err.Error()) - newAPIError = err - break - } - - newAPIError = wssRequest(c, ws, relayMode, channel) - - if newAPIError == nil { - return // 成功处理请求,直接返回 - } - - go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - - if !shouldRetry(c, newAPIError, common.RetryTimes-i) { - break - } - } - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - //if newAPIError.StatusCode == http.StatusTooManyRequests { - // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试") - //} - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - helper.WssError(c, ws, newAPIError.ToOpenAIError()) - } -} - -func RelayClaude(c *gin.Context) { - //relayMode := constant.Path2RelayMode(c.Request.URL.Path) - requestId := c.GetString(common.RequestIdKey) - group := c.GetString("group") - originalModel := c.GetString("original_model") - var newAPIError *types.NewAPIError - - for i := 0; i <= common.RetryTimes; i++ { - channel, err := getChannel(c, group, originalModel, i) - if err != nil { - common.LogError(c, err.Error()) - newAPIError = err - break - } - - newAPIError = claudeRequest(c, channel) - - if newAPIError == nil { - return // 成功处理请求,直接返回 - } - - go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) - - if !shouldRetry(c, newAPIError, common.RetryTimes-i) { - break - } - } - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) - } - - if newAPIError != nil { - newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) - c.JSON(newAPIError.StatusCode, gin.H{ - "type": "error", - "error": newAPIError.ToClaudeError(), - }) - } -} - -func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relayHandler(c, relayMode) -} - -func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relay.WssHelper(c, ws) -} - -func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError { - addUsedChannel(c, channel.Id) - requestBody, _ := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - return relay.ClaudeHelper(c) -} - func addUsedChannel(c *gin.Context, channelId int) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) @@ -259,10 +221,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m } channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { - if group == "auto" { - return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed) - } - return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed) + return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) + } + if channel == nil { + return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) if newAPIError != nil { @@ -278,7 +240,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b if types.IsChannelError(openaiErr) { return true } - if types.IsLocalError(openaiErr) { + if types.IsSkipRetryError(openaiErr) { return false } if retryTimes <= 0 { @@ -301,10 +263,6 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b return true } if openaiErr.StatusCode == http.StatusBadRequest { - channelType := c.GetInt("channel_type") - if channelType == constant.ChannelTypeAnthropic { - return true - } return false } if openaiErr.StatusCode == 408 { @@ -318,44 +276,84 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b } func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { - // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 - // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) - if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { - service.DisableChannel(channelError, err.Error()) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) + + gopool.Go(func() { + // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 + // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously + if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { + service.DisableChannel(channelError, err.Error()) + } + }) + + if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) { + // 保存错误日志到mysql中 + userId := c.GetInt("id") + tokenName := c.GetString("token_name") + modelName := c.GetString("original_model") + tokenId := c.GetInt("token_id") + userGroup := c.GetString("group") + channelId := c.GetInt("channel_id") + other := make(map[string]interface{}) + other["error_type"] = err.GetErrorType() + other["error_code"] = err.GetErrorCode() + other["status_code"] = err.StatusCode + other["channel_id"] = channelId + other["channel_name"] = c.GetString("channel_name") + other["channel_type"] = c.GetInt("channel_type") + adminInfo := make(map[string]interface{}) + adminInfo["use_channel"] = c.GetStringSlice("use_channel") + isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey) + if isMultiKey { + adminInfo["is_multi_key"] = true + adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex) + } + other["admin_info"] = adminInfo + model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other) } + } func RelayMidjourney(c *gin.Context) { - relayMode := c.GetInt("relay_mode") - var err *dto.MidjourneyResponse - switch relayMode { + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil) + + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "description": fmt.Sprintf("failed to generate relay info: %s", err.Error()), + "type": "upstream_error", + "code": 4, + }) + return + } + + var mjErr *dto.MidjourneyResponse + switch relayInfo.RelayMode { case relayconstant.RelayModeMidjourneyNotify: - err = relay.RelayMidjourneyNotify(c) + mjErr = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: - err = relay.RelayMidjourneyTask(c, relayMode) + mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode) case relayconstant.RelayModeMidjourneyTaskImageSeed: - err = relay.RelayMidjourneyTaskImageSeed(c) + mjErr = relay.RelayMidjourneyTaskImageSeed(c) case relayconstant.RelayModeSwapFace: - err = relay.RelaySwapFace(c) + mjErr = relay.RelaySwapFace(c, relayInfo) default: - err = relay.RelayMidjourneySubmit(c, relayMode) + mjErr = relay.RelayMidjourneySubmit(c, relayInfo) } //err = relayMidjourneySubmit(c, relayMode) - log.Println(err) - if err != nil { + log.Println(mjErr) + if mjErr != nil { statusCode := http.StatusBadRequest - if err.Code == 30 { - err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + if mjErr.Code == 30 { + mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" statusCode = http.StatusTooManyRequests } c.JSON(statusCode, gin.H{ - "description": fmt.Sprintf("%s %s", err.Description, err.Result), + "description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result), "type": "upstream_error", - "code": err.Code, + "code": mjErr.Code, }) channelId := c.GetInt("channel_id") - common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result))) } } @@ -386,18 +384,21 @@ func RelayNotFound(c *gin.Context) { func RelayTask(c *gin.Context) { retryTimes := common.RetryTimes channelId := c.GetInt("channel_id") - relayMode := c.GetInt("relay_mode") group := c.GetString("group") originalModel := c.GetString("original_model") c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) - taskErr := taskRelayHandler(c, relayMode) + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) + if err != nil { + return + } + taskErr := taskRelayHandler(c, relayInfo) if taskErr == nil { retryTimes = 0 } for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { channel, newAPIError := getChannel(c, group, originalModel, i) if newAPIError != nil { - common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) + logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) break } @@ -405,17 +406,17 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) - common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) //middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - taskErr = taskRelayHandler(c, relayMode) + taskErr = taskRelayHandler(c, relayInfo) } useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c, retryLogStr) + logger.LogInfo(c, retryLogStr) } if taskErr != nil { if taskErr.StatusCode == http.StatusTooManyRequests { @@ -425,13 +426,13 @@ func RelayTask(c *gin.Context) { } } -func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { +func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError { var err *dto.TaskError - switch relayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID: - err = relay.RelayTaskFetch(c, relayMode) + switch relayInfo.RelayMode { + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: + err = relay.RelayTaskFetch(c, relayInfo.RelayMode) default: - err = relay.RelayTaskSubmit(c, relayMode) + err = relay.RelayTaskSubmit(c, relayInfo) } return err } diff --git a/controller/swag_video.go b/controller/swag_video.go index 185fd5159..68dd6345f 100644 --- a/controller/swag_video.go +++ b/controller/swag_video.go @@ -114,3 +114,23 @@ type KlingImage2VideoRequest struct { CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"` ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"` } + +// KlingImage2videoTaskId godoc +// @Summary 可灵任务查询--图生视频 +// @Description Query the status and result of a Kling video generation task by task ID +// @Tags Origin +// @Accept json +// @Produce json +// @Param task_id path string true "Task ID" +// @Router /kling/v1/videos/image2video/{task_id} [get] +func KlingImage2videoTaskId(c *gin.Context) {} + +// KlingText2videoTaskId godoc +// @Summary 可灵任务查询--文生视频 +// @Description Query the status and result of a Kling text-to-video generation task by task ID +// @Tags Origin +// @Accept json +// @Produce json +// @Param task_id path string true "Task ID" +// @Router /kling/v1/videos/text2video/{task_id} [get] +func KlingText2videoTaskId(c *gin.Context) {} diff --git a/controller/task.go b/controller/task.go index 78674d8b6..1082d7a11 100644 --- a/controller/task.go +++ b/controller/task.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/relay" "sort" @@ -54,9 +55,9 @@ func UpdateTaskBulk() { "progress": "100%", }) if err != nil { - common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) } else { - common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { @@ -75,10 +76,10 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) case constant.TaskPlatformSuno: _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) - case constant.TaskPlatformKling, constant.TaskPlatformJimeng: - _ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM) default: - common.SysLog("未知平台") + if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil { + common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) + } } } @@ -86,14 +87,14 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM for channelId, taskIds := range taskChannelM { err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) if err != nil { - common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) } } return nil } func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { - common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil } @@ -106,7 +107,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas "progress": "100%", }) if err != nil { - common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) + common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) } return err } @@ -118,23 +119,23 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas "ids": taskIds, }) if err != nil { - common.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) return err } if resp.StatusCode != http.StatusOK { - common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) } defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) return err } var responseItems dto.TaskResponse[[]dto.SunoDataResponse] err = json.Unmarshal(responseBody, &responseItems) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) return err } if !responseItems.IsSuccess() { @@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { - common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) task.Progress = "100%" //err = model.CacheUpdateUserQuota(task.UserId) ? if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) + logger.LogError(ctx, "error update user quota cache: "+err.Error()) } else { quota := task.Quota if quota != 0 { err = model.IncreaseUserQuota(task.UserId, quota, false) if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota)) + logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } @@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas err = task.Update() if err != nil { - common.SysError("UpdateMidjourneyTask task error: " + err.Error()) + common.SysLog("UpdateMidjourneyTask task error: " + err.Error()) } } return nil diff --git a/controller/task_video.go b/controller/task_video.go index b62978a75..84b78f901 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -2,27 +2,31 @@ package controller import ( "context" + "encoding/json" "fmt" "io" "one-api/common" "one-api/constant" + "one-api/dto" + "one-api/logger" "one-api/model" "one-api/relay" "one-api/relay/channel" + relaycommon "one-api/relay/common" "time" ) func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { for channelId, taskIds := range taskChannelM { if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { - common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) } } return nil } func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { - common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) if len(taskIds) == 0 { return nil } @@ -34,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha "progress": "100%", }) if errUpdate != nil { - common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) } return fmt.Errorf("CacheGetChannel failed: %w", err) } @@ -44,7 +48,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha } for _, taskId := range taskIds { if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { - common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) } } return nil @@ -58,7 +62,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task := taskM[taskId] if task == nil { - common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) return fmt.Errorf("task %s not found", taskId) } resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ @@ -77,13 +81,21 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha return fmt.Errorf("readAll failed for task %s: %w", taskId, err) } - taskResult, err := adaptor.ParseTaskResult(responseBody) - if err != nil { + taskResult := &relaycommon.TaskInfo{} + // try parse as New API response format + var responseItems dto.TaskResponse[model.Task] + if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { + t := responseItems.Data + taskResult.TaskID = t.TaskID + taskResult.Status = string(t.Status) + taskResult.Url = t.FailReason + taskResult.Progress = t.Progress + taskResult.Reason = t.FailReason + } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) + } else { + task.Data = responseBody } - //if taskResult.Code != 0 { - // return fmt.Errorf("video task fetch failed for task %s", taskId) - //} now := time.Now().Unix() if taskResult.Status == "" { @@ -113,13 +125,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.FinishTime = now } task.FailReason = taskResult.Reason - common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) quota := task.Quota if quota != 0 { if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { - common.LogError(ctx, "Failed to increase user quota: "+err.Error()) + logger.LogError(ctx, "Failed to increase user quota: "+err.Error()) } - logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) + logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } default: @@ -128,10 +140,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha if taskResult.Progress != "" { task.Progress = taskResult.Progress } - - task.Data = responseBody if err := task.Update(); err != nil { - common.SysError("UpdateVideoTask task error: " + err.Error()) + common.SysLog("UpdateVideoTask task error: " + err.Error()) } return nil diff --git a/controller/token.go b/controller/token.go index 62eb5474e..8ed8b9570 100644 --- a/controller/token.go +++ b/controller/token.go @@ -5,6 +5,7 @@ import ( "one-api/common" "one-api/model" "strconv" + "strings" "github.com/gin-gonic/gin" ) @@ -82,6 +83,57 @@ func GetTokenStatus(c *gin.Context) { }) } +func GetTokenUsage(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "message": "No Authorization header", + }) + return + } + + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "message": "Invalid Bearer token", + }) + return + } + tokenKey := parts[1] + + token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + expiredAt := token.ExpiredTime + if expiredAt == -1 { + expiredAt = 0 + } + + c.JSON(http.StatusOK, gin.H{ + "code": true, + "message": "ok", + "data": gin.H{ + "object": "token_usage", + "name": token.Name, + "total_granted": token.RemainQuota + token.UsedQuota, + "total_used": token.UsedQuota, + "total_available": token.RemainQuota, + "unlimited_quota": token.UnlimitedQuota, + "model_limits": token.GetModelLimitsMap(), + "model_limits_enabled": token.ModelLimitsEnabled, + "expires_at": expiredAt, + }, + }) +} + func AddToken(c *gin.Context) { token := model.Token{} err := c.ShouldBindJSON(&token) @@ -102,7 +154,7 @@ func AddToken(c *gin.Context) { "success": false, "message": "生成令牌失败", }) - common.SysError("failed to generate token key: " + err.Error()) + common.SysLog("failed to generate token key: " + err.Error()) return } cleanToken := model.Token{ diff --git a/controller/topup.go b/controller/topup.go index 827dda393..3f3c86231 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -5,6 +5,7 @@ import ( "log" "net/url" "one-api/common" + "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -231,7 +232,7 @@ func EpayNotify(c *gin.Context) { return } log.Printf("易支付回调更新用户成功 %v", topUp) - model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money)) + model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money)) } } else { log.Printf("易支付异常回调: %v", verifyInfo) diff --git a/controller/twofa.go b/controller/twofa.go new file mode 100644 index 000000000..1859a1284 --- /dev/null +++ b/controller/twofa.go @@ -0,0 +1,553 @@ +package controller + +import ( + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +// Setup2FARequest 设置2FA请求结构 +type Setup2FARequest struct { + Code string `json:"code" binding:"required"` +} + +// Verify2FARequest 验证2FA请求结构 +type Verify2FARequest struct { + Code string `json:"code" binding:"required"` +} + +// Setup2FAResponse 设置2FA响应结构 +type Setup2FAResponse struct { + Secret string `json:"secret"` + QRCodeData string `json:"qr_code_data"` + BackupCodes []string `json:"backup_codes"` +} + +// Setup2FA 初始化2FA设置 +func Setup2FA(c *gin.Context) { + userId := c.GetInt("id") + + // 检查用户是否已经启用2FA + existing, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + if existing != nil && existing.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户已启用2FA,请先禁用后重新设置", + }) + return + } + + // 如果存在已禁用的2FA记录,先删除它 + if existing != nil && !existing.IsEnabled { + if err := existing.Delete(); err != nil { + common.ApiError(c, err) + return + } + existing = nil // 重置为nil,后续将创建新记录 + } + + // 获取用户信息 + user, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + // 生成TOTP密钥 + key, err := common.GenerateTOTPSecret(user.Username) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "生成2FA密钥失败", + }) + common.SysLog("生成TOTP密钥失败: " + err.Error()) + return + } + + // 生成备用码 + backupCodes, err := common.GenerateBackupCodes() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "生成备用码失败", + }) + common.SysLog("生成备用码失败: " + err.Error()) + return + } + + // 生成二维码数据 + qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username) + + // 创建或更新2FA记录(暂未启用) + twoFA := &model.TwoFA{ + UserId: userId, + Secret: key.Secret(), + IsEnabled: false, + } + + if existing != nil { + // 更新现有记录 + twoFA.Id = existing.Id + err = twoFA.Update() + } else { + // 创建新记录 + err = twoFA.Create() + } + + if err != nil { + common.ApiError(c, err) + return + } + + // 创建备用码记录 + if err := model.CreateBackupCodes(userId, backupCodes); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "保存备用码失败", + }) + common.SysLog("保存备用码失败: " + err.Error()) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证") + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置", + "data": Setup2FAResponse{ + Secret: key.Secret(), + QRCodeData: qrCodeData, + BackupCodes: backupCodes, + }, + }) +} + +// Enable2FA 启用2FA +func Enable2FA(c *gin.Context) { + var req Setup2FARequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + + userId := c.GetInt("id") + + // 获取2FA记录 + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + if twoFA == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "请先完成2FA初始化设置", + }) + return + } + if twoFA.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "2FA已经启用", + }) + return + } + + // 验证TOTP验证码 + cleanCode, err := common.ValidateNumericCode(req.Code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码或备用码错误,请重试", + }) + return + } + + // 启用2FA + if err := twoFA.Enable(); err != nil { + common.ApiError(c, err) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证") + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "两步验证启用成功", + }) +} + +// Disable2FA 禁用2FA +func Disable2FA(c *gin.Context) { + var req Verify2FARequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + + userId := c.GetInt("id") + + // 获取2FA记录 + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + if twoFA == nil || !twoFA.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户未启用2FA", + }) + return + } + + // 验证TOTP验证码或备用码 + cleanCode, err := common.ValidateNumericCode(req.Code) + isValidTOTP := false + isValidBackup := false + + if err == nil { + // 尝试验证TOTP + isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode) + } + + if !isValidTOTP { + // 尝试验证备用码 + isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } + + if !isValidTOTP && !isValidBackup { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码或备用码错误,请重试", + }) + return + } + + // 禁用2FA + if err := model.DisableTwoFA(userId); err != nil { + common.ApiError(c, err) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证") + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "两步验证已禁用", + }) +} + +// Get2FAStatus 获取用户2FA状态 +func Get2FAStatus(c *gin.Context) { + userId := c.GetInt("id") + + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + + status := map[string]interface{}{ + "enabled": false, + "locked": false, + } + + if twoFA != nil { + status["enabled"] = twoFA.IsEnabled + status["locked"] = twoFA.IsLocked() + if twoFA.IsEnabled { + // 获取剩余备用码数量 + backupCount, err := model.GetUnusedBackupCodeCount(userId) + if err != nil { + common.SysLog("获取备用码数量失败: " + err.Error()) + } else { + status["backup_codes_remaining"] = backupCount + } + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": status, + }) +} + +// RegenerateBackupCodes 重新生成备用码 +func RegenerateBackupCodes(c *gin.Context) { + var req Verify2FARequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + + userId := c.GetInt("id") + + // 获取2FA记录 + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + if twoFA == nil || !twoFA.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户未启用2FA", + }) + return + } + + // 验证TOTP验证码 + cleanCode, err := common.ValidateNumericCode(req.Code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if !valid { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码或备用码错误,请重试", + }) + return + } + + // 生成新的备用码 + backupCodes, err := common.GenerateBackupCodes() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "生成备用码失败", + }) + common.SysLog("生成备用码失败: " + err.Error()) + return + } + + // 保存新的备用码 + if err := model.CreateBackupCodes(userId, backupCodes); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "保存备用码失败", + }) + common.SysLog("保存备用码失败: " + err.Error()) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码") + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "备用码重新生成成功", + "data": map[string]interface{}{ + "backup_codes": backupCodes, + }, + }) +} + +// Verify2FALogin 登录时验证2FA +func Verify2FALogin(c *gin.Context) { + var req Verify2FARequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + + // 从会话中获取pending用户信息 + session := sessions.Default(c) + pendingUserId := session.Get("pending_user_id") + if pendingUserId == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "会话已过期,请重新登录", + }) + return + } + userId, ok := pendingUserId.(int) + if !ok { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "会话数据无效,请重新登录", + }) + return + } + // 获取用户信息 + user, err := model.GetUserById(userId, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户不存在", + }) + return + } + + // 获取2FA记录 + twoFA, err := model.GetTwoFAByUserId(user.Id) + if err != nil { + common.ApiError(c, err) + return + } + if twoFA == nil || !twoFA.IsEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户未启用2FA", + }) + return + } + + // 验证TOTP验证码或备用码 + cleanCode, err := common.ValidateNumericCode(req.Code) + isValidTOTP := false + isValidBackup := false + + if err == nil { + // 尝试验证TOTP + isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode) + } + + if !isValidTOTP { + // 尝试验证备用码 + isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } + + if !isValidTOTP && !isValidBackup { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码或备用码错误,请重试", + }) + return + } + + // 2FA验证成功,清理pending会话信息并完成登录 + session.Delete("pending_username") + session.Delete("pending_user_id") + session.Save() + + setupLogin(user, c) +} + +// Admin2FAStats 管理员获取2FA统计信息 +func Admin2FAStats(c *gin.Context) { + stats, err := model.GetTwoFAStats() + if err != nil { + common.ApiError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": stats, + }) +} + +// AdminDisable2FA 管理员强制禁用用户2FA +func AdminDisable2FA(c *gin.Context) { + userIdStr := c.Param("id") + userId, err := strconv.Atoi(userIdStr) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户ID格式错误", + }) + return + } + + // 检查目标用户权限 + targetUser, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + myRole := c.GetInt("role") + if myRole <= targetUser.Role && myRole != common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权操作同级或更高级用户的2FA设置", + }) + return + } + + // 禁用2FA + if err := model.DisableTwoFA(userId); err != nil { + if errors.Is(err, model.ErrTwoFANotEnabled) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户未启用2FA", + }) + return + } + common.ApiError(c, err) + return + } + + // 记录操作日志 + adminId := c.GetInt("id") + model.RecordLog(userId, model.LogTypeManage, + fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId)) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "用户2FA已被强制禁用", + }) +} diff --git a/controller/uptime_kuma.go b/controller/uptime_kuma.go index 05d6297eb..41b9695c3 100644 --- a/controller/uptime_kuma.go +++ b/controller/uptime_kuma.go @@ -31,7 +31,7 @@ type Monitor struct { type UptimeGroupResult struct { CategoryName string `json:"categoryName"` - Monitors []Monitor `json:"monitors"` + Monitors []Monitor `json:"monitors"` } func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error { @@ -57,29 +57,29 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st url, _ := groupConfig["url"].(string) slug, _ := groupConfig["slug"].(string) categoryName, _ := groupConfig["categoryName"].(string) - + result := UptimeGroupResult{ CategoryName: categoryName, - Monitors: []Monitor{}, + Monitors: []Monitor{}, } - + if url == "" || slug == "" { return result } baseURL := strings.TrimSuffix(url, "/") - + var statusData struct { PublicGroupList []struct { - ID int `json:"id"` - Name string `json:"name"` + ID int `json:"id"` + Name string `json:"name"` MonitorList []struct { ID int `json:"id"` Name string `json:"name"` } `json:"monitorList"` } `json:"publicGroupList"` } - + var heartbeatData struct { HeartbeatList map[string][]struct { Status int `json:"status"` @@ -88,11 +88,11 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st } g, gCtx := errgroup.WithContext(ctx) - g.Go(func() error { - return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData) + g.Go(func() error { + return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData) }) - g.Go(func() error { - return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData) + g.Go(func() error { + return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData) }) if g.Wait() != nil { @@ -139,7 +139,7 @@ func GetUptimeKumaStatus(c *gin.Context) { client := &http.Client{Timeout: httpTimeout} results := make([]UptimeGroupResult, len(groups)) - + g, gCtx := errgroup.WithContext(ctx) for i, group := range groups { i, group := i, group @@ -148,7 +148,7 @@ func GetUptimeKumaStatus(c *gin.Context) { return nil }) } - + g.Wait() c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results}) -} \ No newline at end of file +} diff --git a/controller/user.go b/controller/user.go index 292ed8c6e..982329cec 100644 --- a/controller/user.go +++ b/controller/user.go @@ -7,6 +7,7 @@ import ( "net/url" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/model" "one-api/setting" "strconv" @@ -62,6 +63,32 @@ func Login(c *gin.Context) { }) return } + + // 检查是否启用2FA + if model.IsTwoFAEnabled(user.Id) { + // 设置pending session,等待2FA验证 + session := sessions.Default(c) + session.Set("pending_username", user.Username) + session.Set("pending_user_id", user.Id) + err := session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": "无法保存会话信息,请重试", + "success": false, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "请输入两步验证码", + "success": true, + "data": map[string]interface{}{ + "require_2fa": true, + }, + }) + return + } + setupLogin(&user, c) } @@ -166,7 +193,7 @@ func Register(c *gin.Context) { "success": false, "message": "数据库错误,请稍后重试", }) - common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) + common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) return } if exist { @@ -183,6 +210,7 @@ func Register(c *gin.Context) { Password: user.Password, DisplayName: user.Username, InviterId: inviterId, + Role: common.RoleCommonUser, // 明确设置角色为普通用户 } if common.EmailVerificationEnabled { cleanUser.Email = user.Email @@ -209,7 +237,7 @@ func Register(c *gin.Context) { "success": false, "message": "生成默认令牌失败", }) - common.SysError("failed to generate token key: " + err.Error()) + common.SysLog("failed to generate token key: " + err.Error()) return } // 生成默认令牌 @@ -316,7 +344,7 @@ func GenerateAccessToken(c *gin.Context) { "success": false, "message": "生成失败", }) - common.SysError("failed to generate key: " + err.Error()) + common.SysLog("failed to generate key: " + err.Error()) return } user.SetAccessToken(key) @@ -399,6 +427,7 @@ func GetAffCode(c *gin.Context) { func GetSelf(c *gin.Context) { id := c.GetInt("id") + userRole := c.GetInt("role") user, err := model.GetUserById(id, false) if err != nil { common.ApiError(c, err) @@ -407,14 +436,134 @@ func GetSelf(c *gin.Context) { // Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users user.Remark = "" + // 计算用户权限信息 + permissions := calculateUserPermissions(userRole) + + // 获取用户设置并提取sidebar_modules + userSetting := user.GetSetting() + + // 构建响应数据,包含用户信息和权限 + responseData := map[string]interface{}{ + "id": user.Id, + "username": user.Username, + "display_name": user.DisplayName, + "role": user.Role, + "status": user.Status, + "email": user.Email, + "group": user.Group, + "quota": user.Quota, + "used_quota": user.UsedQuota, + "request_count": user.RequestCount, + "aff_code": user.AffCode, + "aff_count": user.AffCount, + "aff_quota": user.AffQuota, + "aff_history_quota": user.AffHistoryQuota, + "inviter_id": user.InviterId, + "linux_do_id": user.LinuxDOId, + "setting": user.Setting, + "stripe_customer": user.StripeCustomer, + "sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段 + "permissions": permissions, // 新增权限字段 + } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": user, + "data": responseData, }) return } +// 计算用户权限的辅助函数 +func calculateUserPermissions(userRole int) map[string]interface{} { + permissions := map[string]interface{}{} + + // 根据用户角色计算权限 + if userRole == common.RoleRootUser { + // 超级管理员不需要边栏设置功能 + permissions["sidebar_settings"] = false + permissions["sidebar_modules"] = map[string]interface{}{} + } else if userRole == common.RoleAdminUser { + // 管理员可以设置边栏,但不包含系统设置功能 + permissions["sidebar_settings"] = true + permissions["sidebar_modules"] = map[string]interface{}{ + "admin": map[string]interface{}{ + "setting": false, // 管理员不能访问系统设置 + }, + } + } else { + // 普通用户只能设置个人功能,不包含管理员区域 + permissions["sidebar_settings"] = true + permissions["sidebar_modules"] = map[string]interface{}{ + "admin": false, // 普通用户不能访问管理员区域 + } + } + + return permissions +} + +// 根据用户角色生成默认的边栏配置 +func generateDefaultSidebarConfig(userRole int) string { + defaultConfig := map[string]interface{}{} + + // 聊天区域 - 所有用户都可以访问 + defaultConfig["chat"] = map[string]interface{}{ + "enabled": true, + "playground": true, + "chat": true, + } + + // 控制台区域 - 所有用户都可以访问 + defaultConfig["console"] = map[string]interface{}{ + "enabled": true, + "detail": true, + "token": true, + "log": true, + "midjourney": true, + "task": true, + } + + // 个人中心区域 - 所有用户都可以访问 + defaultConfig["personal"] = map[string]interface{}{ + "enabled": true, + "topup": true, + "personal": true, + } + + // 管理员区域 - 根据角色决定 + if userRole == common.RoleAdminUser { + // 管理员可以访问管理员区域,但不能访问系统设置 + defaultConfig["admin"] = map[string]interface{}{ + "enabled": true, + "channel": true, + "models": true, + "redemption": true, + "user": true, + "setting": false, // 管理员不能访问系统设置 + } + } else if userRole == common.RoleRootUser { + // 超级管理员可以访问所有功能 + defaultConfig["admin"] = map[string]interface{}{ + "enabled": true, + "channel": true, + "models": true, + "redemption": true, + "user": true, + "setting": true, + } + } + // 普通用户不包含admin区域 + + // 转换为JSON字符串 + configBytes, err := json.Marshal(defaultConfig) + if err != nil { + common.SysLog("生成默认边栏配置失败: " + err.Error()) + return "" + } + + return string(configBytes) +} + func GetUserModels(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { @@ -491,7 +640,7 @@ func UpdateUser(c *gin.Context) { return } if originUser.Quota != updatedUser.Quota { - model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) + model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota))) } c.JSON(http.StatusOK, gin.H{ "success": true, @@ -501,8 +650,8 @@ func UpdateUser(c *gin.Context) { } func UpdateSelf(c *gin.Context) { - var user model.User - err := json.NewDecoder(c.Request.Body).Decode(&user) + var requestData map[string]interface{} + err := json.NewDecoder(c.Request.Body).Decode(&requestData) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -510,6 +659,60 @@ func UpdateSelf(c *gin.Context) { }) return } + + // 检查是否是sidebar_modules更新请求 + if sidebarModules, exists := requestData["sidebar_modules"]; exists { + userId := c.GetInt("id") + user, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + // 获取当前用户设置 + currentSetting := user.GetSetting() + + // 更新sidebar_modules字段 + if sidebarModulesStr, ok := sidebarModules.(string); ok { + currentSetting.SidebarModules = sidebarModulesStr + } + + // 保存更新后的设置 + user.SetSetting(currentSetting) + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "更新设置失败: " + err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "设置更新成功", + }) + return + } + + // 原有的用户信息更新逻辑 + var user model.User + requestDataBytes, err := json.Marshal(requestData) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + err = json.Unmarshal(requestDataBytes, &user) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if user.Password == "" { user.Password = "$I_LOVE_U" // make Validator happy :) } @@ -652,6 +855,7 @@ func CreateUser(c *gin.Context) { Username: user.Username, Password: user.Password, DisplayName: user.DisplayName, + Role: user.Role, // 保持管理员设置的角色 } if err := cleanUser.Insert(0); err != nil { common.ApiError(c, err) @@ -817,18 +1021,64 @@ type topUpRequest struct { Key string `json:"key"` } -var topUpLock = sync.Mutex{} +var topUpLocks sync.Map +var topUpCreateLock sync.Mutex + +type topUpTryLock struct { + ch chan struct{} +} + +func newTopUpTryLock() *topUpTryLock { + return &topUpTryLock{ch: make(chan struct{}, 1)} +} + +func (l *topUpTryLock) TryLock() bool { + select { + case l.ch <- struct{}{}: + return true + default: + return false + } +} + +func (l *topUpTryLock) Unlock() { + select { + case <-l.ch: + default: + } +} + +func getTopUpLock(userID int) *topUpTryLock { + if v, ok := topUpLocks.Load(userID); ok { + return v.(*topUpTryLock) + } + topUpCreateLock.Lock() + defer topUpCreateLock.Unlock() + if v, ok := topUpLocks.Load(userID); ok { + return v.(*topUpTryLock) + } + l := newTopUpTryLock() + topUpLocks.Store(userID, l) + return l +} func TopUp(c *gin.Context) { - topUpLock.Lock() - defer topUpLock.Unlock() + id := c.GetInt("id") + lock := getTopUpLock(id) + if !lock.TryLock() { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "充值处理中,请稍后重试", + }) + return + } + defer lock.Unlock() req := topUpRequest{} err := c.ShouldBindJSON(&req) if err != nil { common.ApiError(c, err) return } - id := c.GetInt("id") quota, err := model.Redeem(req.Key, id) if err != nil { common.ApiError(c, err) @@ -839,7 +1089,6 @@ func TopUp(c *gin.Context) { "message": "", "data": quota, }) - return } type UpdateUserSettingRequest struct { @@ -848,6 +1097,7 @@ type UpdateUserSettingRequest struct { WebhookUrl string `json:"webhook_url,omitempty"` WebhookSecret string `json:"webhook_secret,omitempty"` NotificationEmail string `json:"notification_email,omitempty"` + BarkUrl string `json:"bark_url,omitempty"` AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"` RecordIpLog bool `json:"record_ip_log"` } @@ -863,7 +1113,7 @@ func UpdateUserSetting(c *gin.Context) { } // 验证预警类型 - if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook { + if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的预警类型", @@ -911,6 +1161,33 @@ func UpdateUserSetting(c *gin.Context) { } } + // 如果是Bark类型,验证Bark URL + if req.QuotaWarningType == dto.NotifyTypeBark { + if req.BarkUrl == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Bark推送URL不能为空", + }) + return + } + // 验证URL格式 + if _, err := url.ParseRequestURI(req.BarkUrl); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的Bark推送URL", + }) + return + } + // 检查是否是HTTP或HTTPS + if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Bark推送URL必须以http://或https://开头", + }) + return + } + } + userId := c.GetInt("id") user, err := model.GetUserById(userId, true) if err != nil { @@ -939,6 +1216,11 @@ func UpdateUserSetting(c *gin.Context) { settings.NotificationEmail = req.NotificationEmail } + // 如果是Bark类型,添加Bark URL到设置中 + if req.QuotaWarningType == dto.NotifyTypeBark { + settings.BarkUrl = req.BarkUrl + } + // 更新用户设置 user.SetSetting(settings) if err := user.Update(false); err != nil { diff --git a/controller/vendor_meta.go b/controller/vendor_meta.go new file mode 100644 index 000000000..21d5a21db --- /dev/null +++ b/controller/vendor_meta.go @@ -0,0 +1,124 @@ +package controller + +import ( + "strconv" + + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +// GetAllVendors 获取供应商列表(分页) +func GetAllVendors(c *gin.Context) { + pageInfo := common.GetPageQuery(c) + vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + var total int64 + model.DB.Model(&model.Vendor{}).Count(&total) + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(vendors) + common.ApiSuccess(c, pageInfo) +} + +// SearchVendors 搜索供应商 +func SearchVendors(c *gin.Context) { + keyword := c.Query("keyword") + pageInfo := common.GetPageQuery(c) + vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(vendors) + common.ApiSuccess(c, pageInfo) +} + +// GetVendorMeta 根据 ID 获取供应商 +func GetVendorMeta(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + v, err := model.GetVendorByID(id) + if err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, v) +} + +// CreateVendorMeta 新建供应商 +func CreateVendorMeta(c *gin.Context) { + var v model.Vendor + if err := c.ShouldBindJSON(&v); err != nil { + common.ApiError(c, err) + return + } + if v.Name == "" { + common.ApiErrorMsg(c, "供应商名称不能为空") + return + } + // 创建前先检查名称 + if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "供应商名称已存在") + return + } + + if err := v.Insert(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &v) +} + +// UpdateVendorMeta 更新供应商 +func UpdateVendorMeta(c *gin.Context) { + var v model.Vendor + if err := c.ShouldBindJSON(&v); err != nil { + common.ApiError(c, err) + return + } + if v.Id == 0 { + common.ApiErrorMsg(c, "缺少供应商 ID") + return + } + // 名称冲突检查 + if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "供应商名称已存在") + return + } + + if err := v.Update(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &v) +} + +// DeleteVendorMeta 删除供应商 +func DeleteVendorMeta(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, nil) +} diff --git a/docker-compose.yml b/docker-compose.yml index 57ad0b30a..d98fd706e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,7 +16,7 @@ services: - REDIS_CONN_STRING=redis://redis - TZ=Asia/Shanghai - ERROR_LOG_ENABLED=true # 是否启用错误日志记录 - # - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 + # - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 # - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!! # - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment # - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed diff --git a/dto/audio.go b/dto/audio.go index c36b3da54..9d71f6f76 100644 --- a/dto/audio.go +++ b/dto/audio.go @@ -1,5 +1,11 @@ package dto +import ( + "one-api/types" + + "github.com/gin-gonic/gin" +) + type AudioRequest struct { Model string `json:"model"` Input string `json:"input"` @@ -8,6 +14,24 @@ type AudioRequest struct { ResponseFormat string `json:"response_format,omitempty"` } +func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta { + meta := &types.TokenCountMeta{ + CombineText: r.Input, + TokenType: types.TokenTypeTextNumber, + } + return meta +} + +func (r *AudioRequest) IsStream(c *gin.Context) bool { + return false +} + +func (r *AudioRequest) SetModelName(modelName string) { + if modelName != "" { + r.Model = modelName + } +} + type AudioResponse struct { Text string `json:"text"` } diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 871d67169..2c58795cb 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -1,7 +1,14 @@ package dto type ChannelSettings struct { - ForceFormat bool `json:"force_format,omitempty"` - ThinkingToContent bool `json:"thinking_to_content,omitempty"` - Proxy string `json:"proxy"` + ForceFormat bool `json:"force_format,omitempty"` + ThinkingToContent bool `json:"thinking_to_content,omitempty"` + Proxy string `json:"proxy"` + PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"` + SystemPrompt string `json:"system_prompt,omitempty"` + SystemPromptOverride bool `json:"system_prompt_override,omitempty"` +} + +type ChannelOtherSettings struct { + AzureResponsesVersion string `json:"azure_responses_version,omitempty"` } diff --git a/dto/claude.go b/dto/claude.go index 1a7eacb18..963e588bf 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -2,8 +2,12 @@ package dto import ( "encoding/json" + "fmt" "one-api/common" "one-api/types" + "strings" + + "github.com/gin-gonic/gin" ) type ClaudeMetadata struct { @@ -80,7 +84,7 @@ func (c *ClaudeMediaMessage) GetStringContent() string { } func (c *ClaudeMediaMessage) GetJsonRowString() string { - jsonContent, _ := json.Marshal(c) + jsonContent, _ := common.Marshal(c) return string(jsonContent) } @@ -198,6 +202,147 @@ type ClaudeRequest struct { Thinking *Thinking `json:"thinking,omitempty"` } +func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta { + var tokenCountMeta = types.TokenCountMeta{ + TokenType: types.TokenTypeTokenizer, + MaxTokens: int(c.MaxTokens), + } + + var texts = make([]string, 0) + var fileMeta = make([]*types.FileMeta, 0) + + // system + if c.System != nil { + if c.IsStringSystem() { + sys := c.GetStringSystem() + if sys != "" { + texts = append(texts, sys) + } + } else { + systemMedia := c.ParseSystem() + for _, media := range systemMedia { + switch media.Type { + case "text": + texts = append(texts, media.GetText()) + case "image": + if media.Source != nil { + data := media.Source.Url + if data == "" { + data = common.Interface2String(media.Source.Data) + } + if data != "" { + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data}) + } + } + } + } + } + } + + // messages + for _, message := range c.Messages { + tokenCountMeta.MessagesCount++ + texts = append(texts, message.Role) + if message.IsStringContent() { + content := message.GetStringContent() + if content != "" { + texts = append(texts, content) + } + continue + } + + content, _ := message.ParseContent() + for _, media := range content { + switch media.Type { + case "text": + texts = append(texts, media.GetText()) + case "image": + if media.Source != nil { + data := media.Source.Url + if data == "" { + data = common.Interface2String(media.Source.Data) + } + if data != "" { + fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data}) + } + } + case "tool_use": + if media.Name != "" { + texts = append(texts, media.Name) + } + if media.Input != nil { + b, _ := common.Marshal(media.Input) + texts = append(texts, string(b)) + } + case "tool_result": + if media.Content != nil { + b, _ := common.Marshal(media.Content) + texts = append(texts, string(b)) + } + } + } + } + + // tools + if c.Tools != nil { + tools := c.GetTools() + normalTools, webSearchTools := ProcessTools(tools) + if normalTools != nil { + for _, t := range normalTools { + tokenCountMeta.ToolsCount++ + if t.Name != "" { + texts = append(texts, t.Name) + } + if t.Description != "" { + texts = append(texts, t.Description) + } + if t.InputSchema != nil { + b, _ := common.Marshal(t.InputSchema) + texts = append(texts, string(b)) + } + } + } + if webSearchTools != nil { + for _, t := range webSearchTools { + tokenCountMeta.ToolsCount++ + if t.Name != "" { + texts = append(texts, t.Name) + } + if t.UserLocation != nil { + b, _ := common.Marshal(t.UserLocation) + texts = append(texts, string(b)) + } + } + } + } + + tokenCountMeta.CombineText = strings.Join(texts, "\n") + tokenCountMeta.Files = fileMeta + return &tokenCountMeta +} + +func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool { + return c.Stream +} + +func (c *ClaudeRequest) SetModelName(modelName string) { + if modelName != "" { + c.Model = modelName + } +} + +func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string { + for _, message := range c.Messages { + content, _ := message.ParseContent() + for _, mediaMessage := range content { + if mediaMessage.Id == toolCallId { + return mediaMessage.Name + } + } + } + return "" +} + // AddTool 添加工具到请求中 func (c *ClaudeRequest) AddTool(tool any) { if c.Tools == nil { @@ -284,14 +429,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage { return mediaContent } -type ClaudeError struct { - Type string `json:"type,omitempty"` - Message string `json:"message,omitempty"` -} - type ClaudeErrorWithStatusCode struct { - Error ClaudeError `json:"error"` - StatusCode int `json:"status_code"` + Error types.ClaudeError `json:"error"` + StatusCode int `json:"status_code"` LocalError bool } @@ -303,7 +443,7 @@ type ClaudeResponse struct { Completion string `json:"completion,omitempty"` StopReason string `json:"stop_reason,omitempty"` Model string `json:"model,omitempty"` - Error *types.ClaudeError `json:"error,omitempty"` + Error any `json:"error,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"` Index *int `json:"index,omitempty"` ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"` @@ -324,12 +464,48 @@ func (c *ClaudeResponse) GetIndex() int { return *c.Index } +// GetClaudeError 从动态错误类型中提取ClaudeError结构 +func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError { + if c.Error == nil { + return nil + } + + switch err := c.Error.(type) { + case types.ClaudeError: + return &err + case *types.ClaudeError: + return err + case map[string]interface{}: + // 处理从JSON解析来的map结构 + claudeErr := &types.ClaudeError{} + if errType, ok := err["type"].(string); ok { + claudeErr.Type = errType + } + if errMsg, ok := err["message"].(string); ok { + claudeErr.Message = errMsg + } + return claudeErr + case string: + // 处理简单字符串错误 + return &types.ClaudeError{ + Type: "upstream_error", + Message: err, + } + default: + // 未知类型,尝试转换为字符串 + return &types.ClaudeError{ + Type: "unknown_upstream_error", + Message: fmt.Sprintf("unknown_error: %v", err), + } + } +} + type ClaudeUsage struct { InputTokens int `json:"input_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` OutputTokens int `json:"output_tokens"` - ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"` + ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"` } type ClaudeServerToolUse struct { diff --git a/dto/dalle.go b/dto/dalle.go deleted file mode 100644 index ce2f6361c..000000000 --- a/dto/dalle.go +++ /dev/null @@ -1,29 +0,0 @@ -package dto - -import "encoding/json" - -type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` - ExtraFields json.RawMessage `json:"extra_fields,omitempty"` - Background string `json:"background,omitempty"` - Moderation string `json:"moderation,omitempty"` - OutputFormat string `json:"output_format,omitempty"` - Watermark *bool `json:"watermark,omitempty"` -} - -type ImageResponse struct { - Data []ImageData `json:"data"` - Created int64 `json:"created"` -} -type ImageData struct { - Url string `json:"url"` - B64Json string `json:"b64_json"` - RevisedPrompt string `json:"revised_prompt"` -} diff --git a/dto/embedding.go b/dto/embedding.go index 9d7222920..b473b7228 100644 --- a/dto/embedding.go +++ b/dto/embedding.go @@ -1,5 +1,12 @@ package dto +import ( + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + type EmbeddingOptions struct { Seed int `json:"seed,omitempty"` Temperature *float64 `json:"temperature,omitempty"` @@ -24,9 +31,32 @@ type EmbeddingRequest struct { PresencePenalty float64 `json:"presence_penalty,omitempty"` } -func (r EmbeddingRequest) ParseInput() []string { +func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta { + var texts = make([]string, 0) + + inputs := r.ParseInput() + for _, input := range inputs { + texts = append(texts, input) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + } +} + +func (r *EmbeddingRequest) IsStream(c *gin.Context) bool { + return false +} + +func (r *EmbeddingRequest) SetModelName(modelName string) { + if modelName != "" { + r.Model = modelName + } +} + +func (r *EmbeddingRequest) ParseInput() []string { if r.Input == nil { - return nil + return make([]string, 0) } var input []string switch r.Input.(type) { diff --git a/relay/channel/gemini/dto.go b/dto/gemini.go similarity index 57% rename from relay/channel/gemini/dto.go rename to dto/gemini.go index b22e092a6..cd5d74cdd 100644 --- a/relay/channel/gemini/dto.go +++ b/dto/gemini.go @@ -1,15 +1,118 @@ -package gemini +package dto -import "encoding/json" +import ( + "encoding/json" + "one-api/common" + "one-api/logger" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) type GeminiChatRequest struct { Contents []GeminiChatContent `json:"contents"` SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"` GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"` - Tools []GeminiChatTool `json:"tools,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` } +func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta { + var files []*types.FileMeta = make([]*types.FileMeta, 0) + + var maxTokens int + + if r.GenerationConfig.MaxOutputTokens > 0 { + maxTokens = int(r.GenerationConfig.MaxOutputTokens) + } + + var inputTexts []string + for _, content := range r.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + if part.InlineData != nil && part.InlineData.Data != "" { + if strings.HasPrefix(part.InlineData.MimeType, "image/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeImage, + OriginData: part.InlineData.Data, + }) + } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeAudio, + OriginData: part.InlineData.Data, + }) + } else if strings.HasPrefix(part.InlineData.MimeType, "video/") { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeVideo, + OriginData: part.InlineData.Data, + }) + } else { + files = append(files, &types.FileMeta{ + FileType: types.FileTypeFile, + OriginData: part.InlineData.Data, + }) + } + } + } + } + + inputText := strings.Join(inputTexts, "\n") + return &types.TokenCountMeta{ + CombineText: inputText, + Files: files, + MaxTokens: maxTokens, + } +} + +func (r *GeminiChatRequest) IsStream(c *gin.Context) bool { + if c.Query("alt") == "sse" { + return true + } + return false +} + +func (r *GeminiChatRequest) SetModelName(modelName string) { + // GeminiChatRequest does not have a model field, so this method does nothing. +} + +func (r *GeminiChatRequest) GetTools() []GeminiChatTool { + var tools []GeminiChatTool + if strings.HasSuffix(string(r.Tools), "[") { + // is array + if err := common.Unmarshal(r.Tools, &tools); err != nil { + logger.LogError(nil, "error_unmarshalling_tools: "+err.Error()) + return nil + } + } else if strings.HasPrefix(string(r.Tools), "{") { + // is object + singleTool := GeminiChatTool{} + if err := common.Unmarshal(r.Tools, &singleTool); err != nil { + logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error()) + return nil + } + tools = []GeminiChatTool{singleTool} + } + return tools +} + +func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) { + if len(tools) == 0 { + r.Tools = json.RawMessage("[]") + return + } + + // Marshal the tools to JSON + data, err := common.Marshal(tools) + if err != nil { + logger.LogError(nil, "error_marshalling_tools: "+err.Error()) + return + } + r.Tools = data +} + type GeminiThinkingConfig struct { IncludeThoughts bool `json:"includeThoughts,omitempty"` ThinkingBudget *int `json:"thinkingBudget,omitempty"` @@ -32,7 +135,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error { MimeTypeSnake string `json:"mime_type"` } - if err := json.Unmarshal(data, &aux); err != nil { + if err := common.Unmarshal(data, &aux); err != nil { return err } @@ -53,7 +156,7 @@ type FunctionCall struct { Arguments any `json:"args"` } -type FunctionResponse struct { +type GeminiFunctionResponse struct { Name string `json:"name"` Response map[string]interface{} `json:"response"` } @@ -78,7 +181,7 @@ type GeminiPart struct { Thought bool `json:"thought,omitempty"` InlineData *GeminiInlineData `json:"inlineData,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"` - FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` + FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"` FileData *GeminiFileData `json:"fileData,omitempty"` ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"` CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"` @@ -93,7 +196,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error { InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant } - if err := json.Unmarshal(data, &aux); err != nil { + if err := common.Unmarshal(data, &aux); err != nil { return err } @@ -166,14 +269,15 @@ type GeminiChatResponse struct { } type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - ThoughtsTokenCount int `json:"thoughtsTokenCount"` - PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + ThoughtsTokenCount int `json:"thoughtsTokenCount"` + PromptTokensDetails []GeminiModalityTokenCount `json:"promptTokensDetails"` + CandidatesTokensDetails []GeminiModalityTokenCount `json:"candidatesTokensDetails"` } -type GeminiPromptTokensDetails struct { +type GeminiModalityTokenCount struct { Modality string `json:"modality"` TokenCount int `json:"tokenCount"` } @@ -207,16 +311,76 @@ type GeminiImagePrediction struct { // Embedding related structs type GeminiEmbeddingRequest struct { + Model string `json:"model,omitempty"` Content GeminiChatContent `json:"content"` TaskType string `json:"taskType,omitempty"` Title string `json:"title,omitempty"` OutputDimensionality int `json:"outputDimensionality,omitempty"` } +func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool { + // Gemini embedding requests are not streamed + return false +} + +func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta { + var inputTexts []string + for _, part := range r.Content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + inputText := strings.Join(inputTexts, "\n") + return &types.TokenCountMeta{ + CombineText: inputText, + } +} + +func (r *GeminiEmbeddingRequest) SetModelName(modelName string) { + if modelName != "" { + r.Model = modelName + } +} + +type GeminiBatchEmbeddingRequest struct { + Requests []*GeminiEmbeddingRequest `json:"requests"` +} + +func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool { + // Gemini batch embedding requests are not streamed + return false +} + +func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta { + var inputTexts []string + for _, request := range r.Requests { + meta := request.GetTokenCountMeta() + if meta != nil && meta.CombineText != "" { + inputTexts = append(inputTexts, meta.CombineText) + } + } + inputText := strings.Join(inputTexts, "\n") + return &types.TokenCountMeta{ + CombineText: inputText, + } +} + +func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) { + if modelName != "" { + for _, req := range r.Requests { + req.SetModelName(modelName) + } + } +} + type GeminiEmbeddingResponse struct { Embedding ContentEmbedding `json:"embedding"` } +type GeminiBatchEmbeddingResponse struct { + Embeddings []*ContentEmbedding `json:"embeddings"` +} + type ContentEmbedding struct { Values []float64 `json:"values"` } diff --git a/dto/openai_image.go b/dto/openai_image.go new file mode 100644 index 000000000..9e838688e --- /dev/null +++ b/dto/openai_image.go @@ -0,0 +1,147 @@ +package dto + +import ( + "encoding/json" + "one-api/common" + "one-api/types" + "reflect" + "strings" + + "github.com/gin-gonic/gin" +) + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N uint `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style json.RawMessage `json:"style,omitempty"` + User json.RawMessage `json:"user,omitempty"` + ExtraFields json.RawMessage `json:"extra_fields,omitempty"` + Background json.RawMessage `json:"background,omitempty"` + Moderation json.RawMessage `json:"moderation,omitempty"` + OutputFormat json.RawMessage `json:"output_format,omitempty"` + OutputCompression json.RawMessage `json:"output_compression,omitempty"` + PartialImages json.RawMessage `json:"partial_images,omitempty"` + // Stream bool `json:"stream,omitempty"` + Watermark *bool `json:"watermark,omitempty"` + // 用匿名参数接收额外参数 + Extra map[string]json.RawMessage `json:"-"` +} + +func (i *ImageRequest) UnmarshalJSON(data []byte) error { + // 先解析成 map[string]interface{} + var rawMap map[string]json.RawMessage + if err := common.Unmarshal(data, &rawMap); err != nil { + return err + } + + // 用 struct tag 获取所有已定义字段名 + knownFields := GetJSONFieldNames(reflect.TypeOf(*i)) + + // 再正常解析已定义字段 + type Alias ImageRequest + var known Alias + if err := common.Unmarshal(data, &known); err != nil { + return err + } + *i = ImageRequest(known) + + // 提取多余字段 + i.Extra = make(map[string]json.RawMessage) + for k, v := range rawMap { + if _, ok := knownFields[k]; !ok { + i.Extra[k] = v + } + } + return nil +} + +func GetJSONFieldNames(t reflect.Type) map[string]struct{} { + fields := make(map[string]struct{}) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // 跳过匿名字段(例如 ExtraFields) + if field.Anonymous { + continue + } + + tag := field.Tag.Get("json") + if tag == "-" || tag == "" { + continue + } + + // 取逗号前字段名(排除 omitempty 等) + name := tag + if commaIdx := indexComma(tag); commaIdx != -1 { + name = tag[:commaIdx] + } + fields[name] = struct{}{} + } + return fields +} + +func indexComma(s string) int { + for i := 0; i < len(s); i++ { + if s[i] == ',' { + return i + } + } + return -1 +} + +func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta { + var sizeRatio = 1.0 + var qualityRatio = 1.0 + + if strings.HasPrefix(i.Model, "dall-e") { + // Size + if i.Size == "256x256" { + sizeRatio = 0.4 + } else if i.Size == "512x512" { + sizeRatio = 0.45 + } else if i.Size == "1024x1024" { + sizeRatio = 1 + } else if i.Size == "1024x1792" || i.Size == "1792x1024" { + sizeRatio = 2 + } + + if i.Model == "dall-e-3" && i.Quality == "hd" { + qualityRatio = 2.0 + if i.Size == "1024x1792" || i.Size == "1792x1024" { + qualityRatio = 1.5 + } + } + } + + // not support token count for dalle + return &types.TokenCountMeta{ + CombineText: i.Prompt, + MaxTokens: 1584, + ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N), + } +} + +func (i *ImageRequest) IsStream(c *gin.Context) bool { + return false +} + +func (i *ImageRequest) SetModelName(modelName string) { + if modelName != "" { + i.Model = modelName + } +} + +type ImageResponse struct { + Data []ImageData `json:"data"` + Created int64 `json:"created"` + Extra any `json:"extra,omitempty"` +} +type ImageData struct { + Url string `json:"url"` + B64Json string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` +} diff --git a/dto/openai_request.go b/dto/openai_request.go index 88d3bd6cc..cd05a63c9 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -2,20 +2,24 @@ package dto import ( "encoding/json" + "fmt" "one-api/common" + "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) type ResponseFormat struct { - Type string `json:"type,omitempty"` - JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"` + Type string `json:"type,omitempty"` + JsonSchema json.RawMessage `json:"json_schema,omitempty"` } type FormatJsonSchema struct { - Description string `json:"description,omitempty"` - Name string `json:"name"` - Schema any `json:"schema,omitempty"` - Strict any `json:"strict,omitempty"` + Description string `json:"description,omitempty"` + Name string `json:"name"` + Schema any `json:"schema,omitempty"` + Strict json.RawMessage `json:"strict,omitempty"` } type GeneralOpenAIRequest struct { @@ -29,6 +33,7 @@ type GeneralOpenAIRequest struct { MaxTokens uint `json:"max_tokens,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"` + Verbosity json.RawMessage `json:"verbosity,omitempty"` // gpt-5 Temperature *float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` @@ -52,16 +57,142 @@ type GeneralOpenAIRequest struct { Dimensions int `json:"dimensions,omitempty"` Modalities json.RawMessage `json:"modalities,omitempty"` Audio json.RawMessage `json:"audio,omitempty"` - EnableThinking any `json:"enable_thinking,omitempty"` // ali - THINKING json.RawMessage `json:"thinking,omitempty"` // doubao - ExtraBody json.RawMessage `json:"extra_body,omitempty"` - SearchParameters any `json:"search_parameters,omitempty"` //xai - WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` + // gemini + ExtraBody json.RawMessage `json:"extra_body,omitempty"` + //xai + SearchParameters json.RawMessage `json:"search_parameters,omitempty"` + // claude + WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` // OpenRouter Params Usage json.RawMessage `json:"usage,omitempty"` Reasoning json.RawMessage `json:"reasoning,omitempty"` // Ali Qwen Params VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"` + EnableThinking any `json:"enable_thinking,omitempty"` + // ollama Params + Think json.RawMessage `json:"think,omitempty"` + // baidu v2 + WebSearch json.RawMessage `json:"web_search,omitempty"` + // doubao,zhipu_v4 + THINKING json.RawMessage `json:"thinking,omitempty"` +} + +func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { + var tokenCountMeta types.TokenCountMeta + var texts = make([]string, 0) + var fileMeta = make([]*types.FileMeta, 0) + + if r.Prompt != nil { + switch v := r.Prompt.(type) { + case string: + texts = append(texts, v) + case []any: + for _, item := range v { + if str, ok := item.(string); ok { + texts = append(texts, str) + } + } + default: + texts = append(texts, fmt.Sprintf("%v", r.Prompt)) + } + } + + if r.Input != nil { + inputs := r.ParseInput() + texts = append(texts, inputs...) + } + + if r.MaxCompletionTokens > r.MaxTokens { + tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens) + } else { + tokenCountMeta.MaxTokens = int(r.MaxTokens) + } + + for _, message := range r.Messages { + tokenCountMeta.MessagesCount++ + texts = append(texts, message.Role) + if message.Content != nil { + if message.Name != nil { + tokenCountMeta.NameCount++ + texts = append(texts, *message.Name) + } + arrayContent := message.ParseContent() + for _, m := range arrayContent { + if m.Type == ContentTypeImageURL { + imageUrl := m.GetImageMedia() + if imageUrl != nil { + if imageUrl.Url != "" { + meta := &types.FileMeta{ + FileType: types.FileTypeImage, + } + meta.OriginData = imageUrl.Url + meta.Detail = imageUrl.Detail + fileMeta = append(fileMeta, meta) + } + } + } else if m.Type == ContentTypeInputAudio { + inputAudio := m.GetInputAudio() + if inputAudio != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeAudio, + } + meta.OriginData = inputAudio.Data + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeFile { + file := m.GetFile() + if file != nil { + meta := &types.FileMeta{ + FileType: types.FileTypeFile, + } + meta.OriginData = file.FileData + fileMeta = append(fileMeta, meta) + } + } else if m.Type == ContentTypeVideoUrl { + videoUrl := m.GetVideoUrl() + if videoUrl != nil && videoUrl.Url != "" { + meta := &types.FileMeta{ + FileType: types.FileTypeVideo, + } + meta.OriginData = videoUrl.Url + fileMeta = append(fileMeta, meta) + } + } else { + texts = append(texts, m.Text) + } + } + } + } + + if r.Tools != nil { + openaiTools := r.Tools + for _, tool := range openaiTools { + tokenCountMeta.ToolsCount++ + texts = append(texts, tool.Function.Name) + if tool.Function.Description != "" { + texts = append(texts, tool.Function.Description) + } + if tool.Function.Parameters != nil { + texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters)) + } + } + //toolTokens := CountTokenInput(countStr, request.Model) + //tkm += 8 + //tkm += toolTokens + } + tokenCountMeta.CombineText = strings.Join(texts, "\n") + tokenCountMeta.Files = fileMeta + return &tokenCountMeta +} + +func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool { + return r.Stream +} + +func (r *GeneralOpenAIRequest) SetModelName(modelName string) { + if modelName != "" { + r.Model = modelName + } } func (r *GeneralOpenAIRequest) ToMap() map[string]any { @@ -71,6 +202,17 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any { return result } +func (r *GeneralOpenAIRequest) GetSystemRoleName() string { + if strings.HasPrefix(r.Model, "o") { + if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") { + return "developer" + } + } else if strings.HasPrefix(r.Model, "gpt-5") { + return "developer" + } + return "system" +} + type ToolCallRequest struct { ID string `json:"id,omitempty"` Type string `json:"type"` @@ -88,8 +230,11 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r *GeneralOpenAIRequest) GetMaxTokens() int { - return int(r.MaxTokens) +func (r *GeneralOpenAIRequest) GetMaxTokens() uint { + if r.MaxCompletionTokens != 0 { + return r.MaxCompletionTokens + } + return r.MaxTokens } func (r *GeneralOpenAIRequest) ParseInput() []string { @@ -185,6 +330,21 @@ func (m *MediaContent) GetFile() *MessageFile { return nil } +func (m *MediaContent) GetVideoUrl() *MessageVideoUrl { + if m.VideoUrl != nil { + if _, ok := m.VideoUrl.(*MessageVideoUrl); ok { + return m.VideoUrl.(*MessageVideoUrl) + } + if itemMap, ok := m.VideoUrl.(map[string]any); ok { + out := &MessageVideoUrl{ + Url: common.Interface2String(itemMap["url"]), + } + return out + } + } + return nil +} + type MessageImageUrl struct { Url string `json:"url"` Detail string `json:"detail"` @@ -216,6 +376,7 @@ const ( ContentTypeInputAudio = "input_audio" ContentTypeFile = "file" ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别 + //ContentTypeAudioUrl = "audio_url" ) func (m *Message) GetPrefix() bool { @@ -605,27 +766,104 @@ type WebSearchOptions struct { // https://platform.openai.com/docs/api-reference/responses/create type OpenAIResponsesRequest struct { - Model string `json:"model"` - Input json.RawMessage `json:"input,omitempty"` - Include json.RawMessage `json:"include,omitempty"` - Instructions json.RawMessage `json:"instructions,omitempty"` - MaxOutputTokens uint `json:"max_output_tokens,omitempty"` - Metadata json.RawMessage `json:"metadata,omitempty"` - ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` - PreviousResponseID string `json:"previous_response_id,omitempty"` - Reasoning *Reasoning `json:"reasoning,omitempty"` - ServiceTier string `json:"service_tier,omitempty"` - Store bool `json:"store,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Text json.RawMessage `json:"text,omitempty"` - ToolChoice json.RawMessage `json:"tool_choice,omitempty"` - Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map - TopP float64 `json:"top_p,omitempty"` - Truncation string `json:"truncation,omitempty"` - User string `json:"user,omitempty"` - MaxToolCalls uint `json:"max_tool_calls,omitempty"` - Prompt json.RawMessage `json:"prompt,omitempty"` + Model string `json:"model"` + Input json.RawMessage `json:"input,omitempty"` + Include json.RawMessage `json:"include,omitempty"` + Instructions json.RawMessage `json:"instructions,omitempty"` + MaxOutputTokens uint `json:"max_output_tokens,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + Store bool `json:"store,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Text json.RawMessage `json:"text,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map + TopP float64 `json:"top_p,omitempty"` + Truncation string `json:"truncation,omitempty"` + User string `json:"user,omitempty"` + MaxToolCalls uint `json:"max_tool_calls,omitempty"` + Prompt json.RawMessage `json:"prompt,omitempty"` +} + +func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { + var fileMeta = make([]*types.FileMeta, 0) + var texts = make([]string, 0) + + if r.Input != nil { + inputs := r.ParseInput() + for _, input := range inputs { + if input.Type == "input_image" { + if input.ImageUrl != "" { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeImage, + OriginData: input.ImageUrl, + Detail: input.Detail, + }) + } + } else if input.Type == "input_file" { + if input.FileUrl != "" { + fileMeta = append(fileMeta, &types.FileMeta{ + FileType: types.FileTypeFile, + OriginData: input.FileUrl, + }) + } + } else { + texts = append(texts, input.Text) + } + } + } + + if len(r.Instructions) > 0 { + texts = append(texts, string(r.Instructions)) + } + + if len(r.Metadata) > 0 { + texts = append(texts, string(r.Metadata)) + } + + if len(r.Text) > 0 { + texts = append(texts, string(r.Text)) + } + + if len(r.ToolChoice) > 0 { + texts = append(texts, string(r.ToolChoice)) + } + + if len(r.Prompt) > 0 { + texts = append(texts, string(r.Prompt)) + } + + if len(r.Tools) > 0 { + texts = append(texts, string(r.Tools)) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + Files: fileMeta, + MaxTokens: int(r.MaxOutputTokens), + } +} + +func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool { + return r.Stream +} + +func (r *OpenAIResponsesRequest) SetModelName(modelName string) { + if modelName != "" { + r.Model = modelName + } +} + +func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any { + var toolsMap []map[string]any + if len(r.Tools) > 0 { + _ = common.Unmarshal(r.Tools, &toolsMap) + } + return toolsMap } type Reasoning struct { @@ -633,23 +871,88 @@ type Reasoning struct { Summary string `json:"summary,omitempty"` } -//type ResponsesToolsCall struct { -// Type string `json:"type"` -// // Web Search -// UserLocation json.RawMessage `json:"user_location,omitempty"` -// SearchContextSize string `json:"search_context_size,omitempty"` -// // File Search -// VectorStoreIds []string `json:"vector_store_ids,omitempty"` -// MaxNumResults uint `json:"max_num_results,omitempty"` -// Filters json.RawMessage `json:"filters,omitempty"` -// // Computer Use -// DisplayWidth uint `json:"display_width,omitempty"` -// DisplayHeight uint `json:"display_height,omitempty"` -// Environment string `json:"environment,omitempty"` -// // Function -// Name string `json:"name,omitempty"` -// Description string `json:"description,omitempty"` -// Parameters json.RawMessage `json:"parameters,omitempty"` -// Function json.RawMessage `json:"function,omitempty"` -// Container json.RawMessage `json:"container,omitempty"` -//} +type MediaInput struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + FileUrl string `json:"file_url,omitempty"` + ImageUrl string `json:"image_url,omitempty"` + Detail string `json:"detail,omitempty"` // 仅 input_image 有效 +} + +// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput. +// Reference implementation mirrors Message.ParseContent: +// - input can be a string, treated as an input_text item +// - input can be an array of objects with a `type` field +// supported types: input_text, input_image, input_file +func (r *OpenAIResponsesRequest) ParseInput() []MediaInput { + if r.Input == nil { + return nil + } + + var inputs []MediaInput + + // Try string first + // if str, ok := common.GetJsonType(r.Input); ok { + // inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) + // return inputs + // } + if common.GetJsonType(r.Input) == "string" { + var str string + _ = common.Unmarshal(r.Input, &str) + inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) + return inputs + } + + // Try array of parts + if common.GetJsonType(r.Input) == "array" { + var array []any + _ = common.Unmarshal(r.Input, &array) + for _, itemAny := range array { + // Already parsed MediaInput + if media, ok := itemAny.(MediaInput); ok { + inputs = append(inputs, media) + continue + } + // Generic map + item, ok := itemAny.(map[string]any) + if !ok { + continue + } + typeVal, ok := item["type"].(string) + if !ok { + continue + } + switch typeVal { + case "input_text": + text, _ := item["text"].(string) + inputs = append(inputs, MediaInput{Type: "input_text", Text: text}) + case "input_image": + // image_url may be string or object with url field + var imageUrl string + switch v := item["image_url"].(type) { + case string: + imageUrl = v + case map[string]any: + if url, ok := v["url"].(string); ok { + imageUrl = url + } + } + inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl}) + case "input_file": + // file_url may be string or object with url field + var fileUrl string + switch v := item["file_url"].(type) { + case string: + fileUrl = v + case map[string]any: + if url, ok := v["url"].(string); ok { + fileUrl = url + } + } + inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl}) + } + } + } + + return inputs +} diff --git a/dto/openai_response.go b/dto/openai_response.go index 4e5348230..966748cb5 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -2,12 +2,18 @@ package dto import ( "encoding/json" + "fmt" "one-api/types" ) type SimpleResponse struct { Usage `json:"usage"` - Error *OpenAIError `json:"error"` + Error any `json:"error"` +} + +// GetOpenAIError 从动态错误类型中提取OpenAIError结构 +func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError { + return GetOpenAIError(s.Error) } type TextResponse struct { @@ -31,10 +37,15 @@ type OpenAITextResponse struct { Object string `json:"object"` Created any `json:"created"` Choices []OpenAITextResponseChoice `json:"choices"` - Error *types.OpenAIError `json:"error,omitempty"` + Error any `json:"error,omitempty"` Usage `json:"usage"` } +// GetOpenAIError 从动态错误类型中提取OpenAIError结构 +func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError { + return GetOpenAIError(o.Error) +} + type OpenAIEmbeddingResponseItem struct { Object string `json:"object"` Index int `json:"index"` @@ -99,7 +110,7 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) { c.ReasoningContent = &s - c.Reasoning = &s + //c.Reasoning = &s } type ToolCallResponse struct { @@ -132,6 +143,13 @@ type ChatCompletionsStreamResponse struct { Usage *Usage `json:"usage"` } +func (c *ChatCompletionsStreamResponse) IsFinished() bool { + if len(c.Choices) == 0 { + return false + } + return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != "" +} + func (c *ChatCompletionsStreamResponse) IsToolCall() bool { if len(c.Choices) == 0 { return false @@ -146,6 +164,19 @@ func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse { return nil } +func (c *ChatCompletionsStreamResponse) ClearToolCalls() { + if !c.IsToolCall() { + return + } + for choiceIdx := range c.Choices { + for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls { + c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = "" + c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil + c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = "" + } + } +} + func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse { choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices)) copy(choices, c.Choices) @@ -217,7 +248,7 @@ type OpenAIResponsesResponse struct { Object string `json:"object"` CreatedAt int `json:"created_at"` Status string `json:"status"` - Error *types.OpenAIError `json:"error,omitempty"` + Error any `json:"error,omitempty"` IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` Instructions string `json:"instructions"` MaxOutputTokens int `json:"max_output_tokens"` @@ -237,6 +268,11 @@ type OpenAIResponsesResponse struct { Metadata json.RawMessage `json:"metadata"` } +// GetOpenAIError 从动态错误类型中提取OpenAIError结构 +func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError { + return GetOpenAIError(o.Error) +} + type IncompleteDetails struct { Reasoning string `json:"reasoning"` } @@ -276,3 +312,45 @@ type ResponsesStreamResponse struct { Delta string `json:"delta,omitempty"` Item *ResponsesOutput `json:"item,omitempty"` } + +// GetOpenAIError 从动态错误类型中提取OpenAIError结构 +func GetOpenAIError(errorField any) *types.OpenAIError { + if errorField == nil { + return nil + } + + switch err := errorField.(type) { + case types.OpenAIError: + return &err + case *types.OpenAIError: + return err + case map[string]interface{}: + // 处理从JSON解析来的map结构 + openaiErr := &types.OpenAIError{} + if errType, ok := err["type"].(string); ok { + openaiErr.Type = errType + } + if errMsg, ok := err["message"].(string); ok { + openaiErr.Message = errMsg + } + if errParam, ok := err["param"].(string); ok { + openaiErr.Param = errParam + } + if errCode, ok := err["code"]; ok { + openaiErr.Code = errCode + } + return openaiErr + case string: + // 处理简单字符串错误 + return &types.OpenAIError{ + Type: "error", + Message: err, + } + default: + // 未知类型,尝试转换为字符串 + return &types.OpenAIError{ + Type: "unknown_error", + Message: fmt.Sprintf("%v", err), + } + } +} diff --git a/dto/pricing.go b/dto/pricing.go index 0f317d9d5..bc024de30 100644 --- a/dto/pricing.go +++ b/dto/pricing.go @@ -2,6 +2,7 @@ package dto import "one-api/constant" +// 这里不好动就不动了,本来想独立出来的( type OpenAIModels struct { Id string `json:"id"` Object string `json:"object"` @@ -9,3 +10,26 @@ type OpenAIModels struct { OwnedBy string `json:"owned_by"` SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` } + +type AnthropicModel struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` + DisplayName string `json:"display_name"` + Type string `json:"type"` +} + +type GeminiModel struct { + Name interface{} `json:"name"` + BaseModelId interface{} `json:"baseModelId"` + Version interface{} `json:"version"` + DisplayName interface{} `json:"displayName"` + Description interface{} `json:"description"` + InputTokenLimit interface{} `json:"inputTokenLimit"` + OutputTokenLimit interface{} `json:"outputTokenLimit"` + SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"` + Thinking interface{} `json:"thinking"` + Temperature interface{} `json:"temperature"` + MaxTemperature interface{} `json:"maxTemperature"` + TopP interface{} `json:"topP"` + TopK interface{} `json:"topK"` +} diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go index 6315f31ae..d6bbf68e1 100644 --- a/dto/ratio_sync.go +++ b/dto/ratio_sync.go @@ -1,23 +1,23 @@ package dto type UpstreamDTO struct { - ID int `json:"id,omitempty"` - Name string `json:"name" binding:"required"` - BaseURL string `json:"base_url" binding:"required"` - Endpoint string `json:"endpoint"` + ID int `json:"id,omitempty"` + Name string `json:"name" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + Endpoint string `json:"endpoint"` } type UpstreamRequest struct { - ChannelIDs []int64 `json:"channel_ids"` - Upstreams []UpstreamDTO `json:"upstreams"` - Timeout int `json:"timeout"` + ChannelIDs []int64 `json:"channel_ids"` + Upstreams []UpstreamDTO `json:"upstreams"` + Timeout int `json:"timeout"` } // TestResult 上游测试连通性结果 type TestResult struct { - Name string `json:"name"` - Status string `json:"status"` - Error string `json:"error,omitempty"` + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` } // DifferenceItem 差异项 @@ -25,14 +25,14 @@ type TestResult struct { // Upstreams 为各渠道的上游值,具体数值 / "same" / nil type DifferenceItem struct { - Current interface{} `json:"current"` - Upstreams map[string]interface{} `json:"upstreams"` - Confidence map[string]bool `json:"confidence"` + Current interface{} `json:"current"` + Upstreams map[string]interface{} `json:"upstreams"` + Confidence map[string]bool `json:"confidence"` } type SyncableChannel struct { - ID int `json:"id"` - Name string `json:"name"` - BaseURL string `json:"base_url"` - Status int `json:"status"` -} \ No newline at end of file + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} diff --git a/dto/request_common.go b/dto/request_common.go new file mode 100644 index 000000000..da3ac3c52 --- /dev/null +++ b/dto/request_common.go @@ -0,0 +1,25 @@ +package dto + +import ( + "github.com/gin-gonic/gin" + "one-api/types" +) + +type Request interface { + GetTokenCountMeta() *types.TokenCountMeta + IsStream(c *gin.Context) bool + SetModelName(modelName string) +} + +type BaseRequest struct { +} + +func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta { + return &types.TokenCountMeta{ + TokenType: types.TokenTypeTokenizer, + } +} +func (b *BaseRequest) IsStream(c *gin.Context) bool { + return false +} +func (b *BaseRequest) SetModelName(modelName string) {} diff --git a/dto/rerank.go b/dto/rerank.go index 5ea68cba1..46f4bce6f 100644 --- a/dto/rerank.go +++ b/dto/rerank.go @@ -1,5 +1,12 @@ package dto +import ( + "fmt" + "github.com/gin-gonic/gin" + "one-api/types" + "strings" +) + type RerankRequest struct { Documents []any `json:"documents"` Query string `json:"query"` @@ -10,6 +17,32 @@ type RerankRequest struct { OverLapTokens int `json:"overlap_tokens,omitempty"` } +func (r *RerankRequest) IsStream(c *gin.Context) bool { + return false +} + +func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta { + var texts = make([]string, 0) + + for _, document := range r.Documents { + texts = append(texts, fmt.Sprintf("%v", document)) + } + + if r.Query != "" { + texts = append(texts, r.Query) + } + + return &types.TokenCountMeta{ + CombineText: strings.Join(texts, "\n"), + } +} + +func (r *RerankRequest) SetModelName(modelName string) { + if modelName != "" { + r.Model = modelName + } +} + func (r *RerankRequest) GetReturnDocuments() bool { if r.ReturnDocuments == nil { return false diff --git a/dto/user_settings.go b/dto/user_settings.go index 2e1a15418..89dd926ef 100644 --- a/dto/user_settings.go +++ b/dto/user_settings.go @@ -6,11 +6,14 @@ type UserSetting struct { WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址 WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥 NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址 + BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型 RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP + SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置 } var ( NotifyTypeEmail = "email" // Email 邮件 NotifyTypeWebhook = "webhook" // Webhook + NotifyTypeBark = "bark" // Bark 推送 ) diff --git a/go.mod b/go.mod index 94873c88a..501d966d5 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,10 @@ require ( github.com/Calcium-Ion/go-epay v0.0.4 github.com/andybalholm/brotli v1.1.1 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 - github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2 v1.37.2 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 - github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 + github.com/aws/smithy-go v1.22.5 github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b github.com/gin-contrib/cors v1.7.2 github.com/gin-contrib/gzip v0.0.6 @@ -22,13 +23,17 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 + github.com/jinzhu/copier v0.4.0 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 + github.com/pquerna/otp v1.5.0 github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible github.com/shopspring/decimal v1.4.0 github.com/stripe/stripe-go/v81 v81.4.0 github.com/thanhpk/randstr v1.0.6 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 github.com/tiktoken-go/tokenizer v0.6.2 golang.org/x/crypto v0.35.0 golang.org/x/image v0.23.0 @@ -41,10 +46,10 @@ require ( require ( github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect - github.com/aws/smithy-go v1.20.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect + github.com/boombuler/barcode v1.1.0 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -80,6 +85,8 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/go.sum b/go.sum index 74eecd4c2..189d09de4 100644 --- a/go.sum +++ b/go.sum @@ -6,20 +6,23 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8= -github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= -github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo= +github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg= github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= -github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= -github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= +github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0= github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= @@ -117,6 +120,8 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= @@ -169,6 +174,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= @@ -199,6 +206,15 @@ github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJ github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo= github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o= github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g= github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= diff --git a/i18n/zh-cn.json b/i18n/zh-cn.json deleted file mode 100644 index 7b57b51ac..000000000 --- a/i18n/zh-cn.json +++ /dev/null @@ -1,1041 +0,0 @@ -{ - "未登录或登录已过期,请重新登录": "未登录或登录已过期,请重新登录", - "登 录": "登 录", - "使用 微信 继续": "使用 微信 继续", - "使用 GitHub 继续": "使用 GitHub 继续", - "使用 LinuxDO 继续": "使用 LinuxDO 继续", - "使用 邮箱或用户名 登录": "使用 邮箱或用户名 登录", - "没有账户?": "没有账户?", - "用户名或邮箱": "用户名或邮箱", - "请输入您的用户名或邮箱地址": "请输入您的用户名或邮箱地址", - "请输入您的密码": "请输入您的密码", - "继续": "继续", - "忘记密码?": "忘记密码?", - "其他登录选项": "其他登录选项", - "微信扫码登录": "微信扫码登录", - "登录": "登录", - "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)", - "验证码": "验证码", - "处理中...": "处理中...", - "绑定成功!": "绑定成功!", - "登录成功!": "登录成功!", - "操作失败,重定向至登录界面中...": "操作失败,重定向至登录界面中...", - "出现错误,第 ${count} 次重试中...": "出现错误,第 ${count} 次重试中...", - "无效的重置链接,请重新发起密码重置请求": "无效的重置链接,请重新发起密码重置请求", - "密码已重置并已复制到剪贴板:": "密码已重置并已复制到剪贴板:", - "密码重置确认": "密码重置确认", - "等待获取邮箱信息...": "等待获取邮箱信息...", - "新密码": "新密码", - "密码已复制到剪贴板:": "密码已复制到剪贴板:", - "密码重置完成": "密码重置完成", - "确认重置密码": "确认重置密码", - "返回登录": "返回登录", - "请输入邮箱地址": "请输入邮箱地址", - "请稍后几秒重试,Turnstile 正在检查用户环境!": "请稍后几秒重试,Turnstile 正在检查用户环境!", - "重置邮件发送成功,请检查邮箱!": "重置邮件发送成功,请检查邮箱!", - "密码重置": "密码重置", - "请输入您的邮箱地址": "请输入您的邮箱地址", - "重试": "重试", - "想起来了?": "想起来了?", - "注 册": "注 册", - "使用 用户名 注册": "使用 用户名 注册", - "已有账户?": "已有账户?", - "用户名": "用户名", - "请输入用户名": "请输入用户名", - "输入密码,最短 8 位,最长 20 位": "输入密码,最短 8 位,最长 20 位", - "确认密码": "确认密码", - "输入邮箱地址": "输入邮箱地址", - "获取验证码": "获取验证码", - "输入验证码": "输入验证码", - "或": "或", - "其他注册选项": "其他注册选项", - "加载中...": "加载中...", - "复制代码": "复制代码", - "代码已复制到剪贴板": "代码已复制到剪贴板", - "复制失败,请手动复制": "复制失败,请手动复制", - "显示更多": "显示更多", - "关于我们": "关于我们", - "关于项目": "关于项目", - "联系我们": "联系我们", - "功能特性": "功能特性", - "快速开始": "快速开始", - "安装指南": "安装指南", - "API 文档": "API 文档", - "基于New API的项目": "基于New API的项目", - "版权所有": "版权所有", - "设计与开发由": "设计与开发由", - "首页": "首页", - "控制台": "控制台", - "文档": "文档", - "关于": "关于", - "注销成功!": "注销成功!", - "个人设置": "个人设置", - "API令牌": "API令牌", - "退出": "退出", - "关闭侧边栏": "关闭侧边栏", - "打开侧边栏": "打开侧边栏", - "关闭菜单": "关闭菜单", - "打开菜单": "打开菜单", - "演示站点": "演示站点", - "自用模式": "自用模式", - "系统公告": "系统公告", - "切换主题": "切换主题", - "切换语言": "切换语言", - "暂无公告": "暂无公告", - "暂无系统公告": "暂无系统公告", - "今日关闭": "今日关闭", - "关闭公告": "关闭公告", - "数据看板": "数据看板", - "绘图日志": "绘图日志", - "任务日志": "任务日志", - "渠道": "渠道", - "兑换码": "兑换码", - "用户管理": "用户管理", - "操练场": "操练场", - "聊天": "聊天", - "管理员": "管理员", - "个人中心": "个人中心", - "展开侧边栏": "展开侧边栏", - "AI 对话": "AI 对话", - "选择模型开始对话": "选择模型开始对话", - "显示调试": "显示调试", - "请输入您的问题...": "请输入您的问题...", - "已复制到剪贴板": "已复制到剪贴板", - "复制失败": "复制失败", - "正在构造请求体预览...": "正在构造请求体预览...", - "暂无请求数据": "暂无请求数据", - "暂无响应数据": "暂无响应数据", - "内容较大,已启用性能优化模式": "内容较大,已启用性能优化模式", - "内容较大,部分功能可能受限": "内容较大,部分功能可能受限", - "已复制": "已复制", - "正在处理大内容...": "正在处理大内容...", - "显示完整内容": "显示完整内容", - "收起": "收起", - "配置已导出到下载文件夹": "配置已导出到下载文件夹", - "导出配置失败: ": "导出配置失败: ", - "确认导入配置": "确认导入配置", - "导入的配置将覆盖当前设置,是否继续?": "导入的配置将覆盖当前设置,是否继续?", - "取消": "取消", - "配置导入成功": "配置导入成功", - "导入配置失败: ": "导入配置失败: ", - "重置配置": "重置配置", - "将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?": "将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?", - "重置选项": "重置选项", - "是否同时重置对话消息?选择\"是\"将清空所有对话记录并恢复默认示例;选择\"否\"将保留当前对话记录。": "是否同时重置对话消息?选择\"是\"将清空所有对话记录并恢复默认示例;选择\"否\"将保留当前对话记录。", - "同时重置消息": "同时重置消息", - "仅重置配置": "仅重置配置", - "配置和消息已全部重置": "配置和消息已全部重置", - "配置已重置,对话消息已保留": "配置已重置,对话消息已保留", - "已有保存的配置": "已有保存的配置", - "暂无保存的配置": "暂无保存的配置", - "导出配置": "导出配置", - "导入配置": "导入配置", - "导出": "导出", - "导入": "导入", - "调试信息": "调试信息", - "预览请求体": "预览请求体", - "实际请求体": "实际请求体", - "预览更新": "预览更新", - "最后请求": "最后请求", - "操作暂时被禁用": "操作暂时被禁用", - "复制": "复制", - "编辑": "编辑", - "切换为System角色": "切换为System角色", - "切换为Assistant角色": "切换为Assistant角色", - "删除": "删除", - "请求发生错误": "请求发生错误", - "系统消息": "系统消息", - "请输入消息内容...": "请输入消息内容...", - "保存": "保存", - "模型配置": "模型配置", - "分组": "分组", - "请选择分组": "请选择分组", - "请选择模型": "请选择模型", - "思考中...": "思考中...", - "思考过程": "思考过程", - "选择同步渠道": "选择同步渠道", - "搜索渠道名称或地址": "搜索渠道名称或地址", - "暂无渠道": "暂无渠道", - "暂无选择": "暂无选择", - "无搜索结果": "无搜索结果", - "公告已更新": "公告已更新", - "公告更新失败": "公告更新失败", - "系统名称已更新": "系统名称已更新", - "系统名称更新失败": "系统名称更新失败", - "系统信息": "系统信息", - "当前版本": "当前版本", - "检查更新": "检查更新", - "启动时间": "启动时间", - "通用设置": "通用设置", - "设置公告": "设置公告", - "个性化设置": "个性化设置", - "系统名称": "系统名称", - "在此输入系统名称": "在此输入系统名称", - "设置系统名称": "设置系统名称", - "Logo 图片地址": "Logo 图片地址", - "在此输入 Logo 图片地址": "在此输入 Logo 图片地址", - "首页内容": "首页内容", - "设置首页内容": "设置首页内容", - "设置关于": "设置关于", - "页脚": "页脚", - "设置页脚": "设置页脚", - "详情": "详情", - "刷新失败": "刷新失败", - "令牌已重置并已复制到剪贴板": "令牌已重置并已复制到剪贴板", - "加载模型列表失败": "加载模型列表失败", - "系统令牌已复制到剪切板": "系统令牌已复制到剪切板", - "请输入你的账户名以确认删除!": "请输入你的账户名以确认删除!", - "账户已删除!": "账户已删除!", - "微信账户绑定成功!": "微信账户绑定成功!", - "请输入原密码!": "请输入原密码!", - "请输入新密码!": "请输入新密码!", - "新密码需要和原密码不一致!": "新密码需要和原密码不一致!", - "两次输入的密码不一致!": "两次输入的密码不一致!", - "密码修改成功!": "密码修改成功!", - "验证码发送成功,请检查邮箱!": "验证码发送成功,请检查邮箱!", - "请输入邮箱验证码!": "请输入邮箱验证码!", - "邮箱账户绑定成功!": "邮箱账户绑定成功!", - "无法复制到剪贴板,请手动复制": "无法复制到剪贴板,请手动复制", - "设置保存成功": "设置保存成功", - "设置保存失败": "设置保存失败", - "超级管理员": "超级管理员", - "普通用户": "普通用户", - "当前余额": "当前余额", - "历史消耗": "历史消耗", - "请求次数": "请求次数", - "默认": "默认", - "可用模型": "可用模型", - "模型列表": "模型列表", - "点击模型名称可复制": "点击模型名称可复制", - "没有可用模型": "没有可用模型", - "该分类下没有可用模型": "该分类下没有可用模型", - "更多": "更多", - "个模型": "个模型", - "账户绑定": "账户绑定", - "未绑定": "未绑定", - "修改绑定": "修改绑定", - "微信": "微信", - "已绑定": "已绑定", - "未启用": "未启用", - "绑定": "绑定", - "安全设置": "安全设置", - "系统访问令牌": "系统访问令牌", - "用于API调用的身份验证令牌,请妥善保管": "用于API调用的身份验证令牌,请妥善保管", - "生成令牌": "生成令牌", - "密码管理": "密码管理", - "定期更改密码可以提高账户安全性": "定期更改密码可以提高账户安全性", - "修改密码": "修改密码", - "此操作不可逆,所有数据将被永久删除": "此操作不可逆,所有数据将被永久删除", - "删除账户": "删除账户", - "其他设置": "其他设置", - "通知设置": "通知设置", - "邮件通知": "邮件通知", - "通过邮件接收通知": "通过邮件接收通知", - "Webhook通知": "Webhook通知", - "通过HTTP请求接收通知": "通过HTTP请求接收通知", - "请输入Webhook地址,例如: https://example.com/webhook": "请输入Webhook地址,例如: https://example.com/webhook", - "只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求": "只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求", - "接口凭证(可选)": "接口凭证(可选)", - "请输入密钥": "请输入密钥", - "密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性": "密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性", - "通知邮箱": "通知邮箱", - "留空则使用账号绑定的邮箱": "留空则使用账号绑定的邮箱", - "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱", - "额度预警阈值": "额度预警阈值", - "请输入预警额度": "请输入预警额度", - "当剩余额度低于此数值时,系统将通过选择的方式发送通知": "当剩余额度低于此数值时,系统将通过选择的方式发送通知", - "接受未设置价格模型": "接受未设置价格模型", - "当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用": "当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用", - "IP记录": "IP记录", - "记录请求与错误日志 IP": "记录请求与错误日志 IP", - "开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址": "开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址", - "绑定邮箱地址": "绑定邮箱地址", - "重新发送": "重新发送", - "绑定微信账户": "绑定微信账户", - "删除账户确认": "删除账户确认", - "您正在删除自己的帐户,将清空所有数据且不可恢复": "您正在删除自己的帐户,将清空所有数据且不可恢复", - "请输入您的用户名以确认删除": "请输入您的用户名以确认删除", - "输入你的账户名{{username}}以确认删除": "输入你的账户名{{username}}以确认删除", - "原密码": "原密码", - "请输入原密码": "请输入原密码", - "请输入新密码": "请输入新密码", - "确认新密码": "确认新密码", - "请再次输入新密码": "请再次输入新密码", - "模型倍率设置": "模型倍率设置", - "可视化倍率设置": "可视化倍率设置", - "未设置倍率模型": "未设置倍率模型", - "上游倍率同步": "上游倍率同步", - "未知类型": "未知类型", - "标签聚合": "标签聚合", - "已启用": "已启用", - "自动禁用": "自动禁用", - "未知状态": "未知状态", - "未测试": "未测试", - "名称": "名称", - "类型": "类型", - "状态": "状态", - ",时间:": ",时间:", - "响应时间": "响应时间", - "已用/剩余": "已用/剩余", - "剩余额度$": "剩余额度$", - ",点击更新": ",点击更新", - "已用额度": "已用额度", - "修改子渠道优先级": "修改子渠道优先级", - "确定要修改所有子渠道优先级为 ": "确定要修改所有子渠道优先级为 ", - "权重": "权重", - "修改子渠道权重": "修改子渠道权重", - "确定要修改所有子渠道权重为 ": "确定要修改所有子渠道权重为 ", - "确定是否要删除此渠道?": "确定是否要删除此渠道?", - "此修改将不可逆": "此修改将不可逆", - "确定是否要复制此渠道?": "确定是否要复制此渠道?", - "复制渠道的所有信息": "复制渠道的所有信息", - "测试单个渠道操作项目组": "测试单个渠道操作项目组", - "禁用": "禁用", - "启用": "启用", - "启用全部": "启用全部", - "禁用全部": "禁用全部", - "重置": "重置", - "全选": "全选", - "_复制": "_复制", - "渠道未找到,请刷新页面后重试。": "渠道未找到,请刷新页面后重试。", - "渠道复制成功": "渠道复制成功", - "渠道复制失败: ": "渠道复制失败: ", - "操作成功完成!": "操作成功完成!", - "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。", - "已停止测试": "已停止测试", - "全部": "全部", - "请先选择要设置标签的渠道!": "请先选择要设置标签的渠道!", - "标签不能为空!": "标签不能为空!", - "已为 ${count} 个渠道设置标签!": "已为 ${count} 个渠道设置标签!", - "已成功开始测试所有已启用通道,请刷新页面查看结果。": "已成功开始测试所有已启用通道,请刷新页面查看结果。", - "已删除所有禁用渠道,共计 ${data} 个": "已删除所有禁用渠道,共计 ${data} 个", - "已更新完毕所有已启用通道余额!": "已更新完毕所有已启用通道余额!", - "通道 ${name} 余额更新成功!": "通道 ${name} 余额更新成功!", - "已删除 ${data} 个通道!": "已删除 ${data} 个通道!", - "已修复 ${data} 个通道!": "已修复 ${data} 个通道!", - "确定是否要删除所选通道?": "确定是否要删除所选通道?", - "删除所选通道": "删除所选通道", - "批量设置标签": "批量设置标签", - "确定要测试所有通道吗?": "确定要测试所有通道吗?", - "测试所有通道": "测试所有通道", - "确定要更新所有已启用通道余额吗?": "确定要更新所有已启用通道余额吗?", - "更新所有已启用通道余额": "更新所有已启用通道余额", - "确定是否要删除禁用通道?": "确定是否要删除禁用通道?", - "删除禁用通道": "删除禁用通道", - "确定是否要修复数据库一致性?": "确定是否要修复数据库一致性?", - "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用", - "批量操作": "批量操作", - "使用ID排序": "使用ID排序", - "开启批量操作": "开启批量操作", - "标签聚合模式": "标签聚合模式", - "刷新": "刷新", - "列设置": "列设置", - "搜索渠道的 ID,名称,密钥和API地址 ...": "搜索渠道的 ID,名称,密钥和API地址 ...", - "模型关键字": "模型关键字", - "选择分组": "选择分组", - "查询": "查询", - "第 {{start}} - {{end}} 条,共 {{total}} 条": "第 {{start}} - {{end}} 条,共 {{total}} 条", - "搜索无结果": "搜索无结果", - "请输入要设置的标签名称": "请输入要设置的标签名称", - "请输入标签名称": "请输入标签名称", - "已选择 ${count} 个渠道": "已选择 ${count} 个渠道", - "共": "共", - "停止测试": "停止测试", - "测试中...": "测试中...", - "批量测试${count}个模型": "批量测试${count}个模型", - "搜索模型...": "搜索模型...", - "模型名称": "模型名称", - "测试中": "测试中", - "未开始": "未开始", - "失败": "失败", - "请求时长: ${time}s": "请求时长: ${time}s", - "充值": "充值", - "消费": "消费", - "系统": "系统", - "错误": "错误", - "流": "流", - "非流": "非流", - "请求并计费模型": "请求并计费模型", - "实际模型": "实际模型", - "用户": "用户", - "用时/首字": "用时/首字", - "提示": "提示", - "花费": "花费", - "只有当用户设置开启IP记录时,才会进行请求和错误类型日志的IP记录": "只有当用户设置开启IP记录时,才会进行请求和错误类型日志的IP记录", - "确定": "确定", - "用户信息": "用户信息", - "渠道信息": "渠道信息", - "语音输入": "语音输入", - "文字输入": "文字输入", - "文字输出": "文字输出", - "缓存创建 Tokens": "缓存创建 Tokens", - "日志详情": "日志详情", - "消耗额度": "消耗额度", - "开始时间": "开始时间", - "结束时间": "结束时间", - "用户名称": "用户名称", - "日志类型": "日志类型", - "绘图": "绘图", - "放大": "放大", - "变换": "变换", - "强变换": "强变换", - "平移": "平移", - "图生文": "图生文", - "图混合": "图混合", - "重绘": "重绘", - "局部重绘-提交": "局部重绘-提交", - "自定义变焦-提交": "自定义变焦-提交", - "窗口处理": "窗口处理", - "未知": "未知", - "已提交": "已提交", - "等待中": "等待中", - "重复提交": "重复提交", - "成功": "成功", - "未启动": "未启动", - "执行中": "执行中", - "窗口等待": "窗口等待", - "秒": "秒", - "提交时间": "提交时间", - "花费时间": "花费时间", - "任务ID": "任务ID", - "提交结果": "提交结果", - "任务状态": "任务状态", - "结果图片": "结果图片", - "查看图片": "查看图片", - "无": "无", - "失败原因": "失败原因", - "已复制:": "已复制:", - "当前未开启Midjourney回调,部分项目可能无法获得绘图结果,可在运营设置中开启。": "当前未开启Midjourney回调,部分项目可能无法获得绘图结果,可在运营设置中开启。", - "Midjourney 任务记录": "Midjourney 任务记录", - "任务 ID": "任务 ID", - "按次计费": "按次计费", - "按量计费": "按量计费", - "您的分组可以使用该模型": "您的分组可以使用该模型", - "可用性": "可用性", - "计费类型": "计费类型", - "当前查看的分组为:{{group}},倍率为:{{ratio}}": "当前查看的分组为:{{group}},倍率为:{{ratio}}", - "倍率": "倍率", - "倍率是为了方便换算不同价格的模型": "倍率是为了方便换算不同价格的模型", - "模型倍率": "模型倍率", - "补全倍率": "补全倍率", - "分组倍率": "分组倍率", - "模型价格": "模型价格", - "补全": "补全", - "模糊搜索模型名称": "模糊搜索模型名称", - "复制选中模型": "复制选中模型", - "模型定价": "模型定价", - "当前分组": "当前分组", - "未登录,使用默认分组倍率": "未登录,使用默认分组倍率", - "按量计费费用 = 分组倍率 × 模型倍率 × (提示token数 + 补全token数 × 补全倍率)/ 500000 (单位:美元)": "按量计费费用 = 分组倍率 × 模型倍率 × (提示token数 + 补全token数 × 补全倍率)/ 500000 (单位:美元)", - "已过期": "已过期", - "未使用": "未使用", - "已禁用": "已禁用", - "创建时间": "创建时间", - "过期时间": "过期时间", - "永不过期": "永不过期", - "确定是否要删除此兑换码?": "确定是否要删除此兑换码?", - "查看": "查看", - "已复制到剪贴板!": "已复制到剪贴板!", - "兑换码可以批量生成和分发,适合用于推广活动或批量充值。": "兑换码可以批量生成和分发,适合用于推广活动或批量充值。", - "添加兑换码": "添加兑换码", - "请至少选择一个兑换码!": "请至少选择一个兑换码!", - "复制所选兑换码到剪贴板": "复制所选兑换码到剪贴板", - "确定清除所有失效兑换码?": "确定清除所有失效兑换码?", - "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。", - "已删除 {{count}} 条失效兑换码": "已删除 {{count}} 条失效兑换码", - "关键字(id或者名称)": "关键字(id或者名称)", - "生成音乐": "生成音乐", - "生成歌词": "生成歌词", - "生成视频": "生成视频", - "排队中": "排队中", - "正在提交": "正在提交", - "平台": "平台", - "点击预览视频": "点击预览视频", - "任务记录": "任务记录", - "渠道 ID": "渠道 ID", - "已启用:限制模型": "已启用:限制模型", - "已耗尽": "已耗尽", - "剩余额度": "剩余额度", - "聊天链接配置错误,请联系管理员": "聊天链接配置错误,请联系管理员", - "令牌详情": "令牌详情", - "确定是否要删除此令牌?": "确定是否要删除此令牌?", - "项目操作按钮组": "项目操作按钮组", - "请联系管理员配置聊天链接": "请联系管理员配置聊天链接", - "令牌用于API访问认证,可以设置额度限制和模型权限。": "令牌用于API访问认证,可以设置额度限制和模型权限。", - "添加令牌": "添加令牌", - "请至少选择一个令牌!": "请至少选择一个令牌!", - "复制所选令牌到剪贴板": "复制所选令牌到剪贴板", - "搜索关键字": "搜索关键字", - "未知身份": "未知身份", - "已封禁": "已封禁", - "统计信息": "统计信息", - "剩余": "剩余", - "调用": "调用", - "邀请信息": "邀请信息", - "收益": "收益", - "无邀请人": "无邀请人", - "已注销": "已注销", - "确定要提升此用户吗?": "确定要提升此用户吗?", - "此操作将提升用户的权限级别": "此操作将提升用户的权限级别", - "确定要降级此用户吗?": "确定要降级此用户吗?", - "此操作将降低用户的权限级别": "此操作将降低用户的权限级别", - "确定是否要注销此用户?": "确定是否要注销此用户?", - "相当于删除用户,此修改将不可逆": "相当于删除用户,此修改将不可逆", - "用户管理页面,可以查看和管理所有注册用户的信息、权限和状态。": "用户管理页面,可以查看和管理所有注册用户的信息、权限和状态。", - "添加用户": "添加用户", - "支持搜索用户的 ID、用户名、显示名称和邮箱地址": "支持搜索用户的 ID、用户名、显示名称和邮箱地址", - "全部模型": "全部模型", - "智谱": "智谱", - "通义千问": "通义千问", - "文心一言": "文心一言", - "腾讯混元": "腾讯混元", - "360智脑": "360智脑", - "豆包": "豆包", - "用户分组": "用户分组", - "专属倍率": "专属倍率", - "输入价格:${{price}} / 1M tokens{{audioPrice}}": "输入价格:${{price}} / 1M tokens{{audioPrice}}", - "Web搜索价格:${{price}} / 1K 次": "Web搜索价格:${{price}} / 1K 次", - "文件搜索价格:${{price}} / 1K 次": "文件搜索价格:${{price}} / 1K 次", - "仅供参考,以实际扣费为准": "仅供参考,以实际扣费为准", - "价格:${{price}} * {{ratioType}}:{{ratio}}": "价格:${{price}} * {{ratioType}}:{{ratio}}", - "模型: {{ratio}} * {{ratioType}}:{{groupRatio}}": "模型: {{ratio}} * {{ratioType}}:{{groupRatio}}", - "提示价格:${{price}} / 1M tokens": "提示价格:${{price}} / 1M tokens", - "模型价格 ${{price}},{{ratioType}} {{ratio}}": "模型价格 ${{price}},{{ratioType}} {{ratio}}", - "模型: {{ratio}} * {{ratioType}}: {{groupRatio}}": "模型: {{ratio}} * {{ratioType}}: {{groupRatio}}", - "不是合法的 JSON 字符串": "不是合法的 JSON 字符串", - "请求发生错误: ": "请求发生错误: ", - "解析响应数据时发生错误": "解析响应数据时发生错误", - "连接已断开": "连接已断开", - "建立连接时发生错误": "建立连接时发生错误", - "加载模型失败": "加载模型失败", - "加载分组失败": "加载分组失败", - "消息已复制到剪贴板": "消息已复制到剪贴板", - "确认删除": "确认删除", - "确定要删除这条消息吗?": "确定要删除这条消息吗?", - "已删除消息及其回复": "已删除消息及其回复", - "消息已删除": "消息已删除", - "消息已编辑": "消息已编辑", - "检测到该消息后有AI回复,是否删除后续回复并重新生成?": "检测到该消息后有AI回复,是否删除后续回复并重新生成?", - "重新生成": "重新生成", - "消息已更新": "消息已更新", - "加载关于内容失败...": "加载关于内容失败...", - "可在设置页面设置关于内容,支持 HTML & Markdown": "可在设置页面设置关于内容,支持 HTML & Markdown", - "New API项目仓库地址:": "New API项目仓库地址:", - "| 基于": "| 基于", - "本项目根据": "本项目根据", - "MIT许可证": "MIT许可证", - "授权,需在遵守": "授权,需在遵守", - "Apache-2.0协议": "Apache-2.0协议", - "管理员暂时未设置任何关于内容": "管理员暂时未设置任何关于内容", - "仅支持 OpenAI 接口格式": "仅支持 OpenAI 接口格式", - "请填写密钥": "请填写密钥", - "获取模型列表成功": "获取模型列表成功", - "获取模型列表失败": "获取模型列表失败", - "请填写渠道名称和渠道密钥!": "请填写渠道名称和渠道密钥!", - "请至少选择一个模型!": "请至少选择一个模型!", - "模型映射必须是合法的 JSON 格式!": "模型映射必须是合法的 JSON 格式!", - "提交失败,请勿重复提交!": "提交失败,请勿重复提交!", - "渠道创建成功!": "渠道创建成功!", - "已新增 {{count}} 个模型:{{list}}": "已新增 {{count}} 个模型:{{list}}", - "未发现新增模型": "未发现新增模型", - "新建": "新建", - "更新渠道信息": "更新渠道信息", - "创建新的渠道": "创建新的渠道", - "基本信息": "基本信息", - "渠道的基本配置信息": "渠道的基本配置信息", - "请选择渠道类型": "请选择渠道类型", - "请为渠道命名": "请为渠道命名", - "请输入密钥,一行一个": "请输入密钥,一行一个", - "批量创建": "批量创建", - "API 配置": "API 配置", - "API 地址和相关配置": "API 地址和相关配置", - "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"": "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"", - "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com": "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com", - "请输入默认 API 版本,例如:2025-04-01-preview": "请输入默认 API 版本,例如:2025-04-01-preview", - "如果你对接的是上游One API或者New API等转发项目,请使用OpenAI类型,不要使用此类型,除非你知道你在做什么。": "如果你对接的是上游One API或者New API等转发项目,请使用OpenAI类型,不要使用此类型,除非你知道你在做什么。", - "完整的 Base URL,支持变量{model}": "完整的 Base URL,支持变量{model}", - "请输入完整的URL,例如:https://api.openai.com/v1/chat/completions": "请输入完整的URL,例如:https://api.openai.com/v1/chat/completions", - "Dify渠道只适配chatflow和agent,并且agent不支持图片!": "Dify渠道只适配chatflow和agent,并且agent不支持图片!", - "此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/": "此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/", - "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写", - "私有部署地址": "私有部署地址", - "请输入私有部署地址,格式为:https://fastgpt.run/api/openapi": "请输入私有部署地址,格式为:https://fastgpt.run/api/openapi", - "注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用": "注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用", - "请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com": "请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com", - "模型选择和映射设置": "模型选择和映射设置", - "模型": "模型", - "请选择该渠道所支持的模型": "请选择该渠道所支持的模型", - "填入相关模型": "填入相关模型", - "填入所有模型": "填入所有模型", - "获取模型列表": "获取模型列表", - "清除所有模型": "清除所有模型", - "输入自定义模型名称": "输入自定义模型名称", - "模型重定向": "模型重定向", - "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:", - "填入模板": "填入模板", - "默认测试模型": "默认测试模型", - "不填则为模型列表第一个": "不填则为模型列表第一个", - "渠道的高级配置选项": "渠道的高级配置选项", - "请选择可以使用该渠道的分组": "请选择可以使用该渠道的分组", - "请在系统设置页面编辑分组倍率以添加新的分组:": "请在系统设置页面编辑分组倍率以添加新的分组:", - "部署地区": "部署地区", - "知识库 ID": "知识库 ID", - "渠道标签": "渠道标签", - "渠道优先级": "渠道优先级", - "渠道权重": "渠道权重", - "渠道额外设置": "渠道额外设置", - "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:": "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:", - "参数覆盖": "参数覆盖", - "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:": "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:", - "请输入组织org-xxx": "请输入组织org-xxx", - "组织,可选,不填则为默认组织": "组织,可选,不填则为默认组织", - "是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道": "是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道", - "状态码复写(仅影响本地判断,不修改返回到上游的状态码)": "状态码复写(仅影响本地判断,不修改返回到上游的状态码)", - "此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:": "此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:", - "编辑标签": "编辑标签", - "标签信息": "标签信息", - "标签的基本配置": "标签的基本配置", - "所有编辑均为覆盖操作,留空则不更改": "所有编辑均为覆盖操作,留空则不更改", - "标签名称": "标签名称", - "请输入新标签,留空则解散标签": "请输入新标签,留空则解散标签", - "当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。": "当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。", - "请选择该渠道所支持的模型,留空则不更改": "请选择该渠道所支持的模型,留空则不更改", - "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改": "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改", - "清空重定向": "清空重定向", - "分组设置": "分组设置", - "用户分组配置": "用户分组配置", - "请选择可以使用该渠道的分组,留空则不更改": "请选择可以使用该渠道的分组,留空则不更改", - "正在跳转...": "正在跳转...", - "小时": "小时", - "周": "周", - "模型调用次数占比": "模型调用次数占比", - "模型消耗分布": "模型消耗分布", - "总计": "总计", - "早上好": "早上好", - "中午好": "中午好", - "下午好": "下午好", - "账户数据": "账户数据", - "使用统计": "使用统计", - "统计次数": "统计次数", - "资源消耗": "资源消耗", - "统计额度": "统计额度", - "性能指标": "性能指标", - "平均RPM": "平均RPM", - "复制成功": "复制成功", - "进行中": "进行中", - "异常": "异常", - "正常": "正常", - "可用率": "可用率", - "有异常": "有异常", - "高延迟": "高延迟", - "维护中": "维护中", - "暂无监控数据": "暂无监控数据", - "搜索条件": "搜索条件", - "时间粒度": "时间粒度", - "模型数据分析": "模型数据分析", - "消耗分布": "消耗分布", - "调用次数分布": "调用次数分布", - "API信息": "API信息", - "暂无API信息": "暂无API信息", - "请联系管理员在系统设置中配置API信息": "请联系管理员在系统设置中配置API信息", - "显示最新20条": "显示最新20条", - "请联系管理员在系统设置中配置公告信息": "请联系管理员在系统设置中配置公告信息", - "暂无常见问答": "暂无常见问答", - "请联系管理员在系统设置中配置常见问答": "请联系管理员在系统设置中配置常见问答", - "服务可用性": "服务可用性", - "请联系管理员在系统设置中配置Uptime": "请联系管理员在系统设置中配置Uptime", - "加载首页内容失败...": "加载首页内容失败...", - "统一的大模型接口网关": "统一的大模型接口网关", - "更好的价格,更好的稳定性,无需订阅": "更好的价格,更好的稳定性,无需订阅", - "开始使用": "开始使用", - "支持众多的大模型供应商": "支持众多的大模型供应商", - "页面未找到,请检查您的浏览器地址是否正确": "页面未找到,请检查您的浏览器地址是否正确", - "登录过期,请重新登录!": "登录过期,请重新登录!", - "兑换码更新成功!": "兑换码更新成功!", - "兑换码创建成功!": "兑换码创建成功!", - "兑换码创建成功": "兑换码创建成功", - "兑换码创建成功,是否下载兑换码?": "兑换码创建成功,是否下载兑换码?", - "兑换码将以文本文件的形式下载,文件名为兑换码的名称。": "兑换码将以文本文件的形式下载,文件名为兑换码的名称。", - "更新兑换码信息": "更新兑换码信息", - "创建新的兑换码": "创建新的兑换码", - "设置兑换码的基本信息": "设置兑换码的基本信息", - "请输入名称": "请输入名称", - "选择过期时间(可选,留空为永久)": "选择过期时间(可选,留空为永久)", - "额度设置": "额度设置", - "设置兑换码的额度和数量": "设置兑换码的额度和数量", - "请输入额度": "请输入额度", - "生成数量": "生成数量", - "请输入生成数量": "请输入生成数量", - "你似乎并没有修改什么": "你似乎并没有修改什么", - "部分保存失败,请重试": "部分保存失败,请重试", - "保存成功": "保存成功", - "保存失败,请重试": "保存失败,请重试", - "请检查输入": "请检查输入", - "聊天配置": "聊天配置", - "为一个 JSON 文本": "为一个 JSON 文本", - "保存聊天设置": "保存聊天设置", - "设置已保存": "设置已保存", - "API地址": "API地址", - "说明": "说明", - "颜色": "颜色", - "API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)": "API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)", - "批量删除": "批量删除", - "保存设置": "保存设置", - "添加API": "添加API", - "请输入API地址": "请输入API地址", - "如:香港线路": "如:香港线路", - "请输入线路描述": "请输入线路描述", - "如:大带宽批量分析图片推荐": "如:大带宽批量分析图片推荐", - "请输入说明": "请输入说明", - "标识颜色": "标识颜色", - "确定要删除此API信息吗?": "确定要删除此API信息吗?", - "警告": "警告", - "发布时间": "发布时间", - "操作": "操作", - "系统公告管理,可以发布系统通知和重要消息(最多100个,前端显示最新20条)": "系统公告管理,可以发布系统通知和重要消息(最多100个,前端显示最新20条)", - "添加公告": "添加公告", - "编辑公告": "编辑公告", - "公告内容": "公告内容", - "请输入公告内容": "请输入公告内容", - "请选择发布日期": "请选择发布日期", - "公告类型": "公告类型", - "说明信息": "说明信息", - "可选,公告的补充说明": "可选,公告的补充说明", - "确定要删除此公告吗?": "确定要删除此公告吗?", - "数据看板设置": "数据看板设置", - "启用数据看板(实验性)": "启用数据看板(实验性)", - "数据看板更新间隔": "数据看板更新间隔", - "设置过短会影响数据库性能": "设置过短会影响数据库性能", - "数据看板默认时间粒度": "数据看板默认时间粒度", - "仅修改展示粒度,统计精确到小时": "仅修改展示粒度,统计精确到小时", - "保存数据看板设置": "保存数据看板设置", - "问题标题": "问题标题", - "回答内容": "回答内容", - "常见问答管理,为用户提供常见问题的答案(最多50个,前端显示最新20条)": "常见问答管理,为用户提供常见问题的答案(最多50个,前端显示最新20条)", - "添加问答": "添加问答", - "编辑问答": "编辑问答", - "请输入问题标题": "请输入问题标题", - "请输入回答内容": "请输入回答内容", - "确定要删除此问答吗?": "确定要删除此问答吗?", - "分类名称": "分类名称", - "Uptime Kuma地址": "Uptime Kuma地址", - "Uptime Kuma监控分类管理,可以配置多个监控分类用于服务状态展示(最多20个)": "Uptime Kuma监控分类管理,可以配置多个监控分类用于服务状态展示(最多20个)", - "编辑分类": "编辑分类", - "添加分类": "添加分类", - "请输入分类名称,如:OpenAI、Claude等": "请输入分类名称,如:OpenAI、Claude等", - "请输入分类名称": "请输入分类名称", - "请输入Uptime Kuma服务地址,如:https://status.example.com": "请输入Uptime Kuma服务地址,如:https://status.example.com", - "请输入Uptime Kuma地址": "请输入Uptime Kuma地址", - "请输入状态页面的Slug,如:my-status": "请输入状态页面的Slug,如:my-status", - "请输入状态页面Slug": "请输入状态页面Slug", - "确定要删除此分类吗?": "确定要删除此分类吗?", - "绘图设置": "绘图设置", - "启用绘图功能": "启用绘图功能", - "允许回调(会泄露服务器 IP 地址)": "允许回调(会泄露服务器 IP 地址)", - "允许 AccountFilter 参数": "允许 AccountFilter 参数", - "开启之后会清除用户提示词中的": "开启之后会清除用户提示词中的", - "以及": "以及", - "检测必须等待绘图成功才能进行放大等操作": "检测必须等待绘图成功才能进行放大等操作", - "保存绘图设置": "保存绘图设置", - "Claude设置": "Claude设置", - "Claude请求头覆盖": "Claude请求头覆盖", - "为一个 JSON 文本,例如:": "为一个 JSON 文本,例如:", - "缺省 MaxTokens": "缺省 MaxTokens", - "启用Claude思考适配(-thinking后缀)": "启用Claude思考适配(-thinking后缀)", - "思考适配 BudgetTokens 百分比": "思考适配 BudgetTokens 百分比", - "0.1-1之间的小数": "0.1-1之间的小数", - "Gemini设置": "Gemini设置", - "Gemini安全设置": "Gemini安全设置", - "default为默认设置,可单独设置每个模型的版本": "default为默认设置,可单独设置每个模型的版本", - "例如:": "例如:", - "Gemini思考适配设置": "Gemini思考适配设置", - "启用Gemini思考后缀适配": "启用Gemini思考后缀适配", - "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀", - "0.002-1之间的小数": "0.002-1之间的小数", - "全局设置": "全局设置", - "启用请求透传": "启用请求透传", - "连接保活设置": "连接保活设置", - "启用Ping间隔": "启用Ping间隔", - "Ping间隔(秒)": "Ping间隔(秒)", - "新用户初始额度": "新用户初始额度", - "请求预扣费额度": "请求预扣费额度", - "请求结束后多退少补": "请求结束后多退少补", - "邀请新用户奖励额度": "邀请新用户奖励额度", - "新用户使用邀请码奖励额度": "新用户使用邀请码奖励额度", - "例如:1000": "例如:1000", - "保存额度设置": "保存额度设置", - "例如发卡网站的购买链接": "例如发卡网站的购买链接", - "文档地址": "文档地址", - "单位美元额度": "单位美元额度", - "一单位货币能兑换的额度": "一单位货币能兑换的额度", - "失败重试次数": "失败重试次数", - "以货币形式显示额度": "以货币形式显示额度", - "额度查询接口返回令牌额度而非用户额度": "额度查询接口返回令牌额度而非用户额度", - "默认折叠侧边栏": "默认折叠侧边栏", - "开启后不限制:必须设置模型倍率": "开启后不限制:必须设置模型倍率", - "保存通用设置": "保存通用设置", - "请选择日志记录时间": "请选择日志记录时间", - "条日志已清理!": "条日志已清理!", - "日志清理失败:": "日志清理失败:", - "启用额度消费日志记录": "启用额度消费日志记录", - "日志记录时间": "日志记录时间", - "清除历史日志": "清除历史日志", - "保存日志设置": "保存日志设置", - "监控设置": "监控设置", - "测试所有渠道的最长响应时间": "测试所有渠道的最长响应时间", - "额度提醒阈值": "额度提醒阈值", - "低于此额度时将发送邮件提醒用户": "低于此额度时将发送邮件提醒用户", - "失败时自动禁用通道": "失败时自动禁用通道", - "成功时自动启用通道": "成功时自动启用通道", - "自动禁用关键词": "自动禁用关键词", - "一行一个,不区分大小写": "一行一个,不区分大小写", - "屏蔽词过滤设置": "屏蔽词过滤设置", - "启用屏蔽词过滤功能": "启用屏蔽词过滤功能", - "启用 Prompt 检查": "启用 Prompt 检查", - "一行一个屏蔽词,不需要符号分割": "一行一个屏蔽词,不需要符号分割", - "保存屏蔽词过滤设置": "保存屏蔽词过滤设置", - "更新成功": "更新成功", - "更新失败": "更新失败", - "服务器地址": "服务器地址", - "更新服务器地址": "更新服务器地址", - "请先填写服务器地址": "请先填写服务器地址", - "充值分组倍率不是合法的 JSON 字符串": "充值分组倍率不是合法的 JSON 字符串", - "充值方式设置不是合法的 JSON 字符串": "充值方式设置不是合法的 JSON 字符串", - "支付设置": "支付设置", - "(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)": "(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)", - "例如:https://yourdomain.com": "例如:https://yourdomain.com", - "易支付商户ID": "易支付商户ID", - "易支付商户密钥": "易支付商户密钥", - "敏感信息不会发送到前端显示": "敏感信息不会发送到前端显示", - "回调地址": "回调地址", - "充值价格(x元/美金)": "充值价格(x元/美金)", - "例如:7,就是7元/美金": "例如:7,就是7元/美金", - "最低充值美元数量": "最低充值美元数量", - "例如:2,就是最低充值2$": "例如:2,就是最低充值2$", - "为一个 JSON 文本,键为组名称,值为倍率": "为一个 JSON 文本,键为组名称,值为倍率", - "充值方式设置": "充值方式设置", - "更新支付设置": "更新支付设置", - "模型请求速率限制": "模型请求速率限制", - "启用用户模型请求速率限制(可能会影响高并发性能)": "启用用户模型请求速率限制(可能会影响高并发性能)", - "分钟": "分钟", - "频率限制的周期(分钟)": "频率限制的周期(分钟)", - "用户每周期最多请求次数": "用户每周期最多请求次数", - "包括失败请求的次数,0代表不限制": "包括失败请求的次数,0代表不限制", - "用户每周期最多请求完成次数": "用户每周期最多请求完成次数", - "只包括请求成功的次数": "只包括请求成功的次数", - "分组速率限制": "分组速率限制", - "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}": "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}", - "示例:{\"default\": [200, 100], \"vip\": [0, 1000]}。": "示例:{\"default\": [200, 100], \"vip\": [0, 1000]}。", - "[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。": "[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。", - "分组速率配置优先级高于全局速率限制。": "分组速率配置优先级高于全局速率限制。", - "限制周期统一使用上方配置的“限制周期”值。": "限制周期统一使用上方配置的“限制周期”值。", - "保存模型速率限制": "保存模型速率限制", - "保存失败": "保存失败", - "为一个 JSON 文本,键为分组名称,值为倍率": "为一个 JSON 文本,键为分组名称,值为倍率", - "用户可选分组": "用户可选分组", - "为一个 JSON 文本,键为分组名称,值为分组描述": "为一个 JSON 文本,键为分组名称,值为分组描述", - "自动分组auto,从第一个开始选择": "自动分组auto,从第一个开始选择", - "必须是有效的 JSON 字符串数组,例如:[\"g1\",\"g2\"]": "必须是有效的 JSON 字符串数组,例如:[\"g1\",\"g2\"]", - "模型固定价格": "模型固定价格", - "一次调用消耗多少刀,优先级大于模型倍率": "一次调用消耗多少刀,优先级大于模型倍率", - "为一个 JSON 文本,键为模型名称,值为倍率": "为一个 JSON 文本,键为模型名称,值为倍率", - "模型补全倍率(仅对自定义模型有效)": "模型补全倍率(仅对自定义模型有效)", - "仅对自定义模型有效": "仅对自定义模型有效", - "保存模型倍率设置": "保存模型倍率设置", - "确定重置模型倍率吗?": "确定重置模型倍率吗?", - "重置模型倍率": "重置模型倍率", - "获取启用模型失败:": "获取启用模型失败:", - "获取启用模型失败": "获取启用模型失败", - "JSON解析错误:": "JSON解析错误:", - "保存失败:": "保存失败:", - "输入模型倍率": "输入模型倍率", - "输入补全倍率": "输入补全倍率", - "请输入数字": "请输入数字", - "模型名称已存在": "模型名称已存在", - "请先选择需要批量设置的模型": "请先选择需要批量设置的模型", - "请输入模型倍率和补全倍率": "请输入模型倍率和补全倍率", - "请输入有效的数字": "请输入有效的数字", - "请输入填充值": "请输入填充值", - "批量设置成功": "批量设置成功", - "已为 {{count}} 个模型设置{{type}}": "已为 {{count}} 个模型设置{{type}}", - "模型倍率和补全倍率": "模型倍率和补全倍率", - "添加模型": "添加模型", - "批量设置": "批量设置", - "应用更改": "应用更改", - "搜索模型名称": "搜索模型名称", - "此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除": "此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除", - "定价模式": "定价模式", - "固定价格": "固定价格", - "固定价格(每次)": "固定价格(每次)", - "输入每次价格": "输入每次价格", - "输入补全价格": "输入补全价格", - "批量设置模型参数": "批量设置模型参数", - "设置类型": "设置类型", - "模型倍率和补全倍率同时设置": "模型倍率和补全倍率同时设置", - "模型倍率值": "模型倍率值", - "请输入模型倍率": "请输入模型倍率", - "补全倍率值": "补全倍率值", - "请输入补全倍率": "请输入补全倍率", - "请输入数值": "请输入数值", - "将为选中的 ": "将为选中的 ", - " 个模型设置相同的值": " 个模型设置相同的值", - "当前设置类型: ": "当前设置类型: ", - "默认补全倍率": "默认补全倍率", - "添加成功": "添加成功", - "价格设置方式": "价格设置方式", - "按倍率设置": "按倍率设置", - "按价格设置": "按价格设置", - "输入价格": "输入价格", - "输出价格": "输出价格", - "获取渠道失败:": "获取渠道失败:", - "请至少选择一个渠道": "请至少选择一个渠道", - "后端请求失败": "后端请求失败", - "部分渠道测试失败:": "部分渠道测试失败:", - "未找到差异化倍率,无需同步": "未找到差异化倍率,无需同步", - "请求后端接口失败:": "请求后端接口失败:", - "同步成功": "同步成功", - "部分保存失败": "部分保存失败", - "未找到匹配的模型": "未找到匹配的模型", - "暂无差异化倍率显示": "暂无差异化倍率显示", - "请先选择同步渠道": "请先选择同步渠道", - "倍率类型": "倍率类型", - "缓存倍率": "缓存倍率", - "当前值": "当前值", - "未设置": "未设置", - "与本地相同": "与本地相同", - "运营设置": "运营设置", - "聊天设置": "聊天设置", - "速率限制设置": "速率限制设置", - "模型相关设置": "模型相关设置", - "系统设置": "系统设置", - "仪表盘设置": "仪表盘设置", - "获取初始化状态失败": "获取初始化状态失败", - "表单引用错误,请刷新页面重试": "表单引用错误,请刷新页面重试", - "请输入管理员用户名": "请输入管理员用户名", - "密码长度至少为8个字符": "密码长度至少为8个字符", - "两次输入的密码不一致": "两次输入的密码不一致", - "系统初始化成功,正在跳转...": "系统初始化成功,正在跳转...", - "初始化失败,请重试": "初始化失败,请重试", - "系统初始化失败,请重试": "系统初始化失败,请重试", - "系统初始化": "系统初始化", - "欢迎使用,请完成以下设置以开始使用系统": "欢迎使用,请完成以下设置以开始使用系统", - "数据库信息": "数据库信息", - "管理员账号": "管理员账号", - "设置系统管理员的登录信息": "设置系统管理员的登录信息", - "管理员账号已经初始化过,请继续设置其他参数": "管理员账号已经初始化过,请继续设置其他参数", - "密码": "密码", - "请输入管理员密码": "请输入管理员密码", - "请确认管理员密码": "请确认管理员密码", - "选择适合您使用场景的模式": "选择适合您使用场景的模式", - "对外运营模式": "对外运营模式", - "适用于为多个用户提供服务的场景": "适用于为多个用户提供服务的场景", - "默认模式": "默认模式", - "适用于个人使用的场景,不需要设置模型价格": "适用于个人使用的场景,不需要设置模型价格", - "无需计费": "无需计费", - "演示站点模式": "演示站点模式", - "适用于展示系统功能的场景,提供基础功能演示": "适用于展示系统功能的场景,提供基础功能演示", - "初始化系统": "初始化系统", - "使用模式说明": "使用模式说明", - "我已了解": "我已了解", - "默认模式,适用于为多个用户提供服务的场景。": "默认模式,适用于为多个用户提供服务的场景。", - "此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。": "此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。", - "多用户支持": "多用户支持", - "适用于个人使用的场景。": "适用于个人使用的场景。", - "不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。": "不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。", - "个人使用": "个人使用", - "适用于展示系统功能的场景。": "适用于展示系统功能的场景。", - "提供基础功能演示,方便用户了解系统特性。": "提供基础功能演示,方便用户了解系统特性。", - "体验试用": "体验试用", - "自动选择": "自动选择", - "过期时间格式错误!": "过期时间格式错误!", - "令牌更新成功!": "令牌更新成功!", - "令牌创建成功,请在列表页面点击复制获取令牌!": "令牌创建成功,请在列表页面点击复制获取令牌!", - "更新令牌信息": "更新令牌信息", - "创建新的令牌": "创建新的令牌", - "设置令牌的基本信息": "设置令牌的基本信息", - "请选择过期时间": "请选择过期时间", - "一天": "一天", - "一个月": "一个月", - "设置令牌可用额度和数量": "设置令牌可用额度和数量", - "新建数量": "新建数量", - "请选择或输入创建令牌的数量": "请选择或输入创建令牌的数量", - "20个": "20个", - "100个": "100个", - "取消无限额度": "取消无限额度", - "设为无限额度": "设为无限额度", - "设置令牌的访问限制": "设置令牌的访问限制", - "IP白名单": "IP白名单", - "允许的IP,一行一个,不填写则不限制": "允许的IP,一行一个,不填写则不限制", - "请勿过度信任此功能,IP可能被伪造": "请勿过度信任此功能,IP可能被伪造", - "勾选启用模型限制后可选择": "勾选启用模型限制后可选择", - "非必要,不建议启用模型限制": "非必要,不建议启用模型限制", - "分组信息": "分组信息", - "设置令牌的分组": "设置令牌的分组", - "令牌分组,默认为用户的分组": "令牌分组,默认为用户的分组", - "管理员未设置用户可选分组": "管理员未设置用户可选分组", - "请输入兑换码!": "请输入兑换码!", - "兑换成功!": "兑换成功!", - "成功兑换额度:": "成功兑换额度:", - "请求失败": "请求失败", - "超级管理员未设置充值链接!": "超级管理员未设置充值链接!", - "管理员未开启在线充值!": "管理员未开启在线充值!", - "充值数量不能小于": "充值数量不能小于", - "支付请求失败": "支付请求失败", - "划转金额最低为": "划转金额最低为", - "邀请链接已复制到剪切板": "邀请链接已复制到剪切板", - "支付方式配置错误, 请联系管理员": "支付方式配置错误, 请联系管理员", - "划转邀请额度": "划转邀请额度", - "可用邀请额度": "可用邀请额度", - "划转额度": "划转额度", - "充值确认": "充值确认", - "充值数量": "充值数量", - "实付金额": "实付金额", - "支付方式": "支付方式", - "在线充值": "在线充值", - "快速方便的充值方式": "快速方便的充值方式", - "选择充值额度": "选择充值额度", - "实付": "实付", - "或输入自定义金额": "或输入自定义金额", - "充值数量,最低 ": "充值数量,最低 ", - "选择支付方式": "选择支付方式", - "处理中": "处理中", - "兑换码充值": "兑换码充值", - "使用兑换码快速充值": "使用兑换码快速充值", - "请输入兑换码": "请输入兑换码", - "兑换中...": "兑换中...", - "兑换": "兑换", - "邀请奖励": "邀请奖励", - "邀请好友获得额外奖励": "邀请好友获得额外奖励", - "待使用收益": "待使用收益", - "总收益": "总收益", - "邀请人数": "邀请人数", - "邀请链接": "邀请链接", - "邀请好友注册,好友充值后您可获得相应奖励": "邀请好友注册,好友充值后您可获得相应奖励", - "通过划转功能将奖励额度转入到您的账户余额中": "通过划转功能将奖励额度转入到您的账户余额中", - "邀请的好友越多,获得的奖励越多": "邀请的好友越多,获得的奖励越多", - "用户名和密码不能为空!": "用户名和密码不能为空!", - "用户账户创建成功!": "用户账户创建成功!", - "提交": "提交", - "创建新用户账户": "创建新用户账户", - "请输入显示名称": "请输入显示名称", - "请输入密码": "请输入密码", - "请输入备注(仅管理员可见)": "请输入备注(仅管理员可见)", - "编辑用户": "编辑用户", - "用户的基本账户信息": "用户的基本账户信息", - "请输入新的用户名": "请输入新的用户名", - "请输入新的密码,最短 8 位": "请输入新的密码,最短 8 位", - "显示名称": "显示名称", - "请输入新的显示名称": "请输入新的显示名称", - "权限设置": "权限设置", - "用户分组和额度管理": "用户分组和额度管理", - "请输入新的剩余额度": "请输入新的剩余额度", - "添加额度": "添加额度", - "第三方账户绑定状态(只读)": "第三方账户绑定状态(只读)", - "已绑定的 GitHub 账户": "已绑定的 GitHub 账户", - "已绑定的 OIDC 账户": "已绑定的 OIDC 账户", - "已绑定的微信账户": "已绑定的微信账户", - "已绑定的邮箱账户": "已绑定的邮箱账户", - "已绑定的 Telegram 账户": "已绑定的 Telegram 账户", - "新额度": "新额度", - "需要添加的额度(支持负数)": "需要添加的额度(支持负数)" -} \ No newline at end of file diff --git a/common/logger.go b/logger/logger.go similarity index 70% rename from common/logger.go rename to logger/logger.go index 0f6dc3c3b..d59e51cb8 100644 --- a/common/logger.go +++ b/logger/logger.go @@ -1,23 +1,26 @@ -package common +package logger import ( "context" "encoding/json" "fmt" - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" "io" "log" + "one-api/common" "os" "path/filepath" "sync" "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) const ( loggerINFO = "INFO" loggerWarn = "WARN" loggerError = "ERR" + loggerDebug = "DEBUG" ) const maxLogCount = 1000000 @@ -27,7 +30,10 @@ var setupLogLock sync.Mutex var setupLogWorking bool func SetupLogger() { - if *LogDir != "" { + defer func() { + setupLogWorking = false + }() + if *common.LogDir != "" { ok := setupLogLock.TryLock() if !ok { log.Println("setup log is already working") @@ -35,9 +41,8 @@ func SetupLogger() { } defer func() { setupLogLock.Unlock() - setupLogWorking = false }() - logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) + logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") @@ -47,16 +52,6 @@ func SetupLogger() { } } -func SysLog(s string) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) -} - -func SysError(s string) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) -} - func LogInfo(ctx context.Context, msg string) { logHelper(ctx, loggerINFO, msg) } @@ -69,12 +64,18 @@ func LogError(ctx context.Context, msg string) { logHelper(ctx, loggerError, msg) } +func LogDebug(ctx context.Context, msg string) { + if common.DebugEnabled { + logHelper(ctx, loggerDebug, msg) + } +} + func logHelper(ctx context.Context, level string, msg string) { writer := gin.DefaultErrorWriter if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(RequestIdKey) + id := ctx.Value(common.RequestIdKey) if id == nil { id = "SYSTEM" } @@ -90,23 +91,17 @@ func logHelper(ctx context.Context, level string, msg string) { } } -func FatalLog(v ...any) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) - os.Exit(1) -} - func LogQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) + if common.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit) } else { return fmt.Sprintf("%d 点额度", quota) } } func FormatQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit) + if common.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit) } else { return fmt.Sprintf("%d", quota) } diff --git a/main.go b/main.go index ca3da6012..cc2288a61 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "one-api/common" "one-api/constant" "one-api/controller" + "one-api/logger" "one-api/middleware" "one-api/model" "one-api/router" @@ -60,13 +61,13 @@ func main() { } if common.MemoryCacheEnabled { common.SysLog("memory cache enabled") - common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) // Add panic recovery and retry for InitChannelCache func() { defer func() { if r := recover(); r != nil { - common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) + common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) // Retry once _, _, fixErr := model.FixAbility() if fixErr != nil { @@ -93,13 +94,9 @@ func main() { } go controller.AutomaticallyUpdateChannels(frequency) } - if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { - frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) - if err != nil { - common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) - } - go controller.AutomaticallyTestChannels(frequency) - } + + go controller.AutomaticallyTestChannels() + if common.IsMasterNode && constant.UpdateTask { gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() @@ -125,7 +122,7 @@ func main() { // Initialize HTTP server server := gin.New() server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { - common.SysError(fmt.Sprintf("panic detected: %v", err)) + common.SysLog(fmt.Sprintf("panic detected: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), @@ -171,7 +168,7 @@ func InitResources() error { // 加载环境变量 common.InitEnv() - common.SetupLogger() + logger.SetupLogger() // Initialize model settings ratio_setting.InitRatioSettings() diff --git a/middleware/auth.go b/middleware/auth.go index a158318c5..25caf50d9 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,7 +4,10 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" "one-api/model" + "one-api/setting" + "one-api/setting/ratio_setting" "strconv" "strings" @@ -122,6 +125,7 @@ func authHelper(c *gin.Context, minRole int) { c.Set("role", role) c.Set("id", id) c.Set("group", session.Get("group")) + c.Set("user_group", session.Get("group")) c.Set("use_access_token", useAccessToken) //userCache, err := model.GetUserCache(id.(int)) @@ -190,14 +194,15 @@ func TokenAuth() func(c *gin.Context) { } // 检查path包含/v1/messages if strings.Contains(c.Request.URL.Path, "/v1/messages") { - // 从x-api-key中获取key - key := c.Request.Header.Get("x-api-key") - if key != "" { - c.Request.Header.Set("Authorization", "Bearer "+key) + anthropicKey := c.Request.Header.Get("x-api-key") + if anthropicKey != "" { + c.Request.Header.Set("Authorization", "Bearer "+anthropicKey) } } // gemini api 从query中获取key - if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { + if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") || + strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") || + strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { skKey := c.Query("key") if skKey != "" { c.Request.Header.Set("Authorization", "Bearer "+skKey) @@ -233,6 +238,16 @@ func TokenAuth() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return } + + allowIpsMap := token.GetIpLimitsMap() + if len(allowIpsMap) != 0 { + clientIp := c.ClientIP() + if _, ok := allowIpsMap[clientIp]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中") + return + } + } + userCache, err := model.GetUserCache(token.UserId) if err != nil { abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) @@ -246,6 +261,25 @@ func TokenAuth() func(c *gin.Context) { userCache.WriteContext(c) + userGroup := userCache.Group + tokenGroup := token.Group + if tokenGroup != "" { + // check common.UserUsableGroups[userGroup] + if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup)) + return + } + // check group in common.GroupRatio + if !ratio_setting.ContainsGroupRatio(tokenGroup) { + if tokenGroup != "auto" { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) + return + } + } + userGroup = tokenGroup + } + common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup) + err = SetupContextForToken(c, token, parts...) if err != nil { return @@ -272,7 +306,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e } else { c.Set("token_model_limit_enabled", false) } - c.Set("allow_ips", token.GetIpLimitsMap()) c.Set("token_group", token.Group) if len(parts) > 1 { if model.IsAdmin(token.UserId) { diff --git a/middleware/disable-cache.go b/middleware/disable-cache.go new file mode 100644 index 000000000..3076e90a8 --- /dev/null +++ b/middleware/disable-cache.go @@ -0,0 +1,12 @@ +package middleware + +import "github.com/gin-gonic/gin" + +func DisableCache() gin.HandlerFunc { + return func(c *gin.Context) { + c.Header("Cache-Control", "no-store, no-cache, must-revalidate, private, max-age=0") + c.Header("Pragma", "no-cache") + c.Header("Expires", "0") + c.Next() + } +} diff --git a/middleware/distributor.go b/middleware/distributor.go index a6889e396..1e6df872d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -27,14 +27,6 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { - allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps) - if len(allowIpsMap) != 0 { - clientIp := c.ClientIP() - if _, ok := allowIpsMap[clientIp]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中") - return - } - } var channel *model.Channel channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId) modelRequest, shouldSelectChannel, err := getModelRequest(c) @@ -42,24 +34,6 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) return } - userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) - tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) - if tokenGroup != "" { - // check common.UserUsableGroups[userGroup] - if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup)) - return - } - // check group in common.GroupRatio - if !ratio_setting.ContainsGroupRatio(tokenGroup) { - if tokenGroup != "auto" { - abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) - return - } - } - userGroup = tokenGroup - } - common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -81,44 +55,63 @@ func Distribute() func(c *gin.Context) { modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) if modelLimitEnable { s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) - var tokenModelLimit map[string]bool - if ok { - tokenModelLimit = s.(map[string]bool) - } else { - tokenModelLimit = map[string]bool{} - } - if tokenModelLimit != nil { - if _, ok := tokenModelLimit[modelRequest.Model]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) - return - } - } else { + if !ok { // token model limit is empty, all models are not allowed abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") return } + var tokenModelLimit map[string]bool + tokenModelLimit, ok = s.(map[string]bool) + if !ok { + tokenModelLimit = map[string]bool{} + } + matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-* + if _, ok := tokenModelLimit[matchName]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) + return + } } if shouldSelectChannel { + if modelRequest.Model == "" { + abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空") + return + } var selectGroup string + userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) + // check path is /pg/chat/completions + if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") { + playgroundRequest := &dto.PlayGroundRequest{} + err = common.UnmarshalBodyReusable(c, playgroundRequest) + if err != nil { + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) + return + } + if playgroundRequest.Group != "" { + if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup { + abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组") + return + } + userGroup = playgroundRequest.Group + } + } channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0) if err != nil { showGroup := userGroup if userGroup == "auto" { showGroup = fmt.Sprintf("auto(%s)", selectGroup) } - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model) + message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error()) // 如果错误,但是渠道不为空,说明是数据库一致性问题 - if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) - message = "数据库一致性已被破坏,请联系管理员" - } - // 如果错误,而且渠道为空,说明是没有可用渠道 - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) + //if channel != nil { + // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + // message = "数据库一致性已被破坏,请联系管理员" + //} + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound)) return } if channel == nil { - abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) + abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound)) return } } @@ -174,23 +167,16 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { c.Set("relay_mode", relayMode) } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { err = common.UnmarshalBodyReusable(c, &modelRequest) - var platform string - var relayMode int - if strings.HasPrefix(modelRequest.Model, "jimeng") { - platform = string(constant.TaskPlatformJimeng) - relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path) - if relayMode == relayconstant.RelayModeJimengFetchByID { - shouldSelectChannel = false - } - } else { - platform = string(constant.TaskPlatformKling) - relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path) - if relayMode == relayconstant.RelayModeKlingFetchByID { - shouldSelectChannel = false - } + relayMode := relayconstant.RelayModeUnknown + if c.Request.Method == http.MethodPost { + relayMode = relayconstant.RelayModeVideoSubmit + } else if c.Request.Method == http.MethodGet { + relayMode = relayconstant.RelayModeVideoFetchByID + shouldSelectChannel = false + } + if _, ok := c.Get("relay_mode"); !ok { + c.Set("relay_mode", relayMode) } - c.Set("platform", platform) - c.Set("relay_mode", relayMode) } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent relayMode := relayconstant.RelayModeGemini @@ -199,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { modelRequest.Model = modelName } c.Set("relay_mode", relayMode) - } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { + } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { @@ -222,7 +208,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { - modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") + //modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") + if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { + modelRequest.Model = c.PostForm("model") + } } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { relayMode := relayconstant.RelayModeAudioSpeech @@ -253,14 +242,16 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError { c.Set("original_model", modelName) // for retry if channel == nil { - return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed) + return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime) common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) + common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings()) common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride()) + common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride()) if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" { common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization) } @@ -275,11 +266,16 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode if channel.ChannelInfo.IsMultiKey { common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true) common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index) + } else { + // 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误 + common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false) } // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key)) common.SetContextKey(c, constant.ContextKeyChannelKey, key) common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, false) + // TODO: api_version统一 switch channel.Type { case constant.ChannelTypeAzure: diff --git a/middleware/email-verification-rate-limit.go b/middleware/email-verification-rate-limit.go new file mode 100644 index 000000000..a7d828d96 --- /dev/null +++ b/middleware/email-verification-rate-limit.go @@ -0,0 +1,80 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "one-api/common" + "time" + + "github.com/gin-gonic/gin" +) + +const ( + EmailVerificationRateLimitMark = "EV" + EmailVerificationMaxRequests = 2 // 30秒内最多2次 + EmailVerificationDuration = 30 // 30秒时间窗口 +) + +func redisEmailVerificationRateLimiter(c *gin.Context) { + ctx := context.Background() + rdb := common.RDB + key := "emailVerification:" + EmailVerificationRateLimitMark + ":" + c.ClientIP() + + count, err := rdb.Incr(ctx, key).Result() + if err != nil { + // fallback + memoryEmailVerificationRateLimiter(c) + return + } + + // 第一次设置键时设置过期时间 + if count == 1 { + _ = rdb.Expire(ctx, key, time.Duration(EmailVerificationDuration)*time.Second).Err() + } + + // 检查是否超出限制 + if count <= int64(EmailVerificationMaxRequests) { + c.Next() + return + } + + // 获取剩余等待时间 + ttl, err := rdb.TTL(ctx, key).Result() + waitSeconds := int64(EmailVerificationDuration) + if err == nil && ttl > 0 { + waitSeconds = int64(ttl.Seconds()) + } + + c.JSON(http.StatusTooManyRequests, gin.H{ + "success": false, + "message": fmt.Sprintf("发送过于频繁,请等待 %d 秒后再试", waitSeconds), + }) + c.Abort() +} + +func memoryEmailVerificationRateLimiter(c *gin.Context) { + key := EmailVerificationRateLimitMark + ":" + c.ClientIP() + + if !inMemoryRateLimiter.Request(key, EmailVerificationMaxRequests, EmailVerificationDuration) { + c.JSON(http.StatusTooManyRequests, gin.H{ + "success": false, + "message": "发送过于频繁,请稍后再试", + }) + c.Abort() + return + } + + c.Next() +} + +func EmailVerificationRateLimit() gin.HandlerFunc { + return func(c *gin.Context) { + if common.RedisEnabled { + redisEmailVerificationRateLimiter(c) + } else { + inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + memoryEmailVerificationRateLimiter(c) + } + } +} diff --git a/middleware/jimeng_adapter.go b/middleware/jimeng_adapter.go new file mode 100644 index 000000000..ce5e14675 --- /dev/null +++ b/middleware/jimeng_adapter.go @@ -0,0 +1,66 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/constant" + relayconstant "one-api/relay/constant" +) + +func JimengRequestConvert() func(c *gin.Context) { + return func(c *gin.Context) { + action := c.Query("Action") + if action == "" { + abortWithOpenAiMessage(c, http.StatusBadRequest, "Action query parameter is required") + return + } + + // Handle Jimeng official API request + var originalReq map[string]interface{} + if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil { + abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request body") + return + } + model, _ := originalReq["req_key"].(string) + prompt, _ := originalReq["prompt"].(string) + + unifiedReq := map[string]interface{}{ + "model": model, + "prompt": prompt, + "metadata": originalReq, + } + + jsonData, err := json.Marshal(unifiedReq) + if err != nil { + abortWithOpenAiMessage(c, http.StatusInternalServerError, "Failed to marshal request body") + return + } + + // Update request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData)) + c.Set(common.KeyRequestBody, jsonData) + + if image, ok := originalReq["image"]; !ok || image == "" { + c.Set("action", constant.TaskActionTextGenerate) + } + + c.Request.URL.Path = "/v1/video/generations" + + if action == "CVSync2AsyncGetResult" { + taskId, ok := originalReq["task_id"].(string) + if !ok || taskId == "" { + abortWithOpenAiMessage(c, http.StatusBadRequest, "task_id is required for CVSync2AsyncGetResult") + return + } + c.Request.URL.Path = "/v1/video/generations/" + taskId + c.Request.Method = http.MethodGet + c.Set("task_id", taskId) + c.Set("relay_mode", relayconstant.RelayModeVideoFetchByID) + } + c.Next() + } +} diff --git a/middleware/kling_adapter.go b/middleware/kling_adapter.go index 3d4943d28..20973c9f6 100644 --- a/middleware/kling_adapter.go +++ b/middleware/kling_adapter.go @@ -18,7 +18,11 @@ func KlingRequestConvert() func(c *gin.Context) { return } + // Support both model_name and model fields model, _ := originalReq["model_name"].(string) + if model == "" { + model, _ = originalReq["model"].(string) + } prompt, _ := originalReq["prompt"].(string) unifiedReq := map[string]interface{}{ diff --git a/middleware/recover.go b/middleware/recover.go index 51fc71908..d78c8137f 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - common.SysError(fmt.Sprintf("panic detected: %v", err)) - common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + common.SysLog(fmt.Sprintf("panic detected: %v", err)) + common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), diff --git a/middleware/stats.go b/middleware/stats.go index 1c97983f7..e49e56991 100644 --- a/middleware/stats.go +++ b/middleware/stats.go @@ -18,12 +18,12 @@ func StatsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // 增加活跃连接数 atomic.AddInt64(&globalStats.activeConnections, 1) - + // 确保在请求结束时减少连接数 defer func() { atomic.AddInt64(&globalStats.activeConnections, -1) }() - + c.Next() } } @@ -38,4 +38,4 @@ func GetStats() StatsInfo { return StatsInfo{ ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections), } -} \ No newline at end of file +} diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 26688810d..106a72781 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -37,7 +37,7 @@ func TurnstileCheck() gin.HandlerFunc { "remoteip": {c.ClientIP()}, }) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -49,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/middleware/utils.go b/middleware/utils.go index 082f56571..77d1eb805 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -4,18 +4,24 @@ import ( "fmt" "github.com/gin-gonic/gin" "one-api/common" + "one-api/logger" ) -func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { +func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) { + codeStr := "" + if len(code) > 0 { + codeStr = code[0] + } userId := c.GetInt("id") c.JSON(statusCode, gin.H{ "error": gin.H{ "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "type": "new_api_error", + "code": codeStr, }, }) c.Abort() - common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) + logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) } func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { @@ -25,5 +31,5 @@ func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, descri "code": code, }) c.Abort() - common.LogError(c.Request.Context(), description) + logger.LogError(c.Request.Context(), description) } diff --git a/model/ability.go b/model/ability.go index f36ff7642..123fc7be5 100644 --- a/model/ability.go +++ b/model/ability.go @@ -136,13 +136,13 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, } } } else { - return nil, errors.New("channel not found") + return nil, nil } err = DB.First(&channel, "id = ?", channel.Id).Error return &channel, err } -func (channel *Channel) AddAbilities() error { +func (channel *Channel) AddAbilities(tx *gorm.DB) error { models_ := strings.Split(channel.Models, ",") groups_ := strings.Split(channel.Group, ",") abilitySet := make(map[string]struct{}) @@ -169,8 +169,13 @@ func (channel *Channel) AddAbilities() error { if len(abilities) == 0 { return nil } + // choose DB or provided tx + useDB := DB + if tx != nil { + useDB = tx + } for _, chunk := range lo.Chunk(abilities, 50) { - err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error + err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error if err != nil { return err } @@ -284,6 +289,21 @@ func FixAbility() (int, int, error) { return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试") } defer fixLock.Unlock() + + // truncate abilities table + if common.UsingSQLite { + err := DB.Exec("DELETE FROM abilities").Error + if err != nil { + common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + return 0, 0, err + } + } else { + err := DB.Exec("TRUNCATE TABLE abilities").Error + if err != nil { + common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) + return 0, 0, err + } + } var channels []*Channel // Find all channels err := DB.Model(&Channel{}).Find(&channels).Error @@ -300,15 +320,15 @@ func FixAbility() (int, int, error) { // Delete all abilities of this channel err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error if err != nil { - common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error())) failCount += len(chunk) continue } // Then add new abilities for _, channel := range chunk { - err = channel.AddAbilities() + err = channel.AddAbilities(nil) if err != nil { - common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) + common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) failCount++ } else { successCount++ diff --git a/model/channel.go b/model/channel.go index 6277fcda2..a61b3eccf 100644 --- a/model/channel.go +++ b/model/channel.go @@ -13,6 +13,7 @@ import ( "strings" "sync" + "github.com/samber/lo" "gorm.io/gorm" ) @@ -41,19 +42,27 @@ type Channel struct { Priority *int64 `json:"priority" gorm:"bigint;default:0"` AutoBan *int `json:"auto_ban" gorm:"default:1"` OtherInfo string `json:"other_info"` + OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置 Tag *string `json:"tag" gorm:"index"` Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 ParamOverride *string `json:"param_override" gorm:"type:text"` + HeaderOverride *string `json:"header_override" gorm:"type:text"` + Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"` // add after v0.8.5 ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` + + // cache info + Keys []string `json:"-" gorm:"-"` } type ChannelInfo struct { - IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 - MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 - MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status - MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 - MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` + IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 + MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 + MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status + MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason + MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time + MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 + MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` } // Value implements driver.Valuer interface @@ -67,15 +76,18 @@ func (c *ChannelInfo) Scan(value interface{}) error { return common.Unmarshal(bytesValue, c) } -func (channel *Channel) getKeys() []string { +func (channel *Channel) GetKeys() []string { if channel.Key == "" { return []string{} } + if len(channel.Keys) > 0 { + return channel.Keys + } trimmed := strings.TrimSpace(channel.Key) // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios) if strings.HasPrefix(trimmed, "[") { var arr []json.RawMessage - if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { + if err := common.Unmarshal([]byte(trimmed), &arr); err == nil { res := make([]string, len(arr)) for i, v := range arr { res[i] = string(v) @@ -95,12 +107,16 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { } // Obtain all keys (split by \n) - keys := channel.getKeys() + keys := channel.GetKeys() if len(keys) == 0 { // No keys available, return error, should disable the channel return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey) } + lock := GetChannelPollingLock(channel.Id) + lock.Lock() + defer lock.Unlock() + statusList := channel.ChannelInfo.MultiKeyStatusList // helper to get key status, default to enabled when missing getStatus := func(idx int) int { @@ -132,13 +148,10 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { return keys[selectedIdx], selectedIdx, nil case constant.MultiKeyModePolling: // Use channel-specific lock to ensure thread-safe polling - lock := getChannelPollingLock(channel.Id) - lock.Lock() - defer lock.Unlock() channelInfo, err := CacheGetChannelInfo(channel.Id) if err != nil { - return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed) + return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex) defer func() { @@ -197,9 +210,9 @@ func (channel *Channel) GetGroups() []string { func (channel *Channel) GetOtherInfo() map[string]interface{} { otherInfo := make(map[string]interface{}) if channel.OtherInfo != "" { - err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo) + err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { - common.SysError("failed to unmarshal other info: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err)) } } return otherInfo @@ -208,7 +221,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { otherInfoBytes, err := json.Marshal(otherInfo) if err != nil { - common.SysError("failed to marshal other info: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err)) return } channel.OtherInfo = string(otherInfoBytes) @@ -236,6 +249,10 @@ func (channel *Channel) Save() error { return DB.Save(channel).Error } +func (channel *Channel) SaveWithoutKey() error { + return DB.Omit("key").Save(channel).Error +} + func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) { var channels []*Channel var err error @@ -328,38 +345,54 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { } func BatchInsertChannels(channels []Channel) error { - var err error - err = DB.Create(&channels).Error - if err != nil { - return err + if len(channels) == 0 { + return nil } - for _, channel_ := range channels { - err = channel_.AddAbilities() - if err != nil { + tx := DB.Begin() + if tx.Error != nil { + return tx.Error + } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + + for _, chunk := range lo.Chunk(channels, 50) { + if err := tx.Create(&chunk).Error; err != nil { + tx.Rollback() return err } + for _, channel_ := range chunk { + if err := channel_.AddAbilities(tx); err != nil { + tx.Rollback() + return err + } + } } - return nil + return tx.Commit().Error } func BatchDeleteChannels(ids []int) error { - //使用事务 删除channel表和channel_ability表 + if len(ids) == 0 { + return nil + } + // 使用事务 分批删除channel表和abilities表 tx := DB.Begin() - err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error - if err != nil { - // 回滚事务 - tx.Rollback() - return err + if tx.Error != nil { + return tx.Error } - err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error - if err != nil { - // 回滚事务 - tx.Rollback() - return err + for _, chunk := range lo.Chunk(ids, 200) { + if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil { + tx.Rollback() + return err + } + if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil { + tx.Rollback() + return err + } } - // 提交事务 - tx.Commit() - return err + return tx.Commit().Error } func (channel *Channel) GetPriority() int64 { @@ -380,7 +413,11 @@ func (channel *Channel) GetBaseURL() string { if channel.BaseURL == nil { return "" } - return *channel.BaseURL + url := *channel.BaseURL + if url == "" { + url = constant.ChannelBaseURLs[channel.Type] + } + return url } func (channel *Channel) GetModelMapping() string { @@ -403,7 +440,7 @@ func (channel *Channel) Insert() error { if err != nil { return err } - err = channel.AddAbilities() + err = channel.AddAbilities(nil) return err } @@ -425,7 +462,7 @@ func (channel *Channel) Update() error { trimmed := strings.TrimSpace(keyStr) if strings.HasPrefix(trimmed, "[") { var arr []json.RawMessage - if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { + if err := common.Unmarshal([]byte(trimmed), &arr); err == nil { keys = make([]string, len(arr)) for i, v := range arr { keys[i] = string(v) @@ -462,7 +499,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { ResponseTime: int(responseTime), }).Error if err != nil { - common.SysError("failed to update response time: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err)) } } @@ -472,7 +509,7 @@ func (channel *Channel) UpdateBalance(balance float64) { Balance: balance, }).Error if err != nil { - common.SysError("failed to update balance: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err)) } } @@ -491,8 +528,8 @@ var channelStatusLock sync.Mutex // channelPollingLocks stores locks for each channel.id to ensure thread-safe polling var channelPollingLocks sync.Map -// getChannelPollingLock returns or creates a mutex for the given channel ID -func getChannelPollingLock(channelId int) *sync.Mutex { +// GetChannelPollingLock returns or creates a mutex for the given channel ID +func GetChannelPollingLock(channelId int) *sync.Mutex { if lock, exists := channelPollingLocks.Load(channelId); exists { return lock.(*sync.Mutex) } @@ -522,8 +559,8 @@ func CleanupChannelPollingLocks() { }) } -func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { - keys := channel.getKeys() +func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) { + keys := channel.GetKeys() if len(keys) == 0 { channel.Status = status } else { @@ -541,6 +578,14 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex) } else { channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status + if channel.ChannelInfo.MultiKeyDisabledReason == nil { + channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) + } + if channel.ChannelInfo.MultiKeyDisabledTime == nil { + channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) + } + channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason + channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp() } if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize { channel.Status = common.ChannelStatusAutoDisabled @@ -562,8 +607,12 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri return false } if channelCache.ChannelInfo.IsMultiKey { + // Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey + pollingLock := GetChannelPollingLock(channelId) + pollingLock.Lock() // 如果是多Key模式,更新缓存中的状态 - handlerMultiKeyUpdate(channelCache, usingKey, status) + handlerMultiKeyUpdate(channelCache, usingKey, status, reason) + pollingLock.Unlock() //CacheUpdateChannel(channelCache) //return true } else { @@ -571,10 +620,6 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if channelCache.Status == status { return false } - // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 - if status != common.ChannelStatusEnabled { - return false - } CacheUpdateChannelStatus(channelId, status) } } @@ -584,7 +629,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if shouldUpdateAbilities { err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err)) } } }() @@ -598,7 +643,11 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if channel.ChannelInfo.IsMultiKey { beforeStatus := channel.Status - handlerMultiKeyUpdate(channel, usingKey, status) + // Protect map writes with the same per-channel lock used by readers + pollingLock := GetChannelPollingLock(channelId) + pollingLock.Lock() + handlerMultiKeyUpdate(channel, usingKey, status, reason) + pollingLock.Unlock() if beforeStatus != channel.Status { shouldUpdateAbilities = true } @@ -610,9 +659,9 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri channel.Status = status shouldUpdateAbilities = true } - err = channel.Save() + err = channel.SaveWithoutKey() if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err)) return false } } @@ -674,7 +723,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models * for _, channel := range channels { err = channel.UpdateAbilities(nil) if err != nil { - common.SysError("failed to update abilities: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err)) } } } @@ -698,7 +747,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - common.SysError("failed to update channel used quota: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err)) } } @@ -778,7 +827,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str func (channel *Channel) ValidateSettings() error { channelParams := &dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { - err := json.Unmarshal([]byte(*channel.Setting), channelParams) + err := common.Unmarshal([]byte(*channel.Setting), channelParams) if err != nil { return err } @@ -789,9 +838,9 @@ func (channel *Channel) ValidateSettings() error { func (channel *Channel) GetSetting() dto.ChannelSettings { setting := dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { - err := json.Unmarshal([]byte(*channel.Setting), &setting) + err := common.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err)) channel.Setting = nil // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -800,25 +849,58 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { } func (channel *Channel) SetSetting(setting dto.ChannelSettings) { - settingBytes, err := json.Marshal(setting) + settingBytes, err := common.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err)) return } channel.Setting = common.GetPointer[string](string(settingBytes)) } +func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { + setting := dto.ChannelOtherSettings{} + if channel.OtherSettings != "" { + err := common.UnmarshalJsonStr(channel.OtherSettings, &setting) + if err != nil { + common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err)) + channel.OtherSettings = "{}" // 清空设置以避免后续错误 + _ = channel.Save() // 保存修改 + } + } + return setting +} + +func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) { + settingBytes, err := common.Marshal(setting) + if err != nil { + common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err)) + return + } + channel.OtherSettings = string(settingBytes) +} + func (channel *Channel) GetParamOverride() map[string]interface{} { paramOverride := make(map[string]interface{}) if channel.ParamOverride != nil && *channel.ParamOverride != "" { - err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) + err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { - common.SysError("failed to unmarshal param override: " + err.Error()) + common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err)) } } return paramOverride } +func (channel *Channel) GetHeaderOverride() map[string]interface{} { + headerOverride := make(map[string]interface{}) + if channel.HeaderOverride != nil && *channel.HeaderOverride != "" { + err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride) + if err != nil { + common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err)) + } + } + return headerOverride +} + func GetChannelsByIds(ids []int) ([]*Channel, error) { var channels []*Channel err := DB.Where("id in (?)", ids).Find(&channels).Error diff --git a/model/channel_cache.go b/model/channel_cache.go index b24512489..86866e404 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -5,7 +5,9 @@ import ( "fmt" "math/rand" "one-api/common" + "one-api/constant" "one-api/setting" + "one-api/setting/ratio_setting" "sort" "strings" "sync" @@ -66,6 +68,20 @@ func InitChannelCache() { channelSyncLock.Lock() group2model2channels = newGroup2model2channels + //channelsIDM = newChannelId2channel + for i, channel := range newChannelId2channel { + if channel.ChannelInfo.IsMultiKey { + channel.Keys = channel.GetKeys() + if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling { + if oldChannel, ok := channelsIDM[i]; ok { + // 存在旧的渠道,如果是多key且轮询,保留轮询索引信息 + if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling { + channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex + } + } + } + } + } channelsIDM = newChannelId2channel channelSyncLock.Unlock() common.SysLog("channels synced from database") @@ -109,20 +125,10 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, return nil, group, err } } - if channel == nil { - return nil, group, errors.New("channel not found") - } return channel, selectGroup, nil } func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { - if strings.HasPrefix(model, "gpt-4-gizmo") { - model = "gpt-4-gizmo-*" - } - if strings.HasPrefix(model, "gpt-4o-gizmo") { - model = "gpt-4o-gizmo-*" - } - // if memory cache is disabled, get channel directly from database if !common.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model, retry) @@ -130,10 +136,18 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, channelSyncLock.RLock() defer channelSyncLock.RUnlock() + + // First, try to find channels with the exact model name. channels := group2model2channels[group][model] + // If no channels found, try to find channels with the normalized model name. if len(channels) == 0 { - return nil, errors.New("channel not found") + normalizedModel := ratio_setting.FormatMatchingModelName(model) + channels = group2model2channels[group][normalizedModel] + } + + if len(channels) == 0 { + return nil, nil } if len(channels) == 1 { @@ -206,9 +220,6 @@ func CacheGetChannel(id int) (*Channel, error) { if !ok { return nil, fmt.Errorf("渠道# %d,已不存在", id) } - if c.Status != common.ChannelStatusEnabled { - return nil, fmt.Errorf("渠道# %d,已被禁用", id) - } return c, nil } @@ -227,9 +238,6 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) { if !ok { return nil, fmt.Errorf("渠道# %d,已不存在", id) } - if c.Status != common.ChannelStatusEnabled { - return nil, fmt.Errorf("渠道# %d,已被禁用", id) - } return &c.ChannelInfo, nil } @@ -242,6 +250,20 @@ func CacheUpdateChannelStatus(id int, status int) { if channel, ok := channelsIDM[id]; ok { channel.Status = status } + if status != common.ChannelStatusEnabled { + // delete the channel from group2model2channels + for group, model2channels := range group2model2channels { + for model, channels := range model2channels { + for i, channelId := range channels { + if channelId == id { + // remove the channel from the slice + group2model2channels[group][model] = append(channels[:i], channels[i+1:]...) + break + } + } + } + } + } } func CacheUpdateChannel(channel *Channel) { diff --git a/model/log.go b/model/log.go index 2070cd6f3..979cbe7b2 100644 --- a/model/log.go +++ b/model/log.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "one-api/common" + "one-api/logger" + "one-api/types" "os" "strings" "time" @@ -87,13 +89,13 @@ func RecordLog(userId int, logType int, content string) { } err := LOG_DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + common.SysLog("failed to record log: " + err.Error()) } } func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int, isStream bool, group string, other map[string]interface{}) { - common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) + logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) username := c.GetString("username") otherStr := common.MapToJsonStr(other) // 判断是否需要记录 IP @@ -129,7 +131,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, } err := LOG_DB.Create(log).Error if err != nil { - common.LogError(c, "failed to record log: "+err.Error()) + logger.LogError(c, "failed to record log: "+err.Error()) } } @@ -142,7 +144,6 @@ type RecordConsumeLogParams struct { Quota int `json:"quota"` Content string `json:"content"` TokenId int `json:"token_id"` - UserQuota int `json:"user_quota"` UseTimeSeconds int `json:"use_time_seconds"` IsStream bool `json:"is_stream"` Group string `json:"group"` @@ -150,10 +151,10 @@ type RecordConsumeLogParams struct { } func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) { - common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) if !common.LogConsumeEnabled { return } + logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) username := c.GetString("username") otherStr := common.MapToJsonStr(params.Other) // 判断是否需要记录 IP @@ -189,7 +190,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) } err := LOG_DB.Create(log).Error if err != nil { - common.LogError(c, "failed to record log: "+err.Error()) + logger.LogError(c, "failed to record log: "+err.Error()) } if common.DataExportEnabled { gopool.Go(func() { @@ -236,26 +237,22 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName return nil, 0, err } - channelIdsMap := make(map[int]struct{}) - channelMap := make(map[int]string) + channelIds := types.NewSet[int]() for _, log := range logs { if log.ChannelId != 0 { - channelIdsMap[log.ChannelId] = struct{}{} + channelIds.Add(log.ChannelId) } } - channelIds := make([]int, 0, len(channelIdsMap)) - for channelId := range channelIdsMap { - channelIds = append(channelIds, channelId) - } - if len(channelIds) > 0 { + if channelIds.Len() > 0 { var channels []struct { Id int `gorm:"column:id"` Name string `gorm:"column:name"` } - if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil { + if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil { return logs, total, err } + channelMap := make(map[int]string, len(channels)) for _, channel := range channels { channelMap[channel.Id] = channel.Name } diff --git a/model/main.go b/model/main.go index 013beacda..1a38d371b 100644 --- a/model/main.go +++ b/model/main.go @@ -180,6 +180,12 @@ func InitDB() (err error) { db = db.Debug() } DB = db + // MySQL charset/collation startup check: ensure Chinese-capable charset + if common.UsingMySQL { + if err := checkMySQLChineseSupport(DB); err != nil { + panic(err) + } + } sqlDB, err := DB.DB() if err != nil { return err @@ -214,6 +220,12 @@ func InitLogDB() (err error) { db = db.Debug() } LOG_DB = db + // If log DB is MySQL, also ensure Chinese-capable charset + if common.LogSqlType == common.DatabaseTypeMySQL { + if err := checkMySQLChineseSupport(LOG_DB); err != nil { + panic(err) + } + } sqlDB, err := LOG_DB.DB() if err != nil { return err @@ -235,9 +247,6 @@ func InitLogDB() (err error) { } func migrateDB() error { - if !common.UsingPostgreSQL { - return migrateDBFast() - } err := DB.AutoMigrate( &Channel{}, &Token{}, @@ -250,7 +259,12 @@ func migrateDB() error { &TopUp{}, &QuotaData{}, &Task{}, + &Model{}, + &Vendor{}, + &PrefillGroup{}, &Setup{}, + &TwoFA{}, + &TwoFABackupCode{}, ) if err != nil { return err @@ -259,6 +273,7 @@ func migrateDB() error { } func migrateDBFast() error { + var wg sync.WaitGroup migrations := []struct { @@ -276,7 +291,12 @@ func migrateDBFast() error { {&TopUp{}, "TopUp"}, {&QuotaData{}, "QuotaData"}, {&Task{}, "Task"}, + {&Model{}, "Model"}, + {&Vendor{}, "Vendor"}, + {&PrefillGroup{}, "PrefillGroup"}, {&Setup{}, "Setup"}, + {&TwoFA{}, "TwoFA"}, + {&TwoFABackupCode{}, "TwoFABackupCode"}, } // 动态计算migration数量,确保errChan缓冲区足够大 errChan := make(chan error, len(migrations)) @@ -332,6 +352,98 @@ func CloseDB() error { return closeDB(DB) } +// checkMySQLChineseSupport ensures the MySQL connection and current schema +// default charset/collation can store Chinese characters. It allows common +// Chinese-capable charsets (utf8mb4, utf8, gbk, big5, gb18030) and panics otherwise. +func checkMySQLChineseSupport(db *gorm.DB) error { + // 仅检测:当前库默认字符集/排序规则 + 各表的排序规则(隐含字符集) + + // Read current schema defaults + var schemaCharset, schemaCollation string + err := db.Raw("SELECT DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = DATABASE()").Row().Scan(&schemaCharset, &schemaCollation) + if err != nil { + return fmt.Errorf("读取当前库默认字符集/排序规则失败 / Failed to read schema default charset/collation: %v", err) + } + + toLower := func(s string) string { return strings.ToLower(s) } + // Allowed charsets that can store Chinese text + allowedCharsets := map[string]string{ + "utf8mb4": "utf8mb4_", + "utf8": "utf8_", + "gbk": "gbk_", + "big5": "big5_", + "gb18030": "gb18030_", + } + isChineseCapable := func(cs, cl string) bool { + csLower := toLower(cs) + clLower := toLower(cl) + if prefix, ok := allowedCharsets[csLower]; ok { + if clLower == "" { + return true + } + return strings.HasPrefix(clLower, prefix) + } + // 如果仅提供了排序规则,尝试按排序规则前缀判断 + for _, prefix := range allowedCharsets { + if strings.HasPrefix(clLower, prefix) { + return true + } + } + return false + } + + // 1) 当前库默认值必须支持中文 + if !isChineseCapable(schemaCharset, schemaCollation) { + return fmt.Errorf("当前库默认字符集/排序规则不支持中文:schema(%s/%s)。请将库设置为 utf8mb4/utf8/gbk/big5/gb18030 / Schema default charset/collation is not Chinese-capable: schema(%s/%s). Please set to utf8mb4/utf8/gbk/big5/gb18030", + schemaCharset, schemaCollation, schemaCharset, schemaCollation) + } + + // 2) 所有物理表的排序规则(隐含字符集)必须支持中文 + type tableInfo struct { + Name string + Collation *string + } + var tables []tableInfo + if err := db.Raw("SELECT TABLE_NAME, TABLE_COLLATION FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'").Scan(&tables).Error; err != nil { + return fmt.Errorf("读取表排序规则失败 / Failed to read table collations: %v", err) + } + + var badTables []string + for _, t := range tables { + // NULL 或空表示继承库默认设置,已在上面校验库默认,视为通过 + if t.Collation == nil || *t.Collation == "" { + continue + } + cl := *t.Collation + // 仅凭排序规则判断是否中文可用 + ok := false + lower := strings.ToLower(cl) + for _, prefix := range allowedCharsets { + if strings.HasPrefix(lower, prefix) { + ok = true + break + } + } + if !ok { + badTables = append(badTables, fmt.Sprintf("%s(%s)", t.Name, cl)) + } + } + + if len(badTables) > 0 { + // 限制输出数量以避免日志过长 + maxShow := 20 + shown := badTables + if len(shown) > maxShow { + shown = shown[:maxShow] + } + return fmt.Errorf( + "存在不支持中文的表,请修复其排序规则/字符集。示例(最多展示 %d 项):%v / Found tables not Chinese-capable. Please fix their collation/charset. Examples (showing up to %d): %v", + maxShow, shown, maxShow, shown, + ) + } + return nil +} + var ( lastPingTime time.Time pingMutex sync.Mutex diff --git a/model/missing_models.go b/model/missing_models.go new file mode 100644 index 000000000..18191ba68 --- /dev/null +++ b/model/missing_models.go @@ -0,0 +1,30 @@ +package model + +// GetMissingModels returns model names that are referenced in the system +func GetMissingModels() ([]string, error) { + // 1. 获取所有已启用模型(去重) + models := GetEnabledModels() + if len(models) == 0 { + return []string{}, nil + } + + // 2. 查询已有的元数据模型名 + var existing []string + if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil { + return nil, err + } + + existingSet := make(map[string]struct{}, len(existing)) + for _, e := range existing { + existingSet[e] = struct{}{} + } + + // 3. 收集缺失模型 + var missing []string + for _, name := range models { + if _, ok := existingSet[name]; !ok { + missing = append(missing, name) + } + } + return missing, nil +} diff --git a/model/model_extra.go b/model/model_extra.go new file mode 100644 index 000000000..71fd84e7b --- /dev/null +++ b/model/model_extra.go @@ -0,0 +1,31 @@ +package model + +func GetModelEnableGroups(modelName string) []string { + // 确保缓存最新 + GetPricing() + + if modelName == "" { + return make([]string, 0) + } + + modelEnableGroupsLock.RLock() + groups, ok := modelEnableGroups[modelName] + modelEnableGroupsLock.RUnlock() + if !ok { + return make([]string, 0) + } + return groups +} + +// GetModelQuotaTypes 返回指定模型的计费类型集合(来自缓存) +func GetModelQuotaTypes(modelName string) []int { + GetPricing() + + modelEnableGroupsLock.RLock() + quota, ok := modelQuotaTypeMap[modelName] + modelEnableGroupsLock.RUnlock() + if !ok { + return []int{} + } + return []int{quota} +} diff --git a/model/model_meta.go b/model/model_meta.go new file mode 100644 index 000000000..e41cbd090 --- /dev/null +++ b/model/model_meta.go @@ -0,0 +1,147 @@ +package model + +import ( + "one-api/common" + "strconv" + + "gorm.io/gorm" +) + +const ( + NameRuleExact = iota + NameRulePrefix + NameRuleContains + NameRuleSuffix +) + +type BoundChannel struct { + Name string `json:"name"` + Type int `json:"type"` +} + +type Model struct { + Id int `json:"id"` + ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"` + Description string `json:"description,omitempty" gorm:"type:text"` + Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` + Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"` + VendorID int `json:"vendor_id,omitempty" gorm:"index"` + Endpoints string `json:"endpoints,omitempty" gorm:"type:text"` + Status int `json:"status" gorm:"default:1"` + SyncOfficial int `json:"sync_official" gorm:"default:1"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"` + + BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"` + EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"` + QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"` + NameRule int `json:"name_rule" gorm:"default:0"` + + MatchedModels []string `json:"matched_models,omitempty" gorm:"-"` + MatchedCount int `json:"matched_count,omitempty" gorm:"-"` +} + +func (mi *Model) Insert() error { + now := common.GetTimestamp() + mi.CreatedTime = now + mi.UpdatedTime = now + return DB.Create(mi).Error +} + +func IsModelNameDuplicated(id int, name string) (bool, error) { + if name == "" { + return false, nil + } + var cnt int64 + err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error + return cnt > 0, err +} + +func (mi *Model) Update() error { + mi.UpdatedTime = common.GetTimestamp() + return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}). + Model(&Model{}). + Where("id = ?", mi.Id). + Omit("created_time"). + Select("*"). + Updates(mi).Error +} + +func (mi *Model) Delete() error { + return DB.Delete(mi).Error +} + +func GetVendorModelCounts() (map[int64]int64, error) { + var stats []struct { + VendorID int64 + Count int64 + } + if err := DB.Model(&Model{}). + Select("vendor_id as vendor_id, count(*) as count"). + Group("vendor_id"). + Scan(&stats).Error; err != nil { + return nil, err + } + m := make(map[int64]int64, len(stats)) + for _, s := range stats { + m[s.VendorID] = s.Count + } + return m, nil +} + +func GetAllModels(offset int, limit int) ([]*Model, error) { + var models []*Model + err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error + return models, err +} + +func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) { + result := make(map[string][]BoundChannel) + if len(modelNames) == 0 { + return result, nil + } + type row struct { + Model string + Name string + Type int + } + var rows []row + err := DB.Table("channels"). + Select("abilities.model as model, channels.name as name, channels.type as type"). + Joins("JOIN abilities ON abilities.channel_id = channels.id"). + Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true). + Distinct(). + Scan(&rows).Error + if err != nil { + return nil, err + } + for _, r := range rows { + result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type}) + } + return result, nil +} + +func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) { + var models []*Model + db := DB.Model(&Model{}) + if keyword != "" { + like := "%" + keyword + "%" + db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like) + } + if vendor != "" { + if vid, err := strconv.Atoi(vendor); err == nil { + db = db.Where("models.vendor_id = ?", vid) + } else { + db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%") + } + } + var total int64 + if err := db.Count(&total).Error; err != nil { + return nil, 0, err + } + if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil { + return nil, 0, err + } + return models, total, nil +} diff --git a/model/option.go b/model/option.go index 05b99b41a..2121710ce 100644 --- a/model/option.go +++ b/model/option.go @@ -150,7 +150,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - common.SysError("failed to update option map: " + err.Error()) + common.SysLog("failed to update option map: " + err.Error()) } } } @@ -336,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) { common.LinuxDOClientId = value case "LinuxDOClientSecret": common.LinuxDOClientSecret = value + case "LinuxDOMinimumTrustLevel": + common.LinuxDOMinimumTrustLevel, _ = strconv.Atoi(value) case "Footer": common.Footer = value case "SystemName": diff --git a/model/prefill_group.go b/model/prefill_group.go new file mode 100644 index 000000000..a21b76fe2 --- /dev/null +++ b/model/prefill_group.go @@ -0,0 +1,126 @@ +package model + +import ( + "database/sql/driver" + "encoding/json" + "one-api/common" + + "gorm.io/gorm" +) + +// PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。 +// Name 字段保持唯一,用于在前端下拉框中展示。 +// Type 字段用于区分组的类别,可选值如:model、tag、endpoint。 +// Items 字段使用 JSON 数组保存对应类型的字符串集合,示例: +// ["gpt-4o", "gpt-3.5-turbo"] +// 设计遵循 3NF,避免冗余,提供灵活扩展能力。 + +// JSONValue 基于 json.RawMessage 实现,支持从数据库的 []byte 和 string 两种类型读取 +type JSONValue json.RawMessage + +// Value 实现 driver.Valuer 接口,用于数据库写入 +func (j JSONValue) Value() (driver.Value, error) { + if j == nil { + return nil, nil + } + return []byte(j), nil +} + +// Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型 +func (j *JSONValue) Scan(value interface{}) error { + switch v := value.(type) { + case nil: + *j = nil + return nil + case []byte: + // 拷贝底层字节,避免保留底层缓冲区 + b := make([]byte, len(v)) + copy(b, v) + *j = JSONValue(b) + return nil + case string: + *j = JSONValue([]byte(v)) + return nil + default: + // 其他类型尝试序列化为 JSON + b, err := json.Marshal(v) + if err != nil { + return err + } + *j = JSONValue(b) + return nil + } +} + +// MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致 +func (j JSONValue) MarshalJSON() ([]byte, error) { + if j == nil { + return []byte("null"), nil + } + return j, nil +} + +// UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致 +func (j *JSONValue) UnmarshalJSON(data []byte) error { + if data == nil { + *j = nil + return nil + } + b := make([]byte, len(data)) + copy(b, data) + *j = JSONValue(b) + return nil +} + +type PrefillGroup struct { + Id int `json:"id"` + Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"` + Type string `json:"type" gorm:"size:32;index;not null"` + Items JSONValue `json:"items" gorm:"type:json"` + Description string `json:"description,omitempty" gorm:"type:varchar(255)"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` +} + +// Insert 新建组 +func (g *PrefillGroup) Insert() error { + now := common.GetTimestamp() + g.CreatedTime = now + g.UpdatedTime = now + return DB.Create(g).Error +} + +// IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID) +func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) { + if name == "" { + return false, nil + } + var cnt int64 + err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error + return cnt > 0, err +} + +// Update 更新组 +func (g *PrefillGroup) Update() error { + g.UpdatedTime = common.GetTimestamp() + return DB.Save(g).Error +} + +// DeleteByID 根据 ID 删除组 +func DeletePrefillGroupByID(id int) error { + return DB.Delete(&PrefillGroup{}, id).Error +} + +// GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部) +func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) { + var groups []*PrefillGroup + query := DB.Model(&PrefillGroup{}) + if groupType != "" { + query = query.Where("type = ?", groupType) + } + if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil { + return nil, err + } + return groups, nil +} diff --git a/model/pricing.go b/model/pricing.go index a280b5246..c1192a3d9 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -1,7 +1,10 @@ package model import ( + "encoding/json" "fmt" + "strings" + "one-api/common" "one-api/constant" "one-api/setting/ratio_setting" @@ -12,6 +15,10 @@ import ( type Pricing struct { ModelName string `json:"model_name"` + Description string `json:"description,omitempty"` + Icon string `json:"icon,omitempty"` + Tags string `json:"tags,omitempty"` + VendorID int `json:"vendor_id,omitempty"` QuotaType int `json:"quota_type"` ModelRatio float64 `json:"model_ratio"` ModelPrice float64 `json:"model_price"` @@ -21,10 +28,24 @@ type Pricing struct { SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` } +type PricingVendor struct { + ID int `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Icon string `json:"icon,omitempty"` +} + var ( - pricingMap []Pricing - lastGetPricingTime time.Time - updatePricingLock sync.Mutex + pricingMap []Pricing + vendorsList []PricingVendor + supportedEndpointMap map[string]common.EndpointInfo + lastGetPricingTime time.Time + updatePricingLock sync.Mutex + + // 缓存映射:模型名 -> 启用分组 / 计费类型 + modelEnableGroups = make(map[string][]string) + modelQuotaTypeMap = make(map[string]int) + modelEnableGroupsLock = sync.RWMutex{} ) var ( @@ -46,6 +67,15 @@ func GetPricing() []Pricing { return pricingMap } +// GetVendors 返回当前定价接口使用到的供应商信息 +func GetVendors() []PricingVendor { + if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { + // 保证先刷新一次 + GetPricing() + } + return vendorsList +} + func GetModelSupportEndpointTypes(model string) []constant.EndpointType { if model == "" { return make([]constant.EndpointType, 0) @@ -62,9 +92,83 @@ func updatePricing() { //modelRatios := common.GetModelRatios() enableAbilities, err := GetAllEnableAbilityWithChannels() if err != nil { - common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) + common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) return } + // 预加载模型元数据与供应商一次,避免循环查询 + var allMeta []Model + _ = DB.Find(&allMeta).Error + metaMap := make(map[string]*Model) + prefixList := make([]*Model, 0) + suffixList := make([]*Model, 0) + containsList := make([]*Model, 0) + for i := range allMeta { + m := &allMeta[i] + if m.NameRule == NameRuleExact { + metaMap[m.ModelName] = m + } else { + switch m.NameRule { + case NameRulePrefix: + prefixList = append(prefixList, m) + case NameRuleSuffix: + suffixList = append(suffixList, m) + case NameRuleContains: + containsList = append(containsList, m) + } + } + } + + // 将非精确规则模型匹配到 metaMap + for _, m := range prefixList { + for _, pricingModel := range enableAbilities { + if strings.HasPrefix(pricingModel.Model, m.ModelName) { + if _, exists := metaMap[pricingModel.Model]; !exists { + metaMap[pricingModel.Model] = m + } + } + } + } + for _, m := range suffixList { + for _, pricingModel := range enableAbilities { + if strings.HasSuffix(pricingModel.Model, m.ModelName) { + if _, exists := metaMap[pricingModel.Model]; !exists { + metaMap[pricingModel.Model] = m + } + } + } + } + for _, m := range containsList { + for _, pricingModel := range enableAbilities { + if strings.Contains(pricingModel.Model, m.ModelName) { + if _, exists := metaMap[pricingModel.Model]; !exists { + metaMap[pricingModel.Model] = m + } + } + } + } + + // 预加载供应商 + var vendors []Vendor + _ = DB.Find(&vendors).Error + vendorMap := make(map[int]*Vendor) + for i := range vendors { + vendorMap[vendors[i].Id] = &vendors[i] + } + + // 初始化默认供应商映射 + initDefaultVendorMapping(metaMap, vendorMap, enableAbilities) + + // 构建对前端友好的供应商列表 + vendorsList = make([]PricingVendor, 0, len(vendorMap)) + for _, v := range vendorMap { + vendorsList = append(vendorsList, PricingVendor{ + ID: v.Id, + Name: v.Name, + Description: v.Description, + Icon: v.Icon, + }) + } + modelGroupsMap := make(map[string]*types.Set[string]) for _, ability := range enableAbilities { @@ -79,12 +183,9 @@ func updatePricing() { //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点 modelSupportEndpointsStr := make(map[string][]string) + // 先根据已有能力填充原生端点 for _, ability := range enableAbilities { - endpoints, ok := modelSupportEndpointsStr[ability.Model] - if !ok { - endpoints = make([]string, 0) - modelSupportEndpointsStr[ability.Model] = endpoints - } + endpoints := modelSupportEndpointsStr[ability.Model] channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) for _, channelType := range channelTypes { if !common.StringsContains(endpoints, string(channelType)) { @@ -94,6 +195,23 @@ func updatePricing() { modelSupportEndpointsStr[ability.Model] = endpoints } + // 再补充模型自定义端点 + for modelName, meta := range metaMap { + if strings.TrimSpace(meta.Endpoints) == "" { + continue + } + var raw map[string]interface{} + if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { + endpoints := modelSupportEndpointsStr[modelName] + for k := range raw { + if !common.StringsContains(endpoints, k) { + endpoints = append(endpoints, k) + } + } + modelSupportEndpointsStr[modelName] = endpoints + } + } + modelSupportEndpointTypes = make(map[string][]constant.EndpointType) for model, endpoints := range modelSupportEndpointsStr { supportedEndpoints := make([]constant.EndpointType, 0) @@ -104,6 +222,45 @@ func updatePricing() { modelSupportEndpointTypes[model] = supportedEndpoints } + // 构建全局 supportedEndpointMap(默认 + 自定义覆盖) + supportedEndpointMap = make(map[string]common.EndpointInfo) + // 1. 默认端点 + for _, endpoints := range modelSupportEndpointTypes { + for _, et := range endpoints { + if info, ok := common.GetDefaultEndpointInfo(et); ok { + if _, exists := supportedEndpointMap[string(et)]; !exists { + supportedEndpointMap[string(et)] = info + } + } + } + } + // 2. 自定义端点(models 表)覆盖默认 + for _, meta := range metaMap { + if strings.TrimSpace(meta.Endpoints) == "" { + continue + } + var raw map[string]interface{} + if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { + for k, v := range raw { + switch val := v.(type) { + case string: + supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"} + case map[string]interface{}: + ep := common.EndpointInfo{Method: "POST"} + if p, ok := val["path"].(string); ok { + ep.Path = p + } + if m, ok := val["method"].(string); ok { + ep.Method = strings.ToUpper(m) + } + supportedEndpointMap[k] = ep + default: + // ignore unsupported types + } + } + } + } + pricingMap = make([]Pricing, 0) for model, groups := range modelGroupsMap { pricing := Pricing{ @@ -111,6 +268,18 @@ func updatePricing() { EnableGroup: groups.Items(), SupportedEndpointTypes: modelSupportEndpointTypes[model], } + + // 补充模型元数据(描述、标签、供应商、状态) + if meta, ok := metaMap[model]; ok { + // 若模型被禁用(status!=1),则直接跳过,不返回给前端 + if meta.Status != 1 { + continue + } + pricing.Description = meta.Description + pricing.Icon = meta.Icon + pricing.Tags = meta.Tags + pricing.VendorID = meta.VendorID + } modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) if findPrice { pricing.ModelPrice = modelPrice @@ -123,5 +292,21 @@ func updatePricing() { } pricingMap = append(pricingMap, pricing) } + + // 刷新缓存映射,供高并发快速查询 + modelEnableGroupsLock.Lock() + modelEnableGroups = make(map[string][]string) + modelQuotaTypeMap = make(map[string]int) + for _, p := range pricingMap { + modelEnableGroups[p.ModelName] = p.EnableGroup + modelQuotaTypeMap[p.ModelName] = p.QuotaType + } + modelEnableGroupsLock.Unlock() + lastGetPricingTime = time.Now() } + +// GetSupportedEndpointMap 返回全局端点到路径的映射 +func GetSupportedEndpointMap() map[string]common.EndpointInfo { + return supportedEndpointMap +} diff --git a/model/pricing_default.go b/model/pricing_default.go new file mode 100644 index 000000000..db64cafbb --- /dev/null +++ b/model/pricing_default.go @@ -0,0 +1,128 @@ +package model + +import ( + "strings" +) + +// 简化的供应商映射规则 +var defaultVendorRules = map[string]string{ + "gpt": "OpenAI", + "dall-e": "OpenAI", + "whisper": "OpenAI", + "o1": "OpenAI", + "o3": "OpenAI", + "claude": "Anthropic", + "gemini": "Google", + "moonshot": "Moonshot", + "kimi": "Moonshot", + "chatglm": "智谱", + "glm-": "智谱", + "qwen": "阿里巴巴", + "deepseek": "DeepSeek", + "abab": "MiniMax", + "ernie": "百度", + "spark": "讯飞", + "hunyuan": "腾讯", + "command": "Cohere", + "@cf/": "Cloudflare", + "360": "360", + "yi": "零一万物", + "jina": "Jina", + "mistral": "Mistral", + "grok": "xAI", + "llama": "Meta", + "doubao": "字节跳动", + "kling": "快手", + "jimeng": "即梦", + "vidu": "Vidu", +} + +// 供应商默认图标映射 +var defaultVendorIcons = map[string]string{ + "OpenAI": "OpenAI", + "Anthropic": "Claude.Color", + "Google": "Gemini.Color", + "Moonshot": "Moonshot", + "智谱": "Zhipu.Color", + "阿里巴巴": "Qwen.Color", + "DeepSeek": "DeepSeek.Color", + "MiniMax": "Minimax.Color", + "百度": "Wenxin.Color", + "讯飞": "Spark.Color", + "腾讯": "Hunyuan.Color", + "Cohere": "Cohere.Color", + "Cloudflare": "Cloudflare.Color", + "360": "Ai360.Color", + "零一万物": "Yi.Color", + "Jina": "Jina", + "Mistral": "Mistral.Color", + "xAI": "XAI", + "Meta": "Ollama", + "字节跳动": "Doubao.Color", + "快手": "Kling.Color", + "即梦": "Jimeng.Color", + "Vidu": "Vidu", + "微软": "AzureAI", + "Microsoft": "AzureAI", + "Azure": "AzureAI", +} + +// initDefaultVendorMapping 简化的默认供应商映射 +func initDefaultVendorMapping(metaMap map[string]*Model, vendorMap map[int]*Vendor, enableAbilities []AbilityWithChannel) { + for _, ability := range enableAbilities { + modelName := ability.Model + if _, exists := metaMap[modelName]; exists { + continue + } + + // 匹配供应商 + vendorID := 0 + modelLower := strings.ToLower(modelName) + for pattern, vendorName := range defaultVendorRules { + if strings.Contains(modelLower, pattern) { + vendorID = getOrCreateVendor(vendorName, vendorMap) + break + } + } + + // 创建模型元数据 + metaMap[modelName] = &Model{ + ModelName: modelName, + VendorID: vendorID, + Status: 1, + NameRule: NameRuleExact, + } + } +} + +// 查找或创建供应商 +func getOrCreateVendor(vendorName string, vendorMap map[int]*Vendor) int { + // 查找现有供应商 + for id, vendor := range vendorMap { + if vendor.Name == vendorName { + return id + } + } + + // 创建新供应商 + newVendor := &Vendor{ + Name: vendorName, + Status: 1, + Icon: getDefaultVendorIcon(vendorName), + } + + if err := newVendor.Insert(); err != nil { + return 0 + } + + vendorMap[newVendor.Id] = newVendor + return newVendor.Id +} + +// 获取供应商默认图标 +func getDefaultVendorIcon(vendorName string) string { + if icon, exists := defaultVendorIcons[vendorName]; exists { + return icon + } + return "" +} diff --git a/model/pricing_refresh.go b/model/pricing_refresh.go new file mode 100644 index 000000000..cd0d75596 --- /dev/null +++ b/model/pricing_refresh.go @@ -0,0 +1,14 @@ +package model + +// RefreshPricing 强制立即重新计算与定价相关的缓存。 +// 该方法用于需要最新数据的内部管理 API, +// 因此会绕过默认的 1 分钟延迟刷新。 +func RefreshPricing() { + updatePricingLock.Lock() + defer updatePricingLock.Unlock() + + modelSupportEndpointsLock.Lock() + defer modelSupportEndpointsLock.Unlock() + + updatePricing() +} diff --git a/model/redemption.go b/model/redemption.go index bf2376685..1ab84f45c 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "strconv" "gorm.io/gorm" @@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return 0, errors.New("兑换失败," + err.Error()) } - RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id)) + RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id)) return redemption.Quota, nil } diff --git a/model/task.go b/model/task.go index 9e4177ba0..4c64a5293 100644 --- a/model/task.go +++ b/model/task.go @@ -77,7 +77,7 @@ type SyncTaskQueryParams struct { UserIDs []int } -func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task { +func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task { t := &Task{ UserId: relayInfo.UserId, SubmitTime: time.Now().Unix(), diff --git a/model/token.go b/model/token.go index e85a445ec..320b5cf04 100644 --- a/model/token.go +++ b/model/token.go @@ -91,7 +91,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExpired err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + common.SysLog("failed to update token status" + err.Error()) } } return token, errors.New("该令牌已过期") @@ -102,7 +102,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExhausted err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + common.SysLog("failed to update token status" + err.Error()) } } keyPrefix := key[:3] @@ -134,7 +134,7 @@ func GetTokenById(id int) (*Token, error) { if shouldUpdateRedis(true, err) { gopool.Go(func() { if err := cacheSetToken(token); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } @@ -147,7 +147,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) { if shouldUpdateRedis(fromDB, err) && token != nil { gopool.Go(func() { if err := cacheSetToken(*token); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } @@ -178,7 +178,7 @@ func (token *Token) Update() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - common.SysError("failed to update token cache: " + err.Error()) + common.SysLog("failed to update token cache: " + err.Error()) } }) } @@ -194,7 +194,7 @@ func (token *Token) SelectUpdate() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - common.SysError("failed to update token cache: " + err.Error()) + common.SysLog("failed to update token cache: " + err.Error()) } }) } @@ -209,7 +209,7 @@ func (token *Token) Delete() (err error) { gopool.Go(func() { err := cacheDeleteToken(token.Key) if err != nil { - common.SysError("failed to delete token cache: " + err.Error()) + common.SysLog("failed to delete token cache: " + err.Error()) } }) } @@ -269,7 +269,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheIncrTokenQuota(key, int64(quota)) if err != nil { - common.SysError("failed to increase token quota: " + err.Error()) + common.SysLog("failed to increase token quota: " + err.Error()) } }) } @@ -299,7 +299,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheDecrTokenQuota(key, int64(quota)) if err != nil { - common.SysError("failed to decrease token quota: " + err.Error()) + common.SysLog("failed to decrease token quota: " + err.Error()) } }) } diff --git a/model/topup.go b/model/topup.go index c34c0ce62..802c866f7 100644 --- a/model/topup.go +++ b/model/topup.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/logger" "gorm.io/gorm" ) @@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) { return errors.New("充值失败," + err.Error()) } - RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount)) + RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount)) return nil } diff --git a/model/twofa.go b/model/twofa.go new file mode 100644 index 000000000..2a3d33530 --- /dev/null +++ b/model/twofa.go @@ -0,0 +1,322 @@ +package model + +import ( + "errors" + "fmt" + "one-api/common" + "time" + + "gorm.io/gorm" +) + +var ErrTwoFANotEnabled = errors.New("用户未启用2FA") + +// TwoFA 用户2FA设置表 +type TwoFA struct { + Id int `json:"id" gorm:"primaryKey"` + UserId int `json:"user_id" gorm:"unique;not null;index"` + Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端 + IsEnabled bool `json:"is_enabled"` + FailedAttempts int `json:"failed_attempts" gorm:"default:0"` + LockedUntil *time.Time `json:"locked_until,omitempty"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` +} + +// TwoFABackupCode 备用码使用记录表 +type TwoFABackupCode struct { + Id int `json:"id" gorm:"primaryKey"` + UserId int `json:"user_id" gorm:"not null;index"` + CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希 + IsUsed bool `json:"is_used"` + UsedAt *time.Time `json:"used_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` +} + +// GetTwoFAByUserId 根据用户ID获取2FA设置 +func GetTwoFAByUserId(userId int) (*TwoFA, error) { + if userId == 0 { + return nil, errors.New("用户ID不能为空") + } + + var twoFA TwoFA + err := DB.Where("user_id = ?", userId).First(&twoFA).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil // 返回nil表示未设置2FA + } + return nil, err + } + + return &twoFA, nil +} + +// IsTwoFAEnabled 检查用户是否启用了2FA +func IsTwoFAEnabled(userId int) bool { + twoFA, err := GetTwoFAByUserId(userId) + if err != nil || twoFA == nil { + return false + } + return twoFA.IsEnabled +} + +// CreateTwoFA 创建2FA设置 +func (t *TwoFA) Create() error { + // 检查用户是否已存在2FA设置 + existing, err := GetTwoFAByUserId(t.UserId) + if err != nil { + return err + } + if existing != nil { + return errors.New("用户已存在2FA设置") + } + + // 验证用户存在 + var user User + if err := DB.First(&user, t.UserId).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("用户不存在") + } + return err + } + + return DB.Create(t).Error +} + +// Update 更新2FA设置 +func (t *TwoFA) Update() error { + if t.Id == 0 { + return errors.New("2FA记录ID不能为空") + } + return DB.Save(t).Error +} + +// Delete 删除2FA设置 +func (t *TwoFA) Delete() error { + if t.Id == 0 { + return errors.New("2FA记录ID不能为空") + } + + // 使用事务确保原子性 + return DB.Transaction(func(tx *gorm.DB) error { + // 同时删除相关的备用码记录(硬删除) + if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil { + return err + } + + // 硬删除2FA记录 + return tx.Unscoped().Delete(t).Error + }) +} + +// ResetFailedAttempts 重置失败尝试次数 +func (t *TwoFA) ResetFailedAttempts() error { + t.FailedAttempts = 0 + t.LockedUntil = nil + return t.Update() +} + +// IncrementFailedAttempts 增加失败尝试次数 +func (t *TwoFA) IncrementFailedAttempts() error { + t.FailedAttempts++ + + // 检查是否需要锁定 + if t.FailedAttempts >= common.MaxFailAttempts { + lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second) + t.LockedUntil = &lockUntil + } + + return t.Update() +} + +// IsLocked 检查账户是否被锁定 +func (t *TwoFA) IsLocked() bool { + if t.LockedUntil == nil { + return false + } + return time.Now().Before(*t.LockedUntil) +} + +// CreateBackupCodes 创建备用码 +func CreateBackupCodes(userId int, codes []string) error { + return DB.Transaction(func(tx *gorm.DB) error { + // 先删除现有的备用码 + if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil { + return err + } + + // 创建新的备用码记录 + for _, code := range codes { + hashedCode, err := common.HashBackupCode(code) + if err != nil { + return err + } + + backupCode := TwoFABackupCode{ + UserId: userId, + CodeHash: hashedCode, + IsUsed: false, + } + + if err := tx.Create(&backupCode).Error; err != nil { + return err + } + } + + return nil + }) +} + +// ValidateBackupCode 验证并使用备用码 +func ValidateBackupCode(userId int, code string) (bool, error) { + if !common.ValidateBackupCode(code) { + return false, errors.New("验证码或备用码不正确") + } + + normalizedCode := common.NormalizeBackupCode(code) + + // 查找未使用的备用码 + var backupCodes []TwoFABackupCode + if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil { + return false, err + } + + // 验证备用码 + for _, bc := range backupCodes { + if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) { + // 标记为已使用 + now := time.Now() + bc.IsUsed = true + bc.UsedAt = &now + + if err := DB.Save(&bc).Error; err != nil { + return false, err + } + + return true, nil + } + } + + return false, nil +} + +// GetUnusedBackupCodeCount 获取未使用的备用码数量 +func GetUnusedBackupCodeCount(userId int) (int, error) { + var count int64 + err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error + return int(count), err +} + +// DisableTwoFA 禁用用户的2FA +func DisableTwoFA(userId int) error { + twoFA, err := GetTwoFAByUserId(userId) + if err != nil { + return err + } + if twoFA == nil { + return ErrTwoFANotEnabled + } + + // 删除2FA设置和备用码 + return twoFA.Delete() +} + +// EnableTwoFA 启用2FA +func (t *TwoFA) Enable() error { + t.IsEnabled = true + t.FailedAttempts = 0 + t.LockedUntil = nil + return t.Update() +} + +// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录 +func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { + // 检查是否被锁定 + if t.IsLocked() { + return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05")) + } + + // 验证TOTP码 + if !common.ValidateTOTPCode(t.Secret, code) { + // 增加失败次数 + if err := t.IncrementFailedAttempts(); err != nil { + common.SysLog("更新2FA失败次数失败: " + err.Error()) + } + return false, nil + } + + // 验证成功,重置失败次数并更新最后使用时间 + now := time.Now() + t.FailedAttempts = 0 + t.LockedUntil = nil + t.LastUsedAt = &now + + if err := t.Update(); err != nil { + common.SysLog("更新2FA使用记录失败: " + err.Error()) + } + + return true, nil +} + +// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录 +func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { + // 检查是否被锁定 + if t.IsLocked() { + return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05")) + } + + // 验证备用码 + valid, err := ValidateBackupCode(t.UserId, code) + if err != nil { + return false, err + } + + if !valid { + // 增加失败次数 + if err := t.IncrementFailedAttempts(); err != nil { + common.SysLog("更新2FA失败次数失败: " + err.Error()) + } + return false, nil + } + + // 验证成功,重置失败次数并更新最后使用时间 + now := time.Now() + t.FailedAttempts = 0 + t.LockedUntil = nil + t.LastUsedAt = &now + + if err := t.Update(); err != nil { + common.SysLog("更新2FA使用记录失败: " + err.Error()) + } + + return true, nil +} + +// GetTwoFAStats 获取2FA统计信息(管理员使用) +func GetTwoFAStats() (map[string]interface{}, error) { + var totalUsers, enabledUsers int64 + + // 总用户数 + if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil { + return nil, err + } + + // 启用2FA的用户数 + if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil { + return nil, err + } + + enabledRate := float64(0) + if totalUsers > 0 { + enabledRate = float64(enabledUsers) / float64(totalUsers) * 100 + } + + return map[string]interface{}{ + "total_users": totalUsers, + "enabled_users": enabledUsers, + "enabled_rate": fmt.Sprintf("%.1f%%", enabledRate), + }, nil +} diff --git a/model/usedata.go b/model/usedata.go index 1255b0bed..7e525d2e1 100644 --- a/model/usedata.go +++ b/model/usedata.go @@ -21,12 +21,6 @@ type QuotaData struct { } func UpdateQuotaData() { - // recover - defer func() { - if r := recover(); r != nil { - common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) - } - }() for { if common.DataExportEnabled { common.SysLog("正在更新数据看板数据...") diff --git a/model/user.go b/model/user.go index 6021f495c..ea0584c5a 100644 --- a/model/user.go +++ b/model/user.go @@ -6,6 +6,7 @@ import ( "fmt" "one-api/common" "one-api/dto" + "one-api/logger" "strconv" "strings" @@ -75,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting { if user.Setting != "" { err := json.Unmarshal([]byte(user.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) } } return setting @@ -84,12 +85,74 @@ func (user *User) GetSetting() dto.UserSetting { func (user *User) SetSetting(setting dto.UserSetting) { settingBytes, err := json.Marshal(setting) if err != nil { - common.SysError("failed to marshal setting: " + err.Error()) + common.SysLog("failed to marshal setting: " + err.Error()) return } user.Setting = string(settingBytes) } +// 根据用户角色生成默认的边栏配置 +func generateDefaultSidebarConfigForRole(userRole int) string { + defaultConfig := map[string]interface{}{} + + // 聊天区域 - 所有用户都可以访问 + defaultConfig["chat"] = map[string]interface{}{ + "enabled": true, + "playground": true, + "chat": true, + } + + // 控制台区域 - 所有用户都可以访问 + defaultConfig["console"] = map[string]interface{}{ + "enabled": true, + "detail": true, + "token": true, + "log": true, + "midjourney": true, + "task": true, + } + + // 个人中心区域 - 所有用户都可以访问 + defaultConfig["personal"] = map[string]interface{}{ + "enabled": true, + "topup": true, + "personal": true, + } + + // 管理员区域 - 根据角色决定 + if userRole == common.RoleAdminUser { + // 管理员可以访问管理员区域,但不能访问系统设置 + defaultConfig["admin"] = map[string]interface{}{ + "enabled": true, + "channel": true, + "models": true, + "redemption": true, + "user": true, + "setting": false, // 管理员不能访问系统设置 + } + } else if userRole == common.RoleRootUser { + // 超级管理员可以访问所有功能 + defaultConfig["admin"] = map[string]interface{}{ + "enabled": true, + "channel": true, + "models": true, + "redemption": true, + "user": true, + "setting": true, + } + } + // 普通用户不包含admin区域 + + // 转换为JSON字符串 + configBytes, err := json.Marshal(defaultConfig) + if err != nil { + common.SysLog("生成默认边栏配置失败: " + err.Error()) + return "" + } + + return string(configBytes) +} + // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil func CheckUserExistOrDeleted(username string, email string) (bool, error) { var user User @@ -274,7 +337,7 @@ func inviteUser(inviterId int) (err error) { func (user *User) TransferAffQuotaToQuota(quota int) error { // 检查quota是否小于最小额度 if float64(quota) < common.QuotaPerUnit { - return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit))) + return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit))) } // 开始数据库事务 @@ -319,21 +382,45 @@ func (user *User) Insert(inviterId int) error { user.Quota = common.QuotaForNewUser //user.SetAccessToken(common.GetUUID()) user.AffCode = common.GetRandomString(4) + + // 初始化用户设置,包括默认的边栏配置 + if user.Setting == "" { + defaultSetting := dto.UserSetting{} + // 这里暂时不设置SidebarModules,因为需要在用户创建后根据角色设置 + user.SetSetting(defaultSetting) + } + result := DB.Create(user) if result.Error != nil { return result.Error } + + // 用户创建成功后,根据角色初始化边栏配置 + // 需要重新获取用户以确保有正确的ID和Role + var createdUser User + if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil { + // 生成基于角色的默认边栏配置 + defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role) + if defaultSidebarConfig != "" { + currentSetting := createdUser.GetSetting() + currentSetting.SidebarModules = defaultSidebarConfig + createdUser.SetSetting(currentSetting) + createdUser.Update(false) + common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role)) + } + } + if common.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) } if inviterId != 0 { if common.QuotaForInvitee > 0 { _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee))) } if common.QuotaForInviter > 0 { //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter))) _ = inviteUser(inviterId) } } @@ -517,7 +604,7 @@ func IsAdmin(userId int) bool { var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) + common.SysLog("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser @@ -572,7 +659,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserQuotaCache(id, quota); err != nil { - common.SysError("failed to update user quota cache: " + err.Error()) + common.SysLog("failed to update user quota cache: " + err.Error()) } }) } @@ -610,7 +697,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserGroupCache(id, group); err != nil { - common.SysError("failed to update user group cache: " + err.Error()) + common.SysLog("failed to update user group cache: " + err.Error()) } }) } @@ -639,7 +726,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserSettingCache(id, setting); err != nil { - common.SysError("failed to update user setting cache: " + err.Error()) + common.SysLog("failed to update user setting cache: " + err.Error()) } }) } @@ -669,7 +756,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) { gopool.Go(func() { err := cacheIncrUserQuota(id, int64(quota)) if err != nil { - common.SysError("failed to increase user quota: " + err.Error()) + common.SysLog("failed to increase user quota: " + err.Error()) } }) if !db && common.BatchUpdateEnabled { @@ -694,7 +781,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { gopool.Go(func() { err := cacheDecrUserQuota(id, int64(quota)) if err != nil { - common.SysError("failed to decrease user quota: " + err.Error()) + common.SysLog("failed to decrease user quota: " + err.Error()) } }) if common.BatchUpdateEnabled { @@ -750,7 +837,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota and request count: " + err.Error()) + common.SysLog("failed to update user used quota and request count: " + err.Error()) return } @@ -767,14 +854,14 @@ func updateUserUsedQuota(id int, quota int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota: " + err.Error()) + common.SysLog("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { - common.SysError("failed to update user request count: " + err.Error()) + common.SysLog("failed to update user request count: " + err.Error()) } } @@ -785,7 +872,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserNameCache(id, username); err != nil { - common.SysError("failed to update user name cache: " + err.Error()) + common.SysLog("failed to update user name cache: " + err.Error()) } }) } diff --git a/model/user_cache.go b/model/user_cache.go index a631457c2..936e1a431 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -37,7 +37,7 @@ func (user *UserBase) GetSetting() dto.UserSetting { if user.Setting != "" { err := common.Unmarshal([]byte(user.Setting), &setting) if err != nil { - common.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) } } return setting @@ -78,7 +78,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) { if shouldUpdateRedis(fromDB, err) && user != nil { gopool.Go(func() { if err := updateUserCache(*user); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } diff --git a/model/utils.go b/model/utils.go index 1f8a09631..dced2bc61 100644 --- a/model/utils.go +++ b/model/utils.go @@ -77,12 +77,12 @@ func batchUpdate() { case BatchUpdateTypeUserQuota: err := increaseUserQuota(key, value) if err != nil { - common.SysError("failed to batch update user quota: " + err.Error()) + common.SysLog("failed to batch update user quota: " + err.Error()) } case BatchUpdateTypeTokenQuota: err := increaseTokenQuota(key, value) if err != nil { - common.SysError("failed to batch update token quota: " + err.Error()) + common.SysLog("failed to batch update token quota: " + err.Error()) } case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) diff --git a/model/vendor_meta.go b/model/vendor_meta.go new file mode 100644 index 000000000..20deaea9b --- /dev/null +++ b/model/vendor_meta.go @@ -0,0 +1,88 @@ +package model + +import ( + "one-api/common" + + "gorm.io/gorm" +) + +// Vendor 用于存储供应商信息,供模型引用 +// Name 唯一,用于在模型中关联 +// Icon 采用 @lobehub/icons 的图标名,前端可直接渲染 +// Status 预留字段,1 表示启用 +// 本表同样遵循 3NF 设计范式 + +type Vendor struct { + Id int `json:"id"` + Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name_delete_at,priority:1"` + Description string `json:"description,omitempty" gorm:"type:text"` + Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` + Status int `json:"status" gorm:"default:1"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name_delete_at,priority:2"` +} + +// Insert 创建新的供应商记录 +func (v *Vendor) Insert() error { + now := common.GetTimestamp() + v.CreatedTime = now + v.UpdatedTime = now + return DB.Create(v).Error +} + +// IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID) +func IsVendorNameDuplicated(id int, name string) (bool, error) { + if name == "" { + return false, nil + } + var cnt int64 + err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error + return cnt > 0, err +} + +// Update 更新供应商记录 +func (v *Vendor) Update() error { + v.UpdatedTime = common.GetTimestamp() + return DB.Save(v).Error +} + +// Delete 软删除供应商 +func (v *Vendor) Delete() error { + return DB.Delete(v).Error +} + +// GetVendorByID 根据 ID 获取供应商 +func GetVendorByID(id int) (*Vendor, error) { + var v Vendor + err := DB.First(&v, id).Error + if err != nil { + return nil, err + } + return &v, nil +} + +// GetAllVendors 获取全部供应商(分页) +func GetAllVendors(offset int, limit int) ([]*Vendor, error) { + var vendors []*Vendor + err := DB.Offset(offset).Limit(limit).Find(&vendors).Error + return vendors, err +} + +// SearchVendors 按关键字搜索供应商 +func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) { + db := DB.Model(&Vendor{}) + if keyword != "" { + like := "%" + keyword + "%" + db = db.Where("name LIKE ? OR description LIKE ?", like, like) + } + var total int64 + if err := db.Count(&total).Error; err != nil { + return nil, 0, err + } + var vendors []*Vendor + if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil { + return nil, 0, err + } + return vendors, total, nil +} diff --git a/relay/audio_handler.go b/relay/audio_handler.go index f39dbd823..711cc7a9b 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -7,104 +7,43 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/types" - "strings" "github.com/gin-gonic/gin" ) -func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { - audioRequest := &dto.AudioRequest{} - err := common.UnmarshalBodyReusable(c, audioRequest) +func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) + + audioReq, ok := info.Request.(*dto.AudioRequest) + if !ok { + return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + request, err := common.DeepCopy(audioReq) if err != nil { - return nil, err + return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - switch info.RelayMode { - case relayconstant.RelayModeAudioSpeech: - if audioRequest.Model == "" { - return nil, errors.New("model is required") - } - if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveInput(audioRequest.Input) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) - return nil, err - } - } - default: - err = c.Request.ParseForm() - if err != nil { - return nil, err - } - formData := c.Request.PostForm - if audioRequest.Model == "" { - audioRequest.Model = formData.Get("model") - } - - if audioRequest.Model == "" { - return nil, errors.New("model is required") - } - audioRequest.ResponseFormat = formData.Get("response_format") - if audioRequest.ResponseFormat == "" { - audioRequest.ResponseFormat = "json" - } - } - return audioRequest, nil -} - -func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) - audioRequest, err := getAndValidAudioRequest(c, relayInfo) + err = helper.ModelMappedHelper(c, info, request) if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptTokens := 0 - preConsumedTokens := common.PreConsumedQuota - if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) - preConsumedTokens = promptTokens - relayInfo.PromptTokens = promptTokens - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) - } - - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if openaiErr != nil { - return openaiErr - } - defer func() { - if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - err = helper.ModelMappedHelper(c, relayInfo, audioRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) - } - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) - ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) + ioReader, err := adaptor.ConvertAudioRequest(c, info, *request) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - resp, err := adaptor.DoRequest(c, relayInfo, ioReader) + resp, err := adaptor.DoRequest(c, info, ioReader) if err != nil { return types.NewError(err, types.ErrorCodeDoRequestFailed) } @@ -121,14 +60,14 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index ab8836baa..02de99567 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -26,19 +26,20 @@ type Adaptor interface { GetModelList() []string GetChannelName() string ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) + ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) } type TaskAdaptor interface { - Init(info *relaycommon.TaskRelayInfo) + Init(info *relaycommon.RelayInfo) - ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError + ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError - BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) - BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error - BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) + BuildRequestURL(info *relaycommon.RelayInfo) (string, error) + BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error + BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) - DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError) + DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, err *dto.TaskError) GetModelList() []string GetChannelName() string diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index d941a1bc7..3ce9e22d3 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -7,10 +7,12 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/claude" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/types" + "strings" "github.com/gin-gonic/gin" ) @@ -18,10 +20,13 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { //TODO implement me - panic("implement me") - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + return req, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -29,18 +34,26 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { var fullRequestURL string - switch info.RelayMode { - case constant.RelayModeEmbeddings: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) - case constant.RelayModeRerank: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) - case constant.RelayModeImagesGenerations: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) - case constant.RelayModeCompletions: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) + switch info.RelayFormat { + case types.RelayFormatClaude: + fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl) default: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) + switch info.RelayMode { + case constant.RelayModeEmbeddings: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl) + case constant.RelayModeRerank: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl) + case constant.RelayModeImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl) + case constant.RelayModeImagesEdits: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl) + case constant.RelayModeCompletions: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl) + default: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl) + } } + return fullRequestURL, nil } @@ -53,6 +66,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel if c.GetString("plugin") != "" { req.Set("X-DashScope-Plugin", c.GetString("plugin")) } + if info.RelayMode == constant.RelayModeImagesGenerations { + req.Set("X-DashScope-Async", "enable") + } + if info.RelayMode == constant.RelayModeImagesEdits { + req.Set("Content-Type", "application/json") + } return nil } @@ -60,7 +79,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - + // docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216 + // fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True. + if strings.Contains(request.Model, "thinking") { + request.EnableThinking = true + request.Stream = true + info.IsStream = true + } // fix: ali parameter.enable_thinking must be set to false for non-streaming calls if !info.IsStream { request.EnableThinking = false @@ -74,8 +99,30 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - aliRequest := oaiImage2Ali(request) - return aliRequest, nil + if info.RelayMode == constant.RelayModeImagesGenerations { + aliRequest, err := oaiImage2Ali(request) + if err != nil { + return nil, fmt.Errorf("convert image request failed: %w", err) + } + return aliRequest, nil + } else if info.RelayMode == constant.RelayModeImagesEdits { + // ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416 + // 如果用户使用表单,则需要解析表单数据 + if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { + aliRequest, err := oaiFormEdit2AliImageEdit(c, info, request) + if err != nil { + return nil, fmt.Errorf("convert image edit form request failed: %w", err) + } + return aliRequest, nil + } else { + aliRequest, err := oaiImage2Ali(request) + if err != nil { + return nil, fmt.Errorf("convert image request failed: %w", err) + } + return aliRequest, nil + } + } + return nil, fmt.Errorf("unsupported image relay mode: %d", info.RelayMode) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -101,21 +148,27 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - switch info.RelayMode { - case constant.RelayModeImagesGenerations: - err, usage = aliImageHandler(c, resp, info) - case constant.RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) - case constant.RelayModeRerank: - err, usage = RerankHandler(c, resp, info) - default: + switch info.RelayFormat { + case types.RelayFormatClaude: if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { - usage, err = openai.OpenaiHandler(c, info, resp) + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) } + default: + switch info.RelayMode { + case constant.RelayModeImagesGenerations: + err, usage = aliImageHandler(c, resp, info) + case constant.RelayModeImagesEdits: + err, usage = aliImageEditHandler(c, resp, info) + case constant.RelayModeRerank: + err, usage = RerankHandler(c, resp, info) + default: + adaptor := openai.Adaptor{} + usage, err = adaptor.DoResponse(c, resp, info) + } + return usage, err } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go index dbd189687..0873c99f0 100644 --- a/relay/channel/ali/dto.go +++ b/relay/channel/ali/dto.go @@ -3,10 +3,15 @@ package ali import "one-api/dto" type AliMessage struct { - Content string `json:"content"` + Content any `json:"content"` Role string `json:"role"` } +type AliMediaContent struct { + Image string `json:"image,omitempty"` + Text string `json:"text,omitempty"` +} + type AliInput struct { Prompt string `json:"prompt,omitempty"` //History []AliMessage `json:"history,omitempty"` @@ -70,13 +75,14 @@ type TaskResult struct { } type AliOutput struct { - TaskId string `json:"task_id,omitempty"` - TaskStatus string `json:"task_status,omitempty"` - Text string `json:"text"` - FinishReason string `json:"finish_reason"` - Message string `json:"message,omitempty"` - Code string `json:"code,omitempty"` - Results []TaskResult `json:"results,omitempty"` + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + Message string `json:"message,omitempty"` + Code string `json:"code,omitempty"` + Results []TaskResult `json:"results,omitempty"` + Choices []map[string]any `json:"choices,omitempty"` } type AliResponse struct { @@ -86,20 +92,26 @@ type AliResponse struct { } type AliImageRequest struct { - Model string `json:"model"` - Input struct { - Prompt string `json:"prompt"` - NegativePrompt string `json:"negative_prompt,omitempty"` - } `json:"input"` - Parameters struct { - Size string `json:"size,omitempty"` - N int `json:"n,omitempty"` - Steps string `json:"steps,omitempty"` - Scale string `json:"scale,omitempty"` - } `json:"parameters,omitempty"` + Model string `json:"model"` + Input any `json:"input"` + Parameters any `json:"parameters,omitempty"` ResponseFormat string `json:"response_format,omitempty"` } +type AliImageParameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + Watermark *bool `json:"watermark,omitempty"` +} + +type AliImageInput struct { + Prompt string `json:"prompt,omitempty"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Messages []AliMessage `json:"messages,omitempty"` +} + type AliRerankParameters struct { TopN *int `json:"top_n,omitempty"` ReturnDocuments *bool `json:"return_documents,omitempty"` diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 0d430c629..490c9d0ad 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -1,13 +1,16 @@ package ali import ( - "encoding/json" + "context" + "encoding/base64" "errors" "fmt" "io" + "mime/multipart" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/service" "one-api/types" @@ -17,19 +20,139 @@ import ( "github.com/gin-gonic/gin" ) -func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { +func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) { + var imageRequest AliImageRequest + imageRequest.Model = request.Model + imageRequest.ResponseFormat = request.ResponseFormat + logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra) + if request.Extra != nil { + if val, ok := request.Extra["parameters"]; ok { + err := common.Unmarshal(val, &imageRequest.Parameters) + if err != nil { + return nil, fmt.Errorf("invalid parameters field: %w", err) + } + } + if val, ok := request.Extra["input"]; ok { + err := common.Unmarshal(val, &imageRequest.Input) + if err != nil { + return nil, fmt.Errorf("invalid input field: %w", err) + } + } + } + + if imageRequest.Parameters == nil { + imageRequest.Parameters = AliImageParameters{ + Size: strings.Replace(request.Size, "x", "*", -1), + N: int(request.N), + Watermark: request.Watermark, + } + } + + if imageRequest.Input == nil { + imageRequest.Input = AliImageInput{ + Prompt: request.Prompt, + } + } + + return &imageRequest, nil +} + +func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) { var imageRequest AliImageRequest - imageRequest.Input.Prompt = request.Prompt imageRequest.Model = request.Model - imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) - imageRequest.Parameters.N = request.N imageRequest.ResponseFormat = request.ResponseFormat - return &imageRequest + mf := c.Request.MultipartForm + if mf == nil { + if _, err := c.MultipartForm(); err != nil { + return nil, fmt.Errorf("failed to parse image edit form request: %w", err) + } + mf = c.Request.MultipartForm + } + + var imageFiles []*multipart.FileHeader + var exists bool + + // First check for standard "image" field + if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 { + // If not found, check for "image[]" field + if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 { + // If still not found, iterate through all fields to find any that start with "image[" + foundArrayImages := false + for fieldName, files := range mf.File { + if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { + foundArrayImages = true + imageFiles = append(imageFiles, files...) + } + } + + // If no image fields found at all + if !foundArrayImages && (len(imageFiles) == 0) { + return nil, errors.New("image is required") + } + } + } + + if len(imageFiles) == 0 { + return nil, errors.New("image is required") + } + + if len(imageFiles) > 1 { + return nil, errors.New("only one image is supported for qwen edit") + } + + // 获取base64编码的图片 + var imageBase64s []string + for _, file := range imageFiles { + image, err := file.Open() + if err != nil { + return nil, errors.New("failed to open image file") + } + + // 读取文件内容 + imageData, err := io.ReadAll(image) + if err != nil { + return nil, errors.New("failed to read image file") + } + + // 获取MIME类型 + mimeType := http.DetectContentType(imageData) + + // 编码为base64 + base64Data := base64.StdEncoding.EncodeToString(imageData) + + // 构造data URL格式 + dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data) + imageBase64s = append(imageBase64s, dataURL) + image.Close() + } + + //dto.MediaContent{} + mediaContents := make([]AliMediaContent, len(imageBase64s)) + for i, b64 := range imageBase64s { + mediaContents[i] = AliMediaContent{ + Image: b64, + } + } + mediaContents = append(mediaContents, AliMediaContent{ + Text: request.Prompt, + }) + imageRequest.Input = AliImageInput{ + Messages: []AliMessage{ + { + Role: "user", + Content: mediaContents, + }, + }, + } + imageRequest.Parameters = AliImageParameters{ + Watermark: request.Watermark, + } + return &imageRequest, nil } func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) { - url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID) + url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID) var aliResponse AliResponse @@ -43,7 +166,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error client := &http.Client{} resp, err := client.Do(req) if err != nil { - common.SysError("updateTask client.Do err: " + err.Error()) + common.SysLog("updateTask client.Do err: " + err.Error()) return &aliResponse, err, nil } defer resp.Body.Close() @@ -51,17 +174,17 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error responseBody, err := io.ReadAll(resp.Body) var response AliResponse - err = json.Unmarshal(responseBody, &response) + err = common.Unmarshal(responseBody, &response) if err != nil { - common.SysError("updateTask NewDecoder err: " + err.Error()) + common.SysLog("updateTask NewDecoder err: " + err.Error()) return &aliResponse, err, nil } return &response, nil, responseBody } -func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) { - waitSeconds := 3 +func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) { + waitSeconds := 10 step := 0 maxStep := 20 @@ -69,11 +192,14 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, [] var responseBody []byte for { + logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds)) step++ rsp, err, body := updateTask(info, taskID) responseBody = body if err != nil { - return &taskResponse, responseBody, err + logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error()) + time.Sleep(time.Duration(waitSeconds) * time.Second) + continue } if rsp.Output.TaskStatus == "" { @@ -99,7 +225,7 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, [] return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") } -func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse { +func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse { imageResponse := dto.ImageResponse{ Created: info.StartTime.Unix(), } @@ -109,7 +235,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc if responseFormat == "b64_json" { _, b64, err := service.GetImageFromUrl(data.Url) if err != nil { - common.LogError(c, "get_image_data_failed: "+err.Error()) + logger.LogError(c, "get_image_data_failed: "+err.Error()) continue } b64Json = b64 @@ -123,6 +249,9 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc RevisedPrompt: "", }) } + var mapResponse map[string]any + _ = common.Unmarshal(originBody, &mapResponse) + imageResponse.Extra = mapResponse return &imageResponse } @@ -132,20 +261,20 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela var aliTaskResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) - err = json.Unmarshal(responseBody, &aliTaskResponse) + service.CloseResponseBodyGracefully(resp) + err = common.Unmarshal(responseBody, &aliTaskResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliTaskResponse.Message != "" { - common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) + logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil } - aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId) + aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId) if err != nil { return types.NewError(err, types.ErrorCodeBadResponse), nil } @@ -159,13 +288,52 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela }, resp.StatusCode), nil } - fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat) - jsonResponse, err := json.Marshal(fullTextResponse) + fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat) + jsonResponse, err := common.Marshal(fullTextResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - c.Writer.Write(jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) + return nil, &dto.Usage{} +} + +func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { + var aliResponse AliResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil + } + + service.CloseResponseBodyGracefully(resp) + err = common.Unmarshal(responseBody, &aliResponse) + if err != nil { + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil + } + + if aliResponse.Message != "" { + logger.LogError(c, "ali_task_failed: "+aliResponse.Message) + return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil + } + var fullTextResponse dto.ImageResponse + if len(aliResponse.Output.Choices) > 0 { + fullTextResponse = dto.ImageResponse{ + Created: info.StartTime.Unix(), + Data: []dto.ImageData{ + { + Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string), + B64Json: "", + }, + }, + } + } + + var mapResponse map[string]any + _ = common.Unmarshal(responseBody, &mapResponse) + fullTextResponse.Extra = mapResponse + jsonResponse, err := common.Marshal(fullTextResponse) + if err != nil { + return types.NewError(err, types.ErrorCodeBadResponseBody), nil + } + service.IOCopyBytesGracefully(c, resp, jsonResponse) return nil, &dto.Usage{} } diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go index 59cb0a11a..e7d6b5141 100644 --- a/relay/channel/ali/rerank.go +++ b/relay/channel/ali/rerank.go @@ -4,9 +4,9 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest { func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var aliResponse AliRerankResponse err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliResponse.Code != "" { diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index 6d90fa713..67b63286c 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -8,6 +8,7 @@ import ( "one-api/common" "one-api/dto" "one-api/relay/helper" + "one-api/service" "strings" "one-api/types" @@ -43,10 +44,10 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro var fullTextResponse dto.FlexibleEmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&fullTextResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) model := c.GetString("model") if model == "" { @@ -148,7 +149,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, var aliResponse AliResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } if aliResponse.Usage.OutputTokens != 0 { @@ -161,7 +162,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -171,7 +172,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, &usage } @@ -179,12 +180,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U var aliResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliResponse.Code != "" { return types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ff7c63fab..a50d5bdb5 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -7,11 +7,13 @@ import ( "io" "net/http" common2 "one-api/common" + "one-api/logger" "one-api/relay/common" "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "one-api/setting/operation_setting" + "one-api/types" "sync" "time" @@ -46,7 +48,19 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } - err = a.SetupRequestHeader(c, &req.Header, info) + headers := req.Header + headerOverride := make(map[string]string) + for k, v := range info.HeadersOverride { + if str, ok := v.(string); ok { + headerOverride[k] = str + } else { + return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) + } + } + for key, value := range headerOverride { + headers.Set(key, value) + } + err = a.SetupRequestHeader(c, &headers, info) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } @@ -71,8 +85,19 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod } // set form data req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - - err = a.SetupRequestHeader(c, &req.Header, info) + headers := req.Header + headerOverride := make(map[string]string) + for k, v := range info.HeadersOverride { + if str, ok := v.(string); ok { + headerOverride[k] = str + } else { + return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) + } + } + for key, value := range headerOverride { + headers.Set(key, value) + } + err = a.SetupRequestHeader(c, &headers, info) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } @@ -181,7 +206,7 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error { err := helper.PingData(c) if err != nil { - common2.LogError(c, "SSE ping error: "+err.Error()) + logger.LogError(c, "SSE ping error: "+err.Error()) done <- err return } @@ -223,7 +248,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http helper.SetEventStreamHeaders(c) // 处理流式请求的 ping 保活 generalSettings := operation_setting.GetGeneralSetting() - if generalSettings.PingIntervalEnabled { + if generalSettings.PingIntervalEnabled && !info.DisablePing { pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second stopPinger = startPingKeepAlive(c, pingInterval) // 使用defer确保在任何情况下都能停止ping goroutine @@ -252,7 +277,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http return resp, nil } -func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.BuildRequestURL(info) if err != nil { return nil, err @@ -269,7 +294,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } - resp, err := doRequest(c, req, info.RelayInfo) + resp, err := doRequest(c, req, info) if err != nil { return nil, fmt.Errorf("do request failed: %w", err) } diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index d3354f00d..1526a7f75 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -22,6 +22,11 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { c.Set("request_model", request.Model) c.Set("converted_request", request) @@ -58,7 +63,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn var claudeReq *dto.ClaudeRequest var err error - claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) + claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request) if err != nil { return nil, err } diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 64c7b747c..3f8800b1e 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -13,6 +13,7 @@ var awsModelIDMap = map[string]string{ "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0", "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0", "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0", + "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ @@ -54,6 +55,9 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{ "anthropic.claude-opus-4-20250514-v1:0": { "us": true, }, + "anthropic.claude-opus-4-1-20250805-v1:0": { + "us": true, + }, } var awsRegionCrossModelPrefixMap = map[string]string{ diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 0df19e07f..5822e363a 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -1,7 +1,6 @@ package aws import ( - "encoding/json" "fmt" "net/http" "one-api/common" @@ -19,20 +18,31 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/aws/smithy-go/auth/bearer" ) func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) { awsSecret := strings.Split(info.ApiKey, "|") - if len(awsSecret) != 3 { + var client *bedrockruntime.Client + switch len(awsSecret) { + case 2: + apiKey := awsSecret[0] + region := awsSecret[1] + client = bedrockruntime.New(bedrockruntime.Options{ + Region: region, + BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}}, + }) + case 3: + ak := awsSecret[0] + sk := awsSecret[1] + region := awsSecret[2] + client = bedrockruntime.New(bedrockruntime.Options{ + Region: region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), + }) + default: return nil, errors.New("invalid aws secret key") } - ak := awsSecret[0] - sk := awsSecret[1] - region := awsSecret[2] - client := bedrockruntime.New(bedrockruntime.Options{ - Region: region, - Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), - }) return client, nil } @@ -102,14 +112,14 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* } claudeReq := claudeReq_.(*dto.ClaudeRequest) awsClaudeReq := copyRequest(claudeReq) - awsReq.Body, err = json.Marshal(awsClaudeReq) + awsReq.Body, err = common.Marshal(awsClaudeReq) if err != nil { return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil } awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) if err != nil { - return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil + return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil } claudeInfo := &claude.ClaudeResponseInfo{ @@ -154,14 +164,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel claudeReq := claudeReq_.(*dto.ClaudeRequest) awsClaudeReq := copyRequest(claudeReq) - awsReq.Body, err = json.Marshal(awsClaudeReq) + awsReq.Body, err = common.Marshal(awsClaudeReq) if err != nil { return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil } awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) if err != nil { - return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil + return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil } stream := awsResp.GetStream() defer stream.Close() diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 22443354b..32e301eed 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -96,7 +101,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { default: suffix += strings.ToLower(info.UpstreamModelName) } - fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix) + fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix) var accessToken string var err error if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 06b48c205..31e8319e5 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest { EnableCitation: false, UserId: request.User, } - if request.MaxTokens != 0 { - maxTokens := int(request.MaxTokens) - if request.MaxTokens == 1 { + if request.GetMaxTokens() != 0 { + maxTokens := int(request.GetMaxTokens()) + if request.GetMaxTokens() == 1 { maxTokens = 2 } baiduRequest.MaxOutputTokens = &maxTokens @@ -118,7 +118,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var baiduResponse BaiduChatStreamResponse err := common.Unmarshal([]byte(data), &baiduResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } if baiduResponse.Usage.TotalTokens != 0 { @@ -129,11 +129,11 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseBaidu2OpenAI(&baiduResponse) err = helper.ObjectData(c, response) if err != nil { - common.SysError("error sending stream response: " + err.Error()) + common.SysLog("error sending stream response: " + err.Error()) } return true }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, usage } @@ -143,7 +143,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil @@ -168,7 +168,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 375fd5318..0577ebcb7 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/types" "strings" @@ -18,10 +19,14 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { //TODO implement me - panic("implement me") - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -38,20 +43,33 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil + case constant.RelayModeImagesGenerations: + return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil + case constant.RelayModeImagesEdits: + return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil + case constant.RelayModeRerank: + return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil + default: + } + return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - keyParts := strings.Split(info.ApiKey, "|") + keyParts := strings.Split(info.ApiKey, "|") if len(keyParts) == 0 || keyParts[0] == "" { - return errors.New("invalid API key: authorization token is required") - } - if len(keyParts) > 1 { - if keyParts[1] != "" { - req.Set("appid", keyParts[1]) - } - } + return errors.New("invalid API key: authorization token is required") + } + if len(keyParts) > 1 { + if keyParts[1] != "" { + req.Set("appid", keyParts[1]) + } + } req.Set("Authorization", "Bearer "+keyParts[0]) return nil } @@ -63,20 +81,23 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if strings.HasSuffix(info.UpstreamModelName, "-search") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search") request.Model = info.UpstreamModelName - toMap := request.ToMap() - toMap["web_search"] = map[string]any{ - "enable": true, - "enable_citation": true, - "enable_trace": true, - "enable_status": false, + if len(request.WebSearch) == 0 { + toMap := request.ToMap() + toMap["web_search"] = map[string]any{ + "enable": true, + "enable_citation": true, + "enable_trace": true, + "enable_status": false, + } + return toMap, nil } - return toMap, nil + return request, nil } return request, nil } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, nil + return nil, errors.New("not implemented") } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { @@ -94,11 +115,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) - } + adaptor := openai.Adaptor{} + usage, err = adaptor.DoResponse(c, resp, info) return } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 540742d64..959327e16 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -24,6 +24,11 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { return request, nil } @@ -48,9 +53,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if a.RequestMode == RequestModeMessage { - return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil } else { - return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil } } @@ -73,7 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if a.RequestMode == RequestModeCompletion { return RequestOpenAI2ClaudeComplete(*request), nil } else { - return RequestOpenAI2ClaudeMessage(*request) + return RequestOpenAI2ClaudeMessage(c, *request) } } @@ -97,9 +102,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.IsStream { - err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) + return ClaudeStreamHandler(c, resp, info, a.RequestMode) } else { - err, usage = ClaudeHandler(c, resp, a.RequestMode, info) + return ClaudeHandler(c, resp, info, a.RequestMode) } return } diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go index e0e3c4215..a23543d21 100644 --- a/relay/channel/claude/constants.go +++ b/relay/channel/claude/constants.go @@ -17,6 +17,8 @@ var ModelList = []string{ "claude-sonnet-4-20250514-thinking", "claude-opus-4-20250514", "claude-opus-4-20250514-thinking", + "claude-opus-4-1-20250805", + "claude-opus-4-1-20250805-thinking", } var ChannelName = "claude" diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index f20b573d4..511db2c6b 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -31,7 +32,7 @@ func stopReasonClaude2OpenAI(reason string) string { case "end_turn": return "stop" case "max_tokens": - return "max_tokens" + return "length" case "tool_use": return "tool_calls" default: @@ -70,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla return &claudeRequest } -func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { +func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) { claudeTools := make([]any, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { @@ -149,7 +150,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla claudeRequest := dto.ClaudeRequest{ Model: textRequest.Model, - MaxTokens: textRequest.MaxTokens, + MaxTokens: textRequest.GetMaxTokens(), StopSequences: nil, Temperature: textRequest.Temperature, TopP: textRequest.TopP, @@ -273,19 +274,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla claudeMessages := make([]dto.ClaudeMessage, 0) isFirstMessage := true + // 初始化system消息数组,用于累积多个system消息 + var systemMessages []dto.ClaudeMediaMessage + for _, message := range formatMessages { if message.Role == "system" { + // 根据Claude API规范,system字段使用数组格式更有通用性 if message.IsStringContent() { - claudeRequest.System = message.StringContent() + systemMessages = append(systemMessages, dto.ClaudeMediaMessage{ + Type: "text", + Text: common.GetPointer[string](message.StringContent()), + }) } else { - contents := message.ParseContent() - content := "" - for _, ctx := range contents { + // 支持复合内容的system消息(虽然不常见,但需要考虑完整性) + for _, ctx := range message.ParseContent() { if ctx.Type == "text" { - content += ctx.Text + systemMessages = append(systemMessages, dto.ClaudeMediaMessage{ + Type: "text", + Text: common.GetPointer[string](ctx.Text), + }) } + // 未来可以在这里扩展对图片等其他类型的支持 } - claudeRequest.System = content } } else { if isFirstMessage { @@ -354,7 +364,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla // 判断是否是url if strings.HasPrefix(imageUrl.Url, "http") { // 是url,获取图片的类型和base64编码的数据 - fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) + fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude") if err != nil { return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) } @@ -375,7 +385,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla for _, toolCall := range message.ParseToolCalls() { inputObj := make(map[string]any) if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { - common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) + common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) continue } claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ @@ -391,6 +401,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla claudeMessages = append(claudeMessages, claudeMessage) } } + + // 设置累积的system消息 + if len(systemMessages) > 0 { + claudeRequest.System = systemMessages + } + claudeRequest.Prompt = "" claudeRequest.Messages = claudeMessages return &claudeRequest, nil @@ -425,7 +441,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse choice.Delta.Role = "assistant" } else if claudeResponse.Type == "content_block_start" { if claudeResponse.ContentBlock != nil { - //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) + // 如果是文本块,尽可能发送首段文本(若存在) + if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil { + choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text) + } if claudeResponse.ContentBlock.Type == "tool_use" { tools = append(tools, dto.ToolCallResponse{ Index: common.GetPointer(fcIdx), @@ -609,13 +628,13 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud var claudeResponse dto.ClaudeResponse err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return types.NewError(err, types.ErrorCodeBadResponseBody) } - if claudeResponse.Error != nil && claudeResponse.Error.Type != "" { - return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError) + if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" { + return types.WithClaudeError(*claudeError, http.StatusInternalServerError) } - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo) if requestMode == RequestModeCompletion { @@ -628,7 +647,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } helper.ClaudeChunkData(c, claudeResponse, data) - } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { + } else if info.RelayFormat == types.RelayFormatOpenAI { response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) { @@ -637,7 +656,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud err = helper.ObjectData(c, response) if err != nil { - common.LogError(c, "send_stream_response_failed: "+err.Error()) + logger.LogError(c, "send_stream_response_failed: "+err.Error()) } } return nil @@ -653,28 +672,27 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau } if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done { if common.DebugEnabled { - common.SysError("claude response usage is not complete, maybe upstream error") + common.SysLog("claude response usage is not complete, maybe upstream error") } claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { // - } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { - + } else if info.RelayFormat == types.RelayFormatOpenAI { if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { - common.SysError("send final response failed: " + err.Error()) + common.SysLog("send final response failed: " + err.Error()) } } helper.Done(c) } } -func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { +func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) { claudeInfo := &ClaudeResponseInfo{ ResponseId: helper.GetResponseID(c), Created: common.GetTimestamp(), @@ -691,11 +709,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. return true }) if err != nil { - return err, nil + return nil, err } HandleStreamFinalResponse(c, info, claudeInfo, requestMode) - return nil, claudeInfo.Usage + return claudeInfo.Usage, nil } func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError { @@ -704,8 +722,8 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody) } - if claudeResponse.Error != nil && claudeResponse.Error.Type != "" { - return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError) + if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" { + return types.WithClaudeError(*claudeError, http.StatusInternalServerError) } if requestMode == RequestModeCompletion { completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) @@ -721,14 +739,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } var responseData []byte switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) openaiResponse.Usage = *claudeInfo.Usage responseData, err = json.Marshal(openaiResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody) } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: responseData = data } @@ -736,12 +754,12 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests) } - common.IOCopyBytesGracefully(c, nil, responseData) + service.IOCopyBytesGracefully(c, nil, responseData) return nil } -func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { - defer common.CloseResponseBodyGracefully(resp) +func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) { + defer service.CloseResponseBodyGracefully(resp) claudeInfo := &ClaudeResponseInfo{ ResponseId: helper.GetResponseID(c), @@ -752,16 +770,16 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r } responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } if common.DebugEnabled { println("responseBody: ", string(responseBody)) } handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode) if handleErr != nil { - return handleErr, nil + return nil, handleErr } - return nil, claudeInfo.Usage + return claudeInfo.Usage, nil } func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice { diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 6e59ad715..bdea72f01 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/types" @@ -18,6 +19,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -30,11 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil + case constant.RelayModeResponses: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil default: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil } } @@ -57,8 +65,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { - // TODO implement me - return nil, errors.New("not implemented") + return request, nil } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { @@ -105,6 +112,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom } else { err, usage = cfHandler(c, info, resp) } + case constant.RelayModeResponses: + if info.IsStream { + usage, err = openai.OaiResponsesStreamHandler(c, info, resp) + } else { + usage, err = openai.OaiResponsesHandler(c, info, resp) + } case constant.RelayModeAudioTranslation: fallthrough case constant.RelayModeAudioTranscription: diff --git a/relay/channel/cloudflare/dto.go b/relay/channel/cloudflare/dto.go index 62a45c400..72b406155 100644 --- a/relay/channel/cloudflare/dto.go +++ b/relay/channel/cloudflare/dto.go @@ -5,7 +5,7 @@ import "one-api/dto" type CfRequest struct { Messages []dto.Message `json:"messages,omitempty"` Lora string `json:"lora,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` Prompt string `json:"prompt,omitempty"` Raw bool `json:"raw,omitempty"` Stream bool `json:"stream,omitempty"` diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index 5e8fe7f92..00f6b6c5e 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -5,8 +5,8 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -51,7 +51,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res var response dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &response) if err != nil { - common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) + logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) continue } for _, choice := range response.Choices { @@ -66,24 +66,24 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res info.FirstResponseTime = time.Now() } if err != nil { - common.LogError(c, "error_rendering_stream_response: "+err.Error()) + logger.LogError(c, "error_rendering_stream_response: "+err.Error()) } } if err := scanner.Err(); err != nil { - common.LogError(c, "error_scanning_stream_response: "+err.Error()) + logger.LogError(c, "error_scanning_stream_response: "+err.Error()) } usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) if err != nil { - common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) + logger.LogError(c, "error_rendering_final_usage_response: "+err.Error()) } } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, usage } @@ -93,7 +93,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var response dto.TextResponse err = json.Unmarshal(responseBody, &response) if err != nil { @@ -123,7 +123,7 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &cfResp) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 4f3a96c32..c8a38d465 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -17,6 +17,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -38,9 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else { - return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index 410540c0f..d51279633 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -7,7 +7,7 @@ type CohereRequest struct { ChatHistory []ChatHistory `json:"chat_history"` Message string `json:"message"` Stream bool `json:"stream"` - MaxTokens int `json:"max_tokens"` + MaxTokens uint `json:"max_tokens"` SafetyMode string `json:"safety_mode,omitempty"` } diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index fcfb12b75..af3573480 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -118,7 +118,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var cohereResp CohereResponse err := json.Unmarshal([]byte(data), &cohereResp) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } var openaiResp dto.ChatCompletionsStreamResponse @@ -153,7 +153,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } jsonStr, err := json.Marshal(openaiResp) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) @@ -175,7 +175,7 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var cohereResp CohereResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { @@ -216,7 +216,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon. if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var cohereResp CohereRerankResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index fe5f5f002..0f2a6fd3f 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + // ConvertAudioRequest implements channel.Adaptor. func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") @@ -117,7 +122,7 @@ func (a *Adaptor) GetModelList() []string { // GetRequestURL implements channel.Adaptor. func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil + return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil } // Init implements channel.Adaptor. diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 32cc69376..c480045f4 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -49,7 +49,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) // convert coze response to openai response var response dto.TextResponse var cozeResponse CozeChatDetailResponse @@ -154,7 +154,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var chatData CozeChatResponseData err := json.Unmarshal([]byte(data), &chatData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } @@ -171,14 +171,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var messageData CozeChatV3MessageDetail err := json.Unmarshal([]byte(data), &messageData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } var content string err = json.Unmarshal(messageData.Content, &content) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } @@ -203,16 +203,16 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var errorData CozeError err := json.Unmarshal([]byte(data), &errorData) if err != nil { - common.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } - common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) } } func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { - requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) + requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl) requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") // 将 conversationId和chatId作为参数发送get请求 @@ -258,7 +258,7 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo } func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { - requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) + requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl) requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") req, err := http.NewRequest("GET", requestURL, nil) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index edfc7fd3b..17d732ab0 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -19,10 +19,14 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { //TODO implement me - panic("implement me") - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -39,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - fimBaseUrl := info.BaseUrl - if !strings.HasSuffix(info.BaseUrl, "/beta") { + fimBaseUrl := info.ChannelBaseUrl + if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { fimBaseUrl += "/beta" } switch info.RelayMode { case constant.RelayModeCompletions: return fmt.Sprintf("%s/completions", fimBaseUrl), nil default: - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 4ad167663..0a08d035a 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -24,6 +24,11 @@ type Adaptor struct { BotType int } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -56,13 +61,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch a.BotType { case BotTypeWorkFlow: - return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil case BotTypeCompletion: - return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil case BotTypeAgent: fallthrough default: - return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 473371271..2336fd4c9 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -22,7 +22,7 @@ import ( ) func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile { - uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl) + uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl) switch media.Type { case dto.ContentTypeImageURL: // Decode base64 data @@ -36,14 +36,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Decode base64 string decodedData, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { - common.SysError("failed to decode base64: " + err.Error()) + common.SysLog("failed to decode base64: " + err.Error()) return nil } // Create temporary file tempFile, err := os.CreateTemp("", "dify-upload-*") if err != nil { - common.SysError("failed to create temp file: " + err.Error()) + common.SysLog("failed to create temp file: " + err.Error()) return nil } defer tempFile.Close() @@ -51,7 +51,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Write decoded data to temp file if _, err := tempFile.Write(decodedData); err != nil { - common.SysError("failed to write to temp file: " + err.Error()) + common.SysLog("failed to write to temp file: " + err.Error()) return nil } @@ -61,7 +61,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Add user field if err := writer.WriteField("user", user); err != nil { - common.SysError("failed to add user field: " + err.Error()) + common.SysLog("failed to add user field: " + err.Error()) return nil } @@ -74,13 +74,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create form file part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/"))) if err != nil { - common.SysError("failed to create form file: " + err.Error()) + common.SysLog("failed to create form file: " + err.Error()) return nil } // Copy file content to form if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil { - common.SysError("failed to copy file content: " + err.Error()) + common.SysLog("failed to copy file content: " + err.Error()) return nil } writer.Close() @@ -88,7 +88,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create HTTP request req, err := http.NewRequest("POST", uploadUrl, body) if err != nil { - common.SysError("failed to create request: " + err.Error()) + common.SysLog("failed to create request: " + err.Error()) return nil } @@ -99,7 +99,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me client := service.GetHttpClient() resp, err := client.Do(req) if err != nil { - common.SysError("failed to send request: " + err.Error()) + common.SysLog("failed to send request: " + err.Error()) return nil } defer resp.Body.Close() @@ -109,7 +109,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me Id string `json:"id"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - common.SysError("failed to decode response: " + err.Error()) + common.SysLog("failed to decode response: " + err.Error()) return nil } @@ -219,7 +219,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R var difyResponse DifyChunkChatCompletionResponse err := json.Unmarshal([]byte(data), &difyResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } var openaiResponse dto.ChatCompletionsStreamResponse @@ -239,7 +239,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } err = helper.ObjectData(c, openaiResponse) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) } return true }) @@ -258,7 +258,7 @@ func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &difyResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 71eb9ba43..4968f78fe 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -1,14 +1,13 @@ package gemini import ( - "encoding/json" "errors" "fmt" "io" "net/http" - "one-api/common" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/setting/model_setting" @@ -21,10 +20,33 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) { + if len(request.Contents) > 0 { + for i, content := range request.Contents { + if i == 0 { + if request.Contents[0].Role == "" { + request.Contents[0].Role = "user" + } + } + for _, part := range content.Parts { + if part.FileData != nil { + if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") { + part.FileData.MimeType = "video/webm" + } + } + } + } + } + return request, nil +} + +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req) + if err != nil { + return nil, err + } + return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest)) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -37,26 +59,33 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("not supported model for image generation") } - // convert size to aspect ratio + // convert size to aspect ratio but allow user to specify aspect ratio aspectRatio := "1:1" // default aspect ratio - switch request.Size { - case "1024x1024": - aspectRatio = "1:1" - case "1024x1792": - aspectRatio = "9:16" - case "1792x1024": - aspectRatio = "16:9" + size := strings.TrimSpace(request.Size) + if size != "" { + if strings.Contains(size, ":") { + aspectRatio = size + } else { + switch size { + case "1024x1024": + aspectRatio = "1:1" + case "1024x1792": + aspectRatio = "9:16" + case "1792x1024": + aspectRatio = "16:9" + } + } } // build gemini imagen request - geminiRequest := GeminiImageRequest{ - Instances: []GeminiImageInstance{ + geminiRequest := dto.GeminiImageRequest{ + Instances: []dto.GeminiImageInstance{ { Prompt: request.Prompt, }, }, - Parameters: GeminiImageParameters{ - SampleCount: request.N, + Parameters: dto.GeminiImageParameters{ + SampleCount: int(request.N), AspectRatio: aspectRatio, PersonGeneration: "allow_adult", // default allow adult }, @@ -86,20 +115,27 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName) if strings.HasPrefix(info.UpstreamModelName, "imagen") { - return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil + return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil } if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || strings.HasPrefix(info.UpstreamModelName, "embedding") || strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { - return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil + action := "embedContent" + if info.IsGeminiBatchEmbedding { + action = "batchEmbedContents" + } + return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil } action := "generateContent" if info.IsStream { action = "streamGenerateContent?alt=sse" + if info.RelayMode == constant.RelayModeGemini { + info.DisablePing = true + } } - return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil + return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -113,7 +149,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn return nil, errors.New("request is nil") } - geminiRequest, err := CovertGemini2OpenAI(*request, info) + geminiRequest, err := CovertGemini2OpenAI(c, *request, info) if err != nil { return nil, err } @@ -134,29 +170,38 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela if len(inputs) == 0 { return nil, errors.New("input is empty") } - - // only process the first input - geminiRequest := GeminiEmbeddingRequest{ - Content: GeminiChatContent{ - Parts: []GeminiPart{ - { - Text: inputs[0], + // We always build a batch-style payload with `requests`, so ensure we call the + // batch endpoint upstream to avoid payload/endpoint mismatches. + info.IsGeminiBatchEmbedding = true + // process all inputs + geminiRequests := make([]map[string]interface{}, 0, len(inputs)) + for _, input := range inputs { + geminiRequest := map[string]interface{}{ + "model": fmt.Sprintf("models/%s", info.UpstreamModelName), + "content": dto.GeminiChatContent{ + Parts: []dto.GeminiPart{ + { + Text: input, + }, }, }, - }, - } - - // set specific parameters for different models - // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent - switch info.UpstreamModelName { - case "text-embedding-004": - // except embedding-001 supports setting `OutputDimensionality` - if request.Dimensions > 0 { - geminiRequest.OutputDimensionality = request.Dimensions } + + // set specific parameters for different models + // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent + switch info.UpstreamModelName { + case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001": + // Only newer models introduced after 2024 support OutputDimensionality + if request.Dimensions > 0 { + geminiRequest["outputDimensionality"] = request.Dimensions + } + } + geminiRequests = append(geminiRequests, geminiRequest) } - return geminiRequest, nil + return map[string]interface{}{ + "requests": geminiRequests, + }, nil } func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { @@ -170,6 +215,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { + if strings.HasSuffix(info.RequestURLPath, ":embedContent") || + strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") { + return NativeGeminiEmbeddingHandler(c, resp, info) + } if info.IsStream { return GeminiTextGenerationStreamHandler(c, info, resp) } else { @@ -194,72 +243,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return GeminiChatHandler(c, info, resp) } - //if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 { - // // 没有请求-thinking的情况下,产生思考token,则按照思考模型计费 - // if !strings.HasSuffix(info.OriginModelName, "-thinking") && - // !strings.HasSuffix(info.OriginModelName, "-nothinking") { - // thinkingModelName := info.OriginModelName + "-thinking" - // if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) { - // info.OriginModelName = thinkingModelName - // } - // } - //} - - return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody) -} - -func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - responseBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) - } - _ = resp.Body.Close() - - var geminiResponse GeminiImageResponse - if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) - } - - if len(geminiResponse.Predictions) == 0 { - return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody) - } - - // convert to openai format response - openAIResponse := dto.ImageResponse{ - Created: common.GetTimestamp(), - Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), - } - - for _, prediction := range geminiResponse.Predictions { - if prediction.RaiFilteredReason != "" { - continue // skip filtered image - } - openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ - B64Json: prediction.BytesBase64Encoded, - }) - } - - jsonResponse, jsonErr := json.Marshal(openAIResponse) - if jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) - } - - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, _ = c.Writer.Write(jsonResponse) - - // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb - // each image has fixed 258 tokens - const imageTokens = 258 - generatedImages := len(openAIResponse.Data) - - usage := &dto.Usage{ - PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens - CompletionTokens: 0, // image generation does not calculate completion tokens - TotalTokens: imageTokens * generatedImages, - } - - return usage, nil } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 0870e3fab..564b86908 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -5,22 +5,25 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" "one-api/types" "strings" + "github.com/pkg/errors" + "github.com/gin-gonic/gin" ) func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if common.DebugEnabled { @@ -28,10 +31,10 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } // 解析为 Gemini 原生响应格式 - var geminiResponse GeminiChatResponse + var geminiResponse dto.GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // 计算使用量(基于 UsageMetadata) @@ -43,6 +46,32 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount + if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") { + imageOutputCounts := 0 + for _, candidate := range geminiResponse.Candidates { + for _, part := range candidate.Content.Parts { + if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") { + imageOutputCounts++ + } + } + } + if imageOutputCounts != 0 { + usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290 + usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290 + c.Set("gemini_image_tokens", imageOutputCounts*1290) + } + } + + // if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") { + // for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails { + // if detail.Modality == "IMAGE" { + // usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount + // usage.TotalTokens = usage.TotalTokens - detail.TokenCount + // c.Set("gemini_image_tokens", detail.TokenCount) + // } + // } + // } + for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { if detail.Modality == "AUDIO" { usage.PromptTokensDetails.AudioTokens = detail.TokenCount @@ -51,17 +80,47 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re } } - // 直接返回 Gemini 原生格式的 JSON 响应 - jsonResponse, err := common.Marshal(geminiResponse) - if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) - } - - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, responseBody) return &usage, nil } +func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { + defer service.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if common.DebugEnabled { + println(string(responseBody)) + } + + usage := &dto.Usage{ + PromptTokens: info.PromptTokens, + TotalTokens: info.PromptTokens, + } + + if info.IsGeminiBatchEmbedding { + var geminiResponse dto.GeminiBatchEmbeddingResponse + err = common.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + } else { + var geminiResponse dto.GeminiEmbeddingResponse + err = common.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + } + + service.IOCopyBytesGracefully(c, resp, responseBody) + + return usage, nil +} + func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var usage = &dto.Usage{} var imageCount int @@ -71,10 +130,10 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn responseText := strings.Builder{} helper.StreamScannerHandler(c, resp, info, func(data string) bool { - var geminiResponse GeminiChatResponse + var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { - common.LogError(c, "error unmarshalling stream response: "+err.Error()) + logger.LogError(c, "error unmarshalling stream response: "+err.Error()) return false } @@ -103,17 +162,31 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn usage.PromptTokensDetails.TextTokens = detail.TokenCount } } + + if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") { + for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails { + if detail.Modality == "IMAGE" { + usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount + usage.TotalTokens = usage.TotalTokens - detail.TokenCount + c.Set("gemini_image_tokens", detail.TokenCount) + } + } + } } // 直接发送 GeminiChatResponse 响应 err = helper.StringData(c, data) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } - + info.SendResponseCount++ return true }) + if info.SendResponseCount == 0 { + return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError) + } + if imageCount != 0 { if usage.CompletionTokens == 0 { usage.CompletionTokens = imageCount * 258 diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 6f3babeb1..eb4afbae1 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -9,6 +9,8 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -48,12 +50,20 @@ const ( flash25LiteMaxBudget = 24576 ) -// clampThinkingBudget 根据模型名称将预算限制在允许的范围内 -func clampThinkingBudget(modelName string, budget int) int { - isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && +func isNew25ProModel(modelName string) bool { + return strings.HasPrefix(modelName, "gemini-2.5-pro") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25") - is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite") +} + +func is25FlashLiteModel(modelName string) bool { + return strings.HasPrefix(modelName, "gemini-2.5-flash-lite") +} + +// clampThinkingBudget 根据模型名称将预算限制在允许的范围内 +func clampThinkingBudget(modelName string, budget int) int { + isNew25Pro := isNew25ProModel(modelName) + is25FlashLite := is25FlashLiteModel(modelName) if is25FlashLite { if budget < flash25LiteMinBudget { @@ -80,7 +90,34 @@ func clampThinkingBudget(modelName string, budget int) int { return budget } -func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) { +// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens) +// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens) +// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens) +func clampThinkingBudgetByEffort(modelName string, effort string) int { + isNew25Pro := isNew25ProModel(modelName) + is25FlashLite := is25FlashLiteModel(modelName) + + maxBudget := 0 + if is25FlashLite { + maxBudget = flash25LiteMaxBudget + } + if isNew25Pro { + maxBudget = pro25MaxBudget + } else { + maxBudget = flash25MaxBudget + } + switch effort { + case "high": + maxBudget = maxBudget * 80 / 100 + case "medium": + maxBudget = maxBudget * 50 / 100 + case "low": + maxBudget = maxBudget * 20 / 100 + } + return clampThinkingBudget(modelName, maxBudget) +} + +func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { modelName := info.UpstreamModelName isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && @@ -92,7 +129,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn if len(parts) == 2 && parts[1] != "" { if budgetTokens, err := strconv.Atoi(parts[1]); err == nil { clampedBudget := clampThinkingBudget(modelName, budgetTokens) - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ ThinkingBudget: common.GetPointer(clampedBudget), IncludeThoughts: true, } @@ -112,22 +149,27 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn } if isUnsupported { - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ IncludeThoughts: true, } } else { - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ IncludeThoughts: true, } if geminiRequest.GenerationConfig.MaxOutputTokens > 0 { budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) clampedBudget := clampThinkingBudget(modelName, int(budgetTokens)) geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget) + } else { + if len(oaiRequest) > 0 { + // 如果有reasoningEffort参数,则根据其值设置思考预算 + geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort)) + } } } } else if strings.HasSuffix(modelName, "-nothinking") { if !isNew25Pro { - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ ThinkingBudget: common.GetPointer(0), } } @@ -136,14 +178,14 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn } // Setting safety to the lowest possible values since Gemini is already powerless enough -func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) { +func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) { - geminiRequest := GeminiChatRequest{ - Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), - GenerationConfig: GeminiChatGenerationConfig{ + geminiRequest := dto.GeminiChatRequest{ + Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)), + GenerationConfig: dto.GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, - MaxOutputTokens: textRequest.MaxTokens, + MaxOutputTokens: textRequest.GetMaxTokens(), Seed: int64(textRequest.Seed), }, } @@ -155,11 +197,41 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } } - ThinkingAdaptor(&geminiRequest, info) + adaptorWithExtraBody := false - safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList)) + if len(textRequest.ExtraBody) > 0 { + if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") { + var extraBody map[string]interface{} + if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil { + return nil, fmt.Errorf("invalid extra body: %w", err) + } + // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}} + if googleBody, ok := extraBody["google"].(map[string]interface{}); ok { + adaptorWithExtraBody = true + if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok { + if budget, ok := thinkingConfig["thinking_budget"].(float64); ok { + budgetInt := int(budget) + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ + ThinkingBudget: common.GetPointer(budgetInt), + IncludeThoughts: true, + } + } else { + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ + IncludeThoughts: true, + } + } + } + } + } + } + + if !adaptorWithExtraBody { + ThinkingAdaptor(&geminiRequest, info, textRequest) + } + + safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList)) for _, category := range SafetySettingList { - safetySettings = append(safetySettings, GeminiChatSafetySettings{ + safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{ Category: category, Threshold: model_setting.GetGeminiSafetySetting(category), }) @@ -196,32 +268,35 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon tool.Function.Parameters = cleanedParams functions = append(functions, tool.Function) } + geminiTools := geminiRequest.GetTools() if codeExecution { - geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ + geminiTools = append(geminiTools, dto.GeminiChatTool{ CodeExecution: make(map[string]string), }) } if googleSearch { - geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ + geminiTools = append(geminiTools, dto.GeminiChatTool{ GoogleSearch: make(map[string]string), }) } if len(functions) > 0 { - geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ + geminiTools = append(geminiTools, dto.GeminiChatTool{ FunctionDeclarations: functions, }) } - // common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools)) - // json_data, _ := json.Marshal(geminiRequest.Tools) - // common.SysLog("tools_json: " + string(json_data)) + geminiRequest.SetTools(geminiTools) } if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") { geminiRequest.GenerationConfig.ResponseMimeType = "application/json" - if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil { - cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0) - geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema + if len(textRequest.ResponseFormat.JsonSchema) > 0 { + // 先将json.RawMessage解析 + var jsonSchema dto.FormatJsonSchema + if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil { + cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0) + geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema + } } } tool_call_ids := make(map[string]string) @@ -233,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon continue } else if message.Role == "tool" || message.Role == "function" { if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" { - geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{ Role: "user", }) } @@ -260,18 +335,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } } - functionResp := &FunctionResponse{ + functionResp := &dto.GeminiFunctionResponse{ Name: name, Response: contentMap, } - *parts = append(*parts, GeminiPart{ + *parts = append(*parts, dto.GeminiPart{ FunctionResponse: functionResp, }) continue } - var parts []GeminiPart - content := GeminiChatContent{ + var parts []dto.GeminiPart + content := dto.GeminiChatContent{ Role: message.Role, } // isToolCall := false @@ -285,8 +360,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments) } } - toolCall := GeminiPart{ - FunctionCall: &FunctionCall{ + toolCall := dto.GeminiPart{ + FunctionCall: &dto.FunctionCall{ FunctionName: call.Function.Name, Arguments: args, }, @@ -303,7 +378,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon if part.Text == "" { continue } - parts = append(parts, GeminiPart{ + parts = append(parts, dto.GeminiPart{ Text: part.Text, }) } else if part.Type == dto.ContentTypeImageURL { @@ -315,7 +390,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon // 判断是否是url if strings.HasPrefix(part.GetImageMedia().Url, "http") { // 是url,获取文件的类型和base64编码的数据 - fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url) + fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini") if err != nil { return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err) } @@ -326,8 +401,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList()) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ + parts = append(parts, dto.GeminiPart{ + InlineData: &dto.GeminiInlineData{ MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义 Data: fileData.Base64Data, }, @@ -337,8 +412,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon if err != nil { return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ + parts = append(parts, dto.GeminiPart{ + InlineData: &dto.GeminiInlineData{ MimeType: format, Data: base64String, }, @@ -352,8 +427,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon if err != nil { return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error()) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ + parts = append(parts, dto.GeminiPart{ + InlineData: &dto.GeminiInlineData{ MimeType: format, Data: base64String, }, @@ -366,8 +441,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon if err != nil { return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error()) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ + parts = append(parts, dto.GeminiPart{ + InlineData: &dto.GeminiInlineData{ MimeType: "audio/" + part.GetInputAudio().Format, Data: base64String, }, @@ -387,8 +462,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } if len(system_content) > 0 { - geminiRequest.SystemInstructions = &GeminiChatContent{ - Parts: []GeminiPart{ + geminiRequest.SystemInstructions = &dto.GeminiChatContent{ + Parts: []dto.GeminiPart{ { Text: strings.Join(system_content, "\n"), }, @@ -631,7 +706,7 @@ func unescapeMapOrSlice(data interface{}) interface{} { return data } -func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse { +func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse { var argsBytes []byte var err error if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok { @@ -653,7 +728,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse { } } -func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse { +func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Id: helper.GetResponseID(c), Object: "chat.completion", @@ -674,7 +749,16 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt var texts []string var toolCalls []dto.ToolCallResponse for _, part := range candidate.Content.Parts { - if part.FunctionCall != nil { + if part.InlineData != nil { + // 媒体内容 + if strings.HasPrefix(part.InlineData.MimeType, "image") { + imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")" + texts = append(texts, imgText) + } else { + // 其他媒体类型,直接显示链接 + texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data)) + } + } else if part.FunctionCall != nil { choice.FinishReason = constant.FinishReasonToolCalls if call := getResponseToolCall(&part); call != nil { toolCalls = append(toolCalls, *call) @@ -720,10 +804,9 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt return &fullTextResponse } -func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) { +func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) { choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) isStop := false - hasImage := false for _, candidate := range geminiResponse.Candidates { if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" { isStop = true @@ -732,7 +815,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C choice := dto.ChatCompletionsStreamResponseChoice{ Index: int(candidate.Index), Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ - Role: "assistant", + //Role: "assistant", }, } var texts []string @@ -754,7 +837,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C if strings.HasPrefix(part.InlineData.MimeType, "image") { imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")" texts = append(texts, imgText) - hasImage = true } } else if part.FunctionCall != nil { isTools = true @@ -762,6 +844,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C call.SetIndex(len(choice.Delta.ToolCalls)) choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call) } + } else if part.Thought { isThought = true texts = append(texts, part.Text) @@ -791,28 +874,60 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Choices = choices - return &response, isStop, hasImage + return &response, isStop +} + +func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error { + streamData, err := common.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal stream response: %w", err) + } + err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) + if err != nil { + return fmt.Errorf("failed to handle stream format: %w", err) + } + return nil +} + +func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error { + streamData, err := common.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal stream response: %w", err) + } + openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false) + return nil } func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { // responseText := "" id := helper.GetResponseID(c) createAt := common.GetTimestamp() + responseText := strings.Builder{} var usage = &dto.Usage{} var imageCount int + finishReason := constant.FinishReasonStop helper.StreamScannerHandler(c, resp, info, func(data string) bool { - var geminiResponse GeminiChatResponse + var geminiResponse dto.GeminiChatResponse err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { - common.LogError(c, "error unmarshalling stream response: "+err.Error()) + logger.LogError(c, "error unmarshalling stream response: "+err.Error()) return false } - response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse) - if hasImage { - imageCount++ + for _, candidate := range geminiResponse.Candidates { + for _, part := range candidate.Content.Parts { + if part.InlineData != nil && part.InlineData.MimeType != "" { + imageCount++ + } + if part.Text != "" { + responseText.WriteString(part.Text) + } + } } + + response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse) + response.Id = id response.Created = createAt response.Model = info.UpstreamModelName @@ -829,18 +944,47 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * } } } - err = helper.ObjectData(c, response) + logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount)) + if info.SendResponseCount == 0 { + // send first response + emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil) + if response.IsToolCall() { + emptyResponse.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 1) + emptyResponse.Choices[0].Delta.ToolCalls[0] = *response.GetFirstToolCall() + emptyResponse.Choices[0].Delta.ToolCalls[0].Function.Arguments = "" + finishReason = constant.FinishReasonToolCalls + err = handleStream(c, info, emptyResponse) + if err != nil { + logger.LogError(c, err.Error()) + } + + response.ClearToolCalls() + if response.IsFinished() { + response.Choices[0].FinishReason = nil + } + } else { + err = handleStream(c, info, emptyResponse) + if err != nil { + logger.LogError(c, err.Error()) + } + } + } + + err = handleStream(c, info, response) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } if isStop { - response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) - helper.ObjectData(c, response) + _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)) } return true }) - var response *dto.ChatCompletionsStreamResponse + if info.SendResponseCount == 0 { + // 空补全,报错不计费 + // empty response, throw an error + return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError) + } if imageCount != 0 { if usage.CompletionTokens == 0 { @@ -851,14 +995,24 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * usage.PromptTokensDetails.TextTokens = usage.PromptTokens usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - if info.ShouldIncludeUsage { - response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) - err := helper.ObjectData(c, response) - if err != nil { - common.SysError("send final response failed: " + err.Error()) + if usage.CompletionTokens == 0 { + str := responseText.String() + if len(str) > 0 { + usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + // 空补全,不需要使用量 + usage = &dto.Usage{} } } - helper.Done(c) + + response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) + err := handleFinalStream(c, info, response) + if err != nil { + common.SysLog("send final response failed: " + err.Error()) + } + //if info.RelayFormat == relaycommon.RelayFormatOpenAI { + // helper.Done(c) + //} //resp.Body.Close() return usage, nil } @@ -866,19 +1020,19 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println(string(responseBody)) } - var geminiResponse GeminiChatResponse + var geminiResponse dto.GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if len(geminiResponse.Candidates) == 0 { - return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName @@ -900,40 +1054,55 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } fullTextResponse.Usage = usage - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + + switch info.RelayFormat { + case types.RelayFormatOpenAI: + responseBody, err = common.Marshal(fullTextResponse) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + case types.RelayFormatClaude: + claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info) + claudeRespStr, err := common.Marshal(claudeResp) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + responseBody = claudeRespStr + case types.RelayFormatGemini: + break } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - c.Writer.Write(jsonResponse) + + service.IOCopyBytesGracefully(c, resp, responseBody) + return &usage, nil } func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { - return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - var geminiResponse GeminiEmbeddingResponse + var geminiResponse dto.GeminiBatchEmbeddingResponse if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // convert to openai format response openAIResponse := dto.OpenAIEmbeddingResponse{ Object: "list", - Data: []dto.OpenAIEmbeddingResponseItem{ - { - Object: "embedding", - Embedding: geminiResponse.Embedding.Values, - Index: 0, - }, - }, - Model: info.UpstreamModelName, + Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)), + Model: info.UpstreamModelName, + } + + for i, embedding := range geminiResponse.Embeddings { + openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{ + Object: "embedding", + Embedding: embedding.Values, + Index: i, + }) } // calculate usage @@ -949,10 +1118,64 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h openAIResponse.Usage = *usage jsonResponse, jsonErr := common.Marshal(openAIResponse) + if jsonErr != nil { + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + service.IOCopyBytesGracefully(c, resp, jsonResponse) + return usage, nil +} + +func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + responseBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + _ = resp.Body.Close() + + var geminiResponse dto.GeminiImageResponse + if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if len(geminiResponse.Predictions) == 0 { + return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + // convert to openai format response + openAIResponse := dto.ImageResponse{ + Created: common.GetTimestamp(), + Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), + } + + for _, prediction := range geminiResponse.Predictions { + if prediction.RaiFilteredReason != "" { + continue // skip filtered image + } + openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + jsonResponse, jsonErr := json.Marshal(openAIResponse) if jsonErr != nil { return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) } - common.IOCopyBytesGracefully(c, resp, jsonResponse) + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb + // each image has fixed 258 tokens + const imageTokens = 258 + generatedImages := len(openAIResponse.Data) + + usage := &dto.Usage{ + PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens + CompletionTokens: 0, // image generation does not calculate completion tokens + TotalTokens: imageTokens * generatedImages, + } + return usage, nil } diff --git a/relay/channel/jimeng/adaptor.go b/relay/channel/jimeng/adaptor.go index 0b743879d..885a1427f 100644 --- a/relay/channel/jimeng/adaptor.go +++ b/relay/channel/jimeng/adaptor.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -13,11 +12,18 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/types" + + "github.com/gin-gonic/gin" ) type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { return nil, errors.New("not implemented") } @@ -26,7 +32,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil + return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go index 3c6a1d991..11a0117bb 100644 --- a/relay/channel/jimeng/image.go +++ b/relay/channel/jimeng/image.go @@ -5,9 +5,9 @@ import ( "fmt" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R var jimengResponse ImageResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &jimengResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // Check if the response indicates an error diff --git a/relay/channel/jimeng/sign.go b/relay/channel/jimeng/sign.go index c9db66301..d8b598dce 100644 --- a/relay/channel/jimeng/sign.go +++ b/relay/channel/jimeng/sign.go @@ -12,7 +12,7 @@ import ( "io" "net/http" "net/url" - "one-api/common" + "one-api/logger" "sort" "strings" "time" @@ -44,7 +44,7 @@ func SetPayloadHash(c *gin.Context, req any) error { if err != nil { return err } - common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) + logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) payloadHash := sha256.Sum256(body) hexPayloadHash := hex.EncodeToString(payloadHash[:]) c.Set(HexPayloadHashKey, hexPayloadHash) diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 408a5c6e4..a383728f7 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -19,6 +19,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -40,9 +45,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } return "", errors.New("invalid relay mode") } diff --git a/relay/channel/minimax/relay-minimax.go b/relay/channel/minimax/relay-minimax.go index d0a15b0da..ff9b72ea3 100644 --- a/relay/channel/minimax/relay-minimax.go +++ b/relay/channel/minimax/relay-minimax.go @@ -6,5 +6,5 @@ import ( ) func GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 434a1031c..f98ff8698 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -16,6 +16,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -36,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go index e26c61019..aa9257811 100644 --- a/relay/channel/mistral/text.go +++ b/relay/channel/mistral/text.go @@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), Tools: request.Tools, ToolChoice: request.ToolChoice, } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index b0b54b0c5..f9da685f2 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -18,6 +18,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -49,7 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if strings.HasPrefix(info.UpstreamModelName, "m3e") { suffix = "embeddings" } - fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix) + fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix) return fullRequestURL, nil } diff --git a/relay/channel/mokaai/constants.go b/relay/channel/mokaai/constants.go index 415d83b7f..385a0876b 100644 --- a/relay/channel/mokaai/constants.go +++ b/relay/channel/mokaai/constants.go @@ -6,4 +6,4 @@ var ModelList = []string{ "m3e-small", } -var ChannelName = "mokaai" \ No newline at end of file +var ChannelName = "mokaai" diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go index 78f96d6d6..d91aceb3d 100644 --- a/relay/channel/mokaai/relay-mokaai.go +++ b/relay/channel/mokaai/relay-mokaai.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -56,7 +57,7 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) @@ -77,6 +78,6 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &fullTextResponse.Usage, nil } diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go new file mode 100644 index 000000000..e290c239d --- /dev/null +++ b/relay/channel/moonshot/adaptor.go @@ -0,0 +1,110 @@ +package moonshot + +import ( + "errors" + "fmt" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/claude" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" + "one-api/relay/constant" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not supported") +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertImageRequest(c, info, request) +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + switch info.RelayFormat { + case types.RelayFormatClaude: + return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil + default: + if info.RelayMode == constant.RelayModeRerank { + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil + } else if info.RelayMode == constant.RelayModeEmbeddings { + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil + } else if info.RelayMode == constant.RelayModeChatCompletions { + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil + } else if info.RelayMode == constant.RelayModeCompletions { + return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil + } + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + return request, nil +} + +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return request, nil +} + +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + return request, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { + switch info.RelayFormat { + case types.RelayFormatClaude: + if info.IsStream { + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + } else { + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + } + default: + adaptor := openai.Adaptor{} + return adaptor.DoResponse(c, resp, info) + } +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index b9e304fcc..d6b5b697e 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -17,10 +17,21 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { //TODO implement me - panic("implement me") - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + openaiAdaptor := openai.Adaptor{} + openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request) + if err != nil { + return nil, err + } + openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{ + IncludeUsage: true, + } + return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest)) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -37,11 +48,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + if info.RelayFormat == types.RelayFormatClaude { + return info.ChannelBaseUrl + "/v1/chat/completions", nil + } switch info.RelayMode { case relayconstant.RelayModeEmbeddings: - return info.BaseUrl + "/api/embed", nil + return info.ChannelBaseUrl + "/api/embed", nil default: - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } } @@ -55,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - return requestOpenAI2Ollama(*request) + return requestOpenAI2Ollama(c, request) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -76,11 +90,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - if info.RelayMode == relayconstant.RelayModeEmbeddings { - usage, err = ollamaEmbeddingHandler(c, info, resp) + switch info.RelayMode { + case relayconstant.RelayModeEmbeddings: + usage, err = ollamaEmbeddingHandler(c, info, resp) + default: + if info.IsStream { + usage, err = openai.OaiStreamHandler(c, info, resp) } else { usage, err = openai.OpenaiHandler(c, info, resp) } diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index 15c64cdcd..317c2a4a1 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -1,6 +1,9 @@ package ollama -import "one-api/dto" +import ( + "encoding/json" + "one-api/dto" +) type OllamaRequest struct { Model string `json:"model,omitempty"` @@ -19,6 +22,7 @@ type OllamaRequest struct { Suffix any `json:"suffix,omitempty"` StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"` Prompt any `json:"prompt,omitempty"` + Think json.RawMessage `json:"think,omitempty"` } type Options struct { diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 295349e31..27c67b4ec 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -14,7 +14,7 @@ import ( "github.com/gin-gonic/gin" ) -func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) { +func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { if !message.IsStringContent() { @@ -24,7 +24,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err imageUrl := mediaMessage.GetImageMedia() // check if not base64 if strings.HasPrefix(imageUrl.Url, "http") { - fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) + fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama") if err != nil { return nil, err } @@ -50,7 +50,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err } else { Stop, _ = request.Stop.([]string) } - return &OllamaRequest{ + ollamaRequest := &OllamaRequest{ Model: request.Model, Messages: messages, Stream: request.Stream, @@ -60,14 +60,16 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err TopK: request.TopK, Stop: Stop, Tools: request.Tools, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), ResponseFormat: request.ResponseFormat, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, Prompt: request.Prompt, StreamOptions: request.StreamOptions, Suffix: request.Suffix, - }, nil + } + ollamaRequest.Think = request.Think + return ollamaRequest, nil } func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest { @@ -88,15 +90,15 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h var ollamaEmbeddingResponse OllamaEmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if ollamaEmbeddingResponse.Error != "" { - return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding) data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) @@ -117,9 +119,9 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h } doResponseBody, err := common.Marshal(embeddingResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - common.IOCopyBytesGracefully(c, resp, doResponseBody) + service.IOCopyBytesGracefully(c, resp, doResponseBody) return usage, nil } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index efd228781..1d8286a43 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -9,13 +9,13 @@ import ( "mime/multipart" "net/http" "net/textproto" + "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/ai360" "one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/minimax" - "one-api/relay/channel/moonshot" "one-api/relay/channel/openrouter" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" @@ -34,15 +34,55 @@ type Adaptor struct { ResponseFormat string } +// parseReasoningEffortFromModelSuffix 从模型名称中解析推理级别 +// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc... +// minimal effort only available in gpt-5 +func parseReasoningEffortFromModelSuffix(model string) (string, string) { + effortSuffixes := []string{"-high", "-minimal", "-low", "-medium"} + for _, suffix := range effortSuffixes { + if strings.HasSuffix(model, suffix) { + effort := strings.TrimPrefix(suffix, "-") + originModel := strings.TrimSuffix(model, suffix) + return effort, originModel + } + } + return "", model +} + +func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) { + // 使用 service.GeminiToOpenAIRequest 转换请求格式 + openaiRequest, err := service.GeminiToOpenAIRequest(request, info) + if err != nil { + return nil, err + } + return a.ConvertOpenAIRequest(c, info, openaiRequest) +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { //if !strings.Contains(request.Model, "claude") { // return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model) //} + //if common.DebugEnabled { + // bodyBytes := []byte(common.GetJsonString(request)) + // err := os.WriteFile(fmt.Sprintf("claude_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644) + // if err != nil { + // println(fmt.Sprintf("failed to save request body to file: %v", err)) + // } + //} aiRequest, err := service.ClaudeToOpenAIRequest(*request, info) if err != nil { return nil, err } - if info.SupportStreamOptions { + //if common.DebugEnabled { + // println(fmt.Sprintf("convert claude to openai request result: %s", common.GetJsonString(aiRequest))) + // // Save request body to file for debugging + // bodyBytes := []byte(common.GetJsonString(aiRequest)) + // err = os.WriteFile(fmt.Sprintf("claude_to_openai_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644) + // if err != nil { + // println(fmt.Sprintf("failed to save request body to file: %v", err)) + // } + //} + if info.SupportStreamOptions && info.IsStream { aiRequest.StreamOptions = &dto.StreamOptions{ IncludeUsage: true, } @@ -64,18 +104,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - if info.RelayFormat == relaycommon.RelayFormatClaude { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil - } if info.RelayMode == relayconstant.RelayModeRealtime { - if strings.HasPrefix(info.BaseUrl, "https://") { - baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") + if strings.HasPrefix(info.ChannelBaseUrl, "https://") { + baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://") baseUrl = "wss://" + baseUrl - info.BaseUrl = baseUrl - } else if strings.HasPrefix(info.BaseUrl, "http://") { - baseUrl := strings.TrimPrefix(info.BaseUrl, "http://") + info.ChannelBaseUrl = baseUrl + } else if strings.HasPrefix(info.ChannelBaseUrl, "http://") { + baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://") baseUrl = "ws://" + baseUrl - info.BaseUrl = baseUrl + info.ChannelBaseUrl = baseUrl } } switch info.ChannelType { @@ -89,10 +126,27 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) task := strings.TrimPrefix(requestURL, "/v1/") + if info.RelayFormat == types.RelayFormatClaude { + task = strings.TrimPrefix(task, "messages") + task = "chat/completions" + task + } + // 特殊处理 responses API if info.RelayMode == relayconstant.RelayModeResponses { - requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview") - return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + responsesApiVersion := "preview" + + subUrl := "/openai/v1/responses" + if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") { + subUrl = "/openai/responses" + responsesApiVersion = apiVersion + } + + if info.ChannelOtherSettings.AzureResponsesVersion != "" { + responsesApiVersion = info.ChannelOtherSettings.AzureResponsesVersion + } + + requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion) + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil } model_ := info.UpstreamModelName @@ -105,15 +159,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == relayconstant.RelayModeRealtime { requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion) } - return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil case constant.ChannelTypeMiniMax: return minimax.GetRequestURL(info) case constant.ChannelTypeCustom: - url := info.BaseUrl + url := info.ChannelBaseUrl url = strings.Replace(url, "{model}", info.UpstreamModelName, -1) return url, nil default: - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini { + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil + } + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } } @@ -163,28 +220,105 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if len(request.Usage) == 0 { request.Usage = json.RawMessage(`{"include":true}`) } + // 适配 OpenRouter 的 thinking 后缀 + if strings.HasSuffix(info.UpstreamModelName, "-thinking") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") + request.Model = info.UpstreamModelName + if len(request.Reasoning) == 0 { + reasoning := map[string]any{ + "enabled": true, + } + if request.ReasoningEffort != "" && request.ReasoningEffort != "none" { + reasoning["effort"] = request.ReasoningEffort + } + marshal, err := common.Marshal(reasoning) + if err != nil { + return nil, fmt.Errorf("error marshalling reasoning: %w", err) + } + request.Reasoning = marshal + } + // 清空多余的ReasoningEffort + request.ReasoningEffort = "" + } else { + if len(request.Reasoning) == 0 { + // 适配 OpenAI 的 ReasoningEffort 格式 + if request.ReasoningEffort != "" { + reasoning := map[string]any{ + "enabled": true, + } + if request.ReasoningEffort != "none" { + reasoning["effort"] = request.ReasoningEffort + marshal, err := common.Marshal(reasoning) + if err != nil { + return nil, fmt.Errorf("error marshalling reasoning: %w", err) + } + request.Reasoning = marshal + } + } + } + request.ReasoningEffort = "" + } + + // https://docs.anthropic.com/en/api/openai-sdk#extended-thinking-support + // 没有做排除3.5Haiku等,要出问题再加吧,最佳兼容性(不是 + if request.THINKING != nil && strings.HasPrefix(info.UpstreamModelName, "anthropic") { + var thinking dto.Thinking // Claude标准Thinking格式 + if err := json.Unmarshal(request.THINKING, &thinking); err != nil { + return nil, fmt.Errorf("error Unmarshal thinking: %w", err) + } + + // 只有当 thinking.Type 是 "enabled" 时才处理 + if thinking.Type == "enabled" { + // 检查 BudgetTokens 是否为 nil + if thinking.BudgetTokens == nil { + return nil, fmt.Errorf("BudgetTokens is nil when thinking is enabled") + } + + reasoning := openrouter.RequestReasoning{ + MaxTokens: *thinking.BudgetTokens, + } + + marshal, err := common.Marshal(reasoning) + if err != nil { + return nil, fmt.Errorf("error marshalling reasoning: %w", err) + } + + request.Reasoning = marshal + } + + // 清空 THINKING + request.THINKING = nil + } + } - if strings.HasPrefix(request.Model, "o") { + if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") { if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 { request.MaxCompletionTokens = request.MaxTokens request.MaxTokens = 0 } - request.Temperature = nil - if strings.HasSuffix(request.Model, "-high") { - request.ReasoningEffort = "high" - request.Model = strings.TrimSuffix(request.Model, "-high") - } else if strings.HasSuffix(request.Model, "-low") { - request.ReasoningEffort = "low" - request.Model = strings.TrimSuffix(request.Model, "-low") - } else if strings.HasSuffix(request.Model, "-medium") { - request.ReasoningEffort = "medium" - request.Model = strings.TrimSuffix(request.Model, "-medium") + + if strings.HasPrefix(info.UpstreamModelName, "o") { + request.Temperature = nil } + + if strings.HasPrefix(info.UpstreamModelName, "gpt-5") { + if info.UpstreamModelName != "gpt-5-chat-latest" { + request.Temperature = nil + } + } + + // 转换模型推理力度后缀 + effort, originModel := parseReasoningEffortFromModelSuffix(info.UpstreamModelName) + if effort != "" { + request.ReasoningEffort = effort + info.UpstreamModelName = originModel + request.Model = originModel + } + info.ReasoningEffort = request.ReasoningEffort - info.UpstreamModelName = request.Model // o系列模型developer适配(o1-mini除外) - if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") { + if !strings.HasPrefix(info.UpstreamModelName, "o1-mini") && !strings.HasPrefix(info.UpstreamModelName, "o1-preview") { //修改第一个Message的内容,将system改为developer if len(request.Messages) > 0 && request.Messages[0].Role == "system" { request.Messages[0].Role = "developer" @@ -260,40 +394,42 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf writer := multipart.NewWriter(&requestBody) writer.WriteField("model", request.Model) - // 获取所有表单字段 - formData := c.Request.PostForm - // 遍历表单字段并打印输出 - for key, values := range formData { - if key == "model" { - continue + // 使用已解析的 multipart 表单,避免重复解析 + mf := c.Request.MultipartForm + if mf == nil { + if _, err := c.MultipartForm(); err != nil { + return nil, errors.New("failed to parse multipart form") } - for _, value := range values { - writer.WriteField(key, value) + mf = c.Request.MultipartForm + } + + // 写入所有非文件字段 + if mf != nil { + for key, values := range mf.Value { + if key == "model" { + continue + } + for _, value := range values { + writer.WriteField(key, value) + } } } - // Parse the multipart form to handle both single image and multiple images - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory - return nil, errors.New("failed to parse multipart form") - } - - if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { + if mf != nil && mf.File != nil { // Check if "image" field exists in any form, including array notation var imageFiles []*multipart.FileHeader var exists bool // First check for standard "image" field - if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 { + if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 { // If not found, check for "image[]" field - if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 { + if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 { // If still not found, iterate through all fields to find any that start with "image[" foundArrayImages := false - for fieldName, files := range c.Request.MultipartForm.File { + for fieldName, files := range mf.File { if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { foundArrayImages = true - for _, file := range files { - imageFiles = append(imageFiles, file) - } + imageFiles = append(imageFiles, files...) } } @@ -310,7 +446,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf if err != nil { return nil, fmt.Errorf("failed to open image file %d: %w", i, err) } - defer file.Close() // If multiple images, use image[] as the field name fieldName := "image" @@ -334,15 +469,18 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf if _, err := io.Copy(part, file); err != nil { return nil, fmt.Errorf("copy file failed for image %d: %w", i, err) } + + // 复制完立即关闭,避免在循环内使用 defer 占用资源 + _ = file.Close() } // Handle mask file if present - if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 { + if maskFiles, exists := mf.File["mask"]; exists && len(maskFiles) > 0 { maskFile, err := maskFiles[0].Open() if err != nil { return nil, errors.New("failed to open mask file") } - defer maskFile.Close() + // 复制完立即关闭,避免在循环内使用 defer 占用资源 // Determine MIME type for mask file mimeType := detectImageMimeType(maskFiles[0].Filename) @@ -360,6 +498,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf if _, err := io.Copy(maskPart, maskFile); err != nil { return nil, errors.New("copy mask file failed") } + _ = maskFile.Close() } } else { return nil, errors.New("no multipart form data found") @@ -368,7 +507,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf // 关闭 multipart 编写器以设置分界线 writer.Close() c.Request.Header.Set("Content-Type", writer.FormDataContentType()) - return bytes.NewReader(requestBody.Bytes()), nil + return &requestBody, nil default: return request, nil @@ -396,16 +535,17 @@ func detectImageMimeType(filename string) string { } func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { - // 模型后缀转换 reasoning effort - if strings.HasSuffix(request.Model, "-high") { - request.Reasoning.Effort = "high" - request.Model = strings.TrimSuffix(request.Model, "-high") - } else if strings.HasSuffix(request.Model, "-low") { - request.Reasoning.Effort = "low" - request.Model = strings.TrimSuffix(request.Model, "-low") - } else if strings.HasSuffix(request.Model, "-medium") { - request.Reasoning.Effort = "medium" - request.Model = strings.TrimSuffix(request.Model, "-medium") + // 转换模型推理力度后缀 + effort, originModel := parseReasoningEffortFromModelSuffix(request.Model) + if effort != "" { + if request.Reasoning == nil { + request.Reasoning = &dto.Reasoning{ + Effort: effort, + } + } else { + request.Reasoning.Effort = effort + } + request.Model = originModel } return request, nil } @@ -456,8 +596,6 @@ func (a *Adaptor) GetModelList() []string { switch a.ChannelType { case constant.ChannelType360: return ai360.ModelList - case constant.ChannelTypeMoonshot: - return moonshot.ModelList case constant.ChannelTypeLingYiWanWu: return lingyiwanwu.ModelList case constant.ChannelTypeMiniMax: @@ -475,8 +613,6 @@ func (a *Adaptor) GetChannelName() string { switch a.ChannelType { case constant.ChannelType360: return ai360.ChannelName - case constant.ChannelTypeMoonshot: - return moonshot.ChannelName case constant.ChannelTypeLingYiWanWu: return lingyiwanwu.ChannelName case constant.ChannelTypeMiniMax: diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go index c703e414b..af5b67248 100644 --- a/relay/channel/openai/constant.go +++ b/relay/channel/openai/constant.go @@ -12,13 +12,25 @@ var ModelList = []string{ "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4.5-preview", "gpt-4.5-preview-2025-02-27", + "gpt-4.1", "gpt-4.1-2025-04-14", + "gpt-4.1-mini", "gpt-4.1-mini-2025-04-14", + "gpt-4.1-nano", "gpt-4.1-nano-2025-04-14", + "o1", "o1-2024-12-17", "o1-preview", "o1-preview-2024-09-12", "o1-mini", "o1-mini-2024-09-12", + "o1-pro", "o1-pro-2025-03-19", "o3-mini", "o3-mini-2025-01-31", "o3-mini-high", "o3-mini-2025-01-31-high", "o3-mini-low", "o3-mini-2025-01-31-low", "o3-mini-medium", "o3-mini-2025-01-31-medium", - "o1", "o1-2024-12-17", + "o3", "o3-2025-04-16", + "o3-pro", "o3-pro-2025-06-10", + "o3-deep-research", "o3-deep-research-2025-06-26", + "o4-mini", "o4-mini-2025-04-16", + "o4-mini-deep-research", "o4-mini-deep-research-2025-06-26", + "gpt-5", "gpt-5-2025-08-07", "gpt-5-chat-latest", + "gpt-5-mini", "gpt-5-mini-2025-08-07", + "gpt-5-nano", "gpt-5-nano-2025-08-07", "gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01", "gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17", "gpt-4o-mini-realtime-preview", "gpt-4o-mini-realtime-preview-2024-12-17", @@ -27,7 +39,7 @@ var ModelList = []string{ "text-moderation-latest", "text-moderation-stable", "text-davinci-edit-001", "davinci-002", "babbage-002", - "dall-e-3", + "dall-e-3", "gpt-image-1", "whisper-1", "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index a068c544c..e84f6cc4a 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -4,30 +4,37 @@ import ( "encoding/json" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" + "github.com/samber/lo" + "github.com/gin-gonic/gin" ) // 辅助函数 -func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { +func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { info.SendResponseCount++ + switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: return sendStreamData(c, info, data, forceFormat, thinkToContent) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: return handleClaudeFormat(c, data, info) + case types.RelayFormatGemini: + return handleGeminiFormat(c, data, info) } return nil } func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { var streamResponse dto.ChatCompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { + if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { return err } @@ -41,6 +48,32 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo return nil } +func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { + var streamResponse dto.ChatCompletionsStreamResponse + if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { + logger.LogError(c, "failed to unmarshal stream response: "+err.Error()) + return err + } + + geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info) + + // 如果返回 nil,表示没有实际内容,跳过发送 + if geminiResponse == nil { + return nil + } + + geminiResponseStr, err := common.Marshal(geminiResponse) + if err != nil { + logger.LogError(c, "failed to marshal gemini response: "+err.Error()) + return err + } + + // send gemini format response + c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)}) + _ = helper.FlushWriter(c) + return nil +} + func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) @@ -74,14 +107,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex var streamResponses []dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { return err } if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { - common.SysError("error processing stream response: " + err.Error()) + common.SysLog("error processing stream response: " + err.Error()) } } return nil @@ -110,7 +143,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui var streamResponses []dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { @@ -151,19 +184,21 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int *containStreamUsage = true *usage = lastStreamResponse.Usage if !info.ShouldIncludeUsage { - *shouldSendLastResp = false + *shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool { + return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != "" + }) } } return nil } -func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string, +func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string, responseId string, createAt int64, model string, systemFingerprint string, usage *dto.Usage, containStreamUsage bool) { switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: if info.ShouldIncludeUsage && !containStreamUsage { response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) @@ -171,11 +206,11 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream } helper.Done(c) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: info.ClaudeConvertInfo.Done = true var streamResponse dto.ChatCompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { + common.SysLog("error unmarshalling stream response: " + err.Error()) return } @@ -183,8 +218,37 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info) for _, resp := range claudeResponses { - helper.ClaudeData(c, *resp) + _ = helper.ClaudeData(c, *resp) } + + case types.RelayFormatGemini: + var streamResponse dto.ChatCompletionsStreamResponse + if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { + common.SysLog("error unmarshalling stream response: " + err.Error()) + return + } + + // 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段 + // 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应 + // 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null + // 暂不知是否有程序会不兼容。 + + geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info) + + // openai 流响应开头的空数据 + if geminiResponse == nil { + return + } + + geminiResponseStr, err := common.Marshal(geminiResponse) + if err != nil { + common.SysLog("error marshalling gemini response: " + err.Error()) + return + } + + // 发送最终的 Gemini 响应 + c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)}) + _ = helper.FlushWriter(c) } } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index bfe8bcd39..4b13a7df1 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -2,6 +2,7 @@ package openai import ( "bytes" + "encoding/json" "fmt" "io" "math" @@ -10,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -108,11 +110,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") - return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) + logger.LogError(c, "invalid response or response body") + return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) } - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) model := info.UpstreamModelName var responseId string @@ -123,30 +125,19 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var toolCount int var usage = &dto.Usage{} var streamItems []string // store stream items - var forceFormat bool - var thinkToContent bool - - if info.ChannelSetting.ForceFormat { - forceFormat = true - } - - if info.ChannelSetting.ThinkingToContent { - thinkToContent = true - } - - var ( - lastStreamData string - ) + var lastStreamData string helper.StreamScannerHandler(c, resp, info, func(data string) bool { if lastStreamData != "" { - err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent) + err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) if err != nil { - common.SysError("error handling stream format: " + err.Error()) + common.SysLog("error handling stream format: " + err.Error()) } } - lastStreamData = data - streamItems = append(streamItems, data) + if len(data) > 0 { + lastStreamData = data + streamItems = append(streamItems, data) + } return true }) @@ -154,16 +145,18 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re shouldSendLastResp := true if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage, &containStreamUsage, info, &shouldSendLastResp); err != nil { - common.SysError("error handling last response: " + err.Error()) + logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) } - if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI { - _ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) + if info.RelayFormat == types.RelayFormatOpenAI { + if shouldSendLastResp { + _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) + } } // 处理token计算 if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { - common.SysError("error processing tokens: " + err.Error()) + logger.LogError(c, "error processing tokens: "+err.Error()) } if !containStreamUsage { @@ -176,26 +169,28 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } } } - - handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) + HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) return usage, nil } func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) + } + if common.DebugEnabled { + println("upstream response body:", string(responseBody)) } err = common.Unmarshal(responseBody, &simpleResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - if simpleResponse.Error != nil && simpleResponse.Error.Type != "" { - return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode) + if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { + return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) } forceFormat := false @@ -203,21 +198,34 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo forceFormat = true } - if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { - completionTokens := 0 - for _, choice := range simpleResponse.Choices { - ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) - completionTokens += ctkm + usageModified := false + if simpleResponse.Usage.PromptTokens == 0 { + completionTokens := simpleResponse.Usage.CompletionTokens + if completionTokens == 0 { + for _, choice := range simpleResponse.Choices { + ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + completionTokens += ctkm + } } simpleResponse.Usage = dto.Usage{ PromptTokens: info.PromptTokens, CompletionTokens: completionTokens, TotalTokens: info.PromptTokens + completionTokens, } + usageModified = true } switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: + if usageModified { + var bodyMap map[string]interface{} + err = common.Unmarshal(responseBody, &bodyMap) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + bodyMap["usage"] = simpleResponse.Usage + responseBody, _ = common.Marshal(bodyMap) + } if forceFormat { responseBody, err = common.Marshal(simpleResponse) if err != nil { @@ -226,16 +234,23 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } else { break } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeRespStr, err := common.Marshal(claudeResp) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr + case types.RelayFormatGemini: + geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info) + geminiRespStr, err := common.Marshal(geminiResp) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + responseBody = geminiRespStr } - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) return &simpleResponse.Usage, nil } @@ -247,7 +262,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel // if the upstream returns a specific status code, once the upstream has already written the header, // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens @@ -258,26 +273,41 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel c.Writer.WriteHeaderNow() _, err := io.Copy(c.Writer, resp.Body) if err != nil { - common.LogError(c, err.Error()) + logger.LogError(c, err.Error()) } return usage } func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil + } + // 写入新的 response body + service.IOCopyBytesGracefully(c, resp, responseBody) + + var responseData struct { + Usage *dto.Usage `json:"usage"` + } + if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil { + if responseData.Usage.TotalTokens > 0 { + usage := responseData.Usage + if usage.PromptTokens == 0 { + usage.PromptTokens = usage.InputTokens + } + if usage.CompletionTokens == 0 { + usage.CompletionTokens = usage.OutputTokens + } + return nil, usage + } + } - // count tokens by audio file duration audioTokens, err := countAudioTokens(c) if err != nil { return types.NewError(err, types.ErrorCodeCountTokenFailed), nil } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil - } - // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) - usage := &dto.Usage{} usage.PromptTokens = audioTokens usage.CompletionTokens = 0 @@ -386,7 +416,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.InputTokens += textToken + audioToken localUsage.InputTokenDetails.TextTokens += textToken @@ -459,7 +489,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken info.IsFirstRequest = false localUsage.InputTokens += textToken + audioToken @@ -474,9 +504,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. localUsage = &dto.RealtimeUsage{} // print now usage } - common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session @@ -491,7 +521,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. errChan <- fmt.Errorf("error counting text token: %v", err) return } - common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) + logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.OutputTokens += textToken + audioToken localUsage.OutputTokenDetails.TextTokens += textToken @@ -517,7 +547,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types. case <-targetClosed: case err := <-errChan: //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil - common.LogError(c, "realtime error: "+err.Error()) + logger.LogError(c, "realtime error: "+err.Error()) case <-c.Done(): } @@ -553,21 +583,21 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R } func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } var usageResp dto.SimpleResponse err = common.Unmarshal(responseBody, &usageResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) // Once we've written to the client, we should not return errors anymore // because the upstream has already consumed resources and returned content diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index d9dd96b90..e188889e4 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -16,43 +17,58 @@ import ( ) func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) // read response body var responsesResponse dto.OpenAIResponsesResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } err = common.Unmarshal(responseBody, &responsesResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - if responsesResponse.Error != nil { - return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode) + if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { + return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) } // 写入新的 response body - common.IOCopyBytesGracefully(c, resp, responseBody) + service.IOCopyBytesGracefully(c, resp, responseBody) // compute usage usage := dto.Usage{} - usage.PromptTokens = responsesResponse.Usage.InputTokens - usage.CompletionTokens = responsesResponse.Usage.OutputTokens - usage.TotalTokens = responsesResponse.Usage.TotalTokens + if responsesResponse.Usage != nil { + usage.PromptTokens = responsesResponse.Usage.InputTokens + usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.TotalTokens = responsesResponse.Usage.TotalTokens + if responsesResponse.Usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens + } + } + if info == nil || info.ResponsesUsageInfo == nil || info.ResponsesUsageInfo.BuiltInTools == nil { + return &usage, nil + } // 解析 Tools 用量 for _, tool := range responsesResponse.Tools { - info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++ + buildToolinfo, ok := info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])] + if !ok || buildToolinfo == nil { + logger.LogError(c, fmt.Sprintf("BuiltInTools not found for tool type: %v", tool["type"])) + continue + } + buildToolinfo.CallCount++ } return &usage, nil } func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") + logger.LogError(c, "invalid response or response body") return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) } + defer service.CloseResponseBodyGracefully(resp) + var usage = &dto.Usage{} var responseTextBuilder strings.Builder @@ -64,9 +80,20 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp sendResponsesStreamData(c, streamResponse, data) switch streamResponse.Type { case "response.completed": - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + if streamResponse.Response != nil && streamResponse.Response.Usage != nil { + if streamResponse.Response.Usage.InputTokens != 0 { + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + } + if streamResponse.Response.Usage.OutputTokens != 0 { + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + } + if streamResponse.Response.Usage.TotalTokens != 0 { + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + } + if streamResponse.Response.Usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens + } + } case "response.output_text.delta": // 处理输出文本 responseTextBuilder.WriteString(streamResponse.Delta) @@ -79,6 +106,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp } } } + } else { + logger.LogError(c, "failed to unmarshal stream response: "+err.Error()) } return true }) @@ -93,5 +122,11 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp } } + if usage.PromptTokens == 0 && usage.CompletionTokens != 0 { + usage.PromptTokens = info.PromptTokens + } + + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return usage, nil } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index a60dc4b28..2a022a1b8 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -17,6 +17,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -37,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil + return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 4db315739..3a6ec2f4b 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -18,30 +18,6 @@ import ( // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body -func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest { - palmRequest := PaLMChatRequest{ - Prompt: PaLMPrompt{ - Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), - }, - Temperature: textRequest.Temperature, - CandidateCount: textRequest.N, - TopP: textRequest.TopP, - TopK: textRequest.MaxTokens, - } - for _, message := range textRequest.Messages { - palmMessage := PaLMChatMessage{ - Content: message.StringContent(), - } - if message.Role == "user" { - palmMessage.Author = "0" - } else { - palmMessage.Author = "1" - } - palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) - } - return &palmRequest -} - func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), @@ -82,15 +58,15 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, go func() { responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError("error reading stream response: " + err.Error()) + common.SysLog("error reading stream response: " + err.Error()) stopChan <- true return } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) stopChan <- true return } @@ -102,7 +78,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) stopChan <- true return } @@ -120,20 +96,20 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return nil, responseText } func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { return nil, types.WithOpenAIError(types.OpenAIError{ @@ -157,6 +133,6 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &usage, nil } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 19830aca4..8ab9c8547 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -17,6 +17,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -37,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/perplexity/relay-perplexity.go b/relay/channel/perplexity/relay-perplexity.go index 9772aead3..7ebadd0f9 100644 --- a/relay/channel/perplexity/relay-perplexity.go +++ b/relay/channel/perplexity/relay-perplexity.go @@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), } } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 63c1c84d7..4c176c088 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -18,20 +18,24 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { //TODO implement me - panic("implement me") - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - return nil, errors.New("not implemented") + return nil, errors.New("not supported") } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + adaptor := openai.Adaptor{} + return adaptor.ConvertImageRequest(c, info, request) } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -39,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeChatCompletions { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeCompletions { - return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil } - return "", errors.New("invalid relay mode") + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -81,16 +85,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom switch info.RelayMode { case constant.RelayModeRerank: usage, err = siliconflowRerankHandler(c, info, resp) + case constant.RelayModeEmbeddings: + usage, err = openai.OpenaiHandler(c, info, resp) case constant.RelayModeCompletions: fallthrough case constant.RelayModeChatCompletions: + fallthrough + default: if info.IsStream { usage, err = openai.OaiStreamHandler(c, info, resp) } else { usage, err = openai.OpenaiHandler(c, info, resp) } - case constant.RelayModeEmbeddings: - usage, err = openai.OpenaiHandler(c, info, resp) + } return } diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go index fabaf9c63..b21faccb7 100644 --- a/relay/channel/siliconflow/relay-siliconflow.go +++ b/relay/channel/siliconflow/relay-siliconflow.go @@ -4,9 +4,9 @@ import ( "encoding/json" "io" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -15,13 +15,13 @@ import ( func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) var siliconflowResp SFRerankResponse err = json.Unmarshal(responseBody, &siliconflowResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } usage := &dto.Usage{ PromptTokens: siliconflowResp.Meta.Tokens.InputTokens, @@ -39,6 +39,6 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 8d0575132..955e592a2 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -74,9 +74,9 @@ type TaskAdaptor struct { baseURL string } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl // apiKey format: "access_key|secret_key" keyParts := strings.Split(info.ApiKey, "|") @@ -87,7 +87,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { } // ValidateRequestAndSetAction parses body, validates fields and sets default action. -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Accept only POST /v1/video/generations as "generate" action. action := constant.TaskActionGenerate info.Action = action @@ -108,19 +108,19 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } // BuildRequestURL constructs the upstream URL. -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil } // BuildRequestHeader sets required headers. -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") return a.signRequest(req, a.accessKey, a.secretKey) } // BuildRequestBody converts request into Jimeng specific format. -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") @@ -139,12 +139,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel } // DoRequest delegates to common helper. -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index afa392016..3d6da253b 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -4,13 +4,14 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/samber/lo" "io" "net/http" "one-api/model" "strings" "time" + "github.com/samber/lo" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" "github.com/pkg/errors" @@ -37,19 +38,52 @@ type SubmitReq struct { Metadata map[string]interface{} `json:"metadata,omitempty"` } +type TrajectoryPoint struct { + X int `json:"x"` + Y int `json:"y"` +} + +type DynamicMask struct { + Mask string `json:"mask,omitempty"` + Trajectories []TrajectoryPoint `json:"trajectories,omitempty"` +} + +type CameraConfig struct { + Horizontal float64 `json:"horizontal,omitempty"` + Vertical float64 `json:"vertical,omitempty"` + Pan float64 `json:"pan,omitempty"` + Tilt float64 `json:"tilt,omitempty"` + Roll float64 `json:"roll,omitempty"` + Zoom float64 `json:"zoom,omitempty"` +} + +type CameraControl struct { + Type string `json:"type,omitempty"` + Config *CameraConfig `json:"config,omitempty"` +} + type requestPayload struct { - Prompt string `json:"prompt,omitempty"` - Image string `json:"image,omitempty"` - Mode string `json:"mode,omitempty"` - Duration string `json:"duration,omitempty"` - AspectRatio string `json:"aspect_ratio,omitempty"` - ModelName string `json:"model_name,omitempty"` - CfgScale float64 `json:"cfg_scale,omitempty"` + Prompt string `json:"prompt,omitempty"` + Image string `json:"image,omitempty"` + ImageTail string `json:"image_tail,omitempty"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Mode string `json:"mode,omitempty"` + Duration string `json:"duration,omitempty"` + AspectRatio string `json:"aspect_ratio,omitempty"` + ModelName string `json:"model_name,omitempty"` + Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model" + CfgScale float64 `json:"cfg_scale,omitempty"` + StaticMask string `json:"static_mask,omitempty"` + DynamicMasks []DynamicMask `json:"dynamic_masks,omitempty"` + CameraControl *CameraControl `json:"camera_control,omitempty"` + CallbackUrl string `json:"callback_url,omitempty"` + ExternalTaskId string `json:"external_task_id,omitempty"` } type responsePayload struct { Code int `json:"code"` Message string `json:"message"` + TaskId string `json:"task_id"` RequestId string `json:"request_id"` Data struct { TaskId string `json:"task_id"` @@ -73,25 +107,20 @@ type responsePayload struct { type TaskAdaptor struct { ChannelType int - accessKey string - secretKey string + apiKey string baseURL string } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl + a.apiKey = info.ApiKey // apiKey format: "access_key|secret_key" - keyParts := strings.Split(info.ApiKey, "|") - if len(keyParts) == 2 { - a.accessKey = strings.TrimSpace(keyParts[0]) - a.secretKey = strings.TrimSpace(keyParts[1]) - } } // ValidateRequestAndSetAction parses body, validates fields and sets default action. -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Accept only POST /v1/video/generations as "generate" action. action := constant.TaskActionGenerate info.Action = action @@ -112,13 +141,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } // BuildRequestURL constructs the upstream URL. -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video") return fmt.Sprintf("%s%s", a.baseURL, path), nil } // BuildRequestHeader sets required headers. -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { token, err := a.createJWTToken() if err != nil { return fmt.Errorf("failed to create JWT token: %w", err) @@ -132,7 +161,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info } // BuildRequestBody converts request into Kling specific format. -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") @@ -143,6 +172,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel if err != nil { return nil, err } + if body.Image == "" && body.ImageTail == "" { + c.Set("action", constant.TaskActionTextGenerate) + } data, err := json.Marshal(body) if err != nil { return nil, err @@ -151,7 +183,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel } // DoRequest delegates to common helper. -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { if action := c.GetString("action"); action != "" { info.Action = action } @@ -159,34 +191,26 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return } - // Attempt Kling response parse first. var kResp responsePayload - if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 { - c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId}) - return kResp.Data.TaskId, responseBody, nil - } - - // Fallback generic task response. - var generic dto.TaskResponse[string] - if err := json.Unmarshal(responseBody, &generic); err != nil { - taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + err = json.Unmarshal(responseBody, &kResp) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) return } - - if !generic.IsSuccess() { - taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError) + if kResp.Code != 0 { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest) return } - - c.JSON(http.StatusOK, gin.H{"task_id": generic.Data}) - return generic.Data, responseBody, nil + kResp.TaskId = kResp.Data.TaskId + c.JSON(http.StatusOK, kResp) + return kResp.Data.TaskId, responseBody, nil } // FetchTask fetch task status @@ -233,13 +257,19 @@ func (a *TaskAdaptor) GetChannelName() string { func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { r := requestPayload{ - Prompt: req.Prompt, - Image: req.Image, - Mode: defaultString(req.Mode, "std"), - Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), - AspectRatio: a.getAspectRatio(req.Size), - ModelName: req.Model, - CfgScale: 0.5, + Prompt: req.Prompt, + Image: req.Image, + Mode: defaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + AspectRatio: a.getAspectRatio(req.Size), + ModelName: req.Model, + Model: req.Model, // Keep consistent with model_name, double writing improves compatibility + CfgScale: 0.5, + StaticMask: "", + DynamicMasks: []DynamicMask{}, + CameraControl: nil, + CallbackUrl: "", + ExternalTaskId: "", } if r.ModelName == "" { r.ModelName = "kling-v1" @@ -288,21 +318,25 @@ func defaultInt(v int, def int) int { // ============================ func (a *TaskAdaptor) createJWTToken() (string, error) { - return a.createJWTTokenWithKeys(a.accessKey, a.secretKey) + return a.createJWTTokenWithKey(a.apiKey) } +//func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { +// parts := strings.Split(apiKey, "|") +// if len(parts) != 2 { +// return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'") +// } +// return a.createJWTTokenWithKey(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])) +//} + func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { - parts := strings.Split(apiKey, "|") - if len(parts) != 2 { - return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'") - } - return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])) -} -func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) { - if accessKey == "" || secretKey == "" { - return "", fmt.Errorf("access key and secret key are required") + keyParts := strings.Split(apiKey, "|") + accessKey := strings.TrimSpace(keyParts[0]) + if len(keyParts) == 1 { + return accessKey, nil } + secretKey := strings.TrimSpace(keyParts[1]) now := time.Now().Unix() claims := jwt.MapClaims{ "iss": accessKey, @@ -315,12 +349,12 @@ func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (strin } func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + taskInfo := &relaycommon.TaskInfo{} resPayload := responsePayload{} err := json.Unmarshal(respBody, &resPayload) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } - taskInfo := &relaycommon.TaskInfo{} taskInfo.Code = resPayload.Code taskInfo.TaskID = resPayload.Data.TaskId taskInfo.Reason = resPayload.Message diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 9c04c7ad4..237513d75 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -16,6 +15,8 @@ import ( "one-api/service" "strings" "time" + + "github.com/gin-gonic/gin" ) type TaskAdaptor struct { @@ -26,11 +27,11 @@ func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, fmt.Errorf("not implement") // todo implement this method if needed } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType } -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { action := strings.ToUpper(c.Param("action")) var sunoRequest *dto.SunoSubmitReq @@ -58,20 +59,20 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return nil } -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { - baseURL := info.BaseUrl +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { + baseURL := info.ChannelBaseUrl fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action) return fullRequestURL, nil } -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Authorization", "Bearer "+info.ApiKey) return nil } -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { sunoRequest, ok := c.Get("task_request") if !ok { err := common.UnmarshalBodyReusable(c, &sunoRequest) @@ -86,11 +87,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel return bytes.NewReader(data), nil } -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -139,7 +140,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody)) if err != nil { - common.SysError(fmt.Sprintf("Get Task error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task error: %v", err)) return nil, err } defer req.Body.Close() diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go new file mode 100644 index 000000000..c82c1c0e8 --- /dev/null +++ b/relay/channel/task/vidu/adaptor.go @@ -0,0 +1,285 @@ +package vidu + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + + "one-api/constant" + "one-api/dto" + "one-api/model" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" + + "github.com/pkg/errors" +) + +// ============================ +// Request / Response structures +// ============================ + +type SubmitReq struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Mode string `json:"mode,omitempty"` + Image string `json:"image,omitempty"` + Size string `json:"size,omitempty"` + Duration int `json:"duration,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type requestPayload struct { + Model string `json:"model"` + Images []string `json:"images"` + Prompt string `json:"prompt,omitempty"` + Duration int `json:"duration,omitempty"` + Seed int `json:"seed,omitempty"` + Resolution string `json:"resolution,omitempty"` + MovementAmplitude string `json:"movement_amplitude,omitempty"` + Bgm bool `json:"bgm,omitempty"` + Payload string `json:"payload,omitempty"` + CallbackUrl string `json:"callback_url,omitempty"` +} + +type responsePayload struct { + TaskId string `json:"task_id"` + State string `json:"state"` + Model string `json:"model"` + Images []string `json:"images"` + Prompt string `json:"prompt"` + Duration int `json:"duration"` + Seed int `json:"seed"` + Resolution string `json:"resolution"` + Bgm bool `json:"bgm"` + MovementAmplitude string `json:"movement_amplitude"` + Payload string `json:"payload"` + CreatedAt string `json:"created_at"` +} + +type taskResultResponse struct { + State string `json:"state"` + ErrCode string `json:"err_code"` + Credits int `json:"credits"` + Payload string `json:"payload"` + Creations []creation `json:"creations"` +} + +type creation struct { + ID string `json:"id"` + URL string `json:"url"` + CoverURL string `json:"cover_url"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.ChannelBaseUrl +} + +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { + var req SubmitReq + if err := c.ShouldBindJSON(&req); err != nil { + return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest) + } + + if req.Prompt == "" { + return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest) + } + + if req.Image != "" { + info.Action = constant.TaskActionGenerate + } else { + info.Action = constant.TaskActionTextGenerate + } + + c.Set("task_request", req) + return nil +} + +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { + v, exists := c.Get("task_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req := v.(SubmitReq) + + body, err := a.convertToRequestPayload(&req) + if err != nil { + return nil, err + } + + if len(body.Images) == 0 { + c.Set("action", constant.TaskActionTextGenerate) + } + + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { + var path string + switch info.Action { + case constant.TaskActionGenerate: + path = "/img2video" + default: + path = "/text2video" + } + return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil +} + +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Token "+info.ApiKey) + return nil +} + +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + if action := c.GetString("action"); action != "" { + info.Action = action + } + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + + var vResp responsePayload + err = json.Unmarshal(responseBody, &vResp) + if err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError) + return + } + + if vResp.State == "failed" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest) + return + } + + c.JSON(http.StatusOK, vResp) + return vResp.TaskId, responseBody, nil +} + +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + + url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Token "+key) + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return []string{"viduq1", "vidu2.0", "vidu1.5"} +} + +func (a *TaskAdaptor) GetChannelName() string { + return "vidu" +} + +// ============================ +// helpers +// ============================ + +func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { + var images []string + if req.Image != "" { + images = []string{req.Image} + } + + r := requestPayload{ + Model: defaultString(req.Model, "viduq1"), + Images: images, + Prompt: req.Prompt, + Duration: defaultInt(req.Duration, 5), + Resolution: defaultString(req.Size, "1080p"), + MovementAmplitude: "auto", + Bgm: false, + } + metadata := req.Metadata + medaBytes, err := json.Marshal(metadata) + if err != nil { + return nil, errors.Wrap(err, "metadata marshal metadata failed") + } + err = json.Unmarshal(medaBytes, &r) + if err != nil { + return nil, errors.Wrap(err, "unmarshal metadata failed") + } + return &r, nil +} + +func defaultString(value, defaultValue string) string { + if value == "" { + return defaultValue + } + return value +} + +func defaultInt(value, defaultValue int) int { + if value == 0 { + return defaultValue + } + return value +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + taskInfo := &relaycommon.TaskInfo{} + + var taskResp taskResultResponse + err := json.Unmarshal(respBody, &taskResp) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal response body") + } + + state := taskResp.State + switch state { + case "created", "queueing": + taskInfo.Status = model.TaskStatusSubmitted + case "processing": + taskInfo.Status = model.TaskStatusInProgress + case "success": + taskInfo.Status = model.TaskStatusSuccess + if len(taskResp.Creations) > 0 { + taskInfo.Url = taskResp.Creations[0].URL + } + case "failed": + taskInfo.Status = model.TaskStatusFailure + if taskResp.ErrCode != "" { + taskInfo.Reason = taskResp.ErrCode + } + default: + return nil, fmt.Errorf("unknown task state: %s", state) + } + + return taskInfo, nil +} diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 520276a7e..ab96ecaa3 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -25,6 +25,11 @@ type Adaptor struct { Timestamp int64 } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -48,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/", info.BaseUrl), nil + return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index c3d96c49a..f33a275c6 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -106,7 +106,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt var tencentResponse TencentChatResponse err := json.Unmarshal([]byte(data), &tencentResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) continue } @@ -117,17 +117,17 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt err = helper.ObjectData(c, response) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) } } if err := scanner.Err(); err != nil { - common.SysError("error reading stream: " + err.Error()) + common.SysLog("error reading stream: " + err.Error()) } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil } @@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp var tencentSb TencentChatResponseSB responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &tencentSb) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if tencentSb.Response.Error.Code != 0 { return nil, types.WithOpenAIError(types.OpenAIError{ @@ -156,7 +156,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - common.IOCopyBytesGracefully(c, resp, jsonResponse) + service.IOCopyBytesGracefully(c, resp, jsonResponse) return &fullTextResponse.Usage, nil } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index fa895de08..0b6b26743 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -35,6 +35,7 @@ var claudeModelMap = map[string]string{ "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219", "claude-sonnet-4-20250514": "claude-sonnet-4@20250514", "claude-opus-4-20250514": "claude-opus-4@20250514", + "claude-opus-4-1-20250805": "claude-opus-4-1@20250805", } const anthropicVersion = "vertex-2023-10-16" @@ -44,6 +45,11 @@ type Adaptor struct { AccountCredentials Credentials } +func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) { + geminiAdaptor := gemini.Adaptor{} + return geminiAdaptor.ConvertGeminiRequest(c, info, request) +} + func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { if v, ok := claudeModelMap[info.UpstreamModelName]; ok { c.Set("request_model", v) @@ -60,17 +66,17 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + geminiAdaptor := gemini.Adaptor{} + return geminiAdaptor.ConvertImageRequest(c, info, request) } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude") { a.RequestMode = RequestModeClaude - } else if strings.HasPrefix(info.UpstreamModelName, "gemini") { - a.RequestMode = RequestModeGemini } else if strings.Contains(info.UpstreamModelName, "llama") { a.RequestMode = RequestModeLlama + } else { + a.RequestMode = RequestModeGemini } } @@ -83,6 +89,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { a.AccountCredentials = *adc suffix := "" if a.RequestMode == RequestModeGemini { + if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { // 新增逻辑:处理 -thinking- 格式 if strings.Contains(info.UpstreamModelName, "-thinking-") { @@ -100,6 +107,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } else { suffix = "generateContent" } + + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + suffix = "predict" + } + if region == "global" { return fmt.Sprintf( "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", @@ -169,8 +181,62 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + if a.RequestMode == RequestModeGemini && strings.HasPrefix(info.UpstreamModelName, "imagen") { + prompt := "" + for _, m := range request.Messages { + if m.Role == "user" { + prompt = m.StringContent() + if prompt != "" { + break + } + } + } + if prompt == "" { + if p, ok := request.Prompt.(string); ok { + prompt = p + } + } + if prompt == "" { + return nil, errors.New("prompt is required for image generation") + } + + imgReq := dto.ImageRequest{ + Model: request.Model, + Prompt: prompt, + N: 1, + Size: "1024x1024", + } + if request.N > 0 { + imgReq.N = uint(request.N) + } + if request.Size != "" { + imgReq.Size = request.Size + } + if len(request.ExtraBody) > 0 { + var extra map[string]any + if err := json.Unmarshal(request.ExtraBody, &extra); err == nil { + if n, ok := extra["n"].(float64); ok && n > 0 { + imgReq.N = uint(n) + } + if size, ok := extra["size"].(string); ok { + imgReq.Size = size + } + // accept aspectRatio in extra body (top-level or under parameters) + if ar, ok := extra["aspectRatio"].(string); ok && ar != "" { + imgReq.Size = ar + } + if params, ok := extra["parameters"].(map[string]any); ok { + if ar, ok := params["aspectRatio"].(string); ok && ar != "" { + imgReq.Size = ar + } + } + } + } + c.Set("request_model", request.Model) + return a.ConvertImageRequest(c, info, imgReq) + } if a.RequestMode == RequestModeClaude { - claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request) + claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request) if err != nil { return nil, err } @@ -179,7 +245,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn info.UpstreamModelName = claudeReq.Model return vertexClaudeReq, nil } else if a.RequestMode == RequestModeGemini { - geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info) + geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info) if err != nil { return nil, err } @@ -213,28 +279,31 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { switch a.RequestMode { case RequestModeClaude: - err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: if info.RelayMode == constant.RelayModeGemini { - usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp) + return gemini.GeminiTextGenerationStreamHandler(c, info, resp) } else { - usage, err = gemini.GeminiChatStreamHandler(c, info, resp) + return gemini.GeminiChatStreamHandler(c, info, resp) } case RequestModeLlama: - usage, err = openai.OaiStreamHandler(c, info, resp) + return openai.OaiStreamHandler(c, info, resp) } } else { switch a.RequestMode { case RequestModeClaude: - err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: if info.RelayMode == constant.RelayModeGemini { - usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) + return gemini.GeminiTextGenerationHandler(c, info, resp) } else { - usage, err = gemini.GeminiChatHandler(c, info, resp) + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + return gemini.GeminiImageHandler(c, info, resp) + } + return gemini.GeminiChatHandler(c, info, resp) } case RequestModeLlama: - usage, err = openai.OpenaiHandler(c, info, resp) + return openai.OpenaiHandler(c, info, resp) } } return diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go index 5a97c021e..9a4650d98 100644 --- a/relay/channel/vertex/service_account.go +++ b/relay/channel/vertex/service_account.go @@ -36,7 +36,12 @@ var Cache = asynccache.NewAsyncCache(asynccache.Options{ }) func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) { - cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId) + var cacheKey string + if info.ChannelIsMultiKey { + cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex) + } else { + cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId) + } val, err := Cache.Get(cacheKey) if err == nil { return val.(string), nil diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index af15d6367..0af019da4 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -2,6 +2,7 @@ package volcengine import ( "bytes" + "encoding/json" "errors" "fmt" "io" @@ -23,10 +24,14 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { //TODO implement me - panic("implement me") - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -184,13 +189,17 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: if strings.HasPrefix(info.UpstreamModelName, "bot") { - return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil + case constant.RelayModeImagesEdits: + return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil + case constant.RelayModeRerank: + return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) @@ -206,6 +215,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + // 适配 方舟deepseek混合模型 的 thinking 后缀 + if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") + request.Model = info.UpstreamModelName + request.THINKING = json.RawMessage(`{"type": "enabled"}`) + } return request, nil } @@ -227,18 +242,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - switch info.RelayMode { - case constant.RelayModeChatCompletions: - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) - } - case constant.RelayModeEmbeddings: - usage, err = openai.OpenaiHandler(c, info, resp) - case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: - usage, err = openai.OpenaiHandlerWithUsage(c, info, resp) - } + adaptor := openai.Adaptor{} + usage, err = adaptor.DoResponse(c, resp, info) return } diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 8d880137e..d5671ab2f 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -19,6 +19,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me //panic("implement me") @@ -34,7 +39,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf xaiRequest := ImageRequest{ Model: request.Model, Prompt: request.Prompt, - N: request.N, + N: int(request.N), ResponseFormat: request.ResponseFormat, } return xaiRequest, nil @@ -44,7 +49,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 4d098102a..5cae9c0ae 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -47,7 +47,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var xAIResp *dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &xAIResp) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } @@ -63,7 +63,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount) err = helper.ObjectData(c, openaiResponse) if err != nil { - common.SysError(err.Error()) + common.SysLog(err.Error()) } return true }) @@ -74,12 +74,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } helper.Done(c) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return usage, nil } func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - defer common.CloseResponseBodyGracefully(resp) + defer service.CloseResponseBodyGracefully(resp) responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -101,7 +101,7 @@ func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - common.IOCopyBytesGracefully(c, resp, encodeJson) + service.IOCopyBytesGracefully(c, resp, encodeJson) return xaiResponse.Usage, nil } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 0d218adaf..7ee76f1ad 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -17,6 +17,11 @@ type Adaptor struct { request *dto.GeneralOpenAIRequest } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 373ad6054..9d5c190fe 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N - xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens + xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens() xunfeiRequest.Payload.Message.Text = messages return &xunfeiRequest } @@ -143,7 +143,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap if err != nil || resp.StatusCode != 101 { return nil, nil, err } + + defer func() { + conn.Close() + }() + data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { @@ -218,20 +223,19 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap for { _, msg, err := conn.ReadMessage() if err != nil { - common.SysError("error reading stream response: " + err.Error()) + common.SysLog("error reading stream response: " + err.Error()) break } var response XunfeiChatResponse err = json.Unmarshal(msg, &response) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { - err := conn.Close() if err != nil { - common.SysError("error closing websocket connection: " + err.Error()) + common.SysLog("error closing websocket connection: " + err.Error()) } break } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 433444289..bd27c90b0 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -16,6 +16,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { //TODO implement me panic("implement me") @@ -40,7 +45,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.IsStream { method = "sse-invoke" } - return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 916a200de..8eb0dcc13 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -10,6 +10,7 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" + "one-api/service" "one-api/types" "strings" "sync" @@ -38,7 +39,7 @@ func getZhipuToken(apikey string) string { split := strings.Split(apikey, ".") if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) + common.SysLog("invalid zhipu key: " + apikey) return "" } @@ -186,7 +187,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseZhipu2OpenAI(data) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -195,13 +196,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var zhipuResponse ZhipuStreamMetaResponse err := json.Unmarshal([]byte(data), &zhipuResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } usage = zhipuUsage @@ -212,7 +213,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. return false } }) - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) return usage, nil } @@ -220,12 +221,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon var zhipuResponse ZhipuResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if !zhipuResponse.Success { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index edd7a5345..37c0c3521 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/claude" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -18,10 +19,13 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { //TODO implement me - panic("implement me") - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + return req, nil } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -38,19 +42,22 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl) - switch info.RelayMode { - case relayconstant.RelayModeEmbeddings: - return fmt.Sprintf("%s/embeddings", baseUrl), nil + switch info.RelayFormat { + case types.RelayFormatClaude: + return fmt.Sprintf("%s/api/anthropic/v1/messages", info.ChannelBaseUrl), nil default: - return fmt.Sprintf("%s/chat/completions", baseUrl), nil + switch info.RelayMode { + case relayconstant.RelayModeEmbeddings: + return fmt.Sprintf("%s/api/paas/v4/embeddings", info.ChannelBaseUrl), nil + default: + return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.ChannelBaseUrl), nil + } } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - token := getZhipuToken(info.ApiKey) - req.Set("Authorization", token) + req.Set("Authorization", "Bearer "+info.ApiKey) return nil } @@ -82,12 +89,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) + switch info.RelayFormat { + case types.RelayFormatClaude: + if info.IsStream { + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + } else { + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + } + default: + adaptor := openai.Adaptor{} + return adaptor.DoResponse(c, resp, info) } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index 271dda8ff..aec87dd5d 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -1,69 +1,10 @@ package zhipu_4v import ( - "github.com/golang-jwt/jwt" - "one-api/common" "one-api/dto" "strings" - "sync" - "time" ) -// https://open.bigmodel.cn/doc/api#chatglm_std -// chatglm_std, chatglm_lite -// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke -// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke - -var zhipuTokens sync.Map -var expSeconds int64 = 24 * 3600 - -func getZhipuToken(apikey string) string { - data, ok := zhipuTokens.Load(apikey) - if ok { - tokenData := data.(tokenData) - if time.Now().Before(tokenData.ExpiryTime) { - return tokenData.Token - } - } - - split := strings.Split(apikey, ".") - if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) - return "" - } - - id := split[0] - secret := split[1] - - expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 - expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) - - timestamp := time.Now().UnixNano() / 1e6 - - payload := jwt.MapClaims{ - "api_key": id, - "exp": expMillis, - "timestamp": timestamp, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) - - token.Header["alg"] = "HS256" - token.Header["sign_type"] = "SIGN" - - tokenString, err := token.SignedString([]byte(secret)) - if err != nil { - return "" - } - - zhipuTokens.Store(apikey, tokenData{ - Token: tokenString, - ExpiryTime: expiryTime, - }) - - return tokenString -} - func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { @@ -105,9 +46,10 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), Stop: Stop, Tools: request.Tools, ToolChoice: request.ToolChoice, + THINKING: request.THINKING, } } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 5f38960e5..59c052f62 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "errors" "fmt" "io" "net/http" @@ -18,119 +17,99 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { - textRequest = &dto.ClaudeRequest{} - err = c.ShouldBindJSON(textRequest) +func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + + info.InitChannelMeta(c) + + claudeReq, ok := info.Request.(*dto.ClaudeRequest) + + if !ok { + return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + + request, err := common.DeepCopy(claudeReq) if err != nil { - return nil, err + return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return nil, errors.New("field messages is required") - } - if textRequest.Model == "" { - return nil, errors.New("field model is required") - } - return textRequest, nil -} -func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - - relayInfo := relaycommon.GenRelayInfoClaude(c) - - // get & validate textRequest 获取并验证文本请求 - textRequest, err := getAndValidateClaudeRequest(c) + err = helper.ModelMappedHelper(c, info, request) if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - if textRequest.Stream { - relayInfo.IsStream = true - } - - err = helper.ModelMappedHelper(c, relayInfo, textRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) - } - - promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) - // count messages token error 计算promptTokens错误 - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) - var requestBody io.Reader + adaptor.Init(info) - if textRequest.MaxTokens == 0 { - textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model)) + if request.MaxTokens == 0 { + request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model)) } if model_setting.GetClaudeSettings().ThinkingAdapterEnabled && - strings.HasSuffix(textRequest.Model, "-thinking") { - if textRequest.Thinking == nil { + strings.HasSuffix(request.Model, "-thinking") { + if request.Thinking == nil { // 因为BudgetTokens 必须大于1024 - if textRequest.MaxTokens < 1280 { - textRequest.MaxTokens = 1280 + if request.MaxTokens < 1280 { + request.MaxTokens = 1280 } // BudgetTokens 为 max_tokens 的 80% - textRequest.Thinking = &dto.Thinking{ + request.Thinking = &dto.Thinking{ Type: "enabled", - BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)), + BudgetTokens: common.GetPointer[int](int(float64(request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)), } // TODO: 临时处理 // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking - textRequest.TopP = 0 - textRequest.Temperature = common.GetPointer[float64](1.0) + request.TopP = 0 + request.Temperature = common.GetPointer[float64](1.0) } - textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") - relayInfo.UpstreamModelName = textRequest.Model + request.Model = strings.TrimSuffix(request.Model, "-thinking") + info.UpstreamModelName = request.Model } - convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + var requestBody io.Reader + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { + body, err := common.GetRequestBody(c) + if err != nil { + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + requestBody = bytes.NewBuffer(body) + } else { + convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + jsonData, err := common.Marshal(convertedRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + + // apply param override + if len(info.ParamOverride) > 0 { + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + } + } + + if common.DebugEnabled { + println("requestBody: ", string(jsonData)) + } + requestBody = bytes.NewBuffer(jsonData) } - jsonData, err := common.Marshal(convertedRequest) - if common.DebugEnabled { - println("requestBody: ", string(jsonData)) - } - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) - } - requestBody = bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -139,24 +118,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) //log.Printf("usage: %v", usage) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + + service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage)) return nil } - -func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - default: - promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName) - } - info.PromptTokens = promptTokens - return promptTokens, err -} diff --git a/relay/common/override.go b/relay/common/override.go new file mode 100644 index 000000000..212cf7b47 --- /dev/null +++ b/relay/common/override.go @@ -0,0 +1,435 @@ +package common + +import ( + "encoding/json" + "fmt" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "regexp" + "strconv" + "strings" +) + +type ConditionOperation struct { + Path string `json:"path"` // JSON路径 + Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte + Value interface{} `json:"value"` // 匹配的值 + Invert bool `json:"invert"` // 反选功能,true表示取反结果 + PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为 +} + +type ParamOperation struct { + Path string `json:"path"` + Mode string `json:"mode"` // delete, set, move, prepend, append + Value interface{} `json:"value"` + KeepOrigin bool `json:"keep_origin"` + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` + Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表 + Logic string `json:"logic,omitempty"` // AND, OR (默认OR) +} + +func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) { + if len(paramOverride) == 0 { + return jsonData, nil + } + + // 尝试断言为操作格式 + if operations, ok := tryParseOperations(paramOverride); ok { + // 使用新方法 + result, err := applyOperations(string(jsonData), operations) + return []byte(result), err + } + + // 直接使用旧方法 + return applyOperationsLegacy(jsonData, paramOverride) +} + +func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) { + // 检查是否包含 "operations" 字段 + if opsValue, exists := paramOverride["operations"]; exists { + if opsSlice, ok := opsValue.([]interface{}); ok { + var operations []ParamOperation + for _, op := range opsSlice { + if opMap, ok := op.(map[string]interface{}); ok { + operation := ParamOperation{} + + // 断言必要字段 + if path, ok := opMap["path"].(string); ok { + operation.Path = path + } + if mode, ok := opMap["mode"].(string); ok { + operation.Mode = mode + } else { + return nil, false // mode 是必需的 + } + + // 可选字段 + if value, exists := opMap["value"]; exists { + operation.Value = value + } + if keepOrigin, ok := opMap["keep_origin"].(bool); ok { + operation.KeepOrigin = keepOrigin + } + if from, ok := opMap["from"].(string); ok { + operation.From = from + } + if to, ok := opMap["to"].(string); ok { + operation.To = to + } + if logic, ok := opMap["logic"].(string); ok { + operation.Logic = logic + } else { + operation.Logic = "OR" // 默认为OR + } + + // 解析条件 + if conditions, exists := opMap["conditions"]; exists { + if condSlice, ok := conditions.([]interface{}); ok { + for _, cond := range condSlice { + if condMap, ok := cond.(map[string]interface{}); ok { + condition := ConditionOperation{} + if path, ok := condMap["path"].(string); ok { + condition.Path = path + } + if mode, ok := condMap["mode"].(string); ok { + condition.Mode = mode + } + if value, ok := condMap["value"]; ok { + condition.Value = value + } + if invert, ok := condMap["invert"].(bool); ok { + condition.Invert = invert + } + if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok { + condition.PassMissingKey = passMissingKey + } + operation.Conditions = append(operation.Conditions, condition) + } + } + } + } + + operations = append(operations, operation) + } else { + return nil, false + } + } + return operations, true + } + } + + return nil, false +} + +func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) { + if len(conditions) == 0 { + return true, nil // 没有条件,直接通过 + } + results := make([]bool, len(conditions)) + for i, condition := range conditions { + result, err := checkSingleCondition(jsonStr, condition) + if err != nil { + return false, err + } + results[i] = result + } + + if strings.ToUpper(logic) == "AND" { + for _, result := range results { + if !result { + return false, nil + } + } + return true, nil + } else { + for _, result := range results { + if result { + return true, nil + } + } + return false, nil + } +} + +func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) { + // 处理负数索引 + path := processNegativeIndex(jsonStr, condition.Path) + value := gjson.Get(jsonStr, path) + if !value.Exists() { + if condition.PassMissingKey { + return true, nil + } + return false, nil + } + + // 利用gjson的类型解析 + targetBytes, err := json.Marshal(condition.Value) + if err != nil { + return false, fmt.Errorf("failed to marshal condition value: %v", err) + } + targetValue := gjson.ParseBytes(targetBytes) + + result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode)) + if err != nil { + return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err) + } + + if condition.Invert { + result = !result + } + return result, nil +} + +func processNegativeIndex(jsonStr string, path string) string { + re := regexp.MustCompile(`\.(-\d+)`) + matches := re.FindAllStringSubmatch(path, -1) + + if len(matches) == 0 { + return path + } + + result := path + for _, match := range matches { + negIndex := match[1] + index, _ := strconv.Atoi(negIndex) + + arrayPath := strings.Split(path, negIndex)[0] + if strings.HasSuffix(arrayPath, ".") { + arrayPath = arrayPath[:len(arrayPath)-1] + } + + array := gjson.Get(jsonStr, arrayPath) + if array.IsArray() { + length := len(array.Array()) + actualIndex := length + index + if actualIndex >= 0 && actualIndex < length { + result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1) + } + } + } + + return result +} + +// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式 +func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) { + switch mode { + case "full": + return compareEqual(jsonValue, targetValue) + case "prefix": + return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil + case "suffix": + return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil + case "contains": + return strings.Contains(jsonValue.String(), targetValue.String()), nil + case "gt": + return compareNumeric(jsonValue, targetValue, "gt") + case "gte": + return compareNumeric(jsonValue, targetValue, "gte") + case "lt": + return compareNumeric(jsonValue, targetValue, "lt") + case "lte": + return compareNumeric(jsonValue, targetValue, "lte") + default: + return false, fmt.Errorf("unsupported comparison mode: %s", mode) + } +} + +func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) { + // 对布尔值特殊处理 + if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) && + (targetValue.Type == gjson.True || targetValue.Type == gjson.False) { + return jsonValue.Bool() == targetValue.Bool(), nil + } + + // 如果类型不同,报错 + if jsonValue.Type != targetValue.Type { + return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type) + } + + switch jsonValue.Type { + case gjson.True, gjson.False: + return jsonValue.Bool() == targetValue.Bool(), nil + case gjson.Number: + return jsonValue.Num == targetValue.Num, nil + case gjson.String: + return jsonValue.String() == targetValue.String(), nil + default: + return jsonValue.String() == targetValue.String(), nil + } +} + +func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) { + // 只有数字类型才支持数值比较 + if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number { + return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type) + } + + jsonNum := jsonValue.Num + targetNum := targetValue.Num + + switch operator { + case "gt": + return jsonNum > targetNum, nil + case "gte": + return jsonNum >= targetNum, nil + case "lt": + return jsonNum < targetNum, nil + case "lte": + return jsonNum <= targetNum, nil + default: + return false, fmt.Errorf("unsupported numeric operator: %s", operator) + } +} + +// applyOperationsLegacy 原参数覆盖方法 +func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) { + reqMap := make(map[string]interface{}) + err := json.Unmarshal(jsonData, &reqMap) + if err != nil { + return nil, err + } + + for key, value := range paramOverride { + reqMap[key] = value + } + + return json.Marshal(reqMap) +} + +func applyOperations(jsonStr string, operations []ParamOperation) (string, error) { + result := jsonStr + for _, op := range operations { + // 检查条件是否满足 + ok, err := checkConditions(result, op.Conditions, op.Logic) + if err != nil { + return "", err + } + if !ok { + continue // 条件不满足,跳过当前操作 + } + // 处理路径中的负数索引 + opPath := processNegativeIndex(result, op.Path) + opFrom := processNegativeIndex(result, op.From) + opTo := processNegativeIndex(result, op.To) + + switch op.Mode { + case "delete": + result, err = sjson.Delete(result, opPath) + case "set": + if op.KeepOrigin && gjson.Get(result, opPath).Exists() { + continue + } + result, err = sjson.Set(result, opPath, op.Value) + case "move": + result, err = moveValue(result, opFrom, opTo) + case "prepend": + result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true) + case "append": + result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false) + default: + return "", fmt.Errorf("unknown operation: %s", op.Mode) + } + if err != nil { + return "", fmt.Errorf("operation %s failed: %v", op.Mode, err) + } + } + return result, nil +} + +func moveValue(jsonStr, fromPath, toPath string) (string, error) { + sourceValue := gjson.Get(jsonStr, fromPath) + if !sourceValue.Exists() { + return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath) + } + result, err := sjson.Set(jsonStr, toPath, sourceValue.Value()) + if err != nil { + return "", err + } + return sjson.Delete(result, fromPath) +} + +func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) { + current := gjson.Get(jsonStr, path) + switch { + case current.IsArray(): + return modifyArray(jsonStr, path, value, isPrepend) + case current.Type == gjson.String: + return modifyString(jsonStr, path, value, isPrepend) + case current.Type == gjson.JSON: + return mergeObjects(jsonStr, path, value, keepOrigin) + } + return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) +} + +func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) { + current := gjson.Get(jsonStr, path) + var newArray []interface{} + // 添加新值 + addValue := func() { + if arr, ok := value.([]interface{}); ok { + newArray = append(newArray, arr...) + } else { + newArray = append(newArray, value) + } + } + // 添加原值 + addOriginal := func() { + current.ForEach(func(_, val gjson.Result) bool { + newArray = append(newArray, val.Value()) + return true + }) + } + if isPrepend { + addValue() + addOriginal() + } else { + addOriginal() + addValue() + } + return sjson.Set(jsonStr, path, newArray) +} + +func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) { + current := gjson.Get(jsonStr, path) + valueStr := fmt.Sprintf("%v", value) + var newStr string + if isPrepend { + newStr = valueStr + current.String() + } else { + newStr = current.String() + valueStr + } + return sjson.Set(jsonStr, path, newStr) +} + +func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) { + current := gjson.Get(jsonStr, path) + var currentMap, newMap map[string]interface{} + + // 解析当前值 + if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil { + return "", err + } + // 解析新值 + switch v := value.(type) { + case map[string]interface{}: + newMap = v + default: + jsonBytes, _ := json.Marshal(v) + if err := json.Unmarshal(jsonBytes, &newMap); err != nil { + return "", err + } + } + // 合并 + result := make(map[string]interface{}) + for k, v := range currentMap { + result[k] = v + } + for k, v := range newMap { + if !keepOrigin || result[k] == nil { + result[k] = v + } + } + return sjson.Set(jsonStr, path, result) +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 45fde019c..da572c070 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -1,10 +1,13 @@ package common import ( + "errors" + "fmt" "one-api/common" "one-api/constant" "one-api/dto" relayconstant "one-api/relay/constant" + "one-api/types" "strings" "time" @@ -33,17 +36,6 @@ type ClaudeConvertInfo struct { Done bool } -const ( - RelayFormatOpenAI = "openai" - RelayFormatClaude = "claude" - RelayFormatGemini = "gemini" - RelayFormatOpenAIResponses = "openai_responses" - RelayFormatOpenAIAudio = "openai_audio" - RelayFormatOpenAIImage = "openai_image" - RelayFormatRerank = "rerank" - RelayFormatEmbedding = "embedding" -) - type RerankerInfo struct { Documents []any ReturnDocuments bool @@ -59,9 +51,27 @@ type ResponsesUsageInfo struct { BuiltInTools map[string]*BuildInToolInfo } +type ChannelMeta struct { + ChannelType int + ChannelId int + ChannelIsMultiKey bool + ChannelMultiKeyIndex int + ChannelBaseUrl string + ApiType int + ApiVersion string + ApiKey string + Organization string + ChannelCreateTime int64 + ParamOverride map[string]interface{} + HeadersOverride map[string]interface{} + ChannelSetting dto.ChannelSettings + ChannelOtherSettings dto.ChannelOtherSettings + UpstreamModelName string + IsModelMapped bool + SupportStreamOptions bool // 是否支持流式选项 +} + type RelayInfo struct { - ChannelType int - ChannelId int TokenId int TokenKey string UserId int @@ -72,43 +82,169 @@ type RelayInfo struct { FirstResponseTime time.Time isFirstResponse bool //SendLastReasoningResponse bool - ApiType int - IsStream bool - IsPlayground bool - UsePrice bool - RelayMode int - UpstreamModelName string - OriginModelName string - //RecodeModelName string - RequestURLPath string - ApiVersion string - PromptTokens int - ApiKey string - Organization string - BaseUrl string - SupportStreamOptions bool - ShouldIncludeUsage bool - IsModelMapped bool - ClientWs *websocket.Conn - TargetWs *websocket.Conn - InputAudioFormat string - OutputAudioFormat string - RealtimeTools []dto.RealTimeTool - IsFirstRequest bool - AudioUsage bool - ReasoningEffort string - ChannelSetting dto.ChannelSettings - ParamOverride map[string]interface{} - UserSetting dto.UserSetting - UserEmail string - UserQuota int - RelayFormat string - SendResponseCount int - ChannelCreateTime int64 + IsStream bool + IsGeminiBatchEmbedding bool + IsPlayground bool + UsePrice bool + RelayMode int + OriginModelName string + RequestURLPath string + PromptTokens int + ShouldIncludeUsage bool + DisablePing bool // 是否禁止向下游发送自定义 Ping + ClientWs *websocket.Conn + TargetWs *websocket.Conn + InputAudioFormat string + OutputAudioFormat string + RealtimeTools []dto.RealTimeTool + IsFirstRequest bool + AudioUsage bool + ReasoningEffort string + UserSetting dto.UserSetting + UserEmail string + UserQuota int + RelayFormat types.RelayFormat + SendResponseCount int + FinalPreConsumedQuota int // 最终预消耗的配额 + + PriceData types.PriceData + + Request dto.Request + ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo *ResponsesUsageInfo + *ChannelMeta + *TaskRelayInfo +} + +func (info *RelayInfo) InitChannelMeta(c *gin.Context) { + channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) + headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride) + apiType, _ := common.ChannelType2APIType(channelType) + channelMeta := &ChannelMeta{ + ChannelType: channelType, + ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId), + ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), + ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex), + ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), + Organization: c.GetString("channel_organization"), + ChannelCreateTime: c.GetInt64("channel_create_time"), + ParamOverride: paramOverride, + HeadersOverride: headerOverride, + UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + IsModelMapped: false, + SupportStreamOptions: false, + } + + if channelType == constant.ChannelTypeAzure { + channelMeta.ApiVersion = GetAPIVersion(c) + } + if channelType == constant.ChannelTypeVertexAi { + channelMeta.ApiVersion = c.GetString("region") + } + + channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) + if ok { + channelMeta.ChannelSetting = channelSetting + } + + channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting) + if ok { + channelMeta.ChannelOtherSettings = channelOtherSettings + } + + if streamSupportedChannels[channelMeta.ChannelType] { + channelMeta.SupportStreamOptions = true + } + + info.ChannelMeta = channelMeta + + // reset some fields based on channel meta + // 重置某些字段,例如模型名称等 + if info.Request != nil { + info.Request.SetModelName(info.OriginModelName) + } +} + +func (info *RelayInfo) ToString() string { + if info == nil { + return "RelayInfo" + } + + // Basic info + b := &strings.Builder{} + fmt.Fprintf(b, "RelayInfo{ ") + fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat) + fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode) + fmt.Fprintf(b, "IsStream: %t, ", info.IsStream) + fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground) + fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath) + fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName) + fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens) + fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage) + fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing) + fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount) + fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota) + + // User & token info (mask secrets) + fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ", + info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota) + fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited) + + // Time info + latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds() + fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ", + info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs) + + // Audio / realtime + if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage { + fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ", + info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools)) + } + + // Reasoning + if info.ReasoningEffort != "" { + fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort) + } + + // Price data (non-sensitive) + if info.PriceData.UsePrice { + fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting()) + } + + // Channel metadata (mask ApiKey) + if info.ChannelMeta != nil { + cm := info.ChannelMeta + fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ", + cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions) + } + + // Responses usage info (non-sensitive) + if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 { + fmt.Fprintf(b, "ResponsesTools{ ") + first := true + for name, tool := range info.ResponsesUsageInfo.BuiltInTools { + if !first { + fmt.Fprintf(b, ", ") + } + first = false + if tool != nil { + fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount) + } else { + fmt.Fprintf(b, "%s: calls=0", name) + } + } + fmt.Fprintf(b, " }, ") + } + + fmt.Fprintf(b, "}") + return b.String() } // 定义支持流式选项的通道类型 @@ -127,7 +263,8 @@ var streamSupportedChannels = map[int]bool{ } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { - info := GenRelayInfo(c) + info := genBaseRelayInfo(c, nil) + info.RelayFormat = types.RelayFormatOpenAIRealtime info.ClientWs = ws info.InputAudioFormat = "pcm16" info.OutputAudioFormat = "pcm16" @@ -135,9 +272,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { return info } -func GenRelayInfoClaude(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatClaude +func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatClaude info.ShouldIncludeUsage = false info.ClaudeConvertInfo = &ClaudeConvertInfo{ LastMessagesType: LastMessageTypeNone, @@ -145,41 +282,39 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { return info } -func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { - info := GenRelayInfo(c) +func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo { + info := genBaseRelayInfo(c, request) info.RelayMode = relayconstant.RelayModeRerank - info.RelayFormat = RelayFormatRerank + info.RelayFormat = types.RelayFormatRerank info.RerankerInfo = &RerankerInfo{ - Documents: req.Documents, - ReturnDocuments: req.GetReturnDocuments(), + Documents: request.Documents, + ReturnDocuments: request.GetReturnDocuments(), } return info } -func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatOpenAIAudio +func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAIAudio return info } -func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatEmbedding +func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatEmbedding return info } -func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { - info := GenRelayInfo(c) +func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo { + info := genBaseRelayInfo(c, request) info.RelayMode = relayconstant.RelayModeResponses - info.RelayFormat = RelayFormatOpenAIResponses - - info.SupportStreamOptions = false + info.RelayFormat = types.RelayFormatOpenAIResponses info.ResponsesUsageInfo = &ResponsesUsageInfo{ BuiltInTools: make(map[string]*BuildInToolInfo), } - if len(req.Tools) > 0 { - for _, tool := range req.Tools { + if len(request.Tools) > 0 { + for _, tool := range request.GetToolsMap() { toolType := common.Interface2String(tool["type"]) info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{ ToolName: toolType, @@ -195,93 +330,87 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel } } } - info.IsStream = req.Stream return info } -func GenRelayInfoGemini(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatGemini +func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatGemini info.ShouldIncludeUsage = false + return info } -func GenRelayInfoImage(c *gin.Context) *RelayInfo { - info := GenRelayInfo(c) - info.RelayFormat = RelayFormatOpenAIImage +func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAIImage return info } -func GenRelayInfo(c *gin.Context) *RelayInfo { - channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) - channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) - paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) +func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo { + info := genBaseRelayInfo(c, request) + info.RelayFormat = types.RelayFormatOpenAI + return info +} + +func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { + + //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) + //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) - tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) - tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey) - userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) - tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited) startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) + if startTime.IsZero() { + startTime = time.Now() + } + + isStream := false + + if request != nil { + isStream = request.IsStream(c) + } + // firstResponseTime = time.Now() - 1 second - apiType, _ := common.ChannelType2APIType(channelType) - info := &RelayInfo{ - UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), - UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), - isFirstResponse: true, - RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), - RequestURLPath: c.Request.URL.String(), - ChannelType: channelType, - ChannelId: channelId, - TokenId: tokenId, - TokenKey: tokenKey, - UserId: userId, - UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), - UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), - TokenUnlimited: tokenUnlimited, + Request: request, + + UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId), + UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), + UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), + UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), + UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), + + OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens), + + TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId), + TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey), + TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited), + + isFirstResponse: true, + RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), + RequestURLPath: c.Request.URL.String(), + IsStream: isStream, + StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), - OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), - UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), - //RecodeModelName: c.GetString("original_model"), - IsModelMapped: false, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey), - Organization: c.GetString("channel_organization"), - - ChannelCreateTime: c.GetInt64("channel_create_time"), - ParamOverride: paramOverride, - RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, }, } + + if info.RelayMode == relayconstant.RelayModeUnknown { + info.RelayMode = c.GetInt("relay_mode") + } + if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg") info.RequestURLPath = "/v1" + info.RequestURLPath } - if info.BaseUrl == "" { - info.BaseUrl = constant.ChannelBaseURLs[channelType] - } - if info.ChannelType == constant.ChannelTypeAzure { - info.ApiVersion = GetAPIVersion(c) - } - if info.ChannelType == constant.ChannelTypeVertexAi { - info.ApiVersion = c.GetString("region") - } - if streamSupportedChannels[info.ChannelType] { - info.SupportStreamOptions = true - } - channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) - if ok { - info.ChannelSetting = channelSetting - } userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting) if ok { info.UserSetting = userSetting @@ -290,12 +419,43 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { return info } -func (info *RelayInfo) SetPromptTokens(promptTokens int) { - info.PromptTokens = promptTokens +func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) { + switch relayFormat { + case types.RelayFormatOpenAI: + return GenRelayInfoOpenAI(c, request), nil + case types.RelayFormatOpenAIAudio: + return GenRelayInfoOpenAIAudio(c, request), nil + case types.RelayFormatOpenAIImage: + return GenRelayInfoImage(c, request), nil + case types.RelayFormatOpenAIRealtime: + return GenRelayInfoWs(c, ws), nil + case types.RelayFormatClaude: + return GenRelayInfoClaude(c, request), nil + case types.RelayFormatRerank: + if request, ok := request.(*dto.RerankRequest); ok { + return GenRelayInfoRerank(c, request), nil + } + return nil, errors.New("request is not a RerankRequest") + case types.RelayFormatGemini: + return GenRelayInfoGemini(c, request), nil + case types.RelayFormatEmbedding: + return GenRelayInfoEmbedding(c, request), nil + case types.RelayFormatOpenAIResponses: + if request, ok := request.(*dto.OpenAIResponsesRequest); ok { + return GenRelayInfoResponses(c, request), nil + } + return nil, errors.New("request is not a OpenAIResponsesRequest") + case types.RelayFormatTask: + return genBaseRelayInfo(c, nil), nil + case types.RelayFormatMjProxy: + return genBaseRelayInfo(c, nil), nil + default: + return nil, errors.New("invalid relay format") + } } -func (info *RelayInfo) SetIsStream(isStream bool) { - info.IsStream = isStream +func (info *RelayInfo) SetPromptTokens(promptTokens int) { + info.PromptTokens = promptTokens } func (info *RelayInfo) SetFirstResponseTime() { @@ -310,20 +470,12 @@ func (info *RelayInfo) HasSendResponse() bool { } type TaskRelayInfo struct { - *RelayInfo Action string OriginTaskID string ConsumeQuota bool } -func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { - info := &TaskRelayInfo{ - RelayInfo: GenRelayInfo(c), - } - return info -} - type TaskSubmitReq struct { Prompt string `json:"prompt"` Model string `json:"model,omitempty"` diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 290865854..3d5efcb6d 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -2,12 +2,10 @@ package common import ( "fmt" - "github.com/gin-gonic/gin" - _ "image/gif" - _ "image/jpeg" - _ "image/png" "one-api/constant" "strings" + + "github.com/gin-gonic/gin" ) func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index ce823b3ab..05dbfa6d7 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -8,6 +8,7 @@ import ( "one-api/dto" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" + "one-api/service" "one-api/types" "github.com/gin-gonic/gin" @@ -16,9 +17,9 @@ import ( func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) + service.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println("reranker response body: ", string(responseBody)) } @@ -27,7 +28,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var xinRerankResponse xinference.XinRerankResponse err = common.Unmarshal(responseBody, &xinRerankResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results)) for i, result := range xinRerankResponse.Results { @@ -62,7 +63,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } else { err = common.Unmarshal(responseBody, &jinaResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens } diff --git a/relay/relay-text.go b/relay/compatible_handler.go similarity index 51% rename from relay/relay-text.go rename to relay/compatible_handler.go index 603270741..a3c6ace6e 100644 --- a/relay/relay-text.go +++ b/relay/compatible_handler.go @@ -2,213 +2,152 @@ package relay import ( "bytes" - "encoding/json" - "errors" "fmt" "io" - "math" "net/http" "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/setting/operation_setting" "one-api/types" "strings" "time" - "github.com/bytedance/gopkg/util/gopool" "github.com/shopspring/decimal" "github.com/gin-gonic/gin" ) -func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) { - textRequest := &dto.GeneralOpenAIRequest{} - err := common.UnmarshalBodyReusable(c, textRequest) +func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) + + textReq, ok := info.Request.(*dto.GeneralOpenAIRequest) + if !ok { + return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + + request, err := common.DeepCopy(textReq) if err != nil { - return nil, err - } - if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { - textRequest.Model = c.Param("model") + return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - if textRequest.MaxTokens > math.MaxInt32/2 { - return nil, errors.New("max_tokens is invalid") + if request.WebSearchOptions != nil { + c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize) } - if textRequest.Model == "" { - return nil, errors.New("model is required") - } - if textRequest.WebSearchOptions != nil { - if textRequest.WebSearchOptions.SearchContextSize != "" { - validSizes := map[string]bool{ - "high": true, - "medium": true, - "low": true, - } - if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { - return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") - } - } else { - textRequest.WebSearchOptions.SearchContextSize = "medium" - } - } - switch relayInfo.RelayMode { - case relayconstant.RelayModeCompletions: - if textRequest.Prompt == "" { - return nil, errors.New("field prompt is required") - } - case relayconstant.RelayModeChatCompletions: - if len(textRequest.Messages) == 0 { - return nil, errors.New("field messages is required") - } - case relayconstant.RelayModeEmbeddings: - case relayconstant.RelayModeModerations: - if textRequest.Input == nil || textRequest.Input == "" { - return nil, errors.New("field input is required") - } - case relayconstant.RelayModeEdits: - if textRequest.Instruction == "" { - return nil, errors.New("field instruction is required") - } - } - relayInfo.IsStream = textRequest.Stream - return textRequest, nil -} - -func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - - relayInfo := relaycommon.GenRelayInfo(c) - - // get & validate textRequest 获取并验证文本请求 - textRequest, err := getAndValidateTextRequest(c, relayInfo) + err = helper.ModelMappedHelper(c, info, request) if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - if textRequest.WebSearchOptions != nil { - c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize) - } - - if setting.ShouldCheckPromptSensitive() { - words, err := checkRequestSensitive(textRequest, relayInfo) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) - } - } - - err = helper.ModelMappedHelper(c, relayInfo, textRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) - } - - // 获取 promptTokens,如果上下文中已经存在,则直接使用 - var promptTokens int - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens = value.(int) - relayInfo.PromptTokens = promptTokens - } else { - promptTokens, err = getPromptTokens(textRequest, relayInfo) - // count messages token error 计算promptTokens错误 - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed) - } - c.Set("prompt_tokens", promptTokens) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens)))) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newApiErr != nil { - return newApiErr - } - defer func() { - if newApiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - includeUsage := false + includeUsage := true // 判断用户是否需要返回使用情况 - if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage { - includeUsage = true + if request.StreamOptions != nil { + includeUsage = request.StreamOptions.IncludeUsage } // 如果不支持StreamOptions,将StreamOptions设置为nil - if !relayInfo.SupportStreamOptions || !textRequest.Stream { - textRequest.StreamOptions = nil + if !info.SupportStreamOptions || !request.Stream { + request.StreamOptions = nil } else { // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions if constant.ForceStreamOption { - textRequest.StreamOptions = &dto.StreamOptions{ + request.StreamOptions = &dto.StreamOptions{ IncludeUsage: true, } } } - if includeUsage { - relayInfo.ShouldIncludeUsage = true - } + info.ShouldIncludeUsage = includeUsage - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - if model_setting.GetGlobalSettings().PassThroughRequestEnabled { + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + if common.DebugEnabled { + println("requestBody: ", string(body)) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) + convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - jsonData, err := json.Marshal(convertedRequest) + + if info.ChannelSetting.SystemPrompt != "" { + // 如果有系统提示,则将其添加到请求中 + request := convertedRequest.(*dto.GeneralOpenAIRequest) + containSystemPrompt := false + for _, message := range request.Messages { + if message.Role == request.GetSystemRoleName() { + containSystemPrompt = true + break + } + } + if !containSystemPrompt { + // 如果没有系统提示,则添加系统提示 + systemMessage := dto.Message{ + Role: request.GetSystemRoleName(), + Content: info.ChannelSetting.SystemPrompt, + } + request.Messages = append([]dto.Message{systemMessage}, request.Messages...) + } else if info.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + // 如果有系统提示,且允许覆盖,则拼接到前面 + for i, message := range request.Messages { + if message.Role == request.GetSystemRoleName() { + if message.IsStringContent() { + request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) + } else { + contents := message.ParseContent() + contents = append([]dto.MediaContent{ + { + Type: dto.ContentTypeText, + Text: info.ChannelSetting.SystemPrompt, + }, + }, contents...) + request.Messages[i].Content = contents + } + break + } + } + } + } + + jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeJsonMarshalFailed, types.ErrOptionWithSkipRetry()) } // apply param override - if len(relayInfo.ParamOverride) > 0 { - reqMap := make(map[string]interface{}) - _ = common.Unmarshal(jsonData, &reqMap) - for key, value := range relayInfo.ParamOverride { - reqMap[key] = value - } - jsonData, err = common.Marshal(reqMap) + if len(info.ParamOverride) > 0 { + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } - if common.DebugEnabled { - println("requestBody: ", string(jsonData)) - } + logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData))) + requestBody = bytes.NewBuffer(jsonData) } var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) - + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -217,125 +156,31 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newApiErr = service.RelayErrorHandler(httpResp, false) + newApiErr := service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } } - usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newApiErr := adaptor.DoResponse(c, httpResp, info) if newApiErr != nil { // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } - if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") } return nil } -func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - promptTokens, err = service.CountTokenChatRequest(info, *textRequest) - case relayconstant.RelayModeCompletions: - promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model) - case relayconstant.RelayModeModerations: - promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) - case relayconstant.RelayModeEmbeddings: - promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) - default: - err = errors.New("unknown relay mode") - promptTokens = 0 - } - info.PromptTokens = promptTokens - return promptTokens, err -} - -func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) { - var err error - var words []string - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - words, err = service.CheckSensitiveMessages(textRequest.Messages) - case relayconstant.RelayModeCompletions: - words, err = service.CheckSensitiveInput(textRequest.Prompt) - case relayconstant.RelayModeModerations: - words, err = service.CheckSensitiveInput(textRequest.Input) - case relayconstant.RelayModeEmbeddings: - words, err = service.CheckSensitiveInput(textRequest.Input) - } - return words, err -} - -// 预扣费并返回用户剩余配额 -func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) { - userQuota, err := model.GetUserQuota(relayInfo.UserId, false) - if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError) - } - if userQuota <= 0 { - return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) - } - if userQuota-preConsumedQuota < 0 { - return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) - } - relayInfo.UserQuota = userQuota - if userQuota > 100*preConsumedQuota { - // 用户额度充足,判断令牌额度是否充足 - if !relayInfo.TokenUnlimited { - // 非无限令牌,判断令牌额度是否充足 - tokenQuota := c.GetInt("token_quota") - if tokenQuota > 100*preConsumedQuota { - // 令牌额度充足,信任令牌 - preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) - } - } else { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota))) - } - } - - if preConsumedQuota > 0 { - err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) - if err != nil { - return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden) - } - err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) - if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError) - } - } - return preConsumedQuota, userQuota, nil -} - -func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) { - if preConsumedQuota != 0 { - gopool.Go(func() { - relayInfoCopy := *relayInfo - - err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) - if err != nil { - common.SysError("error return pre-consumed quota: " + err.Error()) - } - }) - } -} - -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { if usage == nil { usage = &dto.Usage{ PromptTokens: relayInfo.PromptTokens, @@ -353,12 +198,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := priceData.CompletionRatio - cacheRatio := priceData.CacheRatio - imageRatio := priceData.ImageRatio - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice + completionRatio := relayInfo.PriceData.CompletionRatio + cacheRatio := relayInfo.PriceData.CacheRatio + imageRatio := relayInfo.PriceData.ImageRatio + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -431,7 +276,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, var audioInputQuota decimal.Decimal var audioInputPrice float64 - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { baseTokens := dPromptTokens // 减去 cached tokens var cachedTokensWithRatio decimal.Decimal @@ -469,21 +314,27 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } else { quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) } + var dGeminiImageOutputQuota decimal.Decimal + var imageOutputPrice float64 + if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") { + imageOutputPrice = operation_setting.GetGeminiImageOutputPricePerMillionTokens(modelName) + if imageOutputPrice > 0 { + dImageOutputTokens := decimal.NewFromInt(int64(ctx.GetInt("gemini_image_tokens"))) + dGeminiImageOutputQuota = decimal.NewFromFloat(imageOutputPrice).Div(decimal.NewFromInt(1000000)).Mul(dImageOutputTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit) + } + } // 添加 responses tools call 调用的配额 quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) // 添加 audio input 独立计费 quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) + // 添加 Gemini image output 计费 + quotaCalculateDecimal = quotaCalculateDecimal.Add(dGeminiImageOutputQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens var logContent string - if !priceData.UsePrice { - logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio) - } else { - logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) - } // record all the consume log even if quota is 0 if totalTokens == 0 { @@ -491,18 +342,38 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { + if !ratio.IsZero() && quota == 0 { + quota = 1 + } model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota + + //logger.LogInfo(ctx, fmt.Sprintf("request quota delta: %s", logger.FormatQuota(quotaDelta))) + + if quotaDelta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } else if quotaDelta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } + if quotaDelta != 0 { - err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } @@ -518,7 +389,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, if extraContent != "" { logContent += ", " + extraContent } - other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) if imageTokens != 0 { other["image"] = true other["image_ratio"] = imageRatio @@ -553,6 +424,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["audio_input_token_count"] = audioTokens other["audio_input_price"] = audioInputPrice } + if !dGeminiImageOutputQuota.IsZero() { + other["image_output_token_count"] = ctx.GetInt("gemini_image_tokens") + other["image_output_price"] = imageOutputPrice + } model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, @@ -562,7 +437,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index b51957521..85a1b9c5f 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -40,11 +40,8 @@ const ( RelayModeSunoFetchByID RelayModeSunoSubmit - RelayModeKlingFetchByID - RelayModeKlingSubmit - - RelayModeJimengFetchByID - RelayModeJimengSubmit + RelayModeVideoFetchByID + RelayModeVideoSubmit RelayModeRerank @@ -87,6 +84,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeRealtime } else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") { relayMode = RelayModeGemini + } else if strings.HasPrefix(path, "/mj") { + relayMode = Path2RelayModeMidjourney(path) } return relayMode } @@ -145,23 +144,3 @@ func Path2RelaySuno(method, path string) int { } return relayMode } - -func Path2RelayKling(method, path string) int { - relayMode := RelayModeUnknown - if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") { - relayMode = RelayModeKlingSubmit - } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") { - relayMode = RelayModeKlingFetchByID - } - return relayMode -} - -func Path2RelayJimeng(method, path string) int { - relayMode := RelayModeUnknown - if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") { - relayMode = RelayModeJimengSubmit - } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") { - relayMode = RelayModeJimengFetchByID - } - return relayMode -} diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index be11bb2b8..26dcf9719 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -8,7 +8,6 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "one-api/types" @@ -16,80 +15,41 @@ import ( "github.com/gin-gonic/gin" ) -func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { - token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) - return token -} +func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) -func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error { - if embeddingRequest.Input == nil { - return fmt.Errorf("input is empty") + embeddingReq, ok := info.Request.(*dto.EmbeddingRequest) + if !ok { + return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { - embeddingRequest.Model = "omni-moderation-latest" - } - if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { - embeddingRequest.Model = c.Param("model") - } - return nil -} -func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoEmbedding(c) - - var embeddingRequest *dto.EmbeddingRequest - err := common.UnmarshalBodyReusable(c, &embeddingRequest) + request, err := common.DeepCopy(embeddingReq) if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest) + err = helper.ModelMappedHelper(c, info, request) if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) - } - - promptToken := getEmbeddingPromptToken(*embeddingRequest) - relayInfo.PromptTokens = promptToken - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) - } - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) - - convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest) + adaptor.Init(info) + convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *request) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -105,12 +65,12 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index e448b4913..460fd2f58 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -2,17 +2,16 @@ package relay import ( "bytes" - "encoding/json" - "errors" "fmt" + "io" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/relay/channel/gemini" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -20,67 +19,13 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) { - request := &gemini.GeminiChatRequest{} - err := common.UnmarshalBodyReusable(c, request) - if err != nil { - return nil, err - } - if len(request.Contents) == 0 { - return nil, errors.New("contents is required") - } - return request, nil -} - -// 流模式 -// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx -func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) { - if c.Query("alt") == "sse" { - relayInfo.IsStream = true - } - - // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") { - // relayInfo.IsStream = true - // } -} - -func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) { - var inputTexts []string - for _, content := range textRequest.Contents { - for _, part := range content.Parts { - if part.Text != "" { - inputTexts = append(inputTexts, part.Text) - } - } - } - if len(inputTexts) == 0 { - return nil, nil - } - - sensitiveWords, err := service.CheckSensitiveInput(inputTexts) - return sensitiveWords, err -} - -func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int { - // 计算输入 token 数量 - var inputTexts []string - for _, content := range req.Contents { - for _, part := range content.Parts { - if part.Text != "" { - inputTexts = append(inputTexts, part.Text) - } - } - } - - inputText := strings.Join(inputTexts, "\n") - inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName) - info.PromptTokens = inputTokens - return inputTokens -} - -func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool { +func isNoThinkingRequest(req *dto.GeminiChatRequest) bool { if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { - return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0 + configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget + if configBudget != nil && *configBudget == 0 { + // 如果思考预算为 0,则认为是非思考请求 + return true + } } return false } @@ -105,108 +50,99 @@ func trimModelThinking(modelName string) string { return modelName } -func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - req, err := getAndValidateGeminiRequest(c) - if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) +func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) + + geminiReq, ok := info.Request.(*dto.GeminiChatRequest) + if !ok { + return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - relayInfo := relaycommon.GenRelayInfoGemini(c) - - // 检查 Gemini 流式模式 - checkGeminiStreamMode(c, relayInfo) - - if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkGeminiInputSensitive(req) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) - } + request, err := common.DeepCopy(geminiReq) + if err != nil { + return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } // model mapped 模型映射 - err = helper.ModelMappedHelper(c, relayInfo, req) + err = helper.ModelMappedHelper(c, info, request) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) - } - - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens := value.(int) - relayInfo.SetPromptTokens(promptTokens) - } else { - promptTokens := getGeminiInputTokens(req, relayInfo) - c.Set("prompt_tokens", promptTokens) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - if isNoThinkingRequest(req) { + if isNoThinkingRequest(request) { // check is thinking - if !strings.Contains(relayInfo.OriginModelName, "-nothinking") { + if !strings.Contains(info.OriginModelName, "-nothinking") { // try to get no thinking model price - noThinkingModelName := relayInfo.OriginModelName + "-nothinking" + noThinkingModelName := info.OriginModelName + "-nothinking" containPrice := helper.ContainPriceOrRatio(noThinkingModelName) if containPrice { - relayInfo.OriginModelName = noThinkingModelName - relayInfo.UpstreamModelName = noThinkingModelName + info.OriginModelName = noThinkingModelName + info.UpstreamModelName = noThinkingModelName } } } - if req.GenerationConfig.ThinkingConfig == nil { - gemini.ThinkingAdaptor(req, relayInfo) + if request.GenerationConfig.ThinkingConfig == nil { + gemini.ThinkingAdaptor(request, info) } } - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) - } - - // pre consume quota - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) // Clean up empty system instruction - if req.SystemInstructions != nil { + if request.SystemInstructions != nil { hasContent := false - for _, part := range req.SystemInstructions.Parts { + for _, part := range request.SystemInstructions.Parts { if part.Text != "" { hasContent = true break } } if !hasContent { - req.SystemInstructions = nil + request.SystemInstructions = nil } } - requestBody, err := json.Marshal(req) - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + var requestBody io.Reader + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { + body, err := common.GetRequestBody(c) + if err != nil { + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + requestBody = bytes.NewReader(body) + } else { + // 使用 ConvertGeminiRequest 转换请求格式 + convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + jsonData, err := common.Marshal(convertedRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + + // apply param override + if len(info.ParamOverride) > 0 { + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + } + } + + logger.LogDebug(c, "Gemini request body: "+string(jsonData)) + + requestBody = bytes.NewReader(jsonData) } - if common.DebugEnabled { - println("Gemini request body: %s", string(requestBody)) - } - - resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - common.LogError(c, "Do gemini request failed: "+err.Error()) - return types.NewError(err, types.ErrorCodeDoRequestFailed) + logger.LogError(c, "Do gemini request failed: "+err.Error()) + return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } statusCodeMappingStr := c.GetString("status_code_mapping") @@ -214,7 +150,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -223,12 +159,108 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) if openaiErr != nil { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") + return nil +} + +func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) + + isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents") + info.IsGeminiBatchEmbedding = isBatch + + var req dto.Request + var err error + var inputTexts []string + + if isBatch { + batchRequest := &dto.GeminiBatchEmbeddingRequest{} + err = common.UnmarshalBodyReusable(c, batchRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + req = batchRequest + for _, r := range batchRequest.Requests { + for _, part := range r.Content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + } else { + singleRequest := &dto.GeminiEmbeddingRequest{} + err = common.UnmarshalBodyReusable(c, singleRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + req = singleRequest + for _, part := range singleRequest.Content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + + err = helper.ModelMappedHelper(c, info, req) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) + } + + adaptor := GetAdaptor(info.ApiType) + if adaptor == nil { + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + } + adaptor.Init(info) + + var requestBody io.Reader + jsonData, err := common.Marshal(req) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + + // apply param override + if len(info.ParamOverride) > 0 { + reqMap := make(map[string]interface{}) + _ = common.Unmarshal(jsonData, &reqMap) + for key, value := range info.ParamOverride { + reqMap[key] = value + } + jsonData, err = common.Marshal(reqMap) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + } + } + requestBody = bytes.NewReader(jsonData) + + resp, err := adaptor.DoRequest(c, info, requestBody) + if err != nil { + logger.LogError(c, "Do gemini request failed: "+err.Error()) + return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) + } + + statusCodeMappingStr := c.GetString("status_code_mapping") + var httpResp *http.Response + if resp != nil { + httpResp = resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { + newAPIError = service.RelayErrorHandler(httpResp, false) + service.ResetStatusCode(newAPIError, statusCodeMappingStr) + return newAPIError + } + } + + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info) + if openaiErr != nil { + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/helper/common.go b/relay/helper/common.go index 5d23b5123..381147ae5 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -1,85 +1,80 @@ package helper import ( - "encoding/json" "errors" "fmt" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/types" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) +func FlushWriter(c *gin.Context) error { + if c.Writer == nil { + return nil + } + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + return nil + } + return errors.New("streaming error: flusher not found") +} + func SetEventStreamHeaders(c *gin.Context) { // 检查是否已经设置过头部 if _, exists := c.Get("event_stream_headers_set"); exists { return } + // 设置标志,表示头部已经设置过 + c.Set("event_stream_headers_set", true) + c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("Transfer-Encoding", "chunked") c.Writer.Header().Set("X-Accel-Buffering", "no") - - // 设置标志,表示头部已经设置过 - c.Set("event_stream_headers_set", true) } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { - jsonData, err := json.Marshal(resp) + jsonData, err := common.Marshal(resp) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) } else { c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)}) } - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } else { - return errors.New("streaming error: flusher not found") - } + _ = FlushWriter(c) return nil } func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } + _ = FlushWriter(c) } func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) { c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)}) - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } + _ = FlushWriter(c) } func StringData(c *gin.Context, str string) error { //str = strings.TrimPrefix(str, "data: ") //str = strings.TrimSuffix(str, "\r") c.Render(-1, common.CustomEvent{Data: "data: " + str}) - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } else { - return errors.New("streaming error: flusher not found") - } + _ = FlushWriter(c) return nil } func PingData(c *gin.Context) error { c.Writer.Write([]byte(": PING\n\n")) - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } else { - return errors.New("streaming error: flusher not found") - } + _ = FlushWriter(c) return nil } @@ -100,7 +95,7 @@ func Done(c *gin.Context) { func WssString(c *gin.Context, ws *websocket.Conn, str string) error { if ws == nil { - common.LogError(c, "websocket connection is nil") + logger.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } //common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) @@ -108,12 +103,12 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error { } func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { - jsonData, err := json.Marshal(object) + jsonData, err := common.Marshal(object) if err != nil { return fmt.Errorf("error marshalling object: %w", err) } if ws == nil { - common.LogError(c, "websocket connection is nil") + logger.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } //common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) @@ -121,6 +116,9 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { } func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) { + if ws == nil { + return + } errorObj := &dto.RealtimeEvent{ Type: "error", EventId: GetLocalRealtimeID(c), @@ -139,6 +137,24 @@ func GetLocalRealtimeID(c *gin.Context) string { return fmt.Sprintf("evt_%s", logID) } +func GenerateStartEmptyResponse(id string, createAt int64, model string, systemFingerprint *string) *dto.ChatCompletionsStreamResponse { + return &dto.ChatCompletionsStreamResponse{ + Id: id, + Object: "chat.completion.chunk", + Created: createAt, + Model: model, + SystemFingerprint: systemFingerprint, + Choices: []dto.ChatCompletionsStreamResponseChoice{ + { + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ + Role: "assistant", + Content: common.GetPointer(""), + }, + }, + }, + } +} + func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse { return &dto.ChatCompletionsStreamResponse{ Id: id, diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index c17351497..5b64cd8b3 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -4,14 +4,12 @@ import ( "encoding/json" "errors" "fmt" - common2 "one-api/common" + "github.com/gin-gonic/gin" "one-api/dto" "one-api/relay/common" - - "github.com/gin-gonic/gin" ) -func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error { +func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error { // map model name modelMapping := c.GetString("model_mapping") if modelMapping != "" && modelMapping != "{}" { @@ -53,40 +51,7 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro } } if request != nil { - switch info.RelayFormat { - case common.RelayFormatGemini: - // Gemini 模型映射 - case common.RelayFormatClaude: - if claudeRequest, ok := request.(*dto.ClaudeRequest); ok { - claudeRequest.Model = info.UpstreamModelName - } - case common.RelayFormatOpenAIResponses: - if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok { - openAIResponsesRequest.Model = info.UpstreamModelName - } - case common.RelayFormatOpenAIAudio: - if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok { - openAIAudioRequest.Model = info.UpstreamModelName - } - case common.RelayFormatOpenAIImage: - if imageRequest, ok := request.(*dto.ImageRequest); ok { - imageRequest.Model = info.UpstreamModelName - } - case common.RelayFormatRerank: - if rerankRequest, ok := request.(*dto.RerankRequest); ok { - rerankRequest.Model = info.UpstreamModelName - } - case common.RelayFormatEmbedding: - if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok { - embeddingRequest.Model = info.UpstreamModelName - } - default: - if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok { - openAIRequest.Model = info.UpstreamModelName - } else { - common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request)) - } - } + request.SetModelName(info.UpstreamModelName) } return nil } diff --git a/relay/helper/price.go b/relay/helper/price.go index e80578e57..fdc5b66d8 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -5,35 +5,14 @@ import ( "one-api/common" relaycommon "one-api/relay/common" "one-api/setting/ratio_setting" + "one-api/types" "github.com/gin-gonic/gin" ) -type GroupRatioInfo struct { - GroupRatio float64 - GroupSpecialRatio float64 - HasSpecialRatio bool -} - -type PriceData struct { - ModelPrice float64 - ModelRatio float64 - CompletionRatio float64 - CacheRatio float64 - CacheCreationRatio float64 - ImageRatio float64 - UsePrice bool - ShouldPreConsumedQuota int - GroupRatioInfo GroupRatioInfo -} - -func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) -} - // HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present -func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo { - groupRatioInfo := GroupRatioInfo{ +func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo { + groupRatioInfo := types.GroupRatioInfo{ GroupRatio: 1.0, // default ratio GroupSpecialRatio: -1, } @@ -62,7 +41,7 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR return groupRatioInfo } -func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { +func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) { modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false) groupRatioInfo := HandleGroupRatio(c, info) @@ -74,9 +53,9 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var imageRatio float64 var cacheCreationRatio float64 if !usePrice { - preConsumedTokens := common.PreConsumedQuota - if maxTokens != 0 { - preConsumedTokens = promptTokens + maxTokens + preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota) + if meta.MaxTokens != 0 { + preConsumedTokens += meta.MaxTokens } var success bool var matchName string @@ -87,7 +66,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens acceptUnsetRatio = true } if !acceptUnsetRatio { - return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) + return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName) } } completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName) @@ -97,10 +76,13 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens ratio := modelRatio * groupRatioInfo.GroupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { + if meta.ImagePriceRatio != 0 { + modelPrice = modelPrice * meta.ImagePriceRatio + } preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) } - priceData := PriceData{ + priceData := types.PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, CompletionRatio: completionRatio, @@ -115,18 +97,12 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens if common.DebugEnabled { println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting())) } - + info.PriceData = priceData return priceData, nil } -type PerCallPriceData struct { - ModelPrice float64 - Quota int - GroupRatioInfo GroupRatioInfo -} - // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData { +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { groupRatioInfo := HandleGroupRatio(c, info) modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) @@ -140,7 +116,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCal } } quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) - priceData := PerCallPriceData{ + priceData := types.PerCallPriceData{ ModelPrice: modelPrice, Quota: quota, GroupRatioInfo: groupRatioInfo, diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index b526b1c0f..725d178cc 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/constant" + "one-api/logger" relaycommon "one-api/relay/common" "one-api/setting/operation_setting" "strings" @@ -39,10 +40,6 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon }() streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second - if strings.HasPrefix(info.UpstreamModelName, "o") { - // twice timeout for thinking model - streamingTimeout *= 2 - } var ( stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞 @@ -54,7 +51,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon ) generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled + pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second if pingInterval <= 0 { pingInterval = DefaultPingInterval @@ -91,7 +88,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case <-done: case <-time.After(5 * time.Second): - common.LogError(c, "timeout waiting for goroutines to exit") + logger.LogError(c, "timeout waiting for goroutines to exit") } close(stopChan) @@ -113,7 +110,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer func() { wg.Done() if r := recover(); r != nil { - common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) + logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) common.SafeSendBool(stopChan, true) } if common.DebugEnabled { @@ -140,14 +137,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case err := <-done: if err != nil { - common.LogError(c, "ping data error: "+err.Error()) + logger.LogError(c, "ping data error: "+err.Error()) return } if common.DebugEnabled { println("ping data sent") } case <-time.After(10 * time.Second): - common.LogError(c, "ping data send timeout") + logger.LogError(c, "ping data send timeout") return case <-ctx.Done(): return @@ -162,7 +159,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon // 监听客户端断开连接 return case <-pingTimeout.C: - common.LogError(c, "ping goroutine max duration reached") + logger.LogError(c, "ping goroutine max duration reached") return } } @@ -175,7 +172,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer func() { wg.Done() if r := recover(); r != nil { - common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) + logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) } common.SafeSendBool(stopChan, true) if common.DebugEnabled { @@ -227,19 +224,25 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon return } case <-time.After(10 * time.Second): - common.LogError(c, "data handler timeout") + logger.LogError(c, "data handler timeout") return case <-ctx.Done(): return case <-stopChan: return } + } else { + // done, 处理完成标志,直接退出停止读取剩余数据防止出错 + if common.DebugEnabled { + println("received [DONE], stopping scanner") + } + return } } if err := scanner.Err(); err != nil { if err != io.EOF { - common.LogError(c, "scanner error: "+err.Error()) + logger.LogError(c, "scanner error: "+err.Error()) } } }) @@ -248,12 +251,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon select { case <-ticker.C: // 超时处理逻辑 - common.LogError(c, "streaming timeout") + logger.LogError(c, "streaming timeout") case <-stopChan: // 正常结束 - common.LogInfo(c, "streaming finished") + logger.LogInfo(c, "streaming finished") case <-c.Request.Context().Done(): // 客户端断开连接 - common.LogInfo(c, "client disconnected") + logger.LogInfo(c, "client disconnected") } } diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go new file mode 100644 index 000000000..4d1c1f9bb --- /dev/null +++ b/relay/helper/valid_request.go @@ -0,0 +1,306 @@ +package helper + +import ( + "errors" + "fmt" + "math" + "one-api/common" + "one-api/dto" + "one-api/logger" + relayconstant "one-api/relay/constant" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + +func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) { + relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) + + switch format { + case types.RelayFormatOpenAI: + request, err = GetAndValidateTextRequest(c, relayMode) + case types.RelayFormatGemini: + request, err = GetAndValidateGeminiRequest(c) + case types.RelayFormatClaude: + request, err = GetAndValidateClaudeRequest(c) + case types.RelayFormatOpenAIResponses: + request, err = GetAndValidateResponsesRequest(c) + + case types.RelayFormatOpenAIImage: + request, err = GetAndValidOpenAIImageRequest(c, relayMode) + case types.RelayFormatEmbedding: + request, err = GetAndValidateEmbeddingRequest(c, relayMode) + case types.RelayFormatRerank: + request, err = GetAndValidateRerankRequest(c) + case types.RelayFormatOpenAIAudio: + request, err = GetAndValidAudioRequest(c, relayMode) + case types.RelayFormatOpenAIRealtime: + request = &dto.BaseRequest{} + default: + return nil, fmt.Errorf("unsupported relay format: %s", format) + } + return request, err +} + +func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) { + audioRequest := &dto.AudioRequest{} + err := common.UnmarshalBodyReusable(c, audioRequest) + if err != nil { + return nil, err + } + switch relayMode { + case relayconstant.RelayModeAudioSpeech: + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + default: + err = c.Request.ParseForm() + if err != nil { + return nil, err + } + formData := c.Request.PostForm + if audioRequest.Model == "" { + audioRequest.Model = formData.Get("model") + } + + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + audioRequest.ResponseFormat = formData.Get("response_format") + if audioRequest.ResponseFormat == "" { + audioRequest.ResponseFormat = "json" + } + } + return audioRequest, nil +} + +func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) { + var rerankRequest *dto.RerankRequest + err := common.UnmarshalBodyReusable(c, &rerankRequest) + if err != nil { + logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + if rerankRequest.Query == "" { + return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + if len(rerankRequest.Documents) == 0 { + return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + return rerankRequest, nil +} + +func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) { + var embeddingRequest *dto.EmbeddingRequest + err := common.UnmarshalBodyReusable(c, &embeddingRequest) + if err != nil { + logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + if embeddingRequest.Input == nil { + return nil, fmt.Errorf("input is empty") + } + if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" { + embeddingRequest.Model = "omni-moderation-latest" + } + if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" { + embeddingRequest.Model = c.Param("model") + } + return embeddingRequest, nil +} + +func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { + request := &dto.OpenAIResponsesRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if request.Model == "" { + return nil, errors.New("model is required") + } + if request.Input == nil { + return nil, errors.New("input is required") + } + return request, nil +} + +func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) { + imageRequest := &dto.ImageRequest{} + + switch relayMode { + case relayconstant.RelayModeImagesEdits: + if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { + _, err := c.MultipartForm() + if err != nil { + return nil, fmt.Errorf("failed to parse image edit form request: %w", err) + } + formData := c.Request.PostForm + imageRequest.Prompt = formData.Get("prompt") + imageRequest.Model = formData.Get("model") + imageRequest.N = uint(common.String2Int(formData.Get("n"))) + imageRequest.Quality = formData.Get("quality") + imageRequest.Size = formData.Get("size") + + if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + + watermark := formData.Has("watermark") + if watermark { + imageRequest.Watermark = &watermark + } + break + } + fallthrough + default: + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + + if imageRequest.Model == "" { + //imageRequest.Model = "dall-e-3" + return nil, errors.New("model is required") + } + + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") + } + + // Not "256x256", "512x512", or "1024x1024" + if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { + if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "dall-e-3" { + if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { + return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") + } + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "auto" + } + } + + //if imageRequest.Prompt == "" { + // return nil, errors.New("prompt is required") + //} + + if imageRequest.N == 0 { + imageRequest.N = 1 + } + } + + return imageRequest, nil +} + +func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { + textRequest = &dto.ClaudeRequest{} + err = c.ShouldBindJSON(textRequest) + if err != nil { + return nil, err + } + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + if textRequest.Model == "" { + return nil, errors.New("field model is required") + } + + //if textRequest.Stream { + // relayInfo.IsStream = true + //} + + return textRequest, nil +} + +func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) { + textRequest := &dto.GeneralOpenAIRequest{} + err := common.UnmarshalBodyReusable(c, textRequest) + if err != nil { + return nil, err + } + + if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } + + if textRequest.MaxTokens > math.MaxInt32/2 { + return nil, errors.New("max_tokens is invalid") + } + if textRequest.Model == "" { + return nil, errors.New("model is required") + } + if textRequest.WebSearchOptions != nil { + if textRequest.WebSearchOptions.SearchContextSize != "" { + validSizes := map[string]bool{ + "high": true, + "medium": true, + "low": true, + } + if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { + return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") + } + } else { + textRequest.WebSearchOptions.SearchContextSize = "medium" + } + } + switch relayMode { + case relayconstant.RelayModeCompletions: + if textRequest.Prompt == "" { + return nil, errors.New("field prompt is required") + } + case relayconstant.RelayModeChatCompletions: + if len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + case relayconstant.RelayModeEmbeddings: + case relayconstant.RelayModeModerations: + if textRequest.Input == nil || textRequest.Input == "" { + return nil, errors.New("field input is required") + } + case relayconstant.RelayModeEdits: + if textRequest.Instruction == "" { + return nil, errors.New("field instruction is required") + } + } + return textRequest, nil +} + +func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { + + request := &dto.GeminiChatRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if len(request.Contents) == 0 { + return nil, errors.New("contents is required") + } + + //if c.Query("alt") == "sse" { + // relayInfo.IsStream = true + //} + + return request, nil +} diff --git a/relay/image_handler.go b/relay/image_handler.go index 8e0598630..14a7103c3 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -2,219 +2,94 @@ package relay import ( "bytes" - "encoding/json" - "errors" "fmt" "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" - "one-api/model" + "one-api/logger" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" - "one-api/setting" + "one-api/setting/model_setting" "one-api/types" "strings" "github.com/gin-gonic/gin" ) -func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) { - imageRequest := &dto.ImageRequest{} +func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - switch info.RelayMode { - case relayconstant.RelayModeImagesEdits: - _, err := c.MultipartForm() - if err != nil { - return nil, err - } - formData := c.Request.PostForm - imageRequest.Prompt = formData.Get("prompt") - imageRequest.Model = formData.Get("model") - imageRequest.N = common.String2Int(formData.Get("n")) - imageRequest.Quality = formData.Get("quality") - imageRequest.Size = formData.Get("size") - - if imageRequest.Model == "gpt-image-1" { - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } - - if info.ApiType == constant.APITypeVolcEngine { - watermark := formData.Has("watermark") - imageRequest.Watermark = &watermark - } - default: - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-3" - } - - if strings.Contains(imageRequest.Size, "×") { - return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") - } - - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - } else if imageRequest.Model == "dall-e-3" { - if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") - } - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - } else if imageRequest.Model == "gpt-image-1" { - if imageRequest.Quality == "" { - imageRequest.Quality = "auto" - } - } - - if imageRequest.Prompt == "" { - return nil, errors.New("prompt is required") - } - - if imageRequest.N == 0 { - imageRequest.N = 1 - } + imageReq, ok := info.Request.(*dto.ImageRequest) + if !ok { + return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - if setting.ShouldCheckPromptSensitive() { - words, err := service.CheckSensitiveInput(imageRequest.Prompt) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) - return nil, err - } - } - return imageRequest, nil -} - -func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoImage(c) - - imageRequest, err := getAndValidImageRequest(c, relayInfo) + request, err := common.DeepCopy(imageReq) if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - err = helper.ModelMappedHelper(c, relayInfo, imageRequest) + err = helper.ModelMappedHelper(c, info, request) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) - } - var preConsumedQuota int - var quota int - var userQuota int - if !priceData.UsePrice { - // modelRatio 16 = modelPrice $0.04 - // per 1 modelRatio = $0.04 / 16 - // priceData.ModelPrice = 0.0025 * priceData.ModelRatio - preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - } else { - sizeRatio := 1.0 - qualityRatio := 1.0 - - if strings.HasPrefix(imageRequest.Model, "dall-e") { - // Size - if imageRequest.Size == "256x256" { - sizeRatio = 0.4 - } else if imageRequest.Size == "512x512" { - sizeRatio = 0.45 - } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1 - } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - sizeRatio = 2 - } - - if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" { - qualityRatio = 2.0 - if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { - qualityRatio = 1.5 - } - } - } - - // reset model price - priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) - quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) - userQuota, err = model.GetUserQuota(relayInfo.UserId, false) - if err != nil { - return types.NewError(err, types.ErrorCodeQueryDataError) - } - if userQuota-quota < 0 { - return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota) - } - } - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader - convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) - } - if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits { - requestBody = convertedRequest.(io.Reader) + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { + body, err := common.GetRequestBody(c) + if err != nil { + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + requestBody = bytes.NewBuffer(body) } else { - jsonData, err := json.Marshal(convertedRequest) + convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed) } - requestBody = bytes.NewBuffer(jsonData) - } - if common.DebugEnabled { - println(fmt.Sprintf("image request body: %s", requestBody)) + switch convertedRequest.(type) { + case *bytes.Buffer: + requestBody = convertedRequest.(io.Reader) + default: + jsonData, err := common.Marshal(convertedRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + + // apply param override + if len(info.ParamOverride) > 0 { + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + } + } + + if common.DebugEnabled { + logger.LogDebug(c, fmt.Sprintf("image request body: %s", string(jsonData))) + } + requestBody = bytes.NewBuffer(jsonData) + } } statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newAPIError = service.RelayErrorHandler(httpResp, false) // reset status code 重置状态码 @@ -223,7 +98,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) @@ -231,17 +106,23 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } if usage.(*dto.Usage).TotalTokens == 0 { - usage.(*dto.Usage).TotalTokens = imageRequest.N + usage.(*dto.Usage).TotalTokens = int(request.N) } if usage.(*dto.Usage).PromptTokens == 0 { - usage.(*dto.Usage).PromptTokens = imageRequest.N + usage.(*dto.Usage).PromptTokens = int(request.N) } + quality := "standard" - if imageRequest.Quality == "hd" { + if request.Quality == "hd" { quality = "hd" } - logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent) + var logContent string + + if len(request.Size) > 0 { + logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality) + } + + postConsumeQuota(c, info, usage.(*dto.Usage), logContent) return nil } diff --git a/relay/relay-mj.go b/relay/mjproxy_handler.go similarity index 87% rename from relay/relay-mj.go rename to relay/mjproxy_handler.go index e7f316b98..7c52cb6be 100644 --- a/relay/relay-mj.go +++ b/relay/mjproxy_handler.go @@ -170,26 +170,23 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo return } -func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { - startTime := time.Now().UnixNano() / int64(time.Millisecond) - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - //group := c.GetString("group") - channelId := c.GetInt("channel_id") - relayInfo := relaycommon.GenRelayInfo(c) +func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse { var swapFaceRequest dto.SwapFaceRequest err := common.UnmarshalBodyReusable(c, &swapFaceRequest) if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") } + + info.InitChannelMeta(c) + if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") } modelName := service.CoverActionToModelName(constant.MjActionSwapFace) - priceData := helper.ModelPriceHelperPerCall(c, relayInfo) + priceData := helper.ModelPriceHelperPerCall(c, info) - userQuota, err := model.GetUserQuota(userId, false) + userQuota, err := model.GetUserQuota(info.UserId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -212,32 +209,31 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func() { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) + err := service.PostConsumeQuota(info, priceData.Quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) other := service.GenerateMjOtherInfo(priceData) - model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: channelId, + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: priceData.Quota, Content: logContent, - TokenId: tokenId, - UserQuota: userQuota, - Group: relayInfo.UsingGroup, + TokenId: info.TokenId, + Group: info.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) - model.UpdateChannelUsedQuota(channelId, priceData.Quota) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota) + model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota) } }() midjResponse := &mjResp.Response midjourneyTask := &model.Midjourney{ - UserId: userId, + UserId: info.UserId, Code: midjResponse.Code, Action: constant.MjActionSwapFace, MjId: midjResponse.Result, @@ -245,7 +241,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { PromptEn: "", Description: midjResponse.Description, State: "", - SubmitTime: startTime, + SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond), StartTime: time.Now().UnixNano() / int64(time.Millisecond), FinishTime: 0, ImageUrl: "", @@ -300,7 +296,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") } - common.IOCopyBytesGracefully(c, nil, respBody) + service.IOCopyBytesGracefully(c, nil, respBody) return nil } @@ -369,14 +365,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse return nil } -func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { - - //tokenId := c.GetInt("token_id") - //channelType := c.GetInt("channel") - userId := c.GetInt("id") - group := c.GetString("group") - channelId := c.GetInt("channel_id") - relayInfo := relaycommon.GenRelayInfo(c) +func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse { consumeQuota := true var midjRequest dto.MidjourneyRequest err := common.UnmarshalBodyReusable(c, &midjRequest) @@ -384,35 +373,37 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") } - if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 + relayInfo.InitChannelMeta(c) + + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 mjErr := service.CoverPlusActionToNormalAction(&midjRequest) if mjErr != nil { return mjErr } - relayMode = relayconstant.RelayModeMidjourneyChange + relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange } - if relayMode == relayconstant.RelayModeMidjourneyVideo { + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { midjRequest.Action = constant.MjActionVideo } - if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") } midjRequest.Action = constant.MjActionImagine - } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 midjRequest.Action = constant.MjActionDescribe - } else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复 midjRequest.Action = constant.MjActionEdits - } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only midjRequest.Action = constant.MjActionShorten - } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionBlend - } else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionUpload } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 mjId := "" - if relayMode == relayconstant.RelayModeMidjourneyChange { + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange { if midjRequest.TaskId == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") } else if midjRequest.Action == "" { @@ -422,7 +413,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } //action = midjRequest.Action mjId = midjRequest.TaskId - } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange { if midjRequest.Content == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") } @@ -432,13 +423,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } mjId = params.TaskId midjRequest.Action = params.Action - } else if relayMode == relayconstant.RelayModeMidjourneyModal { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal { //if midjRequest.MaskBase64 == "" { // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") //} mjId = midjRequest.TaskId midjRequest.Action = constant.MjActionModal - } else if relayMode == relayconstant.RelayModeMidjourneyVideo { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { midjRequest.Action = constant.MjActionVideo if midjRequest.TaskId == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") @@ -448,12 +439,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons mjId = midjRequest.TaskId } - originTask := model.GetByMJId(userId, mjId) + originTask := model.GetByMJId(relayInfo.UserId, mjId) if originTask == nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 if setting.MjActionCheckSuccessEnabled { - if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { + if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") } } @@ -496,7 +487,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons priceData := helper.ModelPriceHelperPerCall(c, relayInfo) - userQuota, err := model.GetUserQuota(userId, false) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -521,24 +512,23 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) other := service.GenerateMjOtherInfo(priceData) model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: channelId, + ChannelId: relayInfo.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: priceData.Quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, - Group: group, + Group: relayInfo.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) - model.UpdateChannelUsedQuota(channelId, priceData.Quota) + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota) } }() @@ -550,7 +540,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons // 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}} // other: 提交错误,description为错误描述 midjourneyTask := &model.Midjourney{ - UserId: userId, + UserId: relayInfo.UserId, Code: midjResponse.Code, Action: midjRequest.Action, MjId: midjResponse.Result, @@ -572,7 +562,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //无实例账号自动禁用渠道(No available account instance) channel, err := model.GetChannelById(midjourneyTask.ChannelId, true) if err != nil { - common.SysError("get_channel_null: " + err.Error()) + common.SysLog("get_channel_null: " + err.Error()) } if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 2ce12a872..1ee85986c 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -1,8 +1,8 @@ package relay import ( + "github.com/gin-gonic/gin" "one-api/constant" - commonconstant "one-api/constant" "one-api/relay/channel" "one-api/relay/channel/ali" "one-api/relay/channel/aws" @@ -19,6 +19,7 @@ import ( "one-api/relay/channel/jina" "one-api/relay/channel/mistral" "one-api/relay/channel/mokaai" + "one-api/relay/channel/moonshot" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" "one-api/relay/channel/palm" @@ -27,6 +28,7 @@ import ( taskjimeng "one-api/relay/channel/task/jimeng" "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" + taskVidu "one-api/relay/channel/task/vidu" "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" "one-api/relay/channel/volcengine" @@ -34,6 +36,7 @@ import ( "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" + "strconv" ) func GetAdaptor(apiType int) channel.Adaptor { @@ -96,20 +99,36 @@ func GetAdaptor(apiType int) channel.Adaptor { return &coze.Adaptor{} case constant.APITypeJimeng: return &jimeng.Adaptor{} + case constant.APITypeMoonshot: + return &moonshot.Adaptor{} // Moonshot uses Claude API } return nil } -func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor { +func GetTaskPlatform(c *gin.Context) constant.TaskPlatform { + channelType := c.GetInt("channel_type") + if channelType > 0 { + return constant.TaskPlatform(strconv.Itoa(channelType)) + } + return constant.TaskPlatform(c.GetString("platform")) +} + +func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor { switch platform { //case constant.APITypeAIProxyLibrary: // return &aiproxy.Adaptor{} - case commonconstant.TaskPlatformSuno: + case constant.TaskPlatformSuno: return &suno.TaskAdaptor{} - case commonconstant.TaskPlatformKling: - return &kling.TaskAdaptor{} - case commonconstant.TaskPlatformJimeng: - return &taskjimeng.TaskAdaptor{} + } + if channelType, err := strconv.ParseInt(string(platform), 10, 64); err == nil { + switch channelType { + case constant.ChannelTypeKling: + return &kling.TaskAdaptor{} + case constant.ChannelTypeJimeng: + return &taskjimeng.TaskAdaptor{} + case constant.ChannelTypeVidu: + return &taskVidu.TaskAdaptor{} + } } return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index 25f63d40e..0754e0234 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -22,24 +22,31 @@ import ( /* Task 任务通过平台、Action 区分任务 */ -func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { +func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + info.InitChannelMeta(c) + // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields + if info.TaskRelayInfo == nil { + info.TaskRelayInfo = &relaycommon.TaskRelayInfo{} + } platform := constant.TaskPlatform(c.GetString("platform")) - relayInfo := relaycommon.GenTaskRelayInfo(c) + if platform == "" { + platform = GetTaskPlatform(c) + } adaptor := GetTaskAdaptor(platform) if adaptor == nil { return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) } - adaptor.Init(relayInfo) + adaptor.Init(info) // get & validate taskRequest 获取并验证文本请求 - taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo) + taskErr = adaptor.ValidateRequestAndSetAction(c, info) if taskErr != nil { return } - modelName := relayInfo.OriginModelName + modelName := info.OriginModelName if modelName == "" { - modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action) + modelName = service.CoverTaskActionToModelName(platform, info.Action) } modelPrice, success := ratio_setting.GetModelPrice(modelName, true) if !success { @@ -52,15 +59,15 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } // 预扣 - groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup) + groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup) var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup) + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup) if hasUserGroupRatio { ratio = modelPrice * userGroupRatio } else { ratio = modelPrice * groupRatio } - userQuota, err := model.GetUserQuota(relayInfo.UserId, false) + userQuota, err := model.GetUserQuota(info.UserId, false) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return @@ -71,8 +78,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { return } - if relayInfo.OriginTaskID != "" { - originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID) + if info.OriginTaskID != "" { + originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) return @@ -81,7 +88,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) return } - if originTask.ChannelId != relayInfo.ChannelId { + if originTask.ChannelId != info.ChannelId { channel, err := model.GetChannelById(originTask.ChannelId, true) if err != nil { taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) @@ -94,19 +101,19 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { c.Set("channel_id", originTask.ChannelId) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - relayInfo.BaseUrl = channel.GetBaseURL() - relayInfo.ChannelId = originTask.ChannelId + info.ChannelBaseUrl = channel.GetBaseURL() + info.ChannelId = originTask.ChannelId } } // build body - requestBody, err := adaptor.BuildRequestBody(c, relayInfo) + requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) return } // do request - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return @@ -120,11 +127,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { defer func() { // release quota - if relayInfo.ConsumeQuota && taskErr == nil { + if info.ConsumeQuota && taskErr == nil { - err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) + err := service.PostConsumeQuota(info, quota, 0, true) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") @@ -132,41 +139,40 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if hasUserGroupRatio { gRatio = userGroupRatio } - logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action) + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action) other := make(map[string]interface{}) other["model_price"] = modelPrice other["group_ratio"] = groupRatio if hasUserGroupRatio { other["user_group_ratio"] = userGroupRatio } - model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: relayInfo.ChannelId, + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: quota, Content: logContent, - TokenId: relayInfo.TokenId, - UserQuota: userQuota, - Group: relayInfo.UsingGroup, + TokenId: info.TokenId, + Group: info.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) - model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota) + model.UpdateChannelUsedQuota(info.ChannelId, quota) } } }() - taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo) + taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { return } - relayInfo.ConsumeQuota = true + info.ConsumeQuota = true // insert task - task := model.InitTask(platform, relayInfo) + task := model.InitTask(platform, info) task.TaskID = taskID task.Quota = quota task.Data = taskData - task.Action = relayInfo.Action + task.Action = info.Action err = task.Insert() if err != nil { taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) @@ -178,7 +184,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, - relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder, + relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder, } func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { @@ -255,6 +261,9 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { taskId := c.Param("task_id") + if taskId == "" { + taskId = c.GetString("task_id") + } userId := c.GetInt("id") originTask, exist, err := model.GetByTaskId(userId, taskId) diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index a092de4bf..fa3c7bbb4 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -3,86 +3,75 @@ package relay import ( "bytes" "fmt" + "io" "net/http" "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "one-api/setting/model_setting" "one-api/types" "github.com/gin-gonic/gin" ) -func getRerankPromptToken(rerankRequest dto.RerankRequest) int { - token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) - for _, document := range rerankRequest.Documents { - tkm := service.CountTokenInput(document, rerankRequest.Model) - token += tkm +func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) + + rerankReq, ok := info.Request.(*dto.RerankRequest) + if !ok { + return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.RerankRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - return token -} -func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) { - - var rerankRequest *dto.RerankRequest - err := common.UnmarshalBodyReusable(c, &rerankRequest) + request, err := common.DeepCopy(rerankReq) if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest) - - if rerankRequest.Query == "" { - return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest) - } - if len(rerankRequest.Documents) == 0 { - return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest) - } - - err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) + err = helper.ModelMappedHelper(c, info, request) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - promptToken := getRerankPromptToken(*rerankRequest) - relayInfo.PromptTokens = promptToken - - priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) - } - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) - convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + var requestBody io.Reader + if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { + body, err := common.GetRequestBody(c) + if err != nil { + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + requestBody = bytes.NewBuffer(body) + } else { + convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + jsonData, err := common.Marshal(convertedRequest) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + + // apply param override + if len(info.ParamOverride) > 0 { + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) + if err != nil { + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) + } + } + + if common.DebugEnabled { + println(fmt.Sprintf("Rerank request body: %s", string(jsonData))) + } + requestBody = bytes.NewBuffer(jsonData) } - jsonData, err := common.Marshal(convertedRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) - } - requestBody := bytes.NewBuffer(jsonData) - if common.DebugEnabled { - println(fmt.Sprintf("Rerank request body: %s", requestBody.String())) - } - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -99,12 +88,12 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") return nil } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 52d1db6ef..f5f624c92 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -2,8 +2,6 @@ package relay import ( "bytes" - "encoding/json" - "errors" "fmt" "io" "net/http" @@ -12,7 +10,6 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" - "one-api/setting" "one-api/setting/model_setting" "one-api/types" "strings" @@ -20,111 +17,50 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { - request := &dto.OpenAIResponsesRequest{} - err := common.UnmarshalBodyReusable(c, request) +func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) + + responsesReq, ok := info.Request.(*dto.OpenAIResponsesRequest) + if !ok { + return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) + } + + request, err := common.DeepCopy(responsesReq) if err != nil { - return nil, err + return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - if request.Model == "" { - return nil, errors.New("model is required") - } - if len(request.Input) == 0 { - return nil, errors.New("input is required") - } - return request, nil -} - -func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { - sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) - return sensitiveWords, err -} - -func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int { - inputTokens := service.CountTokenInput(req.Input, req.Model) - info.PromptTokens = inputTokens - return inputTokens -} - -func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { - req, err := getAndValidateResponsesRequest(c) + err = helper.ModelMappedHelper(c, info, request) if err != nil { - common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } - relayInfo := relaycommon.GenRelayInfoResponses(c, req) - - if setting.ShouldCheckPromptSensitive() { - sensitiveWords, err := checkInputSensitive(req, relayInfo) - if err != nil { - common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) - } - } - - err = helper.ModelMappedHelper(c, relayInfo, req) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) - } - - if value, exists := c.Get("prompt_tokens"); exists { - promptTokens := value.(int) - relayInfo.SetPromptTokens(promptTokens) - } else { - promptTokens := getInputTokens(req, relayInfo) - c.Set("prompt_tokens", promptTokens) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) - } - // pre consume quota - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewError(err, types.ErrorCodeReadRequestBodyFailed) + return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) + convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - jsonData, err := json.Marshal(convertedRequest) + jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override - if len(relayInfo.ParamOverride) > 0 { - reqMap := make(map[string]interface{}) - err = json.Unmarshal(jsonData, &reqMap) + if len(info.ParamOverride) > 0 { + jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) - } - for key, value := range relayInfo.ParamOverride { - reqMap[key] = value - } - jsonData, err = json.Marshal(reqMap) - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } @@ -135,7 +71,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -153,17 +89,17 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } } - usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, httpResp, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + postConsumeQuota(c, info, usage.(*dto.Usage), "") } return nil } diff --git a/relay/websocket.go b/relay/websocket.go index 659e27d56..2d313154c 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -4,7 +4,6 @@ import ( "fmt" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/relay/helper" "one-api/service" "one-api/types" @@ -12,65 +11,35 @@ import ( "github.com/gorilla/websocket" ) -func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoWs(c, ws) +func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - // get & validate textRequest 获取并验证文本请求 - //realtimeEvent, err := getAndValidateWssRequest(c, ws) - //if err != nil { - // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error())) - // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) - //} - - err := helper.ModelMappedHelper(c, relayInfo, nil) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) //var requestBody io.Reader //firstWssRequest, _ := c.Get("first_wss_request") //requestBody = bytes.NewBuffer(firstWssRequest.([]byte)) statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, nil) + resp, err := adaptor.DoRequest(c, info, nil) if err != nil { return types.NewError(err, types.ErrorCodeDoRequestFailed) } if resp != nil { - relayInfo.TargetWs = resp.(*websocket.Conn) - defer relayInfo.TargetWs.Close() + info.TargetWs = resp.(*websocket.Conn) + defer info.TargetWs.Close() } - usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, nil, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, - userQuota, priceData, "") + service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "") return nil } diff --git a/router/api-router.go b/router/api-router.go index bc49803a2..773857385 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -24,7 +24,7 @@ func SetApiRouter(router *gin.Engine) { //apiRouter.GET("/midjourney", controller.GetMidjourney) apiRouter.GET("/home_page_content", controller.GetHomePageContent) apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing) - apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) + apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) @@ -44,6 +44,7 @@ func SetApiRouter(router *gin.Engine) { { userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login) + userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin) //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog) userRoute.GET("/logout", controller.Logout) userRoute.GET("/epay/notify", controller.EpayNotify) @@ -66,6 +67,13 @@ func SetApiRouter(router *gin.Engine) { selfRoute.POST("/stripe/amount", controller.RequestStripeAmount) selfRoute.POST("/aff_transfer", controller.TransferAffQuota) selfRoute.PUT("/setting", controller.UpdateUserSetting) + + // 2FA routes + selfRoute.GET("/2fa/status", controller.Get2FAStatus) + selfRoute.POST("/2fa/setup", controller.Setup2FA) + selfRoute.POST("/2fa/enable", controller.Enable2FA) + selfRoute.POST("/2fa/disable", controller.Disable2FA) + selfRoute.POST("/2fa/backup_codes", controller.RegenerateBackupCodes) } adminRoute := userRoute.Group("/") @@ -78,6 +86,10 @@ func SetApiRouter(router *gin.Engine) { adminRoute.POST("/manage", controller.ManageUser) adminRoute.PUT("/", controller.UpdateUser) adminRoute.DELETE("/:id", controller.DeleteUser) + + // Admin 2FA routes + adminRoute.GET("/2fa/stats", controller.Admin2FAStats) + adminRoute.DELETE("/:id/2fa", controller.AdminDisable2FA) } } optionRoute := apiRouter.Group("/option") @@ -102,6 +114,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/models", controller.ChannelListModels) channelRoute.GET("/models_enabled", controller.EnabledListModels) channelRoute.GET("/:id", controller.GetChannel) + channelRoute.POST("/:id/key", middleware.CriticalRateLimit(), middleware.DisableCache(), controller.GetChannelKey) channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) @@ -120,6 +133,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.POST("/batch/tag", controller.BatchSetChannelTag) channelRoute.GET("/tag/models", controller.GetTagModels) channelRoute.POST("/copy/:id", controller.CopyChannel) + channelRoute.POST("/multi_key/manage", controller.ManageMultiKeys) } tokenRoute := apiRouter.Group("/token") tokenRoute.Use(middleware.UserAuth()) @@ -132,6 +146,17 @@ func SetApiRouter(router *gin.Engine) { tokenRoute.DELETE("/:id", controller.DeleteToken) tokenRoute.POST("/batch", controller.DeleteTokenBatch) } + + usageRoute := apiRouter.Group("/usage") + usageRoute.Use(middleware.CriticalRateLimit()) + { + tokenUsageRoute := usageRoute.Group("/token") + tokenUsageRoute.Use(middleware.TokenAuth()) + { + tokenUsageRoute.GET("/", controller.GetTokenUsage) + } + } + redemptionRoute := apiRouter.Group("/redemption") redemptionRoute.Use(middleware.AdminAuth()) { @@ -159,13 +184,22 @@ func SetApiRouter(router *gin.Engine) { logRoute.Use(middleware.CORS()) { logRoute.GET("/token", controller.GetLogByKey) - } groupRoute := apiRouter.Group("/group") groupRoute.Use(middleware.AdminAuth()) { groupRoute.GET("/", controller.GetGroups) } + + prefillGroupRoute := apiRouter.Group("/prefill_group") + prefillGroupRoute.Use(middleware.AdminAuth()) + { + prefillGroupRoute.GET("/", controller.GetPrefillGroups) + prefillGroupRoute.POST("/", controller.CreatePrefillGroup) + prefillGroupRoute.PUT("/", controller.UpdatePrefillGroup) + prefillGroupRoute.DELETE("/:id", controller.DeletePrefillGroup) + } + mjRoute := apiRouter.Group("/mj") mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney) mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney) @@ -175,5 +209,30 @@ func SetApiRouter(router *gin.Engine) { taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask) taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask) } + + vendorRoute := apiRouter.Group("/vendors") + vendorRoute.Use(middleware.AdminAuth()) + { + vendorRoute.GET("/", controller.GetAllVendors) + vendorRoute.GET("/search", controller.SearchVendors) + vendorRoute.GET("/:id", controller.GetVendorMeta) + vendorRoute.POST("/", controller.CreateVendorMeta) + vendorRoute.PUT("/", controller.UpdateVendorMeta) + vendorRoute.DELETE("/:id", controller.DeleteVendorMeta) + } + + modelsRoute := apiRouter.Group("/models") + modelsRoute.Use(middleware.AdminAuth()) + { + modelsRoute.GET("/sync_upstream/preview", controller.SyncUpstreamPreview) + modelsRoute.POST("/sync_upstream", controller.SyncUpstreamModels) + modelsRoute.GET("/missing", controller.GetMissingModels) + modelsRoute.GET("/", controller.GetAllModelsMeta) + modelsRoute.GET("/search", controller.SearchModelsMeta) + modelsRoute.GET("/:id", controller.GetModelMeta) + modelsRoute.POST("/", controller.CreateModelMeta) + modelsRoute.PUT("/", controller.UpdateModelMeta) + modelsRoute.DELETE("/:id", controller.DeleteModelMeta) + } } } diff --git a/router/main.go b/router/main.go index 0d2bfdcea..235764270 100644 --- a/router/main.go +++ b/router/main.go @@ -3,11 +3,12 @@ package router import ( "embed" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "os" "strings" + + "github.com/gin-gonic/gin" ) func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { diff --git a/router/relay-router.go b/router/relay-router.go index 5b293dbdc..e0f05e97b 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,9 +1,11 @@ package router import ( + "one-api/constant" "one-api/controller" "one-api/middleware" "one-api/relay" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -16,9 +18,43 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) { - modelsRouter.GET("", controller.ListModels) - modelsRouter.GET("/:model", controller.RetrieveModel) + modelsRouter.GET("", func(c *gin.Context) { + switch { + case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "": + controller.ListModels(c, constant.ChannelTypeAnthropic) + case c.GetHeader("x-goog-api-key") != "" || c.Query("key") != "": // 单独的适配 + controller.RetrieveModel(c, constant.ChannelTypeGemini) + default: + controller.ListModels(c, constant.ChannelTypeOpenAI) + } + }) + + modelsRouter.GET("/:model", func(c *gin.Context) { + switch { + case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "": + controller.RetrieveModel(c, constant.ChannelTypeAnthropic) + default: + controller.RetrieveModel(c, constant.ChannelTypeOpenAI) + } + }) } + + geminiRouter := router.Group("/v1beta/models") + geminiRouter.Use(middleware.TokenAuth()) + { + geminiRouter.GET("", func(c *gin.Context) { + controller.ListModels(c, constant.ChannelTypeGemini) + }) + } + + geminiCompatibleRouter := router.Group("/v1beta/openai/models") + geminiCompatibleRouter.Use(middleware.TokenAuth()) + { + geminiCompatibleRouter.GET("", func(c *gin.Context) { + controller.ListModels(c, constant.ChannelTypeOpenAI) + }) + } + playgroundRouter := router.Group("/pg") playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute()) { @@ -28,28 +64,83 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.Use(middleware.TokenAuth()) relayV1Router.Use(middleware.ModelRequestRateLimit()) { - // WebSocket 路由 + // WebSocket 路由(统一到 Relay) wsRouter := relayV1Router.Group("") wsRouter.Use(middleware.Distribute()) - wsRouter.GET("/realtime", controller.WssRelay) + wsRouter.GET("/realtime", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIRealtime) + }) } { //http router httpRouter := relayV1Router.Group("") httpRouter.Use(middleware.Distribute()) - httpRouter.POST("/messages", controller.RelayClaude) - httpRouter.POST("/completions", controller.Relay) - httpRouter.POST("/chat/completions", controller.Relay) - httpRouter.POST("/edits", controller.Relay) - httpRouter.POST("/images/generations", controller.Relay) - httpRouter.POST("/images/edits", controller.Relay) + + // claude related routes + httpRouter.POST("/messages", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatClaude) + }) + + // chat related routes + httpRouter.POST("/completions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + httpRouter.POST("/chat/completions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + + // response related routes + httpRouter.POST("/responses", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIResponses) + }) + + // image related routes + httpRouter.POST("/edits", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + httpRouter.POST("/images/generations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + httpRouter.POST("/images/edits", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIImage) + }) + + // embedding related routes + httpRouter.POST("/embeddings", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatEmbedding) + }) + + // audio related routes + httpRouter.POST("/audio/transcriptions", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + httpRouter.POST("/audio/translations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + httpRouter.POST("/audio/speech", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAIAudio) + }) + + // rerank related routes + httpRouter.POST("/rerank", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatRerank) + }) + + // gemini relay routes + httpRouter.POST("/engines/:model/embeddings", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) + httpRouter.POST("/models/*path", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) + + // other relay routes + httpRouter.POST("/moderations", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatOpenAI) + }) + + // not implemented httpRouter.POST("/images/variations", controller.RelayNotImplemented) - httpRouter.POST("/embeddings", controller.Relay) - httpRouter.POST("/engines/:model/embeddings", controller.Relay) - httpRouter.POST("/audio/transcriptions", controller.Relay) - httpRouter.POST("/audio/translations", controller.Relay) - httpRouter.POST("/audio/speech", controller.Relay) - httpRouter.POST("/responses", controller.Relay) httpRouter.GET("/files", controller.RelayNotImplemented) httpRouter.POST("/files", controller.RelayNotImplemented) httpRouter.DELETE("/files/:id", controller.RelayNotImplemented) @@ -61,9 +152,6 @@ func SetRelayRouter(router *gin.Engine) { httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) httpRouter.DELETE("/models/:model", controller.RelayNotImplemented) - httpRouter.POST("/moderations", controller.Relay) - httpRouter.POST("/rerank", controller.Relay) - httpRouter.POST("/models/*path", controller.Relay) } relayMjRouter := router.Group("/mj") @@ -87,7 +175,9 @@ func SetRelayRouter(router *gin.Engine) { relayGeminiRouter.Use(middleware.Distribute()) { // Gemini API 路径格式: /v1beta/models/{model_name}:{action} - relayGeminiRouter.POST("/models/*path", controller.Relay) + relayGeminiRouter.POST("/models/*path", func(c *gin.Context) { + controller.Relay(c, types.RelayFormatGemini) + }) } } diff --git a/router/video-router.go b/router/video-router.go index 9e605d541..bcc05eae9 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -20,5 +20,15 @@ func SetVideoRouter(router *gin.Engine) { { klingV1Router.POST("/videos/text2video", controller.RelayTask) klingV1Router.POST("/videos/image2video", controller.RelayTask) + klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask) + klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask) + } + + // Jimeng official API routes - direct mapping to official API format + jimengOfficialGroup := router.Group("jimeng") + jimengOfficialGroup.Use(middleware.JimengRequestConvert(), middleware.TokenAuth(), middleware.Distribute()) + { + // Maps to: /?Action=CVSync2AsyncSubmitTask&Version=2022-08-31 and /?Action=CVSync2AsyncGetResult&Version=2022-08-31 + jimengOfficialGroup.POST("/", controller.RelayTask) } } diff --git a/service/cf_worker.go b/service/cf_worker.go index ae6e1ffe9..4a7b43760 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -42,16 +42,16 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload)) } -func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { +func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) { if setting.EnableWorker() { - common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) req := &WorkerRequest{ URL: originUrl, Key: setting.WorkerValidKey, } return DoWorkerRequest(req) } else { - common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) return http.Get(originUrl) } } diff --git a/service/channel.go b/service/channel.go index 4d38e6edc..6ddc8e9ec 100644 --- a/service/channel.go +++ b/service/channel.go @@ -18,6 +18,14 @@ func formatNotifyType(channelId int, status int) string { // disable & notify func DisableChannel(channelError types.ChannelError, reason string) { + common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason)) + + // 检查是否启用自动禁用功能 + if !channelError.AutoBan { + common.SysLog(fmt.Sprintf("通道「%s」(#%d)未启用自动禁用功能,跳过禁用操作", channelError.ChannelName, channelError.ChannelId)) + return + } + success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason) if success { subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelError.ChannelName, channelError.ChannelId) @@ -45,7 +53,7 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool { if types.IsChannelError(err) { return true } - if types.IsLocalError(err) { + if types.IsSkipRetryError(err) { return false } if err.StatusCode == http.StatusUnauthorized { diff --git a/service/convert.go b/service/convert.go index 593b59d94..b232ca396 100644 --- a/service/convert.go +++ b/service/convert.go @@ -153,9 +153,13 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re toolCalls = append(toolCalls, toolCall) case "tool_result": // Add tool result as a separate message + toolName := mediaMsg.Name + if toolName == "" { + toolName = claudeRequest.SearchToolNameByToolCallId(mediaMsg.ToolUseId) + } oaiToolMessage := dto.Message{ Role: "tool", - Name: &mediaMsg.Name, + Name: &toolName, ToolCallId: mediaMsg.ToolUseId, } //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text) @@ -188,28 +192,6 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re return &openAIRequest, nil } -func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode { - claudeError := dto.ClaudeError{ - Type: "new_api_error", - Message: openAIError.Error.Message, - } - return &dto.ClaudeErrorWithStatusCode{ - Error: claudeError, - StatusCode: openAIError.StatusCode, - } -} - -func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode { - openAIError := dto.OpenAIError{ - Message: claudeError.Error.Message, - Type: "new_api_error", - } - return &dto.OpenAIErrorWithStatusCode{ - Error: openAIError, - StatusCode: claudeError.StatusCode, - } -} - func generateStopBlock(index int) *dto.ClaudeResponse { return &dto.ClaudeResponse{ Type: "content_block_stop", @@ -240,40 +222,77 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon // Type: "ping", //}) if openAIResponse.IsToolCall() { + info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools resp := &dto.ClaudeResponse{ Type: "content_block_start", ContentBlock: &dto.ClaudeMediaMessage{ - Id: openAIResponse.GetFirstToolCall().ID, - Type: "tool_use", - Name: openAIResponse.GetFirstToolCall().Function.Name, + Id: openAIResponse.GetFirstToolCall().ID, + Type: "tool_use", + Name: openAIResponse.GetFirstToolCall().Function.Name, + Input: map[string]interface{}{}, }, } resp.SetIndex(0) claudeResponses = append(claudeResponses, resp) } else { - //resp := &dto.ClaudeResponse{ - // Type: "content_block_start", - // ContentBlock: &dto.ClaudeMediaMessage{ - // Type: "text", - // Text: common.GetPointer[string](""), - // }, - //} - //resp.SetIndex(0) - //claudeResponses = append(claudeResponses, resp) + + } + // 判断首个响应是否存在内容(非标准的 OpenAI 响应) + if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 { + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Index: &info.ClaudeConvertInfo.Index, + Type: "content_block_start", + ContentBlock: &dto.ClaudeMediaMessage{ + Type: "text", + Text: common.GetPointer[string](""), + }, + }) + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Index: &info.ClaudeConvertInfo.Index, + Type: "content_block_delta", + Delta: &dto.ClaudeMediaMessage{ + Type: "text_delta", + Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()), + }, + }) + info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText } return claudeResponses } if len(openAIResponse.Choices) == 0 { // no choices - // TODO: handle this case + // 可能为非标准的 OpenAI 响应,判断是否已经完成 + if info.Done { + claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) + oaiUsage := info.ClaudeConvertInfo.Usage + if oaiUsage != nil { + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_delta", + Usage: &dto.ClaudeUsage{ + InputTokens: oaiUsage.PromptTokens, + OutputTokens: oaiUsage.CompletionTokens, + CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens, + CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens, + }, + Delta: &dto.ClaudeMediaMessage{ + StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)), + }, + }) + } + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_stop", + }) + } return claudeResponses } else { chosenChoice := openAIResponse.Choices[0] if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" { // should be done info.FinishReason = *chosenChoice.FinishReason - return claudeResponses + if !info.Done { + return claudeResponses + } } if info.Done { claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) @@ -389,22 +408,26 @@ func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relayco } for _, choice := range openAIResponse.Choices { stopReason = stopReasonOpenAI2Claude(choice.FinishReason) - claudeContent := dto.ClaudeMediaMessage{} if choice.FinishReason == "tool_calls" { - claudeContent.Type = "tool_use" - claudeContent.Id = choice.Message.ToolCallId - claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name - var mapParams map[string]interface{} - if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil { - claudeContent.Input = mapParams - } else { - claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments + for _, toolUse := range choice.Message.ParseToolCalls() { + claudeContent := dto.ClaudeMediaMessage{} + claudeContent.Type = "tool_use" + claudeContent.Id = toolUse.ID + claudeContent.Name = toolUse.Function.Name + var mapParams map[string]interface{} + if err := common.Unmarshal([]byte(toolUse.Function.Arguments), &mapParams); err == nil { + claudeContent.Input = mapParams + } else { + claudeContent.Input = toolUse.Function.Arguments + } + contents = append(contents, claudeContent) } } else { + claudeContent := dto.ClaudeMediaMessage{} claudeContent.Type = "text" claudeContent.SetText(choice.Message.StringContent()) + contents = append(contents, claudeContent) } - contents = append(contents, claudeContent) } claudeResponse.Content = contents claudeResponse.StopReason = stopReason @@ -422,6 +445,8 @@ func stopReasonOpenAI2Claude(reason string) string { return "end_turn" case "stop_sequence": return "stop_sequence" + case "length": + fallthrough case "max_tokens": return "max_tokens" case "tool_calls": @@ -438,3 +463,353 @@ func toJSONString(v interface{}) string { } return string(b) } + +func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) { + openaiRequest := &dto.GeneralOpenAIRequest{ + Model: info.UpstreamModelName, + Stream: info.IsStream, + } + + // 转换 messages + var messages []dto.Message + for _, content := range geminiRequest.Contents { + message := dto.Message{ + Role: convertGeminiRoleToOpenAI(content.Role), + } + + // 处理 parts + var mediaContents []dto.MediaContent + var toolCalls []dto.ToolCallRequest + for _, part := range content.Parts { + if part.Text != "" { + mediaContent := dto.MediaContent{ + Type: "text", + Text: part.Text, + } + mediaContents = append(mediaContents, mediaContent) + } else if part.InlineData != nil { + mediaContent := dto.MediaContent{ + Type: "image_url", + ImageUrl: &dto.MessageImageUrl{ + Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data), + Detail: "auto", + MimeType: part.InlineData.MimeType, + }, + } + mediaContents = append(mediaContents, mediaContent) + } else if part.FileData != nil { + mediaContent := dto.MediaContent{ + Type: "image_url", + ImageUrl: &dto.MessageImageUrl{ + Url: part.FileData.FileUri, + Detail: "auto", + MimeType: part.FileData.MimeType, + }, + } + mediaContents = append(mediaContents, mediaContent) + } else if part.FunctionCall != nil { + // 处理 Gemini 的工具调用 + toolCall := dto.ToolCallRequest{ + ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID + Type: "function", + Function: dto.FunctionRequest{ + Name: part.FunctionCall.FunctionName, + Arguments: toJSONString(part.FunctionCall.Arguments), + }, + } + toolCalls = append(toolCalls, toolCall) + } else if part.FunctionResponse != nil { + // 处理 Gemini 的工具响应,创建单独的 tool 消息 + toolMessage := dto.Message{ + Role: "tool", + ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID + } + toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response)) + messages = append(messages, toolMessage) + } + } + + // 设置消息内容 + if len(toolCalls) > 0 { + // 如果有工具调用,设置工具调用 + message.SetToolCalls(toolCalls) + } else if len(mediaContents) == 1 && mediaContents[0].Type == "text" { + // 如果只有一个文本内容,直接设置字符串 + message.Content = mediaContents[0].Text + } else if len(mediaContents) > 0 { + // 如果有多个内容或包含媒体,设置为数组 + message.SetMediaContent(mediaContents) + } + + // 只有当消息有内容或工具调用时才添加 + if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 { + messages = append(messages, message) + } + } + + openaiRequest.Messages = messages + + if geminiRequest.GenerationConfig.Temperature != nil { + openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature + } + if geminiRequest.GenerationConfig.TopP > 0 { + openaiRequest.TopP = geminiRequest.GenerationConfig.TopP + } + if geminiRequest.GenerationConfig.TopK > 0 { + openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK) + } + if geminiRequest.GenerationConfig.MaxOutputTokens > 0 { + openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens + } + // gemini stop sequences 最多 5 个,openai stop 最多 4 个 + if len(geminiRequest.GenerationConfig.StopSequences) > 0 { + openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4] + } + if geminiRequest.GenerationConfig.CandidateCount > 0 { + openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount + } + + // 转换工具调用 + if len(geminiRequest.GetTools()) > 0 { + var tools []dto.ToolCallRequest + for _, tool := range geminiRequest.GetTools() { + if tool.FunctionDeclarations != nil { + // 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest + functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest) + if ok { + for _, function := range functionDeclarations { + openAITool := dto.ToolCallRequest{ + Type: "function", + Function: dto.FunctionRequest{ + Name: function.Name, + Description: function.Description, + Parameters: function.Parameters, + }, + } + tools = append(tools, openAITool) + } + } + } + } + if len(tools) > 0 { + openaiRequest.Tools = tools + } + } + + // gemini system instructions + if geminiRequest.SystemInstructions != nil { + // 将系统指令作为第一条消息插入 + systemMessage := dto.Message{ + Role: "system", + Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts), + } + openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...) + } + + return openaiRequest, nil +} + +func convertGeminiRoleToOpenAI(geminiRole string) string { + switch geminiRole { + case "user": + return "user" + case "model": + return "assistant" + case "function": + return "function" + default: + return "user" + } +} + +func extractTextFromGeminiParts(parts []dto.GeminiPart) string { + var texts []string + for _, part := range parts { + if part.Text != "" { + texts = append(texts, part.Text) + } + } + return strings.Join(texts, "\n") +} + +// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式 +func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse { + geminiResponse := &dto.GeminiChatResponse{ + Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)), + PromptFeedback: dto.GeminiChatPromptFeedback{ + SafetyRatings: []dto.GeminiChatSafetyRating{}, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: openAIResponse.PromptTokens, + CandidatesTokenCount: openAIResponse.CompletionTokens, + TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens, + }, + } + + for _, choice := range openAIResponse.Choices { + candidate := dto.GeminiChatCandidate{ + Index: int64(choice.Index), + SafetyRatings: []dto.GeminiChatSafetyRating{}, + } + + // 设置结束原因 + var finishReason string + switch choice.FinishReason { + case "stop": + finishReason = "STOP" + case "length": + finishReason = "MAX_TOKENS" + case "content_filter": + finishReason = "SAFETY" + case "tool_calls": + finishReason = "STOP" + default: + finishReason = "STOP" + } + candidate.FinishReason = &finishReason + + // 转换消息内容 + content := dto.GeminiChatContent{ + Role: "model", + Parts: make([]dto.GeminiPart, 0), + } + + // 处理工具调用 + toolCalls := choice.Message.ParseToolCalls() + if len(toolCalls) > 0 { + for _, toolCall := range toolCalls { + // 解析参数 + var args map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { + args = map[string]interface{}{"arguments": toolCall.Function.Arguments} + } + } else { + args = make(map[string]interface{}) + } + + part := dto.GeminiPart{ + FunctionCall: &dto.FunctionCall{ + FunctionName: toolCall.Function.Name, + Arguments: args, + }, + } + content.Parts = append(content.Parts, part) + } + } else { + // 处理文本内容 + textContent := choice.Message.StringContent() + if textContent != "" { + part := dto.GeminiPart{ + Text: textContent, + } + content.Parts = append(content.Parts, part) + } + } + + candidate.Content = content + geminiResponse.Candidates = append(geminiResponse.Candidates, candidate) + } + + return geminiResponse +} + +// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式 +func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse { + // 检查是否有实际内容或结束标志 + hasContent := false + hasFinishReason := false + for _, choice := range openAIResponse.Choices { + if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) { + hasContent = true + } + if choice.FinishReason != nil { + hasFinishReason = true + } + } + + // 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据 + if !hasContent && !hasFinishReason { + return nil + } + + geminiResponse := &dto.GeminiChatResponse{ + Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)), + PromptFeedback: dto.GeminiChatPromptFeedback{ + SafetyRatings: []dto.GeminiChatSafetyRating{}, + }, + UsageMetadata: dto.GeminiUsageMetadata{ + PromptTokenCount: info.PromptTokens, + CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息 + TotalTokenCount: info.PromptTokens, + }, + } + + for _, choice := range openAIResponse.Choices { + candidate := dto.GeminiChatCandidate{ + Index: int64(choice.Index), + SafetyRatings: []dto.GeminiChatSafetyRating{}, + } + + // 设置结束原因 + if choice.FinishReason != nil { + var finishReason string + switch *choice.FinishReason { + case "stop": + finishReason = "STOP" + case "length": + finishReason = "MAX_TOKENS" + case "content_filter": + finishReason = "SAFETY" + case "tool_calls": + finishReason = "STOP" + default: + finishReason = "STOP" + } + candidate.FinishReason = &finishReason + } + + // 转换消息内容 + content := dto.GeminiChatContent{ + Role: "model", + Parts: make([]dto.GeminiPart, 0), + } + + // 处理工具调用 + if choice.Delta.ToolCalls != nil { + for _, toolCall := range choice.Delta.ToolCalls { + // 解析参数 + var args map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { + args = map[string]interface{}{"arguments": toolCall.Function.Arguments} + } + } else { + args = make(map[string]interface{}) + } + + part := dto.GeminiPart{ + FunctionCall: &dto.FunctionCall{ + FunctionName: toolCall.Function.Name, + Arguments: args, + }, + } + content.Parts = append(content.Parts, part) + } + } else { + // 处理文本内容 + textContent := choice.Delta.GetContentString() + if textContent != "" { + part := dto.GeminiPart{ + Text: textContent, + } + content.Parts = append(content.Parts, part) + } + } + + candidate.Content = content + geminiResponse.Candidates = append(geminiResponse.Candidates, candidate) + } + + return geminiResponse +} diff --git a/service/error.go b/service/error.go index a0713b55b..ef5cbbde6 100644 --- a/service/error.go +++ b/service/error.go @@ -1,7 +1,6 @@ package service import ( - "encoding/json" "errors" "fmt" "io" @@ -63,7 +62,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError text = "请求上游地址失败" } } - claudeError := dto.ClaudeError{ + claudeError := types.ClaudeError{ Message: text, Type: "new_api_error", } @@ -80,16 +79,13 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude } func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { - newApiErr = &types.NewAPIError{ - StatusCode: resp.StatusCode, - ErrorType: types.ErrorTypeOpenAIError, - } + newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode) responseBody, err := io.ReadAll(resp.Body) if err != nil { return } - common.CloseResponseBodyGracefully(resp) + CloseResponseBodyGracefully(resp) var errResponse dto.GeneralErrorResponse err = common.Unmarshal(responseBody, &errResponse) @@ -97,6 +93,9 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t if showBodyWhenFail { newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) } else { + if common.DebugEnabled { + println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) + } newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) } return @@ -105,8 +104,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t // General format error (OpenAI, Anthropic, Gemini, etc.) newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode) } else { - newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode) - newApiErr.ErrorType = types.ErrorTypeOpenAIError + newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode) } return } @@ -116,7 +114,7 @@ func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) return } statusCodeMapping := make(map[string]string) - err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) + err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) if err != nil { return } diff --git a/service/file_decoder.go b/service/file_decoder.go index c1d4fb0c0..99fdc3f9a 100644 --- a/service/file_decoder.go +++ b/service/file_decoder.go @@ -1,19 +1,148 @@ package service import ( + "bytes" "encoding/base64" "fmt" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" "io" + "net/http" "one-api/common" "one-api/constant" - "one-api/dto" + "one-api/logger" + "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) -func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) { +// GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf +// 如果获取失败,返回 application/octet-stream +func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) { + response, err := DoDownloadRequest(url, []string{"get_mime_type", strings.Join(reason, ", ")}...) + if err != nil { + common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error())) + return "", err + } + defer response.Body.Close() + + if response.StatusCode != 200 { + logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode)) + return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode) + } + + if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" { + if i := strings.Index(headerType, ";"); i != -1 { + headerType = headerType[:i] + } + if headerType != "application/octet-stream" { + return headerType, nil + } + } + + if cd := response.Header.Get("Content-Disposition"); cd != "" { + parts := strings.Split(cd, ";") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(strings.ToLower(part), "filename=") { + name := strings.TrimSpace(strings.TrimPrefix(part, "filename=")) + if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' { + name = name[1 : len(name)-1] + } + if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) { + ext := strings.ToLower(name[dot+1:]) + if ext != "" { + mt := GetMimeTypeByExtension(ext) + if mt != "application/octet-stream" { + return mt, nil + } + } + } + break + } + } + } + + cleanedURL := url + if q := strings.Index(cleanedURL, "?"); q != -1 { + cleanedURL = cleanedURL[:q] + } + if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) { + last := cleanedURL[slash+1:] + if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) { + ext := strings.ToLower(last[dot+1:]) + if ext != "" { + mt := GetMimeTypeByExtension(ext) + if mt != "application/octet-stream" { + return mt, nil + } + } + } + } + + var readData []byte + limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024} + for _, limit := range limits { + logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit)) + if len(readData) < limit { + need := limit - len(readData) + tmp := make([]byte, need) + n, _ := io.ReadFull(response.Body, tmp) + if n > 0 { + readData = append(readData, tmp[:n]...) + } + } + + if len(readData) == 0 { + continue + } + + sniffed := http.DetectContentType(readData) + if sniffed != "" && sniffed != "application/octet-stream" { + return sniffed, nil + } + + if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil { + switch strings.ToLower(format) { + case "jpeg", "jpg": + return "image/jpeg", nil + case "png": + return "image/png", nil + case "gif": + return "image/gif", nil + case "bmp": + return "image/bmp", nil + case "tiff": + return "image/tiff", nil + default: + if format != "" { + return "image/" + strings.ToLower(format), nil + } + } + } + } + + // Fallback + return "application/octet-stream", nil +} + +func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) { + contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url)) + + // Check if the file has already been downloaded in this request + if cachedData, exists := c.Get(contextKey); exists { + if common.DebugEnabled { + logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url)) + } + return cachedData.(*types.LocalFileData), nil + } + var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024 - resp, err := DoDownloadRequest(url) + resp, err := DoDownloadRequest(url, reason...) if err != nil { return nil, err } @@ -38,9 +167,7 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) { mimeType = strings.Split(mimeType, ";")[0] } if mimeType == "application/octet-stream" { - if common.DebugEnabled { - println("MIME type is application/octet-stream, trying to guess from URL or filename") - } + logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url)) // try to guess the MIME type from the url last segment urlParts := strings.Split(url, "/") if len(urlParts) > 0 { @@ -77,12 +204,15 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) { } } } - - return &dto.LocalFileData{ + data := &types.LocalFileData{ Base64Data: base64Data, MimeType: mimeType, Size: int64(len(fileBytes)), - }, nil + } + // Store the file data in the context to avoid re-downloading + c.Set(contextKey, data) + + return data, nil } func GetMimeTypeByExtension(ext string) string { diff --git a/common/http.go b/service/http.go similarity index 86% rename from common/http.go rename to service/http.go index d2e824efb..357a2e788 100644 --- a/common/http.go +++ b/service/http.go @@ -1,10 +1,12 @@ -package common +package service import ( "bytes" "fmt" "io" "net/http" + "one-api/common" + "one-api/logger" "github.com/gin-gonic/gin" ) @@ -15,7 +17,7 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) { } err := httpResponse.Body.Close() if err != nil { - SysError("failed to close response body: " + err.Error()) + common.SysError("failed to close response body: " + err.Error()) } } @@ -52,6 +54,6 @@ func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) { _, err := io.Copy(c.Writer, body) if err != nil { - LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) + logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) } } diff --git a/service/image.go b/service/image.go index 252093f1f..453d8dd1c 100644 --- a/service/image.go +++ b/service/image.go @@ -21,6 +21,10 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e base64String = base64String[idx+1:] } + if len(base64String) == 0 { + return image.Config{}, "", "", errors.New("base64 string is empty") + } + // 将base64字符串解码为字节切片 decodedData, err := base64.StdEncoding.DecodeString(base64String) if err != nil { diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 020a2ba9f..7a609c9f5 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -5,7 +5,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/relay/helper" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -28,6 +28,12 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m other["is_model_mapped"] = true other["upstream_model_name"] = relayInfo.UpstreamModelName } + + isSystemPromptOverwritten := common.GetContextKeyBool(ctx, constant.ContextKeySystemPromptOverride) + if isSystemPromptOverwritten { + other["is_system_prompt_overwritten"] = true + } + adminInfo := make(map[string]interface{}) adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") isMultiKey := common.GetContextKeyBool(ctx, constant.ContextKeyChannelIsMultiKey) @@ -72,7 +78,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, return info } -func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} { +func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} { other := make(map[string]interface{}) other["model_price"] = priceData.ModelPrice other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio diff --git a/service/midjourney.go b/service/midjourney.go index 1fc196822..916d02d0b 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -212,7 +212,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU defer cancel() resp, err := GetHttpClient().Do(req) if err != nil { - common.SysError("do request failed: " + err.Error()) + common.SysLog("do request failed: " + err.Error()) return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err } statusCode := resp.StatusCode @@ -233,7 +233,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err } - common.CloseResponseBodyGracefully(resp) + CloseResponseBodyGracefully(resp) respStr := string(responseBody) log.Printf("respStr: %s", respStr) if respStr == "" { diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go new file mode 100644 index 000000000..86b04e526 --- /dev/null +++ b/service/pre_consume_quota.go @@ -0,0 +1,78 @@ +package service + +import ( + "fmt" + "net/http" + "one-api/common" + "one-api/logger" + "one-api/model" + relaycommon "one-api/relay/common" + "one-api/types" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" +) + +func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { + if preConsumedQuota != 0 { + logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota))) + gopool.Go(func() { + relayInfoCopy := *relayInfo + + err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) + if err != nil { + common.SysLog("error return pre-consumed quota: " + err.Error()) + } + }) + } +} + +// PreConsumeQuota checks if the user has enough quota to pre-consume. +// It returns the pre-consumed quota if successful, or an error if not. +func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) { + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) + if err != nil { + return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) + } + if userQuota <= 0 { + return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + if userQuota-preConsumedQuota < 0 { + return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + + trustQuota := common.GetTrustQuota() + + relayInfo.UserQuota = userQuota + if userQuota > trustQuota { + // 用户额度充足,判断令牌额度是否充足 + if !relayInfo.TokenUnlimited { + // 非无限令牌,判断令牌额度是否充足 + tokenQuota := c.GetInt("token_quota") + if tokenQuota > trustQuota { + // 令牌额度充足,信任令牌 + preConsumedQuota = 0 + logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) + } + } else { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId)) + } + } + + if preConsumedQuota > 0 { + err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) + if err != nil { + return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) + if err != nil { + return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + } + logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) + } + relayInfo.FinalPreConsumedQuota = preConsumedQuota + return preConsumedQuota, nil +} diff --git a/service/quota.go b/service/quota.go index 0f618402f..e078a1ad1 100644 --- a/service/quota.go +++ b/service/quota.go @@ -8,11 +8,12 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - "one-api/relay/helper" "one-api/setting" "one-api/setting/ratio_setting" + "one-api/types" "strings" "time" @@ -37,6 +38,14 @@ type QuotaInfo struct { GroupRatio float64 } +func hasCustomModelRatio(modelName string, currentRatio float64) bool { + defaultRatio, exists := ratio_setting.GetDefaultModelRatioMap()[modelName] + if !exists { + return true + } + return currentRatio != defaultRatio +} + func calculateAudioQuota(info QuotaInfo) int { if info.UsePrice { modelPrice := decimal.NewFromFloat(info.ModelPrice) @@ -129,23 +138,23 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag quota := calculateAudioQuota(quotaInfo) if userQuota < quota { - return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)) + return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota)) } if !token.UnlimitedQuota && token.RemainQuota < quota { - return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota)) } err = PostConsumeQuota(relayInfo, quota, 0, false) if err != nil { return err } - common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) + logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) return nil } func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { + usage *dto.RealtimeUsage, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.InputTokenDetails.TextTokens @@ -159,10 +168,10 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName)) - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - usePrice := priceData.UsePrice + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + usePrice := relayInfo.PriceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -196,8 +205,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) @@ -208,7 +217,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod logContent += ", " + extraContent } other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: usage.InputTokens, @@ -218,7 +227,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -226,8 +234,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod }) } -func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -235,21 +242,22 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := priceData.CompletionRatio - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - cacheRatio := priceData.CacheRatio + completionRatio := relayInfo.PriceData.CompletionRatio + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + cacheRatio := relayInfo.PriceData.CacheRatio cacheTokens := usage.PromptTokensDetails.CachedTokens - cacheCreationRatio := priceData.CacheCreationRatio + cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens if relayInfo.ChannelType == constant.ChannelTypeOpenRouter { promptTokens -= cacheTokens - if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 { - maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData) - if promptTokens >= maybeCacheCreationTokens { + isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio) + if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings { + maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData) + if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens { cacheCreationTokens = maybeCacheCreationTokens } } @@ -257,7 +265,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } calculateQuota := 0.0 - if !priceData.UsePrice { + if !relayInfo.PriceData.UsePrice { calculateQuota = float64(promptTokens) calculateQuota += float64(cacheTokens) * cacheRatio calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio @@ -282,23 +290,38 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游出错)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota + + if quotaDelta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } else if quotaDelta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } + if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, - cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, @@ -308,7 +331,6 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -317,7 +339,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } -func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int { +func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int { if priceData.CacheCreationRatio == 1 { return 0 } @@ -338,8 +360,7 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData (promptCacheCreatePrice - quotaPrice))) } -func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { +func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.PromptTokensDetails.TextTokens @@ -353,10 +374,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)) - modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatioInfo.GroupRatio - modelPrice := priceData.ModelPrice - usePrice := priceData.UsePrice + modelRatio := relayInfo.PriceData.ModelRatio + groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio + modelPrice := relayInfo.PriceData.ModelPrice + usePrice := relayInfo.PriceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -390,18 +411,33 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota)) + logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota)) } else { model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - preConsumedQuota + quotaDelta := quota - relayInfo.FinalPreConsumedQuota + + if quotaDelta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } else if quotaDelta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-quotaDelta), + logger.FormatQuota(quota), + logger.FormatQuota(relayInfo.FinalPreConsumedQuota), + )) + } + if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } @@ -410,7 +446,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, logContent += ", " + extraContent } other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: usage.PromptTokens, @@ -420,7 +456,6 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, @@ -443,7 +478,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { return err } if !relayInfo.TokenUnlimited && token.RemainQuota < quota { - return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota)) } err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) if err != nil { @@ -500,8 +535,27 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon if quotaTooLow { prompt := "您的额度即将用尽" topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) - content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" - err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink})) + + // 根据通知方式生成不同的内容格式 + var content string + var values []interface{} + + notifyType := userSetting.NotifyType + if notifyType == "" { + notifyType = dto.NotifyTypeEmail + } + + if notifyType == dto.NotifyTypeBark { + // Bark推送使用简短文本,不支持HTML + content = "{{value}},剩余额度:{{value}},请及时充值" + values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)} + } else { + // 默认内容格式,适用于Email和Webhook + content = "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" + values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink} + } + + err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, values)) if err != nil { common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error())) } diff --git a/service/sensitive.go b/service/sensitive.go index b3e3c4d66..ed033daac 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -2,7 +2,6 @@ package service import ( "errors" - "fmt" "one-api/dto" "one-api/setting" "strings" @@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) { return nil, nil } -func CheckSensitiveText(text string) ([]string, error) { - if ok, words := SensitiveWordContains(text); ok { - return words, errors.New("sensitive words detected") - } - return nil, nil -} - -func CheckSensitiveInput(input any) ([]string, error) { - switch v := input.(type) { - case string: - return CheckSensitiveText(v) - case []string: - var builder strings.Builder - for _, s := range v { - builder.WriteString(s) - } - return CheckSensitiveText(builder.String()) - } - return CheckSensitiveText(fmt.Sprintf("%v", input)) +func CheckSensitiveText(text string) (bool, []string) { + return SensitiveWordContains(text) } // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 @@ -71,7 +53,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, return false, nil, text } checkText := strings.ToLower(text) - m := InitAc(setting.SensitiveWords) + m := getOrBuildAC(setting.SensitiveWords) hits := m.MultiPatternSearch([]rune(checkText), returnImmediately) if len(hits) > 0 { words := make([]string, 0, len(hits)) diff --git a/service/str.go b/service/str.go index 4390e99be..61054bdc4 100644 --- a/service/str.go +++ b/service/str.go @@ -3,8 +3,12 @@ package service import ( "bytes" "fmt" - goahocorasick "github.com/anknown/ahocorasick" + "hash/fnv" + "sort" "strings" + "sync" + + goahocorasick "github.com/anknown/ahocorasick" ) func SundaySearch(text string, pattern string) bool { @@ -56,26 +60,73 @@ func RemoveDuplicate(s []string) []string { return result } -func InitAc(words []string) *goahocorasick.Machine { +func InitAc(dict []string) *goahocorasick.Machine { m := new(goahocorasick.Machine) - dict := readRunes(words) - if err := m.Build(dict); err != nil { + runes := readRunes(dict) + if err := m.Build(runes); err != nil { fmt.Println(err) return nil } return m } -func readRunes(words []string) [][]rune { - var dict [][]rune +var acCache sync.Map - for _, word := range words { +func acKey(dict []string) string { + if len(dict) == 0 { + return "" + } + normalized := make([]string, 0, len(dict)) + for _, w := range dict { + w = strings.ToLower(strings.TrimSpace(w)) + if w != "" { + normalized = append(normalized, w) + } + } + if len(normalized) == 0 { + return "" + } + sort.Strings(normalized) + hasher := fnv.New64a() + for _, w := range normalized { + hasher.Write([]byte{0}) + hasher.Write([]byte(w)) + } + return fmt.Sprintf("%x", hasher.Sum64()) +} + +func getOrBuildAC(dict []string) *goahocorasick.Machine { + key := acKey(dict) + if key == "" { + return nil + } + if v, ok := acCache.Load(key); ok { + if m, ok2 := v.(*goahocorasick.Machine); ok2 { + return m + } + } + m := InitAc(dict) + if m == nil { + return nil + } + if actual, loaded := acCache.LoadOrStore(key, m); loaded { + if cached, ok := actual.(*goahocorasick.Machine); ok { + return cached + } + } + return m +} + +func readRunes(dict []string) [][]rune { + var runes [][]rune + + for _, word := range dict { word = strings.ToLower(word) l := bytes.TrimSpace([]byte(word)) - dict = append(dict, bytes.Runes(l)) + runes = append(runes, bytes.Runes(l)) } - return dict + return runes } func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) { @@ -85,7 +136,7 @@ func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []str if len(findText) == 0 { return false, nil } - m := InitAc(dict) + m := getOrBuildAC(dict) if m == nil { return false, nil } diff --git a/service/token_counter.go b/service/token_counter.go index eed5b5ca0..da56523fe 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -4,18 +4,24 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tiktoken-go/tokenizer" - "github.com/tiktoken-go/tokenizer/codec" "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" "log" "math" "one-api/common" "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/types" "strings" "sync" "unicode/utf8" + + "github.com/gin-gonic/gin" + "github.com/tiktoken-go/tokenizer" + "github.com/tiktoken-go/tokenizer/codec" ) // tokenEncoderMap won't grow after initialization @@ -72,52 +78,101 @@ func getTokenNum(tokenEncoder tokenizer.Codec, text string) int { return tkm } -func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { - if imageUrl == nil { +func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) { + if fileMeta == nil { return 0, fmt.Errorf("image_url_is_nil") } + + // Defaults for 4o/4.1/4.5 family unless overridden below baseTokens := 85 - if model == "glm-4v" { + tileTokens := 170 + + // Model classification + lowerModel := strings.ToLower(model) + + // Special cases from existing behavior + if strings.HasPrefix(lowerModel, "glm-4") { return 1047, nil } - if imageUrl.Detail == "low" { + + // Patch-based models (32x32 patches, capped at 1536, with multiplier) + isPatchBased := false + multiplier := 1.0 + switch { + case strings.Contains(lowerModel, "gpt-4.1-mini"): + isPatchBased = true + multiplier = 1.62 + case strings.Contains(lowerModel, "gpt-4.1-nano"): + isPatchBased = true + multiplier = 2.46 + case strings.HasPrefix(lowerModel, "o4-mini"): + isPatchBased = true + multiplier = 1.72 + case strings.HasPrefix(lowerModel, "gpt-5-mini"): + isPatchBased = true + multiplier = 1.62 + case strings.HasPrefix(lowerModel, "gpt-5-nano"): + isPatchBased = true + multiplier = 2.46 + } + + // Tile-based model tokens and bases per doc + if !isPatchBased { + if strings.HasPrefix(lowerModel, "gpt-4o-mini") { + baseTokens = 2833 + tileTokens = 5667 + } else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) { + baseTokens = 70 + tileTokens = 140 + } else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") { + baseTokens = 75 + tileTokens = 150 + } else if strings.Contains(lowerModel, "computer-use-preview") { + baseTokens = 65 + tileTokens = 129 + } else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") { + baseTokens = 85 + tileTokens = 170 + } + } + + // Respect existing feature flags/short-circuits + if fileMeta.Detail == "low" && !isPatchBased { return baseTokens, nil } if !constant.GetMediaTokenNotStream && !stream { return 3 * baseTokens, nil } - - // 同步One API的图片计费逻辑 - if imageUrl.Detail == "auto" || imageUrl.Detail == "" { - imageUrl.Detail = "high" + // Normalize detail + if fileMeta.Detail == "auto" || fileMeta.Detail == "" { + fileMeta.Detail = "high" } - - tileTokens := 170 - if strings.HasPrefix(model, "gpt-4o-mini") { - tileTokens = 5667 - baseTokens = 2833 - } - // 是否统计图片token + // Whether to count image tokens at all if !constant.GetMediaToken { return 3 * baseTokens, nil } - if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic { - return 3 * baseTokens, nil - } + + // Decode image to get dimensions var config image.Config var err error var format string var b64str string - if strings.HasPrefix(imageUrl.Url, "http") { - config, format, err = DecodeUrlImageData(imageUrl.Url) + + if fileMeta.ParsedData != nil { + config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data) } else { - common.SysLog(fmt.Sprintf("decoding image")) - config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url) + if strings.HasPrefix(fileMeta.OriginData, "http") { + config, format, err = DecodeUrlImageData(fileMeta.OriginData) + } else { + common.SysLog(fmt.Sprintf("decoding image")) + config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData) + } + fileMeta.MimeType = format } + if err != nil { return 0, err } - imageUrl.MimeType = format if config.Width == 0 || config.Height == 0 { // not an image @@ -125,57 +180,183 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m // file type return 3 * baseTokens, nil } - return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url)) + return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData)) } - shortSide := config.Width - otherSide := config.Height - log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height) - // 缩放倍数 - scale := 1.0 - if config.Height < shortSide { - shortSide = config.Height - otherSide = config.Width + width := config.Width + height := config.Height + log.Printf("format: %s, width: %d, height: %d", format, width, height) + + if isPatchBased { + // 32x32 patch-based calculation with 1536 cap and model multiplier + ceilDiv := func(a, b int) int { return (a + b - 1) / b } + rawPatchesW := ceilDiv(width, 32) + rawPatchesH := ceilDiv(height, 32) + rawPatches := rawPatchesW * rawPatchesH + if rawPatches > 1536 { + // scale down + area := float64(width * height) + r := math.Sqrt(float64(32*32*1536) / area) + wScaled := float64(width) * r + hScaled := float64(height) * r + // adjust to fit whole number of patches after scaling + adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0) + adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0) + adj := math.Min(adjW, adjH) + if !math.IsNaN(adj) && adj > 0 { + r = r * adj + } + wScaled = float64(width) * r + hScaled = float64(height) * r + patchesW := math.Ceil(wScaled / 32.0) + patchesH := math.Ceil(hScaled / 32.0) + imageTokens := int(patchesW * patchesH) + if imageTokens > 1536 { + imageTokens = 1536 + } + return int(math.Round(float64(imageTokens) * multiplier)), nil + } + // below cap + imageTokens := rawPatches + return int(math.Round(float64(imageTokens) * multiplier)), nil } - // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768 - if shortSide > 768 { - scale = float64(shortSide) / 768 - shortSide = 768 + // Tile-based calculation for 4o/4.1/4.5/o1/o3/etc. + // Step 1: fit within 2048x2048 square + maxSide := math.Max(float64(width), float64(height)) + fitScale := 1.0 + if maxSide > 2048 { + fitScale = maxSide / 2048.0 } - // 将另一边按照相同的比例缩小,向上取整 - otherSide = int(math.Ceil(float64(otherSide) / scale)) - log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale) - // 计算图片的token数量(边的长度除以512,向上取整) - tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512) - log.Printf("tiles: %d", tiles) + fitW := int(math.Round(float64(width) / fitScale)) + fitH := int(math.Round(float64(height) / fitScale)) + + // Step 2: scale so that shortest side is exactly 768 + minSide := math.Min(float64(fitW), float64(fitH)) + if minSide == 0 { + return baseTokens, nil + } + shortScale := 768.0 / minSide + finalW := int(math.Round(float64(fitW) * shortScale)) + finalH := int(math.Round(float64(fitH) * shortScale)) + + // Count 512px tiles + tilesW := (finalW + 512 - 1) / 512 + tilesH := (finalH + 512 - 1) / 512 + tiles := tilesW * tilesH + + if common.DebugEnabled { + log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles) + } + return tiles*tileTokens + baseTokens, nil } -func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) { - tkm := 0 - msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream) - if err != nil { - return 0, err +func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { + if !constant.GetMediaToken { + return 0, nil } - tkm += msgTokens - if request.Tools != nil { - openaiTools := request.Tools - countStr := "" - for _, tool := range openaiTools { - countStr = tool.Function.Name - if tool.Function.Description != "" { - countStr += tool.Function.Description - } - if tool.Function.Parameters != nil { - countStr += fmt.Sprintf("%v", tool.Function.Parameters) - } - } - toolTokens := CountTokenInput(countStr, request.Model) - tkm += 8 - tkm += toolTokens + if !constant.GetMediaTokenNotStream && !info.IsStream { + return 0, nil + } + if info.RelayFormat == types.RelayFormatOpenAIRealtime { + return 0, nil + } + if meta == nil { + return 0, errors.New("token count meta is nil") } + model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) + tkm := 0 + + if meta.TokenType == types.TokenTypeTextNumber { + tkm += utf8.RuneCountInString(meta.CombineText) + } else { + tkm += CountTextToken(meta.CombineText, model) + } + + if info.RelayFormat == types.RelayFormatOpenAI { + tkm += meta.ToolsCount * 8 + tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量 + tkm += meta.NameCount * 3 + tkm += 3 + } + + shouldFetchFiles := true + + if info.RelayFormat == types.RelayFormatGemini { + shouldFetchFiles = false + } + + if shouldFetchFiles { + for _, file := range meta.Files { + if strings.HasPrefix(file.OriginData, "http") { + mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter") + if err != nil { + return 0, fmt.Errorf("error getting file base64 from url: %v", err) + } + if strings.HasPrefix(mineType, "image/") { + file.FileType = types.FileTypeImage + } else if strings.HasPrefix(mineType, "video/") { + file.FileType = types.FileTypeVideo + } else if strings.HasPrefix(mineType, "audio/") { + file.FileType = types.FileTypeAudio + } else { + file.FileType = types.FileTypeFile + } + file.MimeType = mineType + } else if strings.HasPrefix(file.OriginData, "data:") { + // get mime type from base64 header + parts := strings.SplitN(file.OriginData, ",", 2) + if len(parts) >= 1 { + header := parts[0] + // Extract mime type from "data:mime/type;base64" format + if strings.Contains(header, ":") && strings.Contains(header, ";") { + mimeStart := strings.Index(header, ":") + 1 + mimeEnd := strings.Index(header, ";") + if mimeStart < mimeEnd { + mineType := header[mimeStart:mimeEnd] + if strings.HasPrefix(mineType, "image/") { + file.FileType = types.FileTypeImage + } else if strings.HasPrefix(mineType, "video/") { + file.FileType = types.FileTypeVideo + } else if strings.HasPrefix(mineType, "audio/") { + file.FileType = types.FileTypeAudio + } else { + file.FileType = types.FileTypeFile + } + file.MimeType = mineType + } + } + } + } + } + } + + for i, file := range meta.Files { + switch file.FileType { + case types.FileTypeImage: + if info.RelayFormat == types.RelayFormatGemini && !strings.HasPrefix(model, "gemini-2.5-flash-image-preview") { + tkm += 256 + } else { + token, err := getImageToken(file, model, info.IsStream) + if err != nil { + return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err) + } + tkm += token + } + case types.FileTypeAudio: + tkm += 256 + case types.FileTypeVideo: + tkm += 4096 * 2 + case types.FileTypeFile: + tkm += 4096 + default: + tkm += 4096 // Default case for unknown file types + } + } + + common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm) return tkm, nil } @@ -338,59 +519,6 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, return textToken, audioToken, nil } -func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) { - //recover when panic - tokenEncoder := getTokenEncoder(model) - // Reference: - // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - // https://github.com/pkoukk/tiktoken-go/issues/6 - // - // Every message follows <|start|>{role/name}\n{content}<|end|>\n - var tokensPerMessage int - var tokensPerName int - if model == "gpt-3.5-turbo-0301" { - tokensPerMessage = 4 - tokensPerName = -1 // If there's a name, the role is omitted - } else { - tokensPerMessage = 3 - tokensPerName = 1 - } - tokenNum := 0 - for _, message := range messages { - tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Role) - if message.Content != nil { - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) - } - arrayContent := message.ParseContent() - for _, m := range arrayContent { - if m.Type == dto.ContentTypeImageURL { - imageUrl := m.GetImageMedia() - imageTokenNum, err := getImageToken(info, imageUrl, model, stream) - if err != nil { - return 0, err - } - tokenNum += imageTokenNum - log.Printf("image token num: %d", imageTokenNum) - } else if m.Type == dto.ContentTypeInputAudio { - // TODO: 音频token数量计算 - tokenNum += 100 - } else if m.Type == dto.ContentTypeFile { - tokenNum += 5000 - } else if m.Type == dto.ContentTypeVideoUrl { - tokenNum += 5000 - } else { - tokenNum += getTokenNum(tokenEncoder, m.Text) - } - } - } - } - tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum, nil -} - func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: diff --git a/service/user_notify.go b/service/user_notify.go index 966640076..c4a3ea91f 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -2,9 +2,12 @@ package service import ( "fmt" + "net/http" + "net/url" "one-api/common" "one-api/dto" "one-api/model" + "one-api/setting" "strings" ) @@ -12,7 +15,7 @@ func NotifyRootUser(t string, subject string, content string) { user := model.GetRootUser().ToBaseUser() err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil)) if err != nil { - common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error())) + common.SysLog(fmt.Sprintf("failed to notify root user: %s", err.Error())) } } @@ -25,7 +28,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // Check notification limit canSend, err := CheckNotificationLimit(userId, data.Type) if err != nil { - common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) + common.SysLog(fmt.Sprintf("failed to check notification limit: %s", err.Error())) return err } if !canSend { @@ -44,13 +47,20 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data case dto.NotifyTypeWebhook: webhookURLStr := userSetting.WebhookUrl if webhookURLStr == "" { - common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) + common.SysLog(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) return nil } // 获取 webhook secret webhookSecret := userSetting.WebhookSecret return SendWebhookNotify(webhookURLStr, webhookSecret, data) + case dto.NotifyTypeBark: + barkURL := userSetting.BarkUrl + if barkURL == "" { + common.SysLog(fmt.Sprintf("user %d has no bark url, skip sending bark", userId)) + return nil + } + return sendBarkNotify(barkURL, data) } return nil } @@ -64,3 +74,67 @@ func sendEmailNotify(userEmail string, data dto.Notify) error { } return common.SendEmail(data.Title, userEmail, content) } + +func sendBarkNotify(barkURL string, data dto.Notify) error { + // 处理占位符 + content := data.Content + for _, value := range data.Values { + content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1) + } + + // 替换模板变量 + finalURL := strings.ReplaceAll(barkURL, "{{title}}", url.QueryEscape(data.Title)) + finalURL = strings.ReplaceAll(finalURL, "{{content}}", url.QueryEscape(content)) + + // 发送GET请求到Bark + var req *http.Request + var resp *http.Response + var err error + + if setting.EnableWorker() { + // 使用worker发送请求 + workerReq := &WorkerRequest{ + URL: finalURL, + Key: setting.WorkerValidKey, + Method: http.MethodGet, + Headers: map[string]string{ + "User-Agent": "OneAPI-Bark-Notify/1.0", + }, + } + + resp, err = DoWorkerRequest(workerReq) + if err != nil { + return fmt.Errorf("failed to send bark request through worker: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode) + } + } else { + // 直接发送请求 + req, err = http.NewRequest(http.MethodGet, finalURL, nil) + if err != nil { + return fmt.Errorf("failed to create bark request: %v", err) + } + + // 设置User-Agent + req.Header.Set("User-Agent", "OneAPI-Bark-Notify/1.0") + + // 发送请求 + client := GetHttpClient() + resp, err = client.Do(req) + if err != nil { + return fmt.Errorf("failed to send bark request: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode) + } + } + + return nil +} diff --git a/setting/chat.go b/setting/chat.go index 53cb655a2..bd1e26e30 100644 --- a/setting/chat.go +++ b/setting/chat.go @@ -12,6 +12,9 @@ var Chats = []map[string]string{ { "Cherry Studio": "cherrystudio://providers/api-keys?v=1&data={cherryConfig}", }, + { + "流畅阅读": "fluentread", + }, { "Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}", }, @@ -34,7 +37,7 @@ func UpdateChatsByJsonString(jsonString string) error { func Chats2JsonString() string { jsonBytes, err := json.Marshal(Chats) if err != nil { - common.SysError("error marshalling chats: " + err.Error()) + common.SysLog("error marshalling chats: " + err.Error()) return "[]" } return string(jsonBytes) diff --git a/setting/console_setting/config.go b/setting/console_setting/config.go index 6327e5584..8cfcd0ed6 100644 --- a/setting/console_setting/config.go +++ b/setting/console_setting/config.go @@ -3,37 +3,37 @@ package console_setting import "one-api/setting/config" type ConsoleSetting struct { - ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串) - UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串) - Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串) - FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串) - ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板 - UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板 - AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板 - FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板 + ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串) + UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串) + Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串) + FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串) + ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板 + UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板 + AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板 + FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板 } // 默认配置 var defaultConsoleSetting = ConsoleSetting{ - ApiInfo: "", - UptimeKumaGroups: "", - Announcements: "", - FAQ: "", - ApiInfoEnabled: true, - UptimeKumaEnabled: true, - AnnouncementsEnabled: true, - FAQEnabled: true, + ApiInfo: "", + UptimeKumaGroups: "", + Announcements: "", + FAQ: "", + ApiInfoEnabled: true, + UptimeKumaEnabled: true, + AnnouncementsEnabled: true, + FAQEnabled: true, } // 全局实例 var consoleSetting = defaultConsoleSetting func init() { - // 注册到全局配置管理器,键名为 console_setting - config.GlobalConfig.Register("console_setting", &consoleSetting) + // 注册到全局配置管理器,键名为 console_setting + config.GlobalConfig.Register("console_setting", &consoleSetting) } // GetConsoleSetting 获取 ConsoleSetting 配置实例 func GetConsoleSetting() *ConsoleSetting { - return &consoleSetting -} \ No newline at end of file + return &consoleSetting +} diff --git a/setting/console_setting/validation.go b/setting/console_setting/validation.go index fda6453df..529457761 100644 --- a/setting/console_setting/validation.go +++ b/setting/console_setting/validation.go @@ -1,304 +1,304 @@ package console_setting import ( - "encoding/json" - "fmt" - "net/url" - "regexp" - "strings" - "time" - "sort" + "encoding/json" + "fmt" + "net/url" + "regexp" + "sort" + "strings" + "time" ) var ( - urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`) - dangerousChars = []string{" 50 { - return fmt.Errorf("API信息数量不能超过50个") - } + if len(apiInfoList) > 50 { + return fmt.Errorf("API信息数量不能超过50个") + } - for i, apiInfo := range apiInfoList { - urlStr, ok := apiInfo["url"].(string) - if !ok || urlStr == "" { - return fmt.Errorf("第%d个API信息缺少URL字段", i+1) - } - route, ok := apiInfo["route"].(string) - if !ok || route == "" { - return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1) - } - description, ok := apiInfo["description"].(string) - if !ok || description == "" { - return fmt.Errorf("第%d个API信息缺少说明字段", i+1) - } - color, ok := apiInfo["color"].(string) - if !ok || color == "" { - return fmt.Errorf("第%d个API信息缺少颜色字段", i+1) - } + for i, apiInfo := range apiInfoList { + urlStr, ok := apiInfo["url"].(string) + if !ok || urlStr == "" { + return fmt.Errorf("第%d个API信息缺少URL字段", i+1) + } + route, ok := apiInfo["route"].(string) + if !ok || route == "" { + return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1) + } + description, ok := apiInfo["description"].(string) + if !ok || description == "" { + return fmt.Errorf("第%d个API信息缺少说明字段", i+1) + } + color, ok := apiInfo["color"].(string) + if !ok || color == "" { + return fmt.Errorf("第%d个API信息缺少颜色字段", i+1) + } - if err := validateURL(urlStr, i+1, "API信息"); err != nil { - return err - } + if err := validateURL(urlStr, i+1, "API信息"); err != nil { + return err + } - if len(urlStr) > 500 { - return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1) - } - if len(route) > 100 { - return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1) - } - if len(description) > 200 { - return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1) - } + if len(urlStr) > 500 { + return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1) + } + if len(route) > 100 { + return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1) + } + if len(description) > 200 { + return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1) + } - if !validColors[color] { - return fmt.Errorf("第%d个API信息的颜色值不合法", i+1) - } + if !validColors[color] { + return fmt.Errorf("第%d个API信息的颜色值不合法", i+1) + } - if err := checkDangerousContent(description, i+1, "API信息"); err != nil { - return err - } - if err := checkDangerousContent(route, i+1, "API信息"); err != nil { - return err - } - } - return nil + if err := checkDangerousContent(description, i+1, "API信息"); err != nil { + return err + } + if err := checkDangerousContent(route, i+1, "API信息"); err != nil { + return err + } + } + return nil } func GetApiInfo() []map[string]interface{} { - return getJSONList(GetConsoleSetting().ApiInfo) + return getJSONList(GetConsoleSetting().ApiInfo) } func validateAnnouncements(announcementsStr string) error { - list, err := parseJSONArray(announcementsStr, "系统公告") - if err != nil { - return err - } - if len(list) > 100 { - return fmt.Errorf("系统公告数量不能超过100个") - } - validTypes := map[string]bool{ - "default": true, "ongoing": true, "success": true, "warning": true, "error": true, - } - for i, ann := range list { - content, ok := ann["content"].(string) - if !ok || content == "" { - return fmt.Errorf("第%d个公告缺少内容字段", i+1) - } - publishDateAny, exists := ann["publishDate"] - if !exists { - return fmt.Errorf("第%d个公告缺少发布日期字段", i+1) - } - publishDateStr, ok := publishDateAny.(string) - if !ok || publishDateStr == "" { - return fmt.Errorf("第%d个公告的发布日期不能为空", i+1) - } - if _, err := time.Parse(time.RFC3339, publishDateStr); err != nil { - return fmt.Errorf("第%d个公告的发布日期格式错误", i+1) - } - if t, exists := ann["type"]; exists { - if typeStr, ok := t.(string); ok { - if !validTypes[typeStr] { - return fmt.Errorf("第%d个公告的类型值不合法", i+1) - } - } - } - if len(content) > 500 { - return fmt.Errorf("第%d个公告的内容长度不能超过500字符", i+1) - } - if extra, exists := ann["extra"]; exists { - if extraStr, ok := extra.(string); ok && len(extraStr) > 200 { - return fmt.Errorf("第%d个公告的说明长度不能超过200字符", i+1) - } - } - } - return nil + list, err := parseJSONArray(announcementsStr, "系统公告") + if err != nil { + return err + } + if len(list) > 100 { + return fmt.Errorf("系统公告数量不能超过100个") + } + validTypes := map[string]bool{ + "default": true, "ongoing": true, "success": true, "warning": true, "error": true, + } + for i, ann := range list { + content, ok := ann["content"].(string) + if !ok || content == "" { + return fmt.Errorf("第%d个公告缺少内容字段", i+1) + } + publishDateAny, exists := ann["publishDate"] + if !exists { + return fmt.Errorf("第%d个公告缺少发布日期字段", i+1) + } + publishDateStr, ok := publishDateAny.(string) + if !ok || publishDateStr == "" { + return fmt.Errorf("第%d个公告的发布日期不能为空", i+1) + } + if _, err := time.Parse(time.RFC3339, publishDateStr); err != nil { + return fmt.Errorf("第%d个公告的发布日期格式错误", i+1) + } + if t, exists := ann["type"]; exists { + if typeStr, ok := t.(string); ok { + if !validTypes[typeStr] { + return fmt.Errorf("第%d个公告的类型值不合法", i+1) + } + } + } + if len(content) > 500 { + return fmt.Errorf("第%d个公告的内容长度不能超过500字符", i+1) + } + if extra, exists := ann["extra"]; exists { + if extraStr, ok := extra.(string); ok && len(extraStr) > 200 { + return fmt.Errorf("第%d个公告的说明长度不能超过200字符", i+1) + } + } + } + return nil } func validateFAQ(faqStr string) error { - list, err := parseJSONArray(faqStr, "FAQ信息") - if err != nil { - return err - } - if len(list) > 100 { - return fmt.Errorf("FAQ数量不能超过100个") - } - for i, faq := range list { - question, ok := faq["question"].(string) - if !ok || question == "" { - return fmt.Errorf("第%d个FAQ缺少问题字段", i+1) - } - answer, ok := faq["answer"].(string) - if !ok || answer == "" { - return fmt.Errorf("第%d个FAQ缺少答案字段", i+1) - } - if len(question) > 200 { - return fmt.Errorf("第%d个FAQ的问题长度不能超过200字符", i+1) - } - if len(answer) > 1000 { - return fmt.Errorf("第%d个FAQ的答案长度不能超过1000字符", i+1) - } - } - return nil + list, err := parseJSONArray(faqStr, "FAQ信息") + if err != nil { + return err + } + if len(list) > 100 { + return fmt.Errorf("FAQ数量不能超过100个") + } + for i, faq := range list { + question, ok := faq["question"].(string) + if !ok || question == "" { + return fmt.Errorf("第%d个FAQ缺少问题字段", i+1) + } + answer, ok := faq["answer"].(string) + if !ok || answer == "" { + return fmt.Errorf("第%d个FAQ缺少答案字段", i+1) + } + if len(question) > 200 { + return fmt.Errorf("第%d个FAQ的问题长度不能超过200字符", i+1) + } + if len(answer) > 1000 { + return fmt.Errorf("第%d个FAQ的答案长度不能超过1000字符", i+1) + } + } + return nil } func getPublishTime(item map[string]interface{}) time.Time { - if v, ok := item["publishDate"]; ok { - if s, ok2 := v.(string); ok2 { - if t, err := time.Parse(time.RFC3339, s); err == nil { - return t - } - } - } - return time.Time{} + if v, ok := item["publishDate"]; ok { + if s, ok2 := v.(string); ok2 { + if t, err := time.Parse(time.RFC3339, s); err == nil { + return t + } + } + } + return time.Time{} } func GetAnnouncements() []map[string]interface{} { - list := getJSONList(GetConsoleSetting().Announcements) - sort.SliceStable(list, func(i, j int) bool { - return getPublishTime(list[i]).After(getPublishTime(list[j])) - }) - return list + list := getJSONList(GetConsoleSetting().Announcements) + sort.SliceStable(list, func(i, j int) bool { + return getPublishTime(list[i]).After(getPublishTime(list[j])) + }) + return list } func GetFAQ() []map[string]interface{} { - return getJSONList(GetConsoleSetting().FAQ) + return getJSONList(GetConsoleSetting().FAQ) } func validateUptimeKumaGroups(groupsStr string) error { - groups, err := parseJSONArray(groupsStr, "Uptime Kuma分组配置") - if err != nil { - return err - } + groups, err := parseJSONArray(groupsStr, "Uptime Kuma分组配置") + if err != nil { + return err + } - if len(groups) > 20 { - return fmt.Errorf("Uptime Kuma分组数量不能超过20个") - } + if len(groups) > 20 { + return fmt.Errorf("Uptime Kuma分组数量不能超过20个") + } - nameSet := make(map[string]bool) + nameSet := make(map[string]bool) - for i, group := range groups { - categoryName, ok := group["categoryName"].(string) - if !ok || categoryName == "" { - return fmt.Errorf("第%d个分组缺少分类名称字段", i+1) - } - if nameSet[categoryName] { - return fmt.Errorf("第%d个分组的分类名称与其他分组重复", i+1) - } - nameSet[categoryName] = true - urlStr, ok := group["url"].(string) - if !ok || urlStr == "" { - return fmt.Errorf("第%d个分组缺少URL字段", i+1) - } - slug, ok := group["slug"].(string) - if !ok || slug == "" { - return fmt.Errorf("第%d个分组缺少Slug字段", i+1) - } - description, ok := group["description"].(string) - if !ok { - description = "" - } + for i, group := range groups { + categoryName, ok := group["categoryName"].(string) + if !ok || categoryName == "" { + return fmt.Errorf("第%d个分组缺少分类名称字段", i+1) + } + if nameSet[categoryName] { + return fmt.Errorf("第%d个分组的分类名称与其他分组重复", i+1) + } + nameSet[categoryName] = true + urlStr, ok := group["url"].(string) + if !ok || urlStr == "" { + return fmt.Errorf("第%d个分组缺少URL字段", i+1) + } + slug, ok := group["slug"].(string) + if !ok || slug == "" { + return fmt.Errorf("第%d个分组缺少Slug字段", i+1) + } + description, ok := group["description"].(string) + if !ok { + description = "" + } - if err := validateURL(urlStr, i+1, "分组"); err != nil { - return err - } + if err := validateURL(urlStr, i+1, "分组"); err != nil { + return err + } - if len(categoryName) > 50 { - return fmt.Errorf("第%d个分组的分类名称长度不能超过50字符", i+1) - } - if len(urlStr) > 500 { - return fmt.Errorf("第%d个分组的URL长度不能超过500字符", i+1) - } - if len(slug) > 100 { - return fmt.Errorf("第%d个分组的Slug长度不能超过100字符", i+1) - } - if len(description) > 200 { - return fmt.Errorf("第%d个分组的描述长度不能超过200字符", i+1) - } + if len(categoryName) > 50 { + return fmt.Errorf("第%d个分组的分类名称长度不能超过50字符", i+1) + } + if len(urlStr) > 500 { + return fmt.Errorf("第%d个分组的URL长度不能超过500字符", i+1) + } + if len(slug) > 100 { + return fmt.Errorf("第%d个分组的Slug长度不能超过100字符", i+1) + } + if len(description) > 200 { + return fmt.Errorf("第%d个分组的描述长度不能超过200字符", i+1) + } - if !slugRegex.MatchString(slug) { - return fmt.Errorf("第%d个分组的Slug只能包含字母、数字、下划线和连字符", i+1) - } + if !slugRegex.MatchString(slug) { + return fmt.Errorf("第%d个分组的Slug只能包含字母、数字、下划线和连字符", i+1) + } - if err := checkDangerousContent(description, i+1, "分组"); err != nil { - return err - } - if err := checkDangerousContent(categoryName, i+1, "分组"); err != nil { - return err - } - } - return nil + if err := checkDangerousContent(description, i+1, "分组"); err != nil { + return err + } + if err := checkDangerousContent(categoryName, i+1, "分组"); err != nil { + return err + } + } + return nil } func GetUptimeKumaGroups() []map[string]interface{} { - return getJSONList(GetConsoleSetting().UptimeKumaGroups) -} \ No newline at end of file + return getJSONList(GetConsoleSetting().UptimeKumaGroups) +} diff --git a/setting/model_setting/gemini.go b/setting/model_setting/gemini.go index f132fec88..5412155f1 100644 --- a/setting/model_setting/gemini.go +++ b/setting/model_setting/gemini.go @@ -26,6 +26,7 @@ var defaultGeminiSettings = GeminiSettings{ SupportedImagineModels: []string{ "gemini-2.0-flash-exp-image-generation", "gemini-2.0-flash-exp", + "gemini-2.5-flash-image-preview", }, ThinkingAdapterEnabled: false, ThinkingAdapterBudgetTokensPercentage: 0.6, diff --git a/setting/operation_setting/monitor_setting.go b/setting/operation_setting/monitor_setting.go new file mode 100644 index 000000000..1d0bbec40 --- /dev/null +++ b/setting/operation_setting/monitor_setting.go @@ -0,0 +1,34 @@ +package operation_setting + +import ( + "one-api/setting/config" + "os" + "strconv" +) + +type MonitorSetting struct { + AutoTestChannelEnabled bool `json:"auto_test_channel_enabled"` + AutoTestChannelMinutes int `json:"auto_test_channel_minutes"` +} + +// 默认配置 +var monitorSetting = MonitorSetting{ + AutoTestChannelEnabled: false, + AutoTestChannelMinutes: 10, +} + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("monitor_setting", &monitorSetting) +} + +func GetMonitorSetting() *MonitorSetting { + if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { + frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) + if err == nil && frequency > 0 { + monitorSetting.AutoTestChannelEnabled = true + monitorSetting.AutoTestChannelMinutes = frequency + } + } + return &monitorSetting +} diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go index a59090ce0..b87265ee1 100644 --- a/setting/operation_setting/tools.go +++ b/setting/operation_setting/tools.go @@ -4,12 +4,8 @@ import "strings" const ( // Web search - WebSearchHighTierModelPriceLow = 30.00 - WebSearchHighTierModelPriceMedium = 35.00 - WebSearchHighTierModelPriceHigh = 50.00 - WebSearchPriceLow = 25.00 - WebSearchPriceMedium = 27.50 - WebSearchPriceHigh = 30.00 + WebSearchPriceHigh = 25.00 + WebSearchPrice = 10.00 // File search FileSearchPrice = 2.5 ) @@ -28,44 +24,29 @@ const ( ClaudeWebSearchPrice = 10.00 ) +const ( + Gemini25FlashImagePreviewImageOutputPrice = 30.00 +) + func GetClaudeWebSearchPricePerThousand() float64 { return ClaudeWebSearchPrice } func GetWebSearchPricePerThousand(modelName string, contextSize string) float64 { // 确定模型类型 - // https://platform.openai.com/docs/pricing Web search 价格按模型类型和 search context size 收费 - // gpt-4.1, gpt-4o, or gpt-4o-search-preview 更贵,gpt-4.1-mini, gpt-4o-mini, gpt-4o-mini-search-preview 更便宜 - isHighTierModel := (strings.HasPrefix(modelName, "gpt-4.1") || strings.HasPrefix(modelName, "gpt-4o")) && - !strings.Contains(modelName, "mini") - // 确定 search context size 对应的价格 + // https://platform.openai.com/docs/pricing Web search 价格按模型类型收费 + // 新版计费规则不再关联 search context size,故在const区域将各size的价格设为一致。 + // gpt-5, gpt-5-mini, gpt-5-nano 和 o 系列模型价格为 10.00 美元/千次调用,产生额外 token 计入 input_tokens + // gpt-4o, gpt-4.1, gpt-4o-mini 和 gpt-4.1-mini 价格为 25.00 美元/千次调用,不产生额外 token + isNormalPriceModel := + strings.HasPrefix(modelName, "o3") || + strings.HasPrefix(modelName, "o4") || + strings.HasPrefix(modelName, "gpt-5") var priceWebSearchPerThousandCalls float64 - switch contextSize { - case "low": - if isHighTierModel { - priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceLow - } else { - priceWebSearchPerThousandCalls = WebSearchPriceLow - } - case "medium": - if isHighTierModel { - priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceMedium - } else { - priceWebSearchPerThousandCalls = WebSearchPriceMedium - } - case "high": - if isHighTierModel { - priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceHigh - } else { - priceWebSearchPerThousandCalls = WebSearchPriceHigh - } - default: - // search context size 默认为 medium - if isHighTierModel { - priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceMedium - } else { - priceWebSearchPerThousandCalls = WebSearchPriceMedium - } + if isNormalPriceModel { + priceWebSearchPerThousandCalls = WebSearchPrice + } else { + priceWebSearchPerThousandCalls = WebSearchPriceHigh } return priceWebSearchPerThousandCalls } @@ -88,3 +69,10 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 { } return 0 } + +func GetGeminiImageOutputPricePerMillionTokens(modelName string) float64 { + if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") { + return Gemini25FlashImagePreviewImageOutputPrice + } + return 0 +} diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 53b53f885..141463e14 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -3,6 +3,7 @@ package setting import ( "encoding/json" "fmt" + "math" "one-api/common" "sync" ) @@ -20,7 +21,7 @@ func ModelRequestRateLimitGroup2JSONString() string { jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - common.SysError("error marshalling model ratio: " + err.Error()) + common.SysLog("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -58,6 +59,9 @@ func CheckModelRequestRateLimitGroup(jsonStr string) error { if limits[0] < 0 || limits[1] < 1 { return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) } + if limits[0] > math.MaxInt32 || limits[1] > math.MaxInt32 { + return fmt.Errorf("group %s [%d, %d] has max rate limits value 2147483647", group, limits[0], limits[1]) + } } return nil diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index 51d473a80..5993cdeeb 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -25,6 +25,16 @@ var defaultCacheRatio = map[string]float64{ "gpt-4o-mini-realtime-preview": 0.5, "gpt-4.5-preview": 0.5, "gpt-4.5-preview-2025-02-27": 0.5, + "gpt-4.1": 0.25, + "gpt-4.1-mini": 0.25, + "gpt-4.1-nano": 0.25, + "gpt-5": 0.1, + "gpt-5-2025-08-07": 0.1, + "gpt-5-chat-latest": 0.1, + "gpt-5-mini": 0.1, + "gpt-5-mini-2025-08-07": 0.1, + "gpt-5-nano": 0.1, + "gpt-5-nano-2025-08-07": 0.1, "deepseek-chat": 0.25, "deepseek-reasoner": 0.25, "deepseek-coder": 0.25, @@ -40,6 +50,8 @@ var defaultCacheRatio = map[string]float64{ "claude-sonnet-4-20250514-thinking": 0.1, "claude-opus-4-20250514": 0.1, "claude-opus-4-20250514-thinking": 0.1, + "claude-opus-4-1-20250805": 0.1, + "claude-opus-4-1-20250805-thinking": 0.1, } var defaultCreateCacheRatio = map[string]float64{ @@ -55,6 +67,8 @@ var defaultCreateCacheRatio = map[string]float64{ "claude-sonnet-4-20250514-thinking": 1.25, "claude-opus-4-20250514": 1.25, "claude-opus-4-20250514-thinking": 1.25, + "claude-opus-4-1-20250805": 1.25, + "claude-opus-4-1-20250805-thinking": 1.25, } //var defaultCreateCacheRatio = map[string]float64{} @@ -75,7 +89,7 @@ func CacheRatio2JSONString() string { defer cacheRatioMapMutex.RUnlock() jsonBytes, err := json.Marshal(cacheRatioMap) if err != nil { - common.SysError("error marshalling cache ratio: " + err.Error()) + common.SysLog("error marshalling cache ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/expose_ratio.go b/setting/ratio_setting/expose_ratio.go index 8fca0bcb0..783d9778e 100644 --- a/setting/ratio_setting/expose_ratio.go +++ b/setting/ratio_setting/expose_ratio.go @@ -5,13 +5,13 @@ import "sync/atomic" var exposeRatioEnabled atomic.Bool func init() { - exposeRatioEnabled.Store(false) + exposeRatioEnabled.Store(false) } func SetExposeRatioEnabled(enabled bool) { - exposeRatioEnabled.Store(enabled) + exposeRatioEnabled.Store(enabled) } func IsExposeRatioEnabled() bool { - return exposeRatioEnabled.Load() -} \ No newline at end of file + return exposeRatioEnabled.Load() +} diff --git a/setting/ratio_setting/exposed_cache.go b/setting/ratio_setting/exposed_cache.go index 9e5b6c300..2fe2cd09b 100644 --- a/setting/ratio_setting/exposed_cache.go +++ b/setting/ratio_setting/exposed_cache.go @@ -1,55 +1,55 @@ package ratio_setting import ( - "sync" - "sync/atomic" - "time" + "sync" + "sync/atomic" + "time" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) const exposedDataTTL = 30 * time.Second type exposedCache struct { - data gin.H - expiresAt time.Time + data gin.H + expiresAt time.Time } var ( - exposedData atomic.Value - rebuildMu sync.Mutex + exposedData atomic.Value + rebuildMu sync.Mutex ) func InvalidateExposedDataCache() { - exposedData.Store((*exposedCache)(nil)) + exposedData.Store((*exposedCache)(nil)) } func cloneGinH(src gin.H) gin.H { - dst := make(gin.H, len(src)) - for k, v := range src { - dst[k] = v - } - return dst + dst := make(gin.H, len(src)) + for k, v := range src { + dst[k] = v + } + return dst } func GetExposedData() gin.H { - if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { - return cloneGinH(c.data) - } - rebuildMu.Lock() - defer rebuildMu.Unlock() - if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { - return cloneGinH(c.data) - } - newData := gin.H{ - "model_ratio": GetModelRatioCopy(), - "completion_ratio": GetCompletionRatioCopy(), - "cache_ratio": GetCacheRatioCopy(), - "model_price": GetModelPriceCopy(), - } - exposedData.Store(&exposedCache{ - data: newData, - expiresAt: time.Now().Add(exposedDataTTL), - }) - return cloneGinH(newData) -} \ No newline at end of file + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + rebuildMu.Lock() + defer rebuildMu.Unlock() + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + newData := gin.H{ + "model_ratio": GetModelRatioCopy(), + "completion_ratio": GetCompletionRatioCopy(), + "cache_ratio": GetCacheRatioCopy(), + "model_price": GetModelPriceCopy(), + } + exposedData.Store(&exposedCache{ + data: newData, + expiresAt: time.Now().Add(exposedDataTTL), + }) + return cloneGinH(newData) +} diff --git a/setting/ratio_setting/group_ratio.go b/setting/ratio_setting/group_ratio.go index 86f4a8d19..c42553da0 100644 --- a/setting/ratio_setting/group_ratio.go +++ b/setting/ratio_setting/group_ratio.go @@ -48,7 +48,7 @@ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(groupRatio) if err != nil { - common.SysError("error marshalling model ratio: " + err.Error()) + common.SysLog("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -67,7 +67,7 @@ func GetGroupRatio(name string) float64 { ratio, ok := groupRatio[name] if !ok { - common.SysError("group ratio not found: " + name) + common.SysLog("group ratio not found: " + name) return 1 } return ratio @@ -94,7 +94,7 @@ func GroupGroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupGroupRatio) if err != nil { - common.SysError("error marshalling group-group ratio: " + err.Error()) + common.SysLog("error marshalling group-group ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 8a1d6aaed..1a1b0afa8 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -52,27 +52,52 @@ var defaultModelRatio = map[string]float64{ "gpt-4o-realtime-preview-2024-12-17": 2.5, "gpt-4o-mini-realtime-preview": 0.3, "gpt-4o-mini-realtime-preview-2024-12-17": 0.3, - "gpt-image-1": 2.5, - "o1": 7.5, - "o1-2024-12-17": 7.5, - "o1-preview": 7.5, - "o1-preview-2024-09-12": 7.5, - "o1-mini": 0.55, - "o1-mini-2024-09-12": 0.55, - "o3-mini": 0.55, - "o3-mini-2025-01-31": 0.55, - "o3-mini-high": 0.55, - "o3-mini-2025-01-31-high": 0.55, - "o3-mini-low": 0.55, - "o3-mini-2025-01-31-low": 0.55, - "o3-mini-medium": 0.55, - "o3-mini-2025-01-31-medium": 0.55, - "gpt-4o-mini": 0.075, - "gpt-4o-mini-2024-07-18": 0.075, - "gpt-4-turbo": 5, // $0.01 / 1K tokens - "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens - "gpt-4.5-preview": 37.5, - "gpt-4.5-preview-2025-02-27": 37.5, + "gpt-4.1": 1.0, // $2 / 1M tokens + "gpt-4.1-2025-04-14": 1.0, // $2 / 1M tokens + "gpt-4.1-mini": 0.2, // $0.4 / 1M tokens + "gpt-4.1-mini-2025-04-14": 0.2, // $0.4 / 1M tokens + "gpt-4.1-nano": 0.05, // $0.1 / 1M tokens + "gpt-4.1-nano-2025-04-14": 0.05, // $0.1 / 1M tokens + "gpt-image-1": 2.5, // $5 / 1M tokens + "o1": 7.5, // $15 / 1M tokens + "o1-2024-12-17": 7.5, // $15 / 1M tokens + "o1-preview": 7.5, // $15 / 1M tokens + "o1-preview-2024-09-12": 7.5, // $15 / 1M tokens + "o1-mini": 0.55, // $1.1 / 1M tokens + "o1-mini-2024-09-12": 0.55, // $1.1 / 1M tokens + "o1-pro": 75.0, // $150 / 1M tokens + "o1-pro-2025-03-19": 75.0, // $150 / 1M tokens + "o3-mini": 0.55, + "o3-mini-2025-01-31": 0.55, + "o3-mini-high": 0.55, + "o3-mini-2025-01-31-high": 0.55, + "o3-mini-low": 0.55, + "o3-mini-2025-01-31-low": 0.55, + "o3-mini-medium": 0.55, + "o3-mini-2025-01-31-medium": 0.55, + "o3": 1.0, // $2 / 1M tokens + "o3-2025-04-16": 1.0, // $2 / 1M tokens + "o3-pro": 10.0, // $20 / 1M tokens + "o3-pro-2025-06-10": 10.0, // $20 / 1M tokens + "o3-deep-research": 5.0, // $10 / 1M tokens + "o3-deep-research-2025-06-26": 5.0, // $10 / 1M tokens + "o4-mini": 0.55, // $1.1 / 1M tokens + "o4-mini-2025-04-16": 0.55, // $1.1 / 1M tokens + "o4-mini-deep-research": 1.0, // $2 / 1M tokens + "o4-mini-deep-research-2025-06-26": 1.0, // $2 / 1M tokens + "gpt-4o-mini": 0.075, + "gpt-4o-mini-2024-07-18": 0.075, + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens + "gpt-4.5-preview": 37.5, + "gpt-4.5-preview-2025-02-27": 37.5, + "gpt-5": 0.625, + "gpt-5-2025-08-07": 0.625, + "gpt-5-chat-latest": 0.625, + "gpt-5-mini": 0.125, + "gpt-5-mini-2025-08-07": 0.125, + "gpt-5-nano": 0.025, + "gpt-5-nano-2025-08-07": 0.025, //"gpt-3.5-turbo-0301": 0.75, //deprecated "gpt-3.5-turbo": 0.25, "gpt-3.5-turbo-0613": 0.75, @@ -118,6 +143,7 @@ var defaultModelRatio = map[string]float64{ "claude-sonnet-4-20250514": 1.5, "claude-3-opus-20240229": 7.5, // $15 / 1M tokens "claude-opus-4-20250514": 7.5, + "claude-opus-4-1-20250805": 7.5, "ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-3.5-8K": 0.012 * RMB, "ERNIE-3.5-8K-0205": 0.024 * RMB, @@ -149,8 +175,10 @@ var defaultModelRatio = map[string]float64{ "gemini-2.5-flash-preview-05-20-nothinking": 0.075, "gemini-2.5-flash-thinking-*": 0.075, // 用于为后续所有2.5 flash thinking budget 模型设置默认倍率 "gemini-2.5-pro-thinking-*": 0.625, // 用于为后续所有2.5 pro thinking budget 模型设置默认倍率 + "gemini-2.5-flash-lite-preview-thinking-*": 0.05, "gemini-2.5-flash-lite-preview-06-17": 0.05, "gemini-2.5-flash": 0.15, + "gemini-2.5-flash-image-preview": 0.15, // $0.30(text/image) / 1M tokens "text-embedding-004": 0.001, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens @@ -266,10 +294,11 @@ var ( ) var defaultCompletionRatio = map[string]float64{ - "gpt-4-gizmo-*": 2, - "gpt-4o-gizmo-*": 3, - "gpt-4-all": 2, - "gpt-image-1": 8, + "gpt-4-gizmo-*": 2, + "gpt-4o-gizmo-*": 3, + "gpt-4-all": 2, + "gpt-image-1": 8, + "gemini-2.5-flash-image-preview": 8.3333333333, } // InitRatioSettings initializes all model related settings maps @@ -311,7 +340,7 @@ func ModelPrice2JSONString() string { modelPriceMapMutex.RLock() defer modelPriceMapMutex.RUnlock() - jsonBytes, err := json.Marshal(modelPriceMap) + jsonBytes, err := common.Marshal(modelPriceMap) if err != nil { common.SysError("error marshalling model price: " + err.Error()) } @@ -334,12 +363,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) { modelPriceMapMutex.RLock() defer modelPriceMapMutex.RUnlock() - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } - if strings.HasPrefix(name, "gpt-4o-gizmo") { - name = "gpt-4o-gizmo-*" - } + name = FormatMatchingModelName(name) + price, ok := modelPriceMap[name] if !ok { if printErr { @@ -354,7 +379,7 @@ func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() modelRatioMap = make(map[string]float64) - err := json.Unmarshal([]byte(jsonStr), &modelRatioMap) + err := common.Unmarshal([]byte(jsonStr), &modelRatioMap) if err == nil { InvalidateExposedDataCache() } @@ -373,11 +398,8 @@ func GetModelRatio(name string) (float64, bool, string) { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() - name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*") - name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*") - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } + name = FormatMatchingModelName(name) + ratio, ok := modelRatioMap[name] if !ok { return 37.5, operation_setting.SelfUseModeEnabled, name @@ -386,7 +408,7 @@ func GetModelRatio(name string) (float64, bool, string) { } func DefaultModelRatio2JSONString() string { - jsonBytes, err := json.Marshal(defaultModelRatio) + jsonBytes, err := common.Marshal(defaultModelRatio) if err != nil { common.SysError("error marshalling model ratio: " + err.Error()) } @@ -418,7 +440,7 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { CompletionRatioMutex.Lock() defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) - err := json.Unmarshal([]byte(jsonStr), &CompletionRatio) + err := common.Unmarshal([]byte(jsonStr), &CompletionRatio) if err == nil { InvalidateExposedDataCache() } @@ -428,12 +450,9 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { func GetCompletionRatio(name string) float64 { CompletionRatioMutex.RLock() defer CompletionRatioMutex.RUnlock() - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } - if strings.HasPrefix(name, "gpt-4o-gizmo") { - name = "gpt-4o-gizmo-*" - } + + name = FormatMatchingModelName(name) + if strings.Contains(name, "/") { if ratio, ok := CompletionRatio[name]; ok { return ratio @@ -451,13 +470,23 @@ func GetCompletionRatio(name string) float64 { func getHardcodedCompletionModelRatio(name string) (float64, bool) { lowercaseName := strings.ToLower(name) - if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") { + + isReservedModel := strings.HasSuffix(name, "-all") || strings.HasSuffix(name, "-gizmo-*") + if isReservedModel { + return 2, false + } + + if strings.HasPrefix(name, "gpt-") { if strings.HasPrefix(name, "gpt-4o") { if name == "gpt-4o-2024-05-13" { return 3, true } return 4, true } + // gpt-5 匹配 + if strings.HasPrefix(name, "gpt-5") { + return 8, true + } // gpt-4.5-preview匹配 if strings.HasPrefix(name, "gpt-4.5-preview") { return 2, true @@ -512,12 +541,9 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) { return 3.5 / 0.15, false } if strings.HasPrefix(name, "gemini-2.5-flash-lite") { - if strings.HasPrefix(name, "gemini-2.5-flash-lite-preview") { - return 4, false - } return 4, false } - return 2.5 / 0.3, true + return 2.5 / 0.3, false } return 4, false } @@ -594,7 +620,7 @@ func ModelRatio2JSONString() string { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() - jsonBytes, err := json.Marshal(modelRatioMap) + jsonBytes, err := common.Marshal(modelRatioMap) if err != nil { common.SysError("error marshalling model ratio: " + err.Error()) } @@ -610,7 +636,7 @@ var imageRatioMapMutex sync.RWMutex func ImageRatio2JSONString() string { imageRatioMapMutex.RLock() defer imageRatioMapMutex.RUnlock() - jsonBytes, err := json.Marshal(imageRatioMap) + jsonBytes, err := common.Marshal(imageRatioMap) if err != nil { common.SysError("error marshalling cache ratio: " + err.Error()) } @@ -621,7 +647,7 @@ func UpdateImageRatioByJSONString(jsonStr string) error { imageRatioMapMutex.Lock() defer imageRatioMapMutex.Unlock() imageRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &imageRatioMap) + return common.Unmarshal([]byte(jsonStr), &imageRatioMap) } func GetImageRatio(name string) (float64, bool) { @@ -663,3 +689,23 @@ func GetCompletionRatioCopy() map[string]float64 { } return copyMap } + +// 转换模型名,减少渠道必须配置各种带参数模型 +func FormatMatchingModelName(name string) string { + + if strings.HasPrefix(name, "gemini-2.5-flash-lite") { + name = handleThinkingBudgetModel(name, "gemini-2.5-flash-lite", "gemini-2.5-flash-lite-thinking-*") + } else if strings.HasPrefix(name, "gemini-2.5-flash") { + name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*") + } else if strings.HasPrefix(name, "gemini-2.5-pro") { + name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*") + } + + if strings.HasPrefix(name, "gpt-4-gizmo") { + name = "gpt-4-gizmo-*" + } + if strings.HasPrefix(name, "gpt-4o-gizmo") { + name = "gpt-4o-gizmo-*" + } + return name +} diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go index 0ae132d0e..57e4beecf 100644 --- a/setting/user_usable_group.go +++ b/setting/user_usable_group.go @@ -29,7 +29,7 @@ func UserUsableGroups2JSONString() string { jsonBytes, err := json.Marshal(userUsableGroups) if err != nil { - common.SysError("error marshalling user groups: " + err.Error()) + common.SysLog("error marshalling user groups: " + err.Error()) } return string(jsonBytes) } diff --git a/types/error.go b/types/error.go index 5c8b37d22..f653e9a28 100644 --- a/types/error.go +++ b/types/error.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "one-api/common" "strings" ) @@ -15,8 +16,8 @@ type OpenAIError struct { } type ClaudeError struct { - Message string `json:"message,omitempty"` Type string `json:"type,omitempty"` + Message string `json:"message,omitempty"` } type ErrorType string @@ -28,6 +29,7 @@ const ( ErrorTypeMidjourneyError ErrorType = "midjourney_error" ErrorTypeGeminiError ErrorType = "gemini_error" ErrorTypeRerankError ErrorType = "rerank_error" + ErrorTypeUpstreamError ErrorType = "upstream_error" ) type ErrorCode string @@ -37,20 +39,22 @@ const ( ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected" // new api error - ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" - ErrorCodeModelPriceError ErrorCode = "model_price_error" - ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" - ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" - ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" - ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" + ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" + ErrorCodeModelPriceError ErrorCode = "model_price_error" + ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" + ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" + ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" + ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" + ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed" // channel error - ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" - ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" - ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" - ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" - ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" - ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" + ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" + ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" + ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid" + ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" + ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" + ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" + ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" // client request error ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed" @@ -62,6 +66,9 @@ const ( ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code" ErrorCodeBadResponse ErrorCode = "bad_response" ErrorCodeBadResponseBody ErrorCode = "bad_response_body" + ErrorCodeEmptyResponse ErrorCode = "empty_response" + ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error" + ErrorCodeModelNotFound ErrorCode = "model_not_found" // sql error ErrorCodeQueryDataError ErrorCode = "query_data_error" @@ -73,11 +80,13 @@ const ( ) type NewAPIError struct { - Err error - RelayError any - ErrorType ErrorType - errorCode ErrorCode - StatusCode int + Err error + RelayError any + skipRetry bool + recordErrorLog *bool + errorType ErrorType + errorCode ErrorCode + StatusCode int } func (e *NewAPIError) GetErrorCode() ErrorCode { @@ -87,6 +96,13 @@ func (e *NewAPIError) GetErrorCode() ErrorCode { return e.errorCode } +func (e *NewAPIError) GetErrorType() ErrorType { + if e == nil { + return "" + } + return e.errorType +} + func (e *NewAPIError) Error() string { if e == nil { return "" @@ -98,100 +114,168 @@ func (e *NewAPIError) Error() string { return e.Err.Error() } +func (e *NewAPIError) MaskSensitiveError() string { + if e == nil { + return "" + } + if e.Err == nil { + return string(e.errorCode) + } + errStr := e.Err.Error() + return common.MaskSensitiveInfo(errStr) +} + func (e *NewAPIError) SetMessage(message string) { e.Err = errors.New(message) } func (e *NewAPIError) ToOpenAIError() OpenAIError { - switch e.ErrorType { + var result OpenAIError + switch e.errorType { case ErrorTypeOpenAIError: - return e.RelayError.(OpenAIError) + if openAIError, ok := e.RelayError.(OpenAIError); ok { + result = openAIError + } case ErrorTypeClaudeError: - claudeError := e.RelayError.(ClaudeError) - return OpenAIError{ - Message: e.Error(), - Type: claudeError.Type, - Param: "", - Code: e.errorCode, + if claudeError, ok := e.RelayError.(ClaudeError); ok { + result = OpenAIError{ + Message: e.Error(), + Type: claudeError.Type, + Param: "", + Code: e.errorCode, + } } default: - return OpenAIError{ + result = OpenAIError{ Message: e.Error(), - Type: string(e.ErrorType), + Type: string(e.errorType), Param: "", Code: e.errorCode, } } + + result.Message = common.MaskSensitiveInfo(result.Message) + return result } func (e *NewAPIError) ToClaudeError() ClaudeError { - switch e.ErrorType { + var result ClaudeError + switch e.errorType { case ErrorTypeOpenAIError: - openAIError := e.RelayError.(OpenAIError) - return ClaudeError{ - Message: e.Error(), - Type: fmt.Sprintf("%v", openAIError.Code), + if openAIError, ok := e.RelayError.(OpenAIError); ok { + result = ClaudeError{ + Message: e.Error(), + Type: fmt.Sprintf("%v", openAIError.Code), + } } case ErrorTypeClaudeError: - return e.RelayError.(ClaudeError) + if claudeError, ok := e.RelayError.(ClaudeError); ok { + result = claudeError + } default: - return ClaudeError{ + result = ClaudeError{ Message: e.Error(), - Type: string(e.ErrorType), + Type: string(e.errorType), } } + result.Message = common.MaskSensitiveInfo(result.Message) + return result } -func NewError(err error, errorCode ErrorCode) *NewAPIError { - return &NewAPIError{ +type NewAPIErrorOptions func(*NewAPIError) + +func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { + e := &NewAPIError{ Err: err, RelayError: nil, - ErrorType: ErrorTypeNewAPIError, + errorType: ErrorTypeNewAPIError, StatusCode: http.StatusInternalServerError, errorCode: errorCode, } + for _, op := range ops { + op(e) + } + return e } -func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError { +func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { + if errorCode == ErrorCodeDoRequestFailed { + err = errors.New("upstream error: do request failed") + } openaiError := OpenAIError{ Message: err.Error(), Type: string(errorCode), + Code: errorCode, } - return WithOpenAIError(openaiError, statusCode) + return WithOpenAIError(openaiError, statusCode, ops...) } -func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError { - return &NewAPIError{ - Err: err, - RelayError: nil, - ErrorType: ErrorTypeNewAPIError, +func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { + openaiError := OpenAIError{ + Type: string(errorCode), + Code: errorCode, + } + return WithOpenAIError(openaiError, statusCode, ops...) +} + +func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { + e := &NewAPIError{ + Err: err, + RelayError: OpenAIError{ + Message: err.Error(), + Type: string(errorCode), + }, + errorType: ErrorTypeNewAPIError, StatusCode: statusCode, errorCode: errorCode, } + for _, op := range ops { + op(e) + } + + return e } -func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { +func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { code, ok := openAIError.Code.(string) if !ok { - code = fmt.Sprintf("%v", openAIError.Code) + if openAIError.Code != nil { + code = fmt.Sprintf("%v", openAIError.Code) + } else { + code = "unknown_error" + } } - return &NewAPIError{ + if openAIError.Type == "" { + openAIError.Type = "upstream_error" + } + e := &NewAPIError{ RelayError: openAIError, - ErrorType: ErrorTypeOpenAIError, + errorType: ErrorTypeOpenAIError, StatusCode: statusCode, Err: errors.New(openAIError.Message), errorCode: ErrorCode(code), } + for _, op := range ops { + op(e) + } + return e } -func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError { - return &NewAPIError{ +func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { + if claudeError.Type == "" { + claudeError.Type = "upstream_error" + } + e := &NewAPIError{ RelayError: claudeError, - ErrorType: ErrorTypeClaudeError, + errorType: ErrorTypeClaudeError, StatusCode: statusCode, Err: errors.New(claudeError.Message), errorCode: ErrorCode(claudeError.Type), } + for _, op := range ops { + op(e) + } + return e } func IsChannelError(err *NewAPIError) bool { @@ -201,10 +285,33 @@ func IsChannelError(err *NewAPIError) bool { return strings.HasPrefix(string(err.errorCode), "channel:") } -func IsLocalError(err *NewAPIError) bool { +func IsSkipRetryError(err *NewAPIError) bool { if err == nil { return false } - return err.ErrorType == ErrorTypeNewAPIError + return err.skipRetry +} + +func ErrOptionWithSkipRetry() NewAPIErrorOptions { + return func(e *NewAPIError) { + e.skipRetry = true + } +} + +func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions { + return func(e *NewAPIError) { + e.recordErrorLog = common.GetPointer(false) + } +} + +func IsRecordErrorLog(e *NewAPIError) bool { + if e == nil { + return false + } + if e.recordErrorLog == nil { + // default to true if not set + return true + } + return *e.recordErrorLog } diff --git a/dto/file_data.go b/types/file_data.go similarity index 88% rename from dto/file_data.go rename to types/file_data.go index d5cf0f684..f1c82e21e 100644 --- a/dto/file_data.go +++ b/types/file_data.go @@ -1,4 +1,4 @@ -package dto +package types type LocalFileData struct { MimeType string diff --git a/types/price_data.go b/types/price_data.go new file mode 100644 index 000000000..f6a92d7e3 --- /dev/null +++ b/types/price_data.go @@ -0,0 +1,31 @@ +package types + +import "fmt" + +type GroupRatioInfo struct { + GroupRatio float64 + GroupSpecialRatio float64 + HasSpecialRatio bool +} + +type PriceData struct { + ModelPrice float64 + ModelRatio float64 + CompletionRatio float64 + CacheRatio float64 + CacheCreationRatio float64 + ImageRatio float64 + UsePrice bool + ShouldPreConsumedQuota int + GroupRatioInfo GroupRatioInfo +} + +type PerCallPriceData struct { + ModelPrice float64 + Quota int + GroupRatioInfo GroupRatioInfo +} + +func (p PriceData) ToSetting() string { + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) +} diff --git a/types/relay_format.go b/types/relay_format.go new file mode 100644 index 000000000..6d94a70bc --- /dev/null +++ b/types/relay_format.go @@ -0,0 +1,18 @@ +package types + +type RelayFormat string + +const ( + RelayFormatOpenAI RelayFormat = "openai" + RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" + RelayFormatOpenAIResponses = "openai_responses" + RelayFormatOpenAIAudio = "openai_audio" + RelayFormatOpenAIImage = "openai_image" + RelayFormatOpenAIRealtime = "openai_realtime" + RelayFormatRerank = "rerank" + RelayFormatEmbedding = "embedding" + + RelayFormatTask = "task" + RelayFormatMjProxy = "mj_proxy" +) diff --git a/types/request_meta.go b/types/request_meta.go new file mode 100644 index 000000000..18f80832b --- /dev/null +++ b/types/request_meta.go @@ -0,0 +1,46 @@ +package types + +type FileType string + +const ( + FileTypeImage FileType = "image" // Image file type + FileTypeAudio FileType = "audio" // Audio file type + FileTypeVideo FileType = "video" // Video file type + FileTypeFile FileType = "file" // Generic file type +) + +type TokenType string + +const ( + TokenTypeTextNumber TokenType = "text_number" // Text or number tokens + TokenTypeTokenizer TokenType = "tokenizer" // Tokenizer tokens + TokenTypeImage TokenType = "image" // Image tokens +) + +type TokenCountMeta struct { + TokenType TokenType `json:"token_type,omitempty"` // Type of tokens used in the request + CombineText string `json:"combine_text,omitempty"` // Combined text from all messages + ToolsCount int `json:"tools_count,omitempty"` // Number of tools used + NameCount int `json:"name_count,omitempty"` // Number of names in the request + MessagesCount int `json:"messages_count,omitempty"` // Number of messages in the request + Files []*FileMeta `json:"files,omitempty"` // List of files, each with type and content + MaxTokens int `json:"max_tokens,omitempty"` // Maximum tokens allowed in the request + + ImagePriceRatio float64 `json:"image_ratio,omitempty"` // Ratio for image size, if applicable + //IsStreaming bool `json:"is_streaming,omitempty"` // Indicates if the request is streaming +} + +type FileMeta struct { + FileType + MimeType string + OriginData string // url or base64 data + Detail string + ParsedData *LocalFileData +} + +type RequestMeta struct { + OriginalModelName string `json:"original_model_name"` + UserUsingGroup string `json:"user_using_group"` + PromptTokens int `json:"prompt_tokens"` + PreConsumedQuota int `json:"pre_consumed_quota"` +} diff --git a/web/.eslintrc.cjs b/web/.eslintrc.cjs new file mode 100644 index 000000000..b1afd96f5 --- /dev/null +++ b/web/.eslintrc.cjs @@ -0,0 +1,42 @@ +module.exports = { + root: true, + env: { browser: true, es2021: true, node: true }, + parserOptions: { + ecmaVersion: 2020, + sourceType: 'module', + ecmaFeatures: { jsx: true }, + }, + plugins: ['header', 'react-hooks'], + overrides: [ + { + files: ['**/*.{js,jsx}'], + rules: { + 'header/header': [ + 2, + 'block', + [ + '', + 'Copyright (C) 2025 QuantumNous', + '', + 'This program is free software: you can redistribute it and/or modify', + 'it under the terms of the GNU Affero General Public License as', + 'published by the Free Software Foundation, either version 3 of the', + 'License, or (at your option) any later version.', + '', + 'This program is distributed in the hope that it will be useful,', + 'but WITHOUT ANY WARRANTY; without even the implied warranty of', + 'MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the', + 'GNU Affero General Public License for more details.', + '', + 'You should have received a copy of the GNU Affero General Public License', + 'along with this program. If not, see .', + '', + 'For commercial licensing, please contact support@quantumnous.com', + '', + ], + ], + 'no-multiple-empty-lines': ['error', { max: 1 }], + }, + }, + ], +}; diff --git a/web/bun.lock b/web/bun.lock index b78c149bf..53467aa5e 100644 --- a/web/bun.lock +++ b/web/bun.lock @@ -21,6 +21,7 @@ "lucide-react": "^0.511.0", "marked": "^4.1.1", "mermaid": "^11.6.0", + "qrcode.react": "^4.2.0", "react": "^18.2.0", "react-dom": "^18.2.0", "react-dropzone": "^14.2.3", @@ -46,6 +47,9 @@ "@so1ve/prettier-config": "^3.1.0", "@vitejs/plugin-react": "^4.2.1", "autoprefixer": "^10.4.21", + "eslint": "8.57.0", + "eslint-plugin-header": "^3.1.1", + "eslint-plugin-react-hooks": "^5.2.0", "postcss": "^8.5.3", "prettier": "^3.0.0", "tailwindcss": "^3", @@ -237,6 +241,14 @@ "@esbuild/win32-x64": ["@esbuild/win32-x64@0.21.5", "", { "os": "win32", "cpu": "x64" }, "sha512-tQd/1efJuzPC6rCFwEvLtci/xNFcTZknmXs98FYDfGE4wP9ClFV98nyKrzJKVPMhdDnjzLhdUyMX4PsQAPjwIw=="], + "@eslint-community/eslint-utils": ["@eslint-community/eslint-utils@4.7.0", "", { "dependencies": { "eslint-visitor-keys": "^3.4.3" }, "peerDependencies": { "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" } }, "sha512-dyybb3AcajC7uha6CvhdVRJqaKyn7w2YKqKyAN37NKYgZT36w+iRb0Dymmc5qEJ549c/S31cMMSFd75bteCpCw=="], + + "@eslint-community/regexpp": ["@eslint-community/regexpp@4.12.1", "", {}, "sha512-CCZCDJuduB9OUkFkY2IgppNZMi2lBQgD2qzwXkEia16cge2pijY/aXi96CJMquDMn3nJdlPV1A5KrJEXwfLNzQ=="], + + "@eslint/eslintrc": ["@eslint/eslintrc@2.1.4", "", { "dependencies": { "ajv": "^6.12.4", "debug": "^4.3.2", "espree": "^9.6.0", "globals": "^13.19.0", "ignore": "^5.2.0", "import-fresh": "^3.2.1", "js-yaml": "^4.1.0", "minimatch": "^3.1.2", "strip-json-comments": "^3.1.1" } }, "sha512-269Z39MS6wVJtsoUl10L60WdkhJVdPG24Q4eZTH3nnF6lpvSShEK3wQjDX9JRWAUPvPh7COouPpU9IrqaZFvtQ=="], + + "@eslint/js": ["@eslint/js@8.57.0", "", {}, "sha512-Ys+3g2TaW7gADOJzPt83SJtCDhMjndcDMFVQ/Tj9iA1BfJzFKD9mAUXT3OenpuPHbI6P/myECxRJrofUsDx/5g=="], + "@floating-ui/core": ["@floating-ui/core@1.7.0", "", { "dependencies": { "@floating-ui/utils": "^0.2.9" } }, "sha512-FRdBLykrPPA6P76GGGqlex/e7fbe0F1ykgxHYNXQsH/iTEtjMj/f9bpY5oQqbjt5VgZvgz/uKXbGuROijh3VLA=="], "@floating-ui/dom": ["@floating-ui/dom@1.7.0", "", { "dependencies": { "@floating-ui/core": "^1.7.0", "@floating-ui/utils": "^0.2.9" } }, "sha512-lGTor4VlXcesUMh1cupTUTDoCxMb0V6bm3CnxHzQcw8Eaf1jQbgQX4i02fYgT0vJ82tb5MZ4CZk1LRGkktJCzg=="], @@ -249,6 +261,12 @@ "@giscus/react": ["@giscus/react@3.1.0", "", { "dependencies": { "giscus": "^1.6.0" }, "peerDependencies": { "react": "^16 || ^17 || ^18 || ^19", "react-dom": "^16 || ^17 || ^18 || ^19" } }, "sha512-0TCO2TvL43+oOdyVVGHDItwxD1UMKP2ZYpT6gXmhFOqfAJtZxTzJ9hkn34iAF/b6YzyJ4Um89QIt9z/ajmAEeg=="], + "@humanwhocodes/config-array": ["@humanwhocodes/config-array@0.11.14", "", { "dependencies": { "@humanwhocodes/object-schema": "^2.0.2", "debug": "^4.3.1", "minimatch": "^3.0.5" } }, "sha512-3T8LkOmg45BV5FICb15QQMsyUSWrQ8AygVfC7ZG32zOalnqrilm018ZVCw0eapXux8FtA33q8PSRSstjee3jSg=="], + + "@humanwhocodes/module-importer": ["@humanwhocodes/module-importer@1.0.1", "", {}, "sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA=="], + + "@humanwhocodes/object-schema": ["@humanwhocodes/object-schema@2.0.3", "", {}, "sha512-93zYdMES/c1D69yZiKDBj0V24vqNzB/koF26KPaagAfd3P/4gUlh3Dys5ogAK+Exi9QyzlD8x/08Zt7wIKcDcA=="], + "@iconify/types": ["@iconify/types@2.0.0", "", {}, "sha512-+wluvCrRhXrhyOmRDJ3q8mux9JkKy5SJ/v8ol2tu4FVjyYvtEzkc/3pK15ET6RKg4b4w4BmTk1+gsCUhf21Ykg=="], "@iconify/utils": ["@iconify/utils@2.3.0", "", { "dependencies": { "@antfu/install-pkg": "^1.0.0", "@antfu/utils": "^8.1.0", "@iconify/types": "^2.0.0", "debug": "^4.4.0", "globals": "^15.14.0", "kolorist": "^1.8.0", "local-pkg": "^1.0.0", "mlly": "^1.7.4" } }, "sha512-GmQ78prtwYW6EtzXRU1rY+KwOKfz32PD7iJh6Iyqw68GiKuoZ2A6pRtzWONz5VQJbp50mEjXh/7NkumtrAgRKA=="], @@ -629,15 +647,17 @@ "abs-svg-path": ["abs-svg-path@0.1.1", "", {}, "sha512-d8XPSGjfyzlXC3Xx891DJRyZfqk5JU0BJrDQcsWomFIV1/BIzPW5HDH5iDdWpqWaav0YVIEzT1RHTwWr0FFshA=="], - "acorn": ["acorn@8.14.0", "", { "bin": { "acorn": "bin/acorn" } }, "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA=="], + "acorn": ["acorn@8.15.0", "", { "bin": { "acorn": "bin/acorn" } }, "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg=="], "acorn-jsx": ["acorn-jsx@5.3.2", "", { "peerDependencies": { "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ=="], "ahooks": ["ahooks@3.8.5", "", { "dependencies": { "@babel/runtime": "^7.21.0", "dayjs": "^1.9.1", "intersection-observer": "^0.12.0", "js-cookie": "^3.0.5", "lodash": "^4.17.21", "react-fast-compare": "^3.2.2", "resize-observer-polyfill": "^1.5.1", "screenfull": "^5.0.0", "tslib": "^2.4.1" }, "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-Y+MLoJpBXVdjsnnBjE5rOSPkQ4DK+8i5aPDzLJdIOsCpo/fiAeXcBY1Y7oWgtOK0TpOz0gFa/XcyO1UGdoqLcw=="], - "ansi-regex": ["ansi-regex@6.1.0", "", {}, "sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA=="], + "ajv": ["ajv@6.12.6", "", { "dependencies": { "fast-deep-equal": "^3.1.1", "fast-json-stable-stringify": "^2.0.0", "json-schema-traverse": "^0.4.1", "uri-js": "^4.2.2" } }, "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g=="], - "ansi-styles": ["ansi-styles@6.2.1", "", {}, "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug=="], + "ansi-regex": ["ansi-regex@5.0.1", "", {}, "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ=="], + + "ansi-styles": ["ansi-styles@4.3.0", "", { "dependencies": { "color-convert": "^2.0.1" } }, "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg=="], "antd": ["antd@5.25.2", "", { "dependencies": { "@ant-design/colors": "^7.2.0", "@ant-design/cssinjs": "^1.23.0", "@ant-design/cssinjs-utils": "^1.1.3", "@ant-design/fast-color": "^2.0.6", "@ant-design/icons": "^5.6.1", "@ant-design/react-slick": "~1.1.2", "@babel/runtime": "^7.26.0", "@rc-component/color-picker": "~2.0.1", "@rc-component/mutate-observer": "^1.1.0", "@rc-component/qrcode": "~1.0.0", "@rc-component/tour": "~1.15.1", "@rc-component/trigger": "^2.2.6", "classnames": "^2.5.1", "copy-to-clipboard": "^3.3.3", "dayjs": "^1.11.11", "rc-cascader": "~3.34.0", "rc-checkbox": "~3.5.0", "rc-collapse": "~3.9.0", "rc-dialog": "~9.6.0", "rc-drawer": "~7.2.0", "rc-dropdown": "~4.2.1", "rc-field-form": "~2.7.0", "rc-image": "~7.12.0", "rc-input": "~1.8.0", "rc-input-number": "~9.5.0", "rc-mentions": "~2.20.0", "rc-menu": "~9.16.1", "rc-motion": "^2.9.5", "rc-notification": "~5.6.4", "rc-pagination": "~5.1.0", "rc-picker": "~4.11.3", "rc-progress": "~4.0.0", "rc-rate": "~2.13.1", "rc-resize-observer": "^1.4.3", "rc-segmented": "~2.7.0", "rc-select": "~14.16.8", "rc-slider": "~11.1.8", "rc-steps": "~6.0.1", "rc-switch": "~4.1.0", "rc-table": "~7.50.5", "rc-tabs": "~15.6.1", "rc-textarea": "~1.10.0", "rc-tooltip": "~6.4.0", "rc-tree": "~5.13.1", "rc-tree-select": "~5.27.0", "rc-upload": "~4.9.0", "rc-util": "^5.44.4", "scroll-into-view-if-needed": "^3.1.0", "throttle-debounce": "^5.0.2" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-7R2nUvlHhey7Trx64+hCtGXOiy+DTUs1Lv5bwbV1LzEIZIhWb0at1AM6V3K108a5lyoR9n7DX3ptlLF7uYV/DQ=="], @@ -649,6 +669,8 @@ "arg": ["arg@5.0.2", "", {}, "sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg=="], + "argparse": ["argparse@2.0.1", "", {}, "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q=="], + "array-source": ["array-source@0.0.4", "", {}, "sha512-frNdc+zBn80vipY+GdcJkLEbMWj3xmzArYApmUGxoiV8uAu/ygcs9icPdsGdA26h0MkHUMW6EN2piIvVx+M5Mw=="], "assign-symbols": ["assign-symbols@1.0.0", "", {}, "sha512-Q+JC7Whu8HhmTdBph/Tq59IoRtoy6KAm5zzPv00WdujX82lbAL8K7WVjne7vdCsAmbF4AYaDOPyO3k0kl8qIrw=="], @@ -699,6 +721,8 @@ "ccount": ["ccount@2.0.1", "", {}, "sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg=="], + "chalk": ["chalk@4.1.2", "", { "dependencies": { "ansi-styles": "^4.1.0", "supports-color": "^7.1.0" } }, "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA=="], + "character-entities": ["character-entities@2.0.2", "", {}, "sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ=="], "character-entities-html4": ["character-entities-html4@2.1.0", "", {}, "sha512-1v7fgQRj6hnSwFpq1Eu0ynr/CDEw0rXo2B61qXrLNdHZmPKgb7fqS1a2JwF0rISo9q77jDI8VMEHoApn8qDoZA=="], @@ -851,6 +875,8 @@ "decode-uri-component": ["decode-uri-component@0.4.1", "", {}, "sha512-+8VxcR21HhTy8nOt6jf20w0c9CADrw1O8d+VZ/YzzCt4bJ3uBjw+D1q2osAB8RnpwwaeYBxy0HyKQxD5JBMuuQ=="], + "deep-is": ["deep-is@0.1.4", "", {}, "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ=="], + "delaunator": ["delaunator@5.0.1", "", { "dependencies": { "robust-predicates": "^3.0.2" } }, "sha512-8nvh+XBe96aCESrGOqMp/84b13H9cdKbG5P2ejQCh4d4sK9RL4371qou9drQjMhvnPmhWl5hnmqbEE0fXr9Xnw=="], "delayed-stream": ["delayed-stream@1.0.0", "", {}, "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ=="], @@ -865,6 +891,8 @@ "dlv": ["dlv@1.1.3", "", {}, "sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA=="], + "doctrine": ["doctrine@3.0.0", "", { "dependencies": { "esutils": "^2.0.2" } }, "sha512-yS+Q5i3hBf7GBkd4KG8a7eBNNWNGLTaEwwYWUijIYM7zrlYDM0BFXHjjPWlWZ1Rg7UaddZeIDmi9jF3HmqiQ2w=="], + "dompurify": ["dompurify@3.2.6", "", { "optionalDependencies": { "@types/trusted-types": "^2.0.7" } }, "sha512-/2GogDQlohXPZe6D6NOgQvXLPSYBqIWMnZ8zzOhn09REE4eyAzb+Hed3jhoM9OkuaJ8P6ZGTTVWQKAi8ieIzfQ=="], "eastasianwidth": ["eastasianwidth@0.2.0", "", {}, "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA=="], @@ -887,7 +915,25 @@ "escalade": ["escalade@3.2.0", "", {}, "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA=="], - "escape-string-regexp": ["escape-string-regexp@5.0.0", "", {}, "sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw=="], + "escape-string-regexp": ["escape-string-regexp@4.0.0", "", {}, "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA=="], + + "eslint": ["eslint@8.57.0", "", { "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", "@eslint-community/regexpp": "^4.6.1", "@eslint/eslintrc": "^2.1.4", "@eslint/js": "8.57.0", "@humanwhocodes/config-array": "^0.11.14", "@humanwhocodes/module-importer": "^1.0.1", "@nodelib/fs.walk": "^1.2.8", "@ungap/structured-clone": "^1.2.0", "ajv": "^6.12.4", "chalk": "^4.0.0", "cross-spawn": "^7.0.2", "debug": "^4.3.2", "doctrine": "^3.0.0", "escape-string-regexp": "^4.0.0", "eslint-scope": "^7.2.2", "eslint-visitor-keys": "^3.4.3", "espree": "^9.6.1", "esquery": "^1.4.2", "esutils": "^2.0.2", "fast-deep-equal": "^3.1.3", "file-entry-cache": "^6.0.1", "find-up": "^5.0.0", "glob-parent": "^6.0.2", "globals": "^13.19.0", "graphemer": "^1.4.0", "ignore": "^5.2.0", "imurmurhash": "^0.1.4", "is-glob": "^4.0.0", "is-path-inside": "^3.0.3", "js-yaml": "^4.1.0", "json-stable-stringify-without-jsonify": "^1.0.1", "levn": "^0.4.1", "lodash.merge": "^4.6.2", "minimatch": "^3.1.2", "natural-compare": "^1.4.0", "optionator": "^0.9.3", "strip-ansi": "^6.0.1", "text-table": "^0.2.0" }, "bin": { "eslint": "bin/eslint.js" } }, "sha512-dZ6+mexnaTIbSBZWgou51U6OmzIhYM2VcNdtiTtI7qPNZm35Akpr0f6vtw3w1Kmn5PYo+tZVfh13WrhpS6oLqQ=="], + + "eslint-plugin-header": ["eslint-plugin-header@3.1.1", "", { "peerDependencies": { "eslint": ">=7.7.0" } }, "sha512-9vlKxuJ4qf793CmeeSrZUvVClw6amtpghq3CuWcB5cUNnWHQhgcqy5eF8oVKFk1G3Y/CbchGfEaw3wiIJaNmVg=="], + + "eslint-plugin-react-hooks": ["eslint-plugin-react-hooks@5.2.0", "", { "peerDependencies": { "eslint": "^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0-0 || ^9.0.0" } }, "sha512-+f15FfK64YQwZdJNELETdn5ibXEUQmW1DZL6KXhNnc2heoy/sg9VJJeT7n8TlMWouzWqSWavFkIhHyIbIAEapg=="], + + "eslint-scope": ["eslint-scope@7.2.2", "", { "dependencies": { "esrecurse": "^4.3.0", "estraverse": "^5.2.0" } }, "sha512-dOt21O7lTMhDM+X9mB4GX+DZrZtCUJPL/wlcTqxyrx5IvO0IYtILdtrQGQp+8n5S0gwSVmOf9NQrjMOgfQZlIg=="], + + "eslint-visitor-keys": ["eslint-visitor-keys@3.4.3", "", {}, "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag=="], + + "espree": ["espree@9.6.1", "", { "dependencies": { "acorn": "^8.9.0", "acorn-jsx": "^5.3.2", "eslint-visitor-keys": "^3.4.1" } }, "sha512-oruZaFkjorTpF32kDSI5/75ViwGeZginGGy2NoOSg3Q9bnwlnmDm4HLnkl0RE3n+njDXR037aY1+x58Z/zFdwQ=="], + + "esquery": ["esquery@1.6.0", "", { "dependencies": { "estraverse": "^5.1.0" } }, "sha512-ca9pw9fomFcKPvFLXhBKUK90ZvGibiGOvRJNbjljY7s7uq/5YO4BOzcYtJqExdx99rF6aAcnRxHmcUHcz6sQsg=="], + + "esrecurse": ["esrecurse@4.3.0", "", { "dependencies": { "estraverse": "^5.2.0" } }, "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag=="], + + "estraverse": ["estraverse@5.3.0", "", {}, "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA=="], "estree-util-attach-comments": ["estree-util-attach-comments@3.0.0", "", { "dependencies": { "@types/estree": "^1.0.0" } }, "sha512-cKUwm/HUcTDsYh/9FgnuFqpfquUbwIqwKM26BVCGDPVgvaCl/nDCCjUfiLlx6lsEZ3Z4RFxNbOQ60pkaEwFxGw=="], @@ -903,6 +949,8 @@ "estree-walker": ["estree-walker@3.0.3", "", { "dependencies": { "@types/estree": "^1.0.0" } }, "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g=="], + "esutils": ["esutils@2.0.3", "", {}, "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g=="], + "eventemitter3": ["eventemitter3@4.0.7", "", {}, "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw=="], "exsolve": ["exsolve@1.0.5", "", {}, "sha512-pz5dvkYYKQ1AHVrgOzBKWeP4u4FRb3a6DNK2ucr0OoNwYIU4QWsJ+NM36LLzORT+z845MzKHHhpXiUF5nvQoJg=="], @@ -917,8 +965,14 @@ "fast-glob": ["fast-glob@3.3.3", "", { "dependencies": { "@nodelib/fs.stat": "^2.0.2", "@nodelib/fs.walk": "^1.2.3", "glob-parent": "^5.1.2", "merge2": "^1.3.0", "micromatch": "^4.0.8" } }, "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg=="], + "fast-json-stable-stringify": ["fast-json-stable-stringify@2.1.0", "", {}, "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw=="], + + "fast-levenshtein": ["fast-levenshtein@2.0.6", "", {}, "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw=="], + "fastq": ["fastq@1.19.1", "", { "dependencies": { "reusify": "^1.0.4" } }, "sha512-GwLTyxkCXjXbxqIhTsMI2Nui8huMPtnxg7krajPJAjnEG/iiOS7i+zCtWGZR9G0NBKbXKh6X9m9UIsYX/N6vvQ=="], + "file-entry-cache": ["file-entry-cache@6.0.1", "", { "dependencies": { "flat-cache": "^3.0.4" } }, "sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg=="], + "file-selector": ["file-selector@2.1.2", "", { "dependencies": { "tslib": "^2.7.0" } }, "sha512-QgXo+mXTe8ljeqUFaX3QVHc5osSItJ/Km+xpocx0aSqWGMSCf6qYs/VnzZgS864Pjn5iceMRFigeAV7AfTlaig=="], "file-source": ["file-source@0.6.1", "", { "dependencies": { "stream-source": "0.3" } }, "sha512-1R1KneL7eTXmXfKxC10V/9NeGOdbsAXJ+lQ//fvvcHUgtaZcZDWNJNblxAoVOyV1cj45pOtUrR3vZTBwqcW8XA=="], @@ -929,6 +983,12 @@ "find-root": ["find-root@1.1.0", "", {}, "sha512-NKfW6bec6GfKc0SGx1e07QZY9PE99u0Bft/0rzSD5k3sO/vwkVUpDUKVm5Gpp5Ue3YfShPFTX2070tDs5kB9Ng=="], + "find-up": ["find-up@5.0.0", "", { "dependencies": { "locate-path": "^6.0.0", "path-exists": "^4.0.0" } }, "sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng=="], + + "flat-cache": ["flat-cache@3.2.0", "", { "dependencies": { "flatted": "^3.2.9", "keyv": "^4.5.3", "rimraf": "^3.0.2" } }, "sha512-CYcENa+FtcUKLmhhqyctpclsq7QF38pKjZHsGNiSQF5r4FtoKDWabFDl3hzaEQMvT1LHEysw5twgLvpYYb4vbw=="], + + "flatted": ["flatted@3.3.3", "", {}, "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg=="], + "follow-redirects": ["follow-redirects@1.15.9", "", {}, "sha512-gew4GsXizNgdoRyqmyfMHyAmXsZDk6mHkSxZFCzW9gwlbtOW44CDtYavM+y+72qD/Vq2l550kMF52DT8fOLJqQ=="], "for-in": ["for-in@1.0.2", "", {}, "sha512-7EwmXrOjyL+ChxMhmG5lnW9MPt1aIeZEwKhQzoBUdTV0N3zuwWDZYVJatDvZ2OyzPUvdIAZDsCetk3coyMfcnQ=="], @@ -969,12 +1029,16 @@ "glob-parent": ["glob-parent@6.0.2", "", { "dependencies": { "is-glob": "^4.0.3" } }, "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A=="], - "globals": ["globals@15.15.0", "", {}, "sha512-7ACyT3wmyp3I61S4fG682L0VA2RGD9otkqGJIwNUMF1SWUombIIk+af1unuDYgMm082aHYwD+mzJvv9Iu8dsgg=="], + "globals": ["globals@13.24.0", "", { "dependencies": { "type-fest": "^0.20.2" } }, "sha512-AhO5QUcj8llrbG09iWhPU2B204J1xnPeL8kQmVorSsy+Sjj1sk8gIyh6cUocGmH4L0UuhAJy+hJMRA4mgA4mFQ=="], "graceful-fs": ["graceful-fs@4.2.11", "", {}, "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ=="], + "graphemer": ["graphemer@1.4.0", "", {}, "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag=="], + "hachure-fill": ["hachure-fill@0.5.2", "", {}, "sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg=="], + "has-flag": ["has-flag@4.0.0", "", {}, "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ=="], + "hasown": ["hasown@2.0.2", "", { "dependencies": { "function-bind": "^1.1.2" } }, "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ=="], "hast-util-from-dom": ["hast-util-from-dom@5.0.1", "", { "dependencies": { "@types/hast": "^3.0.0", "hastscript": "^9.0.0", "web-namespaces": "^2.0.0" } }, "sha512-N+LqofjR2zuzTjCPzyDUdSshy4Ma6li7p/c3pA78uTwzFgENbgbUrm2ugwsOdcjI1muO+o6Dgzp9p8WHtn/39Q=="], @@ -1025,12 +1089,16 @@ "ieee754": ["ieee754@1.2.1", "", {}, "sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA=="], + "ignore": ["ignore@5.3.2", "", {}, "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g=="], + "immer": ["immer@10.1.1", "", {}, "sha512-s2MPrmjovJcoMaHtx6K11Ra7oD05NT97w1IC5zpMkT6Atjr7H8LjaDd81iIxUYpMKSRRNMJE703M1Fhr/TctHw=="], "immutable": ["immutable@5.1.2", "", {}, "sha512-qHKXW1q6liAk1Oys6umoaZbDRqjcjgSrbnrifHsfsttza7zcvRAsL7mMV6xWcyhwQy7Xj5v4hhbr6b+iDYwlmQ=="], "import-fresh": ["import-fresh@3.3.0", "", { "dependencies": { "parent-module": "^1.0.0", "resolve-from": "^4.0.0" } }, "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw=="], + "imurmurhash": ["imurmurhash@0.1.4", "", {}, "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA=="], + "inflight": ["inflight@1.0.6", "", { "dependencies": { "once": "^1.3.0", "wrappy": "1" } }, "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA=="], "inherits": ["inherits@2.0.4", "", {}, "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="], @@ -1065,6 +1133,8 @@ "is-number": ["is-number@7.0.0", "", {}, "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng=="], + "is-path-inside": ["is-path-inside@3.0.3", "", {}, "sha512-Fd4gABb+ycGAmKou8eMftCupSir5lRxqf4aD/vd0cD2qc4HL07OjCeuHMr8Ro4CoMaeCKDB0/ECBOVWjTwUvPQ=="], + "is-plain-obj": ["is-plain-obj@4.1.0", "", {}, "sha512-+Pgi+vMuUNkJyExiMBt5IlFoMyKnr5zhJ4Uspz58WOhBF5QoIZkFyNHIbBAtHwzVAgk5RtndVNsDRN61/mmDqg=="], "is-plain-object": ["is-plain-object@2.0.4", "", { "dependencies": { "isobject": "^3.0.1" } }, "sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og=="], @@ -1083,10 +1153,18 @@ "js-tokens": ["js-tokens@4.0.0", "", {}, "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ=="], + "js-yaml": ["js-yaml@4.1.0", "", { "dependencies": { "argparse": "^2.0.1" }, "bin": { "js-yaml": "bin/js-yaml.js" } }, "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA=="], + "jsesc": ["jsesc@3.1.0", "", { "bin": { "jsesc": "bin/jsesc" } }, "sha512-/sM3dO2FOzXjKQhJuo0Q173wf2KOo8t4I8vHy6lF9poUp7bKT0/NHE8fPX23PwfhnykfqnC2xRxOnVw5XuGIaA=="], + "json-buffer": ["json-buffer@3.0.1", "", {}, "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ=="], + "json-parse-even-better-errors": ["json-parse-even-better-errors@2.3.1", "", {}, "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w=="], + "json-schema-traverse": ["json-schema-traverse@0.4.1", "", {}, "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg=="], + + "json-stable-stringify-without-jsonify": ["json-stable-stringify-without-jsonify@1.0.1", "", {}, "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw=="], + "json2mq": ["json2mq@0.2.0", "", { "dependencies": { "string-convert": "^0.2.0" } }, "sha512-SzoRg7ux5DWTII9J2qkrZrqV1gt+rTaoufMxEzXbS26Uid0NwaJd123HcoB80TgubEppxxIGdNxCx50fEoEWQA=="], "json5": ["json5@2.2.3", "", { "bin": { "json5": "lib/cli.js" } }, "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg=="], @@ -1097,6 +1175,8 @@ "katex": ["katex@0.16.22", "", { "dependencies": { "commander": "^8.3.0" }, "bin": { "katex": "cli.js" } }, "sha512-XCHRdUw4lf3SKBaJe4EvgqIuWwkPSo9XoeO8GjQW94Bp7TWv9hNhzZjZ+OH9yf1UmLygb7DIT5GSFQiyt16zYg=="], + "keyv": ["keyv@4.5.4", "", { "dependencies": { "json-buffer": "3.0.1" } }, "sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw=="], + "khroma": ["khroma@2.1.0", "", {}, "sha512-Ls993zuzfayK269Svk9hzpeGUKob/sIgZzyHYdjQoAdQetRKpOLj+k/QQQ/6Qi0Yz65mlROrfd+Ev+1+7dz9Kw=="], "kolorist": ["kolorist@1.8.0", "", {}, "sha512-Y+60/zizpJ3HRH8DCss+q95yr6145JXZo46OTpFvDZWLfRCE4qChOyk1b26nMaNpfHHgxagk9dXT5OP0Tfe+dQ=="], @@ -1107,6 +1187,8 @@ "leva": ["leva@0.10.0", "", { "dependencies": { "@radix-ui/react-portal": "1.0.2", "@radix-ui/react-tooltip": "1.0.5", "@stitches/react": "^1.2.8", "@use-gesture/react": "^10.2.5", "colord": "^2.9.2", "dequal": "^2.0.2", "merge-value": "^1.0.0", "react-colorful": "^5.5.1", "react-dropzone": "^12.0.0", "v8n": "^1.3.3", "zustand": "^3.6.9" }, "peerDependencies": { "react": "^18.0.0 || ^19.0.0", "react-dom": "^18.0.0 || ^19.0.0" } }, "sha512-RiNJWmeqQdKIeHuVXgshmxIHu144a2AMYtLxKf8Nm1j93pisDPexuQDHKNdQlbo37wdyDQibLjY9JKGIiD7gaw=="], + "levn": ["levn@0.4.1", "", { "dependencies": { "prelude-ls": "^1.2.1", "type-check": "~0.4.0" } }, "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ=="], + "lilconfig": ["lilconfig@3.1.3", "", {}, "sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw=="], "lines-and-columns": ["lines-and-columns@1.2.4", "", {}, "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg=="], @@ -1119,12 +1201,16 @@ "local-pkg": ["local-pkg@1.1.1", "", { "dependencies": { "mlly": "^1.7.4", "pkg-types": "^2.0.1", "quansync": "^0.2.8" } }, "sha512-WunYko2W1NcdfAFpuLUoucsgULmgDBRkdxHxWQ7mK0cQqwPiy8E1enjuRBrhLtZkB5iScJ1XIPdhVEFK8aOLSg=="], + "locate-path": ["locate-path@6.0.0", "", { "dependencies": { "p-locate": "^5.0.0" } }, "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw=="], + "lodash": ["lodash@4.17.21", "", {}, "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="], "lodash-es": ["lodash-es@4.17.21", "", {}, "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="], "lodash.debounce": ["lodash.debounce@4.0.8", "", {}, "sha512-FT1yDzDYEoYWhnSGnpE/4Kj1fLZkDFyqRb7fNt6FdYOSxlUWAtp42Eh6Wb0rGIv/m9Bgo7x4GhQbm5Ys4SG5ow=="], + "lodash.merge": ["lodash.merge@4.6.2", "", {}, "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ=="], + "longest-streak": ["longest-streak@3.1.0", "", {}, "sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g=="], "loose-envify": ["loose-envify@1.4.0", "", { "dependencies": { "js-tokens": "^3.0.0 || ^4.0.0" }, "bin": { "loose-envify": "cli.js" } }, "sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q=="], @@ -1285,6 +1371,8 @@ "nanoid": ["nanoid@3.3.8", "", { "bin": { "nanoid": "bin/nanoid.cjs" } }, "sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w=="], + "natural-compare": ["natural-compare@1.4.0", "", {}, "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw=="], + "node-addon-api": ["node-addon-api@7.1.1", "", {}, "sha512-5m3bsyrjFWE1xf7nz7YXdN4udnVtXK6/Yfgn5qnahL6bCkf2yKt4k3nuTKAtT4r3IG8JNR2ncsIMdZuAzJjHQQ=="], "node-releases": ["node-releases@2.0.19", "", {}, "sha512-xxOWJsBKtzAq7DY0J+DTzuz58K8e7sJbdgwkbMWQe8UYB6ekmsQ45q0M/tJDsGaZmbC+l7n57UV8Hl5tHxO9uw=="], @@ -1307,6 +1395,12 @@ "oniguruma-to-es": ["oniguruma-to-es@4.3.3", "", { "dependencies": { "oniguruma-parser": "^0.12.1", "regex": "^6.0.1", "regex-recursion": "^6.0.2" } }, "sha512-rPiZhzC3wXwE59YQMRDodUwwT9FZ9nNBwQQfsd1wfdtlKEyCdRV0avrTcSZ5xlIvGRVPd/cx6ZN45ECmS39xvg=="], + "optionator": ["optionator@0.9.4", "", { "dependencies": { "deep-is": "^0.1.3", "fast-levenshtein": "^2.0.6", "levn": "^0.4.1", "prelude-ls": "^1.2.1", "type-check": "^0.4.0", "word-wrap": "^1.2.5" } }, "sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g=="], + + "p-limit": ["p-limit@3.1.0", "", { "dependencies": { "yocto-queue": "^0.1.0" } }, "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ=="], + + "p-locate": ["p-locate@5.0.0", "", { "dependencies": { "p-limit": "^3.0.2" } }, "sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw=="], + "package-json-from-dist": ["package-json-from-dist@1.0.1", "", {}, "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw=="], "package-manager-detector": ["package-manager-detector@1.3.0", "", {}, "sha512-ZsEbbZORsyHuO00lY1kV3/t72yp6Ysay6Pd17ZAlNGuGwmWDLCJxFpRs0IzfXfj1o4icJOkUEioexFHzyPurSQ=="], @@ -1327,6 +1421,8 @@ "path-data-parser": ["path-data-parser@0.1.0", "", {}, "sha512-NOnmBpt5Y2RWbuv0LMzsayp3lVylAHLPUTut412ZA3l+C4uw4ZVkQbjShYCQ8TCpUMdPapr4YjUqLYD6v68j+w=="], + "path-exists": ["path-exists@4.0.0", "", {}, "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w=="], + "path-is-absolute": ["path-is-absolute@1.0.1", "", {}, "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg=="], "path-key": ["path-key@3.1.1", "", {}, "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q=="], @@ -1375,6 +1471,8 @@ "postcss-value-parser": ["postcss-value-parser@4.2.0", "", {}, "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ=="], + "prelude-ls": ["prelude-ls@1.2.1", "", {}, "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g=="], + "prettier": ["prettier@3.4.2", "", { "bin": { "prettier": "bin/prettier.cjs" } }, "sha512-e9MewbtFo+Fevyuxn/4rrcDAaq0IYxPGLvObpQjiZBMAzB9IGmzlnG9RZy3FFas+eBMu2vA0CszMeduow5dIuQ=="], "prettier-package-json": ["prettier-package-json@2.8.0", "", { "dependencies": { "@types/parse-author": "^2.0.0", "commander": "^4.0.1", "cosmiconfig": "^7.0.0", "fs-extra": "^10.0.0", "glob": "^7.1.6", "minimatch": "^3.0.4", "parse-author": "^2.0.0", "sort-object-keys": "^1.1.3", "sort-order": "^1.0.1" }, "bin": { "prettier-package-json": "bin/prettier-package-json" } }, "sha512-WxtodH/wWavfw3MR7yK/GrS4pASEQ+iSTkdtSxPJWvqzG55ir5nvbLt9rw5AOiEcqqPCRM92WCtR1rk3TG3JSQ=="], @@ -1393,6 +1491,10 @@ "protocol-buffers-schema": ["protocol-buffers-schema@3.6.0", "", {}, "sha512-TdDRD+/QNdrCGCE7v8340QyuXd4kIWIgapsE2+n/SaGiSSbomYl4TjHlvIoCWRpE7wFt02EpB35VVA2ImcBVqw=="], + "punycode": ["punycode@2.3.1", "", {}, "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg=="], + + "qrcode.react": ["qrcode.react@4.2.0", "", { "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-QpgqWi8rD9DsS9EP3z7BT+5lY5SFhsqGjpgW5DY/i3mK4M9DTBNz3ErMi8BWYEfI3L0d8GIbGmcdFAS1uIRGjA=="], + "quansync": ["quansync@0.2.10", "", {}, "sha512-t41VRkMYbkHyCYmOvx/6URnN80H7k4X0lLdBMGsz+maAwrJQYB1djpV6vHrQIBE0WBSGqhtEHrK9U3DWWH8v7A=="], "query-string": ["query-string@9.2.0", "", { "dependencies": { "decode-uri-component": "^0.4.1", "filter-obj": "^5.1.0", "split-on-first": "^3.0.0" } }, "sha512-YIRhrHujoQxhexwRLxfy3VSjOXmvZRd2nyw1PwL1UUqZ/ys1dEZd1+NSgXkne2l/4X/7OXkigEAuhTX0g/ivJQ=="], @@ -1403,7 +1505,7 @@ "rc-checkbox": ["rc-checkbox@3.5.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "^2.3.2", "rc-util": "^5.25.2" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-aOAQc3E98HteIIsSqm6Xk2FPKIER6+5vyEFMZfo73TqM+VVAIqOkHoPjgKLqSNtVLWScoaM7vY2ZrGEheI79yg=="], - "rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="], + "rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="], "rc-dialog": ["rc-dialog@9.6.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "@rc-component/portal": "^1.0.0-8", "classnames": "^2.2.6", "rc-motion": "^2.3.0", "rc-util": "^5.21.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-ApoVi9Z8PaCQg6FsUzS8yvBEQy0ZL2PkuvAgrmohPkN3okps5WZ5WQWPc1RNuiOKaAYv8B97ACdsFU5LizzCqg=="], @@ -1577,6 +1679,8 @@ "reusify": ["reusify@1.1.0", "", {}, "sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw=="], + "rimraf": ["rimraf@3.0.2", "", { "dependencies": { "glob": "^7.1.3" }, "bin": { "rimraf": "bin.js" } }, "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA=="], + "robust-predicates": ["robust-predicates@3.0.2", "", {}, "sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg=="], "rollup": ["rollup@4.30.0", "", { "dependencies": { "@types/estree": "1.0.6" }, "optionalDependencies": { "@rollup/rollup-android-arm-eabi": "4.30.0", "@rollup/rollup-android-arm64": "4.30.0", "@rollup/rollup-darwin-arm64": "4.30.0", "@rollup/rollup-darwin-x64": "4.30.0", "@rollup/rollup-freebsd-arm64": "4.30.0", "@rollup/rollup-freebsd-x64": "4.30.0", "@rollup/rollup-linux-arm-gnueabihf": "4.30.0", "@rollup/rollup-linux-arm-musleabihf": "4.30.0", "@rollup/rollup-linux-arm64-gnu": "4.30.0", "@rollup/rollup-linux-arm64-musl": "4.30.0", "@rollup/rollup-linux-loongarch64-gnu": "4.30.0", "@rollup/rollup-linux-powerpc64le-gnu": "4.30.0", "@rollup/rollup-linux-riscv64-gnu": "4.30.0", "@rollup/rollup-linux-s390x-gnu": "4.30.0", "@rollup/rollup-linux-x64-gnu": "4.30.0", "@rollup/rollup-linux-x64-musl": "4.30.0", "@rollup/rollup-win32-arm64-msvc": "4.30.0", "@rollup/rollup-win32-ia32-msvc": "4.30.0", "@rollup/rollup-win32-x64-msvc": "4.30.0", "fsevents": "~2.3.2" }, "bin": { "rollup": "dist/bin/rollup" } }, "sha512-sDnr1pcjTgUT69qBksNF1N1anwfbyYG6TBQ22b03bII8EdiUQ7J0TlozVaTMjT/eEJAO49e1ndV7t+UZfL1+vA=="], @@ -1655,10 +1759,12 @@ "stringify-entities": ["stringify-entities@4.0.4", "", { "dependencies": { "character-entities-html4": "^2.0.0", "character-entities-legacy": "^3.0.0" } }, "sha512-IwfBptatlO+QCJUo19AqvrPNqlVMpW9YEL2LIVY+Rpv2qsjCGxaDLNRgeGsQWJhfItebuJhsGSLjaBbNSQ+ieg=="], - "strip-ansi": ["strip-ansi@7.1.0", "", { "dependencies": { "ansi-regex": "^6.0.1" } }, "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ=="], + "strip-ansi": ["strip-ansi@6.0.1", "", { "dependencies": { "ansi-regex": "^5.0.1" } }, "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A=="], "strip-ansi-cjs": ["strip-ansi@6.0.1", "", { "dependencies": { "ansi-regex": "^5.0.1" } }, "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A=="], + "strip-json-comments": ["strip-json-comments@3.1.1", "", {}, "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig=="], + "style-to-object": ["style-to-object@1.0.8", "", { "dependencies": { "inline-style-parser": "0.2.4" } }, "sha512-xT47I/Eo0rwJmaXC4oilDGDWLohVhR6o/xAQcPQN8q6QBuZVL8qMYL85kLmST5cPjAorwvqIA4qXTRQoYHaL6g=="], "stylis": ["stylis@4.3.6", "", {}, "sha512-yQ3rwFWRfwNUY7H5vpU0wfdkNSnvnJinhF9830Swlaxl03zsOjCfmX0ugac+3LtK0lYSgwL/KXc8oYL3mG4YFQ=="], @@ -1667,6 +1773,8 @@ "suf-log": ["suf-log@2.5.3", "", { "dependencies": { "s.color": "0.0.15" } }, "sha512-KvC8OPjzdNOe+xQ4XWJV2whQA0aM1kGVczMQ8+dStAO6KfEB140JEVQ9dE76ONZ0/Ylf67ni4tILPJB41U0eow=="], + "supports-color": ["supports-color@7.2.0", "", { "dependencies": { "has-flag": "^4.0.0" } }, "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw=="], + "supports-preserve-symlinks-flag": ["supports-preserve-symlinks-flag@1.0.0", "", {}, "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w=="], "swr": ["swr@2.3.3", "", { "dependencies": { "dequal": "^2.0.3", "use-sync-external-store": "^1.4.0" }, "peerDependencies": { "react": "^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-dshNvs3ExOqtZ6kJBaAsabhPdHyeY4P2cKwRCniDVifBMoG/SVI7tfLWqPXriVspf2Rg4tPzXJTnwaihIeFw2A=="], @@ -1677,6 +1785,8 @@ "text-encoding": ["text-encoding@0.6.4", "", {}, "sha512-hJnc6Qg3dWoOMkqP53F0dzRIgtmsAge09kxUIqGrEUS4qr5rWLckGYaQAVr+opBrIMRErGgy6f5aPnyPpyGRfg=="], + "text-table": ["text-table@0.2.0", "", {}, "sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw=="], + "thenify": ["thenify@3.3.1", "", { "dependencies": { "any-promise": "^1.0.0" } }, "sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw=="], "thenify-all": ["thenify-all@1.6.0", "", { "dependencies": { "thenify": ">= 3.1.0 < 4" } }, "sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA=="], @@ -1705,6 +1815,10 @@ "tslib": ["tslib@2.8.1", "", {}, "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w=="], + "type-check": ["type-check@0.4.0", "", { "dependencies": { "prelude-ls": "^1.2.1" } }, "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew=="], + + "type-fest": ["type-fest@0.20.2", "", {}, "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ=="], + "typedarray": ["typedarray@0.0.6", "", {}, "sha512-/aCDEGatGvZ2BIk+HmLf4ifCJFwvKFNb9/JeZPMulfgFracn9QFcAf5GO8B/mweUjSoblS5In0cWhqpfs/5PQA=="], "typescript": ["typescript@4.4.2", "", { "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" } }, "sha512-gzP+t5W4hdy4c+68bfcv0t400HVJMMd2+H9B7gae1nQlBzCqvrXX+6GL/b3GAgyTH966pzrZ70/fRjwAtZksSQ=="], @@ -1733,6 +1847,8 @@ "update-browserslist-db": ["update-browserslist-db@1.1.3", "", { "dependencies": { "escalade": "^3.2.0", "picocolors": "^1.1.1" }, "peerDependencies": { "browserslist": ">= 4.21.0" }, "bin": { "update-browserslist-db": "cli.js" } }, "sha512-UxhIZQ+QInVdunkDAaiazvvT/+fXL5Osr0JZlJulepYu6Jd7qJtDZjlur0emRlT71EN3ScPoE7gvsuIKKNavKw=="], + "uri-js": ["uri-js@4.4.1", "", { "dependencies": { "punycode": "^2.1.0" } }, "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg=="], + "url-join": ["url-join@5.0.0", "", {}, "sha512-n2huDr9h9yzd6exQVnH/jU5mr+Pfx08LRXXZhkLLetAMESRj+anQsTAh940iMrIetKAmry9coFuZQ2jY8/p3WA=="], "use-debounce": ["use-debounce@10.0.4", "", { "peerDependencies": { "react": "*" } }, "sha512-6Cf7Yr7Wk7Kdv77nnJMf6de4HuDE4dTxKij+RqE9rufDsI6zsbjyAxcH5y2ueJCQAnfgKbzXbZHYlkFwmBlWkw=="], @@ -1777,6 +1893,8 @@ "which": ["which@2.0.2", "", { "dependencies": { "isexe": "^2.0.0" }, "bin": { "node-which": "./bin/node-which" } }, "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA=="], + "word-wrap": ["word-wrap@1.2.5", "", {}, "sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA=="], + "wrap-ansi": ["wrap-ansi@8.1.0", "", { "dependencies": { "ansi-styles": "^6.1.0", "string-width": "^5.0.1", "strip-ansi": "^7.0.1" } }, "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ=="], "wrap-ansi-cjs": ["wrap-ansi@7.0.0", "", { "dependencies": { "ansi-styles": "^4.0.0", "string-width": "^4.1.0", "strip-ansi": "^6.0.0" } }, "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q=="], @@ -1787,6 +1905,8 @@ "yaml": ["yaml@2.8.0", "", { "bin": { "yaml": "bin.mjs" } }, "sha512-4lLa/EcQCB0cJkyts+FpIRx5G/llPxfP6VQU5KByHEhLxY3IJCH0f0Hy1MHI8sClTvsIb8qwRJ6R/ZdlDJ/leQ=="], + "yocto-queue": ["yocto-queue@0.1.0", "", {}, "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q=="], + "zustand": ["zustand@3.7.2", "", { "peerDependencies": { "react": ">=16.8" }, "optionalPeers": ["react"] }, "sha512-PIJDIZKtokhof+9+60cpockVOq05sJzHCriyvaLBmEJixseQ1a5Kdov6fWZfWOu5SK9c+FhH1jU0tntLxRJYMA=="], "zwitch": ["zwitch@2.0.4", "", {}, "sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A=="], @@ -1807,8 +1927,6 @@ "@emotion/babel-plugin/convert-source-map": ["convert-source-map@1.9.0", "", {}, "sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A=="], - "@emotion/babel-plugin/escape-string-regexp": ["escape-string-regexp@4.0.0", "", {}, "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA=="], - "@emotion/babel-plugin/source-map": ["source-map@0.5.7", "", {}, "sha512-LbrmJOMUSdEVxIKvdcJzQC+nQhe8FUZQTXQy6+I75skNgn3OoQ0DZA8YnFa7gp8tqtL3KPf1kmo0R5DoApeSGQ=="], "@emotion/babel-plugin/stylis": ["stylis@4.2.0", "", {}, "sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw=="], @@ -1819,6 +1937,10 @@ "@emotion/serialize/@emotion/unitless": ["@emotion/unitless@0.10.0", "", {}, "sha512-dFoMUuQA20zvtVTuxZww6OHoJYgrzfKM1t52mVySDJnMSEa08ruEvdYQbhvyu6soU+NeLVd3yKfTfT0NeV6qGg=="], + "@iconify/utils/globals": ["globals@15.15.0", "", {}, "sha512-7ACyT3wmyp3I61S4fG682L0VA2RGD9otkqGJIwNUMF1SWUombIIk+af1unuDYgMm082aHYwD+mzJvv9Iu8dsgg=="], + + "@isaacs/cliui/strip-ansi": ["strip-ansi@7.1.0", "", { "dependencies": { "ansi-regex": "^6.0.1" } }, "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ=="], + "@lobehub/fluent-emoji/lucide-react": ["lucide-react@0.469.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-28vvUnnKQ/dBwiCQtwJw7QauYnE7yd2Cyp4tTTJpvglX4EMpbflcdBgrgToX2j71B3YvugK/NH3BGUk+E/p/Fw=="], "@lobehub/icons/lucide-react": ["lucide-react@0.469.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-28vvUnnKQ/dBwiCQtwJw7QauYnE7yd2Cyp4tTTJpvglX4EMpbflcdBgrgToX2j71B3YvugK/NH3BGUk+E/p/Fw=="], @@ -1827,8 +1949,6 @@ "@lobehub/ui/lucide-react": ["lucide-react@0.484.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-oZy8coK9kZzvqhSgfbGkPtTgyjpBvs3ukLgDPv14dSOZtBtboryWF5o8i3qen7QbGg7JhiJBz5mK1p8YoMZTLQ=="], - "@lobehub/ui/rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="], - "@radix-ui/react-dismissable-layer/@radix-ui/react-compose-refs": ["@radix-ui/react-compose-refs@1.0.0", "", { "dependencies": { "@babel/runtime": "^7.13.10" }, "peerDependencies": { "react": "^16.8 || ^17.0 || ^18.0" } }, "sha512-0KaSv6sx787/hK3eF53iOkiSLwAGlFMx5lotrqD2pTjB18KbybKoEIgkNZTKC60YECDQTKGTRcDBILwZVqVKvA=="], "@radix-ui/react-popper/@floating-ui/react-dom": ["@floating-ui/react-dom@0.7.2", "", { "dependencies": { "@floating-ui/dom": "^0.5.3", "use-isomorphic-layout-effect": "^1.1.1" }, "peerDependencies": { "react": ">=16.8.0", "react-dom": ">=16.8.0" } }, "sha512-1T0sJcpHgX/u4I1OzIEhlcrvkUN8ln39nz7fMoE/2HDHrPiMFoOGR7++GYyfUmIQHkkrTinaeQsO3XWubjSvGg=="], @@ -1845,6 +1965,8 @@ "@visactor/vrender-kits/roughjs": ["roughjs@4.5.2", "", { "dependencies": { "path-data-parser": "^0.1.0", "points-on-curve": "^0.2.0", "points-on-path": "^0.2.1" } }, "sha512-2xSlLDKdsWyFxrveYWk9YQ/Y9UfK38EAMRNkYkMqYBJvPX8abCa9PN0x3w02H8Oa6/0bcZICJU+U95VumPqseg=="], + "antd/rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="], + "antd/scroll-into-view-if-needed": ["scroll-into-view-if-needed@3.1.0", "", { "dependencies": { "compute-scroll-into-view": "^3.0.2" } }, "sha512-49oNpRjWRvnU8NyGVmUaYG4jtTkNonFZI86MmGRDqBphEK2EXT9gdEUoQPZhuBM8yWHxCWbobltqYO5M4XrUvQ=="], "chokidar/glob-parent": ["glob-parent@5.1.2", "", { "dependencies": { "is-glob": "^4.0.1" } }, "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow=="], @@ -1867,6 +1989,8 @@ "d3-sankey/d3-shape": ["d3-shape@1.3.7", "", { "dependencies": { "d3-path": "1" } }, "sha512-EUkvKjqPFUAZyOlhY5gzCxCeI0Aep04LwIRpsZ/mLFelJiUfnK56jo5JMDSE7yyP2kLSb6LtF+S5chMk7uqPqw=="], + "esast-util-from-js/acorn": ["acorn@8.14.0", "", { "bin": { "acorn": "bin/acorn" } }, "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA=="], + "extend-shallow/is-extendable": ["is-extendable@0.1.1", "", {}, "sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw=="], "fast-glob/glob-parent": ["glob-parent@5.1.2", "", { "dependencies": { "is-glob": "^4.0.1" } }, "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow=="], @@ -1887,8 +2011,14 @@ "leva/react-dropzone": ["react-dropzone@12.1.0", "", { "dependencies": { "attr-accept": "^2.2.2", "file-selector": "^0.5.0", "prop-types": "^15.8.1" }, "peerDependencies": { "react": ">= 16.8" } }, "sha512-iBYHA1rbopIvtzokEX4QubO6qk5IF/x3BtKGu74rF2JkQDXnwC4uO/lHKpaw4PJIV6iIAYOlwLv2FpiGyqHNog=="], + "mdast-util-find-and-replace/escape-string-regexp": ["escape-string-regexp@5.0.0", "", {}, "sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw=="], + "mermaid/marked": ["marked@15.0.12", "", { "bin": { "marked": "bin/marked.js" } }, "sha512-8dD6FusOQSrpv9Z1rdNMdlSgQOIP880DHqnohobOmYLElGEqAL/JvxvuxZO16r4HtjTlfPRDC1hbvxC9dPN2nA=="], + "micromark-extension-mdxjs/acorn": ["acorn@8.14.0", "", { "bin": { "acorn": "bin/acorn" } }, "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA=="], + + "mlly/acorn": ["acorn@8.14.0", "", { "bin": { "acorn": "bin/acorn" } }, "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA=="], + "mlly/pkg-types": ["pkg-types@1.3.1", "", { "dependencies": { "confbox": "^0.1.8", "mlly": "^1.7.4", "pathe": "^2.0.1" } }, "sha512-/Jm5M4RvtBFVkKWRu2BLUTNP8/M2a+UwuAX+ae4770q1qVGtfjG+WTCupoZixokjmHiry8uI+dlY8KXYV5HVVQ=="], "parse-entities/@types/unist": ["@types/unist@2.0.11", "", {}, "sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA=="], @@ -1909,6 +2039,8 @@ "react-toastify/clsx": ["clsx@1.2.1", "", {}, "sha512-EcR6r5a8bj6pu3ycsa/E/cKVGuTgZJZdsyUYHOksG/UHIiKfjxzRxYJpyVBwYaQeOvghal9fcc4PidlgzugAQg=="], + "rimraf/glob": ["glob@7.2.3", "", { "dependencies": { "fs.realpath": "^1.0.0", "inflight": "^1.0.4", "inherits": "2", "minimatch": "^3.1.1", "once": "^1.3.0", "path-is-absolute": "^1.0.0" } }, "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q=="], + "sass/chokidar": ["chokidar@4.0.3", "", { "dependencies": { "readdirp": "^4.0.1" } }, "sha512-Qgzu8kfBvo+cA4962jnP1KkS6Dop5NS6g7R5LFYJr4b8Ub94PPQXUksCw9PvXoeXPRRddRNC5C1JQUR2SMGtnA=="], "set-value/is-extendable": ["is-extendable@0.1.1", "", {}, "sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw=="], @@ -1921,12 +2053,10 @@ "string-width/emoji-regex": ["emoji-regex@9.2.2", "", {}, "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg=="], + "string-width/strip-ansi": ["strip-ansi@7.1.0", "", { "dependencies": { "ansi-regex": "^6.0.1" } }, "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ=="], + "string-width-cjs/emoji-regex": ["emoji-regex@8.0.0", "", {}, "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A=="], - "string-width-cjs/strip-ansi": ["strip-ansi@6.0.1", "", { "dependencies": { "ansi-regex": "^5.0.1" } }, "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A=="], - - "strip-ansi-cjs/ansi-regex": ["ansi-regex@5.0.1", "", {}, "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ=="], - "sucrase/commander": ["commander@4.1.1", "", {}, "sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA=="], "topojson-client/commander": ["commander@2.20.3", "", {}, "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ=="], @@ -1935,12 +2065,12 @@ "vite/postcss": ["postcss@8.4.49", "", { "dependencies": { "nanoid": "^3.3.7", "picocolors": "^1.1.1", "source-map-js": "^1.2.1" } }, "sha512-OCVPnIObs4N29kxTjzLfUryOkvZEq+pf8jTF0lg8E7uETuWHA+v7j3c/xJmiqpX450191LlmZfUKkXxkTry7nA=="], - "wrap-ansi-cjs/ansi-styles": ["ansi-styles@4.3.0", "", { "dependencies": { "color-convert": "^2.0.1" } }, "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg=="], + "wrap-ansi/ansi-styles": ["ansi-styles@6.2.1", "", {}, "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug=="], + + "wrap-ansi/strip-ansi": ["strip-ansi@7.1.0", "", { "dependencies": { "ansi-regex": "^6.0.1" } }, "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ=="], "wrap-ansi-cjs/string-width": ["string-width@4.2.3", "", { "dependencies": { "emoji-regex": "^8.0.0", "is-fullwidth-code-point": "^3.0.0", "strip-ansi": "^6.0.1" } }, "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g=="], - "wrap-ansi-cjs/strip-ansi": ["strip-ansi@6.0.1", "", { "dependencies": { "ansi-regex": "^5.0.1" } }, "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A=="], - "@babel/helper-compilation-targets/browserslist/caniuse-lite": ["caniuse-lite@1.0.30001690", "", {}, "sha512-5ExiE3qQN6oF8Clf8ifIDcMRCRE/dMGcETG/XGMD8/XiXm6HXQgQTh1yZYLXXpSOsEUlJm1Xr7kGULZTuGtP/w=="], "@babel/helper-compilation-targets/browserslist/electron-to-chromium": ["electron-to-chromium@1.5.76", "", {}, "sha512-CjVQyG7n7Sr+eBXE86HIulnL5N8xZY1sgmOPGuq/F0Rr0FJq63lg0kEtOIDfZBk44FnDLf6FUJ+dsJcuiUDdDQ=="], @@ -1951,6 +2081,8 @@ "@babel/plugin-transform-runtime/@babel/helper-module-imports/@babel/types": ["@babel/types@7.27.1", "", { "dependencies": { "@babel/helper-string-parser": "^7.27.1", "@babel/helper-validator-identifier": "^7.27.1" } }, "sha512-+EzkxvLNfiUeKMgy/3luqfsCWFRXLb7U6wNQTk60tovuckwB15B191tJWvpp4HjiQWdJkCxO3Wbvc6jlk3Xb2Q=="], + "@isaacs/cliui/strip-ansi/ansi-regex": ["ansi-regex@6.1.0", "", {}, "sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA=="], + "@radix-ui/react-popper/@floating-ui/react-dom/@floating-ui/dom": ["@floating-ui/dom@0.5.4", "", { "dependencies": { "@floating-ui/core": "^0.7.3" } }, "sha512-419BMceRLq0RrmTSDxn8hf9R3VCJv2K9PUfugh5JyEFmdjzDo+e8U5EdR8nzKq8Yj1htzLm3b6eQEEam3/rrtg=="], "@radix-ui/react-primitive/@radix-ui/react-slot/@radix-ui/react-compose-refs": ["@radix-ui/react-compose-refs@1.0.0", "", { "dependencies": { "@babel/runtime": "^7.13.10" }, "peerDependencies": { "react": "^16.8 || ^17.0 || ^18.0" } }, "sha512-0KaSv6sx787/hK3eF53iOkiSLwAGlFMx5lotrqD2pTjB18KbybKoEIgkNZTKC60YECDQTKGTRcDBILwZVqVKvA=="], @@ -1981,11 +2113,11 @@ "simplify-geojson/concat-stream/typedarray": ["typedarray@0.0.7", "", {}, "sha512-ueeb9YybpjhivjbHP2LdFDAjbS948fGEPj+ACAMs4xCMmh72OCOMQWBQKlaN4ZNQ04yfLSDLSx1tGRIoWimObQ=="], - "string-width-cjs/strip-ansi/ansi-regex": ["ansi-regex@5.0.1", "", {}, "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ=="], + "string-width/strip-ansi/ansi-regex": ["ansi-regex@6.1.0", "", {}, "sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA=="], "wrap-ansi-cjs/string-width/emoji-regex": ["emoji-regex@8.0.0", "", {}, "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A=="], - "wrap-ansi-cjs/strip-ansi/ansi-regex": ["ansi-regex@5.0.1", "", {}, "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ=="], + "wrap-ansi/strip-ansi/ansi-regex": ["ansi-regex@6.1.0", "", {}, "sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA=="], "@babel/plugin-transform-runtime/@babel/helper-module-imports/@babel/traverse/@babel/code-frame": ["@babel/code-frame@7.27.1", "", { "dependencies": { "@babel/helper-validator-identifier": "^7.27.1", "js-tokens": "^4.0.0", "picocolors": "^1.1.1" } }, "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg=="], diff --git a/web/index.html b/web/index.html index 1e75f3d74..09d87ae1a 100644 --- a/web/index.html +++ b/web/index.html @@ -11,9 +11,10 @@ /> New API +
- + diff --git a/web/package.json b/web/package.json index a313e0f59..f014d84b9 100644 --- a/web/package.json +++ b/web/package.json @@ -21,6 +21,7 @@ "lucide-react": "^0.511.0", "marked": "^4.1.1", "mermaid": "^11.6.0", + "qrcode.react": "^4.2.0", "react": "^18.2.0", "react-dom": "^18.2.0", "react-dropzone": "^14.2.3", @@ -46,6 +47,8 @@ "build": "vite build", "lint": "prettier . --check", "lint:fix": "prettier . --write", + "eslint": "bunx eslint \"**/*.{js,jsx}\" --cache", + "eslint:fix": "bunx eslint \"**/*.{js,jsx}\" --fix --cache", "preview": "vite preview" }, "eslintConfig": { @@ -71,6 +74,9 @@ "@so1ve/prettier-config": "^3.1.0", "@vitejs/plugin-react": "^4.2.1", "autoprefixer": "^10.4.21", + "eslint": "8.57.0", + "eslint-plugin-header": "^3.1.1", + "eslint-plugin-react-hooks": "^5.2.0", "postcss": "^8.5.3", "prettier": "^3.0.0", "tailwindcss": "^3", diff --git a/web/postcss.config.js b/web/postcss.config.js index 2e7af2b7f..5731ce76e 100644 --- a/web/postcss.config.js +++ b/web/postcss.config.js @@ -1,6 +1,25 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + export default { plugins: { tailwindcss: {}, autoprefixer: {}, }, -} +}; diff --git a/web/public/cover-4.webp b/web/public/cover-4.webp new file mode 100644 index 000000000..0e9ecbf0d Binary files /dev/null and b/web/public/cover-4.webp differ diff --git a/web/src/App.js b/web/src/App.jsx similarity index 66% rename from web/src/App.js rename to web/src/App.jsx index 2d715767d..635742f91 100644 --- a/web/src/App.js +++ b/web/src/App.jsx @@ -1,38 +1,82 @@ -import React, { lazy, Suspense } from 'react'; +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { lazy, Suspense, useContext, useMemo } from 'react'; import { Route, Routes, useLocation } from 'react-router-dom'; -import Loading from './components/common/Loading.js'; +import Loading from './components/common/ui/Loading'; import User from './pages/User'; -import { AuthRedirect, PrivateRoute } from './helpers'; -import RegisterForm from './components/auth/RegisterForm.js'; -import LoginForm from './components/auth/LoginForm.js'; +import { AuthRedirect, PrivateRoute, AdminRoute } from './helpers'; +import RegisterForm from './components/auth/RegisterForm'; +import LoginForm from './components/auth/LoginForm'; import NotFound from './pages/NotFound'; +import Forbidden from './pages/Forbidden'; import Setting from './pages/Setting'; -import EditUser from './pages/User/EditUser'; -import PasswordResetForm from './components/auth/PasswordResetForm.js'; -import PasswordResetConfirm from './components/auth/PasswordResetConfirm.js'; +import { StatusContext } from './context/Status'; + +import PasswordResetForm from './components/auth/PasswordResetForm'; +import PasswordResetConfirm from './components/auth/PasswordResetConfirm'; import Channel from './pages/Channel'; import Token from './pages/Token'; -import EditChannel from './pages/Channel/EditChannel'; import Redemption from './pages/Redemption'; import TopUp from './pages/TopUp'; import Log from './pages/Log'; import Chat from './pages/Chat'; import Chat2Link from './pages/Chat2Link'; import Midjourney from './pages/Midjourney'; -import Pricing from './pages/Pricing/index.js'; -import Task from './pages/Task/index.js'; -import Playground from './pages/Playground/index.js'; -import OAuth2Callback from './components/auth/OAuth2Callback.js'; -import PersonalSetting from './components/settings/PersonalSetting.js'; -import Setup from './pages/Setup/index.js'; -import SetupCheck from './components/layout/SetupCheck.js'; +import Pricing from './pages/Pricing'; +import Task from './pages/Task'; +import ModelPage from './pages/Model'; +import Playground from './pages/Playground'; +import OAuth2Callback from './components/auth/OAuth2Callback'; +import PersonalSetting from './components/settings/PersonalSetting'; +import Setup from './pages/Setup'; +import SetupCheck from './components/layout/SetupCheck'; const Home = lazy(() => import('./pages/Home')); -const Detail = lazy(() => import('./pages/Detail')); +const Dashboard = lazy(() => import('./pages/Dashboard')); const About = lazy(() => import('./pages/About')); function App() { const location = useLocation(); + const [statusState] = useContext(StatusContext); + + // 获取模型广场权限配置 + const pricingRequireAuth = useMemo(() => { + const headerNavModulesConfig = statusState?.status?.HeaderNavModules; + if (headerNavModulesConfig) { + try { + const modules = JSON.parse(headerNavModulesConfig); + + // 处理向后兼容性:如果pricing是boolean,默认不需要登录 + if (typeof modules.pricing === 'boolean') { + return false; // 默认不需要登录鉴权 + } + + // 如果是对象格式,使用requireAuth配置 + return modules.pricing?.requireAuth === true; + } catch (error) { + console.error('解析顶栏模块配置失败:', error); + return false; // 默认不需要登录 + } + } + return false; // 默认不需要登录 + }, [statusState?.status?.HeaderNavModules]); return ( @@ -53,28 +97,21 @@ function App() { } /> + } /> + + + + } + /> + - - } - /> - } key={location.pathname}> - - - } - /> - } key={location.pathname}> - - + } /> + - + } /> + - - } - /> - } key={location.pathname}> - - - } - /> - } key={location.pathname}> - - + } /> + } key={location.pathname}> - + } /> } key={location.pathname}> - + } @@ -256,9 +277,20 @@ function App() { } key={location.pathname}> - - + pricingRequireAuth ? ( + + } + key={location.pathname} + > + + + + ) : ( + } key={location.pathname}> + + + ) } /> . + +For commercial licensing, please contact support@quantumnous.com +*/ + import React, { useContext, useEffect, useState } from 'react'; import { Link, useNavigate, useSearchParams } from 'react-router-dom'; -import { UserContext } from '../../context/User/index.js'; +import { UserContext } from '../../context/User'; import { API, getLogo, @@ -12,25 +31,19 @@ import { setUserData, onGitHubOAuthClicked, onOIDCClicked, - onLinuxDOOAuthClicked -} from '../../helpers/index.js'; + onLinuxDOOAuthClicked, +} from '../../helpers'; import Turnstile from 'react-turnstile'; -import { - Button, - Card, - Divider, - Form, - Icon, - Modal, -} from '@douyinfe/semi-ui'; +import { Button, Card, Divider, Form, Icon, Modal } from '@douyinfe/semi-ui'; import Title from '@douyinfe/semi-ui/lib/es/typography/title'; import Text from '@douyinfe/semi-ui/lib/es/typography/text'; import TelegramLoginButton from 'react-telegram-login'; import { IconGithubLogo, IconMail, IconLock } from '@douyinfe/semi-icons'; -import OIDCIcon from '../common/logo/OIDCIcon.js'; -import WeChatIcon from '../common/logo/WeChatIcon.js'; -import LinuxDoIcon from '../common/logo/LinuxDoIcon.js'; +import OIDCIcon from '../common/logo/OIDCIcon'; +import WeChatIcon from '../common/logo/WeChatIcon'; +import LinuxDoIcon from '../common/logo/LinuxDoIcon'; +import TwoFAVerification from './TwoFAVerification'; import { useTranslation } from 'react-i18next'; const LoginForm = () => { @@ -57,8 +70,10 @@ const LoginForm = () => { const [emailLoginLoading, setEmailLoginLoading] = useState(false); const [loginLoading, setLoginLoading] = useState(false); const [resetPasswordLoading, setResetPasswordLoading] = useState(false); - const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = useState(false); + const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = + useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); + const [showTwoFA, setShowTwoFA] = useState(false); const logo = getLogo(); const systemName = getSystemName(); @@ -143,6 +158,13 @@ const LoginForm = () => { ); const { success, message, data } = res.data; if (success) { + // 检查是否需要2FA验证 + if (data && data.require_2fa) { + setShowTwoFA(true); + setLoginLoading(false); + return; + } + userDispatch({ type: 'login', payload: data }); setUserData(data); updateAPI(); @@ -219,10 +241,7 @@ const LoginForm = () => { const handleOIDCClick = () => { setOidcLoading(true); try { - onOIDCClicked( - status.oidc_authorization_endpoint, - status.oidc_client_id - ); + onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id); } finally { // 由于重定向,这里不会执行到,但为了完整性添加 setTimeout(() => setOidcLoading(false), 3000); @@ -261,79 +280,104 @@ const LoginForm = () => { setOtherLoginOptionsLoading(false); }; + // 2FA验证成功处理 + const handle2FASuccess = (data) => { + userDispatch({ type: 'login', payload: data }); + setUserData(data); + updateAPI(); + showSuccess('登录成功!'); + navigate('/console'); + }; + + // 返回登录页面 + const handleBackToLogin = () => { + setShowTwoFA(false); + setInputs({ username: '', password: '', wechat_verification_code: '' }); + }; + const renderOAuthOptions = () => { return ( -
-
-
- Logo - {systemName} +
+
+
+ Logo + + {systemName} +
- -
- {t('登 录')} + +
+ + {t('登 录')} +
-
-
+
+
{status.wechat_login && ( )} {status.github_oauth && ( )} {status.oidc_enabled && ( )} {status.linuxdo_oauth && ( )} {status.telegram_oauth && ( -
+
{
{!status.self_use_mode_enabled && ( -
+
{t('没有账户?')}{' '} {t('注册')} @@ -380,47 +423,46 @@ const LoginForm = () => { const renderEmailLoginForm = () => { return ( -
-
-
- Logo +
+
+
+ Logo {systemName}
- -
- {t('登 录')} + +
+ + {t('登 录')} +
-
-
+
+ handleChange('username', value)} prefix={} /> handleChange('password', value)} prefix={} /> -
+
- {(status.github_oauth || status.oidc_enabled || status.wechat_login || status.linuxdo_oauth || status.telegram_oauth) && ( + {(status.github_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth) && ( <> {t('或')} -
+
-
- {t('返回登录')} +
+ + + {t('返回登录')} + +
diff --git a/web/src/components/auth/PasswordResetForm.js b/web/src/components/auth/PasswordResetForm.jsx similarity index 54% rename from web/src/components/auth/PasswordResetForm.js rename to web/src/components/auth/PasswordResetForm.jsx index 033989e01..92afc2afa 100644 --- a/web/src/components/auth/PasswordResetForm.js +++ b/web/src/components/auth/PasswordResetForm.jsx @@ -1,5 +1,31 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + import React, { useEffect, useState } from 'react'; -import { API, getLogo, showError, showInfo, showSuccess, getSystemName } from '../../helpers'; +import { + API, + getLogo, + showError, + showInfo, + showSuccess, + getSystemName, +} from '../../helpers'; import Turnstile from 'react-turnstile'; import { Button, Card, Form, Typography } from '@douyinfe/semi-ui'; import { IconMail } from '@douyinfe/semi-icons'; @@ -78,59 +104,77 @@ const PasswordResetForm = () => { } return ( -
+
{/* 背景模糊晕染球 */} -
-
-
-
-
-
- Logo - {systemName} +
+
+
+
+
+
+ Logo + + {systemName} +
- -
- {t('密码重置')} + +
+ + {t('密码重置')} +
-
-
+
+ } /> -
+
-
- {t('想起来了?')} {t('登录')} +
+ + {t('想起来了?')}{' '} + + {t('登录')} + +
{turnstileEnabled && ( -
+
{ diff --git a/web/src/components/auth/RegisterForm.js b/web/src/components/auth/RegisterForm.jsx similarity index 61% rename from web/src/components/auth/RegisterForm.js rename to web/src/components/auth/RegisterForm.jsx index 9d213a600..9c98bdc3a 100644 --- a/web/src/components/auth/RegisterForm.js +++ b/web/src/components/auth/RegisterForm.jsx @@ -1,3 +1,22 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + import React, { useContext, useEffect, useState } from 'react'; import { Link, useNavigate } from 'react-router-dom'; import { @@ -8,30 +27,29 @@ import { showSuccess, updateAPI, getSystemName, - setUserData -} from '../../helpers/index.js'; + setUserData, +} from '../../helpers'; import Turnstile from 'react-turnstile'; -import { - Button, - Card, - Divider, - Form, - Icon, - Modal, -} from '@douyinfe/semi-ui'; +import { Button, Card, Divider, Form, Icon, Modal } from '@douyinfe/semi-ui'; import Title from '@douyinfe/semi-ui/lib/es/typography/title'; import Text from '@douyinfe/semi-ui/lib/es/typography/text'; -import { IconGithubLogo, IconMail, IconUser, IconLock, IconKey } from '@douyinfe/semi-icons'; +import { + IconGithubLogo, + IconMail, + IconUser, + IconLock, + IconKey, +} from '@douyinfe/semi-icons'; import { onGitHubOAuthClicked, onLinuxDOOAuthClicked, onOIDCClicked, -} from '../../helpers/index.js'; -import OIDCIcon from '../common/logo/OIDCIcon.js'; -import LinuxDoIcon from '../common/logo/LinuxDoIcon.js'; -import WeChatIcon from '../common/logo/WeChatIcon.js'; +} from '../../helpers'; +import OIDCIcon from '../common/logo/OIDCIcon'; +import LinuxDoIcon from '../common/logo/LinuxDoIcon'; +import WeChatIcon from '../common/logo/WeChatIcon'; import TelegramLoginButton from 'react-telegram-login/src'; -import { UserContext } from '../../context/User/index.js'; +import { UserContext } from '../../context/User'; import { useTranslation } from 'react-i18next'; const RegisterForm = () => { @@ -59,8 +77,11 @@ const RegisterForm = () => { const [emailRegisterLoading, setEmailRegisterLoading] = useState(false); const [registerLoading, setRegisterLoading] = useState(false); const [verificationCodeLoading, setVerificationCodeLoading] = useState(false); - const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] = useState(false); + const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] = + useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); const logo = getLogo(); const systemName = getSystemName(); @@ -87,6 +108,19 @@ const RegisterForm = () => { } }, [status]); + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); // Clean up on unmount + }, [disableButton, countdown]); + const onWeChatLoginClicked = () => { setWechatLoading(true); setShowWeChatLoginModal(true); @@ -179,6 +213,7 @@ const RegisterForm = () => { const { success, message } = res.data; if (success) { showSuccess('验证码发送成功,请检查你的邮箱!'); + setDisableButton(true); // 发送成功后禁用按钮,开始倒计时 } else { showError(message); } @@ -201,10 +236,7 @@ const RegisterForm = () => { const handleOIDCClick = () => { setOidcLoading(true); try { - onOIDCClicked( - status.oidc_authorization_endpoint, - status.oidc_client_id - ); + onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id); } finally { setTimeout(() => setOidcLoading(false), 3000); } @@ -268,77 +300,87 @@ const RegisterForm = () => { const renderOAuthOptions = () => { return ( -
-
-
- Logo - {systemName} +
+
+
+ Logo + + {systemName} +
- -
- {t('注 册')} + +
+ + {t('注 册')} +
-
-
+
+
{status.wechat_login && ( )} {status.github_oauth && ( )} {status.oidc_enabled && ( )} {status.linuxdo_oauth && ( )} {status.telegram_oauth && ( -
+
{
-
- {t('已有账户?')} {t('登录')} +
+ + {t('已有账户?')}{' '} + + {t('登录')} + +
@@ -375,47 +424,48 @@ const RegisterForm = () => { const renderEmailRegisterForm = () => { return ( -
-
-
- Logo - {systemName} +
+
+
+ Logo + + {systemName} +
- -
- {t('注 册')} + +
+ + {t('注 册')} +
-
-
+
+ handleChange('username', value)} prefix={} /> handleChange('password', value)} prefix={} /> handleChange('password2', value)} prefix={} /> @@ -423,43 +473,44 @@ const RegisterForm = () => { {showEmailVerification && ( <> handleChange('email', value)} prefix={} suffix={ } /> handleChange('verification_code', value)} + name='verification_code' + onChange={(value) => + handleChange('verification_code', value) + } prefix={} /> )} -
+
- {(status.github_oauth || status.oidc_enabled || status.wechat_login || status.linuxdo_oauth || status.telegram_oauth) && ( + {(status.github_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth) && ( <> {t('或')} -
+
+ + + + +
+ + + {onBack && ( + + )} +
+ +
+ + 提示: +
+ • 验证码每30秒更新一次 +
+ • 如果无法获取验证码,请使用备用码 +
• 每个备用码只能使用一次 +
+
+
+ ); + } + + return ( +
+ +
+ 两步验证 + + 请输入认证器应用显示的验证码完成登录 + +
+ +
+ + + + + + + +
+ + + {onBack && ( + + )} +
+ +
+ + 提示: +
+ • 验证码每30秒更新一次 +
+ • 如果无法获取验证码,请使用备用码 +
• 每个备用码只能使用一次 +
+
+
+
+ ); +}; + +export default TwoFAVerification; diff --git a/web/src/components/common/Loading.js b/web/src/components/common/Loading.js deleted file mode 100644 index 738227550..000000000 --- a/web/src/components/common/Loading.js +++ /dev/null @@ -1,16 +0,0 @@ -import React from 'react'; -import { Spin } from '@douyinfe/semi-ui'; - -const Loading = ({ size = 'small' }) => { - - return ( -
- -
- ); -}; - -export default Loading; diff --git a/web/src/components/common/logo/LinuxDoIcon.js b/web/src/components/common/logo/LinuxDoIcon.jsx similarity index 71% rename from web/src/components/common/logo/LinuxDoIcon.js rename to web/src/components/common/logo/LinuxDoIcon.jsx index f6ee9b313..861f19d4f 100644 --- a/web/src/components/common/logo/LinuxDoIcon.js +++ b/web/src/components/common/logo/LinuxDoIcon.jsx @@ -1,3 +1,22 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + import React from 'react'; import { Icon } from '@douyinfe/semi-ui'; diff --git a/web/src/components/common/logo/OIDCIcon.js b/web/src/components/common/logo/OIDCIcon.jsx similarity index 68% rename from web/src/components/common/logo/OIDCIcon.js rename to web/src/components/common/logo/OIDCIcon.jsx index bd98c8fba..28d538eb0 100644 --- a/web/src/components/common/logo/OIDCIcon.js +++ b/web/src/components/common/logo/OIDCIcon.jsx @@ -1,3 +1,22 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + import React from 'react'; import { Icon } from '@douyinfe/semi-ui'; diff --git a/web/src/components/common/logo/WeChatIcon.js b/web/src/components/common/logo/WeChatIcon.jsx similarity index 71% rename from web/src/components/common/logo/WeChatIcon.js rename to web/src/components/common/logo/WeChatIcon.jsx index 723c7ecb2..f9f7057cf 100644 --- a/web/src/components/common/logo/WeChatIcon.js +++ b/web/src/components/common/logo/WeChatIcon.jsx @@ -1,3 +1,22 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + import React from 'react'; import { Icon } from '@douyinfe/semi-ui'; diff --git a/web/src/components/common/markdown/MarkdownRenderer.js b/web/src/components/common/markdown/MarkdownRenderer.jsx similarity index 68% rename from web/src/components/common/markdown/MarkdownRenderer.js rename to web/src/components/common/markdown/MarkdownRenderer.jsx index a48d34d1a..f1283a640 100644 --- a/web/src/components/common/markdown/MarkdownRenderer.js +++ b/web/src/components/common/markdown/MarkdownRenderer.jsx @@ -1,3 +1,22 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + import ReactMarkdown from 'react-markdown'; import 'katex/dist/katex.min.css'; import 'highlight.js/styles/github.css'; @@ -141,7 +160,7 @@ export function PreCode(props) { }} >
@@ -348,7 +374,16 @@ function _MarkdownContent(props) { components={{ pre: PreCode, code: CustomCode, - p: (pProps) =>

, + p: (pProps) => ( +

+ ), a: (aProps) => { const href = aProps.href || ''; if (/\.(aac|mp3|opus|wav)$/.test(href)) { @@ -360,13 +395,16 @@ function _MarkdownContent(props) { } if (/\.(3gp|3g2|webm|ogv|mpeg|mp4|avi)$/.test(href)) { return ( -

, - h2: (props) =>

, - h3: (props) =>

, - h4: (props) =>

, - h5: (props) =>

, - h6: (props) =>
, + h1: (props) => ( +

+ ), + h2: (props) => ( +

+ ), + h3: (props) => ( +

+ ), + h4: (props) => ( +

+ ), + h5: (props) => ( +

+ ), + h6: (props) => ( +
+ ), blockquote: (props) => (
), - ul: (props) =>
    , - ol: (props) =>
      , - li: (props) =>
    1. , + ul: (props) => ( +
        + ), + ol: (props) => ( +
          + ), + li: (props) => ( +
        1. + ), table: (props) => (
          @@ -477,25 +614,29 @@ export function MarkdownRenderer(props) { color: 'var(--semi-color-text-0)', ...style, }} - dir="auto" + dir='auto' {...otherProps} > {loading ? ( -
          -
          +
          +
          正在渲染...
          ) : ( @@ -510,4 +651,4 @@ export function MarkdownRenderer(props) { ); } -export default MarkdownRenderer; \ No newline at end of file +export default MarkdownRenderer; diff --git a/web/src/components/common/markdown/markdown.css b/web/src/components/common/markdown/markdown.css index 3b5c1067d..e1e9e9cb4 100644 --- a/web/src/components/common/markdown/markdown.css +++ b/web/src/components/common/markdown/markdown.css @@ -59,12 +59,12 @@ } .user-message a { - color: #87CEEB !important; + color: #87ceeb !important; /* 浅蓝色链接 */ } .user-message a:hover { - color: #B0E0E6 !important; + color: #b0e0e6 !important; /* hover时更浅的蓝色 */ } @@ -298,7 +298,12 @@ pre:hover .copy-code-button { .markdown-body hr { border: none; height: 1px; - background: linear-gradient(to right, transparent, var(--semi-color-border), transparent); + background: linear-gradient( + to right, + transparent, + var(--semi-color-border), + transparent + ); margin: 24px 0; } @@ -332,7 +337,7 @@ pre:hover .copy-code-button { } /* 任务列表样式 */ -.markdown-body input[type="checkbox"] { +.markdown-body input[type='checkbox'] { margin-right: 8px; transform: scale(1.1); } @@ -441,4 +446,4 @@ pre:hover .copy-code-button { .animate-fade-in { animation: fade-in 0.6s cubic-bezier(0.22, 1, 0.36, 1) both; will-change: opacity, transform; -} \ No newline at end of file +} diff --git a/web/src/components/common/modals/TwoFactorAuthModal.jsx b/web/src/components/common/modals/TwoFactorAuthModal.jsx new file mode 100644 index 000000000..b0fc28e2a --- /dev/null +++ b/web/src/components/common/modals/TwoFactorAuthModal.jsx @@ -0,0 +1,146 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React from 'react'; +import { useTranslation } from 'react-i18next'; +import { Modal, Button, Input, Typography } from '@douyinfe/semi-ui'; + +/** + * 可复用的两步验证模态框组件 + * @param {Object} props + * @param {boolean} props.visible - 是否显示模态框 + * @param {string} props.code - 验证码值 + * @param {boolean} props.loading - 是否正在验证 + * @param {Function} props.onCodeChange - 验证码变化回调 + * @param {Function} props.onVerify - 验证回调 + * @param {Function} props.onCancel - 取消回调 + * @param {string} props.title - 模态框标题 + * @param {string} props.description - 验证描述文本 + * @param {string} props.placeholder - 输入框占位文本 + */ +const TwoFactorAuthModal = ({ + visible, + code, + loading, + onCodeChange, + onVerify, + onCancel, + title, + description, + placeholder, +}) => { + const { t } = useTranslation(); + + const handleKeyDown = (e) => { + if (e.key === 'Enter' && code && !loading) { + onVerify(); + } + }; + + return ( + +
          + + + +
          + {title || t('安全验证')} +
          + } + visible={visible} + onCancel={onCancel} + footer={ + <> + + + + } + width={500} + style={{ maxWidth: '90vw' }} + > +
          + {/* 安全提示 */} +
          +
          + + + +
          + + {t('安全验证')} + + + {description || t('为了保护账户安全,请验证您的两步验证码。')} + +
          +
          +
          + + {/* 验证码输入 */} +
          + + {t('验证身份')} + + + + {t('支持6位TOTP验证码或8位备用码')} + +
          +
          + + ); +}; + +export default TwoFactorAuthModal; diff --git a/web/src/components/common/ui/CardPro.jsx b/web/src/components/common/ui/CardPro.jsx new file mode 100644 index 000000000..2c95f97c7 --- /dev/null +++ b/web/src/components/common/ui/CardPro.jsx @@ -0,0 +1,200 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useState } from 'react'; +import { Card, Divider, Typography, Button } from '@douyinfe/semi-ui'; +import PropTypes from 'prop-types'; +import { useIsMobile } from '../../../hooks/common/useIsMobile'; +import { IconEyeOpened, IconEyeClosed } from '@douyinfe/semi-icons'; + +const { Text } = Typography; + +/** + * CardPro 高级卡片组件 + * + * 布局分为6个区域: + * 1. 统计信息区域 (statsArea) + * 2. 描述信息区域 (descriptionArea) + * 3. 类型切换/标签区域 (tabsArea) + * 4. 操作按钮区域 (actionsArea) + * 5. 搜索表单区域 (searchArea) + * 6. 分页区域 (paginationArea) - 固定在卡片底部 + * + * 支持三种布局类型: + * - type1: 操作型 (如TokensTable) - 描述信息 + 操作按钮 + 搜索表单 + * - type2: 查询型 (如LogsTable) - 统计信息 + 搜索表单 + * - type3: 复杂型 (如ChannelsTable) - 描述信息 + 类型切换 + 操作按钮 + 搜索表单 + */ +const CardPro = ({ + type = 'type1', + className = '', + children, + // 各个区域的内容 + statsArea, + descriptionArea, + tabsArea, + actionsArea, + searchArea, + paginationArea, // 新增分页区域 + // 卡片属性 + shadows = '', + bordered = true, + // 自定义样式 + style, + // 国际化函数 + t = (key) => key, + ...props +}) => { + const isMobile = useIsMobile(); + const [showMobileActions, setShowMobileActions] = useState(false); + + const toggleMobileActions = () => { + setShowMobileActions(!showMobileActions); + }; + + const hasMobileHideableContent = actionsArea || searchArea; + + const renderHeader = () => { + const hasContent = + statsArea || descriptionArea || tabsArea || actionsArea || searchArea; + if (!hasContent) return null; + + return ( +
          + {/* 统计信息区域 - 用于type2 */} + {type === 'type2' && statsArea && <>{statsArea}} + + {/* 描述信息区域 - 用于type1和type3 */} + {(type === 'type1' || type === 'type3') && descriptionArea && ( + <>{descriptionArea} + )} + + {/* 第一个分隔线 - 在描述信息或统计信息后面 */} + {((type === 'type1' || type === 'type3') && descriptionArea) || + (type === 'type2' && statsArea) ? ( + + ) : null} + + {/* 类型切换/标签区域 - 主要用于type3 */} + {type === 'type3' && tabsArea && <>{tabsArea}} + + {/* 移动端操作切换按钮 */} + {isMobile && hasMobileHideableContent && ( + <> +
          + +
          + + )} + + {/* 操作按钮和搜索表单的容器 */} +
          + {/* 操作按钮区域 - 用于type1和type3 */} + {(type === 'type1' || type === 'type3') && + actionsArea && + (Array.isArray(actionsArea) ? ( + actionsArea.map((area, idx) => ( + + {idx !== 0 && } +
          {area}
          +
          + )) + ) : ( +
          {actionsArea}
          + ))} + + {/* 当同时存在操作区和搜索区时,插入分隔线 */} + {actionsArea && searchArea && } + + {/* 搜索表单区域 - 所有类型都可能有 */} + {searchArea &&
          {searchArea}
          } +
          +
          + ); + }; + + const headerContent = renderHeader(); + + // 渲染分页区域 + const renderFooter = () => { + if (!paginationArea) return null; + + return ( +
          + {paginationArea} +
          + ); + }; + + const footerContent = renderFooter(); + + return ( + + {children} + + ); +}; + +CardPro.propTypes = { + // 布局类型 + type: PropTypes.oneOf(['type1', 'type2', 'type3']), + // 样式相关 + className: PropTypes.string, + style: PropTypes.object, + shadows: PropTypes.oneOfType([PropTypes.string, PropTypes.bool]), + bordered: PropTypes.bool, + // 内容区域 + statsArea: PropTypes.node, + descriptionArea: PropTypes.node, + tabsArea: PropTypes.node, + actionsArea: PropTypes.oneOfType([ + PropTypes.node, + PropTypes.arrayOf(PropTypes.node), + ]), + searchArea: PropTypes.node, + paginationArea: PropTypes.node, + // 表格内容 + children: PropTypes.node, + // 国际化函数 + t: PropTypes.func, +}; + +export default CardPro; diff --git a/web/src/components/common/ui/CardTable.jsx b/web/src/components/common/ui/CardTable.jsx new file mode 100644 index 000000000..8a331d07e --- /dev/null +++ b/web/src/components/common/ui/CardTable.jsx @@ -0,0 +1,242 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useState, useEffect, useRef } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + Table, + Card, + Skeleton, + Pagination, + Empty, + Button, + Collapsible, +} from '@douyinfe/semi-ui'; +import { IconChevronDown, IconChevronUp } from '@douyinfe/semi-icons'; +import PropTypes from 'prop-types'; +import { useIsMobile } from '../../../hooks/common/useIsMobile'; +import { useMinimumLoadingTime } from '../../../hooks/common/useMinimumLoadingTime'; + +/** + * CardTable 响应式表格组件 + * + * 在桌面端渲染 Semi-UI 的 Table 组件,在移动端则将每一行数据渲染成 Card 形式。 + * 该组件与 Table 组件的大部分 API 保持一致,只需将原 Table 换成 CardTable 即可。 + */ +const CardTable = ({ + columns = [], + dataSource = [], + loading = false, + rowKey = 'key', + hidePagination = false, + ...tableProps +}) => { + const isMobile = useIsMobile(); + const { t } = useTranslation(); + + const showSkeleton = useMinimumLoadingTime(loading); + + const getRowKey = (record, index) => { + if (typeof rowKey === 'function') return rowKey(record); + return record[rowKey] !== undefined ? record[rowKey] : index; + }; + + if (!isMobile) { + const finalTableProps = hidePagination + ? { ...tableProps, pagination: false } + : tableProps; + + return ( +
          + ); + } + + if (showSkeleton) { + const visibleCols = columns.filter((col) => { + if (tableProps?.visibleColumns && col.key) { + return tableProps.visibleColumns[col.key]; + } + return true; + }); + + const renderSkeletonCard = (key) => { + const placeholder = ( +
          + {visibleCols.map((col, idx) => { + if (!col.title) { + return ( +
          + +
          + ); + } + + return ( +
          + + +
          + ); + })} +
          + ); + + return ( + + + + ); + }; + + return ( +
          + {[1, 2, 3].map((i) => renderSkeletonCard(i))} +
          + ); + } + + const isEmpty = !showSkeleton && (!dataSource || dataSource.length === 0); + + const MobileRowCard = ({ record, index }) => { + const [showDetails, setShowDetails] = useState(false); + const rowKeyVal = getRowKey(record, index); + + const hasDetails = + tableProps.expandedRowRender && + (!tableProps.rowExpandable || tableProps.rowExpandable(record)); + + return ( + + {columns.map((col, colIdx) => { + if ( + tableProps?.visibleColumns && + !tableProps.visibleColumns[col.key] + ) { + return null; + } + + const title = col.title; + const cellContent = col.render + ? col.render(record[col.dataIndex], record, index) + : record[col.dataIndex]; + + if (!title) { + return ( +
          + {cellContent} +
          + ); + } + + return ( +
          + + {title} + +
          + {cellContent !== undefined && cellContent !== null + ? cellContent + : '-'} +
          +
          + ); + })} + + {hasDetails && ( + <> + + +
          + {tableProps.expandedRowRender(record, index)} +
          +
          + + )} +
          + ); + }; + + if (isEmpty) { + if (tableProps.empty) return tableProps.empty; + return ( +
          + +
          + ); + } + + return ( +
          + {dataSource.map((record, index) => ( + + ))} + {!hidePagination && tableProps.pagination && dataSource.length > 0 && ( +
          + +
          + )} +
          + ); +}; + +CardTable.propTypes = { + columns: PropTypes.array.isRequired, + dataSource: PropTypes.array, + loading: PropTypes.bool, + rowKey: PropTypes.oneOfType([PropTypes.string, PropTypes.func]), + hidePagination: PropTypes.bool, +}; + +export default CardTable; diff --git a/web/src/components/common/ui/ChannelKeyDisplay.jsx b/web/src/components/common/ui/ChannelKeyDisplay.jsx new file mode 100644 index 000000000..79aa3eec7 --- /dev/null +++ b/web/src/components/common/ui/ChannelKeyDisplay.jsx @@ -0,0 +1,280 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React from 'react'; +import { useTranslation } from 'react-i18next'; +import { Card, Button, Typography, Tag } from '@douyinfe/semi-ui'; +import { copy, showSuccess } from '../../../helpers'; + +/** + * 解析密钥数据,支持多种格式 + * @param {string} keyData - 密钥数据 + * @param {Function} t - 翻译函数 + * @returns {Array} 解析后的密钥数组 + */ +const parseChannelKeys = (keyData, t) => { + if (!keyData) return []; + + const trimmed = keyData.trim(); + + // 检查是否是JSON数组格式(如Vertex AI) + if (trimmed.startsWith('[')) { + try { + const parsed = JSON.parse(trimmed); + if (Array.isArray(parsed)) { + return parsed.map((item, index) => ({ + id: index, + content: + typeof item === 'string' ? item : JSON.stringify(item, null, 2), + type: typeof item === 'string' ? 'text' : 'json', + label: `${t('密钥')} ${index + 1}`, + })); + } + } catch (e) { + // 如果解析失败,按普通文本处理 + console.warn('Failed to parse JSON keys:', e); + } + } + + // 检查是否是多行密钥(按换行符分割) + const lines = trimmed.split('\n').filter((line) => line.trim()); + if (lines.length > 1) { + return lines.map((line, index) => ({ + id: index, + content: line.trim(), + type: 'text', + label: `${t('密钥')} ${index + 1}`, + })); + } + + // 单个密钥 + return [ + { + id: 0, + content: trimmed, + type: trimmed.startsWith('{') ? 'json' : 'text', + label: t('密钥'), + }, + ]; +}; + +/** + * 可复用的密钥显示组件 + * @param {Object} props + * @param {string} props.keyData - 密钥数据 + * @param {boolean} props.showSuccessIcon - 是否显示成功图标 + * @param {string} props.successText - 成功文本 + * @param {boolean} props.showWarning - 是否显示安全警告 + * @param {string} props.warningText - 警告文本 + */ +const ChannelKeyDisplay = ({ + keyData, + showSuccessIcon = true, + successText, + showWarning = true, + warningText, +}) => { + const { t } = useTranslation(); + + const parsedKeys = parseChannelKeys(keyData, t); + const isMultipleKeys = parsedKeys.length > 1; + + const handleCopyAll = () => { + copy(keyData); + showSuccess(t('所有密钥已复制到剪贴板')); + }; + + const handleCopyKey = (content) => { + copy(content); + showSuccess(t('密钥已复制到剪贴板')); + }; + + return ( +
          + {/* 成功状态 */} + {showSuccessIcon && ( +
          + + + + + {successText || t('验证成功')} + +
          + )} + + {/* 密钥内容 */} +
          +
          + + {isMultipleKeys ? t('渠道密钥列表') : t('渠道密钥')} + + {isMultipleKeys && ( +
          + + {t('共 {{count}} 个密钥', { count: parsedKeys.length })} + + +
          + )} +
          + +
          + {parsedKeys.map((keyItem) => ( + +
          +
          + + {keyItem.label} + +
          + {keyItem.type === 'json' && ( + + {t('JSON')} + + )} + +
          +
          + +
          + + {keyItem.content} + +
          + + {keyItem.type === 'json' && ( + + {t('JSON格式密钥,请确保格式正确')} + + )} +
          +
          + ))} +
          + + {isMultipleKeys && ( +
          + + + + + {t( + '检测到多个密钥,您可以单独复制每个密钥,或点击复制全部获取完整内容。', + )} + +
          + )} +
          + + {/* 安全警告 */} + {showWarning && ( +
          +
          + + + +
          + + {t('安全提醒')} + + + {warningText || + t( + '请妥善保管密钥信息,不要泄露给他人。如有安全疑虑,请及时更换密钥。', + )} + +
          +
          +
          + )} +
          + ); +}; + +export default ChannelKeyDisplay; diff --git a/web/src/components/common/ui/CompactModeToggle.jsx b/web/src/components/common/ui/CompactModeToggle.jsx new file mode 100644 index 000000000..40da0abc0 --- /dev/null +++ b/web/src/components/common/ui/CompactModeToggle.jsx @@ -0,0 +1,68 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React from 'react'; +import { Button } from '@douyinfe/semi-ui'; +import PropTypes from 'prop-types'; +import { useIsMobile } from '../../../hooks/common/useIsMobile'; + +/** + * 紧凑模式切换按钮组件 + * 用于在自适应列表和紧凑列表之间切换 + * 在移动端时自动隐藏,因为移动端使用"显示操作项"按钮来控制内容显示 + */ +const CompactModeToggle = ({ + compactMode, + setCompactMode, + t, + size = 'small', + type = 'tertiary', + className = '', + ...props +}) => { + const isMobile = useIsMobile(); + + // 在移动端隐藏紧凑列表切换按钮 + if (isMobile) { + return null; + } + + return ( + + ); +}; + +CompactModeToggle.propTypes = { + compactMode: PropTypes.bool.isRequired, + setCompactMode: PropTypes.func.isRequired, + t: PropTypes.func.isRequired, + size: PropTypes.string, + type: PropTypes.string, + className: PropTypes.string, +}; + +export default CompactModeToggle; diff --git a/web/src/components/common/ui/JSONEditor.jsx b/web/src/components/common/ui/JSONEditor.jsx new file mode 100644 index 000000000..d89753872 --- /dev/null +++ b/web/src/components/common/ui/JSONEditor.jsx @@ -0,0 +1,714 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useState, useEffect, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + Button, + Form, + Typography, + Banner, + Tabs, + TabPane, + Card, + Input, + InputNumber, + Switch, + TextArea, + Row, + Col, + Divider, + Tooltip, +} from '@douyinfe/semi-ui'; +import { IconPlus, IconDelete, IconAlertTriangle } from '@douyinfe/semi-icons'; + +const { Text } = Typography; + +// 唯一 ID 生成器,确保在组件生命周期内稳定且递增 +const generateUniqueId = (() => { + let counter = 0; + return () => `kv_${counter++}`; +})(); + +const JSONEditor = ({ + value = '', + onChange, + field, + label, + placeholder, + extraText, + extraFooter, + showClear = true, + template, + templateLabel, + editorType = 'keyValue', + rules = [], + formApi = null, + ...props +}) => { + const { t } = useTranslation(); + + // 将对象转换为键值对数组(包含唯一ID) + const objectToKeyValueArray = useCallback((obj, prevPairs = []) => { + if (!obj || typeof obj !== 'object') return []; + + const entries = Object.entries(obj); + return entries.map(([key, value], index) => { + // 如果上一次转换后同位置的键一致,则沿用其 id,保持 React key 稳定 + const prev = prevPairs[index]; + const shouldReuseId = prev && prev.key === key; + return { + id: shouldReuseId ? prev.id : generateUniqueId(), + key, + value, + }; + }); + }, []); + + // 将键值对数组转换为对象(重复键时后面的会覆盖前面的) + const keyValueArrayToObject = useCallback((arr) => { + const result = {}; + arr.forEach((item) => { + if (item.key) { + result[item.key] = item.value; + } + }); + return result; + }, []); + + // 初始化键值对数组 + const [keyValuePairs, setKeyValuePairs] = useState(() => { + if (typeof value === 'string' && value.trim()) { + try { + const parsed = JSON.parse(value); + return objectToKeyValueArray(parsed); + } catch (error) { + return []; + } + } + if (typeof value === 'object' && value !== null) { + return objectToKeyValueArray(value); + } + return []; + }); + + // 手动模式下的本地文本缓冲 + const [manualText, setManualText] = useState(() => { + if (typeof value === 'string') return value; + if (value && typeof value === 'object') + return JSON.stringify(value, null, 2); + return ''; + }); + + // 根据键数量决定默认编辑模式 + const [editMode, setEditMode] = useState(() => { + if (typeof value === 'string' && value.trim()) { + try { + const parsed = JSON.parse(value); + const keyCount = Object.keys(parsed).length; + return keyCount > 10 ? 'manual' : 'visual'; + } catch (error) { + return 'manual'; + } + } + return 'visual'; + }); + + const [jsonError, setJsonError] = useState(''); + + // 计算重复的键 + const duplicateKeys = useMemo(() => { + const keyCount = {}; + const duplicates = new Set(); + + keyValuePairs.forEach((pair) => { + if (pair.key) { + keyCount[pair.key] = (keyCount[pair.key] || 0) + 1; + if (keyCount[pair.key] > 1) { + duplicates.add(pair.key); + } + } + }); + + return duplicates; + }, [keyValuePairs]); + + // 数据同步 - 当value变化时更新键值对数组 + useEffect(() => { + try { + let parsed = {}; + if (typeof value === 'string' && value.trim()) { + parsed = JSON.parse(value); + } else if (typeof value === 'object' && value !== null) { + parsed = value; + } + + // 只在外部值真正改变时更新,避免循环更新 + const currentObj = keyValueArrayToObject(keyValuePairs); + if (JSON.stringify(parsed) !== JSON.stringify(currentObj)) { + setKeyValuePairs(objectToKeyValueArray(parsed, keyValuePairs)); + } + setJsonError(''); + } catch (error) { + console.log('JSON解析失败:', error.message); + setJsonError(error.message); + } + }, [value]); + + // 外部 value 变化时,若不在手动模式,则同步手动文本 + useEffect(() => { + if (editMode !== 'manual') { + if (typeof value === 'string') setManualText(value); + else if (value && typeof value === 'object') + setManualText(JSON.stringify(value, null, 2)); + else setManualText(''); + } + }, [value, editMode]); + + // 处理可视化编辑的数据变化 + const handleVisualChange = useCallback( + (newPairs) => { + setKeyValuePairs(newPairs); + const jsonObject = keyValueArrayToObject(newPairs); + const jsonString = + Object.keys(jsonObject).length === 0 + ? '' + : JSON.stringify(jsonObject, null, 2); + + setJsonError(''); + + // 通过formApi设置值 + if (formApi && field) { + formApi.setValue(field, jsonString); + } + + onChange?.(jsonString); + }, + [onChange, formApi, field, keyValueArrayToObject], + ); + + // 处理手动编辑的数据变化 + const handleManualChange = useCallback( + (newValue) => { + setManualText(newValue); + if (newValue && newValue.trim()) { + try { + const parsed = JSON.parse(newValue); + setKeyValuePairs(objectToKeyValueArray(parsed, keyValuePairs)); + setJsonError(''); + onChange?.(newValue); + } catch (error) { + setJsonError(error.message); + } + } else { + setKeyValuePairs([]); + setJsonError(''); + onChange?.(''); + } + }, + [onChange, objectToKeyValueArray, keyValuePairs], + ); + + // 切换编辑模式 + const toggleEditMode = useCallback(() => { + if (editMode === 'visual') { + const jsonObject = keyValueArrayToObject(keyValuePairs); + setManualText( + Object.keys(jsonObject).length === 0 + ? '' + : JSON.stringify(jsonObject, null, 2), + ); + setEditMode('manual'); + } else { + try { + let parsed = {}; + if (manualText && manualText.trim()) { + parsed = JSON.parse(manualText); + } else if (typeof value === 'string' && value.trim()) { + parsed = JSON.parse(value); + } else if (typeof value === 'object' && value !== null) { + parsed = value; + } + setKeyValuePairs(objectToKeyValueArray(parsed, keyValuePairs)); + setJsonError(''); + setEditMode('visual'); + } catch (error) { + setJsonError(error.message); + return; + } + } + }, [ + editMode, + value, + manualText, + keyValuePairs, + keyValueArrayToObject, + objectToKeyValueArray, + ]); + + // 添加键值对 + const addKeyValue = useCallback(() => { + const newPairs = [...keyValuePairs]; + const existingKeys = newPairs.map((p) => p.key); + let counter = 1; + let newKey = `field_${counter}`; + while (existingKeys.includes(newKey)) { + counter += 1; + newKey = `field_${counter}`; + } + newPairs.push({ + id: generateUniqueId(), + key: newKey, + value: '', + }); + handleVisualChange(newPairs); + }, [keyValuePairs, handleVisualChange]); + + // 删除键值对 + const removeKeyValue = useCallback( + (id) => { + const newPairs = keyValuePairs.filter((pair) => pair.id !== id); + handleVisualChange(newPairs); + }, + [keyValuePairs, handleVisualChange], + ); + + // 更新键名 + const updateKey = useCallback( + (id, newKey) => { + const newPairs = keyValuePairs.map((pair) => + pair.id === id ? { ...pair, key: newKey } : pair, + ); + handleVisualChange(newPairs); + }, + [keyValuePairs, handleVisualChange], + ); + + // 更新值 + const updateValue = useCallback( + (id, newValue) => { + const newPairs = keyValuePairs.map((pair) => + pair.id === id ? { ...pair, value: newValue } : pair, + ); + handleVisualChange(newPairs); + }, + [keyValuePairs, handleVisualChange], + ); + + // 填入模板 + const fillTemplate = useCallback(() => { + if (template) { + const templateString = JSON.stringify(template, null, 2); + + if (formApi && field) { + formApi.setValue(field, templateString); + } + + setManualText(templateString); + setKeyValuePairs(objectToKeyValueArray(template, keyValuePairs)); + onChange?.(templateString); + setJsonError(''); + } + }, [ + template, + onChange, + formApi, + field, + objectToKeyValueArray, + keyValuePairs, + ]); + + // 渲染值输入控件(支持嵌套) + const renderValueInput = (pairId, value) => { + const valueType = typeof value; + + if (valueType === 'boolean') { + return ( +
          + updateValue(pairId, newValue)} + /> + + {value ? t('true') : t('false')} + +
          + ); + } + + if (valueType === 'number') { + return ( + updateValue(pairId, newValue)} + style={{ width: '100%' }} + placeholder={t('输入数字')} + /> + ); + } + + if (valueType === 'object' && value !== null) { + // 简化嵌套对象的处理,使用TextArea + return ( +