mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 23:34:13 +00:00
Compare commits
295 Commits
jsrt
...
v0.8.8.3.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
faaa5a2949 | ||
|
|
c00f5a17c8 | ||
|
|
9c079d04a8 | ||
|
|
c9d4cdc57e | ||
|
|
12b4e80d4b | ||
|
|
6e2a04f374 | ||
|
|
8357b15fec | ||
|
|
ecdd9d1ccb | ||
|
|
10b04416c1 | ||
|
|
398ae7156b | ||
|
|
d85eeabf11 | ||
|
|
c056a7ad7c | ||
|
|
c784a70277 | ||
|
|
e6c87907d5 | ||
|
|
71e9290142 | ||
|
|
74ec34da67 | ||
|
|
7188749cb3 | ||
|
|
c28add55db | ||
|
|
78f34a8245 | ||
|
|
97d6f10f15 | ||
|
|
afefc4caca | ||
|
|
6abbd036f8 | ||
|
|
ef0db0f914 | ||
|
|
e01986fdd4 | ||
|
|
a0c6ebe2d8 | ||
|
|
d2183af23f | ||
|
|
953f1bdc3c | ||
|
|
e2429f20f8 | ||
|
|
f0945da4fb | ||
|
|
8df3de9ae5 | ||
|
|
277cc1cac8 | ||
|
|
07a92293e4 | ||
|
|
f995e31d04 | ||
|
|
9758a9e60d | ||
|
|
6f56696af2 | ||
|
|
345fbdf3d2 | ||
|
|
ce031f7d15 | ||
|
|
bd6b811183 | ||
|
|
196bafff03 | ||
|
|
f20b558e22 | ||
|
|
54447bf227 | ||
|
|
fc09051d8b | ||
|
|
1f5ef24ecd | ||
|
|
b1faf42529 | ||
|
|
6a85206e32 | ||
|
|
e3d3e697d3 | ||
|
|
db9b333930 | ||
|
|
f7b284ad73 | ||
|
|
e1970e8a66 | ||
|
|
0cd93d67ff | ||
|
|
6e806e21bd | ||
|
|
a8462c1b70 | ||
|
|
706ea8b649 | ||
|
|
95d46d1dfc | ||
|
|
010f27678d | ||
|
|
d87117a2cf | ||
|
|
4ed92a94a1 | ||
|
|
821ea34a3c | ||
|
|
ecb3d01376 | ||
|
|
e322ed4f05 | ||
|
|
bcf7e78665 | ||
|
|
0cb2bb2ea7 | ||
|
|
c5d97597c4 | ||
|
|
fe9acb6c59 | ||
|
|
bca78beb1b | ||
|
|
a8a42cbfa8 | ||
|
|
19df2ac234 | ||
|
|
e7524c85c2 | ||
|
|
a4356727e9 | ||
|
|
f15a53fae4 | ||
|
|
8e3cf2eaab | ||
|
|
c51ec3135b | ||
|
|
2469c439b1 | ||
|
|
1297addfb1 | ||
|
|
d6cbf43373 | ||
|
|
df647e7b42 | ||
|
|
fe16d05fbb | ||
|
|
1430c05b6c | ||
|
|
b25841e50d | ||
|
|
b704fc9254 | ||
|
|
352da66bd1 | ||
|
|
8205ad2cd0 | ||
|
|
e162b9c169 | ||
|
|
77e3502028 | ||
|
|
ae0461692c | ||
|
|
13bdb80958 | ||
|
|
6f74e7b738 | ||
|
|
eaee89f77a | ||
|
|
756a8c50d6 | ||
|
|
e0b859dbbe | ||
|
|
07b64ff1a4 | ||
|
|
7bc9192f3f | ||
|
|
057e551059 | ||
|
|
2f80c814aa | ||
|
|
136a029bb4 | ||
|
|
d4b32a403b | ||
|
|
722b187f83 | ||
|
|
0c5c5823bf | ||
|
|
f5a6b7d1f0 | ||
|
|
bcd236286c | ||
|
|
6c4ada5098 | ||
|
|
2402715492 | ||
|
|
f32cf02714 | ||
|
|
e224ee5498 | ||
|
|
90011aa0c9 | ||
|
|
d0589468c1 | ||
|
|
6ef5acbfe5 | ||
|
|
efe894cad6 | ||
|
|
2a366c176d | ||
|
|
8e280a6a24 | ||
|
|
f144518e0e | ||
|
|
fcc006ecd3 | ||
|
|
5fbadc6b21 | ||
|
|
7902570855 | ||
|
|
55898780f1 | ||
|
|
d16cb90c2f | ||
|
|
66dd514c56 | ||
|
|
ba40748118 | ||
|
|
3538cefe68 | ||
|
|
f77aef82d2 | ||
|
|
4d0037a40c | ||
|
|
fd7a4461cc | ||
|
|
8bc6ddbca8 | ||
|
|
7d50e432b5 | ||
|
|
6103888610 | ||
|
|
4d8189f21b | ||
|
|
cddb778577 | ||
|
|
fa506ec04f | ||
|
|
0eaeef5723 | ||
|
|
f87054895e | ||
|
|
d74a5bd507 | ||
|
|
b5d4535db6 | ||
|
|
4d7562fd79 | ||
|
|
5b869376ab | ||
|
|
19c522d9bc | ||
|
|
1d4ecad134 | ||
|
|
805464e406 | ||
|
|
c674c3561a | ||
|
|
7aa2972c3f | ||
|
|
986558fea7 | ||
|
|
818e34682c | ||
|
|
252fddf3de | ||
|
|
39079e7aff | ||
|
|
1fa4518bb9 | ||
|
|
1b739e87ae | ||
|
|
e944983567 | ||
|
|
4fccaf3284 | ||
|
|
0a79dc9ecc | ||
|
|
847a8c8c4d | ||
|
|
a1018c5823 | ||
|
|
323417182a | ||
|
|
f3bcf570f4 | ||
|
|
a3059597fb | ||
|
|
d19a6914f9 | ||
|
|
4313ede132 | ||
|
|
635bfd4aba | ||
|
|
38e72e1af7 | ||
|
|
26644bfd1e | ||
|
|
6a827fc7b9 | ||
|
|
3b3ae9c0dd | ||
|
|
301909e3e5 | ||
|
|
97a9c8627c | ||
|
|
56c1fbecea | ||
|
|
de9d18a2fe | ||
|
|
be16ad26b5 | ||
|
|
d762da9141 | ||
|
|
c05d6f7cdf | ||
|
|
7af3fb5ae4 | ||
|
|
3ac54b2178 | ||
|
|
42a26f076a | ||
|
|
3b67759730 | ||
|
|
5407a8345f | ||
|
|
3fe509757b | ||
|
|
952b679ca3 | ||
|
|
6799daacd1 | ||
|
|
fa02b5150c | ||
|
|
63a1904242 | ||
|
|
1e3450fdcb | ||
|
|
5541026b86 | ||
|
|
c36c920b34 | ||
|
|
514fea65c4 | ||
|
|
e269b3bfdd | ||
|
|
0862a9bfa7 | ||
|
|
f43c695527 | ||
|
|
ead43f081c | ||
|
|
4e2a3d61dc | ||
|
|
218ad6bbe0 | ||
|
|
b485f2e42e | ||
|
|
16e32c3f67 | ||
|
|
15f65bb558 | ||
|
|
b161d6831f | ||
|
|
969953039f | ||
|
|
f1506ed5da | ||
|
|
9a239d9e13 | ||
|
|
a5da09dfb9 | ||
|
|
6f81f2d143 | ||
|
|
0b877ca8a3 | ||
|
|
2911b9cd04 | ||
|
|
6b3f1ab0e4 | ||
|
|
2c15655b08 | ||
|
|
afa9c650fe | ||
|
|
28d8d82ded | ||
|
|
a100baf57f | ||
|
|
5621755655 | ||
|
|
d892bfc278 | ||
|
|
4369b18fbf | ||
|
|
fb9b5d31e8 | ||
|
|
3bf0748389 | ||
|
|
cf46b89814 | ||
|
|
3360b34af9 | ||
|
|
4558eb41fc | ||
|
|
bbc5584f80 | ||
|
|
8604c9f9d5 | ||
|
|
747e02ee0d | ||
|
|
8b0334309b | ||
|
|
48afa821e4 | ||
|
|
42a8d3e3dc | ||
|
|
a44fc51007 | ||
|
|
961bc874d2 | ||
|
|
b2b018ab93 | ||
|
|
77da33de4f | ||
|
|
06ad5e3f8c | ||
|
|
9326bf96fc | ||
|
|
bed73102b4 | ||
|
|
eb59f9c75d | ||
|
|
f3bd2ed472 | ||
|
|
456475d593 | ||
|
|
a36ce199ba | ||
|
|
b7c3ad0867 | ||
|
|
ea3545cc7e | ||
|
|
232ba46b16 | ||
|
|
5f011502d1 | ||
|
|
93b6f1066b | ||
|
|
52fe92ed7f | ||
|
|
0d005df463 | ||
|
|
e3ef3ace29 | ||
|
|
a203e98689 | ||
|
|
27f99a0f38 | ||
|
|
d1e48d02bd | ||
|
|
4f06a1df50 | ||
|
|
2d7ae1180f | ||
|
|
75b486b467 | ||
|
|
5b5f10fe93 | ||
|
|
5f654e76e2 | ||
|
|
aa8d112c58 | ||
|
|
e82dc0e841 | ||
|
|
dd741fc38a | ||
|
|
120e4ee92f | ||
|
|
9d2a56bff4 | ||
|
|
31d82a3169 | ||
|
|
d22ee5d451 | ||
|
|
203edaed50 | ||
|
|
93b5638a9c | ||
|
|
52a5e58f0c | ||
|
|
20607b0b5c | ||
|
|
6bebfe9e54 | ||
|
|
50b76f4466 | ||
|
|
23e4e25e9a | ||
|
|
5b83d478d6 | ||
|
|
dca38d01d6 | ||
|
|
0a434d3b3a | ||
|
|
7c4b83a430 | ||
|
|
b7f24b428b | ||
|
|
22a0ed0ee2 | ||
|
|
cf711d55a5 | ||
|
|
26ea562fdb | ||
|
|
efce0c6c57 | ||
|
|
a3768dae97 | ||
|
|
85efea3fb8 | ||
|
|
c820fda26d | ||
|
|
4740293640 | ||
|
|
8be8813cd8 | ||
|
|
8cc747ef22 | ||
|
|
d6ed2ab3e0 | ||
|
|
e8ae980104 | ||
|
|
cd8c23c0ab | ||
|
|
3568042cd9 | ||
|
|
7443129e18 | ||
|
|
a9e03e6172 | ||
|
|
cb16bf552e | ||
|
|
98952198bb | ||
|
|
338e914a60 | ||
|
|
4196a3db5a | ||
|
|
0e6b608f91 | ||
|
|
f1856fe4d2 | ||
|
|
870cdd5a56 | ||
|
|
f0f277dc2a | ||
|
|
b695e67154 | ||
|
|
fa2cd85007 | ||
|
|
4a8b7bfa37 | ||
|
|
7403df7e9c | ||
|
|
617c8e8f4f | ||
|
|
aa793088ed | ||
|
|
0089157b83 | ||
|
|
b887db474e |
@@ -4,4 +4,5 @@
|
||||
.vscode
|
||||
.gitignore
|
||||
Makefile
|
||||
docs
|
||||
docs
|
||||
.eslintcache
|
||||
11
.env.example
11
.env.example
@@ -73,14 +73,3 @@
|
||||
# 节点类型
|
||||
# 如果是主节点则为master
|
||||
# NODE_TYPE=master
|
||||
|
||||
|
||||
# JavaScript 运行时配置
|
||||
# 是否启用(默认:false)
|
||||
# JS_RUNTIME_ENABLED=true
|
||||
# 最大虚拟机数量(默认:8)
|
||||
# JS_MAX_VM_COUNT=
|
||||
# 运行超时时间(单位:秒,默认:5)
|
||||
# JS_SCRIPT_TIMEOUT=
|
||||
# 脚本文件夹(默认:scripts/)
|
||||
# JS_SCRIPT_PATH=
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,4 +10,5 @@ web/dist
|
||||
.env
|
||||
one-api
|
||||
.DS_Store
|
||||
tiktoken_cache
|
||||
tiktoken_cache
|
||||
.eslintcache
|
||||
@@ -2,6 +2,7 @@ FROM oven/bun:latest AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
COPY web/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web .
|
||||
COPY ./VERSION .
|
||||
|
||||
240
LICENSE
240
LICENSE
@@ -1,201 +1,103 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
# **New API 许可协议 (Licensing)**
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
本项目采用**基于使用场景的双重许可 (Usage-Based Dual Licensing)** 模式。
|
||||
|
||||
1. Definitions.
|
||||
**核心原则:**
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
- **默认许可:** 本项目默认在 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)** 下提供。任何用户在遵守 AGPLv3 条款和下述附加限制的前提下,均可免费使用。
|
||||
- **商业许可:** 在特定商业场景下,或当您希望获得 AGPLv3 之外的权利时,**必须**获取**商业许可证 (Commercial License)**。
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
---
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
## **1. 开源许可证 (Open Source License): AGPLv3 - 适用于基础使用**
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
- 在遵守 **AGPLv3** 条款的前提下,您可以自由地使用、修改和分发 New API。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
|
||||
- **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 New API 并通过网络提供服务 (SaaS),或者分发了修改后的版本,您必须以 AGPLv3 许可证向所有用户提供相应的**完整源代码**。
|
||||
- **附加限制 (重要):** 在仅使用 AGPLv3 开源许可证的情况下,您**必须**完整保留项目代码中原有的品牌标识、LOGO 及版权声明信息。**禁止以任何形式修改、移除或遮盖**这些信息。如需移除,必须获取商业许可证。
|
||||
- 使用前请务必仔细阅读并理解 AGPLv3 的所有条款及上述附加限制。
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
## **2. 商业许可证 (Commercial License) - 适用于高级场景及闭源需求**
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
在以下任一情况下,您**必须**联系我们获取并签署一份商业许可证,才能合法使用 New API:
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
- **场景一:移除品牌和版权信息**
|
||||
您希望在您的产品或服务中移除 New API 的 LOGO、UI界面中的版权声明或其他品牌标识。
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
- **场景二:规避 AGPLv3 开源义务**
|
||||
您基于 New API 进行了修改,并希望:
|
||||
- 通过网络提供服务(SaaS),但**不希望**向您的服务用户公开您修改后的源代码。
|
||||
- 分发一个集成了 New API 的软件产品,但**不希望**以 AGPLv3 许可证发布您的产品或公开源代码。
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
- **场景三:企业政策与集成需求**
|
||||
- 您所在公司的政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件。
|
||||
- 您需要进行 OEM 集成,将 New API 作为您闭源商业产品的一部分进行再分发。
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
- **场景四:需要商业支持与保障**
|
||||
您需要 AGPLv3 未提供的商业保障,如官方技术支持等。
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
**获取商业许可:**
|
||||
请通过电子邮件 **support@quantumnous.com** 联系 New API 团队洽谈商业授权事宜。
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
## **3. 贡献 (Contributions)**
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
- 我们欢迎社区对 New API 的贡献。所有向本项目提交的贡献(例如通过 Pull Request)都将被视为在 **AGPLv3** 许可证下提供。
|
||||
- 通过向本项目提交贡献,即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
|
||||
- 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 New API 版本中。
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
## **4. 其他条款 (Other Terms)**
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
- 关于商业许可证的具体条款、条件和价格,以双方签署的正式商业许可协议为准。
|
||||
- 项目维护者保留根据需要更新本许可政策的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
---
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
# **New API Licensing**
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
This project uses a **Usage-Based Dual Licensing** model.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
**Core Principles:**
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
- **Default License:** This project is available by default under the **GNU Affero General Public License v3.0 (AGPLv3)**. Any user may use it free of charge, provided they comply with both the AGPLv3 terms and the additional restrictions listed below.
|
||||
- **Commercial License:** For specific commercial scenarios, or if you require rights beyond those granted by AGPLv3, you **must** obtain a **Commercial License**.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
---
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
## **1. Open Source License: AGPLv3 – For Basic Usage**
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
- Under the terms of the **AGPLv3**, you are free to use, modify, and distribute New API. The complete AGPLv3 license text can be viewed at [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html).
|
||||
- **Core Obligation:** A key AGPLv3 requirement is that if you modify New API and provide it as a network service (SaaS), or distribute a modified version, you must make the **complete corresponding source code** available to all users under the AGPLv3 license.
|
||||
- **Additional Restriction (Important):** When using only the AGPLv3 open-source license, you **must** retain all original branding, logos, and copyright statements within the project’s code. **You are strictly prohibited from modifying, removing, or concealing** any such information. If you wish to remove this, you must obtain a Commercial License.
|
||||
- Please read and ensure that you fully understand all AGPLv3 terms and the above additional restriction before use.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
## **2. Commercial License – For Advanced Scenarios & Closed Source Needs**
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
You **must** contact us to obtain and sign a Commercial License in any of the following scenarios in order to legally use New API:
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
- **Scenario 1: Removal of Branding and Copyright**
|
||||
You wish to remove the New API logo, copyright statement, or other branding elements from your product or service.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
- **Scenario 2: Avoidance of AGPLv3 Open Source Obligations**
|
||||
You have modified New API and wish to:
|
||||
- Offer it as a network service (SaaS) **without** disclosing your modifications' source code to your users.
|
||||
- Distribute a software product integrated with New API **without** releasing your product under AGPLv3 or open-sourcing the code.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
- **Scenario 3: Enterprise Policy & Integration Needs**
|
||||
- Your organization’s policies, client contracts, or project requirements prohibit the use of AGPLv3-licensed software.
|
||||
- You require OEM integration and need to redistribute New API as part of your closed-source commercial product.
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
- **Scenario 4: Commercial Support and Assurances**
|
||||
You require commercial assurances not provided by AGPLv3, such as official technical support.
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
**Obtaining a Commercial License:**
|
||||
Please contact the New API team via email at **support@quantumnous.com** to discuss commercial licensing.
|
||||
|
||||
## **3. Contributions**
|
||||
|
||||
- We welcome community contributions to New API. All contributions (e.g., via Pull Request) are deemed to be provided under the **AGPLv3** license.
|
||||
- By submitting a contribution, you agree that your code is licensed to this project and all downstream users under the AGPLv3 license (regardless of whether those users ultimately operate under AGPLv3 or a Commercial License).
|
||||
- You also acknowledge and agree that your contribution may be included in New API releases distributed under a Commercial License.
|
||||
|
||||
## **4. Other Terms**
|
||||
|
||||
- The specific terms, conditions, and pricing of the Commercial License are governed by the formal commercial license agreement executed by both parties.
|
||||
- Project maintainers reserve the right to update this licensing policy as needed. Updates will be communicated via official project channels (e.g., repository, official website).
|
||||
|
||||
18
README.en.md
18
README.en.md
@@ -189,6 +189,24 @@ If you have any questions, please refer to [Help and Support](https://docs.newap
|
||||
- [Issue Feedback](https://docs.newapi.pro/support/feedback-issues)
|
||||
- [FAQ](https://docs.newapi.pro/support/faq)
|
||||
|
||||
## 🤝 Trusted Partners
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target="_blank"><img
|
||||
src="./docs/images/cherry-studio.svg" alt="Cherry Studio" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://bda.pku.edu.cn/" target="_blank"><img
|
||||
src="./docs/images/pku.png" alt="Peking University" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank"><img
|
||||
src="./docs/images/ucloud.svg" alt="UCloud" height="58"
|
||||
/></a>
|
||||
</p>
|
||||
|
||||
<p align="center"><em>No particular order</em></p>
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
18
README.md
18
README.md
@@ -188,6 +188,24 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||
- [反馈问题](https://docs.newapi.pro/support/feedback-issues)
|
||||
- [常见问题](https://docs.newapi.pro/support/faq)
|
||||
|
||||
## 🤝 我们信任的合作伙伴
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target="_blank"><img
|
||||
src="./docs/images/cherry-studio.svg" alt="Cherry Studio" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://bda.pku.edu.cn/" target="_blank"><img
|
||||
src="./docs/images/pku.png" alt="北京大学" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank"><img
|
||||
src="./docs/images/ucloud.svg" alt="UCloud 优刻得" height="58"
|
||||
/></a>
|
||||
</p>
|
||||
|
||||
<p align="center"><em>排名不分先后</em></p>
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
@@ -63,6 +63,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = constant.APITypeXai
|
||||
case constant.ChannelTypeCoze:
|
||||
apiType = constant.APITypeCoze
|
||||
case constant.ChannelTypeJimeng:
|
||||
apiType = constant.APITypeJimeng
|
||||
}
|
||||
if apiType == -1 {
|
||||
return constant.APITypeOpenAI, false
|
||||
|
||||
@@ -193,3 +193,9 @@ const (
|
||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
TopUpStatusPending = "pending"
|
||||
TopUpStatusSuccess = "success"
|
||||
TopUpStatusExpired = "expired"
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type stringWriter interface {
|
||||
@@ -52,6 +53,8 @@ type CustomEvent struct {
|
||||
Id string
|
||||
Retry uint
|
||||
Data interface{}
|
||||
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
func encode(writer io.Writer, event CustomEvent) error {
|
||||
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
|
||||
}
|
||||
|
||||
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||
r.Mutex.Lock()
|
||||
defer r.Mutex.Unlock()
|
||||
header := w.Header()
|
||||
header["Content-Type"] = contentType
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -32,7 +33,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = UnmarshalJson(requestBody, &v)
|
||||
err = Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
@@ -86,3 +87,25 @@ func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool)
|
||||
var t T
|
||||
return t, false
|
||||
}
|
||||
|
||||
func ApiError(c *gin.Context, err error) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ApiErrorMsg(c *gin.Context, msg string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": msg,
|
||||
})
|
||||
}
|
||||
|
||||
func ApiSuccess(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
34
common/hash.go
Normal file
34
common/hash.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func Sha256Raw(data []byte) []byte {
|
||||
h := sha256.New()
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1Raw(data []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1(data []byte) string {
|
||||
return hex.EncodeToString(Sha1Raw(data))
|
||||
}
|
||||
|
||||
func HmacSha256Raw(message, key []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(message)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func HmacSha256(message, key string) string {
|
||||
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func UnmarshalJson(data []byte, v any) error {
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,6 @@ func DecodeJson(reader *bytes.Reader, v any) error {
|
||||
return json.NewDecoder(reader).Decode(v)
|
||||
}
|
||||
|
||||
func EncodeJson(v any) ([]byte, error) {
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
@@ -75,6 +75,9 @@ func logHelper(ctx context.Context, level string, msg string) {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
id := ctx.Value(RequestIdKey)
|
||||
if id == nil {
|
||||
id = "SYSTEM"
|
||||
}
|
||||
now := time.Now()
|
||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||
logCount++ // we don't need accurate count, so no lock here
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PageInfo struct {
|
||||
Page int `json:"page"` // page num 页码
|
||||
PageSize int `json:"page_size"` // page size 页大小
|
||||
StartTimestamp int64 `json:"start_timestamp"` // 秒级
|
||||
EndTimestamp int64 `json:"end_timestamp"` // 秒级
|
||||
Page int `json:"page"` // page num 页码
|
||||
PageSize int `json:"page_size"` // page size 页大小
|
||||
|
||||
Total int `json:"total"` // 总条数,后设置
|
||||
Items any `json:"items"` // 数据,后设置
|
||||
@@ -39,11 +38,14 @@ func (p *PageInfo) SetItems(items any) {
|
||||
p.Items = items
|
||||
}
|
||||
|
||||
func GetPageQuery(c *gin.Context) (*PageInfo, error) {
|
||||
func GetPageQuery(c *gin.Context) *PageInfo {
|
||||
pageInfo := &PageInfo{}
|
||||
err := c.BindQuery(pageInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 手动获取并处理每个参数
|
||||
if page, err := strconv.Atoi(c.Query("p")); err == nil {
|
||||
pageInfo.Page = page
|
||||
}
|
||||
if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
if pageInfo.Page < 1 {
|
||||
// 兼容
|
||||
@@ -56,7 +58,25 @@ func GetPageQuery(c *gin.Context) (*PageInfo, error) {
|
||||
}
|
||||
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageInfo.PageSize = ItemsPerPage
|
||||
// 兼容
|
||||
pageSize, _ := strconv.Atoi(c.Query("ps"))
|
||||
if pageSize != 0 {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageSize, _ = strconv.Atoi(c.Query("size")) // token page
|
||||
if pageSize != 0 {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
}
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageInfo.PageSize = ItemsPerPage
|
||||
}
|
||||
}
|
||||
return pageInfo, nil
|
||||
|
||||
if pageInfo.PageSize > 100 {
|
||||
pageInfo.PageSize = 100
|
||||
}
|
||||
|
||||
return pageInfo
|
||||
}
|
||||
|
||||
119
common/str.go
119
common/str.go
@@ -4,7 +4,10 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -32,16 +35,30 @@ func MapToJsonStr(m map[string]interface{}) string {
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func StrToMap(str string) map[string]interface{} {
|
||||
func StrToMap(str string) (map[string]interface{}, error) {
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(str), &m)
|
||||
err := Unmarshal([]byte(str), &m)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
return m
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func IsJsonStr(str string) bool {
|
||||
func StrToJsonArray(str string) ([]interface{}, error) {
|
||||
var js []interface{}
|
||||
err := json.Unmarshal([]byte(str), &js)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func IsJsonArray(str string) bool {
|
||||
var js []interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
|
||||
func IsJsonObject(str string) bool {
|
||||
var js map[string]interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
@@ -81,3 +98,95 @@ func GetJsonString(data any) string {
|
||||
b, _ := json.Marshal(data)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
|
||||
// Example:
|
||||
// http://example.com -> http://***.com
|
||||
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
||||
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
||||
// 192.168.1.1 -> ***.***.***.***
|
||||
func MaskSensitiveInfo(str string) string {
|
||||
// Mask URLs
|
||||
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
host := u.Host
|
||||
if host == "" {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
// Split host by dots
|
||||
parts := strings.Split(host, ".")
|
||||
if len(parts) < 2 {
|
||||
// If less than 2 parts, just mask the whole host
|
||||
return u.Scheme + "://***" + u.Path
|
||||
}
|
||||
|
||||
// Keep the TLD (Top Level Domain) and mask the rest
|
||||
var maskedHost string
|
||||
if len(parts) == 2 {
|
||||
// example.com -> ***.com
|
||||
maskedHost = "***." + parts[len(parts)-1]
|
||||
} else {
|
||||
// Handle cases like sub.domain.co.uk or api.example.com
|
||||
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
|
||||
lastPart := parts[len(parts)-1]
|
||||
secondLastPart := parts[len(parts)-2]
|
||||
|
||||
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
||||
// Likely country code TLD like co.uk, com.cn
|
||||
maskedHost = "***." + secondLastPart + "." + lastPart
|
||||
} else {
|
||||
// Regular TLD like .com, .org
|
||||
maskedHost = "***." + lastPart
|
||||
}
|
||||
}
|
||||
|
||||
result := u.Scheme + "://" + maskedHost
|
||||
|
||||
// Mask path
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
|
||||
maskedPathParts := make([]string, len(pathParts))
|
||||
for i := range pathParts {
|
||||
if pathParts[i] != "" {
|
||||
maskedPathParts[i] = "***"
|
||||
}
|
||||
}
|
||||
if len(maskedPathParts) > 0 {
|
||||
result += "/" + strings.Join(maskedPathParts, "/")
|
||||
}
|
||||
} else if u.Path == "/" {
|
||||
result += "/"
|
||||
}
|
||||
|
||||
// Mask query parameters
|
||||
if u.RawQuery != "" {
|
||||
values, err := url.ParseQuery(u.RawQuery)
|
||||
if err != nil {
|
||||
// If can't parse query, just mask the whole query string
|
||||
result += "?***"
|
||||
} else {
|
||||
maskedParams := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
maskedParams = append(maskedParams, key+"=***")
|
||||
}
|
||||
if len(maskedParams) > 0 {
|
||||
result += "?" + strings.Join(maskedParams, "&")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// Mask IP addresses
|
||||
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// StructToMap 递归地把任意结构体 v 转成 map[string]any。
|
||||
// - 只处理导出字段;未导出字段会被跳过。
|
||||
// - 优先使用 `json:"name"` 里逗号前的部分作为键;如果是 "-" 则忽略该字段;若无 tag,则使用字段名。
|
||||
// - 对指针、切片、数组、嵌套结构体、map 做深度遍历,保持原始结构。
|
||||
func StructToMap(v any) (map[string]any, error) {
|
||||
val := reflect.ValueOf(v)
|
||||
if !val.IsValid() {
|
||||
return nil, fmt.Errorf("nil value")
|
||||
}
|
||||
for val.Kind() == reflect.Pointer {
|
||||
if val.IsNil() {
|
||||
return nil, fmt.Errorf("nil pointer")
|
||||
}
|
||||
val = val.Elem()
|
||||
}
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("expect struct, got %s", val.Kind())
|
||||
}
|
||||
|
||||
return structValueToMap(val), nil
|
||||
}
|
||||
|
||||
func structValueToMap(val reflect.Value) map[string]any {
|
||||
out := make(map[string]any, val.NumField())
|
||||
|
||||
typ := val.Type()
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
if f.PkgPath != "" { // 未导出字段
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析 json tag
|
||||
tag := f.Tag.Get("json")
|
||||
name, opts := parseTag(tag)
|
||||
if name == "-" {
|
||||
continue
|
||||
}
|
||||
if name == "" {
|
||||
name = f.Name
|
||||
}
|
||||
|
||||
fv := val.Field(i)
|
||||
out[name] = valueToAny(fv, opts.Contains("omitempty"))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// valueToAny 递归处理各种值类型。
|
||||
func valueToAny(v reflect.Value, omitEmpty bool) any {
|
||||
if !v.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
for v.Kind() == reflect.Pointer {
|
||||
if v.IsNil() {
|
||||
if omitEmpty {
|
||||
return nil
|
||||
}
|
||||
// 保持与 encoding/json 行为一致,nil 指针写成 null
|
||||
return nil
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
|
||||
case reflect.Struct:
|
||||
return structValueToMap(v)
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
l := v.Len()
|
||||
arr := make([]any, l)
|
||||
for i := 0; i < l; i++ {
|
||||
arr[i] = valueToAny(v.Index(i), false)
|
||||
}
|
||||
return arr
|
||||
|
||||
case reflect.Map:
|
||||
m := make(map[string]any, v.Len())
|
||||
iter := v.MapRange()
|
||||
for iter.Next() {
|
||||
k := iter.Key()
|
||||
// 只支持 string key,与 encoding/json 一致
|
||||
if k.Kind() == reflect.String {
|
||||
m[k.String()] = valueToAny(iter.Value(), false)
|
||||
}
|
||||
}
|
||||
return m
|
||||
|
||||
default:
|
||||
// 基本类型直接返回其接口值
|
||||
return v.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
// tagOptions 用于判断是否包含 "omitempty"
|
||||
type tagOptions string
|
||||
|
||||
func (o tagOptions) Contains(opt string) bool {
|
||||
if len(o) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, s := range splitComma(string(o)) {
|
||||
if s == opt {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
if idx := indexComma(tag); idx != -1 {
|
||||
return tag[:idx], tagOptions(tag[idx+1:])
|
||||
}
|
||||
return tag, tagOptions("")
|
||||
}
|
||||
|
||||
// 避免 strings.Split 额外分配
|
||||
func indexComma(s string) int {
|
||||
for i, r := range s {
|
||||
if r == ',' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func splitComma(s string) []string {
|
||||
var parts []string
|
||||
start := 0
|
||||
for i, r := range s {
|
||||
if r == ',' {
|
||||
parts = append(parts, s[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start <= len(s) {
|
||||
parts = append(parts, s[start:])
|
||||
}
|
||||
return parts
|
||||
}
|
||||
150
common/totp.go
Normal file
150
common/totp.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
)
|
||||
|
||||
const (
|
||||
// 备用码配置
|
||||
BackupCodeLength = 8 // 备用码长度
|
||||
BackupCodeCount = 4 // 生成备用码数量
|
||||
|
||||
// 限制配置
|
||||
MaxFailAttempts = 5 // 最大失败尝试次数
|
||||
LockoutDuration = 300 // 锁定时间(秒)
|
||||
)
|
||||
|
||||
// GenerateTOTPSecret 生成TOTP密钥和配置
|
||||
func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
|
||||
issuer := Get2FAIssuer()
|
||||
return totp.Generate(totp.GenerateOpts{
|
||||
Issuer: issuer,
|
||||
AccountName: accountName,
|
||||
Period: 30,
|
||||
Digits: otp.DigitsSix,
|
||||
Algorithm: otp.AlgorithmSHA1,
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateTOTPCode 验证TOTP验证码
|
||||
func ValidateTOTPCode(secret, code string) bool {
|
||||
// 清理验证码格式
|
||||
cleanCode := strings.ReplaceAll(code, " ", "")
|
||||
if len(cleanCode) != 6 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
return totp.Validate(cleanCode, secret)
|
||||
}
|
||||
|
||||
// GenerateBackupCodes 生成备用恢复码
|
||||
func GenerateBackupCodes() ([]string, error) {
|
||||
codes := make([]string, BackupCodeCount)
|
||||
|
||||
for i := 0; i < BackupCodeCount; i++ {
|
||||
code, err := generateRandomBackupCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
codes[i] = code
|
||||
}
|
||||
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// generateRandomBackupCode 生成单个备用码
|
||||
func generateRandomBackupCode() (string, error) {
|
||||
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
code := make([]byte, BackupCodeLength)
|
||||
|
||||
for i := range code {
|
||||
randomBytes := make([]byte, 1)
|
||||
_, err := rand.Read(randomBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code[i] = charset[int(randomBytes[0])%len(charset)]
|
||||
}
|
||||
|
||||
// 格式化为 XXXX-XXXX 格式
|
||||
return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
|
||||
}
|
||||
|
||||
// ValidateBackupCode 验证备用码格式
|
||||
func ValidateBackupCode(code string) bool {
|
||||
// 移除所有分隔符并转为大写
|
||||
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||
if len(cleanCode) != BackupCodeLength {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查字符是否合法
|
||||
for _, char := range cleanCode {
|
||||
if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// NormalizeBackupCode 标准化备用码格式
|
||||
func NormalizeBackupCode(code string) string {
|
||||
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||
if len(cleanCode) == BackupCodeLength {
|
||||
return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
|
||||
}
|
||||
return code
|
||||
}
|
||||
|
||||
// HashBackupCode 对备用码进行哈希
|
||||
func HashBackupCode(code string) (string, error) {
|
||||
normalizedCode := NormalizeBackupCode(code)
|
||||
return Password2Hash(normalizedCode)
|
||||
}
|
||||
|
||||
// Get2FAIssuer 获取2FA发行者名称
|
||||
func Get2FAIssuer() string {
|
||||
return SystemName
|
||||
}
|
||||
|
||||
// getEnvOrDefault 获取环境变量或默认值
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// ValidateNumericCode 验证数字验证码格式
|
||||
func ValidateNumericCode(code string) (string, error) {
|
||||
// 移除空格
|
||||
code = strings.ReplaceAll(code, " ", "")
|
||||
|
||||
if len(code) != 6 {
|
||||
return "", fmt.Errorf("验证码必须是6位数字")
|
||||
}
|
||||
|
||||
// 检查是否为纯数字
|
||||
if _, err := strconv.Atoi(code); err != nil {
|
||||
return "", fmt.Errorf("验证码只能包含数字")
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// GenerateQRCodeData 生成二维码数据
|
||||
func GenerateQRCodeData(secret, username string) string {
|
||||
issuer := Get2FAIssuer()
|
||||
accountName := fmt.Sprintf("%s (%s)", username, issuer)
|
||||
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
|
||||
issuer, accountName, secret, issuer)
|
||||
}
|
||||
@@ -30,5 +30,6 @@ const (
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeCoze
|
||||
APITypeJimeng
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
@@ -49,6 +49,7 @@ const (
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeVidu = 52
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -106,4 +107,5 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
"https://api.vidu.cn", //52
|
||||
}
|
||||
|
||||
@@ -17,11 +17,20 @@ const (
|
||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||
|
||||
/* channel related keys */
|
||||
ContextKeyBaseUrl ContextKey = "base_url"
|
||||
ContextKeyChannelType ContextKey = "channel_type"
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyParamOverride ContextKey = "param_override"
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelName ContextKey = "channel_name"
|
||||
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
|
||||
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||
ContextKeyChannelType ContextKey = "channel_type"
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
|
||||
ContextKeyChannelKey ContextKey = "channel_key"
|
||||
|
||||
/* user related keys */
|
||||
ContextKeyUserId ContextKey = "id"
|
||||
|
||||
8
constant/multi_key_mode.go
Normal file
8
constant/multi_key_mode.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package constant
|
||||
|
||||
type MultiKeyMode string
|
||||
|
||||
const (
|
||||
MultiKeyModeRandom MultiKeyMode = "random" // 随机
|
||||
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
|
||||
)
|
||||
@@ -5,8 +5,6 @@ type TaskPlatform string
|
||||
const (
|
||||
TaskPlatformSuno TaskPlatform = "suno"
|
||||
TaskPlatformMidjourney = "mj"
|
||||
TaskPlatformKling TaskPlatform = "kling"
|
||||
TaskPlatformJimeng TaskPlatform = "jimeng"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/shopspring/decimal"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -12,9 +11,12 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -409,26 +411,24 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
func UpdateChannelBalance(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
channel, err := model.CacheGetChannel(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": "多密钥渠道不支持余额查询",
|
||||
})
|
||||
return
|
||||
}
|
||||
balance, err := updateChannelBalance(channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -436,7 +436,6 @@ func UpdateChannelBalance(c *gin.Context) {
|
||||
"message": "",
|
||||
"balance": balance,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func updateAllChannelsBalance() error {
|
||||
@@ -448,6 +447,9 @@ func updateAllChannelsBalance() error {
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
continue // skip multi-key channels
|
||||
}
|
||||
// TODO: support Azure
|
||||
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||||
// continue
|
||||
@@ -458,7 +460,7 @@ func updateAllChannelsBalance() error {
|
||||
} else {
|
||||
// err is nil & balance <= 0 means quota is used up
|
||||
if balance <= 0 {
|
||||
service.DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||
service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
|
||||
}
|
||||
}
|
||||
time.Sleep(common.RequestInterval)
|
||||
@@ -470,10 +472,7 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
||||
// TODO: make it async
|
||||
err := updateAllChannelsBalance()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -17,8 +17,10 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -29,22 +31,49 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
type testResult struct {
|
||||
context *gin.Context
|
||||
localErr error
|
||||
newAPIError *types.NewAPIError
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
tik := time.Now()
|
||||
if channel.Type == constant.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("midjourney channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||
return errors.New("midjourney plus channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("midjourney plus channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||
return errors.New("suno channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("suno channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeKling {
|
||||
return errors.New("kling channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("kling channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeJimeng {
|
||||
return errors.New("jimeng channel test is not supported"), nil
|
||||
return testResult{
|
||||
localErr: errors.New("jimeng channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeVidu {
|
||||
return testResult{
|
||||
localErr: errors.New("vidu channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -81,31 +110,49 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
cache, err := model.GetUserCache(1)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
localErr: err,
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
cache.WriteContext(c)
|
||||
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
group, _ := model.GetUserGroup(1, false)
|
||||
c.Set("group", group)
|
||||
|
||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
if newAPIError != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: newAPIError,
|
||||
newAPIError: newAPIError,
|
||||
}
|
||||
}
|
||||
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, nil)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||
}
|
||||
}
|
||||
testModel = info.UpstreamModelName
|
||||
|
||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
||||
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
||||
}
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel)
|
||||
@@ -116,45 +163,91 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
||||
}
|
||||
}
|
||||
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
var convertedRequest any
|
||||
// 根据 RelayMode 选择正确的转换函数
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
// 创建一个 EmbeddingRequest
|
||||
embeddingRequest := dto.EmbeddingRequest{
|
||||
Input: request.Input,
|
||||
Model: request.Model,
|
||||
}
|
||||
// 调用专门用于 Embedding 的转换函数
|
||||
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
|
||||
} else {
|
||||
// 对其他所有请求类型(如 Chat),保持原有逻辑
|
||||
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||
}
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||
}
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
}
|
||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: respErr,
|
||||
newAPIError: respErr,
|
||||
}
|
||||
}
|
||||
if usageA == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: errors.New("usage is nil"),
|
||||
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
result := w.Result()
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
|
||||
@@ -187,7 +280,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
Other: other,
|
||||
})
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: nil,
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
@@ -202,7 +299,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
strings.Contains(model, "bge-") {
|
||||
testRequest.Model = model
|
||||
// Embedding 请求
|
||||
testRequest.Input = []string{"hello world"}
|
||||
testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
|
||||
return testRequest
|
||||
}
|
||||
// 并非Embedding 模型
|
||||
@@ -230,31 +327,41 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
func TestChannel(c *gin.Context) {
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(channelId, true)
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
channel, err = model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
//defer func() {
|
||||
// if channel.ChannelInfo.IsMultiKey {
|
||||
// go func() { _ = channel.SaveChannelInfo() }()
|
||||
// }
|
||||
//}()
|
||||
testModel := c.Query("model")
|
||||
tik := time.Now()
|
||||
err, _ = testChannel(channel, testModel)
|
||||
result := testChannel(channel, testModel)
|
||||
if result.localErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": result.localErr.Error(),
|
||||
"time": 0.0,
|
||||
})
|
||||
return
|
||||
}
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
go channel.UpdateResponseTime(milliseconds)
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
if err != nil {
|
||||
if result.newAPIError != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": result.newAPIError.Error(),
|
||||
"time": consumedTime,
|
||||
})
|
||||
return
|
||||
@@ -279,9 +386,9 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
testAllChannelsRunning = true
|
||||
testAllChannelsLock.Unlock()
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
return err
|
||||
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
||||
if getChannelErr != nil {
|
||||
return getChannelErr
|
||||
}
|
||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||
if disableThreshold == 0 {
|
||||
@@ -298,32 +405,34 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiWithStatusErr := testChannel(channel, "")
|
||||
result := testChannel(channel, "")
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
shouldBanChannel := false
|
||||
|
||||
newAPIError := result.newAPIError
|
||||
// request error disables the channel
|
||||
if openaiWithStatusErr != nil {
|
||||
oaiErr := openaiWithStatusErr.Error
|
||||
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||
if newAPIError != nil {
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
||||
}
|
||||
|
||||
if milliseconds > disableThreshold {
|
||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
shouldBanChannel = true
|
||||
// 当错误检查通过,才检查响应时间
|
||||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||
if milliseconds > disableThreshold {
|
||||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||||
shouldBanChannel = true
|
||||
}
|
||||
}
|
||||
|
||||
// disable channel
|
||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
}
|
||||
|
||||
// enable channel
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
||||
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
||||
}
|
||||
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
@@ -340,10 +449,7 @@ func testAllChannels(notify bool) error {
|
||||
func TestAllChannels(c *gin.Context) {
|
||||
err := testAllChannels(true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -354,6 +460,10 @@ func TestAllChannels(c *gin.Context) {
|
||||
}
|
||||
|
||||
func AutomaticallyTestChannels(frequency int) {
|
||||
if frequency <= 0 {
|
||||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||
return
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("testing all channels")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,13 +5,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GitHubOAuthResponse struct {
|
||||
@@ -103,10 +104,7 @@ func GitHubOAuth(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -185,10 +183,7 @@ func GitHubBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -207,19 +202,13 @@ func GitHubBind(c *gin.Context) {
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.GitHubId = githubUser.Login
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -239,10 +228,7 @@ func GenerateOAuthCode(c *gin.Context) {
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -38,10 +38,7 @@ func LinuxDoBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -63,20 +60,14 @@ func LinuxDoBind(c *gin.Context) {
|
||||
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -202,10 +193,7 @@ func LinuxdoOAuth(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -10,14 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func GetAllLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
@@ -26,38 +19,19 @@ func GetAllLogs(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": map[string]any{
|
||||
"items": logs,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(logs)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
userId := c.GetInt("id")
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
@@ -65,24 +39,14 @@ func GetUserLogs(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": map[string]any{
|
||||
"items": logs,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(logs)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -90,10 +54,7 @@ func SearchAllLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
logs, err := model.SearchAllLogs(keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -109,10 +70,7 @@ func SearchUserLogs(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
logs, err := model.SearchUserLogs(userId, keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -198,10 +156,7 @@ func DeleteHistoryLogs(c *gin.Context) {
|
||||
}
|
||||
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -13,8 +12,9 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func UpdateMidjourneyTaskBulk() {
|
||||
@@ -213,14 +213,7 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
||||
}
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
// 解析其他查询参数
|
||||
queryParams := model.TaskQueryParams{
|
||||
@@ -230,7 +223,7 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
||||
items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.CountAllTasks(queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
@@ -239,27 +232,13 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func GetUserMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
@@ -269,7 +248,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
||||
items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.CountAllUserTask(userId, queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
@@ -278,14 +257,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/middleware"
|
||||
"one-api/middleware/jsrt"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
@@ -35,6 +33,7 @@ func TestStatus(c *gin.Context) {
|
||||
"message": "Server is running",
|
||||
"http_stats": httpStats,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
@@ -58,7 +57,9 @@ func GetStatus(c *gin.Context) {
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
@@ -72,12 +73,14 @@ func GetStatus(c *gin.Context) {
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
@@ -107,6 +110,7 @@ func GetStatus(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": data,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetNotice(c *gin.Context) {
|
||||
@@ -117,6 +121,7 @@ func GetNotice(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": common.OptionMap["Notice"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetAbout(c *gin.Context) {
|
||||
@@ -127,6 +132,7 @@ func GetAbout(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": common.OptionMap["About"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetMidjourney(c *gin.Context) {
|
||||
@@ -137,6 +143,7 @@ func GetMidjourney(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": common.OptionMap["Midjourney"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetHomePageContent(c *gin.Context) {
|
||||
@@ -147,6 +154,7 @@ func GetHomePageContent(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": common.OptionMap["HomePageContent"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendEmailVerification(c *gin.Context) {
|
||||
@@ -169,7 +177,13 @@ func SendEmailVerification(c *gin.Context) {
|
||||
localPart := parts[0]
|
||||
domainPart := parts[1]
|
||||
if common.EmailDomainRestrictionEnabled {
|
||||
allowed := slices.Contains(common.EmailDomainWhitelist, domainPart)
|
||||
allowed := false
|
||||
for _, domain := range common.EmailDomainWhitelist {
|
||||
if domainPart == domain {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -204,16 +218,14 @@ func SendEmailVerification(c *gin.Context) {
|
||||
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendPasswordResetEmail(c *gin.Context) {
|
||||
@@ -242,16 +254,14 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type PasswordResetRequest struct {
|
||||
@@ -279,10 +289,7 @@ func ResetPassword(c *gin.Context) {
|
||||
password := common.GenerateVerificationCode(12)
|
||||
err = model.ResetUserPasswordByEmail(req.Email, password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.DeleteKey(req.Email, common.PasswordResetPurpose)
|
||||
@@ -291,13 +298,5 @@ func ResetPassword(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": password,
|
||||
})
|
||||
}
|
||||
|
||||
func ReloadJSScripts(c *gin.Context) {
|
||||
jsrt.ReloadJSScripts()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "JavaScript 脚本已重新加载",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -126,10 +126,7 @@ func OidcAuth(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -195,10 +192,7 @@ func OidcBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -217,19 +211,13 @@ func OidcBind(c *gin.Context) {
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -160,10 +160,7 @@ func UpdateOption(c *gin.Context) {
|
||||
}
|
||||
err = model.UpdateOption(option.Key, option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -3,45 +3,44 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
if newAPIError != nil {
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
@@ -52,28 +51,34 @@ func Playground(c *gin.Context) {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userId := c.GetInt("id")
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
|
||||
tempToken := &model.Token{
|
||||
UserId: userId,
|
||||
Name: fmt.Sprintf("playground-%s", group),
|
||||
Group: group,
|
||||
}
|
||||
_ = middleware.SetupContextForToken(c, tempToken)
|
||||
_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
Relay(c)
|
||||
}
|
||||
|
||||
@@ -1,91 +1,52 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"errors"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllRedemptions(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": redemptions,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(redemptions)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchRedemptions(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": redemptions,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(redemptions)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetRedemption(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
redemption, err := model.GetRedemptionById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -100,13 +61,10 @@ func AddRedemption(c *gin.Context) {
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
||||
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "兑换码名称长度必须在1-20之间",
|
||||
@@ -165,10 +123,7 @@ func DeleteRedemption(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
err := model.DeleteRedemptionById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -183,18 +138,12 @@ func UpdateRedemption(c *gin.Context) {
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if statusOnly == "" {
|
||||
@@ -212,10 +161,7 @@ func UpdateRedemption(c *gin.Context) {
|
||||
}
|
||||
err = cleanRedemption.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -229,16 +175,13 @@ func UpdateRedemption(c *gin.Context) {
|
||||
func DeleteInvalidRedemption(c *gin.Context) {
|
||||
rows, err := model.DeleteInvalidRedemptions()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": rows,
|
||||
"data": rows,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -17,14 +17,15 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
var err *types.NewAPIError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
err = relay.ImageHelper(c)
|
||||
@@ -46,7 +47,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
|
||||
if constant2.ErrorLogEnabled && err != nil {
|
||||
if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
|
||||
// 保存错误日志到mysql中
|
||||
userId := c.GetInt("id")
|
||||
tokenName := c.GetString("token_name")
|
||||
@@ -55,14 +56,21 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.Error.Type
|
||||
other["error_code"] = err.Error.Code
|
||||
other["error_type"] = err.GetErrorType()
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
|
||||
adminInfo := make(map[string]interface{})
|
||||
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||
if isMultiKey {
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -73,25 +81,25 @@ func Relay(c *gin.Context) {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = relayRequest(c, relayMode, channel)
|
||||
newAPIError = relayRequest(c, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -101,14 +109,14 @@ func Relay(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,8 +135,7 @@ func WssRelay(c *gin.Context) {
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -137,25 +144,25 @@ func WssRelay(c *gin.Context) {
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = wssRequest(c, ws, relayMode, channel)
|
||||
newAPIError = wssRequest(c, ws, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -165,12 +172,12 @@ func WssRelay(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,27 +186,25 @@ func RelayClaude(c *gin.Context) {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var claudeErr *dto.ClaudeErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
claudeErr = claudeRequest(c, channel)
|
||||
newAPIError = claudeRequest(c, channel)
|
||||
|
||||
if claudeErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -209,30 +214,30 @@ func RelayClaude(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if claudeErr != nil {
|
||||
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
|
||||
c.JSON(claudeErr.StatusCode, gin.H{
|
||||
if newAPIError != nil {
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": claudeErr.Error,
|
||||
"error": newAPIError.ToClaudeError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relayHandler(c, relayMode)
|
||||
}
|
||||
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
@@ -245,7 +250,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
|
||||
c.Set("use_channel", useChannel)
|
||||
}
|
||||
|
||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
|
||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||||
if retryCount == 0 {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
autoBanInt := 1
|
||||
@@ -259,19 +264,28 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
return nil, newAPIError
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
||||
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
|
||||
if openaiErr == nil {
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
if types.IsChannelError(openaiErr) {
|
||||
return true
|
||||
}
|
||||
if types.IsSkipRetryError(openaiErr) {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
@@ -310,12 +324,12 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,9 +402,10 @@ func RelayTask(c *gin.Context) {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||||
if newAPIError != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
channelId = channel.Id
|
||||
@@ -398,9 +413,9 @@ func RelayTask(c *gin.Context) {
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
taskErr = taskRelayHandler(c, relayMode)
|
||||
}
|
||||
@@ -420,7 +435,7 @@ func RelayTask(c *gin.Context) {
|
||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayMode)
|
||||
|
||||
136
controller/swag_video.go
Normal file
136
controller/swag_video.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// VideoGenerations
|
||||
// @Summary 生成视频
|
||||
// @Description 调用视频生成接口生成视频
|
||||
// @Description 支持多种视频生成服务:
|
||||
// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
|
||||
// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||
// @Param request body dto.VideoRequest true "视频生成请求参数"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /v1/video/generations [post]
|
||||
func VideoGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
// VideoGenerationsTaskId
|
||||
// @Summary 查询视频
|
||||
// @Description 根据任务ID查询视频生成任务的状态和结果
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param task_id path string true "Task ID"
|
||||
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /v1/video/generations/{task_id} [get]
|
||||
func VideoGenerationsTaskId(c *gin.Context) {
|
||||
}
|
||||
|
||||
// KlingText2VideoGenerations
|
||||
// @Summary 可灵文生视频
|
||||
// @Description 调用可灵AI文生视频接口,生成视频内容
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||
// @Param request body KlingText2VideoRequest true "视频生成请求参数"
|
||||
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /kling/v1/videos/text2video [post]
|
||||
func KlingText2VideoGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
type KlingText2VideoRequest struct {
|
||||
ModelName string `json:"model_name,omitempty" example:"kling-v1"`
|
||||
Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||
Mode string `json:"mode,omitempty" example:"std"`
|
||||
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||
Duration string `json:"duration,omitempty" example:"5"`
|
||||
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
|
||||
}
|
||||
|
||||
type KlingCameraControl struct {
|
||||
Type string `json:"type,omitempty" example:"simple"`
|
||||
Config *KlingCameraConfig `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
type KlingCameraConfig struct {
|
||||
Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
|
||||
Vertical float64 `json:"vertical,omitempty" example:"0"`
|
||||
Pan float64 `json:"pan,omitempty" example:"0"`
|
||||
Tilt float64 `json:"tilt,omitempty" example:"0"`
|
||||
Roll float64 `json:"roll,omitempty" example:"0"`
|
||||
Zoom float64 `json:"zoom,omitempty" example:"0"`
|
||||
}
|
||||
|
||||
// KlingImage2VideoGenerations
|
||||
// @Summary 可灵官方-图生视频
|
||||
// @Description 调用可灵AI图生视频接口,生成视频内容
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||
// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
|
||||
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /kling/v1/videos/image2video [post]
|
||||
func KlingImage2VideoGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
type KlingImage2VideoRequest struct {
|
||||
ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
|
||||
Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
|
||||
Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||
Mode string `json:"mode,omitempty" example:"std"`
|
||||
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||
Duration string `json:"duration,omitempty" example:"5"`
|
||||
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
|
||||
}
|
||||
|
||||
// KlingImage2videoTaskId godoc
|
||||
// @Summary 可灵任务查询--图生视频
|
||||
// @Description Query the status and result of a Kling video generation task by task ID
|
||||
// @Tags Origin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task_id path string true "Task ID"
|
||||
// @Router /kling/v1/videos/image2video/{task_id} [get]
|
||||
func KlingImage2videoTaskId(c *gin.Context) {}
|
||||
|
||||
// KlingText2videoTaskId godoc
|
||||
// @Summary 可灵任务查询--文生视频
|
||||
// @Description Query the status and result of a Kling text-to-video generation task by task ID
|
||||
// @Tags Origin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task_id path string true "Task ID"
|
||||
// @Router /kling/v1/videos/text2video/{task_id} [get]
|
||||
func KlingText2videoTaskId(c *gin.Context) {}
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -17,6 +15,9 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func UpdateTaskBulk() {
|
||||
@@ -74,10 +75,10 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
|
||||
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
|
||||
default:
|
||||
common.SysLog("未知平台")
|
||||
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +123,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
|
||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
@@ -225,14 +226,7 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
|
||||
}
|
||||
|
||||
func GetAllTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
@@ -247,30 +241,15 @@ func GetAllTask(c *gin.Context) {
|
||||
ChannelID: c.Query("channel_id"),
|
||||
}
|
||||
|
||||
items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
||||
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllTasks(queryParams)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func GetUserTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
@@ -286,17 +265,9 @@ func GetUserTask(c *gin.Context) {
|
||||
EndTimestamp: endTimestamp,
|
||||
}
|
||||
|
||||
items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
||||
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
@@ -2,13 +2,16 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -77,13 +80,21 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
taskResult, err := adaptor.ParseTaskResult(responseBody)
|
||||
if err != nil {
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
taskResult.Url = t.FailReason
|
||||
taskResult.Progress = t.Progress
|
||||
taskResult.Reason = t.FailReason
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = responseBody
|
||||
}
|
||||
//if taskResult.Code != 0 {
|
||||
// return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||
//}
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
@@ -128,8 +139,6 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
|
||||
task.Data = responseBody
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -1,46 +1,26 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
size, _ := strconv.Atoi(c.Query("size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if size <= 0 {
|
||||
size = common.ItemsPerPage
|
||||
} else if size > 100 {
|
||||
size = 100
|
||||
}
|
||||
tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// Get total count for pagination
|
||||
total, _ := model.CountUserTokens(userId)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": tokens,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": size,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tokens)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -50,10 +30,7 @@ func SearchTokens(c *gin.Context) {
|
||||
token := c.Query("token")
|
||||
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -68,18 +45,12 @@ func GetToken(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt("id")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
token, err := model.GetTokenByIds(id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -95,10 +66,7 @@ func GetTokenStatus(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
token, err := model.GetTokenByIds(tokenId, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
expiredAt := token.ExpiredTime
|
||||
@@ -118,10 +86,7 @@ func AddToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
@@ -156,10 +121,7 @@ func AddToken(c *gin.Context) {
|
||||
}
|
||||
err = cleanToken.Insert()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -174,10 +136,7 @@ func DeleteToken(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
err := model.DeleteTokenById(id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -193,10 +152,7 @@ func UpdateToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
@@ -208,10 +164,7 @@ func UpdateToken(c *gin.Context) {
|
||||
}
|
||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if token.Status == common.TokenStatusEnabled {
|
||||
@@ -245,10 +198,7 @@ func UpdateToken(c *gin.Context) {
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -275,10 +225,7 @@ func DeleteTokenBatch(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
275
controller/topup_stripe.go
Normal file
275
controller/topup_stripe.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
"github.com/stripe/stripe-go/v81/checkout/session"
|
||||
"github.com/stripe/stripe-go/v81/webhook"
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentMethodStripe = "stripe"
|
||||
)
|
||||
|
||||
var stripeAdaptor = &StripeAdaptor{}
|
||||
|
||||
type StripePayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
}
|
||||
|
||||
type StripeAdaptor struct {
|
||||
}
|
||||
|
||||
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
|
||||
if req.Amount < getStripeMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getStripePayMoney(float64(req.Amount), group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
}
|
||||
|
||||
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||
if req.PaymentMethod != PaymentMethodStripe {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
return
|
||||
}
|
||||
if req.Amount < getStripeMinTopup() {
|
||||
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount > 10000 {
|
||||
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||
|
||||
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
|
||||
referenceId := "ref_" + common.Sha1([]byte(reference))
|
||||
|
||||
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
|
||||
if err != nil {
|
||||
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"pay_link": payLink,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func RequestStripeAmount(c *gin.Context) {
|
||||
var req StripePayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
stripeAdaptor.RequestAmount(c, &req)
|
||||
}
|
||||
|
||||
func RequestStripePay(c *gin.Context) {
|
||||
var req StripePayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
stripeAdaptor.RequestPay(c, &req)
|
||||
}
|
||||
|
||||
func StripeWebhook(c *gin.Context) {
|
||||
payload, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
signature := c.GetHeader("Stripe-Signature")
|
||||
endpointSecret := setting.StripeWebhookSecret
|
||||
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
|
||||
IgnoreAPIVersionMismatch: true,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case stripe.EventTypeCheckoutSessionCompleted:
|
||||
sessionCompleted(event)
|
||||
case stripe.EventTypeCheckoutSessionExpired:
|
||||
sessionExpired(event)
|
||||
default:
|
||||
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func sessionCompleted(event stripe.Event) {
|
||||
customerId := event.GetObjectValue("customer")
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "complete" != status {
|
||||
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
err := model.Recharge(referenceId, customerId)
|
||||
if err != nil {
|
||||
log.Println(err.Error(), referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||
}
|
||||
|
||||
func sessionExpired(event stripe.Event) {
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "expired" != status {
|
||||
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if len(referenceId) == 0 {
|
||||
log.Println("未提供支付单号")
|
||||
return
|
||||
}
|
||||
|
||||
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||
if topUp == nil {
|
||||
log.Println("充值订单不存在", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
log.Println("充值订单状态错误", referenceId)
|
||||
}
|
||||
|
||||
topUp.Status = common.TopUpStatusExpired
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("充值订单已过期", referenceId)
|
||||
}
|
||||
|
||||
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
||||
if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
|
||||
return "", fmt.Errorf("无效的Stripe API密钥")
|
||||
}
|
||||
|
||||
stripe.Key = setting.StripeApiSecret
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: stripe.String(referenceId),
|
||||
SuccessURL: stripe.String(setting.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(setting.ServerAddress + "/topup"),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(setting.StripePriceId),
|
||||
Quantity: stripe.Int64(amount),
|
||||
},
|
||||
},
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
}
|
||||
|
||||
if "" == customerId {
|
||||
if "" != email {
|
||||
params.CustomerEmail = stripe.String(email)
|
||||
}
|
||||
|
||||
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
||||
} else {
|
||||
params.Customer = stripe.String(customerId)
|
||||
}
|
||||
|
||||
result, err := session.New(params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func GetChargedAmount(count float64, user model.User) float64 {
|
||||
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||
if topUpGroupRatio == 0 {
|
||||
topUpGroupRatio = 1
|
||||
}
|
||||
|
||||
return count * topUpGroupRatio
|
||||
}
|
||||
|
||||
func getStripePayMoney(amount float64, group string) float64 {
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
// Using float64 for monetary calculations is acceptable here due to the small amounts involved
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
|
||||
return payMoney
|
||||
}
|
||||
|
||||
func getStripeMinTopup() int64 {
|
||||
minTopup := setting.StripeMinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
}
|
||||
return int64(minTopup)
|
||||
}
|
||||
553
controller/twofa.go
Normal file
553
controller/twofa.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Setup2FARequest 设置2FA请求结构
|
||||
type Setup2FARequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// Verify2FARequest 验证2FA请求结构
|
||||
type Verify2FARequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// Setup2FAResponse 设置2FA响应结构
|
||||
type Setup2FAResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeData string `json:"qr_code_data"`
|
||||
BackupCodes []string `json:"backup_codes"`
|
||||
}
|
||||
|
||||
// Setup2FA 初始化2FA设置
|
||||
func Setup2FA(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 检查用户是否已经启用2FA
|
||||
existing, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if existing != nil && existing.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已启用2FA,请先禁用后重新设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果存在已禁用的2FA记录,先删除它
|
||||
if existing != nil && !existing.IsEnabled {
|
||||
if err := existing.Delete(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
existing = nil // 重置为nil,后续将创建新记录
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成TOTP密钥
|
||||
key, err := common.GenerateTOTPSecret(user.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成2FA密钥失败",
|
||||
})
|
||||
common.SysError("生成TOTP密钥失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 生成备用码
|
||||
backupCodes, err := common.GenerateBackupCodes()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成备用码失败",
|
||||
})
|
||||
common.SysError("生成备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 生成二维码数据
|
||||
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
|
||||
|
||||
// 创建或更新2FA记录(暂未启用)
|
||||
twoFA := &model.TwoFA{
|
||||
UserId: userId,
|
||||
Secret: key.Secret(),
|
||||
IsEnabled: false,
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
// 更新现有记录
|
||||
twoFA.Id = existing.Id
|
||||
err = twoFA.Update()
|
||||
} else {
|
||||
// 创建新记录
|
||||
err = twoFA.Create()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建备用码记录
|
||||
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "保存备用码失败",
|
||||
})
|
||||
common.SysError("保存备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
|
||||
"data": Setup2FAResponse{
|
||||
Secret: key.Secret(),
|
||||
QRCodeData: qrCodeData,
|
||||
BackupCodes: backupCodes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Enable2FA 启用2FA
|
||||
func Enable2FA(c *gin.Context) {
|
||||
var req Setup2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "请先完成2FA初始化设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
if twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "2FA已经启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 启用2FA
|
||||
if err := twoFA.Enable(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "两步验证启用成功",
|
||||
})
|
||||
}
|
||||
|
||||
// Disable2FA 禁用2FA
|
||||
func Disable2FA(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码或备用码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
isValidTOTP := false
|
||||
isValidBackup := false
|
||||
|
||||
if err == nil {
|
||||
// 尝试验证TOTP
|
||||
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
}
|
||||
|
||||
if !isValidTOTP {
|
||||
// 尝试验证备用码
|
||||
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidTOTP && !isValidBackup {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 禁用2FA
|
||||
if err := model.DisableTwoFA(userId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "两步验证已禁用",
|
||||
})
|
||||
}
|
||||
|
||||
// Get2FAStatus 获取用户2FA状态
|
||||
func Get2FAStatus(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
status := map[string]interface{}{
|
||||
"enabled": false,
|
||||
"locked": false,
|
||||
}
|
||||
|
||||
if twoFA != nil {
|
||||
status["enabled"] = twoFA.IsEnabled
|
||||
status["locked"] = twoFA.IsLocked()
|
||||
if twoFA.IsEnabled {
|
||||
// 获取剩余备用码数量
|
||||
backupCount, err := model.GetUnusedBackupCodeCount(userId)
|
||||
if err != nil {
|
||||
common.SysError("获取备用码数量失败: " + err.Error())
|
||||
} else {
|
||||
status["backup_codes_remaining"] = backupCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": status,
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateBackupCodes 重新生成备用码
|
||||
func RegenerateBackupCodes(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新的备用码
|
||||
backupCodes, err := common.GenerateBackupCodes()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成备用码失败",
|
||||
})
|
||||
common.SysError("生成备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 保存新的备用码
|
||||
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "保存备用码失败",
|
||||
})
|
||||
common.SysError("保存备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "备用码重新生成成功",
|
||||
"data": map[string]interface{}{
|
||||
"backup_codes": backupCodes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Verify2FALogin 登录时验证2FA
|
||||
func Verify2FALogin(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从会话中获取pending用户信息
|
||||
session := sessions.Default(c)
|
||||
pendingUserId := session.Get("pending_user_id")
|
||||
if pendingUserId == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "会话已过期,请重新登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
userId, ok := pendingUserId.(int)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "会话数据无效,请重新登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 获取用户信息
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(user.Id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码或备用码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
isValidTOTP := false
|
||||
isValidBackup := false
|
||||
|
||||
if err == nil {
|
||||
// 尝试验证TOTP
|
||||
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
}
|
||||
|
||||
if !isValidTOTP {
|
||||
// 尝试验证备用码
|
||||
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidTOTP && !isValidBackup {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2FA验证成功,清理pending会话信息并完成登录
|
||||
session.Delete("pending_username")
|
||||
session.Delete("pending_user_id")
|
||||
session.Save()
|
||||
|
||||
setupLogin(user, c)
|
||||
}
|
||||
|
||||
// Admin2FAStats 管理员获取2FA统计信息
|
||||
func Admin2FAStats(c *gin.Context) {
|
||||
stats, err := model.GetTwoFAStats()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": stats,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminDisable2FA 管理员强制禁用用户2FA
|
||||
func AdminDisable2FA(c *gin.Context) {
|
||||
userIdStr := c.Param("id")
|
||||
userId, err := strconv.Atoi(userIdStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户ID格式错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查目标用户权限
|
||||
targetUser, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权操作同级或更高级用户的2FA设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 禁用2FA
|
||||
if err := model.DisableTwoFA(userId); err != nil {
|
||||
if errors.Is(err, model.ErrTwoFANotEnabled) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
adminId := c.GetInt("id")
|
||||
model.RecordLog(userId, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "用户2FA已被强制禁用",
|
||||
})
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllQuotaDates(c *gin.Context) {
|
||||
@@ -13,10 +15,7 @@ func GetAllQuotaDates(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -41,10 +40,7 @@ func GetUserQuotaDates(c *gin.Context) {
|
||||
}
|
||||
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -62,6 +62,32 @@ func Login(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否启用2FA
|
||||
if model.IsTwoFAEnabled(user.Id) {
|
||||
// 设置pending session,等待2FA验证
|
||||
session := sessions.Default(c)
|
||||
session.Set("pending_username", user.Username)
|
||||
session.Set("pending_user_id", user.Id)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "无法保存会话信息,请重试",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "请输入两步验证码",
|
||||
"success": true,
|
||||
"data": map[string]interface{}{
|
||||
"require_2fa": true,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
@@ -188,10 +214,7 @@ func Register(c *gin.Context) {
|
||||
cleanUser.Email = user.Email
|
||||
}
|
||||
if err := cleanUser.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -247,81 +270,45 @@ func Register(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetAllUsers(c *gin.Context) {
|
||||
pageInfo, err := common.GetPageQuery(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "parse page query failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.GetAllUsers(pageInfo)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": pageInfo,
|
||||
})
|
||||
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
startIdx := (p - 1) * pageSize
|
||||
users, total, err := model.SearchUsers(keyword, group, startIdx, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": users,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
@@ -344,10 +331,7 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// get rand int 28-32
|
||||
@@ -372,10 +356,7 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -395,18 +376,12 @@ func TransferAffQuota(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
tran := TransferAffQuotaRequest{}
|
||||
if err := c.ShouldBindJSON(&tran); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
err = user.TransferAffQuotaToQuota(tran.Quota)
|
||||
@@ -427,10 +402,7 @@ func GetAffCode(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if user.AffCode == "" {
|
||||
@@ -455,10 +427,7 @@ func GetSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||
@@ -479,10 +448,7 @@ func GetUserModels(c *gin.Context) {
|
||||
}
|
||||
user, err := model.GetUserCache(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
groups := setting.GetUserUsableGroups(user.Group)
|
||||
@@ -524,10 +490,7 @@ func UpdateUser(c *gin.Context) {
|
||||
}
|
||||
originUser, err := model.GetUserById(updatedUser.Id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
@@ -550,10 +513,7 @@ func UpdateUser(c *gin.Context) {
|
||||
}
|
||||
updatePassword := updatedUser.Password != ""
|
||||
if err := updatedUser.Edit(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if originUser.Quota != updatedUser.Quota {
|
||||
@@ -599,17 +559,11 @@ func UpdateSelf(c *gin.Context) {
|
||||
}
|
||||
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := cleanUser.Update(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -640,18 +594,12 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int
|
||||
func DeleteUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
@@ -686,10 +634,7 @@ func DeleteSelf(c *gin.Context) {
|
||||
|
||||
err := model.DeleteUserById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -735,10 +680,7 @@ func CreateUser(c *gin.Context) {
|
||||
DisplayName: user.DisplayName,
|
||||
}
|
||||
if err := cleanUser.Insert(0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -848,10 +790,7 @@ func ManageUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
clearUser := model.User{
|
||||
@@ -883,20 +822,14 @@ func EmailBind(c *gin.Context) {
|
||||
}
|
||||
err := user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.Email = email
|
||||
// no need to check if this email already taken, because we have used verification code to check it
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -918,19 +851,13 @@ func TopUp(c *gin.Context) {
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
quota, err := model.Redeem(req.Key, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -1013,10 +940,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type wechatLoginResponse struct {
|
||||
@@ -150,19 +151,13 @@ func WeChatBind(c *gin.Context) {
|
||||
}
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.WeChatId = wechatId
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -11,7 +11,6 @@ services:
|
||||
volumes:
|
||||
- ./data:/data
|
||||
- ./logs:/app/logs
|
||||
- ${JS_SCRIPT_DIR:-./scripts}:/app/scripts
|
||||
environment:
|
||||
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
@@ -22,6 +21,7 @@ services:
|
||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
# - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
- mysql
|
||||
|
||||
195
docs/api/web_api.md
Normal file
195
docs/api/web_api.md
Normal file
@@ -0,0 +1,195 @@
|
||||
# One API – Web 界面后端接口文档
|
||||
|
||||
> 本文档汇总了 **One API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
|
||||
>
|
||||
> 接口前缀统一为 `https://<your-domain>`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
|
||||
>
|
||||
> 鉴权级别说明:
|
||||
> * **公开** – 不需要登录即可调用
|
||||
> * **用户** – 需携带用户 Token(`middleware.UserAuth`)
|
||||
> * **管理员** – 需管理员 Token(`middleware.AdminAuth`)
|
||||
> * **Root** – 仅限最高权限 Root 用户(`middleware.RootAuth`)
|
||||
|
||||
---
|
||||
|
||||
## 1. 初始化 / 系统状态
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/setup | 公开 | 获取系统初始化状态 |
|
||||
| POST | /api/setup | 公开 | 完成首次安装向导 |
|
||||
| GET | /api/status | 公开 | 获取运行状态摘要 |
|
||||
| GET | /api/uptime/status | 公开 | Uptime-Kuma 兼容状态探针 |
|
||||
| GET | /api/status/test | 管理员 | 测试后端与依赖组件是否正常 |
|
||||
|
||||
## 2. 公共信息
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/models | 用户 | 获取前端可用模型列表 |
|
||||
| GET | /api/notice | 公开 | 获取公告栏内容 |
|
||||
| GET | /api/about | 公开 | 关于页面信息 |
|
||||
| GET | /api/home_page_content | 公开 | 首页自定义内容 |
|
||||
| GET | /api/pricing | 可匿名/用户 | 价格与套餐信息 |
|
||||
| GET | /api/ratio_config | 公开 | 模型倍率配置(仅公开字段) |
|
||||
|
||||
## 3. 邮件 / 身份验证
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/verification | 公开 (限流) | 发送邮箱验证邮件 |
|
||||
| GET | /api/reset_password | 公开 (限流) | 发送重置密码邮件 |
|
||||
| POST | /api/user/reset | 公开 | 提交重置密码请求 |
|
||||
|
||||
## 4. OAuth / 第三方登录
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
|
||||
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
|
||||
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
|
||||
| GET | /api/oauth/wechat/bind | 公开 | 微信账户绑定 |
|
||||
| GET | /api/oauth/email/bind | 公开 | 邮箱绑定 |
|
||||
| GET | /api/oauth/telegram/login | 公开 | Telegram 登录 |
|
||||
| GET | /api/oauth/telegram/bind | 公开 | Telegram 账户绑定 |
|
||||
| GET | /api/oauth/state | 公开 | 获取随机 state(防 CSRF) |
|
||||
|
||||
## 5. 用户模块
|
||||
### 5.1 账号注册/登录
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| POST | /api/user/register | 公开 | 注册新账号 |
|
||||
| POST | /api/user/login | 公开 | 用户登录 |
|
||||
| GET | /api/user/logout | 用户 | 退出登录 |
|
||||
| GET | /api/user/epay/notify | 公开 | Epay 支付回调 |
|
||||
| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
|
||||
|
||||
### 5.2 用户自身操作 (需登录)
|
||||
| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
|
||||
| GET | /api/user/self | 用户 | 获取个人资料 |
|
||||
| GET | /api/user/models | 用户 | 获取模型可见性 |
|
||||
| PUT | /api/user/self | 用户 | 修改个人资料 |
|
||||
| DELETE | /api/user/self | 用户 | 注销账号 |
|
||||
| GET | /api/user/token | 用户 | 生成用户级别 Access Token |
|
||||
| GET | /api/user/aff | 用户 | 获取推广码信息 |
|
||||
| POST | /api/user/topup | 用户 | 余额直充 |
|
||||
| POST | /api/user/pay | 用户 | 提交支付订单 |
|
||||
| POST | /api/user/amount | 用户 | 余额支付 |
|
||||
| POST | /api/user/aff_transfer | 用户 | 推广额度转账 |
|
||||
| PUT | /api/user/setting | 用户 | 更新用户设置 |
|
||||
|
||||
### 5.3 管理员用户管理
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/user/ | 管理员 | 获取全部用户列表 |
|
||||
| GET | /api/user/search | 管理员 | 搜索用户 |
|
||||
| GET | /api/user/:id | 管理员 | 获取单个用户信息 |
|
||||
| POST | /api/user/ | 管理员 | 创建用户 |
|
||||
| POST | /api/user/manage | 管理员 | 冻结/重置等管理操作 |
|
||||
| PUT | /api/user/ | 管理员 | 更新用户 |
|
||||
| DELETE | /api/user/:id | 管理员 | 删除用户 |
|
||||
|
||||
## 6. 站点选项 (Root)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/option/ | Root | 获取全局配置 |
|
||||
| PUT | /api/option/ | Root | 更新全局配置 |
|
||||
| POST | /api/option/rest_model_ratio | Root | 重置模型倍率 |
|
||||
| POST | /api/option/migrate_console_setting | Root | 迁移旧版控制台配置 |
|
||||
|
||||
## 7. 模型倍率同步 (Root)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/ratio_sync/channels | Root | 获取可同步渠道列表 |
|
||||
| POST | /api/ratio_sync/fetch | Root | 从上游拉取倍率 |
|
||||
|
||||
## 8. 渠道管理 (管理员)
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| GET | /api/channel/ | 获取渠道列表 |
|
||||
| GET | /api/channel/search | 搜索渠道 |
|
||||
| GET | /api/channel/models | 查询渠道模型能力 |
|
||||
| GET | /api/channel/models_enabled | 查询启用模型能力 |
|
||||
| GET | /api/channel/:id | 获取单个渠道 |
|
||||
| GET | /api/channel/test | 批量测试渠道连通性 |
|
||||
| GET | /api/channel/test/:id | 单个渠道测试 |
|
||||
| GET | /api/channel/update_balance | 批量刷新余额 |
|
||||
| GET | /api/channel/update_balance/:id | 单个刷新余额 |
|
||||
| POST | /api/channel/ | 新增渠道 |
|
||||
| PUT | /api/channel/ | 更新渠道 |
|
||||
| DELETE | /api/channel/disabled | 删除已禁用渠道 |
|
||||
| POST | /api/channel/tag/disabled | 批量禁用标签渠道 |
|
||||
| POST | /api/channel/tag/enabled | 批量启用标签渠道 |
|
||||
| PUT | /api/channel/tag | 编辑渠道标签 |
|
||||
| DELETE | /api/channel/:id | 删除渠道 |
|
||||
| POST | /api/channel/batch | 批量删除渠道 |
|
||||
| POST | /api/channel/fix | 修复渠道能力表 |
|
||||
| GET | /api/channel/fetch_models/:id | 拉取单渠道模型 |
|
||||
| POST | /api/channel/fetch_models | 拉取全部渠道模型 |
|
||||
| POST | /api/channel/batch/tag | 批量设置渠道标签 |
|
||||
| GET | /api/channel/tag/models | 根据标签获取模型 |
|
||||
| POST | /api/channel/copy/:id | 复制渠道 |
|
||||
|
||||
## 9. Token 管理
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/token/ | 用户 | 获取全部 Token |
|
||||
| GET | /api/token/search | 用户 | 搜索 Token |
|
||||
| GET | /api/token/:id | 用户 | 获取单个 Token |
|
||||
| POST | /api/token/ | 用户 | 创建 Token |
|
||||
| PUT | /api/token/ | 用户 | 更新 Token |
|
||||
| DELETE | /api/token/:id | 用户 | 删除 Token |
|
||||
| POST | /api/token/batch | 用户 | 批量删除 Token |
|
||||
|
||||
## 10. 兑换码管理 (管理员)
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| GET | /api/redemption/ | 获取兑换码列表 |
|
||||
| GET | /api/redemption/search | 搜索兑换码 |
|
||||
| GET | /api/redemption/:id | 获取单个兑换码 |
|
||||
| POST | /api/redemption/ | 创建兑换码 |
|
||||
| PUT | /api/redemption/ | 更新兑换码 |
|
||||
| DELETE | /api/redemption/invalid | 删除无效兑换码 |
|
||||
| DELETE | /api/redemption/:id | 删除兑换码 |
|
||||
|
||||
## 11. 日志
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/log/ | 管理员 | 获取全部日志 |
|
||||
| DELETE | /api/log/ | 管理员 | 删除历史日志 |
|
||||
| GET | /api/log/stat | 管理员 | 日志统计 |
|
||||
| GET | /api/log/self/stat | 用户 | 我的日志统计 |
|
||||
| GET | /api/log/search | 管理员 | 搜索全部日志 |
|
||||
| GET | /api/log/self | 用户 | 获取我的日志 |
|
||||
| GET | /api/log/self/search | 用户 | 搜索我的日志 |
|
||||
| GET | /api/log/token | 公开 | 根据 Token 查询日志(支持 CORS) |
|
||||
|
||||
## 12. 数据统计
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/data/ | 管理员 | 全站用量按日期统计 |
|
||||
| GET | /api/data/self | 用户 | 我的用量按日期统计 |
|
||||
|
||||
## 13. 分组
|
||||
| GET | /api/group/ | 管理员 | 获取全部分组列表 |
|
||||
|
||||
## 14. Midjourney 任务
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/mj/self | 用户 | 获取自己的 MJ 任务 |
|
||||
| GET | /api/mj/ | 管理员 | 获取全部 MJ 任务 |
|
||||
|
||||
## 15. 任务中心
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/task/self | 用户 | 获取我的任务 |
|
||||
| GET | /api/task/ | 管理员 | 获取全部任务 |
|
||||
|
||||
## 16. 账户计费面板 (Dashboard)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /dashboard/billing/subscription | 用户 Token | 获取订阅额度信息 |
|
||||
| GET | /v1/dashboard/billing/subscription | 同上 | 兼容 OpenAI SDK 路径 |
|
||||
| GET | /dashboard/billing/usage | 用户 Token | 获取使用量信息 |
|
||||
| GET | /v1/dashboard/billing/usage | 同上 | 兼容 OpenAI SDK 路径 |
|
||||
|
||||
---
|
||||
|
||||
> **更新日期**:2025.07.17
|
||||
55
docs/images/cherry-studio.svg
Normal file
55
docs/images/cherry-studio.svg
Normal file
@@ -0,0 +1,55 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg id="_图层_2" data-name="图层_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 198.45 66.73">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: #ea5e5d;
|
||||
}
|
||||
|
||||
.cls-2 {
|
||||
fill: #23af69;
|
||||
}
|
||||
|
||||
.cls-3 {
|
||||
fill: #ea5756;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<g id="_图层_1-2" data-name="图层_1">
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path class="cls-1" d="M16.72,51.21c-4.45,0-8.64-1.78-11.81-5.01-3.17-3.23-4.91-7.51-4.91-12.04s1.74-8.81,4.91-12.04,7.36-5.01,11.81-5.01,8.71,1.82,11.82,4.99c2.32,2.36,2.32,6.2,0,8.56-2.32,2.36-6.08,2.36-8.4,0-.9-.92-2.15-1.45-3.43-1.45-2.63,0-4.85,2.26-4.85,4.94s2.22,4.94,4.85,4.94c1.28,0,2.52-.53,3.43-1.45,2.32-2.36,6.08-2.36,8.4,0,2.32,2.36,2.32,6.2,0,8.56-3.11,3.17-7.42,4.99-11.82,4.99Z"/>
|
||||
<path class="cls-1" d="M32.05,66.73c-4.45,0-8.64-1.78-11.81-5.01s-4.91-7.51-4.91-12.04,1.79-8.88,4.9-12.06c2.32-2.36,6.08-2.36,8.4,0,2.32,2.36,2.32,6.2,0,8.56-.9.92-1.42,2.19-1.42,3.49,0,2.68,2.22,4.94,4.85,4.94s4.85-2.26,4.85-4.94c0-.95-.23-2.31-1.32-3.43-3.13-3.19-4.92-7.6-4.92-12.09s1.74-8.81,4.91-12.04,7.36-5.01,11.81-5.01,8.64,1.78,11.81,5.01,4.91,7.51,4.91,12.04-1.79,8.88-4.9,12.06c-2.32,2.36-6.08,2.36-8.4,0-2.32-2.36-2.32-6.2,0-8.56.9-.92,1.42-2.19,1.42-3.49,0-2.68-2.22-4.94-4.85-4.94s-4.85,2.26-4.85,4.94c0,1.31.53,2.6,1.45,3.53,3.1,3.16,4.8,7.42,4.8,11.99s-1.74,8.81-4.91,12.04c-3.17,3.23-7.36,5.01-11.81,5.01Z"/>
|
||||
</g>
|
||||
<path class="cls-2" d="M32.05,19.09l-9.72-9.12c-1.5-1.4-1.57-3.75-.17-5.25,1.4-1.49,3.75-1.57,5.25-.17l3.89,3.65,5.53-6.83c1.29-1.59,3.63-1.84,5.22-.55,1.59,1.29,1.84,3.63.55,5.22l-10.56,13.05Z"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="cls-3" d="M93.93,24.6l.55-.39c.69-.4,1.17-.61,1.46-.61.63,0,1.3.57,2.03,1.7.44.71.67,1.27.67,1.7s-.14.78-.41,1.06c-.27.28-.59.54-.96.76-.36.22-.71.43-1.05.64-.33.2-1.02.47-2.05.79-1.03.32-2.03.49-2.99.49s-1.93-.13-2.91-.38c-.98-.25-1.99-.68-3.03-1.27-1.04-.6-1.98-1.32-2.81-2.18-.83-.86-1.51-1.96-2.05-3.31-.54-1.35-.8-2.81-.8-4.38s.26-3.01.79-4.29c.53-1.28,1.2-2.35,2.02-3.19.82-.84,1.75-1.54,2.81-2.11,1.98-1.09,3.97-1.64,5.98-1.64.95,0,1.92.15,2.9.44.98.29,1.72.59,2.23.9l.73.42c.36.22.65.4.85.55.53.42.79.91.79,1.44s-.21,1.1-.64,1.68c-.79,1.09-1.5,1.64-2.12,1.64-.36,0-.88-.22-1.55-.67-.85-.69-1.98-1.03-3.4-1.03-1.31,0-2.61.46-3.88,1.36-.61.44-1.11,1.07-1.52,1.88-.4.81-.61,1.72-.61,2.75s.2,1.94.61,2.75c.4.81.92,1.45,1.55,1.91,1.23.89,2.52,1.34,3.85,1.34.63,0,1.22-.08,1.77-.24.56-.16.96-.32,1.2-.49Z"/>
|
||||
<path class="cls-3" d="M114.38,9.07c.16-.3.43-.52.82-.64.38-.12.87-.18,1.46-.18s1.05.05,1.4.15c.34.1.61.22.79.36.18.14.32.34.42.61.1.34.15.87.15,1.58v16.84c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.28.55-1.07.82-2.37.82-1.42,0-2.25-.37-2.49-1.12-.12-.34-.18-.87-.18-1.58v-6.16h-8.04v6.19c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.28.55-1.07.82-2.37.82-1.42,0-2.25-.37-2.49-1.12-.12-.34-.18-.87-.18-1.58V10.92c0-.46.02-.81.05-1.05.03-.23.13-.5.29-.8.28-.55,1.07-.82,2.37-.82,1.42,0,2.25.37,2.52,1.12.1.34.15.87.15,1.58v6.19h8.04v-6.22c0-.46.02-.81.05-1.05.03-.23.13-.5.29-.8Z"/>
|
||||
<path class="cls-3" d="M127.21,25.1h9.34c.47,0,.81.02,1.05.05.23.03.5.13.8.29.55.28.82,1.07.82,2.37,0,1.42-.37,2.25-1.12,2.49-.34.12-.87.18-1.58.18h-12.01c-1.42,0-2.25-.38-2.49-1.15-.12-.32-.18-.84-.18-1.55V10.9c0-1.03.19-1.73.58-2.11.38-.37,1.11-.56,2.18-.56h11.95c.47,0,.81.02,1.05.05.23.03.5.13.8.29.55.28.82,1.07.82,2.37,0,1.42-.37,2.25-1.12,2.49-.34.12-.87.18-1.58.18h-9.31v3.06h6.01c.46,0,.81.02,1.05.05.23.03.5.13.8.29.55.28.82,1.07.82,2.37,0,1.42-.38,2.25-1.15,2.49-.34.12-.87.18-1.58.18h-5.95v3.06Z"/>
|
||||
<path class="cls-3" d="M196.96,8.79c.99.69,1.49,1.35,1.49,2,0,.38-.23.92-.7,1.61l-6.55,9.8v5.79c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.16.3-.43.52-.82.64-.38.12-.9.18-1.55.18s-1.16-.06-1.55-.18c-.38-.12-.66-.34-.82-.65-.16-.31-.26-.59-.29-.82-.03-.23-.05-.59-.05-1.08v-5.73l-6.55-9.8c-.47-.69-.7-1.22-.7-1.61,0-.65.44-1.27,1.33-1.87.89-.6,1.53-.9,1.91-.9s.69.08.91.24c.34.22.71.64,1.09,1.24l4.7,7.52,4.7-7.52c.38-.61.72-1.01,1-1.2s.61-.29.99-.29.97.25,1.77.76Z"/>
|
||||
<g>
|
||||
<path class="cls-3" d="M81.93,56.63c-.53-.65-.79-1.23-.79-1.74s.43-1.2,1.3-2.05c.51-.49,1.04-.73,1.61-.73s1.36.51,2.37,1.52c.28.34.69.67,1.21.99.53.31,1.01.47,1.46.47,1.88,0,2.82-.77,2.82-2.31,0-.46-.26-.85-.77-1.17-.52-.31-1.16-.54-1.93-.68-.77-.14-1.6-.37-2.49-.68-.89-.31-1.72-.68-2.49-1.11-.77-.42-1.41-1.1-1.93-2.02-.52-.92-.77-2.03-.77-3.32,0-1.78.66-3.33,1.99-4.66s3.13-1.99,5.42-1.99c1.21,0,2.32.16,3.32.47,1,.31,1.69.63,2.08.96l.76.58c.63.59.94,1.08.94,1.49s-.24.96-.73,1.67c-.69,1.01-1.4,1.52-2.12,1.52-.42,0-.95-.2-1.58-.61-.06-.04-.18-.14-.35-.3-.17-.16-.33-.29-.47-.39-.42-.26-.97-.39-1.62-.39s-1.2.16-1.64.47c-.43.31-.65.75-.65,1.3s.26,1.01.77,1.35c.52.34,1.16.58,1.93.7.77.12,1.61.31,2.52.56.91.25,1.75.56,2.52.93.77.36,1.41,1,1.93,1.9.52.9.77,2.01.77,3.32s-.26,2.47-.79,3.47c-.53,1-1.21,1.77-2.06,2.32-1.64,1.07-3.39,1.61-5.25,1.61-.95,0-1.85-.12-2.7-.35-.85-.23-1.54-.52-2.06-.86-1.07-.65-1.82-1.27-2.24-1.88l-.27-.33Z"/>
|
||||
<path class="cls-3" d="M100.74,37.49h16.87c.65,0,1.12.08,1.43.23.3.15.51.39.61.71.1.32.15.75.15,1.27s-.05.95-.15,1.26c-.1.31-.27.53-.52.65-.36.18-.88.27-1.55.27h-5.79v15.26c0,.47-.02.81-.05,1.03s-.12.48-.27.77c-.15.29-.42.5-.8.62-.38.12-.89.18-1.52.18s-1.13-.06-1.5-.18c-.37-.12-.64-.33-.79-.62-.15-.29-.24-.56-.27-.79-.03-.23-.05-.58-.05-1.05v-15.23h-5.82c-.65,0-1.12-.08-1.43-.23-.3-.15-.51-.39-.61-.71-.1-.32-.15-.75-.15-1.27s.05-.95.15-1.26c.1-.31.27-.53.52-.65.36-.18.88-.27,1.55-.27Z"/>
|
||||
<path class="cls-3" d="M135.99,38.34c.2-.32.5-.55.88-.67.38-.12.86-.18,1.44-.18s1.04.05,1.38.15c.34.1.61.22.79.36.18.14.31.35.39.64.12.34.18.87.18,1.58v9.16c0,2.67-.83,5.1-2.49,7.28-.81,1.03-1.85,1.87-3.12,2.5s-2.68.96-4.23.96-2.95-.32-4.22-.97c-1.26-.65-2.29-1.5-3.08-2.55-1.64-2.14-2.46-4.57-2.46-7.28v-9.13c0-.49.02-.84.05-1.08.03-.23.13-.5.29-.8.16-.3.43-.52.82-.64.38-.12.9-.18,1.55-.18s1.16.06,1.55.18c.38.12.65.33.79.64.24.47.36,1.1.36,1.91v9.1c0,1.23.3,2.41.91,3.52.3.57.76,1.02,1.37,1.36.61.34,1.32.52,2.15.52,1.48,0,2.58-.55,3.31-1.64.73-1.09,1.09-2.36,1.09-3.79v-9.28c0-.79.1-1.34.3-1.67Z"/>
|
||||
<path class="cls-3" d="M146.18,37.49l5.61.03c2.93,0,5.51,1.06,7.74,3.17,2.22,2.11,3.34,4.71,3.34,7.8s-1.09,5.73-3.26,7.93c-2.17,2.2-4.81,3.31-7.9,3.31h-5.55c-1.23,0-2-.25-2.31-.76-.24-.42-.36-1.07-.36-1.94v-16.87c0-.49.02-.84.05-1.06s.13-.49.29-.79c.28-.55,1.07-.82,2.37-.82ZM151.79,54.35c1.46,0,2.77-.54,3.94-1.62,1.17-1.08,1.76-2.44,1.76-4.08s-.57-3.01-1.71-4.11c-1.14-1.1-2.48-1.65-4.02-1.65h-2.91v11.47h2.94Z"/>
|
||||
<path class="cls-3" d="M164.84,40.19c0-.46.02-.81.05-1.05.03-.23.13-.5.29-.8.28-.55,1.07-.82,2.37-.82,1.42,0,2.25.37,2.52,1.12.1.34.15.87.15,1.58v16.87c0,.49-.02.84-.05,1.06s-.13.49-.29.79c-.28.55-1.07.82-2.37.82-1.42,0-2.25-.38-2.49-1.15-.12-.32-.18-.84-.18-1.55v-16.87Z"/>
|
||||
<path class="cls-3" d="M183.07,37.24c2.99,0,5.59,1.08,7.8,3.25,2.2,2.16,3.31,4.85,3.31,8.05s-1.05,5.94-3.16,8.19c-2.1,2.26-4.69,3.38-7.77,3.38s-5.69-1.11-7.84-3.34c-2.15-2.22-3.23-4.87-3.23-7.95,0-1.68.3-3.25.91-4.72.61-1.47,1.42-2.7,2.43-3.69,1.01-.99,2.17-1.77,3.49-2.34,1.31-.57,2.67-.85,4.07-.85ZM177.55,48.68c0,1.8.58,3.26,1.74,4.38,1.16,1.12,2.46,1.68,3.9,1.68s2.73-.55,3.88-1.64c1.15-1.09,1.73-2.56,1.73-4.4s-.58-3.32-1.74-4.43c-1.16-1.11-2.46-1.67-3.9-1.67s-2.73.56-3.88,1.68c-1.15,1.12-1.73,2.58-1.73,4.38Z"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="cls-3" d="M176.92,11.06c-.03-.23-.13-.5-.29-.8-.28-.55-1.07-.82-2.37-.82h-6.55c-1.78,0-3.51.65-5.19,1.94-.81.63-1.48,1.48-2,2.55-.53,1.07-.79,2.27-.79,3.58,0,2.29.76,4.17,2.28,5.64-.44,1.07-1.13,2.66-2.06,4.76-.3.73-.45,1.25-.45,1.58,0,.77.63,1.42,1.88,1.94.65.28,1.17.43,1.56.43s.72-.1.97-.29c.25-.19.44-.39.56-.59.2-.38.99-2.21,2.37-5.49l.94.06h3.82v3.43c0,.47.02.81.05,1.05.03.23.13.5.29.8.28.55,1.07.82,2.37.82,1.42,0,2.25-.37,2.49-1.12.12-.34.18-.87.18-1.58V12.11c0-.46-.02-.81-.05-1.05ZM172.81,19.44c-.09.14-.48.77-1.24.91-.2.04-.37.03-.48.02-.02.14-.04.26-.06.38-.16.83-.38,1.05-.57,1.07-.29.05-.51-.35-.93-.9-.23.01-.46.02-.69.02-.51,0-1.01-.03-1.49-.09-.25-.03-.5-.07-.74-.11-1.18-.32-2.03-1.27-2.03-2.4v-1.37c0-1.13.86-2.08,2.03-2.4.24-.04.49-.08.74-.11.48-.06.98-.09,1.49-.09s1.01.03,1.49.09c.25.03.5.07.74.11.6.16,1.12.49,1.49.93.34.41.55.92.55,1.47v1.37c0,.23-.01.66-.29,1.1Z"/>
|
||||
<circle class="cls-2" cx="167.24" cy="17.67" r=".49"/>
|
||||
<circle class="cls-2" cx="168.88" cy="17.71" r=".49"/>
|
||||
<circle class="cls-2" cx="170.59" cy="17.71" r=".49"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="cls-3" d="M141.01,8.24c.03-.23.13-.5.29-.8.28-.55,1.07-.82,2.37-.82h6.55c1.78,0,3.51.65,5.19,1.94.81.63,1.48,1.48,2,2.55.53,1.07.79,2.27.79,3.58,0,2.29-.76,4.17-2.28,5.64.44,1.07,1.13,2.66,2.06,4.76.3.73.45,1.25.45,1.58,0,.77-.63,1.42-1.88,1.94-.65.28-1.17.43-1.56.43s-.72-.1-.97-.29c-.25-.19-.44-.39-.56-.59-.2-.38-.99-2.21-2.37-5.49l-.94.06h-3.82v3.43c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.28.55-1.07.82-2.37.82-1.42,0-2.25-.37-2.49-1.12-.12-.34-.18-.87-.18-1.58V9.28c0-.46.02-.81.05-1.05ZM145.12,16.62c.09.14.48.77,1.24.91.2.04.37.03.48.02.02.14.04.26.06.38.16.83.38,1.05.57,1.07.29.05.51-.35.93-.9.23.01.46.02.69.02.51,0,1.01-.03,1.49-.09.25-.03.5-.07.74-.11,1.18-.32,2.03-1.27,2.03-2.4v-1.37c0-1.13-.86-2.08-2.03-2.4-.24-.04-.49-.08-.74-.11-.48-.06-.98-.09-1.49-.09s-1.01.03-1.49.09c-.25.03-.5.07-.74.11-.6.16-1.12.49-1.49.93-.34.41-.55.92-.55,1.47v1.37c0,.23.01.66.29,1.1Z"/>
|
||||
<circle class="cls-2" cx="150.69" cy="14.84" r=".49"/>
|
||||
<circle class="cls-2" cx="149.05" cy="14.89" r=".49"/>
|
||||
<circle class="cls-2" cx="147.35" cy="14.89" r=".49"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 9.5 KiB |
BIN
docs/images/pku.png
Normal file
BIN
docs/images/pku.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
1
docs/images/ucloud.svg
Normal file
1
docs/images/ucloud.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 7.9 KiB |
@@ -1,7 +1,9 @@
|
||||
package dto
|
||||
|
||||
type ChannelSettings struct {
|
||||
ForceFormat bool `json:"force_format,omitempty"`
|
||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||
Proxy string `json:"proxy"`
|
||||
ForceFormat bool `json:"force_format,omitempty"`
|
||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||
Proxy string `json:"proxy"`
|
||||
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
}
|
||||
|
||||
138
dto/claude.go
138
dto/claude.go
@@ -2,7 +2,9 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
@@ -158,6 +160,27 @@ type InputSchema struct {
|
||||
Required any `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeWebSearchTool struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
MaxUses int `json:"max_uses,omitempty"`
|
||||
UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeWebSearchUserLocation struct {
|
||||
Type string `json:"type"`
|
||||
Timezone string `json:"timezone,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
@@ -176,6 +199,59 @@ type ClaudeRequest struct {
|
||||
Thinking *Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// AddTool 添加工具到请求中
|
||||
func (c *ClaudeRequest) AddTool(tool any) {
|
||||
if c.Tools == nil {
|
||||
c.Tools = make([]any, 0)
|
||||
}
|
||||
|
||||
switch tools := c.Tools.(type) {
|
||||
case []any:
|
||||
c.Tools = append(tools, tool)
|
||||
default:
|
||||
// 如果Tools不是[]any类型,重新初始化为[]any
|
||||
c.Tools = []any{tool}
|
||||
}
|
||||
}
|
||||
|
||||
// GetTools 获取工具列表
|
||||
func (c *ClaudeRequest) GetTools() []any {
|
||||
if c.Tools == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch tools := c.Tools.(type) {
|
||||
case []any:
|
||||
return tools
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessTools 处理工具列表,支持类型断言
|
||||
func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
|
||||
var normalTools []*Tool
|
||||
var webSearchTools []*ClaudeWebSearchTool
|
||||
|
||||
for _, tool := range tools {
|
||||
switch t := tool.(type) {
|
||||
case *Tool:
|
||||
normalTools = append(normalTools, t)
|
||||
case *ClaudeWebSearchTool:
|
||||
webSearchTools = append(webSearchTools, t)
|
||||
case Tool:
|
||||
normalTools = append(normalTools, &t)
|
||||
case ClaudeWebSearchTool:
|
||||
webSearchTools = append(webSearchTools, &t)
|
||||
default:
|
||||
// 未知类型,跳过
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return normalTools, webSearchTools
|
||||
}
|
||||
|
||||
type Thinking struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||
@@ -209,14 +285,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
||||
return mediaContent
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeErrorWithStatusCode struct {
|
||||
Error ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Error types.ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
LocalError bool
|
||||
}
|
||||
|
||||
@@ -228,7 +299,7 @@ type ClaudeResponse struct {
|
||||
Completion string `json:"completion,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *ClaudeError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||
@@ -249,9 +320,50 @@ func (c *ClaudeResponse) GetIndex() int {
|
||||
return *c.Index
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
// GetClaudeError 从动态错误类型中提取ClaudeError结构
|
||||
func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
|
||||
if c.Error == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err := c.Error.(type) {
|
||||
case types.ClaudeError:
|
||||
return &err
|
||||
case *types.ClaudeError:
|
||||
return err
|
||||
case map[string]interface{}:
|
||||
// 处理从JSON解析来的map结构
|
||||
claudeErr := &types.ClaudeError{}
|
||||
if errType, ok := err["type"].(string); ok {
|
||||
claudeErr.Type = errType
|
||||
}
|
||||
if errMsg, ok := err["message"].(string); ok {
|
||||
claudeErr.Message = errMsg
|
||||
}
|
||||
return claudeErr
|
||||
case string:
|
||||
// 处理简单字符串错误
|
||||
return &types.ClaudeError{
|
||||
Type: "error",
|
||||
Message: err,
|
||||
}
|
||||
default:
|
||||
// 未知类型,尝试转换为字符串
|
||||
return &types.ClaudeError{
|
||||
Type: "unknown_error",
|
||||
Message: fmt.Sprintf("%v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"`
|
||||
}
|
||||
|
||||
type ClaudeServerToolUse struct {
|
||||
WebSearchRequests int `json:"web_search_requests"`
|
||||
}
|
||||
|
||||
12
dto/error.go
12
dto/error.go
@@ -1,5 +1,7 @@
|
||||
package dto
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
@@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct {
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Error types.OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package gemini
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
@@ -32,7 +35,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
||||
MimeTypeSnake string `json:"mime_type"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -53,7 +56,7 @@ type FunctionCall struct {
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
@@ -78,7 +81,7 @@ type GeminiPart struct {
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||
@@ -93,7 +96,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error {
|
||||
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,15 +7,15 @@ import (
|
||||
)
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema json.RawMessage `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type FormatJsonSchema struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict any `json:"strict,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict json.RawMessage `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
@@ -55,21 +55,33 @@ type GeneralOpenAIRequest struct {
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
SearchParameters any `json:"search_parameters,omitempty"` //xai
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// OpenRouter Params
|
||||
Usage json.RawMessage `json:"usage,omitempty"`
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
// 用匿名参数接收额外参数,例如ollama的think参数在此接收
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
result := make(map[string]any)
|
||||
data, _ := common.EncodeJson(r)
|
||||
_ = common.UnmarshalJson(data, &result)
|
||||
data, _ := common.Marshal(r)
|
||||
_ = common.Unmarshal(data, &result)
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||
if strings.HasPrefix(r.Model, "o") {
|
||||
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||
return "developer"
|
||||
}
|
||||
}
|
||||
return "system"
|
||||
}
|
||||
|
||||
type ToolCallRequest struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
@@ -602,26 +614,29 @@ type WebSearchOptions struct {
|
||||
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/responses/create
|
||||
type OpenAIResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Include json.RawMessage `json:"include,omitempty"`
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools []ResponsesToolsCall `json:"tools,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Include json.RawMessage `json:"include,omitempty"`
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
}
|
||||
|
||||
type Reasoning struct {
|
||||
@@ -629,23 +644,23 @@ type Reasoning struct {
|
||||
Summary string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
type ResponsesToolsCall struct {
|
||||
Type string `json:"type"`
|
||||
// Web Search
|
||||
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
||||
SearchContextSize string `json:"search_context_size,omitempty"`
|
||||
// File Search
|
||||
VectorStoreIds []string `json:"vector_store_ids,omitempty"`
|
||||
MaxNumResults uint `json:"max_num_results,omitempty"`
|
||||
Filters json.RawMessage `json:"filters,omitempty"`
|
||||
// Computer Use
|
||||
DisplayWidth uint `json:"display_width,omitempty"`
|
||||
DisplayHeight uint `json:"display_height,omitempty"`
|
||||
Environment string `json:"environment,omitempty"`
|
||||
// Function
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Function json.RawMessage `json:"function,omitempty"`
|
||||
Container json.RawMessage `json:"container,omitempty"`
|
||||
}
|
||||
//type ResponsesToolsCall struct {
|
||||
// Type string `json:"type"`
|
||||
// // Web Search
|
||||
// UserLocation json.RawMessage `json:"user_location,omitempty"`
|
||||
// SearchContextSize string `json:"search_context_size,omitempty"`
|
||||
// // File Search
|
||||
// VectorStoreIds []string `json:"vector_store_ids,omitempty"`
|
||||
// MaxNumResults uint `json:"max_num_results,omitempty"`
|
||||
// Filters json.RawMessage `json:"filters,omitempty"`
|
||||
// // Computer Use
|
||||
// DisplayWidth uint `json:"display_width,omitempty"`
|
||||
// DisplayHeight uint `json:"display_height,omitempty"`
|
||||
// Environment string `json:"environment,omitempty"`
|
||||
// // Function
|
||||
// Name string `json:"name,omitempty"`
|
||||
// Description string `json:"description,omitempty"`
|
||||
// Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
// Function json.RawMessage `json:"function,omitempty"`
|
||||
// Container json.RawMessage `json:"container,omitempty"`
|
||||
//}
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error *OpenAIError `json:"error"`
|
||||
Error any `json:"error"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(s.Error)
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
@@ -28,10 +37,15 @@ type OpenAITextResponse struct {
|
||||
Object string `json:"object"`
|
||||
Created any `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
@@ -45,6 +59,19 @@ type OpenAIEmbeddingResponse struct {
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type FlexibleEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"`
|
||||
}
|
||||
|
||||
type FlexibleEmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []FlexibleEmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
|
||||
Logprobs *any `json:"logprobs"`
|
||||
@@ -179,7 +206,7 @@ type Usage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||
// OpenRouter Params
|
||||
Cost float64 `json:"cost,omitempty"`
|
||||
Cost any `json:"cost,omitempty"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
@@ -197,28 +224,33 @@ type OutputTokenDetails struct {
|
||||
}
|
||||
|
||||
type OpenAIResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutput `json:"output"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
PreviousResponseID string `json:"previous_response_id"`
|
||||
Reasoning *Reasoning `json:"reasoning"`
|
||||
Store bool `json:"store"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
ToolChoice string `json:"tool_choice"`
|
||||
Tools []ResponsesToolsCall `json:"tools"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Truncation string `json:"truncation"`
|
||||
Usage *Usage `json:"usage"`
|
||||
User json.RawMessage `json:"user"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Error any `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutput `json:"output"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
PreviousResponseID string `json:"previous_response_id"`
|
||||
Reasoning *Reasoning `json:"reasoning"`
|
||||
Store bool `json:"store"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
ToolChoice string `json:"tool_choice"`
|
||||
Tools []map[string]any `json:"tools"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Truncation string `json:"truncation"`
|
||||
Usage *Usage `json:"usage"`
|
||||
User json.RawMessage `json:"user"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
type IncompleteDetails struct {
|
||||
@@ -260,3 +292,45 @@ type ResponsesStreamResponse struct {
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Item *ResponsesOutput `json:"item,omitempty"`
|
||||
}
|
||||
|
||||
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||
func GetOpenAIError(errorField any) *types.OpenAIError {
|
||||
if errorField == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err := errorField.(type) {
|
||||
case types.OpenAIError:
|
||||
return &err
|
||||
case *types.OpenAIError:
|
||||
return err
|
||||
case map[string]interface{}:
|
||||
// 处理从JSON解析来的map结构
|
||||
openaiErr := &types.OpenAIError{}
|
||||
if errType, ok := err["type"].(string); ok {
|
||||
openaiErr.Type = errType
|
||||
}
|
||||
if errMsg, ok := err["message"].(string); ok {
|
||||
openaiErr.Message = errMsg
|
||||
}
|
||||
if errParam, ok := err["param"].(string); ok {
|
||||
openaiErr.Param = errParam
|
||||
}
|
||||
if errCode, ok := err["code"]; ok {
|
||||
openaiErr.Code = errCode
|
||||
}
|
||||
return openaiErr
|
||||
case string:
|
||||
// 处理简单字符串错误
|
||||
return &types.OpenAIError{
|
||||
Type: "error",
|
||||
Message: err,
|
||||
}
|
||||
default:
|
||||
// 未知类型,尝试转换为字符串
|
||||
return &types.OpenAIError{
|
||||
Type: "unknown_error",
|
||||
Message: fmt.Sprintf("%v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package dto
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
const (
|
||||
RealtimeEventTypeError = "error"
|
||||
RealtimeEventTypeSessionUpdate = "session.update"
|
||||
@@ -23,12 +25,12 @@ type RealtimeEvent struct {
|
||||
EventId string `json:"event_id"`
|
||||
Type string `json:"type"`
|
||||
//PreviousItemId string `json:"previous_item_id"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type RealtimeResponse struct {
|
||||
|
||||
7
go.mod
7
go.mod
@@ -11,7 +11,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@@ -28,6 +27,8 @@ require (
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/stripe/stripe-go/v81 v81.4.0
|
||||
github.com/thanhpk/randstr v1.0.6
|
||||
github.com/tiktoken-go/tokenizer v0.6.2
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
@@ -44,6 +45,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||
github.com/aws/smithy-go v1.20.2 // indirect
|
||||
github.com/boombuler/barcode v1.1.0 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -58,11 +60,9 @@ require (
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
@@ -80,6 +80,7 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||
github.com/pquerna/otp v1.5.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
|
||||
22
go.sum
22
go.sum
@@ -1,7 +1,5 @@
|
||||
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
|
||||
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
|
||||
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||
@@ -22,6 +20,10 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
||||
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
||||
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
|
||||
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
@@ -42,8 +44,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994 h1:aQYWswi+hRL2zJqGacdCZx32XjKYV8ApXFGntw79XAM=
|
||||
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
@@ -87,8 +87,6 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU=
|
||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
@@ -103,8 +101,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U=
|
||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
@@ -175,6 +173,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
@@ -201,6 +201,10 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
|
||||
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
|
||||
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
@@ -230,6 +234,7 @@ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSO
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -238,6 +243,7 @@ golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -70,7 +70,7 @@
|
||||
"关于": "关于",
|
||||
"注销成功!": "注销成功!",
|
||||
"个人设置": "个人设置",
|
||||
"API令牌": "API令牌",
|
||||
"令牌管理": "令牌管理",
|
||||
"退出": "退出",
|
||||
"关闭侧边栏": "关闭侧边栏",
|
||||
"打开侧边栏": "打开侧边栏",
|
||||
@@ -585,6 +585,19 @@
|
||||
"渠道权重": "渠道权重",
|
||||
"渠道额外设置": "渠道额外设置",
|
||||
"此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:": "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:",
|
||||
"强制格式化": "强制格式化",
|
||||
"强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)": "强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)",
|
||||
"思考内容转换": "思考内容转换",
|
||||
"将 reasoning_content 转换为 <think> 标签拼接到内容中": "将 reasoning_content 转换为 <think> 标签拼接到内容中",
|
||||
"透传请求体": "透传请求体",
|
||||
"启用请求体透传功能": "启用请求体透传功能",
|
||||
"代理地址": "代理地址",
|
||||
"例如: socks5://user:pass@host:port": "例如: socks5://user:pass@host:port",
|
||||
"用于配置网络代理": "用于配置网络代理",
|
||||
"用于配置网络代理,支持 socks5 协议": "用于配置网络代理,支持 socks5 协议",
|
||||
"系统提示词": "系统提示词",
|
||||
"输入系统提示词,用户的系统提示词将优先于此设置": "输入系统提示词,用户的系统提示词将优先于此设置",
|
||||
"用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置": "用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置",
|
||||
"参数覆盖": "参数覆盖",
|
||||
"此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:": "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:",
|
||||
"请输入组织org-xxx": "请输入组织org-xxx",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -121,7 +122,20 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Set("role", role)
|
||||
c.Set("id", id)
|
||||
c.Set("group", session.Get("group"))
|
||||
c.Set("user_group", session.Get("group"))
|
||||
c.Set("use_access_token", useAccessToken)
|
||||
|
||||
//userCache, err := model.GetUserCache(id.(int))
|
||||
//if err != nil {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": err.Error(),
|
||||
// })
|
||||
// c.Abort()
|
||||
// return
|
||||
//}
|
||||
//userCache.WriteContext(c)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@@ -233,30 +247,41 @@ func TokenAuth() func(c *gin.Context) {
|
||||
|
||||
userCache.WriteContext(c)
|
||||
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
||||
if !token.UnlimitedQuota {
|
||||
c.Set("token_quota", token.RemainQuota)
|
||||
}
|
||||
if token.ModelLimitsEnabled {
|
||||
c.Set("token_model_limit_enabled", true)
|
||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
}
|
||||
err = SetupContextForToken(c, token, parts...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token is nil")
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
||||
if !token.UnlimitedQuota {
|
||||
c.Set("token_quota", token.RemainQuota)
|
||||
}
|
||||
if token.ModelLimitsEnabled {
|
||||
c.Set("token_model_limit_enabled", true)
|
||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return fmt.Errorf("普通用户不支持指定渠道")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
|
||||
type ModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
Group string `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
@@ -98,6 +100,10 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if shouldSelectChannel {
|
||||
if modelRequest.Model == "" {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
|
||||
return
|
||||
}
|
||||
var selectGroup string
|
||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
@@ -105,18 +111,17 @@ func Distribute() func(c *gin.Context) {
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
}
|
||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||
//if channel != nil {
|
||||
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
// message = "数据库一致性已被破坏,请联系管理员"
|
||||
//}
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
if channel == nil {
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -172,22 +177,13 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
var platform string
|
||||
var relayMode int
|
||||
if strings.HasPrefix(modelRequest.Model, "jimeng") {
|
||||
platform = string(constant.TaskPlatformJimeng)
|
||||
relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeJimengFetchByID {
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
} else {
|
||||
platform = string(constant.TaskPlatformKling)
|
||||
relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
relayMode := relayconstant.RelayModeUnknown
|
||||
if c.Request.Method == http.MethodPost {
|
||||
relayMode = relayconstant.RelayModeVideoSubmit
|
||||
} else if c.Request.Method == http.MethodGet {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
c.Set("platform", platform)
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||
@@ -237,28 +233,50 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||
// playground chat completions
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
return nil, false, errors.New("无效的请求, " + err.Error())
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
|
||||
}
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
return
|
||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
c.Set("channel_create_time", channel.CreatedTime)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||
c.Set("param_override", channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||
}
|
||||
c.Set("auto_ban", channel.GetAutoBan())
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
|
||||
|
||||
key, index, newAPIError := channel.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
||||
} else {
|
||||
// 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
|
||||
}
|
||||
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeAzure:
|
||||
@@ -278,6 +296,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
case constant.ChannelTypeCoze:
|
||||
c.Set("bot_id", channel.Other)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Runtime 配置
|
||||
type JSRuntimeConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
MaxVMCount int `json:"max_vm_count"`
|
||||
ScriptTimeout time.Duration `json:"script_timeout"`
|
||||
ScriptDir string `json:"script_dir"`
|
||||
FetchTimeout time.Duration `json:"fetch_timeout"`
|
||||
}
|
||||
|
||||
var (
|
||||
jsConfig = JSRuntimeConfig{}
|
||||
)
|
||||
|
||||
const (
|
||||
defaultScriptDir = "scripts/"
|
||||
defaultScriptTimeout = 5 * time.Second
|
||||
defaultFetchTimeout = 10 * time.Second
|
||||
defaultMaxVMCount = 8
|
||||
)
|
||||
|
||||
func loadCfg() {
|
||||
if enabled := os.Getenv("JS_RUNTIME_ENABLED"); enabled != "" {
|
||||
jsConfig.Enabled = enabled == "true"
|
||||
}
|
||||
|
||||
if maxCount := os.Getenv("JS_MAX_VM_COUNT"); maxCount != "" {
|
||||
if count, err := strconv.Atoi(maxCount); err == nil && count > 0 {
|
||||
jsConfig.MaxVMCount = count
|
||||
}
|
||||
} else {
|
||||
jsConfig.MaxVMCount = defaultMaxVMCount
|
||||
}
|
||||
|
||||
if timeout := os.Getenv("JS_SCRIPT_TIMEOUT"); timeout != "" {
|
||||
if t, err := time.ParseDuration(timeout + "s"); err == nil && t > 0 {
|
||||
jsConfig.ScriptTimeout = t
|
||||
}
|
||||
} else {
|
||||
jsConfig.ScriptTimeout = defaultScriptTimeout
|
||||
}
|
||||
|
||||
if fetchTimeout := os.Getenv("JS_FETCH_TIMEOUT"); fetchTimeout != "" {
|
||||
if t, err := time.ParseDuration(fetchTimeout + "s"); err == nil && t > 0 {
|
||||
jsConfig.FetchTimeout = t
|
||||
}
|
||||
} else {
|
||||
jsConfig.FetchTimeout = defaultFetchTimeout
|
||||
}
|
||||
|
||||
jsConfig.ScriptDir = os.Getenv("JS_SCRIPT_DIR")
|
||||
if jsConfig.ScriptDir == "" {
|
||||
jsConfig.ScriptDir = defaultScriptDir
|
||||
}
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func dbQuery(db *gorm.DB, sql string, args ...any) []map[string]any {
|
||||
if db == nil {
|
||||
common.SysError("JS DB is nil")
|
||||
return nil
|
||||
}
|
||||
|
||||
rows, err := db.Raw(sql, args...).Rows()
|
||||
if err != nil {
|
||||
common.SysError("JS DB Query Error: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
common.SysError("JS DB Columns Error: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
results := make([]map[string]any, 0, 100)
|
||||
for rows.Next() {
|
||||
values := make([]any, len(columns))
|
||||
valuePtrs := make([]any, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
common.SysError("JS DB Scan Error: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
row := make(map[string]any, len(columns))
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func dbExec(db *gorm.DB, sql string, args ...any) map[string]any {
|
||||
if db == nil {
|
||||
return map[string]any{
|
||||
"rowsAffected": int64(0),
|
||||
"error": "database is nil",
|
||||
}
|
||||
}
|
||||
|
||||
result := db.Exec(sql, args...)
|
||||
return map[string]any{
|
||||
"rowsAffected": result.RowsAffected,
|
||||
"error": result.Error,
|
||||
}
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type JSFetchRequest struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body any `json:"body"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
type JSFetchResponse struct {
|
||||
Status int `json:"status"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body string `json:"body"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) fetch(url string, options ...any) *JSFetchResponse {
|
||||
req := &JSFetchRequest{
|
||||
Method: "GET",
|
||||
URL: url,
|
||||
Headers: make(map[string]string),
|
||||
Timeout: int(jsConfig.FetchTimeout.Seconds()),
|
||||
}
|
||||
|
||||
// 解析选项
|
||||
if len(options) > 0 && options[0] != nil {
|
||||
if optMap, ok := options[0].(map[string]any); ok {
|
||||
if method, exists := optMap["method"]; exists {
|
||||
if methodStr, ok := method.(string); ok {
|
||||
req.Method = strings.ToUpper(methodStr)
|
||||
}
|
||||
}
|
||||
|
||||
if headers, exists := optMap["headers"]; exists {
|
||||
if headersMap, ok := headers.(map[string]any); ok {
|
||||
for k, v := range headersMap {
|
||||
if vStr, ok := v.(string); ok {
|
||||
req.Headers[k] = vStr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if body, exists := optMap["body"]; exists {
|
||||
req.Body = body
|
||||
}
|
||||
|
||||
if timeout, exists := optMap["timeout"]; exists {
|
||||
if timeoutNum, ok := timeout.(float64); ok {
|
||||
req.Timeout = int(timeoutNum)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
var bodyReader io.Reader
|
||||
switch body := req.Body.(type) {
|
||||
case string:
|
||||
bodyReader = strings.NewReader(body)
|
||||
case []byte:
|
||||
bodyReader = bytes.NewReader(body)
|
||||
case nil:
|
||||
bodyReader = nil
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return &JSFetchResponse{
|
||||
Error: fmt.Sprintf("Failed to marshal body: %v", err),
|
||||
}
|
||||
}
|
||||
bodyReader = bytes.NewReader(bodyBytes)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest(req.Method, req.URL, bodyReader)
|
||||
if err != nil {
|
||||
return &JSFetchResponse{
|
||||
Error: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
for k, v := range req.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// 设置默认User-Agent
|
||||
if httpReq.Header.Get("User-Agent") == "" {
|
||||
httpReq.Header.Set("User-Agent", "JS-Runtime-Fetch/1.0")
|
||||
}
|
||||
|
||||
// 创建带超时的上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
httpReq = httpReq.WithContext(ctx)
|
||||
|
||||
// 执行请求
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return &JSFetchResponse{}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return &JSFetchResponse{
|
||||
Status: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// 构建响应头
|
||||
headers := make(map[string]string)
|
||||
for k, v := range resp.Header {
|
||||
if len(v) > 0 {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
return &JSFetchResponse{
|
||||
Status: resp.StatusCode,
|
||||
Headers: headers,
|
||||
Body: string(bodyBytes),
|
||||
}
|
||||
}
|
||||
@@ -1,570 +0,0 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 池化
|
||||
type JSRuntimePool struct {
|
||||
pool chan *goja.Runtime
|
||||
maxSize int
|
||||
createFunc func() *goja.Runtime
|
||||
scripts string
|
||||
mu sync.RWMutex
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
var (
|
||||
jsRuntimePool *JSRuntimePool
|
||||
jsPoolOnce sync.Once
|
||||
)
|
||||
|
||||
func NewJSRuntimePool(maxSize int) *JSRuntimePool {
|
||||
// 创建HTTP客户端
|
||||
httpClient := &http.Client{
|
||||
Timeout: jsConfig.FetchTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
},
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
pool := &JSRuntimePool{
|
||||
pool: make(chan *goja.Runtime, maxSize),
|
||||
maxSize: maxSize,
|
||||
scripts: "",
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
pool.createFunc = func() *goja.Runtime {
|
||||
vm := goja.New()
|
||||
pool.setupGlobals(vm)
|
||||
pool.loadScripts(vm)
|
||||
return vm
|
||||
}
|
||||
|
||||
// 预创建VM
|
||||
preCreate := min(maxSize/2, 4)
|
||||
for range preCreate {
|
||||
select {
|
||||
case pool.pool <- pool.createFunc():
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) Get() *goja.Runtime {
|
||||
select {
|
||||
case vm := <-p.pool:
|
||||
return vm
|
||||
default:
|
||||
return p.createFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) Put(vm *goja.Runtime) {
|
||||
if vm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case p.pool <- vm:
|
||||
default:
|
||||
// 池满,丢弃VM让GC回收
|
||||
}
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) {
|
||||
// console
|
||||
console := vm.NewObject()
|
||||
console.Set("log", func(args ...any) {
|
||||
var strs []string
|
||||
for _, arg := range args {
|
||||
strs = append(strs, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
common.SysLog("JS: " + strings.Join(strs, " "))
|
||||
})
|
||||
console.Set("error", func(args ...any) {
|
||||
var strs []string
|
||||
for _, arg := range args {
|
||||
strs = append(strs, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
common.SysError("JS: " + strings.Join(strs, " "))
|
||||
})
|
||||
console.Set("warn", func(args ...any) {
|
||||
var strs []string
|
||||
for _, arg := range args {
|
||||
strs = append(strs, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
common.SysError("JS WARN: " + strings.Join(strs, " "))
|
||||
})
|
||||
vm.Set("console", console)
|
||||
|
||||
// JSON
|
||||
jsonObj := vm.NewObject()
|
||||
jsonObj.Set("parse", func(str string) any {
|
||||
var result any
|
||||
err := json.Unmarshal([]byte(str), &result)
|
||||
if err != nil {
|
||||
panic(vm.ToValue(err.Error()))
|
||||
}
|
||||
return result
|
||||
})
|
||||
jsonObj.Set("stringify", func(obj any) string {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
panic(vm.ToValue(err.Error()))
|
||||
}
|
||||
return string(data)
|
||||
})
|
||||
vm.Set("JSON", jsonObj)
|
||||
|
||||
// fetch 实现
|
||||
vm.Set("fetch", func(url string, options ...any) *JSFetchResponse {
|
||||
return p.fetch(url, options...)
|
||||
})
|
||||
|
||||
// 数据库
|
||||
setDB(vm, model.DB, "db")
|
||||
setDB(vm, model.LOG_DB, "logdb")
|
||||
|
||||
// 定时器
|
||||
vm.Set("setTimeout", func(fn func(), delay int) {
|
||||
go func() {
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
fn()
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
// 如果已经缓存了合并的脚本,直接使用
|
||||
if p.scripts != "" {
|
||||
if _, err := vm.RunString(p.scripts); err != nil {
|
||||
common.SysError("Failed to load cached scripts: " + err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 首次加载时,读取 scripts/ 文件夹中的所有脚本
|
||||
p.mu.RUnlock()
|
||||
p.mu.Lock()
|
||||
defer func() {
|
||||
p.mu.Unlock()
|
||||
p.mu.RLock()
|
||||
}()
|
||||
|
||||
if p.scripts != "" {
|
||||
if _, err := vm.RunString(p.scripts); err != nil {
|
||||
common.SysError("Failed to load cached scripts: " + err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 读取所有脚本文件
|
||||
var combinedScript strings.Builder
|
||||
scriptDir := jsConfig.ScriptDir
|
||||
|
||||
// 检查目录是否存在
|
||||
if _, err := os.Stat(scriptDir); os.IsNotExist(err) {
|
||||
common.SysLog("Scripts directory does not exist: " + scriptDir)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取目录中的所有 .js 文件
|
||||
files, err := filepath.Glob(filepath.Join(scriptDir, "*.js"))
|
||||
if err != nil {
|
||||
common.SysError("Failed to read scripts directory: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
common.SysLog("No JavaScript files found in: " + scriptDir)
|
||||
return
|
||||
}
|
||||
|
||||
// 按文件名排序以确保加载顺序一致
|
||||
for _, file := range files {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
common.SysError("Failed to read script file " + file + ": " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 添加文件注释和内容
|
||||
combinedScript.WriteString("// File: " + filepath.Base(file) + "\n")
|
||||
combinedScript.WriteString(string(content))
|
||||
combinedScript.WriteString("\n\n")
|
||||
|
||||
common.SysLog("Loaded script: " + filepath.Base(file))
|
||||
}
|
||||
|
||||
// 缓存合并后的脚本
|
||||
p.scripts = combinedScript.String()
|
||||
|
||||
// 执行脚本
|
||||
if p.scripts != "" {
|
||||
if _, err := vm.RunString(p.scripts); err != nil {
|
||||
common.SysError("Failed to load combined scripts: " + err.Error())
|
||||
} else {
|
||||
common.SysLog("Successfully loaded and combined all JavaScript files from: " + scriptDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) ReloadScripts() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 清空缓存的脚本
|
||||
p.scripts = ""
|
||||
|
||||
// 清空VM池,强制重新创建
|
||||
for {
|
||||
select {
|
||||
case <-p.pool:
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
common.SysLog("JavaScript scripts reloaded")
|
||||
}
|
||||
|
||||
func initJSRuntimePool() *JSRuntimePool {
|
||||
jsPoolOnce.Do(func() {
|
||||
jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount)
|
||||
common.SysLog("JavaScript runtime pool initialized successfully")
|
||||
})
|
||||
return jsRuntimePool
|
||||
}
|
||||
|
||||
func validateGinContext(c *gin.Context) error {
|
||||
if c == nil {
|
||||
return fmt.Errorf("gin context is nil")
|
||||
}
|
||||
if c.Request == nil {
|
||||
return fmt.Errorf("gin context request is nil")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) executeWithTimeout(_ *goja.Runtime, fn func() (goja.Value, error)) (goja.Value, error) {
|
||||
type result struct {
|
||||
value goja.Value
|
||||
err error
|
||||
}
|
||||
|
||||
resultChan := make(chan result, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
resultChan <- result{err: fmt.Errorf("JS panic: %v", r)}
|
||||
}
|
||||
}()
|
||||
|
||||
value, err := fn()
|
||||
resultChan <- result{value: value, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-resultChan:
|
||||
return res.value, res.err
|
||||
case <-time.After(jsConfig.ScriptTimeout):
|
||||
return nil, fmt.Errorf("script execution timeout after %v", jsConfig.ScriptTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) PreProcessRequest(c *gin.Context) error {
|
||||
if err := validateGinContext(c); err != nil {
|
||||
common.SysError("JS PreProcess Validation Error: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
vm := p.Get()
|
||||
defer p.Put(vm)
|
||||
|
||||
preProcessFunc := vm.Get("preProcessRequest")
|
||||
if preProcessFunc == nil || goja.IsUndefined(preProcessFunc) {
|
||||
return nil
|
||||
}
|
||||
|
||||
jsReq, err := common.StructToMap(createJSReq(c))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JS context: %v", err)
|
||||
}
|
||||
|
||||
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
|
||||
fn, ok := goja.AssertFunction(preProcessFunc)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("preProcessRequest is not a function")
|
||||
}
|
||||
return fn(goja.Undefined(), vm.ToValue(jsReq))
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
common.SysError("JS PreProcess Error: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理返回结果
|
||||
if result != nil && !goja.IsUndefined(result) {
|
||||
resultObj := result.Export()
|
||||
if resultMap, ok := resultObj.(map[string]any); ok {
|
||||
// 是否修改请求
|
||||
if newBody, exists := resultMap["body"]; exists {
|
||||
switch v := newBody.(type) {
|
||||
case string:
|
||||
c.Request.Body = io.NopCloser(strings.NewReader(v))
|
||||
c.Request.ContentLength = int64(len(v))
|
||||
case []byte:
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(v))
|
||||
c.Request.ContentLength = int64(len(v))
|
||||
case map[string]any:
|
||||
bodyBytes, err := json.Marshal(v)
|
||||
if err == nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
c.Request.ContentLength = int64(len(bodyBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
common.SysError("JS PreProcess JSON Marshal Error: " + err.Error())
|
||||
}
|
||||
default:
|
||||
common.SysError("JS PreProcess Unsupported Body Type: " + fmt.Sprintf("%T", newBody))
|
||||
}
|
||||
}
|
||||
|
||||
// 是否修改 headers
|
||||
if newHeaders, exists := resultMap["headers"]; exists {
|
||||
if headersMap, ok := newHeaders.(map[string]any); ok {
|
||||
for key, value := range headersMap {
|
||||
if valueStr, ok := value.(string); ok {
|
||||
c.Request.Header.Set(key, valueStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 是否阻止请求
|
||||
if block, exists := resultMap["block"]; exists {
|
||||
if blockBool, ok := block.(bool); ok && blockBool {
|
||||
status := http.StatusForbidden
|
||||
if statusCode, exists := resultMap["statusCode"]; exists {
|
||||
if statusInt, ok := statusCode.(float64); ok {
|
||||
status = int(statusInt)
|
||||
}
|
||||
}
|
||||
|
||||
message := "Request blocked by pre-process script"
|
||||
if msg, exists := resultMap["message"]; exists {
|
||||
if msgStr, ok := msg.(string); ok {
|
||||
message = msgStr
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(status, gin.H{"error": message})
|
||||
c.Abort()
|
||||
return fmt.Errorf("request blocked")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) {
|
||||
if err := validateGinContext(c); err != nil {
|
||||
common.SysError("JS PostProcess Validation Error: " + err.Error())
|
||||
return statusCode, body, err
|
||||
}
|
||||
|
||||
vm := p.Get()
|
||||
defer p.Put(vm)
|
||||
|
||||
postProcessFunc := vm.Get("postProcessResponse")
|
||||
if postProcessFunc == nil || goja.IsUndefined(postProcessFunc) {
|
||||
return statusCode, body, nil
|
||||
}
|
||||
|
||||
jsReq, err := common.StructToMap(createJSReq(c))
|
||||
if err != nil {
|
||||
return statusCode, body, fmt.Errorf("failed to create JS context: %v", err)
|
||||
}
|
||||
|
||||
jsResp := &JSResponse{
|
||||
StatusCode: statusCode,
|
||||
Headers: make(map[string]string),
|
||||
Body: string(body),
|
||||
}
|
||||
|
||||
// 获取响应头
|
||||
if c.Writer != nil {
|
||||
for key, values := range c.Writer.Header() {
|
||||
if len(values) > 0 {
|
||||
jsResp.Headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
jsResponse, err := common.StructToMap(jsResp)
|
||||
if err != nil {
|
||||
return statusCode, body, fmt.Errorf("failed to create JS response context: %v", err)
|
||||
}
|
||||
|
||||
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
|
||||
fn, ok := goja.AssertFunction(postProcessFunc)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("postProcessResponse is not a function")
|
||||
}
|
||||
return fn(goja.Undefined(), vm.ToValue(jsReq), vm.ToValue(jsResponse))
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
common.SysError("JS PostProcess Error: " + err.Error())
|
||||
return statusCode, body, err
|
||||
}
|
||||
|
||||
// 处理返回
|
||||
if result != nil && !goja.IsUndefined(result) {
|
||||
resultObj := result.Export()
|
||||
if resultMap, ok := resultObj.(map[string]any); ok {
|
||||
if newStatusCode, exists := resultMap["statusCode"]; exists {
|
||||
if statusInt, ok := newStatusCode.(float64); ok {
|
||||
statusCode = int(statusInt)
|
||||
}
|
||||
}
|
||||
|
||||
if newBody, exists := resultMap["body"]; exists {
|
||||
if bodyStr, ok := newBody.(string); ok {
|
||||
body = []byte(bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
if newHeaders, exists := resultMap["headers"]; exists {
|
||||
if headersMap, ok := newHeaders.(map[string]any); ok {
|
||||
for key, value := range headersMap {
|
||||
if valueStr, ok := value.(string); ok {
|
||||
c.Header(key, valueStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return statusCode, body, nil
|
||||
}
|
||||
|
||||
func (p *JSRuntimePool) hasPostProcessFunction() bool {
|
||||
vm := p.Get()
|
||||
defer p.Put(vm)
|
||||
postProcessFunc := vm.Get("postProcessResponse")
|
||||
return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc)
|
||||
}
|
||||
|
||||
func JSRuntimeMiddleware() *gin.HandlerFunc {
|
||||
loadCfg()
|
||||
if !jsConfig.Enabled {
|
||||
common.SysLog("JavaScript Runtime is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
pool := initJSRuntimePool()
|
||||
var fn gin.HandlerFunc
|
||||
fn = func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
// 预处理
|
||||
if err := pool.PreProcessRequest(c); err != nil {
|
||||
common.SysError("JS Runtime PreProcess Error: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
if duration > time.Millisecond*100 {
|
||||
common.SysLog(fmt.Sprintf("JS Runtime PreProcess took %v", duration))
|
||||
}
|
||||
|
||||
// 后处理
|
||||
if pool.hasPostProcessFunction() {
|
||||
writer := newResponseWriter(c.Writer)
|
||||
c.Writer = writer
|
||||
|
||||
c.Next()
|
||||
|
||||
// 后处理响应
|
||||
if writer.body.Len() > 0 {
|
||||
start := time.Now()
|
||||
|
||||
statusCode, body, err := pool.PostProcessResponse(c, writer.statusCode, writer.body.Bytes())
|
||||
if err == nil {
|
||||
c.Writer = writer.ResponseWriter
|
||||
|
||||
for k, v := range writer.headerMap {
|
||||
for _, value := range v {
|
||||
c.Writer.Header().Add(k, value)
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(statusCode)
|
||||
|
||||
if len(body) >= 0 {
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
c.Writer.Write(body)
|
||||
} else {
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
c.Writer.Write(body)
|
||||
}
|
||||
} else {
|
||||
// 出错时回复原响应
|
||||
c.Writer = writer.ResponseWriter
|
||||
c.Status(writer.statusCode)
|
||||
|
||||
common.SysError(fmt.Sprintf("JS Runtime PostProcess Error: %v", err))
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
if duration > time.Millisecond*100 {
|
||||
common.SysLog(fmt.Sprintf("JS Runtime PostProcess took %v", duration))
|
||||
}
|
||||
} else {
|
||||
// 没有响应体时,恢复原始writer
|
||||
c.Writer = writer.ResponseWriter
|
||||
}
|
||||
} else {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
return &fn
|
||||
}
|
||||
|
||||
func ReloadJSScripts() {
|
||||
if jsRuntimePool != nil {
|
||||
jsRuntimePool.ReloadScripts()
|
||||
common.SysLog("JavaScript scripts reloaded")
|
||||
}
|
||||
}
|
||||
@@ -1,139 +0,0 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 请求
|
||||
type JSReq struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body any `json:"body"`
|
||||
UserAgent string `json:"userAgent"`
|
||||
RemoteIP string `json:"remoteIP"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
type JSResponse struct {
|
||||
StatusCode int `json:"statusCode"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body string `json:"body"`
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
statusCode int
|
||||
headerMap http.Header
|
||||
written bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func createJSReq(c *gin.Context) *JSReq {
|
||||
var bodyBytes []byte
|
||||
if c.Request != nil && c.Request.Body != nil {
|
||||
bodyBytes, _ = io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
|
||||
// headers map
|
||||
headers := make(map[string]string)
|
||||
if c.Request != nil && c.Request.Header != nil {
|
||||
for key, values := range c.Request.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
method := ""
|
||||
url := ""
|
||||
userAgent := ""
|
||||
remoteIP := ""
|
||||
contentType := ""
|
||||
|
||||
if c.Request != nil {
|
||||
method = c.Request.Method
|
||||
if c.Request.URL != nil {
|
||||
url = c.Request.URL.String()
|
||||
}
|
||||
userAgent = c.Request.UserAgent()
|
||||
contentType = c.ContentType()
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
remoteIP = c.ClientIP()
|
||||
}
|
||||
|
||||
parsedBody := parseBodyByType(bodyBytes, contentType)
|
||||
|
||||
return &JSReq{
|
||||
Method: method,
|
||||
URL: url,
|
||||
Headers: headers,
|
||||
Body: parsedBody,
|
||||
UserAgent: userAgent,
|
||||
RemoteIP: remoteIP,
|
||||
Extra: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func newResponseWriter(w gin.ResponseWriter) *responseWriter {
|
||||
return &responseWriter{
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
statusCode: 200,
|
||||
headerMap: make(http.Header),
|
||||
written: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Write(data []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if !w.written {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
return w.body.Write(data)
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteString(s string) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if !w.written {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
return w.body.WriteString(s)
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(statusCode int) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.written {
|
||||
return
|
||||
}
|
||||
w.statusCode = statusCode
|
||||
w.written = true
|
||||
|
||||
maps.Copy(w.headerMap, w.ResponseWriter.Header())
|
||||
}
|
||||
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
if w.headerMap == nil {
|
||||
w.headerMap = make(http.Header)
|
||||
}
|
||||
return w.headerMap
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
package jsrt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setDB(vm *goja.Runtime, db *gorm.DB, name string) {
|
||||
if db == nil {
|
||||
common.SysError("JS DB is nil")
|
||||
return
|
||||
}
|
||||
|
||||
obj := vm.NewObject()
|
||||
obj.Set("query", func(sql string, params ...any) []map[string]any {
|
||||
return dbQuery(db, sql, params...)
|
||||
})
|
||||
obj.Set("exec", func(sql string, params ...any) map[string]any {
|
||||
return dbExec(db, sql, params...)
|
||||
})
|
||||
if err := vm.Set(name, obj); err != nil {
|
||||
common.SysError("Failed to set JS DB: " + err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func parseBodyByType(bodyBytes []byte, contentType string) any {
|
||||
if len(bodyBytes) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
bodyStr := string(bodyBytes)
|
||||
contentLower := strings.ToLower(contentType)
|
||||
|
||||
switch {
|
||||
case strings.Contains(contentLower, "application/json"):
|
||||
var jsonObj any
|
||||
if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil {
|
||||
return jsonObj
|
||||
}
|
||||
return bodyStr
|
||||
|
||||
case strings.Contains(contentLower, "application/x-www-form-urlencoded"):
|
||||
if values, err := url.ParseQuery(bodyStr); err == nil {
|
||||
result := make(map[string]string, len(values))
|
||||
for k, v := range values {
|
||||
if len(v) > 0 {
|
||||
result[k] = v[0]
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return bodyStr
|
||||
|
||||
case strings.Contains(contentLower, "multipart/form-data"):
|
||||
return bodyBytes
|
||||
|
||||
case strings.Contains(contentLower, "text/"):
|
||||
return bodyStr
|
||||
|
||||
default:
|
||||
// 尝试JSON解析
|
||||
var jsonObj any
|
||||
if json.Unmarshal(bodyBytes, &jsonObj) == nil {
|
||||
return jsonObj
|
||||
}
|
||||
|
||||
// 尝试form解析
|
||||
if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 {
|
||||
result := make(map[string]string, len(values))
|
||||
for k, v := range values {
|
||||
if len(v) > 0 {
|
||||
result[k] = v[0]
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
return bodyStr
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,11 @@ func KlingRequestConvert() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := originalReq["model"].(string)
|
||||
// Support both model_name and model fields
|
||||
model, _ := originalReq["model_name"].(string)
|
||||
if model == "" {
|
||||
model, _ = originalReq["model"].(string)
|
||||
}
|
||||
prompt, _ := originalReq["prompt"].(string)
|
||||
|
||||
unifiedReq := map[string]interface{}{
|
||||
@@ -36,7 +40,7 @@ func KlingRequestConvert() func(c *gin.Context) {
|
||||
// Rewrite request body and path
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||
c.Request.URL.Path = "/v1/video/generations"
|
||||
if image := originalReq["image"]; image == "" {
|
||||
if image, ok := originalReq["image"]; !ok || image == "" {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
|
||||
@@ -87,26 +87,29 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
return priorityToUse, nil
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
return nil, err
|
||||
} else {
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
|
||||
}
|
||||
}
|
||||
|
||||
return channelQuery
|
||||
return channelQuery, nil
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
var abilities []Ability
|
||||
|
||||
var err error = nil
|
||||
channelQuery := getChannelQuery(group, model, retry)
|
||||
channelQuery, err := getChannelQuery(group, model, retry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||
} else {
|
||||
@@ -133,7 +136,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("channel not found")
|
||||
return nil, nil
|
||||
}
|
||||
err = DB.First(&channel, "id = ?", channel.Id).Error
|
||||
return &channel, err
|
||||
@@ -281,6 +284,21 @@ func FixAbility() (int, int, error) {
|
||||
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
||||
}
|
||||
defer fixLock.Unlock()
|
||||
|
||||
// truncate abilities table
|
||||
if common.UsingSQLite {
|
||||
err := DB.Exec("DELETE FROM abilities").Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
return 0, 0, err
|
||||
}
|
||||
} else {
|
||||
err := DB.Exec("TRUNCATE TABLE abilities").Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
var channels []*Channel
|
||||
// Find all channels
|
||||
err := DB.Model(&Channel{}).Find(&channels).Error
|
||||
|
||||
351
model/channel.go
351
model/channel.go
@@ -1,9 +1,15 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -35,9 +41,148 @@ type Channel struct {
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
Settings string `json:"settings"`
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
|
||||
// cache info
|
||||
Keys []string `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
type ChannelInfo struct {
|
||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
|
||||
MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
|
||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
func (c ChannelInfo) Value() (driver.Value, error) {
|
||||
return common.Marshal(&c)
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner interface
|
||||
func (c *ChannelInfo) Scan(value interface{}) error {
|
||||
bytesValue, _ := value.([]byte)
|
||||
return common.Unmarshal(bytesValue, c)
|
||||
}
|
||||
|
||||
func (channel *Channel) GetKeys() []string {
|
||||
if channel.Key == "" {
|
||||
return []string{}
|
||||
}
|
||||
if len(channel.Keys) > 0 {
|
||||
return channel.Keys
|
||||
}
|
||||
trimmed := strings.TrimSpace(channel.Key)
|
||||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
res := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
res[i] = string(v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
// Otherwise, fall back to splitting by newline
|
||||
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
|
||||
return keys
|
||||
}
|
||||
|
||||
func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
// If not in multi-key mode, return the original key string directly.
|
||||
if !channel.ChannelInfo.IsMultiKey {
|
||||
return channel.Key, 0, nil
|
||||
}
|
||||
|
||||
// Obtain all keys (split by \n)
|
||||
keys := channel.GetKeys()
|
||||
if len(keys) == 0 {
|
||||
// No keys available, return error, should disable the channel
|
||||
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||||
}
|
||||
|
||||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||
// helper to get key status, default to enabled when missing
|
||||
getStatus := func(idx int) int {
|
||||
if statusList == nil {
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
if status, ok := statusList[idx]; ok {
|
||||
return status
|
||||
}
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
|
||||
// Collect indexes of enabled keys
|
||||
enabledIdx := make([]int, 0, len(keys))
|
||||
for i := range keys {
|
||||
if getStatus(i) == common.ChannelStatusEnabled {
|
||||
enabledIdx = append(enabledIdx, i)
|
||||
}
|
||||
}
|
||||
// If no specific status list or none enabled, fall back to first key
|
||||
if len(enabledIdx) == 0 {
|
||||
return keys[0], 0, nil
|
||||
}
|
||||
|
||||
switch channel.ChannelInfo.MultiKeyMode {
|
||||
case constant.MultiKeyModeRandom:
|
||||
// Randomly pick one enabled key
|
||||
selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
|
||||
return keys[selectedIdx], selectedIdx, nil
|
||||
case constant.MultiKeyModePolling:
|
||||
// Use channel-specific lock to ensure thread-safe polling
|
||||
lock := getChannelPollingLock(channel.Id)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||||
if err != nil {
|
||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||||
defer func() {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
|
||||
}
|
||||
if !common.MemoryCacheEnabled {
|
||||
_ = channel.SaveChannelInfo()
|
||||
} else {
|
||||
// CacheUpdateChannel(channel)
|
||||
}
|
||||
}()
|
||||
// Start from the saved polling index and look for the next enabled key
|
||||
start := channelInfo.MultiKeyPollingIndex
|
||||
if start < 0 || start >= len(keys) {
|
||||
start = 0
|
||||
}
|
||||
for i := 0; i < len(keys); i++ {
|
||||
idx := (start + i) % len(keys)
|
||||
if getStatus(idx) == common.ChannelStatusEnabled {
|
||||
// update polling index for next call (point to the next position)
|
||||
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
|
||||
return keys[idx], idx, nil
|
||||
}
|
||||
}
|
||||
// Fallback – should not happen, but return first enabled key
|
||||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||||
default:
|
||||
// Unknown mode, default to first enabled key (or original key string)
|
||||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||||
}
|
||||
}
|
||||
|
||||
func (channel *Channel) SaveChannelInfo() error {
|
||||
return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
@@ -61,7 +206,7 @@ func (channel *Channel) GetGroups() []string {
|
||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
otherInfo := make(map[string]interface{})
|
||||
if channel.OtherInfo != "" {
|
||||
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal other info: " + err.Error())
|
||||
}
|
||||
@@ -175,14 +320,20 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
}
|
||||
|
||||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
channel := Channel{Id: id}
|
||||
channel := &Channel{Id: id}
|
||||
var err error = nil
|
||||
if selectAll {
|
||||
err = DB.First(&channel, "id = ?", id).Error
|
||||
err = DB.First(channel, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Omit("key").First(&channel, "id = ?", id).Error
|
||||
err = DB.Omit("key").First(channel, "id = ?", id).Error
|
||||
}
|
||||
return &channel, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func BatchInsertChannels(channels []Channel) error {
|
||||
@@ -266,6 +417,44 @@ func (channel *Channel) Insert() error {
|
||||
}
|
||||
|
||||
func (channel *Channel) Update() error {
|
||||
// If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
var keyStr string
|
||||
if channel.Key != "" {
|
||||
keyStr = channel.Key
|
||||
} else {
|
||||
// If key is not provided, read the existing key from the database
|
||||
if existing, err := GetChannelById(channel.Id, true); err == nil {
|
||||
keyStr = existing.Key
|
||||
}
|
||||
}
|
||||
// Parse the key list (supports newline separation or JSON array)
|
||||
keys := []string{}
|
||||
if keyStr != "" {
|
||||
trimmed := strings.TrimSpace(keyStr)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
keys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
keys[i] = string(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 { // fallback to newline split
|
||||
keys = strings.Split(strings.Trim(keyStr, "\n"), "\n")
|
||||
}
|
||||
}
|
||||
channel.ChannelInfo.MultiKeySize = len(keys)
|
||||
// Clean up status data that exceeds the new key count to prevent index out of range
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
for idx := range channel.ChannelInfo.MultiKeyStatusList {
|
||||
if idx >= channel.ChannelInfo.MultiKeySize {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
var err error
|
||||
err = DB.Model(channel).Updates(channel).Error
|
||||
if err != nil {
|
||||
@@ -308,48 +497,132 @@ func (channel *Channel) Delete() error {
|
||||
|
||||
var channelStatusLock sync.Mutex
|
||||
|
||||
func UpdateChannelStatusById(id int, status int, reason string) bool {
|
||||
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
|
||||
var channelPollingLocks sync.Map
|
||||
|
||||
// getChannelPollingLock returns or creates a mutex for the given channel ID
|
||||
func getChannelPollingLock(channelId int) *sync.Mutex {
|
||||
if lock, exists := channelPollingLocks.Load(channelId); exists {
|
||||
return lock.(*sync.Mutex)
|
||||
}
|
||||
// Create new lock for this channel
|
||||
newLock := &sync.Mutex{}
|
||||
actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
|
||||
return actual.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// CleanupChannelPollingLocks removes locks for channels that no longer exist
|
||||
// This is optional and can be called periodically to prevent memory leaks
|
||||
func CleanupChannelPollingLocks() {
|
||||
var activeChannelIds []int
|
||||
DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
|
||||
|
||||
activeChannelSet := make(map[int]bool)
|
||||
for _, id := range activeChannelIds {
|
||||
activeChannelSet[id] = true
|
||||
}
|
||||
|
||||
channelPollingLocks.Range(func(key, value interface{}) bool {
|
||||
channelId := key.(int)
|
||||
if !activeChannelSet[channelId] {
|
||||
channelPollingLocks.Delete(channelId)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
|
||||
keys := channel.GetKeys()
|
||||
if len(keys) == 0 {
|
||||
channel.Status = status
|
||||
} else {
|
||||
var keyIndex int
|
||||
for i, key := range keys {
|
||||
if key == usingKey {
|
||||
keyIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
if status == common.ChannelStatusEnabled {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||
} else {
|
||||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
}
|
||||
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
|
||||
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
|
||||
}
|
||||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||||
channel.Status = common.ChannelStatusAutoDisabled
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = "All keys are disabled"
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
|
||||
if common.MemoryCacheEnabled {
|
||||
channelStatusLock.Lock()
|
||||
defer channelStatusLock.Unlock()
|
||||
|
||||
channelCache, _ := CacheGetChannel(id)
|
||||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||||
if channelCache != nil && channelCache.Status == status {
|
||||
channelCache, _ := CacheGetChannel(channelId)
|
||||
if channelCache == nil {
|
||||
return false
|
||||
}
|
||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
||||
if channelCache == nil && status != common.ChannelStatusEnabled {
|
||||
return false
|
||||
if channelCache.ChannelInfo.IsMultiKey {
|
||||
// 如果是多Key模式,更新缓存中的状态
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||||
//CacheUpdateChannel(channelCache)
|
||||
//return true
|
||||
} else {
|
||||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||||
if channelCache.Status == status {
|
||||
return false
|
||||
}
|
||||
CacheUpdateChannelStatus(channelId, status)
|
||||
}
|
||||
CacheUpdateChannelStatus(id, status)
|
||||
}
|
||||
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
||||
|
||||
shouldUpdateAbilities := false
|
||||
defer func() {
|
||||
if shouldUpdateAbilities {
|
||||
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
||||
if err != nil {
|
||||
common.SysError("failed to update ability status: " + err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
channel, err := GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.SysError("failed to update ability status: " + err.Error())
|
||||
return false
|
||||
}
|
||||
channel, err := GetChannelById(id, true)
|
||||
if err != nil {
|
||||
// find channel by id error, directly update status
|
||||
result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
|
||||
if result.Error != nil {
|
||||
common.SysError("failed to update channel status: " + result.Error.Error())
|
||||
return false
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if channel.Status == status {
|
||||
return false
|
||||
}
|
||||
// find channel by id success, update status and other info
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = reason
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
channel.Status = status
|
||||
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
beforeStatus := channel.Status
|
||||
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
||||
if beforeStatus != channel.Status {
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
} else {
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = reason
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
channel.Status = status
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
err = channel.Save()
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel status: " + err.Error())
|
||||
@@ -518,7 +791,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
func (channel *Channel) ValidateSettings() error {
|
||||
channelParams := &dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
err := common.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -529,16 +802,18 @@ func (channel *Channel) ValidateSettings() error {
|
||||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
setting := dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
channel.Setting = nil // 清空设置以避免后续错误
|
||||
_ = channel.Save() // 保存修改
|
||||
}
|
||||
}
|
||||
return setting
|
||||
}
|
||||
|
||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
settingBytes, err := common.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
@@ -549,7 +824,7 @@ func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||||
paramOverride := make(map[string]interface{})
|
||||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||||
err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal param override: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -14,8 +15,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
var group2model2channels map[string]map[string][]int // enabled channel
|
||||
var channelsIDM map[int]*Channel // all channels include disabled
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
@@ -24,7 +25,7 @@ func InitChannelCache() {
|
||||
}
|
||||
newChannelId2channel := make(map[int]*Channel)
|
||||
var channels []*Channel
|
||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
||||
DB.Find(&channels)
|
||||
for _, channel := range channels {
|
||||
newChannelId2channel[channel.Id] = channel
|
||||
}
|
||||
@@ -34,21 +35,22 @@ func InitChannelCache() {
|
||||
for _, ability := range abilities {
|
||||
groups[ability.Group] = true
|
||||
}
|
||||
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
||||
newChannelsIDM := make(map[int]*Channel)
|
||||
newGroup2model2channels := make(map[string]map[string][]int)
|
||||
for group := range groups {
|
||||
newGroup2model2channels[group] = make(map[string][]*Channel)
|
||||
newGroup2model2channels[group] = make(map[string][]int)
|
||||
}
|
||||
for _, channel := range channels {
|
||||
newChannelsIDM[channel.Id] = channel
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue // skip disabled channels
|
||||
}
|
||||
groups := strings.Split(channel.Group, ",")
|
||||
for _, group := range groups {
|
||||
models := strings.Split(channel.Models, ",")
|
||||
for _, model := range models {
|
||||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
||||
newGroup2model2channels[group][model] = make([]*Channel, 0)
|
||||
newGroup2model2channels[group][model] = make([]int, 0)
|
||||
}
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -57,7 +59,7 @@ func InitChannelCache() {
|
||||
for group, model2channels := range newGroup2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return channels[i].GetPriority() > channels[j].GetPriority()
|
||||
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
|
||||
})
|
||||
newGroup2model2channels[group][model] = channels
|
||||
}
|
||||
@@ -65,7 +67,21 @@ func InitChannelCache() {
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelsIDM = newChannelsIDM
|
||||
//channelsIDM = newChannelId2channel
|
||||
for i, channel := range newChannelId2channel {
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
channel.Keys = channel.GetKeys()
|
||||
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||
if oldChannel, ok := channelsIDM[i]; ok {
|
||||
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
|
||||
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
channelsIDM = newChannelId2channel
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
@@ -108,9 +124,6 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string,
|
||||
return nil, group, err
|
||||
}
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, group, errors.New("channel not found")
|
||||
}
|
||||
return channel, selectGroup, nil
|
||||
}
|
||||
|
||||
@@ -128,16 +141,27 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
}
|
||||
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
channels := group2model2channels[group][model]
|
||||
channelSyncLock.RUnlock()
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(channels) == 1 {
|
||||
if channel, ok := channelsIDM[channels[0]]; ok {
|
||||
return channel, nil
|
||||
}
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
|
||||
}
|
||||
|
||||
uniquePriorities := make(map[int]bool)
|
||||
for _, channel := range channels {
|
||||
uniquePriorities[int(channel.GetPriority())] = true
|
||||
for _, channelId := range channels {
|
||||
if channel, ok := channelsIDM[channelId]; ok {
|
||||
uniquePriorities[int(channel.GetPriority())] = true
|
||||
} else {
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||||
}
|
||||
}
|
||||
var sortedUniquePriorities []int
|
||||
for priority := range uniquePriorities {
|
||||
@@ -152,9 +176,13 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
|
||||
// get the priority for the given retry number
|
||||
var targetChannels []*Channel
|
||||
for _, channel := range channels {
|
||||
if channel.GetPriority() == targetPriority {
|
||||
targetChannels = append(targetChannels, channel)
|
||||
for _, channelId := range channels {
|
||||
if channel, ok := channelsIDM[channelId]; ok {
|
||||
if channel.GetPriority() == targetPriority {
|
||||
targetChannels = append(targetChannels, channel)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,11 +216,29 @@ func CacheGetChannel(id int) (*Channel, error) {
|
||||
|
||||
c, ok := channelsIDM[id]
|
||||
if !ok {
|
||||
return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
channel, err := GetChannelById(id, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &channel.ChannelInfo, nil
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
c, ok := channelsIDM[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
return &c.ChannelInfo, nil
|
||||
}
|
||||
|
||||
func CacheUpdateChannelStatus(id int, status int) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
@@ -202,4 +248,35 @@ func CacheUpdateChannelStatus(id int, status int) {
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
if status != common.ChannelStatusEnabled {
|
||||
// delete the channel from group2model2channels
|
||||
for group, model2channels := range group2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
for i, channelId := range channels {
|
||||
if channelId == id {
|
||||
// remove the channel from the slice
|
||||
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CacheUpdateChannel(channel *Channel) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
channelSyncLock.Lock()
|
||||
defer channelSyncLock.Unlock()
|
||||
if channel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
|
||||
|
||||
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||||
channelsIDM[channel.Id] = channel
|
||||
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||||
}
|
||||
@@ -27,7 +27,7 @@ type Log struct {
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
UseTime int `json:"use_time" gorm:"default:0"`
|
||||
IsStream bool `json:"is_stream" gorm:"default:false"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
@@ -49,7 +49,7 @@ func formatUserLogs(logs []*Log) {
|
||||
for i := range logs {
|
||||
logs[i].ChannelName = ""
|
||||
var otherMap map[string]interface{}
|
||||
otherMap = common.StrToMap(logs[i].Other)
|
||||
otherMap, _ = common.StrToMap(logs[i].Other)
|
||||
if otherMap != nil {
|
||||
// delete admin
|
||||
delete(otherMap, "admin_info")
|
||||
|
||||
@@ -57,7 +57,7 @@ func initCol() {
|
||||
}
|
||||
}
|
||||
// log sql type and database type
|
||||
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
}
|
||||
|
||||
var DB *gorm.DB
|
||||
@@ -225,12 +225,6 @@ func InitLogDB() (err error) {
|
||||
if !common.IsMasterNode {
|
||||
return nil
|
||||
}
|
||||
//if common.UsingMySQL {
|
||||
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
|
||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
|
||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
|
||||
//}
|
||||
common.SysLog("database migration started")
|
||||
err = migrateLOGDB()
|
||||
return err
|
||||
@@ -257,6 +251,8 @@ func migrateDB() error {
|
||||
&QuotaData{},
|
||||
&Task{},
|
||||
&Setup{},
|
||||
&TwoFA{},
|
||||
&TwoFABackupCode{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -266,7 +262,6 @@ func migrateDB() error {
|
||||
|
||||
func migrateDBFast() error {
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 12) // Buffer size matches number of migrations
|
||||
|
||||
migrations := []struct {
|
||||
model interface{}
|
||||
@@ -284,7 +279,11 @@ func migrateDBFast() error {
|
||||
{&QuotaData{}, "QuotaData"},
|
||||
{&Task{}, "Task"},
|
||||
{&Setup{}, "Setup"},
|
||||
{&TwoFA{}, "TwoFA"},
|
||||
{&TwoFABackupCode{}, "TwoFABackupCode"},
|
||||
}
|
||||
// 动态计算migration数量,确保errChan缓冲区足够大
|
||||
errChan := make(chan error, len(migrations))
|
||||
|
||||
for _, m := range migrations {
|
||||
wg.Add(1)
|
||||
|
||||
@@ -74,7 +74,13 @@ func InitOptionMap() {
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
|
||||
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
|
||||
common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
|
||||
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
|
||||
common.OptionMap["StripePriceId"] = setting.StripePriceId
|
||||
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
@@ -306,8 +312,20 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.EpayKey = value
|
||||
case "Price":
|
||||
setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||
case "USDExchangeRate":
|
||||
setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
setting.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "StripeApiSecret":
|
||||
setting.StripeApiSecret = value
|
||||
case "StripeWebhookSecret":
|
||||
setting.StripeWebhookSecret = value
|
||||
case "StripePriceId":
|
||||
setting.StripePriceId = value
|
||||
case "StripeUnitPrice":
|
||||
setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "StripeMinTopUp":
|
||||
setting.StripeMinTopUp, _ = strconv.Atoi(value)
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
|
||||
@@ -116,7 +116,7 @@ func updatePricing() {
|
||||
pricing.ModelPrice = modelPrice
|
||||
pricing.QuotaType = 1
|
||||
} else {
|
||||
modelRatio, _ := ratio_setting.GetModelRatio(model)
|
||||
modelRatio, _, _ := ratio_setting.GetModelRatio(model)
|
||||
pricing.ModelRatio = modelRatio
|
||||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||||
pricing.QuotaType = 0
|
||||
|
||||
@@ -20,8 +20,8 @@ type Token struct {
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota"`
|
||||
ModelLimitsEnabled bool `json:"model_limits_enabled"`
|
||||
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
|
||||
AllowIps *string `json:"allow_ips" gorm:"default:''"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (topUp *TopUp) Insert() error {
|
||||
@@ -41,3 +50,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
|
||||
}
|
||||
return topUp
|
||||
}
|
||||
|
||||
func Recharge(referenceId string, customerId string) (err error) {
|
||||
if referenceId == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
|
||||
var quota float64
|
||||
topUp := &TopUp{}
|
||||
|
||||
refCol := "`trade_no`"
|
||||
if common.UsingPostgreSQL {
|
||||
refCol = `"trade_no"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
|
||||
if err != nil {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
return errors.New("充值订单状态错误")
|
||||
}
|
||||
|
||||
topUp.CompleteTime = common.GetTimestamp()
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
err = tx.Save(topUp).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
quota = topUp.Money * common.QuotaPerUnit
|
||||
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return errors.New("充值失败," + err.Error())
|
||||
}
|
||||
|
||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
322
model/twofa.go
Normal file
322
model/twofa.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
|
||||
|
||||
// TwoFA 用户2FA设置表
|
||||
type TwoFA struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"unique;not null;index"`
|
||||
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
|
||||
IsEnabled bool `json:"is_enabled" gorm:"default:false"`
|
||||
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
|
||||
LockedUntil *time.Time `json:"locked_until,omitempty"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
// TwoFABackupCode 备用码使用记录表
|
||||
type TwoFABackupCode struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"not null;index"`
|
||||
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
|
||||
IsUsed bool `json:"is_used" gorm:"default:false"`
|
||||
UsedAt *time.Time `json:"used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
// GetTwoFAByUserId 根据用户ID获取2FA设置
|
||||
func GetTwoFAByUserId(userId int) (*TwoFA, error) {
|
||||
if userId == 0 {
|
||||
return nil, errors.New("用户ID不能为空")
|
||||
}
|
||||
|
||||
var twoFA TwoFA
|
||||
err := DB.Where("user_id = ?", userId).First(&twoFA).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil // 返回nil表示未设置2FA
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &twoFA, nil
|
||||
}
|
||||
|
||||
// IsTwoFAEnabled 检查用户是否启用了2FA
|
||||
func IsTwoFAEnabled(userId int) bool {
|
||||
twoFA, err := GetTwoFAByUserId(userId)
|
||||
if err != nil || twoFA == nil {
|
||||
return false
|
||||
}
|
||||
return twoFA.IsEnabled
|
||||
}
|
||||
|
||||
// CreateTwoFA 创建2FA设置
|
||||
func (t *TwoFA) Create() error {
|
||||
// 检查用户是否已存在2FA设置
|
||||
existing, err := GetTwoFAByUserId(t.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
return errors.New("用户已存在2FA设置")
|
||||
}
|
||||
|
||||
// 验证用户存在
|
||||
var user User
|
||||
if err := DB.First(&user, t.UserId).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return DB.Create(t).Error
|
||||
}
|
||||
|
||||
// Update 更新2FA设置
|
||||
func (t *TwoFA) Update() error {
|
||||
if t.Id == 0 {
|
||||
return errors.New("2FA记录ID不能为空")
|
||||
}
|
||||
return DB.Save(t).Error
|
||||
}
|
||||
|
||||
// Delete 删除2FA设置
|
||||
func (t *TwoFA) Delete() error {
|
||||
if t.Id == 0 {
|
||||
return errors.New("2FA记录ID不能为空")
|
||||
}
|
||||
|
||||
// 使用事务确保原子性
|
||||
return DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 同时删除相关的备用码记录(硬删除)
|
||||
if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 硬删除2FA记录
|
||||
return tx.Unscoped().Delete(t).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ResetFailedAttempts 重置失败尝试次数
|
||||
func (t *TwoFA) ResetFailedAttempts() error {
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
return t.Update()
|
||||
}
|
||||
|
||||
// IncrementFailedAttempts 增加失败尝试次数
|
||||
func (t *TwoFA) IncrementFailedAttempts() error {
|
||||
t.FailedAttempts++
|
||||
|
||||
// 检查是否需要锁定
|
||||
if t.FailedAttempts >= common.MaxFailAttempts {
|
||||
lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second)
|
||||
t.LockedUntil = &lockUntil
|
||||
}
|
||||
|
||||
return t.Update()
|
||||
}
|
||||
|
||||
// IsLocked 检查账户是否被锁定
|
||||
func (t *TwoFA) IsLocked() bool {
|
||||
if t.LockedUntil == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*t.LockedUntil)
|
||||
}
|
||||
|
||||
// CreateBackupCodes 创建备用码
|
||||
func CreateBackupCodes(userId int, codes []string) error {
|
||||
return DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 先删除现有的备用码
|
||||
if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建新的备用码记录
|
||||
for _, code := range codes {
|
||||
hashedCode, err := common.HashBackupCode(code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupCode := TwoFABackupCode{
|
||||
UserId: userId,
|
||||
CodeHash: hashedCode,
|
||||
IsUsed: false,
|
||||
}
|
||||
|
||||
if err := tx.Create(&backupCode).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateBackupCode 验证并使用备用码
|
||||
func ValidateBackupCode(userId int, code string) (bool, error) {
|
||||
if !common.ValidateBackupCode(code) {
|
||||
return false, errors.New("验证码或备用码不正确")
|
||||
}
|
||||
|
||||
normalizedCode := common.NormalizeBackupCode(code)
|
||||
|
||||
// 查找未使用的备用码
|
||||
var backupCodes []TwoFABackupCode
|
||||
if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 验证备用码
|
||||
for _, bc := range backupCodes {
|
||||
if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) {
|
||||
// 标记为已使用
|
||||
now := time.Now()
|
||||
bc.IsUsed = true
|
||||
bc.UsedAt = &now
|
||||
|
||||
if err := DB.Save(&bc).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetUnusedBackupCodeCount 获取未使用的备用码数量
|
||||
func GetUnusedBackupCodeCount(userId int) (int, error) {
|
||||
var count int64
|
||||
err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error
|
||||
return int(count), err
|
||||
}
|
||||
|
||||
// DisableTwoFA 禁用用户的2FA
|
||||
func DisableTwoFA(userId int) error {
|
||||
twoFA, err := GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if twoFA == nil {
|
||||
return ErrTwoFANotEnabled
|
||||
}
|
||||
|
||||
// 删除2FA设置和备用码
|
||||
return twoFA.Delete()
|
||||
}
|
||||
|
||||
// EnableTwoFA 启用2FA
|
||||
func (t *TwoFA) Enable() error {
|
||||
t.IsEnabled = true
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
return t.Update()
|
||||
}
|
||||
|
||||
// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录
|
||||
func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
|
||||
// 检查是否被锁定
|
||||
if t.IsLocked() {
|
||||
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
|
||||
// 验证TOTP码
|
||||
if !common.ValidateTOTPCode(t.Secret, code) {
|
||||
// 增加失败次数
|
||||
if err := t.IncrementFailedAttempts(); err != nil {
|
||||
common.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 验证成功,重置失败次数并更新最后使用时间
|
||||
now := time.Now()
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
t.LastUsedAt = &now
|
||||
|
||||
if err := t.Update(); err != nil {
|
||||
common.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录
|
||||
func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
|
||||
// 检查是否被锁定
|
||||
if t.IsLocked() {
|
||||
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
|
||||
// 验证备用码
|
||||
valid, err := ValidateBackupCode(t.UserId, code)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !valid {
|
||||
// 增加失败次数
|
||||
if err := t.IncrementFailedAttempts(); err != nil {
|
||||
common.SysError("更新2FA失败次数失败: " + err.Error())
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 验证成功,重置失败次数并更新最后使用时间
|
||||
now := time.Now()
|
||||
t.FailedAttempts = 0
|
||||
t.LockedUntil = nil
|
||||
t.LastUsedAt = &now
|
||||
|
||||
if err := t.Update(); err != nil {
|
||||
common.SysError("更新2FA使用记录失败: " + err.Error())
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetTwoFAStats 获取2FA统计信息(管理员使用)
|
||||
func GetTwoFAStats() (map[string]interface{}, error) {
|
||||
var totalUsers, enabledUsers int64
|
||||
|
||||
// 总用户数
|
||||
if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 启用2FA的用户数
|
||||
if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
enabledRate := float64(0)
|
||||
if totalUsers > 0 {
|
||||
enabledRate = float64(enabledUsers) / float64(totalUsers) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_users": totalUsers,
|
||||
"enabled_users": enabledUsers,
|
||||
"enabled_rate": fmt.Sprintf("%.1f%%", enabledRate),
|
||||
}, nil
|
||||
}
|
||||
@@ -43,6 +43,7 @@ type User struct {
|
||||
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
|
||||
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
||||
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||
StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
|
||||
}
|
||||
|
||||
func (user *User) ToBaseUser() *UserBase {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -36,7 +35,7 @@ func (user *UserBase) WriteContext(c *gin.Context) {
|
||||
func (user *UserBase) GetSetting() dto.UserSetting {
|
||||
setting := dto.UserSetting{}
|
||||
if user.Setting != "" {
|
||||
err := json.Unmarshal([]byte(user.Setting), &setting)
|
||||
err := common.Unmarshal([]byte(user.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package relay
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -12,7 +11,10 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||
@@ -54,13 +56,13 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
return audioRequest, nil
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
@@ -73,7 +75,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -88,23 +90,23 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -112,18 +114,18 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -21,10 +22,11 @@ type Adaptor interface {
|
||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -17,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -99,7 +105,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
@@ -109,9 +115,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -4,15 +4,17 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
||||
@@ -124,49 +126,46 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliTaskResponse.Message != "" {
|
||||
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
|
||||
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, nil
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -31,29 +31,26 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var aliResponse AliRerankResponse
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
usage := dto.Usage{
|
||||
@@ -68,14 +65,10 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
@@ -8,9 +8,10 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -38,11 +39,11 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
|
||||
}
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var fullTextResponse dto.OpenAIEmbeddingResponse
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var fullTextResponse dto.FlexibleEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
@@ -53,11 +54,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
@@ -119,7 +120,7 @@ func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStre
|
||||
return &response
|
||||
}
|
||||
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
@@ -174,32 +175,29 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: "ali_error",
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
@@ -203,6 +203,9 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
|
||||
}
|
||||
}
|
||||
|
||||
func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
|
||||
return doRequest(c, req, info)
|
||||
}
|
||||
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
|
||||
var client *http.Client
|
||||
var err error
|
||||
@@ -220,7 +223,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
helper.SetEventStreamHeaders(c)
|
||||
// 处理流式请求的 ping 保活
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
if generalSettings.PingIntervalEnabled {
|
||||
if generalSettings.PingIntervalEnabled && !info.DisablePing {
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
stopPinger = startPingKeepAlive(c, pingInterval)
|
||||
// 使用defer确保在任何情况下都能停止ping goroutine
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -21,6 +22,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", request)
|
||||
@@ -84,7 +90,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -3,19 +3,22 @@ package aws
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
)
|
||||
|
||||
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||
@@ -65,24 +68,21 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
||||
return modelPrefix + "." + awsModelId
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) (string, error) {
|
||||
func awsModelID(requestModel string) string {
|
||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelID, nil
|
||||
return awsModelID
|
||||
}
|
||||
|
||||
return requestModel, nil
|
||||
return requestModel
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
@@ -98,42 +98,42 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
|
||||
claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
@@ -149,25 +149,25 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
|
||||
return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
stream := awsResp.GetStream()
|
||||
defer stream.Close()
|
||||
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
@@ -176,18 +176,18 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
for event := range stream.Events() {
|
||||
switch v := event.(type) {
|
||||
case *types.ResponseStreamMemberChunk:
|
||||
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
|
||||
info.SetFirstResponseTime()
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
||||
if respErr != nil {
|
||||
return respErr, nil
|
||||
}
|
||||
case *types.UnknownUnionMember:
|
||||
case *bedrockruntimeTypes.UnknownUnionMember:
|
||||
fmt.Println("unknown tag:", v.Tag)
|
||||
return wrapErr(errors.New("unknown response type")), nil
|
||||
return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
|
||||
default:
|
||||
fmt.Println("union is nil or unknown type")
|
||||
return wrapErr(errors.New("nil or unknown response type")), nil
|
||||
return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -17,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -140,15 +146,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = baiduStreamHandler(c, resp)
|
||||
err, usage = baiduStreamHandler(c, info, resp)
|
||||
} else {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = baiduEmbeddingHandler(c, resp)
|
||||
err, usage = baiduEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
err, usage = baiduHandler(c, resp)
|
||||
err, usage = baiduHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
@@ -110,92 +112,49 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
usage := &dto.Usage{}
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := common.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("error sending stream response: " + err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, &usage
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var baiduResponse BaiduChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -203,32 +162,24 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var baiduResponse BaiduEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -17,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -42,15 +48,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 0 || keyParts[0] == "" {
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+keyParts[0])
|
||||
return nil
|
||||
}
|
||||
@@ -92,11 +98,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -23,6 +24,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
@@ -94,7 +100,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -12,11 +12,18 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
WebSearchMaxUsesLow = 1
|
||||
WebSearchMaxUsesMedium = 5
|
||||
WebSearchMaxUsesHigh = 10
|
||||
)
|
||||
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "stop_sequence":
|
||||
@@ -64,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
|
||||
claudeTools := make([]any, 0, len(textRequest.Tools))
|
||||
|
||||
for _, tool := range textRequest.Tools {
|
||||
if params, ok := tool.Function.Parameters.(map[string]any); ok {
|
||||
@@ -84,10 +91,62 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
}
|
||||
claudeTool.InputSchema[s] = a
|
||||
}
|
||||
claudeTools = append(claudeTools, claudeTool)
|
||||
claudeTools = append(claudeTools, &claudeTool)
|
||||
}
|
||||
}
|
||||
|
||||
// Web search tool
|
||||
// https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
webSearchTool := dto.ClaudeWebSearchTool{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
}
|
||||
|
||||
// 处理 user_location
|
||||
if textRequest.WebSearchOptions.UserLocation != nil {
|
||||
anthropicUserLocation := &dto.ClaudeWebSearchUserLocation{
|
||||
Type: "approximate", // 固定为 "approximate"
|
||||
}
|
||||
|
||||
// 解析 UserLocation JSON
|
||||
var userLocationMap map[string]interface{}
|
||||
if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
|
||||
// 检查是否有 approximate 字段
|
||||
if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
|
||||
if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
|
||||
anthropicUserLocation.Timezone = timezone
|
||||
}
|
||||
if country, ok := approximateData["country"].(string); ok && country != "" {
|
||||
anthropicUserLocation.Country = country
|
||||
}
|
||||
if region, ok := approximateData["region"].(string); ok && region != "" {
|
||||
anthropicUserLocation.Region = region
|
||||
}
|
||||
if city, ok := approximateData["city"].(string); ok && city != "" {
|
||||
anthropicUserLocation.City = city
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
webSearchTool.UserLocation = anthropicUserLocation
|
||||
}
|
||||
|
||||
// 处理 search_context_size 转换为 max_uses
|
||||
if textRequest.WebSearchOptions.SearchContextSize != "" {
|
||||
switch textRequest.WebSearchOptions.SearchContextSize {
|
||||
case "low":
|
||||
webSearchTool.MaxUses = WebSearchMaxUsesLow
|
||||
case "medium":
|
||||
webSearchTool.MaxUses = WebSearchMaxUsesMedium
|
||||
case "high":
|
||||
webSearchTool.MaxUses = WebSearchMaxUsesHigh
|
||||
}
|
||||
}
|
||||
|
||||
claudeTools = append(claudeTools, &webSearchTool)
|
||||
}
|
||||
|
||||
claudeRequest := dto.ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
@@ -99,6 +158,14 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
Tools: claudeTools,
|
||||
}
|
||||
|
||||
// 处理 tool_choice 和 parallel_tool_calls
|
||||
if textRequest.ToolChoice != nil || textRequest.ParallelTooCalls != nil {
|
||||
claudeToolChoice := mapToolChoice(textRequest.ToolChoice, textRequest.ParallelTooCalls)
|
||||
if claudeToolChoice != nil {
|
||||
claudeRequest.ToolChoice = claudeToolChoice
|
||||
}
|
||||
}
|
||||
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
}
|
||||
@@ -123,9 +190,30 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
}
|
||||
|
||||
if textRequest.ReasoningEffort != "" {
|
||||
switch textRequest.ReasoningEffort {
|
||||
case "low":
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: common.GetPointer[int](1280),
|
||||
}
|
||||
case "medium":
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: common.GetPointer[int](2048),
|
||||
}
|
||||
case "high":
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: common.GetPointer[int](4096),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 指定了 reasoning 参数,覆盖 budgetTokens
|
||||
if textRequest.Reasoning != nil {
|
||||
var reasoning openrouter.RequestReasoning
|
||||
if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
|
||||
if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -517,22 +605,15 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
return true
|
||||
}
|
||||
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Code: "stream_response_error",
|
||||
Type: claudeResponse.Error.Type,
|
||||
Message: claudeResponse.Error.Message,
|
||||
},
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
@@ -593,15 +674,15 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
var err *types.NewAPIError
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
|
||||
if err != nil {
|
||||
@@ -617,21 +698,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.UnmarshalJson(data, &claudeResponse)
|
||||
err := common.Unmarshal(data, &claudeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
@@ -652,21 +726,25 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
openaiResponse.Usage = *claudeInfo.Usage
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
responseData = data
|
||||
}
|
||||
|
||||
if claudeResponse.Usage.ServerToolUse != nil && claudeResponse.Usage.ServerToolUse.WebSearchRequests > 0 {
|
||||
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, nil, responseData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
@@ -674,7 +752,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("responseBody: ", string(responseBody))
|
||||
@@ -685,3 +763,51 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
}
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
|
||||
var claudeToolChoice *dto.ClaudeToolChoice
|
||||
|
||||
// 处理 tool_choice 字符串值
|
||||
if toolChoiceStr, ok := toolChoice.(string); ok {
|
||||
switch toolChoiceStr {
|
||||
case "auto":
|
||||
claudeToolChoice = &dto.ClaudeToolChoice{
|
||||
Type: "auto",
|
||||
}
|
||||
case "required":
|
||||
claudeToolChoice = &dto.ClaudeToolChoice{
|
||||
Type: "any",
|
||||
}
|
||||
case "none":
|
||||
claudeToolChoice = &dto.ClaudeToolChoice{
|
||||
Type: "none",
|
||||
}
|
||||
}
|
||||
} else if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
|
||||
// 处理 tool_choice 对象值
|
||||
if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
|
||||
if toolName, ok := function["name"].(string); ok {
|
||||
claudeToolChoice = &dto.ClaudeToolChoice{
|
||||
Type: "tool",
|
||||
Name: toolName,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 parallel_tool_calls
|
||||
if parallelToolCalls != nil {
|
||||
if claudeToolChoice == nil {
|
||||
// 如果没有 tool_choice,但有 parallel_tool_calls,创建默认的 auto 类型
|
||||
claudeToolChoice = &dto.ClaudeToolChoice{
|
||||
Type: "auto",
|
||||
}
|
||||
}
|
||||
|
||||
// 设置 disable_parallel_tool_use
|
||||
// 如果 parallel_tool_calls 为 true,则 disable_parallel_tool_use 为 false
|
||||
claudeToolChoice.DisableParallelToolUse = !*parallelToolCalls
|
||||
}
|
||||
|
||||
return claudeToolChoice
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -17,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -94,20 +100,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
err, usage = cfStreamHandler(c, resp, info)
|
||||
err, usage = cfStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = cfHandler(c, resp, info)
|
||||
err, usage = cfHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
err, usage = cfSTTHandler(c, resp, info)
|
||||
err, usage = cfSTTHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package cloudflare
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -11,8 +10,11 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
||||
@@ -25,7 +27,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
|
||||
}
|
||||
}
|
||||
|
||||
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
@@ -86,16 +88,16 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var response dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
response.Model = info.UpstreamModelName
|
||||
var responseText string
|
||||
@@ -107,7 +109,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -115,16 +117,16 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var cfResp CfAudioResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &cfResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
audioResp := &dto.AudioResponse{
|
||||
@@ -133,7 +135,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
|
||||
jsonResponse, err := json.Marshal(audioResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -16,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -71,14 +77,14 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = cohereRerankHandler(c, resp, info)
|
||||
usage, err = cohereRerankHandler(c, resp, info)
|
||||
} else {
|
||||
if info.IsStream {
|
||||
err, usage = cohereStreamHandler(c, resp, info)
|
||||
usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
|
||||
} else {
|
||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
usage, err = cohereHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -3,7 +3,6 @@ package cohere
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -11,8 +10,11 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
@@ -76,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
@@ -164,20 +166,20 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
if usage.PromptTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return nil, usage
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
createdTime := common.GetTimestamp()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||
@@ -188,7 +190,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
openaiResp.Id = cohereResp.ResponseId
|
||||
openaiResp.Created = createdTime
|
||||
openaiResp.Object = "chat.completion"
|
||||
openaiResp.Model = modelName
|
||||
openaiResp.Model = info.UpstreamModelName
|
||||
openaiResp.Usage = usage
|
||||
|
||||
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
||||
@@ -201,24 +203,24 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
|
||||
jsonResponse, err := json.Marshal(openaiResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereRerankResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||
@@ -237,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user