mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 06:47:26 +00:00
Merge branch 'alpha' into base
This commit is contained in:
@@ -4,4 +4,5 @@
|
|||||||
.vscode
|
.vscode
|
||||||
.gitignore
|
.gitignore
|
||||||
Makefile
|
Makefile
|
||||||
docs
|
docs
|
||||||
|
.eslintcache
|
||||||
@@ -47,7 +47,7 @@
|
|||||||
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||||
# RELAY_TIMEOUT=0
|
# RELAY_TIMEOUT=0
|
||||||
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||||
# STREAMING_TIMEOUT=120
|
# STREAMING_TIMEOUT=300
|
||||||
|
|
||||||
# Gemini 识别图片 最大图片数量
|
# Gemini 识别图片 最大图片数量
|
||||||
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,4 +10,5 @@ web/dist
|
|||||||
.env
|
.env
|
||||||
one-api
|
one-api
|
||||||
.DS_Store
|
.DS_Store
|
||||||
tiktoken_cache
|
tiktoken_cache
|
||||||
|
.eslintcache
|
||||||
240
LICENSE
240
LICENSE
@@ -1,201 +1,103 @@
|
|||||||
Apache License
|
# **New API 许可协议 (Licensing)**
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
本项目采用**基于使用场景的双重许可 (Usage-Based Dual Licensing)** 模式。
|
||||||
|
|
||||||
1. Definitions.
|
**核心原则:**
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
- **默认许可:** 本项目默认在 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)** 下提供。任何用户在遵守 AGPLv3 条款和下述附加限制的前提下,均可免费使用。
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
- **商业许可:** 在特定商业场景下,或当您希望获得 AGPLv3 之外的权利时,**必须**获取**商业许可证 (Commercial License)**。
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
---
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
## **1. 开源许可证 (Open Source License): AGPLv3 - 适用于基础使用**
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
- 在遵守 **AGPLv3** 条款的前提下,您可以自由地使用、修改和分发 New API。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
|
||||||
exercising permissions granted by this License.
|
- **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 New API 并通过网络提供服务 (SaaS),或者分发了修改后的版本,您必须以 AGPLv3 许可证向所有用户提供相应的**完整源代码**。
|
||||||
|
- **附加限制 (重要):** 在仅使用 AGPLv3 开源许可证的情况下,您**必须**完整保留项目代码中原有的品牌标识、LOGO 及版权声明信息。**禁止以任何形式修改、移除或遮盖**这些信息。如需移除,必须获取商业许可证。
|
||||||
|
- 使用前请务必仔细阅读并理解 AGPLv3 的所有条款及上述附加限制。
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
## **2. 商业许可证 (Commercial License) - 适用于高级场景及闭源需求**
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
在以下任一情况下,您**必须**联系我们获取并签署一份商业许可证,才能合法使用 New API:
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
- **场景一:移除品牌和版权信息**
|
||||||
Object form, made available under the License, as indicated by a
|
您希望在您的产品或服务中移除 New API 的 LOGO、UI界面中的版权声明或其他品牌标识。
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
- **场景二:规避 AGPLv3 开源义务**
|
||||||
form, that is based on (or derived from) the Work and for which the
|
您基于 New API 进行了修改,并希望:
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
- 通过网络提供服务(SaaS),但**不希望**向您的服务用户公开您修改后的源代码。
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
- 分发一个集成了 New API 的软件产品,但**不希望**以 AGPLv3 许可证发布您的产品或公开源代码。
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
- **场景三:企业政策与集成需求**
|
||||||
the original version of the Work and any modifications or additions
|
- 您所在公司的政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件。
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
- 您需要进行 OEM 集成,将 New API 作为您闭源商业产品的一部分进行再分发。
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
- **场景四:需要商业支持与保障**
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
您需要 AGPLv3 未提供的商业保障,如官方技术支持等。
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
**获取商业许可:**
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
请通过电子邮件 **support@quantumnous.com** 联系 New API 团队洽谈商业授权事宜。
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
## **3. 贡献 (Contributions)**
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
- 我们欢迎社区对 New API 的贡献。所有向本项目提交的贡献(例如通过 Pull Request)都将被视为在 **AGPLv3** 许可证下提供。
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
- 通过向本项目提交贡献,即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
|
||||||
modifications, and in Source or Object form, provided that You
|
- 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 New API 版本中。
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
## **4. 其他条款 (Other Terms)**
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
- 关于商业许可证的具体条款、条件和价格,以双方签署的正式商业许可协议为准。
|
||||||
stating that You changed the files; and
|
- 项目维护者保留根据需要更新本许可政策的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
---
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
# **New API Licensing**
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
This project uses a **Usage-Based Dual Licensing** model.
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
**Core Principles:**
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
- **Default License:** This project is available by default under the **GNU Affero General Public License v3.0 (AGPLv3)**. Any user may use it free of charge, provided they comply with both the AGPLv3 terms and the additional restrictions listed below.
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
- **Commercial License:** For specific commercial scenarios, or if you require rights beyond those granted by AGPLv3, you **must** obtain a **Commercial License**.
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
---
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
## **1. Open Source License: AGPLv3 – For Basic Usage**
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
- Under the terms of the **AGPLv3**, you are free to use, modify, and distribute New API. The complete AGPLv3 license text can be viewed at [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html).
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
- **Core Obligation:** A key AGPLv3 requirement is that if you modify New API and provide it as a network service (SaaS), or distribute a modified version, you must make the **complete corresponding source code** available to all users under the AGPLv3 license.
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
- **Additional Restriction (Important):** When using only the AGPLv3 open-source license, you **must** retain all original branding, logos, and copyright statements within the project’s code. **You are strictly prohibited from modifying, removing, or concealing** any such information. If you wish to remove this, you must obtain a Commercial License.
|
||||||
or other liability obligations and/or rights consistent with this
|
- Please read and ensure that you fully understand all AGPLv3 terms and the above additional restriction before use.
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
## **2. Commercial License – For Advanced Scenarios & Closed Source Needs**
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
You **must** contact us to obtain and sign a Commercial License in any of the following scenarios in order to legally use New API:
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
- **Scenario 1: Removal of Branding and Copyright**
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
You wish to remove the New API logo, copyright statement, or other branding elements from your product or service.
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
- **Scenario 2: Avoidance of AGPLv3 Open Source Obligations**
|
||||||
|
You have modified New API and wish to:
|
||||||
|
- Offer it as a network service (SaaS) **without** disclosing your modifications' source code to your users.
|
||||||
|
- Distribute a software product integrated with New API **without** releasing your product under AGPLv3 or open-sourcing the code.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
- **Scenario 3: Enterprise Policy & Integration Needs**
|
||||||
you may not use this file except in compliance with the License.
|
- Your organization’s policies, client contracts, or project requirements prohibit the use of AGPLv3-licensed software.
|
||||||
You may obtain a copy of the License at
|
- You require OEM integration and need to redistribute New API as part of your closed-source commercial product.
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
- **Scenario 4: Commercial Support and Assurances**
|
||||||
|
You require commercial assurances not provided by AGPLv3, such as official technical support.
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
**Obtaining a Commercial License:**
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
Please contact the New API team via email at **support@quantumnous.com** to discuss commercial licensing.
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
## **3. Contributions**
|
||||||
limitations under the License.
|
|
||||||
|
- We welcome community contributions to New API. All contributions (e.g., via Pull Request) are deemed to be provided under the **AGPLv3** license.
|
||||||
|
- By submitting a contribution, you agree that your code is licensed to this project and all downstream users under the AGPLv3 license (regardless of whether those users ultimately operate under AGPLv3 or a Commercial License).
|
||||||
|
- You also acknowledge and agree that your contribution may be included in New API releases distributed under a Commercial License.
|
||||||
|
|
||||||
|
## **4. Other Terms**
|
||||||
|
|
||||||
|
- The specific terms, conditions, and pricing of the Commercial License are governed by the formal commercial license agreement executed by both parties.
|
||||||
|
- Project maintainers reserve the right to update this licensing policy as needed. Updates will be communicated via official project channels (e.g., repository, official website).
|
||||||
|
|||||||
@@ -65,6 +65,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
|||||||
apiType = constant.APITypeCoze
|
apiType = constant.APITypeCoze
|
||||||
case constant.ChannelTypeJimeng:
|
case constant.ChannelTypeJimeng:
|
||||||
apiType = constant.APITypeJimeng
|
apiType = constant.APITypeJimeng
|
||||||
|
case constant.ChannelTypeMoonshot:
|
||||||
|
apiType = constant.APITypeMoonshot
|
||||||
}
|
}
|
||||||
if apiType == -1 {
|
if apiType == -1 {
|
||||||
return constant.APITypeOpenAI, false
|
return constant.APITypeOpenAI, false
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ var GitHubClientId = ""
|
|||||||
var GitHubClientSecret = ""
|
var GitHubClientSecret = ""
|
||||||
var LinuxDOClientId = ""
|
var LinuxDOClientId = ""
|
||||||
var LinuxDOClientSecret = ""
|
var LinuxDOClientSecret = ""
|
||||||
|
var LinuxDOMinimumTrustLevel = 0
|
||||||
|
|
||||||
var WeChatServerAddress = ""
|
var WeChatServerAddress = ""
|
||||||
var WeChatServerToken = ""
|
var WeChatServerToken = ""
|
||||||
|
|||||||
19
common/copy.go
Normal file
19
common/copy.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jinzhu/copier"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DeepCopy[T any](src *T) (*T, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, fmt.Errorf("copy source cannot be nil")
|
||||||
|
}
|
||||||
|
var dst T
|
||||||
|
err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &dst, nil
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type stringWriter interface {
|
type stringWriter interface {
|
||||||
@@ -52,6 +53,8 @@ type CustomEvent struct {
|
|||||||
Id string
|
Id string
|
||||||
Retry uint
|
Retry uint
|
||||||
Data interface{}
|
Data interface{}
|
||||||
|
|
||||||
|
Mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func encode(writer io.Writer, event CustomEvent) error {
|
func encode(writer io.Writer, event CustomEvent) error {
|
||||||
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||||
|
r.Mutex.Lock()
|
||||||
|
defer r.Mutex.Unlock()
|
||||||
header := w.Header()
|
header := w.Header()
|
||||||
header["Content-Type"] = contentType
|
header["Content-Type"] = contentType
|
||||||
|
|
||||||
|
|||||||
@@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
|||||||
var UsingMySQL = false
|
var UsingMySQL = false
|
||||||
var UsingClickHouse = false
|
var UsingClickHouse = false
|
||||||
|
|
||||||
var SQLitePath = "one-api.db?_busy_timeout=5000"
|
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||||
|
|||||||
32
common/endpoint_defaults.go
Normal file
32
common/endpoint_defaults.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "one-api/constant"
|
||||||
|
|
||||||
|
// EndpointInfo 描述单个端点的默认请求信息
|
||||||
|
// path: 上游路径
|
||||||
|
// method: HTTP 请求方式,例如 POST/GET
|
||||||
|
// 目前均为 POST,后续可扩展
|
||||||
|
//
|
||||||
|
// json 标签用于直接序列化到 API 输出
|
||||||
|
// 例如:{"path":"/v1/chat/completions","method":"POST"}
|
||||||
|
|
||||||
|
type EndpointInfo struct {
|
||||||
|
Path string `json:"path"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
|
||||||
|
var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
|
||||||
|
constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
|
||||||
|
constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
|
||||||
|
constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
|
||||||
|
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
|
||||||
|
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
|
||||||
|
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
|
||||||
|
func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
|
||||||
|
info, ok := defaultEndpointInfoMap[et]
|
||||||
|
return info, ok
|
||||||
|
}
|
||||||
@@ -2,12 +2,13 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const KeyRequestBody = "key_request_body"
|
const KeyRequestBody = "key_request_body"
|
||||||
@@ -31,6 +32,9 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
//if DebugEnabled {
|
||||||
|
// println("UnmarshalBodyReusable request body:", string(requestBody))
|
||||||
|
//}
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = Unmarshal(requestBody, &v)
|
err = Unmarshal(requestBody, &v)
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func InitEnv() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func initConstantEnv() {
|
func initConstantEnv() {
|
||||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
|
||||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||||
|
|||||||
@@ -20,3 +20,25 @@ func DecodeJson(reader *bytes.Reader, v any) error {
|
|||||||
func Marshal(v any) ([]byte, error) {
|
func Marshal(v any) ([]byte, error) {
|
||||||
return json.Marshal(v)
|
return json.Marshal(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetJsonType(data json.RawMessage) string {
|
||||||
|
data = bytes.TrimSpace(data)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
firstChar := bytes.TrimSpace(data)[0]
|
||||||
|
switch firstChar {
|
||||||
|
case '{':
|
||||||
|
return "object"
|
||||||
|
case '[':
|
||||||
|
return "array"
|
||||||
|
case '"':
|
||||||
|
return "string"
|
||||||
|
case 't', 'f':
|
||||||
|
return "boolean"
|
||||||
|
case 'n':
|
||||||
|
return "null"
|
||||||
|
default:
|
||||||
|
return "number"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func (p *PageInfo) SetItems(items any) {
|
|||||||
func GetPageQuery(c *gin.Context) *PageInfo {
|
func GetPageQuery(c *gin.Context) *PageInfo {
|
||||||
pageInfo := &PageInfo{}
|
pageInfo := &PageInfo{}
|
||||||
// 手动获取并处理每个参数
|
// 手动获取并处理每个参数
|
||||||
if page, err := strconv.Atoi(c.Query("page")); err == nil {
|
if page, err := strconv.Atoi(c.Query("p")); err == nil {
|
||||||
pageInfo.Page = page
|
pageInfo.Page = page
|
||||||
}
|
}
|
||||||
if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
|
if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
|
||||||
|
|||||||
5
common/quota.go
Normal file
5
common/quota.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
func GetTrustQuota() int {
|
||||||
|
return int(10 * QuotaPerUnit)
|
||||||
|
}
|
||||||
140
common/str.go
140
common/str.go
@@ -4,7 +4,10 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -95,3 +98,140 @@ func GetJsonString(data any) string {
|
|||||||
b, _ := json.Marshal(data)
|
b, _ := json.Marshal(data)
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaskEmail masks a user email to prevent PII leakage in logs
|
||||||
|
// Returns "***masked***" if email is empty, otherwise shows only the domain part
|
||||||
|
func MaskEmail(email string) string {
|
||||||
|
if email == "" {
|
||||||
|
return "***masked***"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the @ symbol
|
||||||
|
atIndex := strings.Index(email, "@")
|
||||||
|
if atIndex == -1 {
|
||||||
|
// No @ symbol found, return masked
|
||||||
|
return "***masked***"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return only the domain part with @ symbol
|
||||||
|
return "***@" + email[atIndex+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostTail returns the tail parts of a domain/host that should be preserved.
|
||||||
|
// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
|
||||||
|
func maskHostTail(parts []string) []string {
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
lastPart := parts[len(parts)-1]
|
||||||
|
secondLastPart := parts[len(parts)-2]
|
||||||
|
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
||||||
|
// Likely country code TLD like co.uk, com.cn
|
||||||
|
return []string{secondLastPart, lastPart}
|
||||||
|
}
|
||||||
|
return []string{lastPart}
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
|
||||||
|
// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
|
||||||
|
func maskHostForURL(host string) string {
|
||||||
|
parts := strings.Split(host, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "***"
|
||||||
|
}
|
||||||
|
tail := maskHostTail(parts)
|
||||||
|
return "***." + strings.Join(tail, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
|
||||||
|
// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
|
||||||
|
func maskHostForPlainDomain(domain string) string {
|
||||||
|
parts := strings.Split(domain, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return domain
|
||||||
|
}
|
||||||
|
tail := maskHostTail(parts)
|
||||||
|
numStars := len(parts) - len(tail)
|
||||||
|
if numStars < 1 {
|
||||||
|
numStars = 1
|
||||||
|
}
|
||||||
|
stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
|
||||||
|
return stars + "." + strings.Join(tail, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
|
||||||
|
// Example:
|
||||||
|
// http://example.com -> http://***.com
|
||||||
|
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
||||||
|
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
||||||
|
// 192.168.1.1 -> ***.***.***.***
|
||||||
|
// openai.com -> ***.com
|
||||||
|
// www.openai.com -> ***.***.com
|
||||||
|
// api.openai.com -> ***.***.com
|
||||||
|
func MaskSensitiveInfo(str string) string {
|
||||||
|
// Mask URLs
|
||||||
|
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||||
|
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||||
|
u, err := url.Parse(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
return urlStr
|
||||||
|
}
|
||||||
|
|
||||||
|
host := u.Host
|
||||||
|
if host == "" {
|
||||||
|
return urlStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask host with unified logic
|
||||||
|
maskedHost := maskHostForURL(host)
|
||||||
|
|
||||||
|
result := u.Scheme + "://" + maskedHost
|
||||||
|
|
||||||
|
// Mask path
|
||||||
|
if u.Path != "" && u.Path != "/" {
|
||||||
|
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
|
||||||
|
maskedPathParts := make([]string, len(pathParts))
|
||||||
|
for i := range pathParts {
|
||||||
|
if pathParts[i] != "" {
|
||||||
|
maskedPathParts[i] = "***"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(maskedPathParts) > 0 {
|
||||||
|
result += "/" + strings.Join(maskedPathParts, "/")
|
||||||
|
}
|
||||||
|
} else if u.Path == "/" {
|
||||||
|
result += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask query parameters
|
||||||
|
if u.RawQuery != "" {
|
||||||
|
values, err := url.ParseQuery(u.RawQuery)
|
||||||
|
if err != nil {
|
||||||
|
// If can't parse query, just mask the whole query string
|
||||||
|
result += "?***"
|
||||||
|
} else {
|
||||||
|
maskedParams := make([]string, 0, len(values))
|
||||||
|
for key := range values {
|
||||||
|
maskedParams = append(maskedParams, key+"=***")
|
||||||
|
}
|
||||||
|
if len(maskedParams) > 0 {
|
||||||
|
result += "?" + strings.Join(maskedParams, "&")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mask domain names without protocol (like openai.com, www.openai.com)
|
||||||
|
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||||
|
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
||||||
|
return maskHostForPlainDomain(domain)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mask IP addresses
|
||||||
|
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||||
|
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||||
|
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|||||||
24
common/sys_log.go
Normal file
24
common/sys_log.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SysLog(s string) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SysError(s string) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FatalLog(v ...any) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
150
common/totp.go
Normal file
150
common/totp.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pquerna/otp"
|
||||||
|
"github.com/pquerna/otp/totp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 备用码配置
|
||||||
|
BackupCodeLength = 8 // 备用码长度
|
||||||
|
BackupCodeCount = 4 // 生成备用码数量
|
||||||
|
|
||||||
|
// 限制配置
|
||||||
|
MaxFailAttempts = 5 // 最大失败尝试次数
|
||||||
|
LockoutDuration = 300 // 锁定时间(秒)
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateTOTPSecret 生成TOTP密钥和配置
|
||||||
|
func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
|
||||||
|
issuer := Get2FAIssuer()
|
||||||
|
return totp.Generate(totp.GenerateOpts{
|
||||||
|
Issuer: issuer,
|
||||||
|
AccountName: accountName,
|
||||||
|
Period: 30,
|
||||||
|
Digits: otp.DigitsSix,
|
||||||
|
Algorithm: otp.AlgorithmSHA1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTOTPCode 验证TOTP验证码
|
||||||
|
func ValidateTOTPCode(secret, code string) bool {
|
||||||
|
// 清理验证码格式
|
||||||
|
cleanCode := strings.ReplaceAll(code, " ", "")
|
||||||
|
if len(cleanCode) != 6 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证验证码
|
||||||
|
return totp.Validate(cleanCode, secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateBackupCodes 生成备用恢复码
|
||||||
|
func GenerateBackupCodes() ([]string, error) {
|
||||||
|
codes := make([]string, BackupCodeCount)
|
||||||
|
|
||||||
|
for i := 0; i < BackupCodeCount; i++ {
|
||||||
|
code, err := generateRandomBackupCode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
codes[i] = code
|
||||||
|
}
|
||||||
|
|
||||||
|
return codes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomBackupCode 生成单个备用码
|
||||||
|
func generateRandomBackupCode() (string, error) {
|
||||||
|
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
code := make([]byte, BackupCodeLength)
|
||||||
|
|
||||||
|
for i := range code {
|
||||||
|
randomBytes := make([]byte, 1)
|
||||||
|
_, err := rand.Read(randomBytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
code[i] = charset[int(randomBytes[0])%len(charset)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 格式化为 XXXX-XXXX 格式
|
||||||
|
return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateBackupCode 验证备用码格式
|
||||||
|
func ValidateBackupCode(code string) bool {
|
||||||
|
// 移除所有分隔符并转为大写
|
||||||
|
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||||
|
if len(cleanCode) != BackupCodeLength {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查字符是否合法
|
||||||
|
for _, char := range cleanCode {
|
||||||
|
if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeBackupCode 标准化备用码格式
|
||||||
|
func NormalizeBackupCode(code string) string {
|
||||||
|
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||||
|
if len(cleanCode) == BackupCodeLength {
|
||||||
|
return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
|
||||||
|
}
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashBackupCode 对备用码进行哈希
|
||||||
|
func HashBackupCode(code string) (string, error) {
|
||||||
|
normalizedCode := NormalizeBackupCode(code)
|
||||||
|
return Password2Hash(normalizedCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FAIssuer 获取2FA发行者名称
|
||||||
|
func Get2FAIssuer() string {
|
||||||
|
return SystemName
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvOrDefault 获取环境变量或默认值
|
||||||
|
func getEnvOrDefault(key, defaultValue string) string {
|
||||||
|
if value, exists := os.LookupEnv(key); exists {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateNumericCode 验证数字验证码格式
|
||||||
|
func ValidateNumericCode(code string) (string, error) {
|
||||||
|
// 移除空格
|
||||||
|
code = strings.ReplaceAll(code, " ", "")
|
||||||
|
|
||||||
|
if len(code) != 6 {
|
||||||
|
return "", fmt.Errorf("验证码必须是6位数字")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否为纯数字
|
||||||
|
if _, err := strconv.Atoi(code); err != nil {
|
||||||
|
return "", fmt.Errorf("验证码只能包含数字")
|
||||||
|
}
|
||||||
|
|
||||||
|
return code, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateQRCodeData 生成二维码数据
|
||||||
|
func GenerateQRCodeData(secret, username string) string {
|
||||||
|
issuer := Get2FAIssuer()
|
||||||
|
accountName := fmt.Sprintf("%s (%s)", username, issuer)
|
||||||
|
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
|
||||||
|
issuer, accountName, secret, issuer)
|
||||||
|
}
|
||||||
@@ -257,32 +257,32 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||||
}
|
}
|
||||||
durationStr := string(bytes.TrimSpace(output))
|
durationStr := string(bytes.TrimSpace(output))
|
||||||
if durationStr == "N/A" {
|
if durationStr == "N/A" {
|
||||||
// Create a temporary output file name
|
// Create a temporary output file name
|
||||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||||
}
|
}
|
||||||
tmpName := tmpFp.Name()
|
tmpName := tmpFp.Name()
|
||||||
// Close immediately so ffmpeg can open the file on Windows.
|
// Close immediately so ffmpeg can open the file on Windows.
|
||||||
_ = tmpFp.Close()
|
_ = tmpFp.Close()
|
||||||
defer os.Remove(tmpName)
|
defer os.Remove(tmpName)
|
||||||
|
|
||||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||||
if err := ffmpegCmd.Run(); err != nil {
|
if err := ffmpegCmd.Run(); err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recalculate the duration of the new file
|
// Recalculate the duration of the new file
|
||||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||||
output, err := c.Output()
|
output, err := c.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||||
}
|
}
|
||||||
durationStr = string(bytes.TrimSpace(output))
|
durationStr = string(bytes.TrimSpace(output))
|
||||||
}
|
}
|
||||||
return strconv.ParseFloat(durationStr, 64)
|
return strconv.ParseFloat(durationStr, 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,5 +31,6 @@ const (
|
|||||||
APITypeXai
|
APITypeXai
|
||||||
APITypeCoze
|
APITypeCoze
|
||||||
APITypeJimeng
|
APITypeJimeng
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeMoonshot // this one is only for count, do not add any channel after this
|
||||||
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ const (
|
|||||||
ChannelTypeCoze = 49
|
ChannelTypeCoze = 49
|
||||||
ChannelTypeKling = 50
|
ChannelTypeKling = 50
|
||||||
ChannelTypeJimeng = 51
|
ChannelTypeJimeng = 51
|
||||||
|
ChannelTypeVidu = 52
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
@@ -106,4 +107,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.coze.cn", //49
|
"https://api.coze.cn", //49
|
||||||
"https://api.klingai.com", //50
|
"https://api.klingai.com", //50
|
||||||
"https://visual.volcengineapi.com", //51
|
"https://visual.volcengineapi.com", //51
|
||||||
|
"https://api.vidu.cn", //52
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package constant
|
|||||||
type ContextKey string
|
type ContextKey string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
ContextKeyTokenCountMeta ContextKey = "token_count_meta"
|
||||||
|
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||||
|
|
||||||
ContextKeyOriginalModel ContextKey = "original_model"
|
ContextKeyOriginalModel ContextKey = "original_model"
|
||||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||||
|
|
||||||
@@ -11,7 +14,6 @@ const (
|
|||||||
ContextKeyTokenKey ContextKey = "token_key"
|
ContextKeyTokenKey ContextKey = "token_key"
|
||||||
ContextKeyTokenId ContextKey = "token_id"
|
ContextKeyTokenId ContextKey = "token_id"
|
||||||
ContextKeyTokenGroup ContextKey = "token_group"
|
ContextKeyTokenGroup ContextKey = "token_group"
|
||||||
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
|
||||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||||
@@ -23,7 +25,9 @@ const (
|
|||||||
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||||
ContextKeyChannelType ContextKey = "channel_type"
|
ContextKeyChannelType ContextKey = "channel_type"
|
||||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||||
|
ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
|
||||||
ContextKeyChannelParamOverride ContextKey = "param_override"
|
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||||
|
ContextKeyChannelHeaderOverride ContextKey = "header_override"
|
||||||
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||||
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||||
@@ -41,4 +45,6 @@ const (
|
|||||||
ContextKeyUserGroup ContextKey = "user_group"
|
ContextKeyUserGroup ContextKey = "user_group"
|
||||||
ContextKeyUsingGroup ContextKey = "group"
|
ContextKeyUsingGroup ContextKey = "group"
|
||||||
ContextKeyUserName ContextKey = "username"
|
ContextKeyUserName ContextKey = "username"
|
||||||
|
|
||||||
|
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ type TaskPlatform string
|
|||||||
const (
|
const (
|
||||||
TaskPlatformSuno TaskPlatform = "suno"
|
TaskPlatformSuno TaskPlatform = "suno"
|
||||||
TaskPlatformMidjourney = "mj"
|
TaskPlatformMidjourney = "mj"
|
||||||
TaskPlatformKling TaskPlatform = "kling"
|
|
||||||
TaskPlatformJimeng TaskPlatform = "jimeng"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -135,7 +135,11 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
|||||||
for k := range headers {
|
for k := range headers {
|
||||||
req.Header.Add(k, headers.Get(k))
|
req.Header.Add(k, headers.Get(k))
|
||||||
}
|
}
|
||||||
res, err := service.GetHttpClient().Do(req)
|
client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -69,6 +69,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
newAPIError: nil,
|
newAPIError: nil,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if channel.Type == constant.ChannelTypeVidu {
|
||||||
|
return testResult{
|
||||||
|
localErr: errors.New("vidu channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
@@ -126,10 +132,27 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
newAPIError: newAPIError,
|
newAPIError: newAPIError,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
request := buildTestRequest(testModel)
|
||||||
|
|
||||||
info := relaycommon.GenRelayInfo(c)
|
// Determine relay format based on request path
|
||||||
|
relayFormat := types.RelayFormatOpenAI
|
||||||
|
if c.Request.URL.Path == "/v1/embeddings" {
|
||||||
|
relayFormat = types.RelayFormatEmbedding
|
||||||
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, info, nil)
|
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info.InitChannelMeta(c)
|
||||||
|
|
||||||
|
err = helper.ModelMappedHelper(c, info, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
@@ -137,7 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
testModel = info.UpstreamModelName
|
testModel = info.UpstreamModelName
|
||||||
|
request.Model = testModel
|
||||||
|
|
||||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
@@ -149,13 +174,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildTestRequest(testModel)
|
//// 创建一个用于日志的 info 副本,移除 ApiKey
|
||||||
// 创建一个用于日志的 info 副本,移除 ApiKey
|
//logInfo := info
|
||||||
logInfo := *info
|
//logInfo.ApiKey = ""
|
||||||
logInfo.ApiKey = ""
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
|
||||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
@@ -203,7 +227,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
localErr: err,
|
localErr: err,
|
||||||
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
@@ -214,7 +238,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
localErr: err,
|
localErr: err,
|
||||||
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -230,7 +254,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
localErr: errors.New("usage is nil"),
|
localErr: errors.New("usage is nil"),
|
||||||
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
|
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
usage := usageA.(*dto.Usage)
|
usage := usageA.(*dto.Usage)
|
||||||
@@ -240,7 +264,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
localErr: err,
|
localErr: err,
|
||||||
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
info.PromptTokens = usage.PromptTokens
|
info.PromptTokens = usage.PromptTokens
|
||||||
@@ -269,7 +293,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
Quota: quota,
|
Quota: quota,
|
||||||
Content: "模型测试",
|
Content: "模型测试",
|
||||||
UseTimeSeconds: int(consumedTime),
|
UseTimeSeconds: int(consumedTime),
|
||||||
IsStream: false,
|
IsStream: info.IsStream,
|
||||||
Group: info.UsingGroup,
|
Group: info.UsingGroup,
|
||||||
Other: other,
|
Other: other,
|
||||||
})
|
})
|
||||||
@@ -326,8 +350,11 @@ func TestChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
channel, err := model.CacheGetChannel(channelId)
|
channel, err := model.CacheGetChannel(channelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
channel, err = model.GetChannelById(channelId, true)
|
||||||
return
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
//defer func() {
|
//defer func() {
|
||||||
// if channel.ChannelInfo.IsMultiKey {
|
// if channel.ChannelInfo.IsMultiKey {
|
||||||
@@ -411,14 +438,14 @@ func testAllChannels(notify bool) error {
|
|||||||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||||
if milliseconds > disableThreshold {
|
if milliseconds > disableThreshold {
|
||||||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
|
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||||||
shouldBanChannel = true
|
shouldBanChannel = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// disable channel
|
// disable channel
|
||||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||||
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// enable channel
|
// enable channel
|
||||||
|
|||||||
@@ -52,6 +52,13 @@ func parseStatusFilter(statusParam string) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func clearChannelInfo(channel *model.Channel) {
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledReason = nil
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledTime = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetAllChannels(c *gin.Context) {
|
func GetAllChannels(c *gin.Context) {
|
||||||
pageInfo := common.GetPageQuery(c)
|
pageInfo := common.GetPageQuery(c)
|
||||||
channelData := make([]*model.Channel, 0)
|
channelData := make([]*model.Channel, 0)
|
||||||
@@ -126,6 +133,10 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, datum := range channelData {
|
||||||
|
clearChannelInfo(datum)
|
||||||
|
}
|
||||||
|
|
||||||
countQuery := model.DB.Model(&model.Channel{})
|
countQuery := model.DB.Model(&model.Channel{})
|
||||||
if statusFilter == common.ChannelStatusEnabled {
|
if statusFilter == common.ChannelStatusEnabled {
|
||||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||||
@@ -168,14 +179,26 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
|
||||||
|
var url string
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case constant.ChannelTypeGemini:
|
case constant.ChannelTypeGemini:
|
||||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
|
||||||
|
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
|
||||||
case constant.ChannelTypeAli:
|
case constant.ChannelTypeAli:
|
||||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||||
|
default:
|
||||||
|
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
||||||
|
var body []byte
|
||||||
|
key := strings.Split(channel.Key, "\n")[0]
|
||||||
|
if channel.Type == constant.ChannelTypeGemini {
|
||||||
|
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it
|
||||||
|
} else {
|
||||||
|
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
|
||||||
}
|
}
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
@@ -319,6 +342,10 @@ func SearchChannels(c *gin.Context) {
|
|||||||
|
|
||||||
pagedData := channelData[startIdx:endIdx]
|
pagedData := channelData[startIdx:endIdx]
|
||||||
|
|
||||||
|
for _, datum := range pagedData {
|
||||||
|
clearChannelInfo(datum)
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@@ -342,6 +369,9 @@ func GetChannel(c *gin.Context) {
|
|||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if channel != nil {
|
||||||
|
clearChannelInfo(channel)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@@ -350,6 +380,85 @@ func GetChannel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetChannelKey 验证2FA后获取渠道密钥
|
||||||
|
func GetChannelKey(c *gin.Context) {
|
||||||
|
type GetChannelKeyRequest struct {
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req GetChannelKeyRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
common.ApiError(c, fmt.Errorf("参数错误: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
channelId, err := strconv.Atoi(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取2FA记录并验证
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, fmt.Errorf("获取2FA信息失败: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
common.ApiError(c, fmt.Errorf("用户未启用2FA,无法查看密钥"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统一的2FA验证逻辑
|
||||||
|
if !validateTwoFactorAuth(twoFA, req.Code) {
|
||||||
|
common.ApiError(c, fmt.Errorf("验证码或备用码错误,请重试"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取渠道信息(包含密钥)
|
||||||
|
channel, err := model.GetChannelById(channelId, true)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if channel == nil {
|
||||||
|
common.ApiError(c, fmt.Errorf("渠道不存在"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
|
||||||
|
|
||||||
|
// 统一的成功响应格式
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "验证成功",
|
||||||
|
"data": map[string]interface{}{
|
||||||
|
"key": channel.Key,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateTwoFactorAuth 统一的2FA验证函数
|
||||||
|
func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
|
||||||
|
// 尝试验证TOTP
|
||||||
|
if cleanCode, err := common.ValidateNumericCode(code); err == nil {
|
||||||
|
if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试验证备用码
|
||||||
|
if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// validateChannel 通用的渠道校验函数
|
// validateChannel 通用的渠道校验函数
|
||||||
func validateChannel(channel *model.Channel, isAdd bool) error {
|
func validateChannel(channel *model.Channel, isAdd bool) error {
|
||||||
// 校验 channel settings
|
// 校验 channel settings
|
||||||
@@ -669,6 +778,7 @@ func DeleteChannelBatch(c *gin.Context) {
|
|||||||
type PatchChannel struct {
|
type PatchChannel struct {
|
||||||
model.Channel
|
model.Channel
|
||||||
MultiKeyMode *string `json:"multi_key_mode"`
|
MultiKeyMode *string `json:"multi_key_mode"`
|
||||||
|
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannel(c *gin.Context) {
|
func UpdateChannel(c *gin.Context) {
|
||||||
@@ -688,7 +798,7 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||||||
originChannel, err := model.GetChannelById(channel.Id, false)
|
originChannel, err := model.GetChannelById(channel.Id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -704,6 +814,69 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||||||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理多key模式下的密钥追加/覆盖逻辑
|
||||||
|
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
|
||||||
|
switch *channel.KeyMode {
|
||||||
|
case "append":
|
||||||
|
// 追加模式:将新密钥添加到现有密钥列表
|
||||||
|
if originChannel.Key != "" {
|
||||||
|
var newKeys []string
|
||||||
|
var existingKeys []string
|
||||||
|
|
||||||
|
// 解析现有密钥
|
||||||
|
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
|
||||||
|
// JSON数组格式
|
||||||
|
var arr []json.RawMessage
|
||||||
|
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
|
||||||
|
existingKeys = make([]string, len(arr))
|
||||||
|
for i, v := range arr {
|
||||||
|
existingKeys[i] = string(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 换行分隔格式
|
||||||
|
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 Vertex AI 的特殊情况
|
||||||
|
if channel.Type == constant.ChannelTypeVertexAi {
|
||||||
|
// 尝试解析新密钥为JSON数组
|
||||||
|
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||||
|
array, err := getVertexArrayKeys(channel.Key)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "追加密钥解析失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newKeys = array
|
||||||
|
} else {
|
||||||
|
// 单个JSON密钥
|
||||||
|
newKeys = []string{channel.Key}
|
||||||
|
}
|
||||||
|
// 合并密钥
|
||||||
|
allKeys := append(existingKeys, newKeys...)
|
||||||
|
channel.Key = strings.Join(allKeys, "\n")
|
||||||
|
} else {
|
||||||
|
// 普通渠道的处理
|
||||||
|
inputKeys := strings.Split(channel.Key, "\n")
|
||||||
|
for _, key := range inputKeys {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key != "" {
|
||||||
|
newKeys = append(newKeys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 合并密钥
|
||||||
|
allKeys := append(existingKeys, newKeys...)
|
||||||
|
channel.Key = strings.Join(allKeys, "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "replace":
|
||||||
|
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
|
||||||
|
}
|
||||||
|
}
|
||||||
err = channel.Update()
|
err = channel.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
@@ -711,6 +884,7 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
model.InitChannelCache()
|
model.InitChannelCache()
|
||||||
channel.Key = ""
|
channel.Key = ""
|
||||||
|
clearChannelInfo(&channel.Channel)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@@ -914,3 +1088,413 @@ func CopyChannel(c *gin.Context) {
|
|||||||
// success
|
// success
|
||||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MultiKeyManageRequest represents the request for multi-key management operations
|
||||||
|
type MultiKeyManageRequest struct {
|
||||||
|
ChannelId int `json:"channel_id"`
|
||||||
|
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
|
||||||
|
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
|
||||||
|
Page int `json:"page,omitempty"` // for get_key_status pagination
|
||||||
|
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
|
||||||
|
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiKeyStatusResponse represents the response for key status query
|
||||||
|
type MultiKeyStatusResponse struct {
|
||||||
|
Keys []KeyStatus `json:"keys"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
TotalPages int `json:"total_pages"`
|
||||||
|
// Statistics
|
||||||
|
EnabledCount int `json:"enabled_count"`
|
||||||
|
ManualDisabledCount int `json:"manual_disabled_count"`
|
||||||
|
AutoDisabledCount int `json:"auto_disabled_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyStatus struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Status int `json:"status"` // 1: enabled, 2: disabled
|
||||||
|
DisabledTime int64 `json:"disabled_time,omitempty"`
|
||||||
|
Reason string `json:"reason,omitempty"`
|
||||||
|
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManageMultiKeys handles multi-key management operations
|
||||||
|
func ManageMultiKeys(c *gin.Context) {
|
||||||
|
request := MultiKeyManageRequest{}
|
||||||
|
err := c.ShouldBindJSON(&request)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, err := model.GetChannelById(request.ChannelId, true)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "渠道不存在",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !channel.ChannelInfo.IsMultiKey {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "该渠道不是多密钥模式",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lock := model.GetChannelPollingLock(channel.Id)
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
switch request.Action {
|
||||||
|
case "get_key_status":
|
||||||
|
keys := channel.GetKeys()
|
||||||
|
|
||||||
|
// Default pagination parameters
|
||||||
|
page := request.Page
|
||||||
|
pageSize := request.PageSize
|
||||||
|
if page <= 0 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize <= 0 {
|
||||||
|
pageSize = 50 // Default page size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Statistics for all keys (unchanged by filtering)
|
||||||
|
var enabledCount, manualDisabledCount, autoDisabledCount int
|
||||||
|
|
||||||
|
// Build all key status data first
|
||||||
|
var allKeyStatusList []KeyStatus
|
||||||
|
for i, key := range keys {
|
||||||
|
status := 1 // default enabled
|
||||||
|
var disabledTime int64
|
||||||
|
var reason string
|
||||||
|
|
||||||
|
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||||
|
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||||
|
status = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count for statistics (all keys)
|
||||||
|
switch status {
|
||||||
|
case 1:
|
||||||
|
enabledCount++
|
||||||
|
case 2:
|
||||||
|
manualDisabledCount++
|
||||||
|
case 3:
|
||||||
|
autoDisabledCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if status != 1 {
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||||
|
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||||
|
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create key preview (first 10 chars)
|
||||||
|
keyPreview := key
|
||||||
|
if len(key) > 10 {
|
||||||
|
keyPreview = key[:10] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
allKeyStatusList = append(allKeyStatusList, KeyStatus{
|
||||||
|
Index: i,
|
||||||
|
Status: status,
|
||||||
|
DisabledTime: disabledTime,
|
||||||
|
Reason: reason,
|
||||||
|
KeyPreview: keyPreview,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply status filter if specified
|
||||||
|
var filteredKeyStatusList []KeyStatus
|
||||||
|
if request.Status != nil {
|
||||||
|
for _, keyStatus := range allKeyStatusList {
|
||||||
|
if keyStatus.Status == *request.Status {
|
||||||
|
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
filteredKeyStatusList = allKeyStatusList
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate pagination based on filtered results
|
||||||
|
filteredTotal := len(filteredKeyStatusList)
|
||||||
|
totalPages := (filteredTotal + pageSize - 1) / pageSize
|
||||||
|
if totalPages == 0 {
|
||||||
|
totalPages = 1
|
||||||
|
}
|
||||||
|
if page > totalPages {
|
||||||
|
page = totalPages
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate range for current page
|
||||||
|
start := (page - 1) * pageSize
|
||||||
|
end := start + pageSize
|
||||||
|
if end > filteredTotal {
|
||||||
|
end = filteredTotal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the page data
|
||||||
|
var pageKeyStatusList []KeyStatus
|
||||||
|
if start < filteredTotal {
|
||||||
|
pageKeyStatusList = filteredKeyStatusList[start:end]
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": MultiKeyStatusResponse{
|
||||||
|
Keys: pageKeyStatusList,
|
||||||
|
Total: filteredTotal, // Total of filtered results
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
TotalPages: totalPages,
|
||||||
|
EnabledCount: enabledCount, // Overall statistics
|
||||||
|
ManualDisabledCount: manualDisabledCount, // Overall statistics
|
||||||
|
AutoDisabledCount: autoDisabledCount, // Overall statistics
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
case "disable_key":
|
||||||
|
if request.KeyIndex == nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "未指定要禁用的密钥索引",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
keyIndex := *request.KeyIndex
|
||||||
|
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "密钥索引超出范围",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||||
|
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
|
||||||
|
|
||||||
|
err = channel.Update()
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model.InitChannelCache()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "密钥已禁用",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
case "enable_key":
|
||||||
|
if request.KeyIndex == nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "未指定要启用的密钥索引",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
keyIndex := *request.KeyIndex
|
||||||
|
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "密钥索引超出范围",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
|
||||||
|
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||||
|
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||||
|
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||||
|
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = channel.Update()
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model.InitChannelCache()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "密钥已启用",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
case "enable_all_keys":
|
||||||
|
// 清空所有禁用状态,使所有密钥回到默认启用状态
|
||||||
|
var enabledCount int
|
||||||
|
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||||
|
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
|
||||||
|
}
|
||||||
|
|
||||||
|
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||||
|
|
||||||
|
err = channel.Update()
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model.InitChannelCache()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
case "disable_all_keys":
|
||||||
|
// 禁用所有启用的密钥
|
||||||
|
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||||
|
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
var disabledCount int
|
||||||
|
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
|
||||||
|
status := 1 // default enabled
|
||||||
|
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||||
|
status = s
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只禁用当前启用的密钥
|
||||||
|
if status == 1 {
|
||||||
|
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
|
||||||
|
disabledCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if disabledCount == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "没有可禁用的密钥",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = channel.Update()
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model.InitChannelCache()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
case "delete_disabled_keys":
|
||||||
|
keys := channel.GetKeys()
|
||||||
|
var remainingKeys []string
|
||||||
|
var deletedCount int
|
||||||
|
var newStatusList = make(map[int]int)
|
||||||
|
var newDisabledTime = make(map[int]int64)
|
||||||
|
var newDisabledReason = make(map[int]string)
|
||||||
|
|
||||||
|
newIndex := 0
|
||||||
|
for i, key := range keys {
|
||||||
|
status := 1 // default enabled
|
||||||
|
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||||
|
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||||
|
status = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
|
||||||
|
if status == 3 {
|
||||||
|
deletedCount++
|
||||||
|
} else {
|
||||||
|
remainingKeys = append(remainingKeys, key)
|
||||||
|
// 保留非自动禁用密钥的状态信息,重新索引
|
||||||
|
if status != 1 {
|
||||||
|
newStatusList[newIndex] = status
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||||
|
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
|
||||||
|
newDisabledTime[newIndex] = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||||
|
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
|
||||||
|
newDisabledReason[newIndex] = r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newIndex++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if deletedCount == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "没有需要删除的自动禁用密钥",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update channel with remaining keys
|
||||||
|
channel.Key = strings.Join(remainingKeys, "\n")
|
||||||
|
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
|
||||||
|
channel.ChannelInfo.MultiKeyStatusList = newStatusList
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
|
||||||
|
|
||||||
|
err = channel.Update()
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model.InitChannelCache()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
|
||||||
|
"data": deletedCount,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "不支持的操作",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,101 +3,102 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
||||||
func MigrateConsoleSetting(c *gin.Context) {
|
func MigrateConsoleSetting(c *gin.Context) {
|
||||||
// 读取全部 option
|
// 读取全部 option
|
||||||
opts, err := model.AllOption()
|
opts, err := model.AllOption()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 建立 map
|
// 建立 map
|
||||||
valMap := map[string]string{}
|
valMap := map[string]string{}
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
valMap[o.Key] = o.Value
|
valMap[o.Key] = o.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理 APIInfo
|
// 处理 APIInfo
|
||||||
if v := valMap["ApiInfo"]; v != "" {
|
if v := valMap["ApiInfo"]; v != "" {
|
||||||
var arr []map[string]interface{}
|
var arr []map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||||
if len(arr) > 50 {
|
if len(arr) > 50 {
|
||||||
arr = arr[:50]
|
arr = arr[:50]
|
||||||
}
|
}
|
||||||
bytes, _ := json.Marshal(arr)
|
bytes, _ := json.Marshal(arr)
|
||||||
model.UpdateOption("console_setting.api_info", string(bytes))
|
model.UpdateOption("console_setting.api_info", string(bytes))
|
||||||
}
|
}
|
||||||
model.UpdateOption("ApiInfo", "")
|
model.UpdateOption("ApiInfo", "")
|
||||||
}
|
}
|
||||||
// Announcements 直接搬
|
// Announcements 直接搬
|
||||||
if v := valMap["Announcements"]; v != "" {
|
if v := valMap["Announcements"]; v != "" {
|
||||||
model.UpdateOption("console_setting.announcements", v)
|
model.UpdateOption("console_setting.announcements", v)
|
||||||
model.UpdateOption("Announcements", "")
|
model.UpdateOption("Announcements", "")
|
||||||
}
|
}
|
||||||
// FAQ 转换
|
// FAQ 转换
|
||||||
if v := valMap["FAQ"]; v != "" {
|
if v := valMap["FAQ"]; v != "" {
|
||||||
var arr []map[string]interface{}
|
var arr []map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||||
out := []map[string]interface{}{}
|
out := []map[string]interface{}{}
|
||||||
for _, item := range arr {
|
for _, item := range arr {
|
||||||
q, _ := item["question"].(string)
|
q, _ := item["question"].(string)
|
||||||
if q == "" {
|
if q == "" {
|
||||||
q, _ = item["title"].(string)
|
q, _ = item["title"].(string)
|
||||||
}
|
}
|
||||||
a, _ := item["answer"].(string)
|
a, _ := item["answer"].(string)
|
||||||
if a == "" {
|
if a == "" {
|
||||||
a, _ = item["content"].(string)
|
a, _ = item["content"].(string)
|
||||||
}
|
}
|
||||||
if q != "" && a != "" {
|
if q != "" && a != "" {
|
||||||
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(out) > 50 {
|
if len(out) > 50 {
|
||||||
out = out[:50]
|
out = out[:50]
|
||||||
}
|
}
|
||||||
bytes, _ := json.Marshal(out)
|
bytes, _ := json.Marshal(out)
|
||||||
model.UpdateOption("console_setting.faq", string(bytes))
|
model.UpdateOption("console_setting.faq", string(bytes))
|
||||||
}
|
}
|
||||||
model.UpdateOption("FAQ", "")
|
model.UpdateOption("FAQ", "")
|
||||||
}
|
}
|
||||||
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
||||||
url := valMap["UptimeKumaUrl"]
|
url := valMap["UptimeKumaUrl"]
|
||||||
slug := valMap["UptimeKumaSlug"]
|
slug := valMap["UptimeKumaSlug"]
|
||||||
if url != "" && slug != "" {
|
if url != "" && slug != "" {
|
||||||
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
||||||
groups := []map[string]interface{}{
|
groups := []map[string]interface{}{
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"categoryName": "old",
|
"categoryName": "old",
|
||||||
"url": url,
|
"url": url,
|
||||||
"slug": slug,
|
"slug": slug,
|
||||||
"description": "",
|
"description": "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
bytes, _ := json.Marshal(groups)
|
bytes, _ := json.Marshal(groups)
|
||||||
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
||||||
}
|
}
|
||||||
// 清空旧键内容
|
// 清空旧键内容
|
||||||
if url != "" {
|
if url != "" {
|
||||||
model.UpdateOption("UptimeKumaUrl", "")
|
model.UpdateOption("UptimeKumaUrl", "")
|
||||||
}
|
}
|
||||||
if slug != "" {
|
if slug != "" {
|
||||||
model.UpdateOption("UptimeKumaSlug", "")
|
model.UpdateOption("UptimeKumaSlug", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 删除旧键记录
|
// 删除旧键记录
|
||||||
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
||||||
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
||||||
|
|
||||||
// 重新加载 OptionMap
|
// 重新加载 OptionMap
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
common.SysLog("console setting migrated")
|
common.SysLog("console setting migrated")
|
||||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -220,21 +220,29 @@ func LinuxdoOAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if common.RegisterEnabled {
|
if common.RegisterEnabled {
|
||||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
|
||||||
user.DisplayName = linuxdoUser.Name
|
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
user.Role = common.RoleCommonUser
|
user.DisplayName = linuxdoUser.Name
|
||||||
user.Status = common.UserStatusEnabled
|
user.Role = common.RoleCommonUser
|
||||||
|
user.Status = common.UserStatusEnabled
|
||||||
|
|
||||||
affCode := session.Get("aff")
|
affCode := session.Get("aff")
|
||||||
inviterId := 0
|
inviterId := 0
|
||||||
if affCode != nil {
|
if affCode != nil {
|
||||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.Insert(inviterId); err != nil {
|
if err := user.Insert(inviterId); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
@@ -28,7 +29,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||||
taskChannelM := make(map[int][]string)
|
taskChannelM := make(map[int][]string)
|
||||||
taskM := make(map[string]*model.Midjourney)
|
taskM := make(map[string]*model.Midjourney)
|
||||||
nullTaskIds := make([]int, 0)
|
nullTaskIds := make([]int, 0)
|
||||||
@@ -47,9 +48,9 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
||||||
} else {
|
} else {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(taskChannelM) == 0 {
|
if len(taskChannelM) == 0 {
|
||||||
@@ -57,20 +58,20 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for channelId, taskIds := range taskChannelM {
|
for channelId, taskIds := range taskChannelM {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||||
if len(taskIds) == 0 {
|
if len(taskIds) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
||||||
err := model.MjBulkUpdate(taskIds, map[string]any{
|
err := model.MjBulkUpdate(taskIds, map[string]any{
|
||||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -81,7 +82,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
})
|
})
|
||||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 设置超时时间
|
// 设置超时时间
|
||||||
@@ -93,22 +94,22 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||||
resp, err := service.GetHttpClient().Do(req)
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var responseItems []dto.MidjourneyDto
|
var responseItems []dto.MidjourneyDto
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
@@ -145,9 +146,25 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||||
task.Buttons = string(buttonStr)
|
task.Buttons = string(buttonStr)
|
||||||
}
|
}
|
||||||
|
// 映射 VideoUrl
|
||||||
|
task.VideoUrl = responseItem.VideoUrl
|
||||||
|
|
||||||
|
// 映射 VideoUrls - 将数组序列化为 JSON 字符串
|
||||||
|
if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
|
||||||
|
videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
|
||||||
|
task.VideoUrls = "[]" // 失败时设置为空数组
|
||||||
|
} else {
|
||||||
|
task.VideoUrls = string(videoUrlsStr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
task.VideoUrls = "" // 空值时清空字段
|
||||||
|
}
|
||||||
|
|
||||||
shouldReturnQuota := false
|
shouldReturnQuota := false
|
||||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
if task.Quota != 0 {
|
if task.Quota != 0 {
|
||||||
shouldReturnQuota = true
|
shouldReturnQuota = true
|
||||||
@@ -155,14 +172,14 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
err = task.Update()
|
err = task.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||||
} else {
|
} else {
|
||||||
if shouldReturnQuota {
|
if shouldReturnQuota {
|
||||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
|
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,6 +225,20 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
|||||||
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
// 检查 VideoUrl 是否需要更新
|
||||||
|
if oldTask.VideoUrl != newTask.VideoUrl {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// 检查 VideoUrls 是否需要更新
|
||||||
|
if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 {
|
||||||
|
newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls)
|
||||||
|
if oldTask.VideoUrls != string(newVideoUrlsStr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else if oldTask.VideoUrls != "" {
|
||||||
|
// 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,48 +39,51 @@ func TestStatus(c *gin.Context) {
|
|||||||
func GetStatus(c *gin.Context) {
|
func GetStatus(c *gin.Context) {
|
||||||
|
|
||||||
cs := console_setting.GetConsoleSetting()
|
cs := console_setting.GetConsoleSetting()
|
||||||
|
common.OptionMapRWMutex.RLock()
|
||||||
|
defer common.OptionMapRWMutex.RUnlock()
|
||||||
|
|
||||||
data := gin.H{
|
data := gin.H{
|
||||||
"version": common.Version,
|
"version": common.Version,
|
||||||
"start_time": common.StartTime,
|
"start_time": common.StartTime,
|
||||||
"email_verification": common.EmailVerificationEnabled,
|
"email_verification": common.EmailVerificationEnabled,
|
||||||
"github_oauth": common.GitHubOAuthEnabled,
|
"github_oauth": common.GitHubOAuthEnabled,
|
||||||
"github_client_id": common.GitHubClientId,
|
"github_client_id": common.GitHubClientId,
|
||||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||||
"linuxdo_client_id": common.LinuxDOClientId,
|
"linuxdo_client_id": common.LinuxDOClientId,
|
||||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||||
"telegram_bot_name": common.TelegramBotName,
|
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||||
"system_name": common.SystemName,
|
"telegram_bot_name": common.TelegramBotName,
|
||||||
"logo": common.Logo,
|
"system_name": common.SystemName,
|
||||||
"footer_html": common.Footer,
|
"logo": common.Logo,
|
||||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
"footer_html": common.Footer,
|
||||||
"wechat_login": common.WeChatAuthEnabled,
|
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||||
"server_address": setting.ServerAddress,
|
"wechat_login": common.WeChatAuthEnabled,
|
||||||
"price": setting.Price,
|
"server_address": setting.ServerAddress,
|
||||||
"stripe_unit_price": setting.StripeUnitPrice,
|
"price": setting.Price,
|
||||||
"min_topup": setting.MinTopUp,
|
"stripe_unit_price": setting.StripeUnitPrice,
|
||||||
"stripe_min_topup": setting.StripeMinTopUp,
|
"min_topup": setting.MinTopUp,
|
||||||
"turnstile_check": common.TurnstileCheckEnabled,
|
"stripe_min_topup": setting.StripeMinTopUp,
|
||||||
"turnstile_site_key": common.TurnstileSiteKey,
|
"turnstile_check": common.TurnstileCheckEnabled,
|
||||||
"top_up_link": common.TopUpLink,
|
"turnstile_site_key": common.TurnstileSiteKey,
|
||||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
"top_up_link": common.TopUpLink,
|
||||||
"quota_per_unit": common.QuotaPerUnit,
|
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
"quota_per_unit": common.QuotaPerUnit,
|
||||||
"enable_batch_update": common.BatchUpdateEnabled,
|
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||||
"enable_drawing": common.DrawingEnabled,
|
"enable_batch_update": common.BatchUpdateEnabled,
|
||||||
"enable_task": common.TaskEnabled,
|
"enable_drawing": common.DrawingEnabled,
|
||||||
"enable_data_export": common.DataExportEnabled,
|
"enable_task": common.TaskEnabled,
|
||||||
"data_export_default_time": common.DataExportDefaultTime,
|
"enable_data_export": common.DataExportEnabled,
|
||||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
"data_export_default_time": common.DataExportDefaultTime,
|
||||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||||
"chats": setting.Chats,
|
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
"chats": setting.Chats,
|
||||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||||
"pay_methods": setting.PayMethods,
|
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||||
"usd_exchange_rate": setting.USDExchangeRate,
|
"pay_methods": setting.PayMethods,
|
||||||
|
"usd_exchange_rate": setting.USDExchangeRate,
|
||||||
|
|
||||||
// 面板启用开关
|
// 面板启用开关
|
||||||
"api_info_enabled": cs.ApiInfoEnabled,
|
"api_info_enabled": cs.ApiInfoEnabled,
|
||||||
@@ -88,6 +91,10 @@ func GetStatus(c *gin.Context) {
|
|||||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||||
"faq_enabled": cs.FAQEnabled,
|
"faq_enabled": cs.FAQEnabled,
|
||||||
|
|
||||||
|
// 模块管理配置
|
||||||
|
"HeaderNavModules": common.OptionMap["HeaderNavModules"],
|
||||||
|
"SidebarModulesAdmin": common.OptionMap["SidebarModulesAdmin"],
|
||||||
|
|
||||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||||
|
|||||||
27
controller/missing_models.go
Normal file
27
controller/missing_models.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetMissingModels returns the list of model names that are referenced by channels
|
||||||
|
// but do not have corresponding records in the models meta table.
|
||||||
|
// This helps administrators quickly discover models that need configuration.
|
||||||
|
func GetMissingModels(c *gin.Context) {
|
||||||
|
missing, err := model.GetMissingModels()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": missing,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@@ -92,7 +93,9 @@ func init() {
|
|||||||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
ChannelType: i,
|
||||||
|
}}
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
channelId2Models[i] = adaptor.GetModelList()
|
channelId2Models[i] = adaptor.GetModelList()
|
||||||
@@ -102,7 +105,7 @@ func init() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func ListModels(c *gin.Context, modelType int) {
|
||||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||||
|
|
||||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
@@ -171,11 +174,42 @@ func ListModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
switch modelType {
|
||||||
"success": true,
|
case constant.ChannelTypeAnthropic:
|
||||||
"data": userOpenAiModels,
|
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||||||
"object": "list",
|
for i, model := range userOpenAiModels {
|
||||||
})
|
useranthropicModels[i] = dto.AnthropicModel{
|
||||||
|
ID: model.Id,
|
||||||
|
CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
|
||||||
|
DisplayName: model.Id,
|
||||||
|
Type: "model",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"data": useranthropicModels,
|
||||||
|
"first_id": useranthropicModels[0].ID,
|
||||||
|
"has_more": false,
|
||||||
|
"last_id": useranthropicModels[len(useranthropicModels)-1].ID,
|
||||||
|
})
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
|
||||||
|
for i, model := range userOpenAiModels {
|
||||||
|
userGeminiModels[i] = dto.GeminiModel{
|
||||||
|
Name: model.Id,
|
||||||
|
DisplayName: model.Id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"models": userGeminiModels,
|
||||||
|
"nextPageToken": nil,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": userOpenAiModels,
|
||||||
|
"object": "list",
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ChannelListModels(c *gin.Context) {
|
func ChannelListModels(c *gin.Context) {
|
||||||
@@ -199,10 +233,20 @@ func EnabledListModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func RetrieveModel(c *gin.Context) {
|
func RetrieveModel(c *gin.Context, modelType int) {
|
||||||
modelId := c.Param("model")
|
modelId := c.Param("model")
|
||||||
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
||||||
c.JSON(200, aiModel)
|
switch modelType {
|
||||||
|
case constant.ChannelTypeAnthropic:
|
||||||
|
c.JSON(200, dto.AnthropicModel{
|
||||||
|
ID: aiModel.Id,
|
||||||
|
CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
|
||||||
|
DisplayName: aiModel.Id,
|
||||||
|
Type: "model",
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(200, aiModel)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
openAIError := dto.OpenAIError{
|
openAIError := dto.OpenAIError{
|
||||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||||
|
|||||||
330
controller/model_meta.go
Normal file
330
controller/model_meta.go
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAllModelsMeta 获取模型列表(分页)
|
||||||
|
func GetAllModelsMeta(c *gin.Context) {
|
||||||
|
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 批量填充附加字段,提升列表接口性能
|
||||||
|
enrichModels(modelsMeta)
|
||||||
|
var total int64
|
||||||
|
model.DB.Model(&model.Model{}).Count(&total)
|
||||||
|
|
||||||
|
// 统计供应商计数(全部数据,不受分页影响)
|
||||||
|
vendorCounts, _ := model.GetVendorModelCounts()
|
||||||
|
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(modelsMeta)
|
||||||
|
common.ApiSuccess(c, gin.H{
|
||||||
|
"items": modelsMeta,
|
||||||
|
"total": total,
|
||||||
|
"page": pageInfo.GetPage(),
|
||||||
|
"page_size": pageInfo.GetPageSize(),
|
||||||
|
"vendor_counts": vendorCounts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchModelsMeta 搜索模型列表
|
||||||
|
func SearchModelsMeta(c *gin.Context) {
|
||||||
|
|
||||||
|
keyword := c.Query("keyword")
|
||||||
|
vendor := c.Query("vendor")
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
|
||||||
|
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 批量填充附加字段,提升列表接口性能
|
||||||
|
enrichModels(modelsMeta)
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(modelsMeta)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelMeta 根据 ID 获取单条模型信息
|
||||||
|
func GetModelMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var m model.Model
|
||||||
|
if err := model.DB.First(&m, id).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
enrichModels([]*model.Model{&m})
|
||||||
|
common.ApiSuccess(c, &m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateModelMeta 新建模型
|
||||||
|
func CreateModelMeta(c *gin.Context) {
|
||||||
|
var m model.Model
|
||||||
|
if err := c.ShouldBindJSON(&m); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.ModelName == "" {
|
||||||
|
common.ApiErrorMsg(c, "模型名称不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "模型名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Insert(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model.RefreshPricing()
|
||||||
|
common.ApiSuccess(c, &m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModelMeta 更新模型
|
||||||
|
func UpdateModelMeta(c *gin.Context) {
|
||||||
|
statusOnly := c.Query("status_only") == "true"
|
||||||
|
|
||||||
|
var m model.Model
|
||||||
|
if err := c.ShouldBindJSON(&m); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.Id == 0 {
|
||||||
|
common.ApiErrorMsg(c, "缺少模型 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusOnly {
|
||||||
|
// 只更新状态,防止误清空其他字段
|
||||||
|
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "模型名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Update(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model.RefreshPricing()
|
||||||
|
common.ApiSuccess(c, &m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteModelMeta 删除模型
|
||||||
|
func DeleteModelMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model.RefreshPricing()
|
||||||
|
common.ApiSuccess(c, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
|
||||||
|
func enrichModels(models []*model.Model) {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) 拆分精确与规则匹配
|
||||||
|
exactNames := make([]string, 0)
|
||||||
|
exactIdx := make(map[string][]int) // modelName -> indices in models
|
||||||
|
ruleIndices := make([]int, 0)
|
||||||
|
for i, m := range models {
|
||||||
|
if m == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m.NameRule == model.NameRuleExact {
|
||||||
|
exactNames = append(exactNames, m.ModelName)
|
||||||
|
exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
|
||||||
|
} else {
|
||||||
|
ruleIndices = append(ruleIndices, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 批量查询精确模型的绑定渠道
|
||||||
|
channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
|
||||||
|
|
||||||
|
// 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
|
||||||
|
for name, indices := range exactIdx {
|
||||||
|
chs := channelsByModel[name]
|
||||||
|
for _, idx := range indices {
|
||||||
|
mm := models[idx]
|
||||||
|
if mm.Endpoints == "" {
|
||||||
|
eps := model.GetModelSupportEndpointTypes(mm.ModelName)
|
||||||
|
if b, err := json.Marshal(eps); err == nil {
|
||||||
|
mm.Endpoints = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mm.BoundChannels = chs
|
||||||
|
mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
|
||||||
|
mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ruleIndices) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 一次性读取定价缓存,内存匹配所有规则模型
|
||||||
|
pricings := model.GetPricing()
|
||||||
|
|
||||||
|
// 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
|
||||||
|
matchedNamesByIdx := make(map[int][]string)
|
||||||
|
endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
|
||||||
|
groupSetByIdx := make(map[int]map[string]struct{})
|
||||||
|
quotaSetByIdx := make(map[int]map[int]struct{})
|
||||||
|
|
||||||
|
for _, p := range pricings {
|
||||||
|
for _, idx := range ruleIndices {
|
||||||
|
mm := models[idx]
|
||||||
|
var matched bool
|
||||||
|
switch mm.NameRule {
|
||||||
|
case model.NameRulePrefix:
|
||||||
|
matched = strings.HasPrefix(p.ModelName, mm.ModelName)
|
||||||
|
case model.NameRuleSuffix:
|
||||||
|
matched = strings.HasSuffix(p.ModelName, mm.ModelName)
|
||||||
|
case model.NameRuleContains:
|
||||||
|
matched = strings.Contains(p.ModelName, mm.ModelName)
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
|
||||||
|
|
||||||
|
es := endpointSetByIdx[idx]
|
||||||
|
if es == nil {
|
||||||
|
es = make(map[constant.EndpointType]struct{})
|
||||||
|
endpointSetByIdx[idx] = es
|
||||||
|
}
|
||||||
|
for _, et := range p.SupportedEndpointTypes {
|
||||||
|
es[et] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
gs := groupSetByIdx[idx]
|
||||||
|
if gs == nil {
|
||||||
|
gs = make(map[string]struct{})
|
||||||
|
groupSetByIdx[idx] = gs
|
||||||
|
}
|
||||||
|
for _, g := range p.EnableGroup {
|
||||||
|
gs[g] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
qs := quotaSetByIdx[idx]
|
||||||
|
if qs == nil {
|
||||||
|
qs = make(map[int]struct{})
|
||||||
|
quotaSetByIdx[idx] = qs
|
||||||
|
}
|
||||||
|
qs[p.QuotaType] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5) 汇总所有匹配到的模型名称,批量查询一次渠道
|
||||||
|
allMatchedSet := make(map[string]struct{})
|
||||||
|
for _, names := range matchedNamesByIdx {
|
||||||
|
for _, n := range names {
|
||||||
|
allMatchedSet[n] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allMatched := make([]string, 0, len(allMatchedSet))
|
||||||
|
for n := range allMatchedSet {
|
||||||
|
allMatched = append(allMatched, n)
|
||||||
|
}
|
||||||
|
matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
|
||||||
|
|
||||||
|
// 6) 回填每个规则模型的并集信息
|
||||||
|
for _, idx := range ruleIndices {
|
||||||
|
mm := models[idx]
|
||||||
|
|
||||||
|
// 端点并集 -> 序列化
|
||||||
|
if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
|
||||||
|
eps := make([]constant.EndpointType, 0, len(es))
|
||||||
|
for et := range es {
|
||||||
|
eps = append(eps, et)
|
||||||
|
}
|
||||||
|
if b, err := json.Marshal(eps); err == nil {
|
||||||
|
mm.Endpoints = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 分组并集
|
||||||
|
if gs, ok := groupSetByIdx[idx]; ok {
|
||||||
|
groups := make([]string, 0, len(gs))
|
||||||
|
for g := range gs {
|
||||||
|
groups = append(groups, g)
|
||||||
|
}
|
||||||
|
mm.EnableGroups = groups
|
||||||
|
}
|
||||||
|
|
||||||
|
// 配额类型集合(保持去重并排序)
|
||||||
|
if qs, ok := quotaSetByIdx[idx]; ok {
|
||||||
|
arr := make([]int, 0, len(qs))
|
||||||
|
for k := range qs {
|
||||||
|
arr = append(arr, k)
|
||||||
|
}
|
||||||
|
sort.Ints(arr)
|
||||||
|
mm.QuotaTypes = arr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 渠道并集
|
||||||
|
names := matchedNamesByIdx[idx]
|
||||||
|
channelSet := make(map[string]model.BoundChannel)
|
||||||
|
for _, n := range names {
|
||||||
|
for _, ch := range matchedChannelsByModel[n] {
|
||||||
|
key := ch.Name + "_" + strconv.Itoa(ch.Type)
|
||||||
|
channelSet[key] = ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(channelSet) > 0 {
|
||||||
|
chs := make([]model.BoundChannel, 0, len(channelSet))
|
||||||
|
for _, ch := range channelSet {
|
||||||
|
chs = append(chs, ch)
|
||||||
|
}
|
||||||
|
mm.BoundChannels = chs
|
||||||
|
}
|
||||||
|
|
||||||
|
// 匹配信息
|
||||||
|
mm.MatchedModels = names
|
||||||
|
mm.MatchedCount = len(names)
|
||||||
|
}
|
||||||
|
}
|
||||||
463
controller/model_sync.go
Normal file
463
controller/model_sync.go
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 上游地址
|
||||||
|
const (
|
||||||
|
upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
|
||||||
|
upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type upstreamEnvelope[T any] struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data []T `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamModel struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
Endpoints json.RawMessage `json:"endpoints"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
NameRule int `json:"name_rule"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
Tags string `json:"tags"`
|
||||||
|
VendorName string `json:"vendor_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamVendor struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type overwriteField struct {
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
Fields []string `json:"fields"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type syncRequest struct {
|
||||||
|
Overwrite []overwriteField `json:"overwrite"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPClient() *http.Client {
|
||||||
|
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||||
|
transport := &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
ResponseHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
host = addr
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(host, "github.io") {
|
||||||
|
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, "tcp6", addr)
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: transport}
|
||||||
|
}
|
||||||
|
|
||||||
|
var httpClient = newHTTPClient()
|
||||||
|
|
||||||
|
func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error {
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 0; attempt < 3; attempt++ {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
func() {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
lastErr = errors.New(resp.Status)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
limited := io.LimitReader(resp.Body, 10<<20)
|
||||||
|
if err := json.NewDecoder(limited).Decode(out); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||||
|
out.Success = true
|
||||||
|
}
|
||||||
|
lastErr = nil
|
||||||
|
}()
|
||||||
|
if lastErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureVendorID(vendorName string, vendorByName map[string]upstreamVendor, vendorIDCache map[string]int, createdVendors *int) int {
|
||||||
|
if vendorName == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if id, ok := vendorIDCache[vendorName]; ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
var existing model.Vendor
|
||||||
|
if err := model.DB.Where("name = ?", vendorName).First(&existing).Error; err == nil {
|
||||||
|
vendorIDCache[vendorName] = existing.Id
|
||||||
|
return existing.Id
|
||||||
|
}
|
||||||
|
uv := vendorByName[vendorName]
|
||||||
|
v := &model.Vendor{
|
||||||
|
Name: vendorName,
|
||||||
|
Description: uv.Description,
|
||||||
|
Icon: coalesce(uv.Icon, ""),
|
||||||
|
Status: chooseStatus(uv.Status, 1),
|
||||||
|
}
|
||||||
|
if err := v.Insert(); err == nil {
|
||||||
|
*createdVendors++
|
||||||
|
vendorIDCache[vendorName] = v.Id
|
||||||
|
return v.Id
|
||||||
|
}
|
||||||
|
vendorIDCache[vendorName] = 0
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncUpstreamModels 同步上游模型与供应商,仅对「未配置模型」生效
|
||||||
|
func SyncUpstreamModels(c *gin.Context) {
|
||||||
|
var req syncRequest
|
||||||
|
// 允许空体
|
||||||
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
// 1) 获取未配置模型列表
|
||||||
|
missing, err := model.GetMissingModels()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(missing) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true, "data": gin.H{
|
||||||
|
"created_models": 0,
|
||||||
|
"created_vendors": 0,
|
||||||
|
"skipped_models": []string{},
|
||||||
|
}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 拉取上游 vendors 与 models
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||||
|
_ = fetchJSON(ctx, upstreamVendorsURL, &vendorsEnv) // 若失败不拦截,后续降级
|
||||||
|
|
||||||
|
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||||
|
if err := fetchJSON(ctx, upstreamModelsURL, &modelsEnv); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 建立映射
|
||||||
|
vendorByName := make(map[string]upstreamVendor)
|
||||||
|
for _, v := range vendorsEnv.Data {
|
||||||
|
if v.Name != "" {
|
||||||
|
vendorByName[v.Name] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelByName := make(map[string]upstreamModel)
|
||||||
|
for _, m := range modelsEnv.Data {
|
||||||
|
if m.ModelName != "" {
|
||||||
|
modelByName[m.ModelName] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 执行同步:仅创建缺失模型;若上游缺失该模型则跳过
|
||||||
|
createdModels := 0
|
||||||
|
createdVendors := 0
|
||||||
|
updatedModels := 0
|
||||||
|
var skipped []string
|
||||||
|
var createdList []string
|
||||||
|
var updatedList []string
|
||||||
|
|
||||||
|
// 本地缓存:vendorName -> id
|
||||||
|
vendorIDCache := make(map[string]int)
|
||||||
|
|
||||||
|
for _, name := range missing {
|
||||||
|
up, ok := modelByName[name]
|
||||||
|
if !ok {
|
||||||
|
skipped = append(skipped, name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时)
|
||||||
|
var existing model.Model
|
||||||
|
if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil {
|
||||||
|
if existing.SyncOfficial == 0 {
|
||||||
|
skipped = append(skipped, name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保 vendor 存在
|
||||||
|
vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||||
|
|
||||||
|
// 创建模型
|
||||||
|
mi := &model.Model{
|
||||||
|
ModelName: name,
|
||||||
|
Description: up.Description,
|
||||||
|
Icon: up.Icon,
|
||||||
|
Tags: up.Tags,
|
||||||
|
VendorID: vendorID,
|
||||||
|
Status: chooseStatus(up.Status, 1),
|
||||||
|
NameRule: up.NameRule,
|
||||||
|
}
|
||||||
|
if err := mi.Insert(); err == nil {
|
||||||
|
createdModels++
|
||||||
|
createdList = append(createdList, name)
|
||||||
|
} else {
|
||||||
|
skipped = append(skipped, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 处理可选覆盖(更新本地已有模型的差异字段)
|
||||||
|
if len(req.Overwrite) > 0 {
|
||||||
|
// vendorIDCache 已用于创建阶段,可复用
|
||||||
|
for _, ow := range req.Overwrite {
|
||||||
|
up, ok := modelByName[ow.ModelName]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var local model.Model
|
||||||
|
if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过被禁用官方同步的模型
|
||||||
|
if local.SyncOfficial == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 映射 vendor
|
||||||
|
newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||||
|
|
||||||
|
// 应用字段覆盖(事务)
|
||||||
|
_ = model.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
needUpdate := false
|
||||||
|
if containsField(ow.Fields, "description") {
|
||||||
|
local.Description = up.Description
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "icon") {
|
||||||
|
local.Icon = up.Icon
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "tags") {
|
||||||
|
local.Tags = up.Tags
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "vendor") {
|
||||||
|
local.VendorID = newVendorID
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "name_rule") {
|
||||||
|
local.NameRule = up.NameRule
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "status") {
|
||||||
|
local.Status = chooseStatus(up.Status, local.Status)
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if !needUpdate {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := tx.Save(&local).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updatedModels++
|
||||||
|
updatedList = append(updatedList, ow.ModelName)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"created_models": createdModels,
|
||||||
|
"created_vendors": createdVendors,
|
||||||
|
"updated_models": updatedModels,
|
||||||
|
"skipped_models": skipped,
|
||||||
|
"created_list": createdList,
|
||||||
|
"updated_list": updatedList,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsField(fields []string, key string) bool {
|
||||||
|
key = strings.ToLower(strings.TrimSpace(key))
|
||||||
|
for _, f := range fields {
|
||||||
|
if strings.ToLower(strings.TrimSpace(f)) == key {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func coalesce(a, b string) string {
|
||||||
|
if strings.TrimSpace(a) != "" {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func chooseStatus(primary, fallback int) int {
|
||||||
|
if primary == 0 && fallback != 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
if primary != 0 {
|
||||||
|
return primary
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择)
|
||||||
|
func SyncUpstreamPreview(c *gin.Context) {
|
||||||
|
// 1) 拉取上游数据
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||||
|
_ = fetchJSON(ctx, upstreamVendorsURL, &vendorsEnv)
|
||||||
|
|
||||||
|
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||||
|
if err := fetchJSON(ctx, upstreamModelsURL, &modelsEnv); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vendorByName := make(map[string]upstreamVendor)
|
||||||
|
for _, v := range vendorsEnv.Data {
|
||||||
|
if v.Name != "" {
|
||||||
|
vendorByName[v.Name] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelByName := make(map[string]upstreamModel)
|
||||||
|
upstreamNames := make([]string, 0, len(modelsEnv.Data))
|
||||||
|
for _, m := range modelsEnv.Data {
|
||||||
|
if m.ModelName != "" {
|
||||||
|
modelByName[m.ModelName] = m
|
||||||
|
upstreamNames = append(upstreamNames, m.ModelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 本地已有模型
|
||||||
|
var locals []model.Model
|
||||||
|
if len(upstreamNames) > 0 {
|
||||||
|
_ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// 本地 vendor 名称映射
|
||||||
|
vendorIdSet := make(map[int]struct{})
|
||||||
|
for _, m := range locals {
|
||||||
|
if m.VendorID != 0 {
|
||||||
|
vendorIdSet[m.VendorID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vendorIDs := make([]int, 0, len(vendorIdSet))
|
||||||
|
for id := range vendorIdSet {
|
||||||
|
vendorIDs = append(vendorIDs, id)
|
||||||
|
}
|
||||||
|
idToVendorName := make(map[int]string)
|
||||||
|
if len(vendorIDs) > 0 {
|
||||||
|
var dbVendors []model.Vendor
|
||||||
|
_ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error
|
||||||
|
for _, v := range dbVendors {
|
||||||
|
idToVendorName[v.Id] = v.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 缺失且上游存在的模型
|
||||||
|
missingList, _ := model.GetMissingModels()
|
||||||
|
var missing []string
|
||||||
|
for _, name := range missingList {
|
||||||
|
if _, ok := modelByName[name]; ok {
|
||||||
|
missing = append(missing, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 计算冲突字段
|
||||||
|
type conflictField struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Local interface{} `json:"local"`
|
||||||
|
Upstream interface{} `json:"upstream"`
|
||||||
|
}
|
||||||
|
type conflictItem struct {
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
Fields []conflictField `json:"fields"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var conflicts []conflictItem
|
||||||
|
for _, local := range locals {
|
||||||
|
up, ok := modelByName[local.ModelName]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields := make([]conflictField, 0, 6)
|
||||||
|
if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) {
|
||||||
|
fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description})
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) {
|
||||||
|
fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon})
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) {
|
||||||
|
fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags})
|
||||||
|
}
|
||||||
|
// vendor 对比使用名称
|
||||||
|
localVendor := idToVendorName[local.VendorID]
|
||||||
|
if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) {
|
||||||
|
fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName})
|
||||||
|
}
|
||||||
|
if local.NameRule != up.NameRule {
|
||||||
|
fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule})
|
||||||
|
}
|
||||||
|
if local.Status != chooseStatus(up.Status, local.Status) {
|
||||||
|
fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status})
|
||||||
|
}
|
||||||
|
if len(fields) > 0 {
|
||||||
|
conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"missing": missing,
|
||||||
|
"conflicts": conflicts,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -69,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oidcResponse.AccessToken == "" {
|
if oidcResponse.AccessToken == "" {
|
||||||
common.SysError("OIDC 获取 Token 失败,请检查设置!")
|
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
}
|
}
|
||||||
defer res2.Body.Close()
|
defer res2.Body.Close()
|
||||||
if res2.StatusCode != http.StatusOK {
|
if res2.StatusCode != http.StatusOK {
|
||||||
common.SysError("OIDC 获取用户信息失败!请检查设置!")
|
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||||
common.SysError("OIDC 获取用户信息为空!请检查设置!")
|
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||||
}
|
}
|
||||||
return &oidcUser, nil
|
return &oidcUser, nil
|
||||||
|
|||||||
@@ -5,10 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -28,41 +26,19 @@ func Playground(c *gin.Context) {
|
|||||||
|
|
||||||
useAccessToken := c.GetBool("use_access_token")
|
useAccessToken := c.GetBool("use_access_token")
|
||||||
if useAccessToken {
|
if useAccessToken {
|
||||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
playgroundRequest := &dto.PlayGroundRequest{}
|
group := c.GetString("group")
|
||||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
modelName := c.GetString("original_model")
|
||||||
if err != nil {
|
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if playgroundRequest.Model == "" {
|
|
||||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("original_model", playgroundRequest.Model)
|
|
||||||
group := playgroundRequest.Group
|
|
||||||
userGroup := c.GetString("group")
|
|
||||||
|
|
||||||
if group == "" {
|
|
||||||
group = userGroup
|
|
||||||
} else {
|
|
||||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
|
||||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("group", group)
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
// Write user context to ensure acceptUnsetRatio is available
|
// Write user context to ensure acceptUnsetRatio is available
|
||||||
userCache, err := model.GetUserCache(userId)
|
userCache, err := model.GetUserCache(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userCache.WriteContext(c)
|
userCache.WriteContext(c)
|
||||||
@@ -73,12 +49,12 @@ func Playground(c *gin.Context) {
|
|||||||
Group: group,
|
Group: group,
|
||||||
}
|
}
|
||||||
_ = middleware.SetupContextForToken(c, tempToken)
|
_ = middleware.SetupContextForToken(c, tempToken)
|
||||||
_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
|
_, newAPIError = getChannel(c, group, modelName, 0)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
|
|
||||||
Relay(c)
|
Relay(c, types.RelayFormatOpenAI)
|
||||||
}
|
}
|
||||||
|
|||||||
90
controller/prefill_group.go
Normal file
90
controller/prefill_group.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
|
||||||
|
func GetPrefillGroups(c *gin.Context) {
|
||||||
|
groupType := c.Query("type")
|
||||||
|
groups, err := model.GetAllPrefillGroups(groupType)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePrefillGroup 创建新的预填组
|
||||||
|
func CreatePrefillGroup(c *gin.Context) {
|
||||||
|
var g model.PrefillGroup
|
||||||
|
if err := c.ShouldBindJSON(&g); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if g.Name == "" || g.Type == "" {
|
||||||
|
common.ApiErrorMsg(c, "组名称和类型不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 创建前检查名称
|
||||||
|
if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "组名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Insert(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePrefillGroup 更新预填组
|
||||||
|
func UpdatePrefillGroup(c *gin.Context) {
|
||||||
|
var g model.PrefillGroup
|
||||||
|
if err := c.ShouldBindJSON(&g); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if g.Id == 0 {
|
||||||
|
common.ApiErrorMsg(c, "缺少组 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "组名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Update(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePrefillGroup 删除预填组
|
||||||
|
func DeletePrefillGroup(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := model.DeletePrefillGroupByID(id); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, nil)
|
||||||
|
}
|
||||||
@@ -39,10 +39,13 @@ func GetPricing(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": pricing,
|
"data": pricing,
|
||||||
"group_ratio": groupRatio,
|
"vendors": model.GetVendors(),
|
||||||
"usable_group": usableGroup,
|
"group_ratio": groupRatio,
|
||||||
|
"usable_group": usableGroup,
|
||||||
|
"supported_endpoint": model.GetSupportedEndpointMap(),
|
||||||
|
"auto_groups": setting.AutoGroups,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,24 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetRatioConfig(c *gin.Context) {
|
func GetRatioConfig(c *gin.Context) {
|
||||||
if !ratio_setting.IsExposeRatioEnabled() {
|
if !ratio_setting.IsExposeRatioEnabled() {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "倍率配置接口未启用",
|
"message": "倍率配置接口未启用",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": ratio_setting.GetExposedData(),
|
"data": ratio_setting.GetExposedData(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,474 +1,539 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"io"
|
||||||
"strings"
|
"net"
|
||||||
"sync"
|
"net/http"
|
||||||
"time"
|
"one-api/logger"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"one-api/common"
|
"one-api/dto"
|
||||||
"one-api/dto"
|
"one-api/model"
|
||||||
"one-api/model"
|
"one-api/setting/ratio_setting"
|
||||||
"one-api/setting/ratio_setting"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultTimeoutSeconds = 10
|
defaultTimeoutSeconds = 10
|
||||||
defaultEndpoint = "/api/ratio_config"
|
defaultEndpoint = "/api/ratio_config"
|
||||||
maxConcurrentFetches = 8
|
maxConcurrentFetches = 8
|
||||||
|
maxRatioConfigBytes = 10 << 20 // 10MB
|
||||||
|
floatEpsilon = 1e-9
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func nearlyEqual(a, b float64) bool {
|
||||||
|
if a > b {
|
||||||
|
return a-b < floatEpsilon
|
||||||
|
}
|
||||||
|
return b-a < floatEpsilon
|
||||||
|
}
|
||||||
|
|
||||||
|
func valuesEqual(a, b interface{}) bool {
|
||||||
|
af, aok := a.(float64)
|
||||||
|
bf, bok := b.(float64)
|
||||||
|
if aok && bok {
|
||||||
|
return nearlyEqual(af, bf)
|
||||||
|
}
|
||||||
|
return a == b
|
||||||
|
}
|
||||||
|
|
||||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||||
|
|
||||||
type upstreamResult struct {
|
type upstreamResult struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Data map[string]any `json:"data,omitempty"`
|
Data map[string]any `json:"data,omitempty"`
|
||||||
Err string `json:"err,omitempty"`
|
Err string `json:"err,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchUpstreamRatios(c *gin.Context) {
|
func FetchUpstreamRatios(c *gin.Context) {
|
||||||
var req dto.UpstreamRequest
|
var req dto.UpstreamRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Timeout <= 0 {
|
if req.Timeout <= 0 {
|
||||||
req.Timeout = defaultTimeoutSeconds
|
req.Timeout = defaultTimeoutSeconds
|
||||||
}
|
}
|
||||||
|
|
||||||
var upstreams []dto.UpstreamDTO
|
var upstreams []dto.UpstreamDTO
|
||||||
|
|
||||||
if len(req.Upstreams) > 0 {
|
if len(req.Upstreams) > 0 {
|
||||||
for _, u := range req.Upstreams {
|
for _, u := range req.Upstreams {
|
||||||
if strings.HasPrefix(u.BaseURL, "http") {
|
if strings.HasPrefix(u.BaseURL, "http") {
|
||||||
if u.Endpoint == "" {
|
if u.Endpoint == "" {
|
||||||
u.Endpoint = defaultEndpoint
|
u.Endpoint = defaultEndpoint
|
||||||
}
|
}
|
||||||
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||||
upstreams = append(upstreams, u)
|
upstreams = append(upstreams, u)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if len(req.ChannelIDs) > 0 {
|
} else if len(req.ChannelIDs) > 0 {
|
||||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||||
for _, id64 := range req.ChannelIDs {
|
for _, id64 := range req.ChannelIDs {
|
||||||
intIds = append(intIds, int(id64))
|
intIds = append(intIds, int(id64))
|
||||||
}
|
}
|
||||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, ch := range dbChannels {
|
for _, ch := range dbChannels {
|
||||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||||
ID: ch.Id,
|
ID: ch.Id,
|
||||||
Name: ch.Name,
|
Name: ch.Name,
|
||||||
BaseURL: strings.TrimRight(base, "/"),
|
BaseURL: strings.TrimRight(base, "/"),
|
||||||
Endpoint: "",
|
Endpoint: "",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(upstreams) == 0 {
|
if len(upstreams) == 0 {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
ch := make(chan upstreamResult, len(upstreams))
|
ch := make(chan upstreamResult, len(upstreams))
|
||||||
|
|
||||||
sem := make(chan struct{}, maxConcurrentFetches)
|
sem := make(chan struct{}, maxConcurrentFetches)
|
||||||
|
|
||||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||||
|
transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
host = addr
|
||||||
|
}
|
||||||
|
// 对 github.io 优先尝试 IPv4,失败则回退 IPv6
|
||||||
|
if strings.HasSuffix(host, "github.io") {
|
||||||
|
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, "tcp6", addr)
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: transport}
|
||||||
|
|
||||||
for _, chn := range upstreams {
|
for _, chn := range upstreams {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(chItem dto.UpstreamDTO) {
|
go func(chItem dto.UpstreamDTO) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
sem <- struct{}{}
|
sem <- struct{}{}
|
||||||
defer func() { <-sem }()
|
defer func() { <-sem }()
|
||||||
|
|
||||||
endpoint := chItem.Endpoint
|
endpoint := chItem.Endpoint
|
||||||
if endpoint == "" {
|
var fullURL string
|
||||||
endpoint = defaultEndpoint
|
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
||||||
} else if !strings.HasPrefix(endpoint, "/") {
|
fullURL = endpoint
|
||||||
endpoint = "/" + endpoint
|
} else {
|
||||||
}
|
if endpoint == "" {
|
||||||
fullURL := chItem.BaseURL + endpoint
|
endpoint = defaultEndpoint
|
||||||
|
} else if !strings.HasPrefix(endpoint, "/") {
|
||||||
|
endpoint = "/" + endpoint
|
||||||
|
}
|
||||||
|
fullURL = chItem.BaseURL + endpoint
|
||||||
|
}
|
||||||
|
|
||||||
uniqueName := chItem.Name
|
uniqueName := chItem.Name
|
||||||
if chItem.ID != 0 {
|
if chItem.ID != 0 {
|
||||||
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.Do(httpReq)
|
// 简单重试:最多 3 次,指数退避
|
||||||
if err != nil {
|
var resp *http.Response
|
||||||
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
var lastErr error
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
for attempt := 0; attempt < 3; attempt++ {
|
||||||
return
|
resp, lastErr = client.Do(httpReq)
|
||||||
}
|
if lastErr == nil {
|
||||||
defer resp.Body.Close()
|
break
|
||||||
if resp.StatusCode != http.StatusOK {
|
}
|
||||||
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
}
|
||||||
return
|
if lastErr != nil {
|
||||||
}
|
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
|
||||||
// 兼容两种上游接口格式:
|
ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
|
||||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
return
|
||||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
}
|
||||||
var body struct {
|
defer resp.Body.Close()
|
||||||
Success bool `json:"success"`
|
if resp.StatusCode != http.StatusOK {
|
||||||
Data json.RawMessage `json:"data"`
|
logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||||
Message string `json:"message"`
|
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
// Content-Type 和响应体大小校验
|
||||||
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
|
||||||
return
|
}
|
||||||
}
|
limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
|
||||||
|
// 兼容两种上游接口格式:
|
||||||
|
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||||
|
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||||
|
var body struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
if !body.Success {
|
if err := json.NewDecoder(limited).Decode(&body); err != nil {
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||||
return
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 尝试按 type1 解析
|
if !body.Success {
|
||||||
var type1Data map[string]any
|
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
return
|
||||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
}
|
||||||
isType1 := false
|
|
||||||
for _, rt := range ratioTypes {
|
|
||||||
if _, ok := type1Data[rt]; ok {
|
|
||||||
isType1 = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if isType1 {
|
|
||||||
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
// 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容)
|
||||||
var pricingItems []struct {
|
|
||||||
ModelName string `json:"model_name"`
|
|
||||||
QuotaType int `json:"quota_type"`
|
|
||||||
ModelRatio float64 `json:"model_ratio"`
|
|
||||||
ModelPrice float64 `json:"model_price"`
|
|
||||||
CompletionRatio float64 `json:"completion_ratio"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
|
||||||
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
modelRatioMap := make(map[string]float64)
|
// 尝试按 type1 解析
|
||||||
completionRatioMap := make(map[string]float64)
|
var type1Data map[string]any
|
||||||
modelPriceMap := make(map[string]float64)
|
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||||
|
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||||
|
isType1 := false
|
||||||
|
for _, rt := range ratioTypes {
|
||||||
|
if _, ok := type1Data[rt]; ok {
|
||||||
|
isType1 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isType1 {
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, item := range pricingItems {
|
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||||
if item.QuotaType == 1 {
|
var pricingItems []struct {
|
||||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
ModelName string `json:"model_name"`
|
||||||
} else {
|
QuotaType int `json:"quota_type"`
|
||||||
modelRatioMap[item.ModelName] = item.ModelRatio
|
ModelRatio float64 `json:"model_ratio"`
|
||||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
ModelPrice float64 `json:"model_price"`
|
||||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
CompletionRatio float64 `json:"completion_ratio"`
|
||||||
}
|
}
|
||||||
}
|
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||||
|
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
converted := make(map[string]any)
|
modelRatioMap := make(map[string]float64)
|
||||||
|
completionRatioMap := make(map[string]float64)
|
||||||
|
modelPriceMap := make(map[string]float64)
|
||||||
|
|
||||||
if len(modelRatioMap) > 0 {
|
for _, item := range pricingItems {
|
||||||
ratioAny := make(map[string]any, len(modelRatioMap))
|
if item.QuotaType == 1 {
|
||||||
for k, v := range modelRatioMap {
|
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||||
ratioAny[k] = v
|
} else {
|
||||||
}
|
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||||
converted["model_ratio"] = ratioAny
|
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||||
}
|
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(completionRatioMap) > 0 {
|
converted := make(map[string]any)
|
||||||
compAny := make(map[string]any, len(completionRatioMap))
|
|
||||||
for k, v := range completionRatioMap {
|
|
||||||
compAny[k] = v
|
|
||||||
}
|
|
||||||
converted["completion_ratio"] = compAny
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(modelPriceMap) > 0 {
|
if len(modelRatioMap) > 0 {
|
||||||
priceAny := make(map[string]any, len(modelPriceMap))
|
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||||
for k, v := range modelPriceMap {
|
for k, v := range modelRatioMap {
|
||||||
priceAny[k] = v
|
ratioAny[k] = v
|
||||||
}
|
}
|
||||||
converted["model_price"] = priceAny
|
converted["model_ratio"] = ratioAny
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
if len(completionRatioMap) > 0 {
|
||||||
}(chn)
|
compAny := make(map[string]any, len(completionRatioMap))
|
||||||
}
|
for k, v := range completionRatioMap {
|
||||||
|
compAny[k] = v
|
||||||
|
}
|
||||||
|
converted["completion_ratio"] = compAny
|
||||||
|
}
|
||||||
|
|
||||||
wg.Wait()
|
if len(modelPriceMap) > 0 {
|
||||||
close(ch)
|
priceAny := make(map[string]any, len(modelPriceMap))
|
||||||
|
for k, v := range modelPriceMap {
|
||||||
|
priceAny[k] = v
|
||||||
|
}
|
||||||
|
converted["model_price"] = priceAny
|
||||||
|
}
|
||||||
|
|
||||||
localData := ratio_setting.GetExposedData()
|
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||||
|
}(chn)
|
||||||
|
}
|
||||||
|
|
||||||
var testResults []dto.TestResult
|
wg.Wait()
|
||||||
var successfulChannels []struct {
|
close(ch)
|
||||||
name string
|
|
||||||
data map[string]any
|
|
||||||
}
|
|
||||||
|
|
||||||
for r := range ch {
|
localData := ratio_setting.GetExposedData()
|
||||||
if r.Err != "" {
|
|
||||||
testResults = append(testResults, dto.TestResult{
|
|
||||||
Name: r.Name,
|
|
||||||
Status: "error",
|
|
||||||
Error: r.Err,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
testResults = append(testResults, dto.TestResult{
|
|
||||||
Name: r.Name,
|
|
||||||
Status: "success",
|
|
||||||
})
|
|
||||||
successfulChannels = append(successfulChannels, struct {
|
|
||||||
name string
|
|
||||||
data map[string]any
|
|
||||||
}{name: r.Name, data: r.Data})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
differences := buildDifferences(localData, successfulChannels)
|
var testResults []dto.TestResult
|
||||||
|
var successfulChannels []struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
for r := range ch {
|
||||||
"success": true,
|
if r.Err != "" {
|
||||||
"data": gin.H{
|
testResults = append(testResults, dto.TestResult{
|
||||||
"differences": differences,
|
Name: r.Name,
|
||||||
"test_results": testResults,
|
Status: "error",
|
||||||
},
|
Error: r.Err,
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
testResults = append(testResults, dto.TestResult{
|
||||||
|
Name: r.Name,
|
||||||
|
Status: "success",
|
||||||
|
})
|
||||||
|
successfulChannels = append(successfulChannels, struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}{name: r.Name, data: r.Data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
differences := buildDifferences(localData, successfulChannels)
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"differences": differences,
|
||||||
|
"test_results": testResults,
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||||
name string
|
name string
|
||||||
data map[string]any
|
data map[string]any
|
||||||
}) map[string]map[string]dto.DifferenceItem {
|
}) map[string]map[string]dto.DifferenceItem {
|
||||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||||
|
|
||||||
allModels := make(map[string]struct{})
|
allModels := make(map[string]struct{})
|
||||||
|
|
||||||
for _, ratioType := range ratioTypes {
|
|
||||||
if localRatioAny, ok := localData[ratioType]; ok {
|
|
||||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
|
||||||
for modelName := range localRatio {
|
|
||||||
allModels[modelName] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, channel := range successfulChannels {
|
|
||||||
for _, ratioType := range ratioTypes {
|
|
||||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
|
||||||
for modelName := range upstreamRatio {
|
|
||||||
allModels[modelName] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
confidenceMap := make(map[string]map[string]bool)
|
for _, ratioType := range ratioTypes {
|
||||||
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
// 预处理阶段:检查pricing接口的可信度
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
for _, channel := range successfulChannels {
|
for modelName := range localRatio {
|
||||||
confidenceMap[channel.name] = make(map[string]bool)
|
allModels[modelName] = struct{}{}
|
||||||
|
}
|
||||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
}
|
||||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
}
|
||||||
|
}
|
||||||
if hasModelRatio && hasCompletionRatio {
|
|
||||||
// 遍历所有模型,检查是否满足不可信条件
|
|
||||||
for modelName := range allModels {
|
|
||||||
// 默认为可信
|
|
||||||
confidenceMap[channel.name][modelName] = true
|
|
||||||
|
|
||||||
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
|
||||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
|
||||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
|
||||||
// 转换为float64进行比较
|
|
||||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
|
||||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
|
||||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
|
||||||
confidenceMap[channel.name][modelName] = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
|
||||||
for modelName := range allModels {
|
|
||||||
confidenceMap[channel.name][modelName] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for modelName := range allModels {
|
for _, channel := range successfulChannels {
|
||||||
for _, ratioType := range ratioTypes {
|
for _, ratioType := range ratioTypes {
|
||||||
var localValue interface{} = nil
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
if localRatioAny, ok := localData[ratioType]; ok {
|
for modelName := range upstreamRatio {
|
||||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
allModels[modelName] = struct{}{}
|
||||||
if val, exists := localRatio[modelName]; exists {
|
}
|
||||||
localValue = val
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
upstreamValues := make(map[string]interface{})
|
confidenceMap := make(map[string]map[string]bool)
|
||||||
confidenceValues := make(map[string]bool)
|
|
||||||
hasUpstreamValue := false
|
|
||||||
hasDifference := false
|
|
||||||
|
|
||||||
for _, channel := range successfulChannels {
|
// 预处理阶段:检查pricing接口的可信度
|
||||||
var upstreamValue interface{} = nil
|
for _, channel := range successfulChannels {
|
||||||
|
confidenceMap[channel.name] = make(map[string]bool)
|
||||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
|
||||||
if val, exists := upstreamRatio[modelName]; exists {
|
|
||||||
upstreamValue = val
|
|
||||||
hasUpstreamValue = true
|
|
||||||
|
|
||||||
if localValue != nil && localValue != val {
|
|
||||||
hasDifference = true
|
|
||||||
} else if localValue == val {
|
|
||||||
upstreamValue = "same"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if upstreamValue == nil && localValue == nil {
|
|
||||||
upstreamValue = "same"
|
|
||||||
}
|
|
||||||
|
|
||||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
|
||||||
hasDifference = true
|
|
||||||
}
|
|
||||||
|
|
||||||
upstreamValues[channel.name] = upstreamValue
|
|
||||||
|
|
||||||
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldInclude := false
|
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||||
|
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||||
if localValue != nil {
|
|
||||||
if hasDifference {
|
|
||||||
shouldInclude = true
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if hasUpstreamValue {
|
|
||||||
shouldInclude = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldInclude {
|
if hasModelRatio && hasCompletionRatio {
|
||||||
if differences[modelName] == nil {
|
// 遍历所有模型,检查是否满足不可信条件
|
||||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
for modelName := range allModels {
|
||||||
}
|
// 默认为可信
|
||||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
confidenceMap[channel.name][modelName] = true
|
||||||
Current: localValue,
|
|
||||||
Upstreams: upstreamValues,
|
|
||||||
Confidence: confidenceValues,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
channelHasDiff := make(map[string]bool)
|
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||||
for _, ratioMap := range differences {
|
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||||
for _, item := range ratioMap {
|
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||||
for chName, val := range item.Upstreams {
|
// 转换为float64进行比较
|
||||||
if val != nil && val != "same" {
|
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||||
channelHasDiff[chName] = true
|
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||||
}
|
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||||
}
|
confidenceMap[channel.name][modelName] = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||||
|
for modelName := range allModels {
|
||||||
|
confidenceMap[channel.name][modelName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for modelName, ratioMap := range differences {
|
for modelName := range allModels {
|
||||||
for ratioType, item := range ratioMap {
|
for _, ratioType := range ratioTypes {
|
||||||
for chName := range item.Upstreams {
|
var localValue interface{} = nil
|
||||||
if !channelHasDiff[chName] {
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
delete(item.Upstreams, chName)
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
delete(item.Confidence, chName)
|
if val, exists := localRatio[modelName]; exists {
|
||||||
}
|
localValue = val
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
allSame := true
|
upstreamValues := make(map[string]interface{})
|
||||||
for _, v := range item.Upstreams {
|
confidenceValues := make(map[string]bool)
|
||||||
if v != "same" {
|
hasUpstreamValue := false
|
||||||
allSame = false
|
hasDifference := false
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(item.Upstreams) == 0 || allSame {
|
|
||||||
delete(ratioMap, ratioType)
|
|
||||||
} else {
|
|
||||||
differences[modelName][ratioType] = item
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ratioMap) == 0 {
|
for _, channel := range successfulChannels {
|
||||||
delete(differences, modelName)
|
var upstreamValue interface{} = nil
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return differences
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
|
if val, exists := upstreamRatio[modelName]; exists {
|
||||||
|
upstreamValue = val
|
||||||
|
hasUpstreamValue = true
|
||||||
|
|
||||||
|
if localValue != nil && !valuesEqual(localValue, val) {
|
||||||
|
hasDifference = true
|
||||||
|
} else if valuesEqual(localValue, val) {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if upstreamValue == nil && localValue == nil {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
|
||||||
|
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||||
|
hasDifference = true
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamValues[channel.name] = upstreamValue
|
||||||
|
|
||||||
|
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldInclude := false
|
||||||
|
|
||||||
|
if localValue != nil {
|
||||||
|
if hasDifference {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if hasUpstreamValue {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldInclude {
|
||||||
|
if differences[modelName] == nil {
|
||||||
|
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||||
|
}
|
||||||
|
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||||
|
Current: localValue,
|
||||||
|
Upstreams: upstreamValues,
|
||||||
|
Confidence: confidenceValues,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
channelHasDiff := make(map[string]bool)
|
||||||
|
for _, ratioMap := range differences {
|
||||||
|
for _, item := range ratioMap {
|
||||||
|
for chName, val := range item.Upstreams {
|
||||||
|
if val != nil && val != "same" {
|
||||||
|
channelHasDiff[chName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelName, ratioMap := range differences {
|
||||||
|
for ratioType, item := range ratioMap {
|
||||||
|
for chName := range item.Upstreams {
|
||||||
|
if !channelHasDiff[chName] {
|
||||||
|
delete(item.Upstreams, chName)
|
||||||
|
delete(item.Confidence, chName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allSame := true
|
||||||
|
for _, v := range item.Upstreams {
|
||||||
|
if v != "same" {
|
||||||
|
allSame = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(item.Upstreams) == 0 || allSame {
|
||||||
|
delete(ratioMap, ratioType)
|
||||||
|
} else {
|
||||||
|
differences[modelName][ratioType] = item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ratioMap) == 0 {
|
||||||
|
delete(differences, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return differences
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSyncableChannels(c *gin.Context) {
|
func GetSyncableChannels(c *gin.Context) {
|
||||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var syncableChannels []dto.SyncableChannel
|
var syncableChannels []dto.SyncableChannel
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||||
ID: channel.Id,
|
ID: channel.Id,
|
||||||
Name: channel.Name,
|
Name: channel.Name,
|
||||||
BaseURL: channel.GetBaseURL(),
|
BaseURL: channel.GetBaseURL(),
|
||||||
Status: channel.Status,
|
Status: channel.Status,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||||
"success": true,
|
ID: -100,
|
||||||
"message": "",
|
Name: "官方倍率预设",
|
||||||
"data": syncableChannels,
|
BaseURL: "https://basellm.github.io",
|
||||||
})
|
Status: 1,
|
||||||
}
|
})
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": syncableChannels,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -63,7 +64,7 @@ func AddRedemption(c *gin.Context) {
|
|||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "兑换码名称长度必须在1-20之间",
|
"message": "兑换码名称长度必须在1-20之间",
|
||||||
|
|||||||
@@ -2,115 +2,193 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
constant2 "one-api/constant"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||||||
var err *types.NewAPIError
|
var err *types.NewAPIError
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||||
err = relay.ImageHelper(c)
|
err = relay.ImageHelper(c, info)
|
||||||
case relayconstant.RelayModeAudioSpeech:
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranslation:
|
case relayconstant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err = relay.AudioHelper(c)
|
err = relay.AudioHelper(c, info)
|
||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err = relay.RerankHelper(c, relayMode)
|
err = relay.RerankHelper(c, info)
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
err = relay.EmbeddingHelper(c)
|
err = relay.EmbeddingHelper(c, info)
|
||||||
case relayconstant.RelayModeResponses:
|
case relayconstant.RelayModeResponses:
|
||||||
err = relay.ResponsesHelper(c)
|
err = relay.ResponsesHelper(c, info)
|
||||||
case relayconstant.RelayModeGemini:
|
|
||||||
err = relay.GeminiHelper(c)
|
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
if constant2.ErrorLogEnabled && err != nil {
|
|
||||||
// 保存错误日志到mysql中
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
modelName := c.GetString("original_model")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userGroup := c.GetString("group")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
other := make(map[string]interface{})
|
|
||||||
other["error_type"] = err.ErrorType
|
|
||||||
other["error_code"] = err.GetErrorCode()
|
|
||||||
other["status_code"] = err.StatusCode
|
|
||||||
other["channel_id"] = channelId
|
|
||||||
other["channel_name"] = c.GetString("channel_name")
|
|
||||||
other["channel_type"] = c.GetInt("channel_type")
|
|
||||||
|
|
||||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
var err *types.NewAPIError
|
||||||
|
if strings.Contains(c.Request.URL.Path, "embed") {
|
||||||
|
err = relay.GeminiEmbeddingHandler(c, info)
|
||||||
|
} else {
|
||||||
|
err = relay.GeminiHelper(c, info)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||||
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||||
var newAPIError *types.NewAPIError
|
|
||||||
|
var (
|
||||||
|
newAPIError *types.NewAPIError
|
||||||
|
ws *websocket.Conn
|
||||||
|
)
|
||||||
|
|
||||||
|
if relayFormat == types.RelayFormatOpenAIRealtime {
|
||||||
|
var err error
|
||||||
|
ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if newAPIError != nil {
|
||||||
|
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||||
|
switch relayFormat {
|
||||||
|
case types.RelayFormatOpenAIRealtime:
|
||||||
|
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||||
|
case types.RelayFormatClaude:
|
||||||
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": newAPIError.ToClaudeError(),
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
|
"error": newAPIError.ToOpenAIError(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
request, err := helper.GetAndValidateRequest(c, relayFormat)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := request.GetTokenCountMeta()
|
||||||
|
|
||||||
|
if setting.ShouldCheckPromptSensitive() {
|
||||||
|
contains, words := service.CheckSensitiveText(meta.CombineText)
|
||||||
|
if contains {
|
||||||
|
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := service.CountRequestToken(c, meta, relayInfo)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayInfo.SetPromptTokens(tokens)
|
||||||
|
|
||||||
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||||
|
|
||||||
|
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
|
if newAPIError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
// Only return quota if downstream failed and quota was actually pre-consumed
|
||||||
|
if newAPIError != nil && preConsumedQuota != 0 {
|
||||||
|
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
logger.LogError(c, err.Error())
|
||||||
newAPIError = err
|
newAPIError = err
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
newAPIError = relayRequest(c, relayMode, channel)
|
addUsedChannel(c, channel.Id)
|
||||||
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|
||||||
if newAPIError == nil {
|
switch relayFormat {
|
||||||
return // 成功处理请求,直接返回
|
case types.RelayFormatOpenAIRealtime:
|
||||||
|
newAPIError = relay.WssHelper(c, relayInfo)
|
||||||
|
case types.RelayFormatClaude:
|
||||||
|
newAPIError = relay.ClaudeHelper(c, relayInfo)
|
||||||
|
case types.RelayFormatGemini:
|
||||||
|
newAPIError = geminiRelayHandler(c, relayInfo)
|
||||||
|
default:
|
||||||
|
newAPIError = relayHandler(c, relayInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
if newAPIError == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c, retryLogStr)
|
logger.LogInfo(c, retryLogStr)
|
||||||
}
|
|
||||||
|
|
||||||
if newAPIError != nil {
|
|
||||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
|
||||||
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
|
||||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
|
||||||
//}
|
|
||||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
|
||||||
c.JSON(newAPIError.StatusCode, gin.H{
|
|
||||||
"error": newAPIError.ToOpenAIError(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,122 +199,6 @@ var upgrader = websocket.Upgrader{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func WssRelay(c *gin.Context) {
|
|
||||||
// 将 HTTP 连接升级为 WebSocket 连接
|
|
||||||
|
|
||||||
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
||||||
defer ws.Close()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
|
||||||
group := c.GetString("group")
|
|
||||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
|
||||||
originalModel := c.GetString("original_model")
|
|
||||||
var newAPIError *types.NewAPIError
|
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, err.Error())
|
|
||||||
newAPIError = err
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
newAPIError = wssRequest(c, ws, relayMode, channel)
|
|
||||||
|
|
||||||
if newAPIError == nil {
|
|
||||||
return // 成功处理请求,直接返回
|
|
||||||
}
|
|
||||||
|
|
||||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
if len(useChannel) > 1 {
|
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
|
||||||
common.LogInfo(c, retryLogStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if newAPIError != nil {
|
|
||||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
|
||||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
|
||||||
//}
|
|
||||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
|
||||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func RelayClaude(c *gin.Context) {
|
|
||||||
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
|
||||||
group := c.GetString("group")
|
|
||||||
originalModel := c.GetString("original_model")
|
|
||||||
var newAPIError *types.NewAPIError
|
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, err.Error())
|
|
||||||
newAPIError = err
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
newAPIError = claudeRequest(c, channel)
|
|
||||||
|
|
||||||
if newAPIError == nil {
|
|
||||||
return // 成功处理请求,直接返回
|
|
||||||
}
|
|
||||||
|
|
||||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
if len(useChannel) > 1 {
|
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
|
||||||
common.LogInfo(c, retryLogStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if newAPIError != nil {
|
|
||||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
|
||||||
c.JSON(newAPIError.StatusCode, gin.H{
|
|
||||||
"type": "error",
|
|
||||||
"error": newAPIError.ToClaudeError(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
|
|
||||||
addUsedChannel(c, channel.Id)
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return relayHandler(c, relayMode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
|
|
||||||
addUsedChannel(c, channel.Id)
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return relay.WssHelper(c, ws)
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
|
|
||||||
addUsedChannel(c, channel.Id)
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return relay.ClaudeHelper(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addUsedChannel(c *gin.Context, channelId int) {
|
func addUsedChannel(c *gin.Context, channelId int) {
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
@@ -259,10 +221,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
|||||||
}
|
}
|
||||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if group == "auto" {
|
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
}
|
||||||
}
|
if channel == nil {
|
||||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
@@ -278,7 +240,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
|||||||
if types.IsChannelError(openaiErr) {
|
if types.IsChannelError(openaiErr) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if types.IsLocalError(openaiErr) {
|
if types.IsSkipRetryError(openaiErr) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if retryTimes <= 0 {
|
if retryTimes <= 0 {
|
||||||
@@ -301,10 +263,6 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||||
channelType := c.GetInt("channel_type")
|
|
||||||
if channelType == constant.ChannelTypeAnthropic {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == 408 {
|
if openaiErr.StatusCode == 408 {
|
||||||
@@ -318,44 +276,84 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
|||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
|
||||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
gopool.Go(func() {
|
||||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||||
service.DisableChannel(channelError, err.Error())
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||||
|
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||||
|
service.DisableChannel(channelError, err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||||
|
// 保存错误日志到mysql中
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
modelName := c.GetString("original_model")
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userGroup := c.GetString("group")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
other := make(map[string]interface{})
|
||||||
|
other["error_type"] = err.GetErrorType()
|
||||||
|
other["error_code"] = err.GetErrorCode()
|
||||||
|
other["status_code"] = err.StatusCode
|
||||||
|
other["channel_id"] = channelId
|
||||||
|
other["channel_name"] = c.GetString("channel_name")
|
||||||
|
other["channel_type"] = c.GetInt("channel_type")
|
||||||
|
adminInfo := make(map[string]interface{})
|
||||||
|
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||||
|
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||||
|
if isMultiKey {
|
||||||
|
adminInfo["is_multi_key"] = true
|
||||||
|
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||||
|
}
|
||||||
|
other["admin_info"] = adminInfo
|
||||||
|
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourney(c *gin.Context) {
|
func RelayMidjourney(c *gin.Context) {
|
||||||
relayMode := c.GetInt("relay_mode")
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
|
||||||
var err *dto.MidjourneyResponse
|
|
||||||
switch relayMode {
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
|
||||||
|
"type": "upstream_error",
|
||||||
|
"code": 4,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var mjErr *dto.MidjourneyResponse
|
||||||
|
switch relayInfo.RelayMode {
|
||||||
case relayconstant.RelayModeMidjourneyNotify:
|
case relayconstant.RelayModeMidjourneyNotify:
|
||||||
err = relay.RelayMidjourneyNotify(c)
|
mjErr = relay.RelayMidjourneyNotify(c)
|
||||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
|
||||||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||||
err = relay.RelayMidjourneyTaskImageSeed(c)
|
mjErr = relay.RelayMidjourneyTaskImageSeed(c)
|
||||||
case relayconstant.RelayModeSwapFace:
|
case relayconstant.RelayModeSwapFace:
|
||||||
err = relay.RelaySwapFace(c)
|
mjErr = relay.RelaySwapFace(c, relayInfo)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
|
||||||
}
|
}
|
||||||
//err = relayMidjourneySubmit(c, relayMode)
|
//err = relayMidjourneySubmit(c, relayMode)
|
||||||
log.Println(err)
|
log.Println(mjErr)
|
||||||
if err != nil {
|
if mjErr != nil {
|
||||||
statusCode := http.StatusBadRequest
|
statusCode := http.StatusBadRequest
|
||||||
if err.Code == 30 {
|
if mjErr.Code == 30 {
|
||||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||||
statusCode = http.StatusTooManyRequests
|
statusCode = http.StatusTooManyRequests
|
||||||
}
|
}
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
"description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
|
||||||
"type": "upstream_error",
|
"type": "upstream_error",
|
||||||
"code": err.Code,
|
"code": mjErr.Code,
|
||||||
})
|
})
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -386,18 +384,21 @@ func RelayNotFound(c *gin.Context) {
|
|||||||
func RelayTask(c *gin.Context) {
|
func RelayTask(c *gin.Context) {
|
||||||
retryTimes := common.RetryTimes
|
retryTimes := common.RetryTimes
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
relayMode := c.GetInt("relay_mode")
|
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||||
taskErr := taskRelayHandler(c, relayMode)
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
taskErr := taskRelayHandler(c, relayInfo)
|
||||||
if taskErr == nil {
|
if taskErr == nil {
|
||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, newAPIError := getChannel(c, group, originalModel, i)
|
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -405,17 +406,17 @@ func RelayTask(c *gin.Context) {
|
|||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
taskErr = taskRelayHandler(c, relayMode)
|
taskErr = taskRelayHandler(c, relayInfo)
|
||||||
}
|
}
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c, retryLogStr)
|
logger.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
if taskErr != nil {
|
if taskErr != nil {
|
||||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||||
@@ -425,13 +426,13 @@ func RelayTask(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
|
||||||
var err *dto.TaskError
|
var err *dto.TaskError
|
||||||
switch relayMode {
|
switch relayInfo.RelayMode {
|
||||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||||
err = relay.RelayTaskFetch(c, relayMode)
|
err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayTaskSubmit(c, relayMode)
|
err = relay.RelayTaskSubmit(c, relayInfo)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,3 +114,23 @@ type KlingImage2VideoRequest struct {
|
|||||||
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||||
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
|
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KlingImage2videoTaskId godoc
|
||||||
|
// @Summary 可灵任务查询--图生视频
|
||||||
|
// @Description Query the status and result of a Kling video generation task by task ID
|
||||||
|
// @Tags Origin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param task_id path string true "Task ID"
|
||||||
|
// @Router /kling/v1/videos/image2video/{task_id} [get]
|
||||||
|
func KlingImage2videoTaskId(c *gin.Context) {}
|
||||||
|
|
||||||
|
// KlingText2videoTaskId godoc
|
||||||
|
// @Summary 可灵任务查询--文生视频
|
||||||
|
// @Description Query the status and result of a Kling text-to-video generation task by task ID
|
||||||
|
// @Tags Origin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param task_id path string true "Task ID"
|
||||||
|
// @Router /kling/v1/videos/text2video/{task_id} [get]
|
||||||
|
func KlingText2videoTaskId(c *gin.Context) {}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -54,9 +55,9 @@ func UpdateTaskBulk() {
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||||
} else {
|
} else {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(taskChannelM) == 0 {
|
if len(taskChannelM) == 0 {
|
||||||
@@ -75,10 +76,10 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
|||||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||||
case constant.TaskPlatformSuno:
|
case constant.TaskPlatformSuno:
|
||||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
|
|
||||||
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
|
|
||||||
default:
|
default:
|
||||||
common.SysLog("未知平台")
|
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,14 +87,14 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
|
|||||||
for channelId, taskIds := range taskChannelM {
|
for channelId, taskIds := range taskChannelM {
|
||||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||||
if len(taskIds) == 0 {
|
if len(taskIds) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -106,7 +107,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -118,23 +119,23 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
"ids": taskIds,
|
"ids": taskIds,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !responseItems.IsSuccess() {
|
if !responseItems.IsSuccess() {
|
||||||
@@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||||
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
logger.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
} else {
|
} else {
|
||||||
quota := task.Quota
|
quota := task.Quota
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
|
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
|
|
||||||
err = task.Update()
|
err = task.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,27 +2,31 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||||
for channelId, taskIds := range taskChannelM {
|
for channelId, taskIds := range taskChannelM {
|
||||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||||
if len(taskIds) == 0 {
|
if len(taskIds) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -34,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if errUpdate != nil {
|
if errUpdate != nil {
|
||||||
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||||
}
|
}
|
||||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -44,7 +48,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
|||||||
}
|
}
|
||||||
for _, taskId := range taskIds {
|
for _, taskId := range taskIds {
|
||||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -58,7 +62,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
|
|
||||||
task := taskM[taskId]
|
task := taskM[taskId]
|
||||||
if task == nil {
|
if task == nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||||
return fmt.Errorf("task %s not found", taskId)
|
return fmt.Errorf("task %s not found", taskId)
|
||||||
}
|
}
|
||||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||||
@@ -77,13 +81,21 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
taskResult, err := adaptor.ParseTaskResult(responseBody)
|
taskResult := &relaycommon.TaskInfo{}
|
||||||
if err != nil {
|
// try parse as New API response format
|
||||||
|
var responseItems dto.TaskResponse[model.Task]
|
||||||
|
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||||
|
t := responseItems.Data
|
||||||
|
taskResult.TaskID = t.TaskID
|
||||||
|
taskResult.Status = string(t.Status)
|
||||||
|
taskResult.Url = t.FailReason
|
||||||
|
taskResult.Progress = t.Progress
|
||||||
|
taskResult.Reason = t.FailReason
|
||||||
|
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||||
|
} else {
|
||||||
|
task.Data = responseBody
|
||||||
}
|
}
|
||||||
//if taskResult.Code != 0 {
|
|
||||||
// return fmt.Errorf("video task fetch failed for task %s", taskId)
|
|
||||||
//}
|
|
||||||
|
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
if taskResult.Status == "" {
|
if taskResult.Status == "" {
|
||||||
@@ -113,13 +125,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
task.FinishTime = now
|
task.FinishTime = now
|
||||||
}
|
}
|
||||||
task.FailReason = taskResult.Reason
|
task.FailReason = taskResult.Reason
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||||
quota := task.Quota
|
quota := task.Quota
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||||
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -128,10 +140,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
if taskResult.Progress != "" {
|
if taskResult.Progress != "" {
|
||||||
task.Progress = taskResult.Progress
|
task.Progress = taskResult.Progress
|
||||||
}
|
}
|
||||||
|
|
||||||
task.Data = responseBody
|
|
||||||
if err := task.Update(); err != nil {
|
if err := task.Update(); err != nil {
|
||||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -82,6 +83,57 @@ func GetTokenStatus(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetTokenUsage(c *gin.Context) {
|
||||||
|
authHeader := c.GetHeader("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "No Authorization header",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(authHeader, " ")
|
||||||
|
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "Invalid Bearer token",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenKey := parts[1]
|
||||||
|
|
||||||
|
token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
expiredAt := token.ExpiredTime
|
||||||
|
if expiredAt == -1 {
|
||||||
|
expiredAt = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": true,
|
||||||
|
"message": "ok",
|
||||||
|
"data": gin.H{
|
||||||
|
"object": "token_usage",
|
||||||
|
"name": token.Name,
|
||||||
|
"total_granted": token.RemainQuota + token.UsedQuota,
|
||||||
|
"total_used": token.UsedQuota,
|
||||||
|
"total_available": token.RemainQuota,
|
||||||
|
"unlimited_quota": token.UnlimitedQuota,
|
||||||
|
"model_limits": token.GetModelLimitsMap(),
|
||||||
|
"model_limits_enabled": token.ModelLimitsEnabled,
|
||||||
|
"expires_at": expiredAt,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func AddToken(c *gin.Context) {
|
func AddToken(c *gin.Context) {
|
||||||
token := model.Token{}
|
token := model.Token{}
|
||||||
err := c.ShouldBindJSON(&token)
|
err := c.ShouldBindJSON(&token)
|
||||||
@@ -102,7 +154,7 @@ func AddToken(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成令牌失败",
|
"message": "生成令牌失败",
|
||||||
})
|
})
|
||||||
common.SysError("failed to generate token key: " + err.Error())
|
common.SysLog("failed to generate token key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cleanToken := model.Token{
|
cleanToken := model.Token{
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
@@ -231,7 +232,7 @@ func EpayNotify(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
|
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||||
|
|||||||
553
controller/twofa.go
Normal file
553
controller/twofa.go
Normal file
@@ -0,0 +1,553 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Setup2FARequest 设置2FA请求结构
|
||||||
|
type Setup2FARequest struct {
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify2FARequest 验证2FA请求结构
|
||||||
|
type Verify2FARequest struct {
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup2FAResponse 设置2FA响应结构
|
||||||
|
type Setup2FAResponse struct {
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
QRCodeData string `json:"qr_code_data"`
|
||||||
|
BackupCodes []string `json:"backup_codes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup2FA 初始化2FA设置
|
||||||
|
func Setup2FA(c *gin.Context) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 检查用户是否已经启用2FA
|
||||||
|
existing, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing != nil && existing.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户已启用2FA,请先禁用后重新设置",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果存在已禁用的2FA记录,先删除它
|
||||||
|
if existing != nil && !existing.IsEnabled {
|
||||||
|
if err := existing.Delete(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
existing = nil // 重置为nil,后续将创建新记录
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户信息
|
||||||
|
user, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成TOTP密钥
|
||||||
|
key, err := common.GenerateTOTPSecret(user.Username)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "生成2FA密钥失败",
|
||||||
|
})
|
||||||
|
common.SysLog("生成TOTP密钥失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成备用码
|
||||||
|
backupCodes, err := common.GenerateBackupCodes()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "生成备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("生成备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成二维码数据
|
||||||
|
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
|
||||||
|
|
||||||
|
// 创建或更新2FA记录(暂未启用)
|
||||||
|
twoFA := &model.TwoFA{
|
||||||
|
UserId: userId,
|
||||||
|
Secret: key.Secret(),
|
||||||
|
IsEnabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if existing != nil {
|
||||||
|
// 更新现有记录
|
||||||
|
twoFA.Id = existing.Id
|
||||||
|
err = twoFA.Update()
|
||||||
|
} else {
|
||||||
|
// 创建新记录
|
||||||
|
err = twoFA.Create()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建备用码记录
|
||||||
|
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "保存备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("保存备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
|
||||||
|
"data": Setup2FAResponse{
|
||||||
|
Secret: key.Secret(),
|
||||||
|
QRCodeData: qrCodeData,
|
||||||
|
BackupCodes: backupCodes,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable2FA 启用2FA
|
||||||
|
func Enable2FA(c *gin.Context) {
|
||||||
|
var req Setup2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "请先完成2FA初始化设置",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "2FA已经启用",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启用2FA
|
||||||
|
if err := twoFA.Enable(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "两步验证启用成功",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable2FA 禁用2FA
|
||||||
|
func Disable2FA(c *gin.Context) {
|
||||||
|
var req Verify2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码或备用码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
isValidTOTP := false
|
||||||
|
isValidBackup := false
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// 尝试验证TOTP
|
||||||
|
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP {
|
||||||
|
// 尝试验证备用码
|
||||||
|
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP && !isValidBackup {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 禁用2FA
|
||||||
|
if err := model.DisableTwoFA(userId); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "两步验证已禁用",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FAStatus 获取用户2FA状态
|
||||||
|
func Get2FAStatus(c *gin.Context) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status := map[string]interface{}{
|
||||||
|
"enabled": false,
|
||||||
|
"locked": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if twoFA != nil {
|
||||||
|
status["enabled"] = twoFA.IsEnabled
|
||||||
|
status["locked"] = twoFA.IsLocked()
|
||||||
|
if twoFA.IsEnabled {
|
||||||
|
// 获取剩余备用码数量
|
||||||
|
backupCount, err := model.GetUnusedBackupCodeCount(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("获取备用码数量失败: " + err.Error())
|
||||||
|
} else {
|
||||||
|
status["backup_codes_remaining"] = backupCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegenerateBackupCodes 重新生成备用码
|
||||||
|
func RegenerateBackupCodes(c *gin.Context) {
|
||||||
|
var req Verify2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成新的备用码
|
||||||
|
backupCodes, err := common.GenerateBackupCodes()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "生成备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("生成备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存新的备用码
|
||||||
|
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "保存备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("保存备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "备用码重新生成成功",
|
||||||
|
"data": map[string]interface{}{
|
||||||
|
"backup_codes": backupCodes,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify2FALogin 登录时验证2FA
|
||||||
|
func Verify2FALogin(c *gin.Context) {
|
||||||
|
var req Verify2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从会话中获取pending用户信息
|
||||||
|
session := sessions.Default(c)
|
||||||
|
pendingUserId := session.Get("pending_user_id")
|
||||||
|
if pendingUserId == nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "会话已过期,请重新登录",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId, ok := pendingUserId.(int)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "会话数据无效,请重新登录",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 获取用户信息
|
||||||
|
user, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户不存在",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(user.Id)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码或备用码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
isValidTOTP := false
|
||||||
|
isValidBackup := false
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// 尝试验证TOTP
|
||||||
|
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP {
|
||||||
|
// 尝试验证备用码
|
||||||
|
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP && !isValidBackup {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2FA验证成功,清理pending会话信息并完成登录
|
||||||
|
session.Delete("pending_username")
|
||||||
|
session.Delete("pending_user_id")
|
||||||
|
session.Save()
|
||||||
|
|
||||||
|
setupLogin(user, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admin2FAStats 管理员获取2FA统计信息
|
||||||
|
func Admin2FAStats(c *gin.Context) {
|
||||||
|
stats, err := model.GetTwoFAStats()
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": stats,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminDisable2FA 管理员强制禁用用户2FA
|
||||||
|
func AdminDisable2FA(c *gin.Context) {
|
||||||
|
userIdStr := c.Param("id")
|
||||||
|
userId, err := strconv.Atoi(userIdStr)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户ID格式错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查目标用户权限
|
||||||
|
targetUser, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
myRole := c.GetInt("role")
|
||||||
|
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无权操作同级或更高级用户的2FA设置",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 禁用2FA
|
||||||
|
if err := model.DisableTwoFA(userId); err != nil {
|
||||||
|
if errors.Is(err, model.ErrTwoFANotEnabled) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
adminId := c.GetInt("id")
|
||||||
|
model.RecordLog(userId, model.LogTypeManage,
|
||||||
|
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "用户2FA已被强制禁用",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -31,7 +31,7 @@ type Monitor struct {
|
|||||||
|
|
||||||
type UptimeGroupResult struct {
|
type UptimeGroupResult struct {
|
||||||
CategoryName string `json:"categoryName"`
|
CategoryName string `json:"categoryName"`
|
||||||
Monitors []Monitor `json:"monitors"`
|
Monitors []Monitor `json:"monitors"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||||
@@ -57,29 +57,29 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
|||||||
url, _ := groupConfig["url"].(string)
|
url, _ := groupConfig["url"].(string)
|
||||||
slug, _ := groupConfig["slug"].(string)
|
slug, _ := groupConfig["slug"].(string)
|
||||||
categoryName, _ := groupConfig["categoryName"].(string)
|
categoryName, _ := groupConfig["categoryName"].(string)
|
||||||
|
|
||||||
result := UptimeGroupResult{
|
result := UptimeGroupResult{
|
||||||
CategoryName: categoryName,
|
CategoryName: categoryName,
|
||||||
Monitors: []Monitor{},
|
Monitors: []Monitor{},
|
||||||
}
|
}
|
||||||
|
|
||||||
if url == "" || slug == "" {
|
if url == "" || slug == "" {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := strings.TrimSuffix(url, "/")
|
baseURL := strings.TrimSuffix(url, "/")
|
||||||
|
|
||||||
var statusData struct {
|
var statusData struct {
|
||||||
PublicGroupList []struct {
|
PublicGroupList []struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
MonitorList []struct {
|
MonitorList []struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
} `json:"monitorList"`
|
} `json:"monitorList"`
|
||||||
} `json:"publicGroupList"`
|
} `json:"publicGroupList"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var heartbeatData struct {
|
var heartbeatData struct {
|
||||||
HeartbeatList map[string][]struct {
|
HeartbeatList map[string][]struct {
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
@@ -88,11 +88,11 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
|||||||
}
|
}
|
||||||
|
|
||||||
g, gCtx := errgroup.WithContext(ctx)
|
g, gCtx := errgroup.WithContext(ctx)
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||||
})
|
})
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||||
})
|
})
|
||||||
|
|
||||||
if g.Wait() != nil {
|
if g.Wait() != nil {
|
||||||
@@ -139,7 +139,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
|||||||
|
|
||||||
client := &http.Client{Timeout: httpTimeout}
|
client := &http.Client{Timeout: httpTimeout}
|
||||||
results := make([]UptimeGroupResult, len(groups))
|
results := make([]UptimeGroupResult, len(groups))
|
||||||
|
|
||||||
g, gCtx := errgroup.WithContext(ctx)
|
g, gCtx := errgroup.WithContext(ctx)
|
||||||
for i, group := range groups {
|
for i, group := range groups {
|
||||||
i, group := i, group
|
i, group := i, group
|
||||||
@@ -148,7 +148,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Wait()
|
g.Wait()
|
||||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -62,6 +63,32 @@ func Login(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否启用2FA
|
||||||
|
if model.IsTwoFAEnabled(user.Id) {
|
||||||
|
// 设置pending session,等待2FA验证
|
||||||
|
session := sessions.Default(c)
|
||||||
|
session.Set("pending_username", user.Username)
|
||||||
|
session.Set("pending_user_id", user.Id)
|
||||||
|
err := session.Save()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "无法保存会话信息,请重试",
|
||||||
|
"success": false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "请输入两步验证码",
|
||||||
|
"success": true,
|
||||||
|
"data": map[string]interface{}{
|
||||||
|
"require_2fa": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
setupLogin(&user, c)
|
setupLogin(&user, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +193,7 @@ func Register(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "数据库错误,请稍后重试",
|
"message": "数据库错误,请稍后重试",
|
||||||
})
|
})
|
||||||
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if exist {
|
if exist {
|
||||||
@@ -183,6 +210,7 @@ func Register(c *gin.Context) {
|
|||||||
Password: user.Password,
|
Password: user.Password,
|
||||||
DisplayName: user.Username,
|
DisplayName: user.Username,
|
||||||
InviterId: inviterId,
|
InviterId: inviterId,
|
||||||
|
Role: common.RoleCommonUser, // 明确设置角色为普通用户
|
||||||
}
|
}
|
||||||
if common.EmailVerificationEnabled {
|
if common.EmailVerificationEnabled {
|
||||||
cleanUser.Email = user.Email
|
cleanUser.Email = user.Email
|
||||||
@@ -209,7 +237,7 @@ func Register(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成默认令牌失败",
|
"message": "生成默认令牌失败",
|
||||||
})
|
})
|
||||||
common.SysError("failed to generate token key: " + err.Error())
|
common.SysLog("failed to generate token key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 生成默认令牌
|
// 生成默认令牌
|
||||||
@@ -316,7 +344,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成失败",
|
"message": "生成失败",
|
||||||
})
|
})
|
||||||
common.SysError("failed to generate key: " + err.Error())
|
common.SysLog("failed to generate key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.SetAccessToken(key)
|
user.SetAccessToken(key)
|
||||||
@@ -399,6 +427,7 @@ func GetAffCode(c *gin.Context) {
|
|||||||
|
|
||||||
func GetSelf(c *gin.Context) {
|
func GetSelf(c *gin.Context) {
|
||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
|
userRole := c.GetInt("role")
|
||||||
user, err := model.GetUserById(id, false)
|
user, err := model.GetUserById(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
@@ -407,14 +436,134 @@ func GetSelf(c *gin.Context) {
|
|||||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||||
user.Remark = ""
|
user.Remark = ""
|
||||||
|
|
||||||
|
// 计算用户权限信息
|
||||||
|
permissions := calculateUserPermissions(userRole)
|
||||||
|
|
||||||
|
// 获取用户设置并提取sidebar_modules
|
||||||
|
userSetting := user.GetSetting()
|
||||||
|
|
||||||
|
// 构建响应数据,包含用户信息和权限
|
||||||
|
responseData := map[string]interface{}{
|
||||||
|
"id": user.Id,
|
||||||
|
"username": user.Username,
|
||||||
|
"display_name": user.DisplayName,
|
||||||
|
"role": user.Role,
|
||||||
|
"status": user.Status,
|
||||||
|
"email": user.Email,
|
||||||
|
"group": user.Group,
|
||||||
|
"quota": user.Quota,
|
||||||
|
"used_quota": user.UsedQuota,
|
||||||
|
"request_count": user.RequestCount,
|
||||||
|
"aff_code": user.AffCode,
|
||||||
|
"aff_count": user.AffCount,
|
||||||
|
"aff_quota": user.AffQuota,
|
||||||
|
"aff_history_quota": user.AffHistoryQuota,
|
||||||
|
"inviter_id": user.InviterId,
|
||||||
|
"linux_do_id": user.LinuxDOId,
|
||||||
|
"setting": user.Setting,
|
||||||
|
"stripe_customer": user.StripeCustomer,
|
||||||
|
"sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段
|
||||||
|
"permissions": permissions, // 新增权限字段
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": user,
|
"data": responseData,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 计算用户权限的辅助函数
|
||||||
|
func calculateUserPermissions(userRole int) map[string]interface{} {
|
||||||
|
permissions := map[string]interface{}{}
|
||||||
|
|
||||||
|
// 根据用户角色计算权限
|
||||||
|
if userRole == common.RoleRootUser {
|
||||||
|
// 超级管理员不需要边栏设置功能
|
||||||
|
permissions["sidebar_settings"] = false
|
||||||
|
permissions["sidebar_modules"] = map[string]interface{}{}
|
||||||
|
} else if userRole == common.RoleAdminUser {
|
||||||
|
// 管理员可以设置边栏,但不包含系统设置功能
|
||||||
|
permissions["sidebar_settings"] = true
|
||||||
|
permissions["sidebar_modules"] = map[string]interface{}{
|
||||||
|
"admin": map[string]interface{}{
|
||||||
|
"setting": false, // 管理员不能访问系统设置
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 普通用户只能设置个人功能,不包含管理员区域
|
||||||
|
permissions["sidebar_settings"] = true
|
||||||
|
permissions["sidebar_modules"] = map[string]interface{}{
|
||||||
|
"admin": false, // 普通用户不能访问管理员区域
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return permissions
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据用户角色生成默认的边栏配置
|
||||||
|
func generateDefaultSidebarConfig(userRole int) string {
|
||||||
|
defaultConfig := map[string]interface{}{}
|
||||||
|
|
||||||
|
// 聊天区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["chat"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"playground": true,
|
||||||
|
"chat": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 控制台区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["console"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"detail": true,
|
||||||
|
"token": true,
|
||||||
|
"log": true,
|
||||||
|
"midjourney": true,
|
||||||
|
"task": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 个人中心区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["personal"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"topup": true,
|
||||||
|
"personal": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 管理员区域 - 根据角色决定
|
||||||
|
if userRole == common.RoleAdminUser {
|
||||||
|
// 管理员可以访问管理员区域,但不能访问系统设置
|
||||||
|
defaultConfig["admin"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"channel": true,
|
||||||
|
"models": true,
|
||||||
|
"redemption": true,
|
||||||
|
"user": true,
|
||||||
|
"setting": false, // 管理员不能访问系统设置
|
||||||
|
}
|
||||||
|
} else if userRole == common.RoleRootUser {
|
||||||
|
// 超级管理员可以访问所有功能
|
||||||
|
defaultConfig["admin"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"channel": true,
|
||||||
|
"models": true,
|
||||||
|
"redemption": true,
|
||||||
|
"user": true,
|
||||||
|
"setting": true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 普通用户不包含admin区域
|
||||||
|
|
||||||
|
// 转换为JSON字符串
|
||||||
|
configBytes, err := json.Marshal(defaultConfig)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("生成默认边栏配置失败: " + err.Error())
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(configBytes)
|
||||||
|
}
|
||||||
|
|
||||||
func GetUserModels(c *gin.Context) {
|
func GetUserModels(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -491,7 +640,7 @@ func UpdateUser(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if originUser.Quota != updatedUser.Quota {
|
if originUser.Quota != updatedUser.Quota {
|
||||||
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
|
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
@@ -501,8 +650,8 @@ func UpdateUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateSelf(c *gin.Context) {
|
func UpdateSelf(c *gin.Context) {
|
||||||
var user model.User
|
var requestData map[string]interface{}
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
err := json.NewDecoder(c.Request.Body).Decode(&requestData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -510,6 +659,60 @@ func UpdateSelf(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否是sidebar_modules更新请求
|
||||||
|
if sidebarModules, exists := requestData["sidebar_modules"]; exists {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
user, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前用户设置
|
||||||
|
currentSetting := user.GetSetting()
|
||||||
|
|
||||||
|
// 更新sidebar_modules字段
|
||||||
|
if sidebarModulesStr, ok := sidebarModules.(string); ok {
|
||||||
|
currentSetting.SidebarModules = sidebarModulesStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存更新后的设置
|
||||||
|
user.SetSetting(currentSetting)
|
||||||
|
if err := user.Update(false); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "更新设置失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "设置更新成功",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有的用户信息更新逻辑
|
||||||
|
var user model.User
|
||||||
|
requestDataBytes, err := json.Marshal(requestData)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的参数",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(requestDataBytes, &user)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的参数",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if user.Password == "" {
|
if user.Password == "" {
|
||||||
user.Password = "$I_LOVE_U" // make Validator happy :)
|
user.Password = "$I_LOVE_U" // make Validator happy :)
|
||||||
}
|
}
|
||||||
@@ -652,6 +855,7 @@ func CreateUser(c *gin.Context) {
|
|||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Password: user.Password,
|
Password: user.Password,
|
||||||
DisplayName: user.DisplayName,
|
DisplayName: user.DisplayName,
|
||||||
|
Role: user.Role, // 保持管理员设置的角色
|
||||||
}
|
}
|
||||||
if err := cleanUser.Insert(0); err != nil {
|
if err := cleanUser.Insert(0); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
@@ -817,18 +1021,64 @@ type topUpRequest struct {
|
|||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var topUpLock = sync.Mutex{}
|
var topUpLocks sync.Map
|
||||||
|
var topUpCreateLock sync.Mutex
|
||||||
|
|
||||||
|
type topUpTryLock struct {
|
||||||
|
ch chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTopUpTryLock() *topUpTryLock {
|
||||||
|
return &topUpTryLock{ch: make(chan struct{}, 1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *topUpTryLock) TryLock() bool {
|
||||||
|
select {
|
||||||
|
case l.ch <- struct{}{}:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *topUpTryLock) Unlock() {
|
||||||
|
select {
|
||||||
|
case <-l.ch:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTopUpLock(userID int) *topUpTryLock {
|
||||||
|
if v, ok := topUpLocks.Load(userID); ok {
|
||||||
|
return v.(*topUpTryLock)
|
||||||
|
}
|
||||||
|
topUpCreateLock.Lock()
|
||||||
|
defer topUpCreateLock.Unlock()
|
||||||
|
if v, ok := topUpLocks.Load(userID); ok {
|
||||||
|
return v.(*topUpTryLock)
|
||||||
|
}
|
||||||
|
l := newTopUpTryLock()
|
||||||
|
topUpLocks.Store(userID, l)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
func TopUp(c *gin.Context) {
|
func TopUp(c *gin.Context) {
|
||||||
topUpLock.Lock()
|
id := c.GetInt("id")
|
||||||
defer topUpLock.Unlock()
|
lock := getTopUpLock(id)
|
||||||
|
if !lock.TryLock() {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "充值处理中,请稍后重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer lock.Unlock()
|
||||||
req := topUpRequest{}
|
req := topUpRequest{}
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
id := c.GetInt("id")
|
|
||||||
quota, err := model.Redeem(req.Key, id)
|
quota, err := model.Redeem(req.Key, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
@@ -839,7 +1089,6 @@ func TopUp(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": quota,
|
"data": quota,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateUserSettingRequest struct {
|
type UpdateUserSettingRequest struct {
|
||||||
@@ -848,6 +1097,7 @@ type UpdateUserSettingRequest struct {
|
|||||||
WebhookUrl string `json:"webhook_url,omitempty"`
|
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||||
NotificationEmail string `json:"notification_email,omitempty"`
|
NotificationEmail string `json:"notification_email,omitempty"`
|
||||||
|
BarkUrl string `json:"bark_url,omitempty"`
|
||||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||||
RecordIpLog bool `json:"record_ip_log"`
|
RecordIpLog bool `json:"record_ip_log"`
|
||||||
}
|
}
|
||||||
@@ -863,7 +1113,7 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证预警类型
|
// 验证预警类型
|
||||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无效的预警类型",
|
"message": "无效的预警类型",
|
||||||
@@ -911,6 +1161,33 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果是Bark类型,验证Bark URL
|
||||||
|
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||||
|
if req.BarkUrl == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "Bark推送URL不能为空",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 验证URL格式
|
||||||
|
if _, err := url.ParseRequestURI(req.BarkUrl); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的Bark推送URL",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 检查是否是HTTP或HTTPS
|
||||||
|
if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "Bark推送URL必须以http://或https://开头",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
user, err := model.GetUserById(userId, true)
|
user, err := model.GetUserById(userId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -939,6 +1216,11 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
settings.NotificationEmail = req.NotificationEmail
|
settings.NotificationEmail = req.NotificationEmail
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果是Bark类型,添加Bark URL到设置中
|
||||||
|
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||||
|
settings.BarkUrl = req.BarkUrl
|
||||||
|
}
|
||||||
|
|
||||||
// 更新用户设置
|
// 更新用户设置
|
||||||
user.SetSetting(settings)
|
user.SetSetting(settings)
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
|
|||||||
124
controller/vendor_meta.go
Normal file
124
controller/vendor_meta.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAllVendors 获取供应商列表(分页)
|
||||||
|
func GetAllVendors(c *gin.Context) {
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
model.DB.Model(&model.Vendor{}).Count(&total)
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(vendors)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchVendors 搜索供应商
|
||||||
|
func SearchVendors(c *gin.Context) {
|
||||||
|
keyword := c.Query("keyword")
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(vendors)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVendorMeta 根据 ID 获取供应商
|
||||||
|
func GetVendorMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v, err := model.GetVendorByID(id)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateVendorMeta 新建供应商
|
||||||
|
func CreateVendorMeta(c *gin.Context) {
|
||||||
|
var v model.Vendor
|
||||||
|
if err := c.ShouldBindJSON(&v); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v.Name == "" {
|
||||||
|
common.ApiErrorMsg(c, "供应商名称不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 创建前先检查名称
|
||||||
|
if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.Insert(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateVendorMeta 更新供应商
|
||||||
|
func UpdateVendorMeta(c *gin.Context) {
|
||||||
|
var v model.Vendor
|
||||||
|
if err := c.ShouldBindJSON(&v); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v.Id == 0 {
|
||||||
|
common.ApiErrorMsg(c, "缺少供应商 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.Update(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteVendorMeta 删除供应商
|
||||||
|
func DeleteVendorMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, nil)
|
||||||
|
}
|
||||||
@@ -16,7 +16,7 @@ services:
|
|||||||
- REDIS_CONN_STRING=redis://redis
|
- REDIS_CONN_STRING=redis://redis
|
||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||||
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
||||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||||
|
|||||||
24
dto/audio.go
24
dto/audio.go
@@ -1,5 +1,11 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
type AudioRequest struct {
|
type AudioRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input string `json:"input"`
|
Input string `json:"input"`
|
||||||
@@ -8,6 +14,24 @@ type AudioRequest struct {
|
|||||||
ResponseFormat string `json:"response_format,omitempty"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
meta := &types.TokenCountMeta{
|
||||||
|
CombineText: r.Input,
|
||||||
|
TokenType: types.TokenTypeTextNumber,
|
||||||
|
}
|
||||||
|
return meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AudioRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AudioRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type AudioResponse struct {
|
type AudioResponse struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,14 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type ChannelSettings struct {
|
type ChannelSettings struct {
|
||||||
ForceFormat bool `json:"force_format,omitempty"`
|
ForceFormat bool `json:"force_format,omitempty"`
|
||||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||||
Proxy string `json:"proxy"`
|
Proxy string `json:"proxy"`
|
||||||
|
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
|
||||||
|
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||||
|
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChannelOtherSettings struct {
|
||||||
|
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
196
dto/claude.go
196
dto/claude.go
@@ -2,8 +2,12 @@ package dto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClaudeMetadata struct {
|
type ClaudeMetadata struct {
|
||||||
@@ -80,7 +84,7 @@ func (c *ClaudeMediaMessage) GetStringContent() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMediaMessage) GetJsonRowString() string {
|
func (c *ClaudeMediaMessage) GetJsonRowString() string {
|
||||||
jsonContent, _ := json.Marshal(c)
|
jsonContent, _ := common.Marshal(c)
|
||||||
return string(jsonContent)
|
return string(jsonContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,6 +202,147 @@ type ClaudeRequest struct {
|
|||||||
Thinking *Thinking `json:"thinking,omitempty"`
|
Thinking *Thinking `json:"thinking,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var tokenCountMeta = types.TokenCountMeta{
|
||||||
|
TokenType: types.TokenTypeTokenizer,
|
||||||
|
MaxTokens: int(c.MaxTokens),
|
||||||
|
}
|
||||||
|
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
var fileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
|
// system
|
||||||
|
if c.System != nil {
|
||||||
|
if c.IsStringSystem() {
|
||||||
|
sys := c.GetStringSystem()
|
||||||
|
if sys != "" {
|
||||||
|
texts = append(texts, sys)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
systemMedia := c.ParseSystem()
|
||||||
|
for _, media := range systemMedia {
|
||||||
|
switch media.Type {
|
||||||
|
case "text":
|
||||||
|
texts = append(texts, media.GetText())
|
||||||
|
case "image":
|
||||||
|
if media.Source != nil {
|
||||||
|
data := media.Source.Url
|
||||||
|
if data == "" {
|
||||||
|
data = common.Interface2String(media.Source.Data)
|
||||||
|
}
|
||||||
|
if data != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// messages
|
||||||
|
for _, message := range c.Messages {
|
||||||
|
tokenCountMeta.MessagesCount++
|
||||||
|
texts = append(texts, message.Role)
|
||||||
|
if message.IsStringContent() {
|
||||||
|
content := message.GetStringContent()
|
||||||
|
if content != "" {
|
||||||
|
texts = append(texts, content)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
content, _ := message.ParseContent()
|
||||||
|
for _, media := range content {
|
||||||
|
switch media.Type {
|
||||||
|
case "text":
|
||||||
|
texts = append(texts, media.GetText())
|
||||||
|
case "image":
|
||||||
|
if media.Source != nil {
|
||||||
|
data := media.Source.Url
|
||||||
|
if data == "" {
|
||||||
|
data = common.Interface2String(media.Source.Data)
|
||||||
|
}
|
||||||
|
if data != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "tool_use":
|
||||||
|
if media.Name != "" {
|
||||||
|
texts = append(texts, media.Name)
|
||||||
|
}
|
||||||
|
if media.Input != nil {
|
||||||
|
b, _ := common.Marshal(media.Input)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
if media.Content != nil {
|
||||||
|
b, _ := common.Marshal(media.Content)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tools
|
||||||
|
if c.Tools != nil {
|
||||||
|
tools := c.GetTools()
|
||||||
|
normalTools, webSearchTools := ProcessTools(tools)
|
||||||
|
if normalTools != nil {
|
||||||
|
for _, t := range normalTools {
|
||||||
|
tokenCountMeta.ToolsCount++
|
||||||
|
if t.Name != "" {
|
||||||
|
texts = append(texts, t.Name)
|
||||||
|
}
|
||||||
|
if t.Description != "" {
|
||||||
|
texts = append(texts, t.Description)
|
||||||
|
}
|
||||||
|
if t.InputSchema != nil {
|
||||||
|
b, _ := common.Marshal(t.InputSchema)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if webSearchTools != nil {
|
||||||
|
for _, t := range webSearchTools {
|
||||||
|
tokenCountMeta.ToolsCount++
|
||||||
|
if t.Name != "" {
|
||||||
|
texts = append(texts, t.Name)
|
||||||
|
}
|
||||||
|
if t.UserLocation != nil {
|
||||||
|
b, _ := common.Marshal(t.UserLocation)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCountMeta.CombineText = strings.Join(texts, "\n")
|
||||||
|
tokenCountMeta.Files = fileMeta
|
||||||
|
return &tokenCountMeta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
|
||||||
|
return c.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
c.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
|
||||||
|
for _, message := range c.Messages {
|
||||||
|
content, _ := message.ParseContent()
|
||||||
|
for _, mediaMessage := range content {
|
||||||
|
if mediaMessage.Id == toolCallId {
|
||||||
|
return mediaMessage.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// AddTool 添加工具到请求中
|
// AddTool 添加工具到请求中
|
||||||
func (c *ClaudeRequest) AddTool(tool any) {
|
func (c *ClaudeRequest) AddTool(tool any) {
|
||||||
if c.Tools == nil {
|
if c.Tools == nil {
|
||||||
@@ -284,14 +429,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
|||||||
return mediaContent
|
return mediaContent
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeError struct {
|
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
Message string `json:"message,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeErrorWithStatusCode struct {
|
type ClaudeErrorWithStatusCode struct {
|
||||||
Error ClaudeError `json:"error"`
|
Error types.ClaudeError `json:"error"`
|
||||||
StatusCode int `json:"status_code"`
|
StatusCode int `json:"status_code"`
|
||||||
LocalError bool
|
LocalError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,7 +443,7 @@ type ClaudeResponse struct {
|
|||||||
Completion string `json:"completion,omitempty"`
|
Completion string `json:"completion,omitempty"`
|
||||||
StopReason string `json:"stop_reason,omitempty"`
|
StopReason string `json:"stop_reason,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Error *types.ClaudeError `json:"error,omitempty"`
|
Error any `json:"error,omitempty"`
|
||||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||||
Index *int `json:"index,omitempty"`
|
Index *int `json:"index,omitempty"`
|
||||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||||
@@ -324,12 +464,48 @@ func (c *ClaudeResponse) GetIndex() int {
|
|||||||
return *c.Index
|
return *c.Index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetClaudeError 从动态错误类型中提取ClaudeError结构
|
||||||
|
func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
|
||||||
|
if c.Error == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch err := c.Error.(type) {
|
||||||
|
case types.ClaudeError:
|
||||||
|
return &err
|
||||||
|
case *types.ClaudeError:
|
||||||
|
return err
|
||||||
|
case map[string]interface{}:
|
||||||
|
// 处理从JSON解析来的map结构
|
||||||
|
claudeErr := &types.ClaudeError{}
|
||||||
|
if errType, ok := err["type"].(string); ok {
|
||||||
|
claudeErr.Type = errType
|
||||||
|
}
|
||||||
|
if errMsg, ok := err["message"].(string); ok {
|
||||||
|
claudeErr.Message = errMsg
|
||||||
|
}
|
||||||
|
return claudeErr
|
||||||
|
case string:
|
||||||
|
// 处理简单字符串错误
|
||||||
|
return &types.ClaudeError{
|
||||||
|
Type: "upstream_error",
|
||||||
|
Message: err,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// 未知类型,尝试转换为字符串
|
||||||
|
return &types.ClaudeError{
|
||||||
|
Type: "unknown_upstream_error",
|
||||||
|
Message: fmt.Sprintf("unknown_error: %v", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type ClaudeUsage struct {
|
type ClaudeUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"`
|
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeServerToolUse struct {
|
type ClaudeServerToolUse struct {
|
||||||
|
|||||||
29
dto/dalle.go
29
dto/dalle.go
@@ -1,29 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
import "encoding/json"
|
|
||||||
|
|
||||||
type ImageRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt" binding:"required"`
|
|
||||||
N int `json:"n,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Quality string `json:"quality,omitempty"`
|
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
|
||||||
Style string `json:"style,omitempty"`
|
|
||||||
User string `json:"user,omitempty"`
|
|
||||||
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
|
||||||
Background string `json:"background,omitempty"`
|
|
||||||
Moderation string `json:"moderation,omitempty"`
|
|
||||||
OutputFormat string `json:"output_format,omitempty"`
|
|
||||||
Watermark *bool `json:"watermark,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageResponse struct {
|
|
||||||
Data []ImageData `json:"data"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
}
|
|
||||||
type ImageData struct {
|
|
||||||
Url string `json:"url"`
|
|
||||||
B64Json string `json:"b64_json"`
|
|
||||||
RevisedPrompt string `json:"revised_prompt"`
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,12 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
type EmbeddingOptions struct {
|
type EmbeddingOptions struct {
|
||||||
Seed int `json:"seed,omitempty"`
|
Seed int `json:"seed,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
@@ -24,9 +31,32 @@ type EmbeddingRequest struct {
|
|||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r EmbeddingRequest) ParseInput() []string {
|
func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
|
||||||
|
inputs := r.ParseInput()
|
||||||
|
for _, input := range inputs {
|
||||||
|
texts = append(texts, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: strings.Join(texts, "\n"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRequest) ParseInput() []string {
|
||||||
if r.Input == nil {
|
if r.Input == nil {
|
||||||
return nil
|
return make([]string, 0)
|
||||||
}
|
}
|
||||||
var input []string
|
var input []string
|
||||||
switch r.Input.(type) {
|
switch r.Input.(type) {
|
||||||
|
|||||||
@@ -1,15 +1,117 @@
|
|||||||
package gemini
|
package dto
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type GeminiChatRequest struct {
|
type GeminiChatRequest struct {
|
||||||
Contents []GeminiChatContent `json:"contents"`
|
Contents []GeminiChatContent `json:"contents"`
|
||||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||||
Tools []GeminiChatTool `json:"tools,omitempty"`
|
Tools json.RawMessage `json:"tools,omitempty"`
|
||||||
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var files []*types.FileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
|
var maxTokens int
|
||||||
|
|
||||||
|
if r.GenerationConfig.MaxOutputTokens > 0 {
|
||||||
|
maxTokens = int(r.GenerationConfig.MaxOutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputTexts []string
|
||||||
|
for _, content := range r.Contents {
|
||||||
|
for _, part := range content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||||
|
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeAudio,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeVideo,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeFile,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: inputText,
|
||||||
|
Files: files,
|
||||||
|
MaxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
|
||||||
|
if c.Query("alt") == "sse" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) SetModelName(modelName string) {
|
||||||
|
// GeminiChatRequest does not have a model field, so this method does nothing.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
|
||||||
|
var tools []GeminiChatTool
|
||||||
|
if strings.HasSuffix(string(r.Tools), "[") {
|
||||||
|
// is array
|
||||||
|
if err := common.Unmarshal(r.Tools, &tools); err != nil {
|
||||||
|
logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(string(r.Tools), "{") {
|
||||||
|
// is object
|
||||||
|
singleTool := GeminiChatTool{}
|
||||||
|
if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
|
||||||
|
logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tools = []GeminiChatTool{singleTool}
|
||||||
|
}
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
|
||||||
|
if len(tools) == 0 {
|
||||||
|
r.Tools = json.RawMessage("[]")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal the tools to JSON
|
||||||
|
data, err := common.Marshal(tools)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(nil, "error_marshalling_tools: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Tools = data
|
||||||
|
}
|
||||||
|
|
||||||
type GeminiThinkingConfig struct {
|
type GeminiThinkingConfig struct {
|
||||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||||
@@ -32,7 +134,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
|||||||
MimeTypeSnake string `json:"mime_type"`
|
MimeTypeSnake string `json:"mime_type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &aux); err != nil {
|
if err := common.Unmarshal(data, &aux); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +155,7 @@ type FunctionCall struct {
|
|||||||
Arguments any `json:"args"`
|
Arguments any `json:"args"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FunctionResponse struct {
|
type GeminiFunctionResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Response map[string]interface{} `json:"response"`
|
Response map[string]interface{} `json:"response"`
|
||||||
}
|
}
|
||||||
@@ -78,7 +180,7 @@ type GeminiPart struct {
|
|||||||
Thought bool `json:"thought,omitempty"`
|
Thought bool `json:"thought,omitempty"`
|
||||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||||
@@ -93,7 +195,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error {
|
|||||||
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &aux); err != nil {
|
if err := common.Unmarshal(data, &aux); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,16 +309,76 @@ type GeminiImagePrediction struct {
|
|||||||
|
|
||||||
// Embedding related structs
|
// Embedding related structs
|
||||||
type GeminiEmbeddingRequest struct {
|
type GeminiEmbeddingRequest struct {
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
Content GeminiChatContent `json:"content"`
|
Content GeminiChatContent `json:"content"`
|
||||||
TaskType string `json:"taskType,omitempty"`
|
TaskType string `json:"taskType,omitempty"`
|
||||||
Title string `json:"title,omitempty"`
|
Title string `json:"title,omitempty"`
|
||||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||||
|
// Gemini embedding requests are not streamed
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var inputTexts []string
|
||||||
|
for _, part := range r.Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: inputText,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiBatchEmbeddingRequest struct {
|
||||||
|
Requests []*GeminiEmbeddingRequest `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||||
|
// Gemini batch embedding requests are not streamed
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var inputTexts []string
|
||||||
|
for _, request := range r.Requests {
|
||||||
|
meta := request.GetTokenCountMeta()
|
||||||
|
if meta != nil && meta.CombineText != "" {
|
||||||
|
inputTexts = append(inputTexts, meta.CombineText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: inputText,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
for _, req := range r.Requests {
|
||||||
|
req.SetModelName(modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type GeminiEmbeddingResponse struct {
|
type GeminiEmbeddingResponse struct {
|
||||||
Embedding ContentEmbedding `json:"embedding"`
|
Embedding ContentEmbedding `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GeminiBatchEmbeddingResponse struct {
|
||||||
|
Embeddings []*ContentEmbedding `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
type ContentEmbedding struct {
|
type ContentEmbedding struct {
|
||||||
Values []float64 `json:"values"`
|
Values []float64 `json:"values"`
|
||||||
}
|
}
|
||||||
147
dto/openai_image.go
Normal file
147
dto/openai_image.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt" binding:"required"`
|
||||||
|
N uint `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Style json.RawMessage `json:"style,omitempty"`
|
||||||
|
User json.RawMessage `json:"user,omitempty"`
|
||||||
|
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
||||||
|
Background json.RawMessage `json:"background,omitempty"`
|
||||||
|
Moderation json.RawMessage `json:"moderation,omitempty"`
|
||||||
|
OutputFormat json.RawMessage `json:"output_format,omitempty"`
|
||||||
|
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||||
|
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||||
|
// Stream bool `json:"stream,omitempty"`
|
||||||
|
Watermark *bool `json:"watermark,omitempty"`
|
||||||
|
// 用匿名参数接收额外参数
|
||||||
|
Extra map[string]json.RawMessage `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) UnmarshalJSON(data []byte) error {
|
||||||
|
// 先解析成 map[string]interface{}
|
||||||
|
var rawMap map[string]json.RawMessage
|
||||||
|
if err := common.Unmarshal(data, &rawMap); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用 struct tag 获取所有已定义字段名
|
||||||
|
knownFields := GetJSONFieldNames(reflect.TypeOf(*i))
|
||||||
|
|
||||||
|
// 再正常解析已定义字段
|
||||||
|
type Alias ImageRequest
|
||||||
|
var known Alias
|
||||||
|
if err := common.Unmarshal(data, &known); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*i = ImageRequest(known)
|
||||||
|
|
||||||
|
// 提取多余字段
|
||||||
|
i.Extra = make(map[string]json.RawMessage)
|
||||||
|
for k, v := range rawMap {
|
||||||
|
if _, ok := knownFields[k]; !ok {
|
||||||
|
i.Extra[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||||
|
fields := make(map[string]struct{})
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
|
||||||
|
// 跳过匿名字段(例如 ExtraFields)
|
||||||
|
if field.Anonymous {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tag := field.Tag.Get("json")
|
||||||
|
if tag == "-" || tag == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 取逗号前字段名(排除 omitempty 等)
|
||||||
|
name := tag
|
||||||
|
if commaIdx := indexComma(tag); commaIdx != -1 {
|
||||||
|
name = tag[:commaIdx]
|
||||||
|
}
|
||||||
|
fields[name] = struct{}{}
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexComma(s string) int {
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if s[i] == ',' {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var sizeRatio = 1.0
|
||||||
|
var qualityRatio = 1.0
|
||||||
|
|
||||||
|
if strings.HasPrefix(i.Model, "dall-e") {
|
||||||
|
// Size
|
||||||
|
if i.Size == "256x256" {
|
||||||
|
sizeRatio = 0.4
|
||||||
|
} else if i.Size == "512x512" {
|
||||||
|
sizeRatio = 0.45
|
||||||
|
} else if i.Size == "1024x1024" {
|
||||||
|
sizeRatio = 1
|
||||||
|
} else if i.Size == "1024x1792" || i.Size == "1792x1024" {
|
||||||
|
sizeRatio = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.Model == "dall-e-3" && i.Quality == "hd" {
|
||||||
|
qualityRatio = 2.0
|
||||||
|
if i.Size == "1024x1792" || i.Size == "1792x1024" {
|
||||||
|
qualityRatio = 1.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// not support token count for dalle
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: i.Prompt,
|
||||||
|
MaxTokens: 1584,
|
||||||
|
ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
i.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageResponse struct {
|
||||||
|
Data []ImageData `json:"data"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Extra any `json:"extra,omitempty"`
|
||||||
|
}
|
||||||
|
type ImageData struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
B64Json string `json:"b64_json"`
|
||||||
|
RevisedPrompt string `json:"revised_prompt"`
|
||||||
|
}
|
||||||
@@ -2,20 +2,24 @@ package dto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
|
JsonSchema json.RawMessage `json:"json_schema,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FormatJsonSchema struct {
|
type FormatJsonSchema struct {
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Schema any `json:"schema,omitempty"`
|
Schema any `json:"schema,omitempty"`
|
||||||
Strict any `json:"strict,omitempty"`
|
Strict json.RawMessage `json:"strict,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeneralOpenAIRequest struct {
|
type GeneralOpenAIRequest struct {
|
||||||
@@ -29,6 +33,7 @@ type GeneralOpenAIRequest struct {
|
|||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||||
|
Verbosity json.RawMessage `json:"verbosity,omitempty"` // gpt-5
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
@@ -52,16 +57,142 @@ type GeneralOpenAIRequest struct {
|
|||||||
Dimensions int `json:"dimensions,omitempty"`
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||||
Audio json.RawMessage `json:"audio,omitempty"`
|
Audio json.RawMessage `json:"audio,omitempty"`
|
||||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
// gemini
|
||||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
//xai
|
||||||
SearchParameters any `json:"search_parameters,omitempty"` //xai
|
SearchParameters json.RawMessage `json:"search_parameters,omitempty"`
|
||||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
// claude
|
||||||
|
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||||
// OpenRouter Params
|
// OpenRouter Params
|
||||||
Usage json.RawMessage `json:"usage,omitempty"`
|
Usage json.RawMessage `json:"usage,omitempty"`
|
||||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||||
// Ali Qwen Params
|
// Ali Qwen Params
|
||||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||||
|
EnableThinking any `json:"enable_thinking,omitempty"`
|
||||||
|
// ollama Params
|
||||||
|
Think json.RawMessage `json:"think,omitempty"`
|
||||||
|
// baidu v2
|
||||||
|
WebSearch json.RawMessage `json:"web_search,omitempty"`
|
||||||
|
// doubao,zhipu_v4
|
||||||
|
THINKING json.RawMessage `json:"thinking,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var tokenCountMeta types.TokenCountMeta
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
var fileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
|
if r.Prompt != nil {
|
||||||
|
switch v := r.Prompt.(type) {
|
||||||
|
case string:
|
||||||
|
texts = append(texts, v)
|
||||||
|
case []any:
|
||||||
|
for _, item := range v {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
texts = append(texts, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
texts = append(texts, fmt.Sprintf("%v", r.Prompt))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Input != nil {
|
||||||
|
inputs := r.ParseInput()
|
||||||
|
texts = append(texts, inputs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.MaxCompletionTokens > r.MaxTokens {
|
||||||
|
tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
|
||||||
|
} else {
|
||||||
|
tokenCountMeta.MaxTokens = int(r.MaxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, message := range r.Messages {
|
||||||
|
tokenCountMeta.MessagesCount++
|
||||||
|
texts = append(texts, message.Role)
|
||||||
|
if message.Content != nil {
|
||||||
|
if message.Name != nil {
|
||||||
|
tokenCountMeta.NameCount++
|
||||||
|
texts = append(texts, *message.Name)
|
||||||
|
}
|
||||||
|
arrayContent := message.ParseContent()
|
||||||
|
for _, m := range arrayContent {
|
||||||
|
if m.Type == ContentTypeImageURL {
|
||||||
|
imageUrl := m.GetImageMedia()
|
||||||
|
if imageUrl != nil {
|
||||||
|
if imageUrl.Url != "" {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
}
|
||||||
|
meta.OriginData = imageUrl.Url
|
||||||
|
meta.Detail = imageUrl.Detail
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if m.Type == ContentTypeInputAudio {
|
||||||
|
inputAudio := m.GetInputAudio()
|
||||||
|
if inputAudio != nil {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeAudio,
|
||||||
|
}
|
||||||
|
meta.OriginData = inputAudio.Data
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
} else if m.Type == ContentTypeFile {
|
||||||
|
file := m.GetFile()
|
||||||
|
if file != nil {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeFile,
|
||||||
|
}
|
||||||
|
meta.OriginData = file.FileData
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
} else if m.Type == ContentTypeVideoUrl {
|
||||||
|
videoUrl := m.GetVideoUrl()
|
||||||
|
if videoUrl != nil && videoUrl.Url != "" {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeVideo,
|
||||||
|
}
|
||||||
|
meta.OriginData = videoUrl.Url
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
texts = append(texts, m.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Tools != nil {
|
||||||
|
openaiTools := r.Tools
|
||||||
|
for _, tool := range openaiTools {
|
||||||
|
tokenCountMeta.ToolsCount++
|
||||||
|
texts = append(texts, tool.Function.Name)
|
||||||
|
if tool.Function.Description != "" {
|
||||||
|
texts = append(texts, tool.Function.Description)
|
||||||
|
}
|
||||||
|
if tool.Function.Parameters != nil {
|
||||||
|
texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//toolTokens := CountTokenInput(countStr, request.Model)
|
||||||
|
//tkm += 8
|
||||||
|
//tkm += toolTokens
|
||||||
|
}
|
||||||
|
tokenCountMeta.CombineText = strings.Join(texts, "\n")
|
||||||
|
tokenCountMeta.Files = fileMeta
|
||||||
|
return &tokenCountMeta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return r.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||||
@@ -71,6 +202,17 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||||
|
if strings.HasPrefix(r.Model, "o") {
|
||||||
|
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||||
|
return "developer"
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(r.Model, "gpt-5") {
|
||||||
|
return "developer"
|
||||||
|
}
|
||||||
|
return "system"
|
||||||
|
}
|
||||||
|
|
||||||
type ToolCallRequest struct {
|
type ToolCallRequest struct {
|
||||||
ID string `json:"id,omitempty"`
|
ID string `json:"id,omitempty"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@@ -88,8 +230,11 @@ type StreamOptions struct {
|
|||||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) GetMaxTokens() int {
|
func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
|
||||||
return int(r.MaxTokens)
|
if r.MaxCompletionTokens != 0 {
|
||||||
|
return r.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
return r.MaxTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
||||||
@@ -185,6 +330,21 @@ func (m *MediaContent) GetFile() *MessageFile {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
|
||||||
|
if m.VideoUrl != nil {
|
||||||
|
if _, ok := m.VideoUrl.(*MessageVideoUrl); ok {
|
||||||
|
return m.VideoUrl.(*MessageVideoUrl)
|
||||||
|
}
|
||||||
|
if itemMap, ok := m.VideoUrl.(map[string]any); ok {
|
||||||
|
out := &MessageVideoUrl{
|
||||||
|
Url: common.Interface2String(itemMap["url"]),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type MessageImageUrl struct {
|
type MessageImageUrl struct {
|
||||||
Url string `json:"url"`
|
Url string `json:"url"`
|
||||||
Detail string `json:"detail"`
|
Detail string `json:"detail"`
|
||||||
@@ -216,6 +376,7 @@ const (
|
|||||||
ContentTypeInputAudio = "input_audio"
|
ContentTypeInputAudio = "input_audio"
|
||||||
ContentTypeFile = "file"
|
ContentTypeFile = "file"
|
||||||
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
|
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
|
||||||
|
//ContentTypeAudioUrl = "audio_url"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *Message) GetPrefix() bool {
|
func (m *Message) GetPrefix() bool {
|
||||||
@@ -605,27 +766,104 @@ type WebSearchOptions struct {
|
|||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/responses/create
|
// https://platform.openai.com/docs/api-reference/responses/create
|
||||||
type OpenAIResponsesRequest struct {
|
type OpenAIResponsesRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input json.RawMessage `json:"input,omitempty"`
|
Input json.RawMessage `json:"input,omitempty"`
|
||||||
Include json.RawMessage `json:"include,omitempty"`
|
Include json.RawMessage `json:"include,omitempty"`
|
||||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||||
ServiceTier string `json:"service_tier,omitempty"`
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
Store bool `json:"store,omitempty"`
|
Store bool `json:"store,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
Text json.RawMessage `json:"text,omitempty"`
|
Text json.RawMessage `json:"text,omitempty"`
|
||||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||||
Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
Truncation string `json:"truncation,omitempty"`
|
Truncation string `json:"truncation,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
User string `json:"user,omitempty"`
|
||||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var fileMeta = make([]*types.FileMeta, 0)
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
|
||||||
|
if r.Input != nil {
|
||||||
|
inputs := r.ParseInput()
|
||||||
|
for _, input := range inputs {
|
||||||
|
if input.Type == "input_image" {
|
||||||
|
if input.ImageUrl != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
OriginData: input.ImageUrl,
|
||||||
|
Detail: input.Detail,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if input.Type == "input_file" {
|
||||||
|
if input.FileUrl != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeFile,
|
||||||
|
OriginData: input.FileUrl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
texts = append(texts, input.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Instructions) > 0 {
|
||||||
|
texts = append(texts, string(r.Instructions))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Metadata) > 0 {
|
||||||
|
texts = append(texts, string(r.Metadata))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Text) > 0 {
|
||||||
|
texts = append(texts, string(r.Text))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.ToolChoice) > 0 {
|
||||||
|
texts = append(texts, string(r.ToolChoice))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Prompt) > 0 {
|
||||||
|
texts = append(texts, string(r.Prompt))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Tools) > 0 {
|
||||||
|
texts = append(texts, string(r.Tools))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: strings.Join(texts, "\n"),
|
||||||
|
Files: fileMeta,
|
||||||
|
MaxTokens: int(r.MaxOutputTokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return r.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any {
|
||||||
|
var toolsMap []map[string]any
|
||||||
|
if len(r.Tools) > 0 {
|
||||||
|
_ = common.Unmarshal(r.Tools, &toolsMap)
|
||||||
|
}
|
||||||
|
return toolsMap
|
||||||
}
|
}
|
||||||
|
|
||||||
type Reasoning struct {
|
type Reasoning struct {
|
||||||
@@ -633,23 +871,88 @@ type Reasoning struct {
|
|||||||
Summary string `json:"summary,omitempty"`
|
Summary string `json:"summary,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
//type ResponsesToolsCall struct {
|
type MediaInput struct {
|
||||||
// Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// // Web Search
|
Text string `json:"text,omitempty"`
|
||||||
// UserLocation json.RawMessage `json:"user_location,omitempty"`
|
FileUrl string `json:"file_url,omitempty"`
|
||||||
// SearchContextSize string `json:"search_context_size,omitempty"`
|
ImageUrl string `json:"image_url,omitempty"`
|
||||||
// // File Search
|
Detail string `json:"detail,omitempty"` // 仅 input_image 有效
|
||||||
// VectorStoreIds []string `json:"vector_store_ids,omitempty"`
|
}
|
||||||
// MaxNumResults uint `json:"max_num_results,omitempty"`
|
|
||||||
// Filters json.RawMessage `json:"filters,omitempty"`
|
// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput.
|
||||||
// // Computer Use
|
// Reference implementation mirrors Message.ParseContent:
|
||||||
// DisplayWidth uint `json:"display_width,omitempty"`
|
// - input can be a string, treated as an input_text item
|
||||||
// DisplayHeight uint `json:"display_height,omitempty"`
|
// - input can be an array of objects with a `type` field
|
||||||
// Environment string `json:"environment,omitempty"`
|
// supported types: input_text, input_image, input_file
|
||||||
// // Function
|
func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||||
// Name string `json:"name,omitempty"`
|
if r.Input == nil {
|
||||||
// Description string `json:"description,omitempty"`
|
return nil
|
||||||
// Parameters json.RawMessage `json:"parameters,omitempty"`
|
}
|
||||||
// Function json.RawMessage `json:"function,omitempty"`
|
|
||||||
// Container json.RawMessage `json:"container,omitempty"`
|
var inputs []MediaInput
|
||||||
//}
|
|
||||||
|
// Try string first
|
||||||
|
// if str, ok := common.GetJsonType(r.Input); ok {
|
||||||
|
// inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||||
|
// return inputs
|
||||||
|
// }
|
||||||
|
if common.GetJsonType(r.Input) == "string" {
|
||||||
|
var str string
|
||||||
|
_ = common.Unmarshal(r.Input, &str)
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||||
|
return inputs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try array of parts
|
||||||
|
if common.GetJsonType(r.Input) == "array" {
|
||||||
|
var array []any
|
||||||
|
_ = common.Unmarshal(r.Input, &array)
|
||||||
|
for _, itemAny := range array {
|
||||||
|
// Already parsed MediaInput
|
||||||
|
if media, ok := itemAny.(MediaInput); ok {
|
||||||
|
inputs = append(inputs, media)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Generic map
|
||||||
|
item, ok := itemAny.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
typeVal, ok := item["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch typeVal {
|
||||||
|
case "input_text":
|
||||||
|
text, _ := item["text"].(string)
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
|
||||||
|
case "input_image":
|
||||||
|
// image_url may be string or object with url field
|
||||||
|
var imageUrl string
|
||||||
|
switch v := item["image_url"].(type) {
|
||||||
|
case string:
|
||||||
|
imageUrl = v
|
||||||
|
case map[string]any:
|
||||||
|
if url, ok := v["url"].(string); ok {
|
||||||
|
imageUrl = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
|
||||||
|
case "input_file":
|
||||||
|
// file_url may be string or object with url field
|
||||||
|
var fileUrl string
|
||||||
|
switch v := item["file_url"].(type) {
|
||||||
|
case string:
|
||||||
|
fileUrl = v
|
||||||
|
case map[string]any:
|
||||||
|
if url, ok := v["url"].(string); ok {
|
||||||
|
fileUrl = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,12 +2,18 @@ package dto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SimpleResponse struct {
|
type SimpleResponse struct {
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
Error *OpenAIError `json:"error"`
|
Error any `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
|
||||||
|
return GetOpenAIError(s.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextResponse struct {
|
type TextResponse struct {
|
||||||
@@ -31,10 +37,15 @@ type OpenAITextResponse struct {
|
|||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created any `json:"created"`
|
Created any `json:"created"`
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||||
Error *types.OpenAIError `json:"error,omitempty"`
|
Error any `json:"error,omitempty"`
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
|
||||||
|
return GetOpenAIError(o.Error)
|
||||||
|
}
|
||||||
|
|
||||||
type OpenAIEmbeddingResponseItem struct {
|
type OpenAIEmbeddingResponseItem struct {
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
@@ -99,7 +110,7 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
|
|||||||
|
|
||||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
|
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
|
||||||
c.ReasoningContent = &s
|
c.ReasoningContent = &s
|
||||||
c.Reasoning = &s
|
//c.Reasoning = &s
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCallResponse struct {
|
type ToolCallResponse struct {
|
||||||
@@ -132,6 +143,13 @@ type ChatCompletionsStreamResponse struct {
|
|||||||
Usage *Usage `json:"usage"`
|
Usage *Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) IsFinished() bool {
|
||||||
|
if len(c.Choices) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != ""
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
|
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
|
||||||
if len(c.Choices) == 0 {
|
if len(c.Choices) == 0 {
|
||||||
return false
|
return false
|
||||||
@@ -146,6 +164,19 @@ func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) ClearToolCalls() {
|
||||||
|
if !c.IsToolCall() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for choiceIdx := range c.Choices {
|
||||||
|
for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls {
|
||||||
|
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = ""
|
||||||
|
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil
|
||||||
|
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
||||||
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
||||||
copy(choices, c.Choices)
|
copy(choices, c.Choices)
|
||||||
@@ -217,7 +248,7 @@ type OpenAIResponsesResponse struct {
|
|||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
CreatedAt int `json:"created_at"`
|
CreatedAt int `json:"created_at"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Error *types.OpenAIError `json:"error,omitempty"`
|
Error any `json:"error,omitempty"`
|
||||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||||
Instructions string `json:"instructions"`
|
Instructions string `json:"instructions"`
|
||||||
MaxOutputTokens int `json:"max_output_tokens"`
|
MaxOutputTokens int `json:"max_output_tokens"`
|
||||||
@@ -237,6 +268,11 @@ type OpenAIResponsesResponse struct {
|
|||||||
Metadata json.RawMessage `json:"metadata"`
|
Metadata json.RawMessage `json:"metadata"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
|
||||||
|
return GetOpenAIError(o.Error)
|
||||||
|
}
|
||||||
|
|
||||||
type IncompleteDetails struct {
|
type IncompleteDetails struct {
|
||||||
Reasoning string `json:"reasoning"`
|
Reasoning string `json:"reasoning"`
|
||||||
}
|
}
|
||||||
@@ -276,3 +312,45 @@ type ResponsesStreamResponse struct {
|
|||||||
Delta string `json:"delta,omitempty"`
|
Delta string `json:"delta,omitempty"`
|
||||||
Item *ResponsesOutput `json:"item,omitempty"`
|
Item *ResponsesOutput `json:"item,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func GetOpenAIError(errorField any) *types.OpenAIError {
|
||||||
|
if errorField == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch err := errorField.(type) {
|
||||||
|
case types.OpenAIError:
|
||||||
|
return &err
|
||||||
|
case *types.OpenAIError:
|
||||||
|
return err
|
||||||
|
case map[string]interface{}:
|
||||||
|
// 处理从JSON解析来的map结构
|
||||||
|
openaiErr := &types.OpenAIError{}
|
||||||
|
if errType, ok := err["type"].(string); ok {
|
||||||
|
openaiErr.Type = errType
|
||||||
|
}
|
||||||
|
if errMsg, ok := err["message"].(string); ok {
|
||||||
|
openaiErr.Message = errMsg
|
||||||
|
}
|
||||||
|
if errParam, ok := err["param"].(string); ok {
|
||||||
|
openaiErr.Param = errParam
|
||||||
|
}
|
||||||
|
if errCode, ok := err["code"]; ok {
|
||||||
|
openaiErr.Code = errCode
|
||||||
|
}
|
||||||
|
return openaiErr
|
||||||
|
case string:
|
||||||
|
// 处理简单字符串错误
|
||||||
|
return &types.OpenAIError{
|
||||||
|
Type: "error",
|
||||||
|
Message: err,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// 未知类型,尝试转换为字符串
|
||||||
|
return &types.OpenAIError{
|
||||||
|
Type: "unknown_error",
|
||||||
|
Message: fmt.Sprintf("%v", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package dto
|
|||||||
|
|
||||||
import "one-api/constant"
|
import "one-api/constant"
|
||||||
|
|
||||||
|
// 这里不好动就不动了,本来想独立出来的(
|
||||||
type OpenAIModels struct {
|
type OpenAIModels struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
@@ -9,3 +10,26 @@ type OpenAIModels struct {
|
|||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy string `json:"owned_by"`
|
||||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AnthropicModel struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiModel struct {
|
||||||
|
Name interface{} `json:"name"`
|
||||||
|
BaseModelId interface{} `json:"baseModelId"`
|
||||||
|
Version interface{} `json:"version"`
|
||||||
|
DisplayName interface{} `json:"displayName"`
|
||||||
|
Description interface{} `json:"description"`
|
||||||
|
InputTokenLimit interface{} `json:"inputTokenLimit"`
|
||||||
|
OutputTokenLimit interface{} `json:"outputTokenLimit"`
|
||||||
|
SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
|
||||||
|
Thinking interface{} `json:"thinking"`
|
||||||
|
Temperature interface{} `json:"temperature"`
|
||||||
|
MaxTemperature interface{} `json:"maxTemperature"`
|
||||||
|
TopP interface{} `json:"topP"`
|
||||||
|
TopK interface{} `json:"topK"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,23 +1,23 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type UpstreamDTO struct {
|
type UpstreamDTO struct {
|
||||||
ID int `json:"id,omitempty"`
|
ID int `json:"id,omitempty"`
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
BaseURL string `json:"base_url" binding:"required"`
|
BaseURL string `json:"base_url" binding:"required"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpstreamRequest struct {
|
type UpstreamRequest struct {
|
||||||
ChannelIDs []int64 `json:"channel_ids"`
|
ChannelIDs []int64 `json:"channel_ids"`
|
||||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||||
Timeout int `json:"timeout"`
|
Timeout int `json:"timeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestResult 上游测试连通性结果
|
// TestResult 上游测试连通性结果
|
||||||
type TestResult struct {
|
type TestResult struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DifferenceItem 差异项
|
// DifferenceItem 差异项
|
||||||
@@ -25,14 +25,14 @@ type TestResult struct {
|
|||||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||||
|
|
||||||
type DifferenceItem struct {
|
type DifferenceItem struct {
|
||||||
Current interface{} `json:"current"`
|
Current interface{} `json:"current"`
|
||||||
Upstreams map[string]interface{} `json:"upstreams"`
|
Upstreams map[string]interface{} `json:"upstreams"`
|
||||||
Confidence map[string]bool `json:"confidence"`
|
Confidence map[string]bool `json:"confidence"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncableChannel struct {
|
type SyncableChannel struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
BaseURL string `json:"base_url"`
|
BaseURL string `json:"base_url"`
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
}
|
}
|
||||||
|
|||||||
25
dto/request_common.go
Normal file
25
dto/request_common.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Request interface {
|
||||||
|
GetTokenCountMeta() *types.TokenCountMeta
|
||||||
|
IsStream(c *gin.Context) bool
|
||||||
|
SetModelName(modelName string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseRequest struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
TokenType: types.TokenTypeTokenizer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (b *BaseRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
func (b *BaseRequest) SetModelName(modelName string) {}
|
||||||
@@ -1,5 +1,12 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type RerankRequest struct {
|
type RerankRequest struct {
|
||||||
Documents []any `json:"documents"`
|
Documents []any `json:"documents"`
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
@@ -10,6 +17,32 @@ type RerankRequest struct {
|
|||||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *RerankRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
|
||||||
|
for _, document := range r.Documents {
|
||||||
|
texts = append(texts, fmt.Sprintf("%v", document))
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Query != "" {
|
||||||
|
texts = append(texts, r.Query)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: strings.Join(texts, "\n"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RerankRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *RerankRequest) GetReturnDocuments() bool {
|
func (r *RerankRequest) GetReturnDocuments() bool {
|
||||||
if r.ReturnDocuments == nil {
|
if r.ReturnDocuments == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ type UserSetting struct {
|
|||||||
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
||||||
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
||||||
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
||||||
|
BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL
|
||||||
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||||
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
||||||
|
SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
NotifyTypeEmail = "email" // Email 邮件
|
NotifyTypeEmail = "email" // Email 邮件
|
||||||
NotifyTypeWebhook = "webhook" // Webhook
|
NotifyTypeWebhook = "webhook" // Webhook
|
||||||
|
NotifyTypeBark = "bark" // Bark 推送
|
||||||
)
|
)
|
||||||
|
|||||||
19
go.mod
19
go.mod
@@ -7,9 +7,10 @@ require (
|
|||||||
github.com/Calcium-Ion/go-epay v0.0.4
|
github.com/Calcium-Ion/go-epay v0.0.4
|
||||||
github.com/andybalholm/brotli v1.1.1
|
github.com/andybalholm/brotli v1.1.1
|
||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
||||||
github.com/aws/aws-sdk-go-v2 v1.26.1
|
github.com/aws/aws-sdk-go-v2 v1.37.2
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
|
||||||
|
github.com/aws/smithy-go v1.22.5
|
||||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||||
github.com/gin-contrib/cors v1.7.2
|
github.com/gin-contrib/cors v1.7.2
|
||||||
github.com/gin-contrib/gzip v0.0.6
|
github.com/gin-contrib/gzip v0.0.6
|
||||||
@@ -22,13 +23,17 @@ require (
|
|||||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
|
github.com/jinzhu/copier v0.4.0
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
|
github.com/pquerna/otp v1.5.0
|
||||||
github.com/samber/lo v1.39.0
|
github.com/samber/lo v1.39.0
|
||||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||||
github.com/shopspring/decimal v1.4.0
|
github.com/shopspring/decimal v1.4.0
|
||||||
github.com/stripe/stripe-go/v81 v81.4.0
|
github.com/stripe/stripe-go/v81 v81.4.0
|
||||||
github.com/thanhpk/randstr v1.0.6
|
github.com/thanhpk/randstr v1.0.6
|
||||||
|
github.com/tidwall/gjson v1.18.0
|
||||||
|
github.com/tidwall/sjson v1.2.5
|
||||||
github.com/tiktoken-go/tokenizer v0.6.2
|
github.com/tiktoken-go/tokenizer v0.6.2
|
||||||
golang.org/x/crypto v0.35.0
|
golang.org/x/crypto v0.35.0
|
||||||
golang.org/x/image v0.23.0
|
golang.org/x/image v0.23.0
|
||||||
@@ -41,10 +46,10 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
|
||||||
github.com/aws/smithy-go v1.20.2 // indirect
|
github.com/boombuler/barcode v1.1.0 // indirect
|
||||||
github.com/bytedance/sonic v1.11.6 // indirect
|
github.com/bytedance/sonic v1.11.6 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
@@ -80,6 +85,8 @@ require (
|
|||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
|
|||||||
40
go.sum
40
go.sum
@@ -6,20 +6,23 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc
|
|||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
|
github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
|
github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
||||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
|
||||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
|
||||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
|
||||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
|
||||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
|
||||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
|
||||||
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
|
||||||
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||||
|
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||||
|
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
|
||||||
|
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
||||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
@@ -117,6 +120,8 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
|
|||||||
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
|
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
|
||||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
|
||||||
|
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
@@ -169,6 +174,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
|||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||||
|
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
@@ -199,6 +206,15 @@ github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJ
|
|||||||
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
|
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
|
||||||
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
|
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
|
||||||
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
|
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
||||||
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||||
|
|||||||
1041
i18n/zh-cn.json
1041
i18n/zh-cn.json
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,26 @@
|
|||||||
package common
|
package logger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"one-api/common"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
loggerINFO = "INFO"
|
loggerINFO = "INFO"
|
||||||
loggerWarn = "WARN"
|
loggerWarn = "WARN"
|
||||||
loggerError = "ERR"
|
loggerError = "ERR"
|
||||||
|
loggerDebug = "DEBUG"
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxLogCount = 1000000
|
const maxLogCount = 1000000
|
||||||
@@ -27,7 +30,10 @@ var setupLogLock sync.Mutex
|
|||||||
var setupLogWorking bool
|
var setupLogWorking bool
|
||||||
|
|
||||||
func SetupLogger() {
|
func SetupLogger() {
|
||||||
if *LogDir != "" {
|
defer func() {
|
||||||
|
setupLogWorking = false
|
||||||
|
}()
|
||||||
|
if *common.LogDir != "" {
|
||||||
ok := setupLogLock.TryLock()
|
ok := setupLogLock.TryLock()
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Println("setup log is already working")
|
log.Println("setup log is already working")
|
||||||
@@ -35,9 +41,8 @@ func SetupLogger() {
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
setupLogLock.Unlock()
|
setupLogLock.Unlock()
|
||||||
setupLogWorking = false
|
|
||||||
}()
|
}()
|
||||||
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
|
logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
|
||||||
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to open log file")
|
log.Fatal("failed to open log file")
|
||||||
@@ -47,16 +52,6 @@ func SetupLogger() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SysLog(s string) {
|
|
||||||
t := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func SysError(s string) {
|
|
||||||
t := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogInfo(ctx context.Context, msg string) {
|
func LogInfo(ctx context.Context, msg string) {
|
||||||
logHelper(ctx, loggerINFO, msg)
|
logHelper(ctx, loggerINFO, msg)
|
||||||
}
|
}
|
||||||
@@ -69,12 +64,18 @@ func LogError(ctx context.Context, msg string) {
|
|||||||
logHelper(ctx, loggerError, msg)
|
logHelper(ctx, loggerError, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LogDebug(ctx context.Context, msg string) {
|
||||||
|
if common.DebugEnabled {
|
||||||
|
logHelper(ctx, loggerDebug, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func logHelper(ctx context.Context, level string, msg string) {
|
func logHelper(ctx context.Context, level string, msg string) {
|
||||||
writer := gin.DefaultErrorWriter
|
writer := gin.DefaultErrorWriter
|
||||||
if level == loggerINFO {
|
if level == loggerINFO {
|
||||||
writer = gin.DefaultWriter
|
writer = gin.DefaultWriter
|
||||||
}
|
}
|
||||||
id := ctx.Value(RequestIdKey)
|
id := ctx.Value(common.RequestIdKey)
|
||||||
if id == nil {
|
if id == nil {
|
||||||
id = "SYSTEM"
|
id = "SYSTEM"
|
||||||
}
|
}
|
||||||
@@ -90,23 +91,17 @@ func logHelper(ctx context.Context, level string, msg string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FatalLog(v ...any) {
|
|
||||||
t := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogQuota(quota int) string {
|
func LogQuota(quota int) string {
|
||||||
if DisplayInCurrencyEnabled {
|
if common.DisplayInCurrencyEnabled {
|
||||||
return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
|
return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Sprintf("%d 点额度", quota)
|
return fmt.Sprintf("%d 点额度", quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormatQuota(quota int) string {
|
func FormatQuota(quota int) string {
|
||||||
if DisplayInCurrencyEnabled {
|
if common.DisplayInCurrencyEnabled {
|
||||||
return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
|
return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Sprintf("%d", quota)
|
return fmt.Sprintf("%d", quota)
|
||||||
}
|
}
|
||||||
9
main.go
9
main.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/router"
|
"one-api/router"
|
||||||
@@ -60,13 +61,13 @@ func main() {
|
|||||||
}
|
}
|
||||||
if common.MemoryCacheEnabled {
|
if common.MemoryCacheEnabled {
|
||||||
common.SysLog("memory cache enabled")
|
common.SysLog("memory cache enabled")
|
||||||
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
||||||
|
|
||||||
// Add panic recovery and retry for InitChannelCache
|
// Add panic recovery and retry for InitChannelCache
|
||||||
func() {
|
func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
||||||
// Retry once
|
// Retry once
|
||||||
_, _, fixErr := model.FixAbility()
|
_, _, fixErr := model.FixAbility()
|
||||||
if fixErr != nil {
|
if fixErr != nil {
|
||||||
@@ -125,7 +126,7 @@ func main() {
|
|||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.New()
|
server := gin.New()
|
||||||
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
|
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
|
||||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
common.SysLog(fmt.Sprintf("panic detected: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||||
@@ -171,7 +172,7 @@ func InitResources() error {
|
|||||||
// 加载环境变量
|
// 加载环境变量
|
||||||
common.InitEnv()
|
common.InitEnv()
|
||||||
|
|
||||||
common.SetupLogger()
|
logger.SetupLogger()
|
||||||
|
|
||||||
// Initialize model settings
|
// Initialize model settings
|
||||||
ratio_setting.InitRatioSettings()
|
ratio_setting.InitRatioSettings()
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -122,6 +125,7 @@ func authHelper(c *gin.Context, minRole int) {
|
|||||||
c.Set("role", role)
|
c.Set("role", role)
|
||||||
c.Set("id", id)
|
c.Set("id", id)
|
||||||
c.Set("group", session.Get("group"))
|
c.Set("group", session.Get("group"))
|
||||||
|
c.Set("user_group", session.Get("group"))
|
||||||
c.Set("use_access_token", useAccessToken)
|
c.Set("use_access_token", useAccessToken)
|
||||||
|
|
||||||
//userCache, err := model.GetUserCache(id.(int))
|
//userCache, err := model.GetUserCache(id.(int))
|
||||||
@@ -190,14 +194,15 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
// 检查path包含/v1/messages
|
// 检查path包含/v1/messages
|
||||||
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
|
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
|
||||||
// 从x-api-key中获取key
|
anthropicKey := c.Request.Header.Get("x-api-key")
|
||||||
key := c.Request.Header.Get("x-api-key")
|
if anthropicKey != "" {
|
||||||
if key != "" {
|
c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// gemini api 从query中获取key
|
// gemini api 从query中获取key
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||||
skKey := c.Query("key")
|
skKey := c.Query("key")
|
||||||
if skKey != "" {
|
if skKey != "" {
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
||||||
@@ -233,6 +238,16 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allowIpsMap := token.GetIpLimitsMap()
|
||||||
|
if len(allowIpsMap) != 0 {
|
||||||
|
clientIp := c.ClientIP()
|
||||||
|
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
userCache, err := model.GetUserCache(token.UserId)
|
userCache, err := model.GetUserCache(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
@@ -246,6 +261,25 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
|
|
||||||
userCache.WriteContext(c)
|
userCache.WriteContext(c)
|
||||||
|
|
||||||
|
userGroup := userCache.Group
|
||||||
|
tokenGroup := token.Group
|
||||||
|
if tokenGroup != "" {
|
||||||
|
// check common.UserUsableGroups[userGroup]
|
||||||
|
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// check group in common.GroupRatio
|
||||||
|
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||||
|
if tokenGroup != "auto" {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userGroup = tokenGroup
|
||||||
|
}
|
||||||
|
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||||
|
|
||||||
err = SetupContextForToken(c, token, parts...)
|
err = SetupContextForToken(c, token, parts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -272,7 +306,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
|
|||||||
} else {
|
} else {
|
||||||
c.Set("token_model_limit_enabled", false)
|
c.Set("token_model_limit_enabled", false)
|
||||||
}
|
}
|
||||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
|
||||||
c.Set("token_group", token.Group)
|
c.Set("token_group", token.Group)
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
|
|||||||
12
middleware/disable-cache.go
Normal file
12
middleware/disable-cache.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
func DisableCache() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.Header("Cache-Control", "no-store, no-cache, must-revalidate, private, max-age=0")
|
||||||
|
c.Header("Pragma", "no-cache")
|
||||||
|
c.Header("Expires", "0")
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -27,14 +27,6 @@ type ModelRequest struct {
|
|||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
|
||||||
if len(allowIpsMap) != 0 {
|
|
||||||
clientIp := c.ClientIP()
|
|
||||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||||
@@ -42,24 +34,6 @@ func Distribute() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
|
||||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
|
||||||
if tokenGroup != "" {
|
|
||||||
// check common.UserUsableGroups[userGroup]
|
|
||||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// check group in common.GroupRatio
|
|
||||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
|
||||||
if tokenGroup != "auto" {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
userGroup = tokenGroup
|
|
||||||
}
|
|
||||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
|
||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -81,44 +55,63 @@ func Distribute() func(c *gin.Context) {
|
|||||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
if modelLimitEnable {
|
if modelLimitEnable {
|
||||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||||
var tokenModelLimit map[string]bool
|
if !ok {
|
||||||
if ok {
|
|
||||||
tokenModelLimit = s.(map[string]bool)
|
|
||||||
} else {
|
|
||||||
tokenModelLimit = map[string]bool{}
|
|
||||||
}
|
|
||||||
if tokenModelLimit != nil {
|
|
||||||
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// token model limit is empty, all models are not allowed
|
// token model limit is empty, all models are not allowed
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var tokenModelLimit map[string]bool
|
||||||
|
tokenModelLimit, ok = s.(map[string]bool)
|
||||||
|
if !ok {
|
||||||
|
tokenModelLimit = map[string]bool{}
|
||||||
|
}
|
||||||
|
matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
|
||||||
|
if _, ok := tokenModelLimit[matchName]; !ok {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldSelectChannel {
|
if shouldSelectChannel {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
var selectGroup string
|
var selectGroup string
|
||||||
|
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||||
|
// check path is /pg/chat/completions
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||||
|
playgroundRequest := &dto.PlayGroundRequest{}
|
||||||
|
err = common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||||
|
if err != nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if playgroundRequest.Group != "" {
|
||||||
|
if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userGroup = playgroundRequest.Group
|
||||||
|
}
|
||||||
|
}
|
||||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
showGroup := userGroup
|
showGroup := userGroup
|
||||||
if userGroup == "auto" {
|
if userGroup == "auto" {
|
||||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||||
}
|
}
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
|
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||||
if channel != nil {
|
//if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
// message = "数据库一致性已被破坏,请联系管理员"
|
||||||
}
|
//}
|
||||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
|
||||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel == nil {
|
if channel == nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -174,23 +167,16 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
var platform string
|
relayMode := relayconstant.RelayModeUnknown
|
||||||
var relayMode int
|
if c.Request.Method == http.MethodPost {
|
||||||
if strings.HasPrefix(modelRequest.Model, "jimeng") {
|
relayMode = relayconstant.RelayModeVideoSubmit
|
||||||
platform = string(constant.TaskPlatformJimeng)
|
} else if c.Request.Method == http.MethodGet {
|
||||||
relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
|
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||||
if relayMode == relayconstant.RelayModeJimengFetchByID {
|
shouldSelectChannel = false
|
||||||
shouldSelectChannel = false
|
}
|
||||||
}
|
if _, ok := c.Get("relay_mode"); !ok {
|
||||||
} else {
|
c.Set("relay_mode", relayMode)
|
||||||
platform = string(constant.TaskPlatformKling)
|
|
||||||
relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
|
||||||
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
|
||||||
shouldSelectChannel = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
c.Set("platform", platform)
|
|
||||||
c.Set("relay_mode", relayMode)
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||||
relayMode := relayconstant.RelayModeGemini
|
relayMode := relayconstant.RelayModeGemini
|
||||||
@@ -199,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
modelRequest.Model = modelName
|
modelRequest.Model = modelName
|
||||||
}
|
}
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -222,7 +208,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||||
modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
||||||
|
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
|
modelRequest.Model = c.PostForm("model")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||||
relayMode := relayconstant.RelayModeAudioSpeech
|
relayMode := relayconstant.RelayModeAudioSpeech
|
||||||
@@ -253,14 +242,16 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||||
c.Set("original_model", modelName) // for retry
|
c.Set("original_model", modelName) // for retry
|
||||||
if channel == nil {
|
if channel == nil {
|
||||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
|
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride())
|
||||||
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||||
}
|
}
|
||||||
@@ -275,11 +266,16 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
if channel.ChannelInfo.IsMultiKey {
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
||||||
|
} else {
|
||||||
|
// 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
|
||||||
}
|
}
|
||||||
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||||
|
|
||||||
|
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, false)
|
||||||
|
|
||||||
// TODO: api_version统一
|
// TODO: api_version统一
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case constant.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
|
|||||||
80
middleware/email-verification-rate-limit.go
Normal file
80
middleware/email-verification-rate-limit.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EmailVerificationRateLimitMark = "EV"
|
||||||
|
EmailVerificationMaxRequests = 2 // 30秒内最多2次
|
||||||
|
EmailVerificationDuration = 30 // 30秒时间窗口
|
||||||
|
)
|
||||||
|
|
||||||
|
func redisEmailVerificationRateLimiter(c *gin.Context) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := common.RDB
|
||||||
|
key := "emailVerification:" + EmailVerificationRateLimitMark + ":" + c.ClientIP()
|
||||||
|
|
||||||
|
count, err := rdb.Incr(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
// fallback
|
||||||
|
memoryEmailVerificationRateLimiter(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第一次设置键时设置过期时间
|
||||||
|
if count == 1 {
|
||||||
|
_ = rdb.Expire(ctx, key, time.Duration(EmailVerificationDuration)*time.Second).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否超出限制
|
||||||
|
if count <= int64(EmailVerificationMaxRequests) {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取剩余等待时间
|
||||||
|
ttl, err := rdb.TTL(ctx, key).Result()
|
||||||
|
waitSeconds := int64(EmailVerificationDuration)
|
||||||
|
if err == nil && ttl > 0 {
|
||||||
|
waitSeconds = int64(ttl.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": fmt.Sprintf("发送过于频繁,请等待 %d 秒后再试", waitSeconds),
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
|
func memoryEmailVerificationRateLimiter(c *gin.Context) {
|
||||||
|
key := EmailVerificationRateLimitMark + ":" + c.ClientIP()
|
||||||
|
|
||||||
|
if !inMemoryRateLimiter.Request(key, EmailVerificationMaxRequests, EmailVerificationDuration) {
|
||||||
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "发送过于频繁,请稍后再试",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmailVerificationRateLimit() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if common.RedisEnabled {
|
||||||
|
redisEmailVerificationRateLimiter(c)
|
||||||
|
} else {
|
||||||
|
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
|
||||||
|
memoryEmailVerificationRateLimiter(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
66
middleware/jimeng_adapter.go
Normal file
66
middleware/jimeng_adapter.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
func JimengRequestConvert() func(c *gin.Context) {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
action := c.Query("Action")
|
||||||
|
if action == "" {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Action query parameter is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Jimeng official API request
|
||||||
|
var originalReq map[string]interface{}
|
||||||
|
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model, _ := originalReq["req_key"].(string)
|
||||||
|
prompt, _ := originalReq["prompt"].(string)
|
||||||
|
|
||||||
|
unifiedReq := map[string]interface{}{
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"metadata": originalReq,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(unifiedReq)
|
||||||
|
if err != nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, "Failed to marshal request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update request body
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||||
|
c.Set(common.KeyRequestBody, jsonData)
|
||||||
|
|
||||||
|
if image, ok := originalReq["image"]; !ok || image == "" {
|
||||||
|
c.Set("action", constant.TaskActionTextGenerate)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.URL.Path = "/v1/video/generations"
|
||||||
|
|
||||||
|
if action == "CVSync2AsyncGetResult" {
|
||||||
|
taskId, ok := originalReq["task_id"].(string)
|
||||||
|
if !ok || taskId == "" {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "task_id is required for CVSync2AsyncGetResult")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Request.URL.Path = "/v1/video/generations/" + taskId
|
||||||
|
c.Request.Method = http.MethodGet
|
||||||
|
c.Set("task_id", taskId)
|
||||||
|
c.Set("relay_mode", relayconstant.RelayModeVideoFetchByID)
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,7 +18,11 @@ func KlingRequestConvert() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Support both model_name and model fields
|
||||||
model, _ := originalReq["model_name"].(string)
|
model, _ := originalReq["model_name"].(string)
|
||||||
|
if model == "" {
|
||||||
|
model, _ = originalReq["model"].(string)
|
||||||
|
}
|
||||||
prompt, _ := originalReq["prompt"].(string)
|
prompt, _ := originalReq["prompt"].(string)
|
||||||
|
|
||||||
unifiedReq := map[string]interface{}{
|
unifiedReq := map[string]interface{}{
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
common.SysLog(fmt.Sprintf("panic detected: %v", err))
|
||||||
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||||
|
|||||||
@@ -18,12 +18,12 @@ func StatsMiddleware() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// 增加活跃连接数
|
// 增加活跃连接数
|
||||||
atomic.AddInt64(&globalStats.activeConnections, 1)
|
atomic.AddInt64(&globalStats.activeConnections, 1)
|
||||||
|
|
||||||
// 确保在请求结束时减少连接数
|
// 确保在请求结束时减少连接数
|
||||||
defer func() {
|
defer func() {
|
||||||
atomic.AddInt64(&globalStats.activeConnections, -1)
|
atomic.AddInt64(&globalStats.activeConnections, -1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -38,4 +38,4 @@ func GetStats() StatsInfo {
|
|||||||
return StatsInfo{
|
return StatsInfo{
|
||||||
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ func TurnstileCheck() gin.HandlerFunc {
|
|||||||
"remoteip": {c.ClientIP()},
|
"remoteip": {c.ClientIP()},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
common.SysLog(err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
@@ -49,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc {
|
|||||||
var res turnstileCheckResponse
|
var res turnstileCheckResponse
|
||||||
err = json.NewDecoder(rawRes.Body).Decode(&res)
|
err = json.NewDecoder(rawRes.Body).Decode(&res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
common.SysLog(err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
|||||||
@@ -4,18 +4,24 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) {
|
||||||
|
codeStr := ""
|
||||||
|
if len(code) > 0 {
|
||||||
|
codeStr = code[0]
|
||||||
|
}
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||||
"type": "new_api_error",
|
"type": "new_api_error",
|
||||||
|
"code": codeStr,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
|
logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
|
||||||
}
|
}
|
||||||
|
|
||||||
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
|
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
|
||||||
@@ -25,5 +31,5 @@ func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, descri
|
|||||||
"code": code,
|
"code": code,
|
||||||
})
|
})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
common.LogError(c.Request.Context(), description)
|
logger.LogError(c.Request.Context(), description)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,13 +136,13 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil, errors.New("channel not found")
|
return nil, nil
|
||||||
}
|
}
|
||||||
err = DB.First(&channel, "id = ?", channel.Id).Error
|
err = DB.First(&channel, "id = ?", channel.Id).Error
|
||||||
return &channel, err
|
return &channel, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) AddAbilities() error {
|
func (channel *Channel) AddAbilities(tx *gorm.DB) error {
|
||||||
models_ := strings.Split(channel.Models, ",")
|
models_ := strings.Split(channel.Models, ",")
|
||||||
groups_ := strings.Split(channel.Group, ",")
|
groups_ := strings.Split(channel.Group, ",")
|
||||||
abilitySet := make(map[string]struct{})
|
abilitySet := make(map[string]struct{})
|
||||||
@@ -169,8 +169,13 @@ func (channel *Channel) AddAbilities() error {
|
|||||||
if len(abilities) == 0 {
|
if len(abilities) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
// choose DB or provided tx
|
||||||
|
useDB := DB
|
||||||
|
if tx != nil {
|
||||||
|
useDB = tx
|
||||||
|
}
|
||||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||||
err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -284,6 +289,21 @@ func FixAbility() (int, int, error) {
|
|||||||
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
||||||
}
|
}
|
||||||
defer fixLock.Unlock()
|
defer fixLock.Unlock()
|
||||||
|
|
||||||
|
// truncate abilities table
|
||||||
|
if common.UsingSQLite {
|
||||||
|
err := DB.Exec("DELETE FROM abilities").Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := DB.Exec("TRUNCATE TABLE abilities").Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
var channels []*Channel
|
var channels []*Channel
|
||||||
// Find all channels
|
// Find all channels
|
||||||
err := DB.Model(&Channel{}).Find(&channels).Error
|
err := DB.Model(&Channel{}).Find(&channels).Error
|
||||||
@@ -300,15 +320,15 @@ func FixAbility() (int, int, error) {
|
|||||||
// Delete all abilities of this channel
|
// Delete all abilities of this channel
|
||||||
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||||
failCount += len(chunk)
|
failCount += len(chunk)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Then add new abilities
|
// Then add new abilities
|
||||||
for _, channel := range chunk {
|
for _, channel := range chunk {
|
||||||
err = channel.AddAbilities()
|
err = channel.AddAbilities(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
||||||
failCount++
|
failCount++
|
||||||
} else {
|
} else {
|
||||||
successCount++
|
successCount++
|
||||||
|
|||||||
202
model/channel.go
202
model/channel.go
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/samber/lo"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,19 +42,27 @@ type Channel struct {
|
|||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||||
OtherInfo string `json:"other_info"`
|
OtherInfo string `json:"other_info"`
|
||||||
|
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置
|
||||||
Tag *string `json:"tag" gorm:"index"`
|
Tag *string `json:"tag" gorm:"index"`
|
||||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||||
|
HeaderOverride *string `json:"header_override" gorm:"type:text"`
|
||||||
|
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||||
// add after v0.8.5
|
// add after v0.8.5
|
||||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||||
|
|
||||||
|
// cache info
|
||||||
|
Keys []string `json:"-" gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChannelInfo struct {
|
type ChannelInfo struct {
|
||||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
|
||||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
|
||||||
|
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||||
|
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements driver.Valuer interface
|
// Value implements driver.Valuer interface
|
||||||
@@ -67,15 +76,18 @@ func (c *ChannelInfo) Scan(value interface{}) error {
|
|||||||
return common.Unmarshal(bytesValue, c)
|
return common.Unmarshal(bytesValue, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) getKeys() []string {
|
func (channel *Channel) GetKeys() []string {
|
||||||
if channel.Key == "" {
|
if channel.Key == "" {
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
if len(channel.Keys) > 0 {
|
||||||
|
return channel.Keys
|
||||||
|
}
|
||||||
trimmed := strings.TrimSpace(channel.Key)
|
trimmed := strings.TrimSpace(channel.Key)
|
||||||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||||||
if strings.HasPrefix(trimmed, "[") {
|
if strings.HasPrefix(trimmed, "[") {
|
||||||
var arr []json.RawMessage
|
var arr []json.RawMessage
|
||||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||||
res := make([]string, len(arr))
|
res := make([]string, len(arr))
|
||||||
for i, v := range arr {
|
for i, v := range arr {
|
||||||
res[i] = string(v)
|
res[i] = string(v)
|
||||||
@@ -95,12 +107,16 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Obtain all keys (split by \n)
|
// Obtain all keys (split by \n)
|
||||||
keys := channel.getKeys()
|
keys := channel.GetKeys()
|
||||||
if len(keys) == 0 {
|
if len(keys) == 0 {
|
||||||
// No keys available, return error, should disable the channel
|
// No keys available, return error, should disable the channel
|
||||||
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lock := GetChannelPollingLock(channel.Id)
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||||
// helper to get key status, default to enabled when missing
|
// helper to get key status, default to enabled when missing
|
||||||
getStatus := func(idx int) int {
|
getStatus := func(idx int) int {
|
||||||
@@ -132,13 +148,10 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
|||||||
return keys[selectedIdx], selectedIdx, nil
|
return keys[selectedIdx], selectedIdx, nil
|
||||||
case constant.MultiKeyModePolling:
|
case constant.MultiKeyModePolling:
|
||||||
// Use channel-specific lock to ensure thread-safe polling
|
// Use channel-specific lock to ensure thread-safe polling
|
||||||
lock := getChannelPollingLock(channel.Id)
|
|
||||||
lock.Lock()
|
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed)
|
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -197,9 +210,9 @@ func (channel *Channel) GetGroups() []string {
|
|||||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||||
otherInfo := make(map[string]interface{})
|
otherInfo := make(map[string]interface{})
|
||||||
if channel.OtherInfo != "" {
|
if channel.OtherInfo != "" {
|
||||||
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to unmarshal other info: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return otherInfo
|
return otherInfo
|
||||||
@@ -208,7 +221,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
|||||||
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
||||||
otherInfoBytes, err := json.Marshal(otherInfo)
|
otherInfoBytes, err := json.Marshal(otherInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to marshal other info: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel.OtherInfo = string(otherInfoBytes)
|
channel.OtherInfo = string(otherInfoBytes)
|
||||||
@@ -236,6 +249,10 @@ func (channel *Channel) Save() error {
|
|||||||
return DB.Save(channel).Error
|
return DB.Save(channel).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) SaveWithoutKey() error {
|
||||||
|
return DB.Omit("key").Save(channel).Error
|
||||||
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
||||||
var channels []*Channel
|
var channels []*Channel
|
||||||
var err error
|
var err error
|
||||||
@@ -328,38 +345,54 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BatchInsertChannels(channels []Channel) error {
|
func BatchInsertChannels(channels []Channel) error {
|
||||||
var err error
|
if len(channels) == 0 {
|
||||||
err = DB.Create(&channels).Error
|
return nil
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
for _, channel_ := range channels {
|
tx := DB.Begin()
|
||||||
err = channel_.AddAbilities()
|
if tx.Error != nil {
|
||||||
if err != nil {
|
return tx.Error
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, chunk := range lo.Chunk(channels, 50) {
|
||||||
|
if err := tx.Create(&chunk).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
for _, channel_ := range chunk {
|
||||||
|
if err := channel_.AddAbilities(tx); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return tx.Commit().Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func BatchDeleteChannels(ids []int) error {
|
func BatchDeleteChannels(ids []int) error {
|
||||||
//使用事务 删除channel表和channel_ability表
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// 使用事务 分批删除channel表和abilities表
|
||||||
tx := DB.Begin()
|
tx := DB.Begin()
|
||||||
err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
|
if tx.Error != nil {
|
||||||
if err != nil {
|
return tx.Error
|
||||||
// 回滚事务
|
|
||||||
tx.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
|
for _, chunk := range lo.Chunk(ids, 200) {
|
||||||
if err != nil {
|
if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
|
||||||
// 回滚事务
|
tx.Rollback()
|
||||||
tx.Rollback()
|
return err
|
||||||
return err
|
}
|
||||||
|
if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 提交事务
|
return tx.Commit().Error
|
||||||
tx.Commit()
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetPriority() int64 {
|
func (channel *Channel) GetPriority() int64 {
|
||||||
@@ -380,7 +413,11 @@ func (channel *Channel) GetBaseURL() string {
|
|||||||
if channel.BaseURL == nil {
|
if channel.BaseURL == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return *channel.BaseURL
|
url := *channel.BaseURL
|
||||||
|
if url == "" {
|
||||||
|
url = constant.ChannelBaseURLs[channel.Type]
|
||||||
|
}
|
||||||
|
return url
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetModelMapping() string {
|
func (channel *Channel) GetModelMapping() string {
|
||||||
@@ -403,7 +440,7 @@ func (channel *Channel) Insert() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = channel.AddAbilities()
|
err = channel.AddAbilities(nil)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,7 +462,7 @@ func (channel *Channel) Update() error {
|
|||||||
trimmed := strings.TrimSpace(keyStr)
|
trimmed := strings.TrimSpace(keyStr)
|
||||||
if strings.HasPrefix(trimmed, "[") {
|
if strings.HasPrefix(trimmed, "[") {
|
||||||
var arr []json.RawMessage
|
var arr []json.RawMessage
|
||||||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||||
keys = make([]string, len(arr))
|
keys = make([]string, len(arr))
|
||||||
for i, v := range arr {
|
for i, v := range arr {
|
||||||
keys[i] = string(v)
|
keys[i] = string(v)
|
||||||
@@ -462,7 +499,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
|||||||
ResponseTime: int(responseTime),
|
ResponseTime: int(responseTime),
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update response time: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,7 +509,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
|
|||||||
Balance: balance,
|
Balance: balance,
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update balance: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -491,8 +528,8 @@ var channelStatusLock sync.Mutex
|
|||||||
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
|
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
|
||||||
var channelPollingLocks sync.Map
|
var channelPollingLocks sync.Map
|
||||||
|
|
||||||
// getChannelPollingLock returns or creates a mutex for the given channel ID
|
// GetChannelPollingLock returns or creates a mutex for the given channel ID
|
||||||
func getChannelPollingLock(channelId int) *sync.Mutex {
|
func GetChannelPollingLock(channelId int) *sync.Mutex {
|
||||||
if lock, exists := channelPollingLocks.Load(channelId); exists {
|
if lock, exists := channelPollingLocks.Load(channelId); exists {
|
||||||
return lock.(*sync.Mutex)
|
return lock.(*sync.Mutex)
|
||||||
}
|
}
|
||||||
@@ -522,8 +559,8 @@ func CleanupChannelPollingLocks() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
|
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
|
||||||
keys := channel.getKeys()
|
keys := channel.GetKeys()
|
||||||
if len(keys) == 0 {
|
if len(keys) == 0 {
|
||||||
channel.Status = status
|
channel.Status = status
|
||||||
} else {
|
} else {
|
||||||
@@ -541,6 +578,14 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
|
|||||||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||||
} else {
|
} else {
|
||||||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||||
|
}
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
|
||||||
|
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
|
||||||
}
|
}
|
||||||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||||||
channel.Status = common.ChannelStatusAutoDisabled
|
channel.Status = common.ChannelStatusAutoDisabled
|
||||||
@@ -563,7 +608,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
}
|
}
|
||||||
if channelCache.ChannelInfo.IsMultiKey {
|
if channelCache.ChannelInfo.IsMultiKey {
|
||||||
// 如果是多Key模式,更新缓存中的状态
|
// 如果是多Key模式,更新缓存中的状态
|
||||||
handlerMultiKeyUpdate(channelCache, usingKey, status)
|
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||||||
//CacheUpdateChannel(channelCache)
|
//CacheUpdateChannel(channelCache)
|
||||||
//return true
|
//return true
|
||||||
} else {
|
} else {
|
||||||
@@ -571,10 +616,6 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
if channelCache.Status == status {
|
if channelCache.Status == status {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
|
||||||
if status != common.ChannelStatusEnabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
CacheUpdateChannelStatus(channelId, status)
|
CacheUpdateChannelStatus(channelId, status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -584,7 +625,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
if shouldUpdateAbilities {
|
if shouldUpdateAbilities {
|
||||||
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update ability status: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -598,7 +639,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
|
|
||||||
if channel.ChannelInfo.IsMultiKey {
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
beforeStatus := channel.Status
|
beforeStatus := channel.Status
|
||||||
handlerMultiKeyUpdate(channel, usingKey, status)
|
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
||||||
if beforeStatus != channel.Status {
|
if beforeStatus != channel.Status {
|
||||||
shouldUpdateAbilities = true
|
shouldUpdateAbilities = true
|
||||||
}
|
}
|
||||||
@@ -610,9 +651,9 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
channel.Status = status
|
channel.Status = status
|
||||||
shouldUpdateAbilities = true
|
shouldUpdateAbilities = true
|
||||||
}
|
}
|
||||||
err = channel.Save()
|
err = channel.SaveWithoutKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel status: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -674,7 +715,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
|
|||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
err = channel.UpdateAbilities(nil)
|
err = channel.UpdateAbilities(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update abilities: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -698,7 +739,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
|
|||||||
func updateChannelUsedQuota(id int, quota int) {
|
func updateChannelUsedQuota(id int, quota int) {
|
||||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel used quota: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -778,7 +819,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
|||||||
func (channel *Channel) ValidateSettings() error {
|
func (channel *Channel) ValidateSettings() error {
|
||||||
channelParams := &dto.ChannelSettings{}
|
channelParams := &dto.ChannelSettings{}
|
||||||
if channel.Setting != nil && *channel.Setting != "" {
|
if channel.Setting != nil && *channel.Setting != "" {
|
||||||
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
|
err := common.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -789,9 +830,9 @@ func (channel *Channel) ValidateSettings() error {
|
|||||||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||||
setting := dto.ChannelSettings{}
|
setting := dto.ChannelSettings{}
|
||||||
if channel.Setting != nil && *channel.Setting != "" {
|
if channel.Setting != nil && *channel.Setting != "" {
|
||||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
|
||||||
channel.Setting = nil // 清空设置以避免后续错误
|
channel.Setting = nil // 清空设置以避免后续错误
|
||||||
_ = channel.Save() // 保存修改
|
_ = channel.Save() // 保存修改
|
||||||
}
|
}
|
||||||
@@ -800,25 +841,58 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||||
settingBytes, err := json.Marshal(setting)
|
settingBytes, err := common.Marshal(setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to marshal setting: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel.Setting = common.GetPointer[string](string(settingBytes))
|
channel.Setting = common.GetPointer[string](string(settingBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
|
||||||
|
setting := dto.ChannelOtherSettings{}
|
||||||
|
if channel.OtherSettings != "" {
|
||||||
|
err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
|
||||||
|
channel.OtherSettings = "{}" // 清空设置以避免后续错误
|
||||||
|
_ = channel.Save() // 保存修改
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return setting
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
|
||||||
|
settingBytes, err := common.Marshal(setting)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channel.OtherSettings = string(settingBytes)
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||||||
paramOverride := make(map[string]interface{})
|
paramOverride := make(map[string]interface{})
|
||||||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||||||
err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to unmarshal param override: " + err.Error())
|
common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return paramOverride
|
return paramOverride
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetHeaderOverride() map[string]interface{} {
|
||||||
|
headerOverride := make(map[string]interface{})
|
||||||
|
if channel.HeaderOverride != nil && *channel.HeaderOverride != "" {
|
||||||
|
err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return headerOverride
|
||||||
|
}
|
||||||
|
|
||||||
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
||||||
var channels []*Channel
|
var channels []*Channel
|
||||||
err := DB.Where("id in (?)", ids).Find(&channels).Error
|
err := DB.Where("id in (?)", ids).Find(&channels).Error
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -66,6 +68,20 @@ func InitChannelCache() {
|
|||||||
|
|
||||||
channelSyncLock.Lock()
|
channelSyncLock.Lock()
|
||||||
group2model2channels = newGroup2model2channels
|
group2model2channels = newGroup2model2channels
|
||||||
|
//channelsIDM = newChannelId2channel
|
||||||
|
for i, channel := range newChannelId2channel {
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
|
channel.Keys = channel.GetKeys()
|
||||||
|
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||||
|
if oldChannel, ok := channelsIDM[i]; ok {
|
||||||
|
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
|
||||||
|
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||||
|
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
channelsIDM = newChannelId2channel
|
channelsIDM = newChannelId2channel
|
||||||
channelSyncLock.Unlock()
|
channelSyncLock.Unlock()
|
||||||
common.SysLog("channels synced from database")
|
common.SysLog("channels synced from database")
|
||||||
@@ -109,20 +125,10 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string,
|
|||||||
return nil, group, err
|
return nil, group, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if channel == nil {
|
|
||||||
return nil, group, errors.New("channel not found")
|
|
||||||
}
|
|
||||||
return channel, selectGroup, nil
|
return channel, selectGroup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
|
||||||
model = "gpt-4-gizmo-*"
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(model, "gpt-4o-gizmo") {
|
|
||||||
model = "gpt-4o-gizmo-*"
|
|
||||||
}
|
|
||||||
|
|
||||||
// if memory cache is disabled, get channel directly from database
|
// if memory cache is disabled, get channel directly from database
|
||||||
if !common.MemoryCacheEnabled {
|
if !common.MemoryCacheEnabled {
|
||||||
return GetRandomSatisfiedChannel(group, model, retry)
|
return GetRandomSatisfiedChannel(group, model, retry)
|
||||||
@@ -130,10 +136,18 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
|||||||
|
|
||||||
channelSyncLock.RLock()
|
channelSyncLock.RLock()
|
||||||
defer channelSyncLock.RUnlock()
|
defer channelSyncLock.RUnlock()
|
||||||
|
|
||||||
|
// First, try to find channels with the exact model name.
|
||||||
channels := group2model2channels[group][model]
|
channels := group2model2channels[group][model]
|
||||||
|
|
||||||
|
// If no channels found, try to find channels with the normalized model name.
|
||||||
if len(channels) == 0 {
|
if len(channels) == 0 {
|
||||||
return nil, errors.New("channel not found")
|
normalizedModel := ratio_setting.FormatMatchingModelName(model)
|
||||||
|
channels = group2model2channels[group][normalizedModel]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(channels) == 0 {
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(channels) == 1 {
|
if len(channels) == 1 {
|
||||||
@@ -206,9 +220,6 @@ func CacheGetChannel(id int) (*Channel, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||||
}
|
}
|
||||||
if c.Status != common.ChannelStatusEnabled {
|
|
||||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
|
||||||
}
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,9 +238,6 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||||
}
|
}
|
||||||
if c.Status != common.ChannelStatusEnabled {
|
|
||||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
|
||||||
}
|
|
||||||
return &c.ChannelInfo, nil
|
return &c.ChannelInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,6 +250,20 @@ func CacheUpdateChannelStatus(id int, status int) {
|
|||||||
if channel, ok := channelsIDM[id]; ok {
|
if channel, ok := channelsIDM[id]; ok {
|
||||||
channel.Status = status
|
channel.Status = status
|
||||||
}
|
}
|
||||||
|
if status != common.ChannelStatusEnabled {
|
||||||
|
// delete the channel from group2model2channels
|
||||||
|
for group, model2channels := range group2model2channels {
|
||||||
|
for model, channels := range model2channels {
|
||||||
|
for i, channelId := range channels {
|
||||||
|
if channelId == id {
|
||||||
|
// remove the channel from the slice
|
||||||
|
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheUpdateChannel(channel *Channel) {
|
func CacheUpdateChannel(channel *Channel) {
|
||||||
|
|||||||
27
model/log.go
27
model/log.go
@@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
|
"one-api/types"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -87,13 +89,13 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
}
|
}
|
||||||
err := LOG_DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to record log: " + err.Error())
|
common.SysLog("failed to record log: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
|
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
|
||||||
isStream bool, group string, other map[string]interface{}) {
|
isStream bool, group string, other map[string]interface{}) {
|
||||||
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||||
username := c.GetString("username")
|
username := c.GetString("username")
|
||||||
otherStr := common.MapToJsonStr(other)
|
otherStr := common.MapToJsonStr(other)
|
||||||
// 判断是否需要记录 IP
|
// 判断是否需要记录 IP
|
||||||
@@ -129,7 +131,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
|||||||
}
|
}
|
||||||
err := LOG_DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "failed to record log: "+err.Error())
|
logger.LogError(c, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,7 +144,6 @@ type RecordConsumeLogParams struct {
|
|||||||
Quota int `json:"quota"`
|
Quota int `json:"quota"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
TokenId int `json:"token_id"`
|
TokenId int `json:"token_id"`
|
||||||
UserQuota int `json:"user_quota"`
|
|
||||||
UseTimeSeconds int `json:"use_time_seconds"`
|
UseTimeSeconds int `json:"use_time_seconds"`
|
||||||
IsStream bool `json:"is_stream"`
|
IsStream bool `json:"is_stream"`
|
||||||
Group string `json:"group"`
|
Group string `json:"group"`
|
||||||
@@ -150,10 +151,10 @@ type RecordConsumeLogParams struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
|
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
|
||||||
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
|
||||||
if !common.LogConsumeEnabled {
|
if !common.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
||||||
username := c.GetString("username")
|
username := c.GetString("username")
|
||||||
otherStr := common.MapToJsonStr(params.Other)
|
otherStr := common.MapToJsonStr(params.Other)
|
||||||
// 判断是否需要记录 IP
|
// 判断是否需要记录 IP
|
||||||
@@ -189,7 +190,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
|
|||||||
}
|
}
|
||||||
err := LOG_DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "failed to record log: "+err.Error())
|
logger.LogError(c, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
if common.DataExportEnabled {
|
if common.DataExportEnabled {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
@@ -236,26 +237,22 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
channelIdsMap := make(map[int]struct{})
|
channelIds := types.NewSet[int]()
|
||||||
channelMap := make(map[int]string)
|
|
||||||
for _, log := range logs {
|
for _, log := range logs {
|
||||||
if log.ChannelId != 0 {
|
if log.ChannelId != 0 {
|
||||||
channelIdsMap[log.ChannelId] = struct{}{}
|
channelIds.Add(log.ChannelId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
channelIds := make([]int, 0, len(channelIdsMap))
|
if channelIds.Len() > 0 {
|
||||||
for channelId := range channelIdsMap {
|
|
||||||
channelIds = append(channelIds, channelId)
|
|
||||||
}
|
|
||||||
if len(channelIds) > 0 {
|
|
||||||
var channels []struct {
|
var channels []struct {
|
||||||
Id int `gorm:"column:id"`
|
Id int `gorm:"column:id"`
|
||||||
Name string `gorm:"column:name"`
|
Name string `gorm:"column:name"`
|
||||||
}
|
}
|
||||||
if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil {
|
if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
|
||||||
return logs, total, err
|
return logs, total, err
|
||||||
}
|
}
|
||||||
|
channelMap := make(map[int]string, len(channels))
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
channelMap[channel.Id] = channel.Name
|
channelMap[channel.Id] = channel.Name
|
||||||
}
|
}
|
||||||
|
|||||||
118
model/main.go
118
model/main.go
@@ -180,6 +180,12 @@ func InitDB() (err error) {
|
|||||||
db = db.Debug()
|
db = db.Debug()
|
||||||
}
|
}
|
||||||
DB = db
|
DB = db
|
||||||
|
// MySQL charset/collation startup check: ensure Chinese-capable charset
|
||||||
|
if common.UsingMySQL {
|
||||||
|
if err := checkMySQLChineseSupport(DB); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
sqlDB, err := DB.DB()
|
sqlDB, err := DB.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -214,6 +220,12 @@ func InitLogDB() (err error) {
|
|||||||
db = db.Debug()
|
db = db.Debug()
|
||||||
}
|
}
|
||||||
LOG_DB = db
|
LOG_DB = db
|
||||||
|
// If log DB is MySQL, also ensure Chinese-capable charset
|
||||||
|
if common.LogSqlType == common.DatabaseTypeMySQL {
|
||||||
|
if err := checkMySQLChineseSupport(LOG_DB); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
sqlDB, err := LOG_DB.DB()
|
sqlDB, err := LOG_DB.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -235,9 +247,6 @@ func InitLogDB() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func migrateDB() error {
|
func migrateDB() error {
|
||||||
if !common.UsingPostgreSQL {
|
|
||||||
return migrateDBFast()
|
|
||||||
}
|
|
||||||
err := DB.AutoMigrate(
|
err := DB.AutoMigrate(
|
||||||
&Channel{},
|
&Channel{},
|
||||||
&Token{},
|
&Token{},
|
||||||
@@ -250,7 +259,12 @@ func migrateDB() error {
|
|||||||
&TopUp{},
|
&TopUp{},
|
||||||
&QuotaData{},
|
&QuotaData{},
|
||||||
&Task{},
|
&Task{},
|
||||||
|
&Model{},
|
||||||
|
&Vendor{},
|
||||||
|
&PrefillGroup{},
|
||||||
&Setup{},
|
&Setup{},
|
||||||
|
&TwoFA{},
|
||||||
|
&TwoFABackupCode{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -259,6 +273,7 @@ func migrateDB() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func migrateDBFast() error {
|
func migrateDBFast() error {
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
migrations := []struct {
|
migrations := []struct {
|
||||||
@@ -276,7 +291,12 @@ func migrateDBFast() error {
|
|||||||
{&TopUp{}, "TopUp"},
|
{&TopUp{}, "TopUp"},
|
||||||
{&QuotaData{}, "QuotaData"},
|
{&QuotaData{}, "QuotaData"},
|
||||||
{&Task{}, "Task"},
|
{&Task{}, "Task"},
|
||||||
|
{&Model{}, "Model"},
|
||||||
|
{&Vendor{}, "Vendor"},
|
||||||
|
{&PrefillGroup{}, "PrefillGroup"},
|
||||||
{&Setup{}, "Setup"},
|
{&Setup{}, "Setup"},
|
||||||
|
{&TwoFA{}, "TwoFA"},
|
||||||
|
{&TwoFABackupCode{}, "TwoFABackupCode"},
|
||||||
}
|
}
|
||||||
// 动态计算migration数量,确保errChan缓冲区足够大
|
// 动态计算migration数量,确保errChan缓冲区足够大
|
||||||
errChan := make(chan error, len(migrations))
|
errChan := make(chan error, len(migrations))
|
||||||
@@ -332,6 +352,98 @@ func CloseDB() error {
|
|||||||
return closeDB(DB)
|
return closeDB(DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkMySQLChineseSupport ensures the MySQL connection and current schema
|
||||||
|
// default charset/collation can store Chinese characters. It allows common
|
||||||
|
// Chinese-capable charsets (utf8mb4, utf8, gbk, big5, gb18030) and panics otherwise.
|
||||||
|
func checkMySQLChineseSupport(db *gorm.DB) error {
|
||||||
|
// 仅检测:当前库默认字符集/排序规则 + 各表的排序规则(隐含字符集)
|
||||||
|
|
||||||
|
// Read current schema defaults
|
||||||
|
var schemaCharset, schemaCollation string
|
||||||
|
err := db.Raw("SELECT DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = DATABASE()").Row().Scan(&schemaCharset, &schemaCollation)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("读取当前库默认字符集/排序规则失败 / Failed to read schema default charset/collation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toLower := func(s string) string { return strings.ToLower(s) }
|
||||||
|
// Allowed charsets that can store Chinese text
|
||||||
|
allowedCharsets := map[string]string{
|
||||||
|
"utf8mb4": "utf8mb4_",
|
||||||
|
"utf8": "utf8_",
|
||||||
|
"gbk": "gbk_",
|
||||||
|
"big5": "big5_",
|
||||||
|
"gb18030": "gb18030_",
|
||||||
|
}
|
||||||
|
isChineseCapable := func(cs, cl string) bool {
|
||||||
|
csLower := toLower(cs)
|
||||||
|
clLower := toLower(cl)
|
||||||
|
if prefix, ok := allowedCharsets[csLower]; ok {
|
||||||
|
if clLower == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return strings.HasPrefix(clLower, prefix)
|
||||||
|
}
|
||||||
|
// 如果仅提供了排序规则,尝试按排序规则前缀判断
|
||||||
|
for _, prefix := range allowedCharsets {
|
||||||
|
if strings.HasPrefix(clLower, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) 当前库默认值必须支持中文
|
||||||
|
if !isChineseCapable(schemaCharset, schemaCollation) {
|
||||||
|
return fmt.Errorf("当前库默认字符集/排序规则不支持中文:schema(%s/%s)。请将库设置为 utf8mb4/utf8/gbk/big5/gb18030 / Schema default charset/collation is not Chinese-capable: schema(%s/%s). Please set to utf8mb4/utf8/gbk/big5/gb18030",
|
||||||
|
schemaCharset, schemaCollation, schemaCharset, schemaCollation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 所有物理表的排序规则(隐含字符集)必须支持中文
|
||||||
|
type tableInfo struct {
|
||||||
|
Name string
|
||||||
|
Collation *string
|
||||||
|
}
|
||||||
|
var tables []tableInfo
|
||||||
|
if err := db.Raw("SELECT TABLE_NAME, TABLE_COLLATION FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'").Scan(&tables).Error; err != nil {
|
||||||
|
return fmt.Errorf("读取表排序规则失败 / Failed to read table collations: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var badTables []string
|
||||||
|
for _, t := range tables {
|
||||||
|
// NULL 或空表示继承库默认设置,已在上面校验库默认,视为通过
|
||||||
|
if t.Collation == nil || *t.Collation == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cl := *t.Collation
|
||||||
|
// 仅凭排序规则判断是否中文可用
|
||||||
|
ok := false
|
||||||
|
lower := strings.ToLower(cl)
|
||||||
|
for _, prefix := range allowedCharsets {
|
||||||
|
if strings.HasPrefix(lower, prefix) {
|
||||||
|
ok = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
badTables = append(badTables, fmt.Sprintf("%s(%s)", t.Name, cl))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(badTables) > 0 {
|
||||||
|
// 限制输出数量以避免日志过长
|
||||||
|
maxShow := 20
|
||||||
|
shown := badTables
|
||||||
|
if len(shown) > maxShow {
|
||||||
|
shown = shown[:maxShow]
|
||||||
|
}
|
||||||
|
return fmt.Errorf(
|
||||||
|
"存在不支持中文的表,请修复其排序规则/字符集。示例(最多展示 %d 项):%v / Found tables not Chinese-capable. Please fix their collation/charset. Examples (showing up to %d): %v",
|
||||||
|
maxShow, shown, maxShow, shown,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
lastPingTime time.Time
|
lastPingTime time.Time
|
||||||
pingMutex sync.Mutex
|
pingMutex sync.Mutex
|
||||||
|
|||||||
30
model/missing_models.go
Normal file
30
model/missing_models.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
// GetMissingModels returns model names that are referenced in the system
|
||||||
|
func GetMissingModels() ([]string, error) {
|
||||||
|
// 1. 获取所有已启用模型(去重)
|
||||||
|
models := GetEnabledModels()
|
||||||
|
if len(models) == 0 {
|
||||||
|
return []string{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 查询已有的元数据模型名
|
||||||
|
var existing []string
|
||||||
|
if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
existingSet := make(map[string]struct{}, len(existing))
|
||||||
|
for _, e := range existing {
|
||||||
|
existingSet[e] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 收集缺失模型
|
||||||
|
var missing []string
|
||||||
|
for _, name := range models {
|
||||||
|
if _, ok := existingSet[name]; !ok {
|
||||||
|
missing = append(missing, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return missing, nil
|
||||||
|
}
|
||||||
31
model/model_extra.go
Normal file
31
model/model_extra.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
func GetModelEnableGroups(modelName string) []string {
|
||||||
|
// 确保缓存最新
|
||||||
|
GetPricing()
|
||||||
|
|
||||||
|
if modelName == "" {
|
||||||
|
return make([]string, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelEnableGroupsLock.RLock()
|
||||||
|
groups, ok := modelEnableGroups[modelName]
|
||||||
|
modelEnableGroupsLock.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return make([]string, 0)
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelQuotaTypes 返回指定模型的计费类型集合(来自缓存)
|
||||||
|
func GetModelQuotaTypes(modelName string) []int {
|
||||||
|
GetPricing()
|
||||||
|
|
||||||
|
modelEnableGroupsLock.RLock()
|
||||||
|
quota, ok := modelQuotaTypeMap[modelName]
|
||||||
|
modelEnableGroupsLock.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return []int{}
|
||||||
|
}
|
||||||
|
return []int{quota}
|
||||||
|
}
|
||||||
147
model/model_meta.go
Normal file
147
model/model_meta.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/common"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NameRuleExact = iota
|
||||||
|
NameRulePrefix
|
||||||
|
NameRuleContains
|
||||||
|
NameRuleSuffix
|
||||||
|
)
|
||||||
|
|
||||||
|
type BoundChannel struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type int `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"`
|
||||||
|
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||||
|
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||||
|
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
||||||
|
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
||||||
|
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
||||||
|
Status int `json:"status" gorm:"default:1"`
|
||||||
|
SyncOfficial int `json:"sync_official" gorm:"default:1"`
|
||||||
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
|
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"`
|
||||||
|
|
||||||
|
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
||||||
|
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
||||||
|
QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"`
|
||||||
|
NameRule int `json:"name_rule" gorm:"default:0"`
|
||||||
|
|
||||||
|
MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
|
||||||
|
MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mi *Model) Insert() error {
|
||||||
|
now := common.GetTimestamp()
|
||||||
|
mi.CreatedTime = now
|
||||||
|
mi.UpdatedTime = now
|
||||||
|
return DB.Create(mi).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsModelNameDuplicated(id int, name string) (bool, error) {
|
||||||
|
if name == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
var cnt int64
|
||||||
|
err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
|
||||||
|
return cnt > 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mi *Model) Update() error {
|
||||||
|
mi.UpdatedTime = common.GetTimestamp()
|
||||||
|
return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
|
||||||
|
Model(&Model{}).
|
||||||
|
Where("id = ?", mi.Id).
|
||||||
|
Omit("created_time").
|
||||||
|
Select("*").
|
||||||
|
Updates(mi).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mi *Model) Delete() error {
|
||||||
|
return DB.Delete(mi).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetVendorModelCounts() (map[int64]int64, error) {
|
||||||
|
var stats []struct {
|
||||||
|
VendorID int64
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
if err := DB.Model(&Model{}).
|
||||||
|
Select("vendor_id as vendor_id, count(*) as count").
|
||||||
|
Group("vendor_id").
|
||||||
|
Scan(&stats).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m := make(map[int64]int64, len(stats))
|
||||||
|
for _, s := range stats {
|
||||||
|
m[s.VendorID] = s.Count
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllModels(offset int, limit int) ([]*Model, error) {
|
||||||
|
var models []*Model
|
||||||
|
err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error
|
||||||
|
return models, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) {
|
||||||
|
result := make(map[string][]BoundChannel)
|
||||||
|
if len(modelNames) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
type row struct {
|
||||||
|
Model string
|
||||||
|
Name string
|
||||||
|
Type int
|
||||||
|
}
|
||||||
|
var rows []row
|
||||||
|
err := DB.Table("channels").
|
||||||
|
Select("abilities.model as model, channels.name as name, channels.type as type").
|
||||||
|
Joins("JOIN abilities ON abilities.channel_id = channels.id").
|
||||||
|
Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
|
||||||
|
Distinct().
|
||||||
|
Scan(&rows).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, r := range rows {
|
||||||
|
result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type})
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||||||
|
var models []*Model
|
||||||
|
db := DB.Model(&Model{})
|
||||||
|
if keyword != "" {
|
||||||
|
like := "%" + keyword + "%"
|
||||||
|
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
|
||||||
|
}
|
||||||
|
if vendor != "" {
|
||||||
|
if vid, err := strconv.Atoi(vendor); err == nil {
|
||||||
|
db = db.Where("models.vendor_id = ?", vid)
|
||||||
|
} else {
|
||||||
|
db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
if err := db.Count(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return models, total, nil
|
||||||
|
}
|
||||||
@@ -150,7 +150,7 @@ func loadOptionsFromDatabase() {
|
|||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
err := updateOptionMap(option.Key, option.Value)
|
err := updateOptionMap(option.Key, option.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update option map: " + err.Error())
|
common.SysLog("failed to update option map: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -336,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
common.LinuxDOClientId = value
|
common.LinuxDOClientId = value
|
||||||
case "LinuxDOClientSecret":
|
case "LinuxDOClientSecret":
|
||||||
common.LinuxDOClientSecret = value
|
common.LinuxDOClientSecret = value
|
||||||
|
case "LinuxDOMinimumTrustLevel":
|
||||||
|
common.LinuxDOMinimumTrustLevel, _ = strconv.Atoi(value)
|
||||||
case "Footer":
|
case "Footer":
|
||||||
common.Footer = value
|
common.Footer = value
|
||||||
case "SystemName":
|
case "SystemName":
|
||||||
|
|||||||
126
model/prefill_group.go
Normal file
126
model/prefill_group.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"one-api/common"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。
|
||||||
|
// Name 字段保持唯一,用于在前端下拉框中展示。
|
||||||
|
// Type 字段用于区分组的类别,可选值如:model、tag、endpoint。
|
||||||
|
// Items 字段使用 JSON 数组保存对应类型的字符串集合,示例:
|
||||||
|
// ["gpt-4o", "gpt-3.5-turbo"]
|
||||||
|
// 设计遵循 3NF,避免冗余,提供灵活扩展能力。
|
||||||
|
|
||||||
|
// JSONValue 基于 json.RawMessage 实现,支持从数据库的 []byte 和 string 两种类型读取
|
||||||
|
type JSONValue json.RawMessage
|
||||||
|
|
||||||
|
// Value 实现 driver.Valuer 接口,用于数据库写入
|
||||||
|
func (j JSONValue) Value() (driver.Value, error) {
|
||||||
|
if j == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return []byte(j), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型
|
||||||
|
func (j *JSONValue) Scan(value interface{}) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case nil:
|
||||||
|
*j = nil
|
||||||
|
return nil
|
||||||
|
case []byte:
|
||||||
|
// 拷贝底层字节,避免保留底层缓冲区
|
||||||
|
b := make([]byte, len(v))
|
||||||
|
copy(b, v)
|
||||||
|
*j = JSONValue(b)
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
*j = JSONValue([]byte(v))
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
// 其他类型尝试序列化为 JSON
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*j = JSONValue(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致
|
||||||
|
func (j JSONValue) MarshalJSON() ([]byte, error) {
|
||||||
|
if j == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return j, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致
|
||||||
|
func (j *JSONValue) UnmarshalJSON(data []byte) error {
|
||||||
|
if data == nil {
|
||||||
|
*j = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
b := make([]byte, len(data))
|
||||||
|
copy(b, data)
|
||||||
|
*j = JSONValue(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type PrefillGroup struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"`
|
||||||
|
Type string `json:"type" gorm:"size:32;index;not null"`
|
||||||
|
Items JSONValue `json:"items" gorm:"type:json"`
|
||||||
|
Description string `json:"description,omitempty" gorm:"type:varchar(255)"`
|
||||||
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
|
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert 新建组
|
||||||
|
func (g *PrefillGroup) Insert() error {
|
||||||
|
now := common.GetTimestamp()
|
||||||
|
g.CreatedTime = now
|
||||||
|
g.UpdatedTime = now
|
||||||
|
return DB.Create(g).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID)
|
||||||
|
func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) {
|
||||||
|
if name == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
var cnt int64
|
||||||
|
err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
|
||||||
|
return cnt > 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新组
|
||||||
|
func (g *PrefillGroup) Update() error {
|
||||||
|
g.UpdatedTime = common.GetTimestamp()
|
||||||
|
return DB.Save(g).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByID 根据 ID 删除组
|
||||||
|
func DeletePrefillGroupByID(id int) error {
|
||||||
|
return DB.Delete(&PrefillGroup{}, id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部)
|
||||||
|
func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) {
|
||||||
|
var groups []*PrefillGroup
|
||||||
|
query := DB.Model(&PrefillGroup{})
|
||||||
|
if groupType != "" {
|
||||||
|
query = query.Where("type = ?", groupType)
|
||||||
|
}
|
||||||
|
if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return groups, nil
|
||||||
|
}
|
||||||
203
model/pricing.go
203
model/pricing.go
@@ -1,7 +1,10 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
@@ -12,6 +15,10 @@ import (
|
|||||||
|
|
||||||
type Pricing struct {
|
type Pricing struct {
|
||||||
ModelName string `json:"model_name"`
|
ModelName string `json:"model_name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Icon string `json:"icon,omitempty"`
|
||||||
|
Tags string `json:"tags,omitempty"`
|
||||||
|
VendorID int `json:"vendor_id,omitempty"`
|
||||||
QuotaType int `json:"quota_type"`
|
QuotaType int `json:"quota_type"`
|
||||||
ModelRatio float64 `json:"model_ratio"`
|
ModelRatio float64 `json:"model_ratio"`
|
||||||
ModelPrice float64 `json:"model_price"`
|
ModelPrice float64 `json:"model_price"`
|
||||||
@@ -21,10 +28,24 @@ type Pricing struct {
|
|||||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PricingVendor struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Icon string `json:"icon,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
pricingMap []Pricing
|
pricingMap []Pricing
|
||||||
lastGetPricingTime time.Time
|
vendorsList []PricingVendor
|
||||||
updatePricingLock sync.Mutex
|
supportedEndpointMap map[string]common.EndpointInfo
|
||||||
|
lastGetPricingTime time.Time
|
||||||
|
updatePricingLock sync.Mutex
|
||||||
|
|
||||||
|
// 缓存映射:模型名 -> 启用分组 / 计费类型
|
||||||
|
modelEnableGroups = make(map[string][]string)
|
||||||
|
modelQuotaTypeMap = make(map[string]int)
|
||||||
|
modelEnableGroupsLock = sync.RWMutex{}
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -46,6 +67,15 @@ func GetPricing() []Pricing {
|
|||||||
return pricingMap
|
return pricingMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetVendors 返回当前定价接口使用到的供应商信息
|
||||||
|
func GetVendors() []PricingVendor {
|
||||||
|
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||||
|
// 保证先刷新一次
|
||||||
|
GetPricing()
|
||||||
|
}
|
||||||
|
return vendorsList
|
||||||
|
}
|
||||||
|
|
||||||
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||||||
if model == "" {
|
if model == "" {
|
||||||
return make([]constant.EndpointType, 0)
|
return make([]constant.EndpointType, 0)
|
||||||
@@ -62,9 +92,83 @@ func updatePricing() {
|
|||||||
//modelRatios := common.GetModelRatios()
|
//modelRatios := common.GetModelRatios()
|
||||||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 预加载模型元数据与供应商一次,避免循环查询
|
||||||
|
var allMeta []Model
|
||||||
|
_ = DB.Find(&allMeta).Error
|
||||||
|
metaMap := make(map[string]*Model)
|
||||||
|
prefixList := make([]*Model, 0)
|
||||||
|
suffixList := make([]*Model, 0)
|
||||||
|
containsList := make([]*Model, 0)
|
||||||
|
for i := range allMeta {
|
||||||
|
m := &allMeta[i]
|
||||||
|
if m.NameRule == NameRuleExact {
|
||||||
|
metaMap[m.ModelName] = m
|
||||||
|
} else {
|
||||||
|
switch m.NameRule {
|
||||||
|
case NameRulePrefix:
|
||||||
|
prefixList = append(prefixList, m)
|
||||||
|
case NameRuleSuffix:
|
||||||
|
suffixList = append(suffixList, m)
|
||||||
|
case NameRuleContains:
|
||||||
|
containsList = append(containsList, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将非精确规则模型匹配到 metaMap
|
||||||
|
for _, m := range prefixList {
|
||||||
|
for _, pricingModel := range enableAbilities {
|
||||||
|
if strings.HasPrefix(pricingModel.Model, m.ModelName) {
|
||||||
|
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||||||
|
metaMap[pricingModel.Model] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range suffixList {
|
||||||
|
for _, pricingModel := range enableAbilities {
|
||||||
|
if strings.HasSuffix(pricingModel.Model, m.ModelName) {
|
||||||
|
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||||||
|
metaMap[pricingModel.Model] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range containsList {
|
||||||
|
for _, pricingModel := range enableAbilities {
|
||||||
|
if strings.Contains(pricingModel.Model, m.ModelName) {
|
||||||
|
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||||||
|
metaMap[pricingModel.Model] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 预加载供应商
|
||||||
|
var vendors []Vendor
|
||||||
|
_ = DB.Find(&vendors).Error
|
||||||
|
vendorMap := make(map[int]*Vendor)
|
||||||
|
for i := range vendors {
|
||||||
|
vendorMap[vendors[i].Id] = &vendors[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化默认供应商映射
|
||||||
|
initDefaultVendorMapping(metaMap, vendorMap, enableAbilities)
|
||||||
|
|
||||||
|
// 构建对前端友好的供应商列表
|
||||||
|
vendorsList = make([]PricingVendor, 0, len(vendorMap))
|
||||||
|
for _, v := range vendorMap {
|
||||||
|
vendorsList = append(vendorsList, PricingVendor{
|
||||||
|
ID: v.Id,
|
||||||
|
Name: v.Name,
|
||||||
|
Description: v.Description,
|
||||||
|
Icon: v.Icon,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
modelGroupsMap := make(map[string]*types.Set[string])
|
modelGroupsMap := make(map[string]*types.Set[string])
|
||||||
|
|
||||||
for _, ability := range enableAbilities {
|
for _, ability := range enableAbilities {
|
||||||
@@ -79,12 +183,9 @@ func updatePricing() {
|
|||||||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||||||
modelSupportEndpointsStr := make(map[string][]string)
|
modelSupportEndpointsStr := make(map[string][]string)
|
||||||
|
|
||||||
|
// 先根据已有能力填充原生端点
|
||||||
for _, ability := range enableAbilities {
|
for _, ability := range enableAbilities {
|
||||||
endpoints, ok := modelSupportEndpointsStr[ability.Model]
|
endpoints := modelSupportEndpointsStr[ability.Model]
|
||||||
if !ok {
|
|
||||||
endpoints = make([]string, 0)
|
|
||||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
|
||||||
}
|
|
||||||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||||||
for _, channelType := range channelTypes {
|
for _, channelType := range channelTypes {
|
||||||
if !common.StringsContains(endpoints, string(channelType)) {
|
if !common.StringsContains(endpoints, string(channelType)) {
|
||||||
@@ -94,6 +195,23 @@ func updatePricing() {
|
|||||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 再补充模型自定义端点
|
||||||
|
for modelName, meta := range metaMap {
|
||||||
|
if strings.TrimSpace(meta.Endpoints) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var raw map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||||||
|
endpoints := modelSupportEndpointsStr[modelName]
|
||||||
|
for k := range raw {
|
||||||
|
if !common.StringsContains(endpoints, k) {
|
||||||
|
endpoints = append(endpoints, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelSupportEndpointsStr[modelName] = endpoints
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||||
for model, endpoints := range modelSupportEndpointsStr {
|
for model, endpoints := range modelSupportEndpointsStr {
|
||||||
supportedEndpoints := make([]constant.EndpointType, 0)
|
supportedEndpoints := make([]constant.EndpointType, 0)
|
||||||
@@ -104,6 +222,45 @@ func updatePricing() {
|
|||||||
modelSupportEndpointTypes[model] = supportedEndpoints
|
modelSupportEndpointTypes[model] = supportedEndpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
|
||||||
|
supportedEndpointMap = make(map[string]common.EndpointInfo)
|
||||||
|
// 1. 默认端点
|
||||||
|
for _, endpoints := range modelSupportEndpointTypes {
|
||||||
|
for _, et := range endpoints {
|
||||||
|
if info, ok := common.GetDefaultEndpointInfo(et); ok {
|
||||||
|
if _, exists := supportedEndpointMap[string(et)]; !exists {
|
||||||
|
supportedEndpointMap[string(et)] = info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2. 自定义端点(models 表)覆盖默认
|
||||||
|
for _, meta := range metaMap {
|
||||||
|
if strings.TrimSpace(meta.Endpoints) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var raw map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||||||
|
for k, v := range raw {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
|
||||||
|
case map[string]interface{}:
|
||||||
|
ep := common.EndpointInfo{Method: "POST"}
|
||||||
|
if p, ok := val["path"].(string); ok {
|
||||||
|
ep.Path = p
|
||||||
|
}
|
||||||
|
if m, ok := val["method"].(string); ok {
|
||||||
|
ep.Method = strings.ToUpper(m)
|
||||||
|
}
|
||||||
|
supportedEndpointMap[k] = ep
|
||||||
|
default:
|
||||||
|
// ignore unsupported types
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pricingMap = make([]Pricing, 0)
|
pricingMap = make([]Pricing, 0)
|
||||||
for model, groups := range modelGroupsMap {
|
for model, groups := range modelGroupsMap {
|
||||||
pricing := Pricing{
|
pricing := Pricing{
|
||||||
@@ -111,6 +268,18 @@ func updatePricing() {
|
|||||||
EnableGroup: groups.Items(),
|
EnableGroup: groups.Items(),
|
||||||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 补充模型元数据(描述、标签、供应商、状态)
|
||||||
|
if meta, ok := metaMap[model]; ok {
|
||||||
|
// 若模型被禁用(status!=1),则直接跳过,不返回给前端
|
||||||
|
if meta.Status != 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pricing.Description = meta.Description
|
||||||
|
pricing.Icon = meta.Icon
|
||||||
|
pricing.Tags = meta.Tags
|
||||||
|
pricing.VendorID = meta.VendorID
|
||||||
|
}
|
||||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||||
if findPrice {
|
if findPrice {
|
||||||
pricing.ModelPrice = modelPrice
|
pricing.ModelPrice = modelPrice
|
||||||
@@ -123,5 +292,21 @@ func updatePricing() {
|
|||||||
}
|
}
|
||||||
pricingMap = append(pricingMap, pricing)
|
pricingMap = append(pricingMap, pricing)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 刷新缓存映射,供高并发快速查询
|
||||||
|
modelEnableGroupsLock.Lock()
|
||||||
|
modelEnableGroups = make(map[string][]string)
|
||||||
|
modelQuotaTypeMap = make(map[string]int)
|
||||||
|
for _, p := range pricingMap {
|
||||||
|
modelEnableGroups[p.ModelName] = p.EnableGroup
|
||||||
|
modelQuotaTypeMap[p.ModelName] = p.QuotaType
|
||||||
|
}
|
||||||
|
modelEnableGroupsLock.Unlock()
|
||||||
|
|
||||||
lastGetPricingTime = time.Now()
|
lastGetPricingTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSupportedEndpointMap 返回全局端点到路径的映射
|
||||||
|
func GetSupportedEndpointMap() map[string]common.EndpointInfo {
|
||||||
|
return supportedEndpointMap
|
||||||
|
}
|
||||||
|
|||||||
128
model/pricing_default.go
Normal file
128
model/pricing_default.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 简化的供应商映射规则
|
||||||
|
var defaultVendorRules = map[string]string{
|
||||||
|
"gpt": "OpenAI",
|
||||||
|
"dall-e": "OpenAI",
|
||||||
|
"whisper": "OpenAI",
|
||||||
|
"o1": "OpenAI",
|
||||||
|
"o3": "OpenAI",
|
||||||
|
"claude": "Anthropic",
|
||||||
|
"gemini": "Google",
|
||||||
|
"moonshot": "Moonshot",
|
||||||
|
"kimi": "Moonshot",
|
||||||
|
"chatglm": "智谱",
|
||||||
|
"glm-": "智谱",
|
||||||
|
"qwen": "阿里巴巴",
|
||||||
|
"deepseek": "DeepSeek",
|
||||||
|
"abab": "MiniMax",
|
||||||
|
"ernie": "百度",
|
||||||
|
"spark": "讯飞",
|
||||||
|
"hunyuan": "腾讯",
|
||||||
|
"command": "Cohere",
|
||||||
|
"@cf/": "Cloudflare",
|
||||||
|
"360": "360",
|
||||||
|
"yi": "零一万物",
|
||||||
|
"jina": "Jina",
|
||||||
|
"mistral": "Mistral",
|
||||||
|
"grok": "xAI",
|
||||||
|
"llama": "Meta",
|
||||||
|
"doubao": "字节跳动",
|
||||||
|
"kling": "快手",
|
||||||
|
"jimeng": "即梦",
|
||||||
|
"vidu": "Vidu",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 供应商默认图标映射
|
||||||
|
var defaultVendorIcons = map[string]string{
|
||||||
|
"OpenAI": "OpenAI",
|
||||||
|
"Anthropic": "Claude.Color",
|
||||||
|
"Google": "Gemini.Color",
|
||||||
|
"Moonshot": "Moonshot",
|
||||||
|
"智谱": "Zhipu.Color",
|
||||||
|
"阿里巴巴": "Qwen.Color",
|
||||||
|
"DeepSeek": "DeepSeek.Color",
|
||||||
|
"MiniMax": "Minimax.Color",
|
||||||
|
"百度": "Wenxin.Color",
|
||||||
|
"讯飞": "Spark.Color",
|
||||||
|
"腾讯": "Hunyuan.Color",
|
||||||
|
"Cohere": "Cohere.Color",
|
||||||
|
"Cloudflare": "Cloudflare.Color",
|
||||||
|
"360": "Ai360.Color",
|
||||||
|
"零一万物": "Yi.Color",
|
||||||
|
"Jina": "Jina",
|
||||||
|
"Mistral": "Mistral.Color",
|
||||||
|
"xAI": "XAI",
|
||||||
|
"Meta": "Ollama",
|
||||||
|
"字节跳动": "Doubao.Color",
|
||||||
|
"快手": "Kling.Color",
|
||||||
|
"即梦": "Jimeng.Color",
|
||||||
|
"Vidu": "Vidu",
|
||||||
|
"微软": "AzureAI",
|
||||||
|
"Microsoft": "AzureAI",
|
||||||
|
"Azure": "AzureAI",
|
||||||
|
}
|
||||||
|
|
||||||
|
// initDefaultVendorMapping 简化的默认供应商映射
|
||||||
|
func initDefaultVendorMapping(metaMap map[string]*Model, vendorMap map[int]*Vendor, enableAbilities []AbilityWithChannel) {
|
||||||
|
for _, ability := range enableAbilities {
|
||||||
|
modelName := ability.Model
|
||||||
|
if _, exists := metaMap[modelName]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 匹配供应商
|
||||||
|
vendorID := 0
|
||||||
|
modelLower := strings.ToLower(modelName)
|
||||||
|
for pattern, vendorName := range defaultVendorRules {
|
||||||
|
if strings.Contains(modelLower, pattern) {
|
||||||
|
vendorID = getOrCreateVendor(vendorName, vendorMap)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建模型元数据
|
||||||
|
metaMap[modelName] = &Model{
|
||||||
|
ModelName: modelName,
|
||||||
|
VendorID: vendorID,
|
||||||
|
Status: 1,
|
||||||
|
NameRule: NameRuleExact,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找或创建供应商
|
||||||
|
func getOrCreateVendor(vendorName string, vendorMap map[int]*Vendor) int {
|
||||||
|
// 查找现有供应商
|
||||||
|
for id, vendor := range vendorMap {
|
||||||
|
if vendor.Name == vendorName {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建新供应商
|
||||||
|
newVendor := &Vendor{
|
||||||
|
Name: vendorName,
|
||||||
|
Status: 1,
|
||||||
|
Icon: getDefaultVendorIcon(vendorName),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := newVendor.Insert(); err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
vendorMap[newVendor.Id] = newVendor
|
||||||
|
return newVendor.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取供应商默认图标
|
||||||
|
func getDefaultVendorIcon(vendorName string) string {
|
||||||
|
if icon, exists := defaultVendorIcons[vendorName]; exists {
|
||||||
|
return icon
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
14
model/pricing_refresh.go
Normal file
14
model/pricing_refresh.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
// RefreshPricing 强制立即重新计算与定价相关的缓存。
|
||||||
|
// 该方法用于需要最新数据的内部管理 API,
|
||||||
|
// 因此会绕过默认的 1 分钟延迟刷新。
|
||||||
|
func RefreshPricing() {
|
||||||
|
updatePricingLock.Lock()
|
||||||
|
defer updatePricingLock.Unlock()
|
||||||
|
|
||||||
|
modelSupportEndpointsLock.Lock()
|
||||||
|
defer modelSupportEndpointsLock.Unlock()
|
||||||
|
|
||||||
|
updatePricing()
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New("兑换失败," + err.Error())
|
return 0, errors.New("兑换失败," + err.Error())
|
||||||
}
|
}
|
||||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
|
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
|
||||||
return redemption.Quota, nil
|
return redemption.Quota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ type SyncTaskQueryParams struct {
|
|||||||
UserIDs []int
|
UserIDs []int
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
|
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task {
|
||||||
t := &Task{
|
t := &Task{
|
||||||
UserId: relayInfo.UserId,
|
UserId: relayInfo.UserId,
|
||||||
SubmitTime: time.Now().Unix(),
|
SubmitTime: time.Now().Unix(),
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
token.Status = common.TokenStatusExpired
|
token.Status = common.TokenStatusExpired
|
||||||
err := token.SelectUpdate()
|
err := token.SelectUpdate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update token status" + err.Error())
|
common.SysLog("failed to update token status" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return token, errors.New("该令牌已过期")
|
return token, errors.New("该令牌已过期")
|
||||||
@@ -102,7 +102,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
token.Status = common.TokenStatusExhausted
|
token.Status = common.TokenStatusExhausted
|
||||||
err := token.SelectUpdate()
|
err := token.SelectUpdate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update token status" + err.Error())
|
common.SysLog("failed to update token status" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
keyPrefix := key[:3]
|
keyPrefix := key[:3]
|
||||||
@@ -134,7 +134,7 @@ func GetTokenById(id int) (*Token, error) {
|
|||||||
if shouldUpdateRedis(true, err) {
|
if shouldUpdateRedis(true, err) {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := cacheSetToken(token); err != nil {
|
if err := cacheSetToken(token); err != nil {
|
||||||
common.SysError("failed to update user status cache: " + err.Error())
|
common.SysLog("failed to update user status cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -147,7 +147,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
|||||||
if shouldUpdateRedis(fromDB, err) && token != nil {
|
if shouldUpdateRedis(fromDB, err) && token != nil {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
if err := cacheSetToken(*token); err != nil {
|
if err := cacheSetToken(*token); err != nil {
|
||||||
common.SysError("failed to update user status cache: " + err.Error())
|
common.SysLog("failed to update user status cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -178,7 +178,7 @@ func (token *Token) Update() (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheSetToken(*token)
|
err := cacheSetToken(*token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update token cache: " + err.Error())
|
common.SysLog("failed to update token cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -194,7 +194,7 @@ func (token *Token) SelectUpdate() (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheSetToken(*token)
|
err := cacheSetToken(*token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update token cache: " + err.Error())
|
common.SysLog("failed to update token cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -209,7 +209,7 @@ func (token *Token) Delete() (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheDeleteToken(token.Key)
|
err := cacheDeleteToken(token.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to delete token cache: " + err.Error())
|
common.SysLog("failed to delete token cache: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -269,7 +269,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheIncrTokenQuota(key, int64(quota))
|
err := cacheIncrTokenQuota(key, int64(quota))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to increase token quota: " + err.Error())
|
common.SysLog("failed to increase token quota: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -299,7 +299,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) {
|
|||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
err := cacheDecrTokenQuota(key, int64(quota))
|
err := cacheDecrTokenQuota(key, int64(quota))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to decrease token quota: " + err.Error())
|
common.SysLog("failed to decrease token quota: " + err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) {
|
|||||||
return errors.New("充值失败," + err.Error())
|
return errors.New("充值失败," + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount))
|
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
322
model/twofa.go
Normal file
322
model/twofa.go
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
|
||||||
|
|
||||||
|
// TwoFA 用户2FA设置表
|
||||||
|
type TwoFA struct {
|
||||||
|
Id int `json:"id" gorm:"primaryKey"`
|
||||||
|
UserId int `json:"user_id" gorm:"unique;not null;index"`
|
||||||
|
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
|
||||||
|
IsEnabled bool `json:"is_enabled"`
|
||||||
|
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
|
||||||
|
LockedUntil *time.Time `json:"locked_until,omitempty"`
|
||||||
|
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TwoFABackupCode 备用码使用记录表
|
||||||
|
type TwoFABackupCode struct {
|
||||||
|
Id int `json:"id" gorm:"primaryKey"`
|
||||||
|
UserId int `json:"user_id" gorm:"not null;index"`
|
||||||
|
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
|
||||||
|
IsUsed bool `json:"is_used"`
|
||||||
|
UsedAt *time.Time `json:"used_at,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTwoFAByUserId 根据用户ID获取2FA设置
|
||||||
|
func GetTwoFAByUserId(userId int) (*TwoFA, error) {
|
||||||
|
if userId == 0 {
|
||||||
|
return nil, errors.New("用户ID不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
var twoFA TwoFA
|
||||||
|
err := DB.Where("user_id = ?", userId).First(&twoFA).Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil // 返回nil表示未设置2FA
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &twoFA, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTwoFAEnabled 检查用户是否启用了2FA
|
||||||
|
func IsTwoFAEnabled(userId int) bool {
|
||||||
|
twoFA, err := GetTwoFAByUserId(userId)
|
||||||
|
if err != nil || twoFA == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return twoFA.IsEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTwoFA 创建2FA设置
|
||||||
|
func (t *TwoFA) Create() error {
|
||||||
|
// 检查用户是否已存在2FA设置
|
||||||
|
existing, err := GetTwoFAByUserId(t.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if existing != nil {
|
||||||
|
return errors.New("用户已存在2FA设置")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证用户存在
|
||||||
|
var user User
|
||||||
|
if err := DB.First(&user, t.UserId).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return errors.New("用户不存在")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return DB.Create(t).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新2FA设置
|
||||||
|
func (t *TwoFA) Update() error {
|
||||||
|
if t.Id == 0 {
|
||||||
|
return errors.New("2FA记录ID不能为空")
|
||||||
|
}
|
||||||
|
return DB.Save(t).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除2FA设置
|
||||||
|
func (t *TwoFA) Delete() error {
|
||||||
|
if t.Id == 0 {
|
||||||
|
return errors.New("2FA记录ID不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用事务确保原子性
|
||||||
|
return DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 同时删除相关的备用码记录(硬删除)
|
||||||
|
if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 硬删除2FA记录
|
||||||
|
return tx.Unscoped().Delete(t).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetFailedAttempts 重置失败尝试次数
|
||||||
|
func (t *TwoFA) ResetFailedAttempts() error {
|
||||||
|
t.FailedAttempts = 0
|
||||||
|
t.LockedUntil = nil
|
||||||
|
return t.Update()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementFailedAttempts 增加失败尝试次数
|
||||||
|
func (t *TwoFA) IncrementFailedAttempts() error {
|
||||||
|
t.FailedAttempts++
|
||||||
|
|
||||||
|
// 检查是否需要锁定
|
||||||
|
if t.FailedAttempts >= common.MaxFailAttempts {
|
||||||
|
lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second)
|
||||||
|
t.LockedUntil = &lockUntil
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.Update()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLocked 检查账户是否被锁定
|
||||||
|
func (t *TwoFA) IsLocked() bool {
|
||||||
|
if t.LockedUntil == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().Before(*t.LockedUntil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateBackupCodes 创建备用码
|
||||||
|
func CreateBackupCodes(userId int, codes []string) error {
|
||||||
|
return DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 先删除现有的备用码
|
||||||
|
if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建新的备用码记录
|
||||||
|
for _, code := range codes {
|
||||||
|
hashedCode, err := common.HashBackupCode(code)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
backupCode := TwoFABackupCode{
|
||||||
|
UserId: userId,
|
||||||
|
CodeHash: hashedCode,
|
||||||
|
IsUsed: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Create(&backupCode).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateBackupCode 验证并使用备用码
|
||||||
|
func ValidateBackupCode(userId int, code string) (bool, error) {
|
||||||
|
if !common.ValidateBackupCode(code) {
|
||||||
|
return false, errors.New("验证码或备用码不正确")
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedCode := common.NormalizeBackupCode(code)
|
||||||
|
|
||||||
|
// 查找未使用的备用码
|
||||||
|
var backupCodes []TwoFABackupCode
|
||||||
|
if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证备用码
|
||||||
|
for _, bc := range backupCodes {
|
||||||
|
if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) {
|
||||||
|
// 标记为已使用
|
||||||
|
now := time.Now()
|
||||||
|
bc.IsUsed = true
|
||||||
|
bc.UsedAt = &now
|
||||||
|
|
||||||
|
if err := DB.Save(&bc).Error; err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnusedBackupCodeCount 获取未使用的备用码数量
|
||||||
|
func GetUnusedBackupCodeCount(userId int) (int, error) {
|
||||||
|
var count int64
|
||||||
|
err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error
|
||||||
|
return int(count), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableTwoFA 禁用用户的2FA
|
||||||
|
func DisableTwoFA(userId int) error {
|
||||||
|
twoFA, err := GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if twoFA == nil {
|
||||||
|
return ErrTwoFANotEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除2FA设置和备用码
|
||||||
|
return twoFA.Delete()
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableTwoFA 启用2FA
|
||||||
|
func (t *TwoFA) Enable() error {
|
||||||
|
t.IsEnabled = true
|
||||||
|
t.FailedAttempts = 0
|
||||||
|
t.LockedUntil = nil
|
||||||
|
return t.Update()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录
|
||||||
|
func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
|
||||||
|
// 检查是否被锁定
|
||||||
|
if t.IsLocked() {
|
||||||
|
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP码
|
||||||
|
if !common.ValidateTOTPCode(t.Secret, code) {
|
||||||
|
// 增加失败次数
|
||||||
|
if err := t.IncrementFailedAttempts(); err != nil {
|
||||||
|
common.SysLog("更新2FA失败次数失败: " + err.Error())
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证成功,重置失败次数并更新最后使用时间
|
||||||
|
now := time.Now()
|
||||||
|
t.FailedAttempts = 0
|
||||||
|
t.LockedUntil = nil
|
||||||
|
t.LastUsedAt = &now
|
||||||
|
|
||||||
|
if err := t.Update(); err != nil {
|
||||||
|
common.SysLog("更新2FA使用记录失败: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录
|
||||||
|
func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
|
||||||
|
// 检查是否被锁定
|
||||||
|
if t.IsLocked() {
|
||||||
|
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证备用码
|
||||||
|
valid, err := ValidateBackupCode(t.UserId, code)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
// 增加失败次数
|
||||||
|
if err := t.IncrementFailedAttempts(); err != nil {
|
||||||
|
common.SysLog("更新2FA失败次数失败: " + err.Error())
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证成功,重置失败次数并更新最后使用时间
|
||||||
|
now := time.Now()
|
||||||
|
t.FailedAttempts = 0
|
||||||
|
t.LockedUntil = nil
|
||||||
|
t.LastUsedAt = &now
|
||||||
|
|
||||||
|
if err := t.Update(); err != nil {
|
||||||
|
common.SysLog("更新2FA使用记录失败: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTwoFAStats 获取2FA统计信息(管理员使用)
|
||||||
|
func GetTwoFAStats() (map[string]interface{}, error) {
|
||||||
|
var totalUsers, enabledUsers int64
|
||||||
|
|
||||||
|
// 总用户数
|
||||||
|
if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启用2FA的用户数
|
||||||
|
if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
enabledRate := float64(0)
|
||||||
|
if totalUsers > 0 {
|
||||||
|
enabledRate = float64(enabledUsers) / float64(totalUsers) * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"total_users": totalUsers,
|
||||||
|
"enabled_users": enabledUsers,
|
||||||
|
"enabled_rate": fmt.Sprintf("%.1f%%", enabledRate),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -21,12 +21,6 @@ type QuotaData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateQuotaData() {
|
func UpdateQuotaData() {
|
||||||
// recover
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
for {
|
for {
|
||||||
if common.DataExportEnabled {
|
if common.DataExportEnabled {
|
||||||
common.SysLog("正在更新数据看板数据...")
|
common.SysLog("正在更新数据看板数据...")
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user