mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 15:46:44 +00:00
Compare commits
568 Commits
v0.8.3.0
...
v0.8.8.0-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc09051d8b | ||
|
|
1f5ef24ecd | ||
|
|
b1faf42529 | ||
|
|
6a85206e32 | ||
|
|
e3d3e697d3 | ||
|
|
db9b333930 | ||
|
|
f7b284ad73 | ||
|
|
e1970e8a66 | ||
|
|
0cd93d67ff | ||
|
|
6e806e21bd | ||
|
|
a8462c1b70 | ||
|
|
706ea8b649 | ||
|
|
95d46d1dfc | ||
|
|
010f27678d | ||
|
|
d87117a2cf | ||
|
|
4ed92a94a1 | ||
|
|
821ea34a3c | ||
|
|
ecb3d01376 | ||
|
|
e322ed4f05 | ||
|
|
bcf7e78665 | ||
|
|
0cb2bb2ea7 | ||
|
|
c5d97597c4 | ||
|
|
a8a42cbfa8 | ||
|
|
19df2ac234 | ||
|
|
e7524c85c2 | ||
|
|
a4356727e9 | ||
|
|
f15a53fae4 | ||
|
|
8e3cf2eaab | ||
|
|
c51ec3135b | ||
|
|
2469c439b1 | ||
|
|
1297addfb1 | ||
|
|
d6cbf43373 | ||
|
|
df647e7b42 | ||
|
|
fe16d05fbb | ||
|
|
1430c05b6c | ||
|
|
b25841e50d | ||
|
|
b704fc9254 | ||
|
|
352da66bd1 | ||
|
|
8205ad2cd0 | ||
|
|
e162b9c169 | ||
|
|
77e3502028 | ||
|
|
ae0461692c | ||
|
|
13bdb80958 | ||
|
|
6f74e7b738 | ||
|
|
eaee89f77a | ||
|
|
756a8c50d6 | ||
|
|
e0b859dbbe | ||
|
|
07b64ff1a4 | ||
|
|
7bc9192f3f | ||
|
|
057e551059 | ||
|
|
2f80c814aa | ||
|
|
136a029bb4 | ||
|
|
d4b32a403b | ||
|
|
722b187f83 | ||
|
|
0c5c5823bf | ||
|
|
f5a6b7d1f0 | ||
|
|
bcd236286c | ||
|
|
6c4ada5098 | ||
|
|
2402715492 | ||
|
|
f32cf02714 | ||
|
|
e224ee5498 | ||
|
|
90011aa0c9 | ||
|
|
d0589468c1 | ||
|
|
6ef5acbfe5 | ||
|
|
efe894cad6 | ||
|
|
2a366c176d | ||
|
|
8e280a6a24 | ||
|
|
f144518e0e | ||
|
|
fcc006ecd3 | ||
|
|
5fbadc6b21 | ||
|
|
7902570855 | ||
|
|
55898780f1 | ||
|
|
d16cb90c2f | ||
|
|
66dd514c56 | ||
|
|
ba40748118 | ||
|
|
3538cefe68 | ||
|
|
f77aef82d2 | ||
|
|
4d0037a40c | ||
|
|
fd7a4461cc | ||
|
|
8bc6ddbca8 | ||
|
|
7d50e432b5 | ||
|
|
6103888610 | ||
|
|
4d8189f21b | ||
|
|
cddb778577 | ||
|
|
fa506ec04f | ||
|
|
0eaeef5723 | ||
|
|
f87054895e | ||
|
|
d74a5bd507 | ||
|
|
b5d4535db6 | ||
|
|
4d7562fd79 | ||
|
|
5b869376ab | ||
|
|
19c522d9bc | ||
|
|
1d4ecad134 | ||
|
|
805464e406 | ||
|
|
c674c3561a | ||
|
|
7aa2972c3f | ||
|
|
986558fea7 | ||
|
|
818e34682c | ||
|
|
252fddf3de | ||
|
|
39079e7aff | ||
|
|
1fa4518bb9 | ||
|
|
1b739e87ae | ||
|
|
e944983567 | ||
|
|
4fccaf3284 | ||
|
|
0a79dc9ecc | ||
|
|
847a8c8c4d | ||
|
|
a1018c5823 | ||
|
|
323417182a | ||
|
|
f3bcf570f4 | ||
|
|
a3059597fb | ||
|
|
d19a6914f9 | ||
|
|
4313ede132 | ||
|
|
635bfd4aba | ||
|
|
38e72e1af7 | ||
|
|
26644bfd1e | ||
|
|
6a827fc7b9 | ||
|
|
3b3ae9c0dd | ||
|
|
301909e3e5 | ||
|
|
97a9c8627c | ||
|
|
56c1fbecea | ||
|
|
de9d18a2fe | ||
|
|
be16ad26b5 | ||
|
|
d762da9141 | ||
|
|
c05d6f7cdf | ||
|
|
7af3fb5ae4 | ||
|
|
3ac54b2178 | ||
|
|
42a26f076a | ||
|
|
3b67759730 | ||
|
|
5407a8345f | ||
|
|
3fe509757b | ||
|
|
952b679ca3 | ||
|
|
6799daacd1 | ||
|
|
fa02b5150c | ||
|
|
63a1904242 | ||
|
|
1e3450fdcb | ||
|
|
5541026b86 | ||
|
|
c36c920b34 | ||
|
|
514fea65c4 | ||
|
|
e269b3bfdd | ||
|
|
0862a9bfa7 | ||
|
|
f43c695527 | ||
|
|
ead43f081c | ||
|
|
4e2a3d61dc | ||
|
|
218ad6bbe0 | ||
|
|
b485f2e42e | ||
|
|
16e32c3f67 | ||
|
|
15f65bb558 | ||
|
|
b161d6831f | ||
|
|
969953039f | ||
|
|
f1506ed5da | ||
|
|
9a239d9e13 | ||
|
|
a5da09dfb9 | ||
|
|
6f81f2d143 | ||
|
|
0b877ca8a3 | ||
|
|
2911b9cd04 | ||
|
|
6b3f1ab0e4 | ||
|
|
2c15655b08 | ||
|
|
afa9c650fe | ||
|
|
28d8d82ded | ||
|
|
a100baf57f | ||
|
|
5621755655 | ||
|
|
d892bfc278 | ||
|
|
4369b18fbf | ||
|
|
fb9b5d31e8 | ||
|
|
3bf0748389 | ||
|
|
cf46b89814 | ||
|
|
3360b34af9 | ||
|
|
4558eb41fc | ||
|
|
bbc5584f80 | ||
|
|
8604c9f9d5 | ||
|
|
747e02ee0d | ||
|
|
8b0334309b | ||
|
|
48afa821e4 | ||
|
|
42a8d3e3dc | ||
|
|
a44fc51007 | ||
|
|
961bc874d2 | ||
|
|
b2b018ab93 | ||
|
|
77da33de4f | ||
|
|
06ad5e3f8c | ||
|
|
9326bf96fc | ||
|
|
bed73102b4 | ||
|
|
eb59f9c75d | ||
|
|
f3bd2ed472 | ||
|
|
456475d593 | ||
|
|
a36ce199ba | ||
|
|
b7c3ad0867 | ||
|
|
ea3545cc7e | ||
|
|
232ba46b16 | ||
|
|
5f011502d1 | ||
|
|
93b6f1066b | ||
|
|
52fe92ed7f | ||
|
|
0d005df463 | ||
|
|
e3ef3ace29 | ||
|
|
a203e98689 | ||
|
|
27f99a0f38 | ||
|
|
d1e48d02bd | ||
|
|
4f06a1df50 | ||
|
|
2d7ae1180f | ||
|
|
75b486b467 | ||
|
|
5b5f10fe93 | ||
|
|
5f654e76e2 | ||
|
|
aa8d112c58 | ||
|
|
e82dc0e841 | ||
|
|
dd741fc38a | ||
|
|
120e4ee92f | ||
|
|
9d2a56bff4 | ||
|
|
31d82a3169 | ||
|
|
d22ee5d451 | ||
|
|
203edaed50 | ||
|
|
93b5638a9c | ||
|
|
52a5e58f0c | ||
|
|
20607b0b5c | ||
|
|
6bebfe9e54 | ||
|
|
50b76f4466 | ||
|
|
23e4e25e9a | ||
|
|
5b83d478d6 | ||
|
|
dca38d01d6 | ||
|
|
0a434d3b3a | ||
|
|
7c4b83a430 | ||
|
|
b7f24b428b | ||
|
|
22a0ed0ee2 | ||
|
|
cf711d55a5 | ||
|
|
26ea562fdb | ||
|
|
efce0c6c57 | ||
|
|
a3768dae97 | ||
|
|
85efea3fb8 | ||
|
|
c820fda26d | ||
|
|
4740293640 | ||
|
|
8be8813cd8 | ||
|
|
8cc747ef22 | ||
|
|
d6ed2ab3e0 | ||
|
|
e8ae980104 | ||
|
|
cd8c23c0ab | ||
|
|
3568042cd9 | ||
|
|
7443129e18 | ||
|
|
a9e03e6172 | ||
|
|
cb16bf552e | ||
|
|
98952198bb | ||
|
|
338e914a60 | ||
|
|
4196a3db5a | ||
|
|
78fb457765 | ||
|
|
8759ef012f | ||
|
|
f8d67a62a2 | ||
|
|
efb98854b2 | ||
|
|
7b29f429ee | ||
|
|
265c7d93a2 | ||
|
|
ce57ad3570 | ||
|
|
0e6b608f91 | ||
|
|
f1856fe4d2 | ||
|
|
870cdd5a56 | ||
|
|
9282f1d893 | ||
|
|
9546a47f2b | ||
|
|
f0f277dc2a | ||
|
|
b695e67154 | ||
|
|
fa2cd85007 | ||
|
|
8073cbd96a | ||
|
|
5eba2f1d61 | ||
|
|
5ec421d8e6 | ||
|
|
1e25bf700d | ||
|
|
30fb349d91 | ||
|
|
d40fb68500 | ||
|
|
3049ad47e5 | ||
|
|
8945a3a2dd | ||
|
|
d191eef657 | ||
|
|
6ac7878863 | ||
|
|
c0a23ffa62 | ||
|
|
7d691f362d | ||
|
|
bf577b8937 | ||
|
|
819290c9b8 | ||
|
|
22e8b46159 | ||
|
|
76b8cc1168 | ||
|
|
fce07325b9 | ||
|
|
123862d41c | ||
|
|
7e298f8ad1 | ||
|
|
34aca14858 | ||
|
|
6b1f94348a | ||
|
|
4322037639 | ||
|
|
ae11f88595 | ||
|
|
389a4c3e4c | ||
|
|
efb691e6c2 | ||
|
|
53e3b35437 | ||
|
|
eb265a55e1 | ||
|
|
950f7d214f | ||
|
|
6bd2316d9c | ||
|
|
660180ea1b | ||
|
|
efc8457770 | ||
|
|
9b8b982d8a | ||
|
|
e6949e611a | ||
|
|
cffade7210 | ||
|
|
6b9237f868 | ||
|
|
1f4cf07b63 | ||
|
|
59a1f4c900 | ||
|
|
0a04a76c71 | ||
|
|
9e6bc518cc | ||
|
|
bfb6fbbac9 | ||
|
|
9c08d8cf20 | ||
|
|
281054ff4c | ||
|
|
3002659f47 | ||
|
|
647f8d7958 | ||
|
|
5d289d38ba | ||
|
|
05ea0dd54f | ||
|
|
1dad04ec09 | ||
|
|
2171117c53 | ||
|
|
d389befc9e | ||
|
|
3ced5ff144 | ||
|
|
38d3ab5acf | ||
|
|
ab32e15a86 | ||
|
|
25e17b95d5 | ||
|
|
d07224e658 | ||
|
|
aa15d45a3d | ||
|
|
c6c68da0b5 | ||
|
|
1a0aac81df | ||
|
|
39cb45c11c | ||
|
|
05d9aa53ef | ||
|
|
86f374df58 | ||
|
|
6935260bf0 | ||
|
|
f0d888729b | ||
|
|
6d7d4292ef | ||
|
|
fcefac9dbe | ||
|
|
ad5f731b20 | ||
|
|
76da067d40 | ||
|
|
0689670698 | ||
|
|
5a6f32c392 | ||
|
|
d6276c4692 | ||
|
|
29a44eb7ae | ||
|
|
048a625181 | ||
|
|
64782027c4 | ||
|
|
277645db50 | ||
|
|
3f53e4f53e | ||
|
|
0c5d4ca0a7 | ||
|
|
44495b153a | ||
|
|
de6e551cdb | ||
|
|
aeb393e391 | ||
|
|
db1b11deaf | ||
|
|
5a5e8ce652 | ||
|
|
6c31151430 | ||
|
|
a8ba2eba33 | ||
|
|
c974b1053c | ||
|
|
1ab75b8a92 | ||
|
|
75e3959474 | ||
|
|
bc371778b6 | ||
|
|
cd2870aebc | ||
|
|
7c72545217 | ||
|
|
2591ca3d60 | ||
|
|
c28190316f | ||
|
|
ffc22b8dac | ||
|
|
5367015a31 | ||
|
|
75c71c397e | ||
|
|
6192aebe66 | ||
|
|
a85a594597 | ||
|
|
014c9450ba | ||
|
|
63640f65e8 | ||
|
|
fd040988a3 | ||
|
|
f7c3b043b5 | ||
|
|
93e7675bc3 | ||
|
|
d7c97d4d34 | ||
|
|
dce794dbf7 | ||
|
|
093d86040f | ||
|
|
39617bc8c6 | ||
|
|
7da224ba92 | ||
|
|
df862732df | ||
|
|
fd4447f60a | ||
|
|
ea79d59aa0 | ||
|
|
41b0cf406c | ||
|
|
ef32cc8e0a | ||
|
|
ee8956b0e9 | ||
|
|
5ad9f8d931 | ||
|
|
ea379e1d0e | ||
|
|
b842baf21f | ||
|
|
58c9c7d5dd | ||
|
|
384fadf227 | ||
|
|
e4def0625b | ||
|
|
44d20de251 | ||
|
|
7ea33c2ddf | ||
|
|
b43423bffc | ||
|
|
cf4700a35c | ||
|
|
6bb552128c | ||
|
|
50b4fc06f8 | ||
|
|
f7f1be9df2 | ||
|
|
59574dc80f | ||
|
|
7577ec1ac4 | ||
|
|
d487be0029 | ||
|
|
83a3872b97 | ||
|
|
1ad2f63f85 | ||
|
|
fcaa8317e4 | ||
|
|
ccda14255a | ||
|
|
8d66828229 | ||
|
|
4ebf9e35e1 | ||
|
|
2902d6c7c2 | ||
|
|
01ef1fe4e4 | ||
|
|
c3d2d07b68 | ||
|
|
18417bacb3 | ||
|
|
8ec18dd21b | ||
|
|
edaff1c689 | ||
|
|
9c3a13cb23 | ||
|
|
0b326e7af4 | ||
|
|
1a1ff836b5 | ||
|
|
34fed74f64 | ||
|
|
f89b29928c | ||
|
|
2c6d4460c3 | ||
|
|
7afd3f97ee | ||
|
|
0708452939 | ||
|
|
a9e5d99ea3 | ||
|
|
a56d9ea98b | ||
|
|
f5e80af0b3 | ||
|
|
a1a7ddbc83 | ||
|
|
8b209d8926 | ||
|
|
9344cab59a | ||
|
|
03468e05e4 | ||
|
|
11792ba1a4 | ||
|
|
5baaa06896 | ||
|
|
d3286893c4 | ||
|
|
b087b20bac | ||
|
|
6a5a839d4d | ||
|
|
5d8a0952b4 | ||
|
|
bd08ecc1e0 | ||
|
|
e4f61c1084 | ||
|
|
a38215478f | ||
|
|
c192d07a04 | ||
|
|
098880b796 | ||
|
|
150c506ece | ||
|
|
f978d8224e | ||
|
|
ab59887933 | ||
|
|
458472f3e2 | ||
|
|
a9f98c5d39 | ||
|
|
2b7dff2d94 | ||
|
|
58752d2dcf | ||
|
|
67546f4b2a | ||
|
|
8e9dae7b5f | ||
|
|
fb4ff63bad | ||
|
|
1fed1ee567 | ||
|
|
02571c20ff | ||
|
|
8201daa4b4 | ||
|
|
5b54624cd5 | ||
|
|
db737567fb | ||
|
|
5c3898d13e | ||
|
|
2fdb2be6d0 | ||
|
|
ab78efc815 | ||
|
|
fcf97d1796 | ||
|
|
e85cc6acbe | ||
|
|
da002e6ca9 | ||
|
|
070e7b6911 | ||
|
|
d5a3eb7d04 | ||
|
|
616e6953cc | ||
|
|
b7c77777a5 | ||
|
|
8a79de333a | ||
|
|
a87d4271d3 | ||
|
|
7975cdf3bf | ||
|
|
b2badad554 | ||
|
|
133d8c9f77 | ||
|
|
9708d645d3 | ||
|
|
0bca4d3efc | ||
|
|
7572e791f6 | ||
|
|
16c63b3be9 | ||
|
|
37fbcb7950 | ||
|
|
a180d13182 | ||
|
|
a6363a502a | ||
|
|
81bc096872 | ||
|
|
edcdb378fd | ||
|
|
4447e51588 | ||
|
|
fb8aac650f | ||
|
|
ba6b0637cc | ||
|
|
3502730dfc | ||
|
|
b95c5bb8f4 | ||
|
|
f35784aa97 | ||
|
|
3746482e8c | ||
|
|
5ed4b60b8f | ||
|
|
547da2da60 | ||
|
|
f88ed4dd5c | ||
|
|
87fc681df3 | ||
|
|
a39b2f5aa7 | ||
|
|
0b9b21eafd | ||
|
|
21f43b0dd8 | ||
|
|
3a7ba5725c | ||
|
|
2e4fa32d63 | ||
|
|
0199896d9a | ||
|
|
edd9049100 | ||
|
|
290c763901 | ||
|
|
226446a3b5 | ||
|
|
ab627db4be | ||
|
|
0f35d2368f | ||
|
|
3c276d13c4 | ||
|
|
b7c3328d43 | ||
|
|
4d8e63bd1a | ||
|
|
51757b83e1 | ||
|
|
87c260093a | ||
|
|
691a878aa2 | ||
|
|
b33d808bc1 | ||
|
|
4559f5b2d3 | ||
|
|
0b9c6ecb00 | ||
|
|
a7d87475af | ||
|
|
ba37750943 | ||
|
|
4fc85d27e9 | ||
|
|
246ca40aac | ||
|
|
59a6fa7274 | ||
|
|
7fa21ce95f | ||
|
|
6b7295bbdf | ||
|
|
4a8b7bfa37 | ||
|
|
b4b6bd46fe | ||
|
|
d5c96cb036 | ||
|
|
1294d286ee | ||
|
|
dc95d0d1e6 | ||
|
|
467439090d | ||
|
|
b77574dad5 | ||
|
|
296da5dbcc | ||
|
|
3ac02879de | ||
|
|
7403df7e9c | ||
|
|
617c8e8f4f | ||
|
|
a9160804a3 | ||
|
|
aa793088ed | ||
|
|
0089157b83 | ||
|
|
c48a398737 | ||
|
|
e735377218 | ||
|
|
d2b47969da | ||
|
|
af50660887 | ||
|
|
5adf1e272d | ||
|
|
abfb3f4006 | ||
|
|
5f05803643 | ||
|
|
ab0ba9f38c | ||
|
|
e1a93a1b82 | ||
|
|
e6e5f31921 | ||
|
|
8978dc7a8b | ||
|
|
d57e6425e5 | ||
|
|
b9b4b24961 | ||
|
|
4c05377c87 | ||
|
|
a9cdbce9de | ||
|
|
66403275b7 | ||
|
|
c554015526 | ||
|
|
35313ae0d6 | ||
|
|
6c359839cc | ||
|
|
be7e09b14d | ||
|
|
60b624a329 | ||
|
|
47531a6b93 | ||
|
|
0e05f725a4 | ||
|
|
034cc7f118 | ||
|
|
927cd07a3f | ||
|
|
070eba4b4c | ||
|
|
af9cc5ce11 | ||
|
|
f844772126 | ||
|
|
a8a2141626 | ||
|
|
0401f1e9ec | ||
|
|
358af20ad1 | ||
|
|
e455f06851 | ||
|
|
f191f981c4 | ||
|
|
9b659ed4f1 | ||
|
|
d39b52272e | ||
|
|
1ec2bbd533 | ||
|
|
a0ae6644ee | ||
|
|
1a7da8397b | ||
|
|
dcefd7dfb4 | ||
|
|
21edb75081 | ||
|
|
a28ab3628a | ||
|
|
856465ae59 | ||
|
|
3123d4bb9b | ||
|
|
dd21183261 | ||
|
|
d67d5d8006 | ||
|
|
c4f25a77d1 | ||
|
|
52763c09f2 | ||
|
|
ef4b0bc371 | ||
|
|
b887db474e | ||
|
|
e77555a04f | ||
|
|
4a313a5f93 | ||
|
|
3d9587f128 | ||
|
|
66778efcc5 | ||
|
|
6be78ff283 | ||
|
|
5281f2ba64 | ||
|
|
69420f713f | ||
|
|
bc322ddac4 |
@@ -4,4 +4,5 @@
|
||||
.vscode
|
||||
.gitignore
|
||||
Makefile
|
||||
docs
|
||||
docs
|
||||
.eslintcache
|
||||
12
.env.example
12
.env.example
@@ -7,6 +7,8 @@
|
||||
# 调试相关配置
|
||||
# 启用pprof
|
||||
# ENABLE_PPROF=true
|
||||
# 启用调试模式
|
||||
# DEBUG=true
|
||||
|
||||
# 数据库相关配置
|
||||
# 数据库连接字符串
|
||||
@@ -41,6 +43,14 @@
|
||||
# 更新任务启用
|
||||
# UPDATE_TASK=true
|
||||
|
||||
# 对话超时设置
|
||||
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||
# RELAY_TIMEOUT=0
|
||||
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||
# STREAMING_TIMEOUT=120
|
||||
|
||||
# Gemini 识别图片 最大图片数量
|
||||
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
||||
|
||||
# 会话密钥
|
||||
# SESSION_SECRET=random_string
|
||||
@@ -58,8 +68,6 @@
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||
# DIFY_DEBUG=true
|
||||
# 设置流式一次回复的超时时间
|
||||
# STREAMING_TIMEOUT=90
|
||||
|
||||
|
||||
# 节点类型
|
||||
|
||||
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
### PR 类型
|
||||
|
||||
- [ ] Bug 修复
|
||||
- [ ] 新功能
|
||||
- [ ] 文档更新
|
||||
- [ ] 其他
|
||||
|
||||
### PR 是否包含破坏性更新?
|
||||
|
||||
- [ ] 是
|
||||
- [ ] 否
|
||||
|
||||
### PR 描述
|
||||
|
||||
**请在下方详细描述您的 PR,包括目的、实现细节等。**
|
||||
|
||||
### **重要提示**
|
||||
|
||||
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**
|
||||
1
.github/workflows/macos-release.yml
vendored
1
.github/workflows/macos-release.yml
vendored
@@ -26,6 +26,7 @@ jobs:
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
|
||||
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Check PR Branching Strategy
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, edited]
|
||||
|
||||
jobs:
|
||||
check-branching-strategy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Enforce branching strategy
|
||||
run: |
|
||||
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
||||
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
||||
exit 1
|
||||
fi
|
||||
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "Branching strategy check passed."
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,4 +10,5 @@ web/dist
|
||||
.env
|
||||
one-api
|
||||
.DS_Store
|
||||
tiktoken_cache
|
||||
tiktoken_cache
|
||||
.eslintcache
|
||||
@@ -2,6 +2,7 @@ FROM oven/bun:latest AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
COPY web/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web .
|
||||
COPY ./VERSION .
|
||||
@@ -24,8 +25,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk update \
|
||||
&& apk upgrade \
|
||||
RUN apk upgrade --no-cache \
|
||||
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
||||
&& update-ca-certificates
|
||||
|
||||
|
||||
240
LICENSE
240
LICENSE
@@ -1,201 +1,103 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
# **New API 许可协议 (Licensing)**
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
本项目采用**基于使用场景的双重许可 (Usage-Based Dual Licensing)** 模式。
|
||||
|
||||
1. Definitions.
|
||||
**核心原则:**
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
- **默认许可:** 本项目默认在 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)** 下提供。任何用户在遵守 AGPLv3 条款和下述附加限制的前提下,均可免费使用。
|
||||
- **商业许可:** 在特定商业场景下,或当您希望获得 AGPLv3 之外的权利时,**必须**获取**商业许可证 (Commercial License)**。
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
---
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
## **1. 开源许可证 (Open Source License): AGPLv3 - 适用于基础使用**
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
- 在遵守 **AGPLv3** 条款的前提下,您可以自由地使用、修改和分发 New API。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
|
||||
- **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 New API 并通过网络提供服务 (SaaS),或者分发了修改后的版本,您必须以 AGPLv3 许可证向所有用户提供相应的**完整源代码**。
|
||||
- **附加限制 (重要):** 在仅使用 AGPLv3 开源许可证的情况下,您**必须**完整保留项目代码中原有的品牌标识、LOGO 及版权声明信息。**禁止以任何形式修改、移除或遮盖**这些信息。如需移除,必须获取商业许可证。
|
||||
- 使用前请务必仔细阅读并理解 AGPLv3 的所有条款及上述附加限制。
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
## **2. 商业许可证 (Commercial License) - 适用于高级场景及闭源需求**
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
在以下任一情况下,您**必须**联系我们获取并签署一份商业许可证,才能合法使用 New API:
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
- **场景一:移除品牌和版权信息**
|
||||
您希望在您的产品或服务中移除 New API 的 LOGO、UI界面中的版权声明或其他品牌标识。
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
- **场景二:规避 AGPLv3 开源义务**
|
||||
您基于 New API 进行了修改,并希望:
|
||||
- 通过网络提供服务(SaaS),但**不希望**向您的服务用户公开您修改后的源代码。
|
||||
- 分发一个集成了 New API 的软件产品,但**不希望**以 AGPLv3 许可证发布您的产品或公开源代码。
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
- **场景三:企业政策与集成需求**
|
||||
- 您所在公司的政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件。
|
||||
- 您需要进行 OEM 集成,将 New API 作为您闭源商业产品的一部分进行再分发。
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
- **场景四:需要商业支持与保障**
|
||||
您需要 AGPLv3 未提供的商业保障,如官方技术支持等。
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
**获取商业许可:**
|
||||
请通过电子邮件 **support@quantumnous.com** 联系 New API 团队洽谈商业授权事宜。
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
## **3. 贡献 (Contributions)**
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
- 我们欢迎社区对 New API 的贡献。所有向本项目提交的贡献(例如通过 Pull Request)都将被视为在 **AGPLv3** 许可证下提供。
|
||||
- 通过向本项目提交贡献,即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
|
||||
- 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 New API 版本中。
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
## **4. 其他条款 (Other Terms)**
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
- 关于商业许可证的具体条款、条件和价格,以双方签署的正式商业许可协议为准。
|
||||
- 项目维护者保留根据需要更新本许可政策的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
---
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
# **New API Licensing**
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
This project uses a **Usage-Based Dual Licensing** model.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
**Core Principles:**
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
- **Default License:** This project is available by default under the **GNU Affero General Public License v3.0 (AGPLv3)**. Any user may use it free of charge, provided they comply with both the AGPLv3 terms and the additional restrictions listed below.
|
||||
- **Commercial License:** For specific commercial scenarios, or if you require rights beyond those granted by AGPLv3, you **must** obtain a **Commercial License**.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
---
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
## **1. Open Source License: AGPLv3 – For Basic Usage**
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
- Under the terms of the **AGPLv3**, you are free to use, modify, and distribute New API. The complete AGPLv3 license text can be viewed at [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html).
|
||||
- **Core Obligation:** A key AGPLv3 requirement is that if you modify New API and provide it as a network service (SaaS), or distribute a modified version, you must make the **complete corresponding source code** available to all users under the AGPLv3 license.
|
||||
- **Additional Restriction (Important):** When using only the AGPLv3 open-source license, you **must** retain all original branding, logos, and copyright statements within the project’s code. **You are strictly prohibited from modifying, removing, or concealing** any such information. If you wish to remove this, you must obtain a Commercial License.
|
||||
- Please read and ensure that you fully understand all AGPLv3 terms and the above additional restriction before use.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
## **2. Commercial License – For Advanced Scenarios & Closed Source Needs**
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
You **must** contact us to obtain and sign a Commercial License in any of the following scenarios in order to legally use New API:
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
- **Scenario 1: Removal of Branding and Copyright**
|
||||
You wish to remove the New API logo, copyright statement, or other branding elements from your product or service.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
- **Scenario 2: Avoidance of AGPLv3 Open Source Obligations**
|
||||
You have modified New API and wish to:
|
||||
- Offer it as a network service (SaaS) **without** disclosing your modifications' source code to your users.
|
||||
- Distribute a software product integrated with New API **without** releasing your product under AGPLv3 or open-sourcing the code.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
- **Scenario 3: Enterprise Policy & Integration Needs**
|
||||
- Your organization’s policies, client contracts, or project requirements prohibit the use of AGPLv3-licensed software.
|
||||
- You require OEM integration and need to redistribute New API as part of your closed-source commercial product.
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
- **Scenario 4: Commercial Support and Assurances**
|
||||
You require commercial assurances not provided by AGPLv3, such as official technical support.
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
**Obtaining a Commercial License:**
|
||||
Please contact the New API team via email at **support@quantumnous.com** to discuss commercial licensing.
|
||||
|
||||
## **3. Contributions**
|
||||
|
||||
- We welcome community contributions to New API. All contributions (e.g., via Pull Request) are deemed to be provided under the **AGPLv3** license.
|
||||
- By submitting a contribution, you agree that your code is licensed to this project and all downstream users under the AGPLv3 license (regardless of whether those users ultimately operate under AGPLv3 or a Commercial License).
|
||||
- You also acknowledge and agree that your contribution may be included in New API releases distributed under a Commercial License.
|
||||
|
||||
## **4. Other Terms**
|
||||
|
||||
- The specific terms, conditions, and pricing of the Commercial License are governed by the formal commercial license agreement executed by both parties.
|
||||
- Project maintainers reserve the right to update this licensing policy as needed. Updates will be communicated via official project channels (e.g., repository, official website).
|
||||
|
||||
20
README.en.md
20
README.en.md
@@ -100,7 +100,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
|
||||
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
||||
|
||||
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds
|
||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 120 seconds
|
||||
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
||||
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
||||
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
||||
@@ -189,6 +189,24 @@ If you have any questions, please refer to [Help and Support](https://docs.newap
|
||||
- [Issue Feedback](https://docs.newapi.pro/support/feedback-issues)
|
||||
- [FAQ](https://docs.newapi.pro/support/faq)
|
||||
|
||||
## 🤝 Trusted Partners
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target="_blank"><img
|
||||
src="./docs/images/cherry-studio.svg" alt="Cherry Studio" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://bda.pku.edu.cn/" target="_blank"><img
|
||||
src="./docs/images/pku.png" alt="Peking University" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank"><img
|
||||
src="./docs/images/ucloud.svg" alt="UCloud" height="58"
|
||||
/></a>
|
||||
</p>
|
||||
|
||||
<p align="center"><em>No particular order</em></p>
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
21
README.md
21
README.md
@@ -100,7 +100,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
||||
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
||||
|
||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
|
||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认120秒
|
||||
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
||||
@@ -180,7 +180,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||
|
||||
其他基于New API的项目:
|
||||
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
|
||||
- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
|
||||
|
||||
## 帮助支持
|
||||
|
||||
@@ -189,6 +188,24 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||
- [反馈问题](https://docs.newapi.pro/support/feedback-issues)
|
||||
- [常见问题](https://docs.newapi.pro/support/faq)
|
||||
|
||||
## 🤝 我们信任的合作伙伴
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target="_blank"><img
|
||||
src="./docs/images/cherry-studio.svg" alt="Cherry Studio" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://bda.pku.edu.cn/" target="_blank"><img
|
||||
src="./docs/images/pku.png" alt="北京大学" height="58"
|
||||
/></a>
|
||||
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target="_blank"><img
|
||||
src="./docs/images/ucloud.svg" alt="UCloud 优刻得" height="58"
|
||||
/></a>
|
||||
</p>
|
||||
|
||||
<p align="center"><em>排名不分先后</em></p>
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
73
common/api_type.go
Normal file
73
common/api_type.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType := -1
|
||||
switch channelType {
|
||||
case constant.ChannelTypeOpenAI:
|
||||
apiType = constant.APITypeOpenAI
|
||||
case constant.ChannelTypeAnthropic:
|
||||
apiType = constant.APITypeAnthropic
|
||||
case constant.ChannelTypeBaidu:
|
||||
apiType = constant.APITypeBaidu
|
||||
case constant.ChannelTypePaLM:
|
||||
apiType = constant.APITypePaLM
|
||||
case constant.ChannelTypeZhipu:
|
||||
apiType = constant.APITypeZhipu
|
||||
case constant.ChannelTypeAli:
|
||||
apiType = constant.APITypeAli
|
||||
case constant.ChannelTypeXunfei:
|
||||
apiType = constant.APITypeXunfei
|
||||
case constant.ChannelTypeAIProxyLibrary:
|
||||
apiType = constant.APITypeAIProxyLibrary
|
||||
case constant.ChannelTypeTencent:
|
||||
apiType = constant.APITypeTencent
|
||||
case constant.ChannelTypeGemini:
|
||||
apiType = constant.APITypeGemini
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
apiType = constant.APITypeZhipuV4
|
||||
case constant.ChannelTypeOllama:
|
||||
apiType = constant.APITypeOllama
|
||||
case constant.ChannelTypePerplexity:
|
||||
apiType = constant.APITypePerplexity
|
||||
case constant.ChannelTypeAws:
|
||||
apiType = constant.APITypeAws
|
||||
case constant.ChannelTypeCohere:
|
||||
apiType = constant.APITypeCohere
|
||||
case constant.ChannelTypeDify:
|
||||
apiType = constant.APITypeDify
|
||||
case constant.ChannelTypeJina:
|
||||
apiType = constant.APITypeJina
|
||||
case constant.ChannelCloudflare:
|
||||
apiType = constant.APITypeCloudflare
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
apiType = constant.APITypeSiliconFlow
|
||||
case constant.ChannelTypeVertexAi:
|
||||
apiType = constant.APITypeVertexAi
|
||||
case constant.ChannelTypeMistral:
|
||||
apiType = constant.APITypeMistral
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
apiType = constant.APITypeDeepSeek
|
||||
case constant.ChannelTypeMokaAI:
|
||||
apiType = constant.APITypeMokaAI
|
||||
case constant.ChannelTypeVolcEngine:
|
||||
apiType = constant.APITypeVolcEngine
|
||||
case constant.ChannelTypeBaiduV2:
|
||||
apiType = constant.APITypeBaiduV2
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
apiType = constant.APITypeOpenRouter
|
||||
case constant.ChannelTypeXinference:
|
||||
apiType = constant.APITypeXinference
|
||||
case constant.ChannelTypeXai:
|
||||
apiType = constant.APITypeXai
|
||||
case constant.ChannelTypeCoze:
|
||||
apiType = constant.APITypeCoze
|
||||
case constant.ChannelTypeJimeng:
|
||||
apiType = constant.APITypeJimeng
|
||||
}
|
||||
if apiType == -1 {
|
||||
return constant.APITypeOpenAI, false
|
||||
}
|
||||
return apiType, true
|
||||
}
|
||||
@@ -195,105 +195,7 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
ChannelTypeAILS = 9
|
||||
ChannelTypeAIProxy = 10
|
||||
ChannelTypePaLM = 11
|
||||
ChannelTypeAPI2GPT = 12
|
||||
ChannelTypeAIGC2D = 13
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
TopUpStatusPending = "pending"
|
||||
TopUpStatusSuccess = "success"
|
||||
TopUpStatusExpired = "expired"
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://api.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"https://api.dify.ai", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type stringWriter interface {
|
||||
@@ -52,6 +53,8 @@ type CustomEvent struct {
|
||||
Id string
|
||||
Retry uint
|
||||
Data interface{}
|
||||
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
func encode(writer io.Writer, event CustomEvent) error {
|
||||
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
|
||||
}
|
||||
|
||||
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||
r.Mutex.Lock()
|
||||
defer r.Mutex.Unlock()
|
||||
header := w.Header()
|
||||
header["Content-Type"] = contentType
|
||||
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
package common
|
||||
|
||||
const (
|
||||
DatabaseTypeMySQL = "mysql"
|
||||
DatabaseTypeSQLite = "sqlite"
|
||||
DatabaseTypePostgreSQL = "postgres"
|
||||
)
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||
var UsingMySQL = false
|
||||
var UsingClickHouse = false
|
||||
|
||||
|
||||
41
common/endpoint_type.go
Normal file
41
common/endpoint_type.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
|
||||
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
|
||||
var endpointTypes []constant.EndpointType
|
||||
switch channelType {
|
||||
case constant.ChannelTypeJina:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
|
||||
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
|
||||
//case constant.ChannelTypeSunoAPI:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
|
||||
//case constant.ChannelTypeKling:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
|
||||
//case constant.ChannelTypeJimeng:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
|
||||
case constant.ChannelTypeAws:
|
||||
fallthrough
|
||||
case constant.ChannelTypeAnthropic:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeVertexAi:
|
||||
fallthrough
|
||||
case constant.ChannelTypeGemini:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
default:
|
||||
if IsOpenAIResponseOnlyModel(modelName) {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||
} else {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
}
|
||||
}
|
||||
if IsImageGenerationModel(modelName) {
|
||||
// add to first
|
||||
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
|
||||
}
|
||||
return endpointTypes
|
||||
}
|
||||
@@ -2,10 +2,12 @@ package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const KeyRequestBody = "key_request_body"
|
||||
@@ -31,7 +33,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
err = Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
@@ -43,3 +45,67 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
||||
c.Set(string(key), value)
|
||||
}
|
||||
|
||||
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
||||
return c.Get(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
||||
return c.GetString(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
||||
return c.GetInt(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
||||
return c.GetBool(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
||||
return c.GetStringSlice(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
||||
return c.GetStringMap(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
||||
return c.GetTime(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
|
||||
if value, ok := c.Get(string(key)); ok {
|
||||
if v, ok := value.(T); ok {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
var t T
|
||||
return t, false
|
||||
}
|
||||
|
||||
func ApiError(c *gin.Context, err error) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ApiErrorMsg(c *gin.Context, msg string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": msg,
|
||||
})
|
||||
}
|
||||
|
||||
func ApiSuccess(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
34
common/hash.go
Normal file
34
common/hash.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func Sha256Raw(data []byte) []byte {
|
||||
h := sha256.New()
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1Raw(data []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1(data []byte) string {
|
||||
return hex.EncodeToString(Sha1Raw(data))
|
||||
}
|
||||
|
||||
func HmacSha256Raw(message, key []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(message)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func HmacSha256(message, key string) string {
|
||||
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
|
||||
}
|
||||
57
common/http.go
Normal file
57
common/http.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CloseResponseBodyGracefully(httpResponse *http.Response) {
|
||||
if httpResponse == nil || httpResponse.Body == nil {
|
||||
return
|
||||
}
|
||||
err := httpResponse.Body.Close()
|
||||
if err != nil {
|
||||
SysError("failed to close response body: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
|
||||
if c.Writer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
body := io.NopCloser(bytes.NewBuffer(data))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
if src != nil {
|
||||
for k, v := range src.Header {
|
||||
// avoid setting Content-Length
|
||||
if k == "Content-Length" {
|
||||
continue
|
||||
}
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
}
|
||||
|
||||
// set Content-Length header manually BEFORE calling WriteHeader
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
|
||||
// Write header with status code (this sends the headers)
|
||||
if src != nil {
|
||||
c.Writer.WriteHeader(src.StatusCode)
|
||||
} else {
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
_, err := io.Copy(c.Writer, body)
|
||||
if err != nil {
|
||||
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -24,7 +25,7 @@ func printHelp() {
|
||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||
}
|
||||
|
||||
func LoadEnv() {
|
||||
func InitEnv() {
|
||||
flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
@@ -95,4 +96,25 @@ func LoadEnv() {
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
initConstantEnv()
|
||||
}
|
||||
|
||||
func initConstantEnv() {
|
||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
}
|
||||
|
||||
@@ -5,14 +5,18 @@ import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func DecodeJson(data []byte, v any) error {
|
||||
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func DecodeJsonStr(data string, v any) error {
|
||||
return DecodeJson(StringToByteSlice(data), v)
|
||||
func UnmarshalJsonStr(data string, v any) error {
|
||||
return json.Unmarshal(StringToByteSlice(data), v)
|
||||
}
|
||||
|
||||
func EncodeJson(v any) ([]byte, error) {
|
||||
func DecodeJson(reader *bytes.Reader, v any) error {
|
||||
return json.NewDecoder(reader).Decode(v)
|
||||
}
|
||||
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
@@ -75,6 +75,9 @@ func logHelper(ctx context.Context, level string, msg string) {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
id := ctx.Value(RequestIdKey)
|
||||
if id == nil {
|
||||
id = "SYSTEM"
|
||||
}
|
||||
now := time.Now()
|
||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||
logCount++ // we don't need accurate count, so no lock here
|
||||
|
||||
42
common/model.go
Normal file
42
common/model.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package common
|
||||
|
||||
import "strings"
|
||||
|
||||
var (
|
||||
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
|
||||
OpenAIResponseOnlyModels = []string{
|
||||
"o3-pro",
|
||||
"o3-deep-research",
|
||||
"o4-mini-deep-research",
|
||||
}
|
||||
ImageGenerationModels = []string{
|
||||
"dall-e-3",
|
||||
"dall-e-2",
|
||||
"gpt-image-1",
|
||||
"prefix:imagen-",
|
||||
"flux-",
|
||||
"flux.1-",
|
||||
}
|
||||
)
|
||||
|
||||
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||
for _, m := range OpenAIResponseOnlyModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsImageGenerationModel(modelName string) bool {
|
||||
modelName = strings.ToLower(modelName)
|
||||
for _, m := range ImageGenerationModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
82
common/page_info.go
Normal file
82
common/page_info.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PageInfo struct {
|
||||
Page int `json:"page"` // page num 页码
|
||||
PageSize int `json:"page_size"` // page size 页大小
|
||||
|
||||
Total int `json:"total"` // 总条数,后设置
|
||||
Items any `json:"items"` // 数据,后设置
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetStartIdx() int {
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetEndIdx() int {
|
||||
return p.Page * p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetPageSize() int {
|
||||
return p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetPage() int {
|
||||
return p.Page
|
||||
}
|
||||
|
||||
func (p *PageInfo) SetTotal(total int) {
|
||||
p.Total = total
|
||||
}
|
||||
|
||||
func (p *PageInfo) SetItems(items any) {
|
||||
p.Items = items
|
||||
}
|
||||
|
||||
func GetPageQuery(c *gin.Context) *PageInfo {
|
||||
pageInfo := &PageInfo{}
|
||||
// 手动获取并处理每个参数
|
||||
if page, err := strconv.Atoi(c.Query("p")); err == nil {
|
||||
pageInfo.Page = page
|
||||
}
|
||||
if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
if pageInfo.Page < 1 {
|
||||
// 兼容
|
||||
page, _ := strconv.Atoi(c.Query("p"))
|
||||
if page != 0 {
|
||||
pageInfo.Page = page
|
||||
} else {
|
||||
pageInfo.Page = 1
|
||||
}
|
||||
}
|
||||
|
||||
if pageInfo.PageSize == 0 {
|
||||
// 兼容
|
||||
pageSize, _ := strconv.Atoi(c.Query("ps"))
|
||||
if pageSize != 0 {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageSize, _ = strconv.Atoi(c.Query("size")) // token page
|
||||
if pageSize != 0 {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
}
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageInfo.PageSize = ItemsPerPage
|
||||
}
|
||||
}
|
||||
|
||||
if pageInfo.PageSize > 100 {
|
||||
pageInfo.PageSize = 100
|
||||
}
|
||||
|
||||
return pageInfo
|
||||
}
|
||||
@@ -16,6 +16,10 @@ import (
|
||||
var RDB *redis.Client
|
||||
var RedisEnabled = true
|
||||
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return SyncFrequency
|
||||
}
|
||||
|
||||
// InitRedisClient This function is called after init()
|
||||
func InitRedisClient() (err error) {
|
||||
if os.Getenv("REDIS_CONN_STRING") == "" {
|
||||
@@ -141,7 +145,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||
|
||||
txn := RDB.TxPipeline()
|
||||
txn.HSet(ctx, key, data)
|
||||
txn.Expire(ctx, key, expiration)
|
||||
|
||||
// 只有在 expiration 大于 0 时才设置过期时间
|
||||
if expiration > 0 {
|
||||
txn.Expire(ctx, key, expiration)
|
||||
}
|
||||
|
||||
_, err := txn.Exec(ctx)
|
||||
if err != nil {
|
||||
|
||||
132
common/str.go
132
common/str.go
@@ -1,9 +1,13 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -31,16 +35,30 @@ func MapToJsonStr(m map[string]interface{}) string {
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func StrToMap(str string) map[string]interface{} {
|
||||
func StrToMap(str string) (map[string]interface{}, error) {
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(str), &m)
|
||||
err := Unmarshal([]byte(str), &m)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
return m
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func IsJsonStr(str string) bool {
|
||||
func StrToJsonArray(str string) ([]interface{}, error) {
|
||||
var js []interface{}
|
||||
err := json.Unmarshal([]byte(str), &js)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func IsJsonArray(str string) bool {
|
||||
var js []interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
|
||||
func IsJsonObject(str string) bool {
|
||||
var js map[string]interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
@@ -68,3 +86,107 @@ func StringToByteSlice(s string) []byte {
|
||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||
}
|
||||
|
||||
func EncodeBase64(str string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(str))
|
||||
}
|
||||
|
||||
func GetJsonString(data any) string {
|
||||
if data == nil {
|
||||
return ""
|
||||
}
|
||||
b, _ := json.Marshal(data)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
|
||||
// Example:
|
||||
// http://example.com -> http://***.com
|
||||
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
||||
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
||||
// 192.168.1.1 -> ***.***.***.***
|
||||
func MaskSensitiveInfo(str string) string {
|
||||
// Mask URLs
|
||||
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
host := u.Host
|
||||
if host == "" {
|
||||
return urlStr
|
||||
}
|
||||
|
||||
// Split host by dots
|
||||
parts := strings.Split(host, ".")
|
||||
if len(parts) < 2 {
|
||||
// If less than 2 parts, just mask the whole host
|
||||
return u.Scheme + "://***" + u.Path
|
||||
}
|
||||
|
||||
// Keep the TLD (Top Level Domain) and mask the rest
|
||||
var maskedHost string
|
||||
if len(parts) == 2 {
|
||||
// example.com -> ***.com
|
||||
maskedHost = "***." + parts[len(parts)-1]
|
||||
} else {
|
||||
// Handle cases like sub.domain.co.uk or api.example.com
|
||||
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
|
||||
lastPart := parts[len(parts)-1]
|
||||
secondLastPart := parts[len(parts)-2]
|
||||
|
||||
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
||||
// Likely country code TLD like co.uk, com.cn
|
||||
maskedHost = "***." + secondLastPart + "." + lastPart
|
||||
} else {
|
||||
// Regular TLD like .com, .org
|
||||
maskedHost = "***." + lastPart
|
||||
}
|
||||
}
|
||||
|
||||
result := u.Scheme + "://" + maskedHost
|
||||
|
||||
// Mask path
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
|
||||
maskedPathParts := make([]string, len(pathParts))
|
||||
for i := range pathParts {
|
||||
if pathParts[i] != "" {
|
||||
maskedPathParts[i] = "***"
|
||||
}
|
||||
}
|
||||
if len(maskedPathParts) > 0 {
|
||||
result += "/" + strings.Join(maskedPathParts, "/")
|
||||
}
|
||||
} else if u.Path == "/" {
|
||||
result += "/"
|
||||
}
|
||||
|
||||
// Mask query parameters
|
||||
if u.RawQuery != "" {
|
||||
values, err := url.ParseQuery(u.RawQuery)
|
||||
if err != nil {
|
||||
// If can't parse query, just mask the whole query string
|
||||
result += "?***"
|
||||
} else {
|
||||
maskedParams := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
maskedParams = append(maskedParams, key+"=***")
|
||||
}
|
||||
if len(maskedParams) > 0 {
|
||||
result += "?" + strings.Join(maskedParams, "&")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// Mask IP addresses
|
||||
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||
}
|
||||
|
||||
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
||||
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
|
||||
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||
}
|
||||
durationStr := string(bytes.TrimSpace(output))
|
||||
if durationStr == "N/A" {
|
||||
// Create a temporary output file name
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||
}
|
||||
tmpName := tmpFp.Name()
|
||||
// Close immediately so ffmpeg can open the file on Windows.
|
||||
_ = tmpFp.Close()
|
||||
defer os.Remove(tmpName)
|
||||
|
||||
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||
if err := ffmpegCmd.Run(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||
}
|
||||
|
||||
// Recalculate the duration of the new file
|
||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||
}
|
||||
durationStr = string(bytes.TrimSpace(output))
|
||||
}
|
||||
return strconv.ParseFloat(durationStr, 64)
|
||||
}
|
||||
|
||||
// BuildURL concatenates base and endpoint, returns the complete url string
|
||||
func BuildURL(base string, endpoint string) string {
|
||||
u, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
end := endpoint
|
||||
if end == "" {
|
||||
end = "/"
|
||||
}
|
||||
ref, err := url.Parse(end)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
return u.ResolveReference(ref).String()
|
||||
}
|
||||
|
||||
26
constant/README.md
Normal file
26
constant/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# constant 包 (`/constant`)
|
||||
|
||||
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
|
||||
|
||||
## 当前文件
|
||||
|
||||
| 文件 | 说明 |
|
||||
|----------------------|---------------------------------------------------------------------|
|
||||
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
|
||||
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
|
||||
| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
|
||||
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
|
||||
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
|
||||
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
|
||||
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
|
||||
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
|
||||
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
|
||||
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
|
||||
|
||||
## 使用约定
|
||||
|
||||
1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
|
||||
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
|
||||
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
|
||||
|
||||
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
|
||||
35
constant/api_type.go
Normal file
35
constant/api_type.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeAnthropic
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
APITypeZhipuV4
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
APITypeDeepSeek
|
||||
APITypeMokaAI
|
||||
APITypeVolcEngine
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeCoze
|
||||
APITypeJimeng
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -1,14 +1,5 @@
|
||||
package constant
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
var (
|
||||
TokenCacheSeconds = common.SyncFrequency
|
||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||
)
|
||||
|
||||
// Cache keys
|
||||
const (
|
||||
UserGroupKeyFmt = "user_group:%d"
|
||||
|
||||
111
constant/channel.go
Normal file
111
constant/channel.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
ChannelTypeAILS = 9
|
||||
ChannelTypeAIProxy = 10
|
||||
ChannelTypePaLM = 11
|
||||
ChannelTypeAPI2GPT = 12
|
||||
ChannelTypeAIGC2D = 13
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeVidu = 52
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://api.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"https://api.dify.ai", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
"https://api.vidu.cn", //52
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
|
||||
ChanelSettingProxy = "proxy" // Proxy 代理
|
||||
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
|
||||
)
|
||||
@@ -1,10 +1,44 @@
|
||||
package constant
|
||||
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ContextKeyRequestStartTime = "request_start_time"
|
||||
ContextKeyUserSetting = "user_setting"
|
||||
ContextKeyUserQuota = "user_quota"
|
||||
ContextKeyUserStatus = "user_status"
|
||||
ContextKeyUserEmail = "user_email"
|
||||
ContextKeyUserGroup = "user_group"
|
||||
ContextKeyOriginalModel ContextKey = "original_model"
|
||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||
|
||||
/* token related keys */
|
||||
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
|
||||
ContextKeyTokenKey ContextKey = "token_key"
|
||||
ContextKeyTokenId ContextKey = "token_id"
|
||||
ContextKeyTokenGroup ContextKey = "token_group"
|
||||
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||
|
||||
/* channel related keys */
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelName ContextKey = "channel_name"
|
||||
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
|
||||
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||
ContextKeyChannelType ContextKey = "channel_type"
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
|
||||
ContextKeyChannelKey ContextKey = "channel_key"
|
||||
|
||||
/* user related keys */
|
||||
ContextKeyUserId ContextKey = "id"
|
||||
ContextKeyUserSetting ContextKey = "user_setting"
|
||||
ContextKeyUserQuota ContextKey = "user_quota"
|
||||
ContextKeyUserStatus ContextKey = "user_status"
|
||||
ContextKeyUserEmail ContextKey = "user_email"
|
||||
ContextKeyUserGroup ContextKey = "user_group"
|
||||
ContextKeyUsingGroup ContextKey = "group"
|
||||
ContextKeyUserName ContextKey = "username"
|
||||
)
|
||||
|
||||
16
constant/endpoint_type.go
Normal file
16
constant/endpoint_type.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package constant
|
||||
|
||||
type EndpointType string
|
||||
|
||||
const (
|
||||
EndpointTypeOpenAI EndpointType = "openai"
|
||||
EndpointTypeOpenAIResponse EndpointType = "openai-response"
|
||||
EndpointTypeAnthropic EndpointType = "anthropic"
|
||||
EndpointTypeGemini EndpointType = "gemini"
|
||||
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||
//EndpointTypeKling EndpointType = "kling"
|
||||
//EndpointTypeJimeng EndpointType = "jimeng"
|
||||
)
|
||||
@@ -1,9 +1,5 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
@@ -17,39 +13,3 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
|
||||
//var GeminiModelMap = map[string]string{
|
||||
// "gemini-1.0-pro": "v1",
|
||||
//}
|
||||
|
||||
func InitEnv() {
|
||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
|
||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
//if modelVersionMapStr == "" {
|
||||
// return
|
||||
//}
|
||||
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
||||
// parts := strings.Split(pair, ":")
|
||||
// if len(parts) == 2 {
|
||||
// GeminiModelMap[parts[0]] = parts[1]
|
||||
// } else {
|
||||
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ const (
|
||||
MjActionPan = "PAN"
|
||||
MjActionSwapFace = "SWAP_FACE"
|
||||
MjActionUpload = "UPLOAD"
|
||||
MjActionVideo = "VIDEO"
|
||||
MjActionEdits = "EDITS"
|
||||
)
|
||||
|
||||
var MidjourneyModel2Action = map[string]string{
|
||||
@@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{
|
||||
"mj_pan": MjActionPan,
|
||||
"swap_face": MjActionSwapFace,
|
||||
"mj_upload": MjActionUpload,
|
||||
"mj_video": MjActionVideo,
|
||||
"mj_edits": MjActionEdits,
|
||||
}
|
||||
|
||||
8
constant/multi_key_mode.go
Normal file
8
constant/multi_key_mode.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package constant
|
||||
|
||||
type MultiKeyMode string
|
||||
|
||||
const (
|
||||
MultiKeyModeRandom MultiKeyMode = "random" // 随机
|
||||
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
|
||||
)
|
||||
@@ -10,6 +10,9 @@ const (
|
||||
const (
|
||||
SunoActionMusic = "MUSIC"
|
||||
SunoActionLyrics = "LYRICS"
|
||||
|
||||
TaskActionGenerate = "generate"
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
|
||||
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
|
||||
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
|
||||
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
||||
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
||||
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
)
|
||||
|
||||
var (
|
||||
NotifyTypeEmail = "email" // Email 邮件
|
||||
NotifyTypeWebhook = "webhook" // Webhook
|
||||
)
|
||||
@@ -7,11 +7,16 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -304,34 +309,70 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.moonshot.cn/v1/users/me/balance"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
type MoonshotBalanceData struct {
|
||||
AvailableBalance float64 `json:"available_balance"`
|
||||
VoucherBalance float64 `json:"voucher_balance"`
|
||||
CashBalance float64 `json:"cash_balance"`
|
||||
}
|
||||
|
||||
type MoonshotBalanceResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data MoonshotBalanceData `json:"data"`
|
||||
Scode string `json:"scode"`
|
||||
Status bool `json:"status"`
|
||||
}
|
||||
|
||||
response := MoonshotBalanceResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !response.Status || response.Code != 0 {
|
||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||
}
|
||||
availableBalanceCny := response.Data.AvailableBalance
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||
channel.UpdateBalance(availableBalanceUsd)
|
||||
return availableBalanceUsd, nil
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
channel.BaseURL = &baseURL
|
||||
}
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeOpenAI:
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
return 0, errors.New("尚未实现")
|
||||
case common.ChannelTypeCustom:
|
||||
case constant.ChannelTypeCustom:
|
||||
baseURL = channel.GetBaseURL()
|
||||
//case common.ChannelTypeOpenAISB:
|
||||
// return updateChannelOpenAISBBalance(channel)
|
||||
case common.ChannelTypeAIProxy:
|
||||
case constant.ChannelTypeAIProxy:
|
||||
return updateChannelAIProxyBalance(channel)
|
||||
case common.ChannelTypeAPI2GPT:
|
||||
case constant.ChannelTypeAPI2GPT:
|
||||
return updateChannelAPI2GPTBalance(channel)
|
||||
case common.ChannelTypeAIGC2D:
|
||||
case constant.ChannelTypeAIGC2D:
|
||||
return updateChannelAIGC2DBalance(channel)
|
||||
case common.ChannelTypeSiliconFlow:
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
return updateChannelSiliconFlowBalance(channel)
|
||||
case common.ChannelTypeDeepSeek:
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
return updateChannelDeepSeekBalance(channel)
|
||||
case common.ChannelTypeOpenRouter:
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
return updateChannelOpenRouterBalance(channel)
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return updateChannelMoonshotBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
}
|
||||
@@ -370,26 +411,24 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
func UpdateChannelBalance(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
channel, err := model.CacheGetChannel(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": "多密钥渠道不支持余额查询",
|
||||
})
|
||||
return
|
||||
}
|
||||
balance, err := updateChannelBalance(channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -397,7 +436,6 @@ func UpdateChannelBalance(c *gin.Context) {
|
||||
"message": "",
|
||||
"balance": balance,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func updateAllChannelsBalance() error {
|
||||
@@ -409,6 +447,9 @@ func updateAllChannelsBalance() error {
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
continue // skip multi-key channels
|
||||
}
|
||||
// TODO: support Azure
|
||||
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||||
// continue
|
||||
@@ -419,7 +460,7 @@ func updateAllChannelsBalance() error {
|
||||
} else {
|
||||
// err is nil & balance <= 0 means quota is used up
|
||||
if balance <= 0 {
|
||||
service.DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||
service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
|
||||
}
|
||||
}
|
||||
time.Sleep(common.RequestInterval)
|
||||
@@ -431,10 +472,7 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
||||
// TODO: make it async
|
||||
err := updateAllChannelsBalance()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -11,14 +11,16 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -29,16 +31,49 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
type testResult struct {
|
||||
context *gin.Context
|
||||
localErr error
|
||||
newAPIError *types.NewAPIError
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
tik := time.Now()
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
if channel.Type == constant.ChannelTypeMidjourney {
|
||||
return testResult{
|
||||
localErr: errors.New("midjourney channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == common.ChannelTypeMidjourneyPlus {
|
||||
return errors.New("midjourney plus channel test is not supported!!!"), nil
|
||||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||
return testResult{
|
||||
localErr: errors.New("midjourney plus channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == common.ChannelTypeSunoAPI {
|
||||
return errors.New("suno channel test is not supported"), nil
|
||||
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||
return testResult{
|
||||
localErr: errors.New("suno channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeKling {
|
||||
return testResult{
|
||||
localErr: errors.New("kling channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeJimeng {
|
||||
return testResult{
|
||||
localErr: errors.New("jimeng channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeVidu {
|
||||
return testResult{
|
||||
localErr: errors.New("vidu channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -50,7 +85,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||
strings.Contains(testModel, "embed") ||
|
||||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
requestPath = "/v1/embeddings" // 修改请求路径
|
||||
}
|
||||
|
||||
@@ -75,31 +110,49 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
cache, err := model.GetUserCache(1)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
localErr: err,
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
cache.WriteContext(c)
|
||||
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
group, _ := model.GetUserGroup(1, false)
|
||||
c.Set("group", group)
|
||||
|
||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
if newAPIError != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: newAPIError,
|
||||
newAPIError: newAPIError,
|
||||
}
|
||||
}
|
||||
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info)
|
||||
err = helper.ModelMappedHelper(c, info, nil)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||
}
|
||||
}
|
||||
testModel = info.UpstreamModelName
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
||||
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
||||
}
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel)
|
||||
@@ -110,45 +163,91 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
||||
}
|
||||
}
|
||||
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
var convertedRequest any
|
||||
// 根据 RelayMode 选择正确的转换函数
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
// 创建一个 EmbeddingRequest
|
||||
embeddingRequest := dto.EmbeddingRequest{
|
||||
Input: request.Input,
|
||||
Model: request.Model,
|
||||
}
|
||||
// 调用专门用于 Embedding 的转换函数
|
||||
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
|
||||
} else {
|
||||
// 对其他所有请求类型(如 Chat),保持原有逻辑
|
||||
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||
}
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||
}
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
}
|
||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: respErr,
|
||||
newAPIError: respErr,
|
||||
}
|
||||
}
|
||||
if usageA == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: errors.New("usage is nil"),
|
||||
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
result := w.Result()
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
|
||||
@@ -165,12 +264,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||
ChannelId: channel.Id,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
ModelName: info.OriginModelName,
|
||||
TokenName: "模型测试",
|
||||
Quota: quota,
|
||||
Content: "模型测试",
|
||||
UseTimeSeconds: int(consumedTime),
|
||||
IsStream: false,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: nil,
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
@@ -185,7 +299,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
strings.Contains(model, "bge-") {
|
||||
testRequest.Model = model
|
||||
// Embedding 请求
|
||||
testRequest.Input = []string{"hello world"}
|
||||
testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
|
||||
return testRequest
|
||||
}
|
||||
// 并非Embedding 模型
|
||||
@@ -196,7 +310,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
testRequest.MaxTokens = 50
|
||||
}
|
||||
} else if strings.Contains(model, "gemini") {
|
||||
testRequest.MaxTokens = 300
|
||||
testRequest.MaxTokens = 3000
|
||||
} else {
|
||||
testRequest.MaxTokens = 10
|
||||
}
|
||||
@@ -213,31 +327,38 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
func TestChannel(c *gin.Context) {
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(channelId, true)
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
//defer func() {
|
||||
// if channel.ChannelInfo.IsMultiKey {
|
||||
// go func() { _ = channel.SaveChannelInfo() }()
|
||||
// }
|
||||
//}()
|
||||
testModel := c.Query("model")
|
||||
tik := time.Now()
|
||||
err, _ = testChannel(channel, testModel)
|
||||
result := testChannel(channel, testModel)
|
||||
if result.localErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": result.localErr.Error(),
|
||||
"time": 0.0,
|
||||
})
|
||||
return
|
||||
}
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
go channel.UpdateResponseTime(milliseconds)
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
if err != nil {
|
||||
if result.newAPIError != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": result.newAPIError.Error(),
|
||||
"time": consumedTime,
|
||||
})
|
||||
return
|
||||
@@ -262,9 +383,9 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
testAllChannelsRunning = true
|
||||
testAllChannelsLock.Unlock()
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
return err
|
||||
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
||||
if getChannelErr != nil {
|
||||
return getChannelErr
|
||||
}
|
||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||
if disableThreshold == 0 {
|
||||
@@ -281,38 +402,40 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiWithStatusErr := testChannel(channel, "")
|
||||
result := testChannel(channel, "")
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
shouldBanChannel := false
|
||||
|
||||
newAPIError := result.newAPIError
|
||||
// request error disables the channel
|
||||
if openaiWithStatusErr != nil {
|
||||
oaiErr := openaiWithStatusErr.Error
|
||||
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||
if newAPIError != nil {
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
||||
}
|
||||
|
||||
if milliseconds > disableThreshold {
|
||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
shouldBanChannel = true
|
||||
// 当错误检查通过,才检查响应时间
|
||||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||
if milliseconds > disableThreshold {
|
||||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||||
shouldBanChannel = true
|
||||
}
|
||||
}
|
||||
|
||||
// disable channel
|
||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
}
|
||||
|
||||
// enable channel
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
||||
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
||||
}
|
||||
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
|
||||
|
||||
if notify {
|
||||
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||||
}
|
||||
@@ -323,10 +446,7 @@ func testAllChannels(notify bool) error {
|
||||
func TestAllChannels(c *gin.Context) {
|
||||
err := testAllChannels(true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -337,6 +457,10 @@ func TestAllChannels(c *gin.Context) {
|
||||
}
|
||||
|
||||
func AutomaticallyTestChannels(frequency int) {
|
||||
if frequency <= 0 {
|
||||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||
return
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("testing all channels")
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -40,50 +41,112 @@ type OpenAIModelsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
func parseStatusFilter(statusParam string) int {
|
||||
switch strings.ToLower(statusParam) {
|
||||
case "enabled", "1":
|
||||
return common.ChannelStatusEnabled
|
||||
case "disabled", "0":
|
||||
return 0
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
statusParam := c.Query("status")
|
||||
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
// type filter
|
||||
typeStr := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeStr != "" {
|
||||
if t, err := strconv.Atoi(typeStr); err == nil {
|
||||
typeFilter = t
|
||||
}
|
||||
}
|
||||
|
||||
var total int64
|
||||
|
||||
if enableTagMode {
|
||||
tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
|
||||
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
for _, tag := range tags {
|
||||
if tag != nil && *tag != "" {
|
||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
|
||||
if err == nil {
|
||||
channelData = append(channelData, tagChannel...)
|
||||
}
|
||||
if tag == nil || *tag == "" {
|
||||
continue
|
||||
}
|
||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
filtered := make([]*model.Channel, 0)
|
||||
for _, ch := range tagChannels {
|
||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if typeFilter >= 0 && ch.Type != typeFilter {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
channelData = append(channelData, filtered...)
|
||||
}
|
||||
total, _ = model.CountAllTags()
|
||||
} else {
|
||||
channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
|
||||
baseQuery := model.DB.Model(&model.Channel{})
|
||||
if typeFilter >= 0 {
|
||||
baseQuery = baseQuery.Where("type = ?", typeFilter)
|
||||
}
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
|
||||
baseQuery.Count(&total)
|
||||
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
|
||||
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
channelData = channels
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelData,
|
||||
|
||||
countQuery := model.DB.Model(&model.Channel{})
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
var results []struct {
|
||||
Type int64
|
||||
Count int64
|
||||
}
|
||||
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
typeCounts[r.Type] = r.Count
|
||||
}
|
||||
common.ApiSuccess(c, gin.H{
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": pageInfo.GetPage(),
|
||||
"page_size": pageInfo.GetPageSize(),
|
||||
"type_counts": typeCounts,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -91,46 +154,30 @@ func GetAllChannels(c *gin.Context) {
|
||||
func FetchUpstreamModels(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
//if channel.Type != common.ChannelTypeOpenAI {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": "仅支持 OpenAI 类型渠道",
|
||||
// })
|
||||
// return
|
||||
//}
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
case common.ChannelTypeAli:
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -146,7 +193,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
var ids []string
|
||||
for _, model := range result.Data {
|
||||
id := model.ID
|
||||
if channel.Type == common.ChannelTypeGemini {
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
id = strings.TrimPrefix(id, "models/")
|
||||
}
|
||||
ids = append(ids, id)
|
||||
@@ -160,18 +207,18 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
func FixChannelsAbilities(c *gin.Context) {
|
||||
count, err := model.FixAbility()
|
||||
success, fails, err := model.FixAbility()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
"data": gin.H{
|
||||
"success": success,
|
||||
"fails": fails,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -179,6 +226,8 @@ func SearchChannels(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
modelKeyword := c.Query("model")
|
||||
statusParam := c.Query("status")
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
channelData := make([]*model.Channel, 0)
|
||||
@@ -210,10 +259,74 @@ func SearchChannels(c *gin.Context) {
|
||||
}
|
||||
channelData = channels
|
||||
}
|
||||
|
||||
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
|
||||
filtered := make([]*model.Channel, 0, len(channelData))
|
||||
for _, ch := range channelData {
|
||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
channelData = filtered
|
||||
}
|
||||
|
||||
// calculate type counts for search results
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, channel := range channelData {
|
||||
typeCounts[int64(channel.Type)]++
|
||||
}
|
||||
|
||||
typeParam := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeParam != "" {
|
||||
if tp, err := strconv.Atoi(typeParam); err == nil {
|
||||
typeFilter = tp
|
||||
}
|
||||
}
|
||||
|
||||
if typeFilter >= 0 {
|
||||
filtered := make([]*model.Channel, 0, len(channelData))
|
||||
for _, ch := range channelData {
|
||||
if ch.Type == typeFilter {
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
}
|
||||
channelData = filtered
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
total := len(channelData)
|
||||
startIdx := (page - 1) * pageSize
|
||||
if startIdx > total {
|
||||
startIdx = total
|
||||
}
|
||||
endIdx := startIdx + pageSize
|
||||
if endIdx > total {
|
||||
endIdx = total
|
||||
}
|
||||
|
||||
pagedData := channelData[startIdx:endIdx]
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelData,
|
||||
"data": gin.H{
|
||||
"items": pagedData,
|
||||
"total": total,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -221,18 +334,12 @@ func SearchChannels(c *gin.Context) {
|
||||
func GetChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -243,66 +350,167 @@ func GetChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func AddChannel(c *gin.Context) {
|
||||
channel := model.Channel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
// validateChannel 通用的渠道校验函数
|
||||
func validateChannel(channel *model.Channel, isAdd bool) error {
|
||||
// 校验 channel settings
|
||||
if err := channel.ValidateSettings(); err != nil {
|
||||
return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
|
||||
}
|
||||
|
||||
// 如果是添加操作,检查 channel 和 key 是否为空
|
||||
if isAdd {
|
||||
if channel == nil || channel.Key == "" {
|
||||
return fmt.Errorf("channel cannot be empty")
|
||||
}
|
||||
|
||||
// 检查模型名称长度是否超过 255
|
||||
for _, m := range channel.GetModels() {
|
||||
if len(m) > 255 {
|
||||
return fmt.Errorf("模型名称过长: %s", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// VertexAI 特殊校验
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
return fmt.Errorf("部署地区不能为空")
|
||||
}
|
||||
|
||||
regionMap, err := common.StrToMap(channel.Other)
|
||||
if err != nil {
|
||||
return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
|
||||
}
|
||||
|
||||
if regionMap["default"] == nil {
|
||||
return fmt.Errorf("部署地区必须包含default字段")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AddChannelRequest struct {
|
||||
Mode string `json:"mode"`
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
Channel *model.Channel `json:"channel"`
|
||||
}
|
||||
|
||||
func getVertexArrayKeys(keys string) ([]string, error) {
|
||||
if keys == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var keyArray []interface{}
|
||||
err := common.Unmarshal([]byte(keys), &keyArray)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
|
||||
}
|
||||
cleanKeys := make([]string, 0, len(keyArray))
|
||||
for _, key := range keyArray {
|
||||
var keyStr string
|
||||
switch v := key.(type) {
|
||||
case string:
|
||||
keyStr = strings.TrimSpace(v)
|
||||
default:
|
||||
bytes, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
|
||||
}
|
||||
keyStr = string(bytes)
|
||||
}
|
||||
if keyStr != "" {
|
||||
cleanKeys = append(cleanKeys, keyStr)
|
||||
}
|
||||
}
|
||||
if len(cleanKeys) == 0 {
|
||||
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
|
||||
}
|
||||
return cleanKeys, nil
|
||||
}
|
||||
|
||||
func AddChannel(c *gin.Context) {
|
||||
addChannelRequest := AddChannelRequest{}
|
||||
err := c.ShouldBindJSON(&addChannelRequest)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用统一的校验函数
|
||||
if err := validateChannel(addChannelRequest.Channel, true); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
channel.CreatedTime = common.GetTimestamp()
|
||||
keys := strings.Split(channel.Key, "\n")
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区不能为空",
|
||||
})
|
||||
return
|
||||
} else {
|
||||
if common.IsJsonStr(channel.Other) {
|
||||
// must have default
|
||||
regionMap := common.StrToMap(channel.Other)
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
|
||||
keys := make([]string, 0)
|
||||
switch addChannelRequest.Mode {
|
||||
case "multi_to_single":
|
||||
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
|
||||
addChannelRequest.Channel.Key = strings.Join(array, "\n")
|
||||
} else {
|
||||
cleanKeys := make([]string, 0)
|
||||
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
cleanKeys = append(cleanKeys, key)
|
||||
}
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
|
||||
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
||||
}
|
||||
keys = []string{channel.Key}
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
case "batch":
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
// multi json
|
||||
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
|
||||
}
|
||||
case "single":
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不支持的添加模式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channels := make([]model.Channel, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
localChannel := channel
|
||||
localChannel := addChannelRequest.Channel
|
||||
localChannel.Key = key
|
||||
// Validate the length of the model name
|
||||
models := strings.Split(localChannel.Models, ",")
|
||||
for _, model := range models {
|
||||
if len(model) > 255 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("模型名称过长: %s", model),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
channels = append(channels, localChannel)
|
||||
channels = append(channels, *localChannel)
|
||||
}
|
||||
err = model.BatchInsertChannels(channels)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -317,12 +525,10 @@ func DeleteChannel(c *gin.Context) {
|
||||
channel := model.Channel{Id: id}
|
||||
err := channel.Delete()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -333,12 +539,10 @@ func DeleteChannel(c *gin.Context) {
|
||||
func DeleteDisabledChannel(c *gin.Context) {
|
||||
rows, err := model.DeleteDisabledChannel()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -369,12 +573,10 @@ func DisableTagChannels(c *gin.Context) {
|
||||
}
|
||||
err = model.DisableChannelByTag(channelTag.Tag)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -394,12 +596,10 @@ func EnableTagChannels(c *gin.Context) {
|
||||
}
|
||||
err = model.EnableChannelByTag(channelTag.Tag)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -426,12 +626,10 @@ func EditTagChannels(c *gin.Context) {
|
||||
}
|
||||
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -456,12 +654,10 @@ func DeleteChannelBatch(c *gin.Context) {
|
||||
}
|
||||
err = model.BatchDeleteChannels(channelBatch.Ids)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -470,9 +666,29 @@ func DeleteChannelBatch(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
type PatchChannel struct {
|
||||
model.Channel
|
||||
MultiKeyMode *string `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
channel := model.Channel{}
|
||||
channel := PatchChannel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用统一的校验函数
|
||||
if err := validateChannel(&channel.Channel, false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||||
originChannel, err := model.GetChannelById(channel.Id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -480,35 +696,21 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区不能为空",
|
||||
})
|
||||
return
|
||||
} else {
|
||||
if common.IsJsonStr(channel.Other) {
|
||||
// must have default
|
||||
regionMap := common.StrToMap(channel.Other)
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
|
||||
channel.ChannelInfo = originChannel.ChannelInfo
|
||||
|
||||
// If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
|
||||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||||
}
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
channel.Key = ""
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -534,7 +736,7 @@ func FetchModels(c *gin.Context) {
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = common.ChannelBaseURLs[req.Type]
|
||||
baseURL = constant.ChannelBaseURLs[req.Type]
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
@@ -610,12 +812,10 @@ func BatchSetChannelTag(c *gin.Context) {
|
||||
}
|
||||
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -664,3 +864,53 @@ func GetTagModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// CopyChannel handles cloning an existing channel with its key.
|
||||
// POST /api/channel/copy/:id
|
||||
// Optional query params:
|
||||
//
|
||||
// suffix - string appended to the original name (default "_复制")
|
||||
// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
|
||||
func CopyChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
|
||||
return
|
||||
}
|
||||
|
||||
suffix := c.DefaultQuery("suffix", "_复制")
|
||||
resetBalance := true
|
||||
if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
|
||||
if v, err := strconv.ParseBool(rbStr); err == nil {
|
||||
resetBalance = v
|
||||
}
|
||||
}
|
||||
|
||||
// fetch original channel with key
|
||||
origin, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// clone channel
|
||||
clone := *origin // shallow copy is sufficient as we will overwrite primitives
|
||||
clone.Id = 0 // let DB auto-generate
|
||||
clone.CreatedTime = common.GetTimestamp()
|
||||
clone.Name = origin.Name + suffix
|
||||
clone.TestTime = 0
|
||||
clone.ResponseTime = 0
|
||||
if resetBalance {
|
||||
clone.Balance = 0
|
||||
clone.UsedQuota = 0
|
||||
}
|
||||
|
||||
// insert
|
||||
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
// success
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
|
||||
}
|
||||
|
||||
103
controller/console_migrate.go
Normal file
103
controller/console_migrate.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// 用于迁移检测的旧键,该文件下个版本会删除
|
||||
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
||||
func MigrateConsoleSetting(c *gin.Context) {
|
||||
// 读取全部 option
|
||||
opts, err := model.AllOption()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// 建立 map
|
||||
valMap := map[string]string{}
|
||||
for _, o := range opts {
|
||||
valMap[o.Key] = o.Value
|
||||
}
|
||||
|
||||
// 处理 APIInfo
|
||||
if v := valMap["ApiInfo"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
if len(arr) > 50 {
|
||||
arr = arr[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(arr)
|
||||
model.UpdateOption("console_setting.api_info", string(bytes))
|
||||
}
|
||||
model.UpdateOption("ApiInfo", "")
|
||||
}
|
||||
// Announcements 直接搬
|
||||
if v := valMap["Announcements"]; v != "" {
|
||||
model.UpdateOption("console_setting.announcements", v)
|
||||
model.UpdateOption("Announcements", "")
|
||||
}
|
||||
// FAQ 转换
|
||||
if v := valMap["FAQ"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
out := []map[string]interface{}{}
|
||||
for _, item := range arr {
|
||||
q, _ := item["question"].(string)
|
||||
if q == "" {
|
||||
q, _ = item["title"].(string)
|
||||
}
|
||||
a, _ := item["answer"].(string)
|
||||
if a == "" {
|
||||
a, _ = item["content"].(string)
|
||||
}
|
||||
if q != "" && a != "" {
|
||||
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
||||
}
|
||||
}
|
||||
if len(out) > 50 {
|
||||
out = out[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(out)
|
||||
model.UpdateOption("console_setting.faq", string(bytes))
|
||||
}
|
||||
model.UpdateOption("FAQ", "")
|
||||
}
|
||||
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
||||
url := valMap["UptimeKumaUrl"]
|
||||
slug := valMap["UptimeKumaSlug"]
|
||||
if url != "" && slug != "" {
|
||||
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
||||
groups := []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"categoryName": "old",
|
||||
"url": url,
|
||||
"slug": slug,
|
||||
"description": "",
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(groups)
|
||||
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
||||
}
|
||||
// 清空旧键内容
|
||||
if url != "" {
|
||||
model.UpdateOption("UptimeKumaUrl", "")
|
||||
}
|
||||
if slug != "" {
|
||||
model.UpdateOption("UptimeKumaSlug", "")
|
||||
}
|
||||
|
||||
// 删除旧键记录
|
||||
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
||||
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
||||
|
||||
// 重新加载 OptionMap
|
||||
model.InitOptionMap()
|
||||
common.SysLog("console setting migrated")
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||
}
|
||||
@@ -5,13 +5,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GitHubOAuthResponse struct {
|
||||
@@ -103,10 +104,7 @@ func GitHubOAuth(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -185,10 +183,7 @@ func GitHubBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -207,19 +202,13 @@ func GitHubBind(c *gin.Context) {
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.GitHubId = githubUser.Login
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -239,10 +228,7 @@ func GenerateOAuthCode(c *gin.Context) {
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
for groupName := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.GetUserGroup(userId, false)
|
||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
||||
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if setting.GroupInUserUsableGroups("auto") {
|
||||
usableGroups["auto"] = map[string]interface{}{
|
||||
"ratio": "自动",
|
||||
"desc": setting.GetUsableGroupDescription("auto"),
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -38,10 +38,7 @@ func LinuxDoBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -63,20 +60,14 @@ func LinuxDoBind(c *gin.Context) {
|
||||
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -202,10 +193,7 @@ func LinuxdoOAuth(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -10,14 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func GetAllLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
@@ -26,38 +19,19 @@ func GetAllLogs(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": map[string]any{
|
||||
"items": logs,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(logs)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
userId := c.GetInt("id")
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
@@ -65,24 +39,14 @@ func GetUserLogs(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": map[string]any{
|
||||
"items": logs,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(logs)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -90,10 +54,7 @@ func SearchAllLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
logs, err := model.SearchAllLogs(keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -109,10 +70,7 @@ func SearchUserLogs(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
logs, err := model.SearchUserLogs(userId, keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -198,10 +156,7 @@ func DeleteHistoryLogs(c *gin.Context) {
|
||||
}
|
||||
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -5,17 +5,16 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func UpdateMidjourneyTaskBulk() {
|
||||
@@ -214,10 +213,7 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
||||
}
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
// 解析其他查询参数
|
||||
queryParams := model.TaskQueryParams{
|
||||
@@ -227,31 +223,24 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.CountAllTasks(queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func GetUserMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
log.Printf("userId = %d \n", userId)
|
||||
|
||||
queryParams := model.TaskQueryParams{
|
||||
MjID: c.Query("mj_id"),
|
||||
@@ -259,19 +248,16 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.CountAllUserTask(userId, queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
@@ -37,52 +38,77 @@ func TestStatus(c *gin.Context) {
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
"setup": constant.Setup,
|
||||
}
|
||||
|
||||
// 根据启用状态注入可选内容
|
||||
if cs.ApiInfoEnabled {
|
||||
data["api_info"] = console_setting.GetApiInfo()
|
||||
}
|
||||
if cs.AnnouncementsEnabled {
|
||||
data["announcements"] = console_setting.GetAnnouncements()
|
||||
}
|
||||
if cs.FAQEnabled {
|
||||
data["faq"] = console_setting.GetFAQ()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
"setup": constant.Setup,
|
||||
"api_info": setting.GetApiInfo(),
|
||||
"announcements": setting.GetAnnouncements(),
|
||||
"faq": setting.GetFAQ(),
|
||||
},
|
||||
"data": data,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -192,10 +218,7 @@ func SendEmailVerification(c *gin.Context) {
|
||||
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -231,10 +254,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -269,10 +289,7 @@ func ResetPassword(c *gin.Context) {
|
||||
password := common.GenerateVerificationCode(12)
|
||||
err = model.ResetUserPasswordByEmail(req.Email, password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.DeleteKey(req.Email, common.PasswordResetPurpose)
|
||||
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -23,30 +24,10 @@ var openAIModels []dto.OpenAIModels
|
||||
var openAIModelsMap map[string]dto.OpenAIModels
|
||||
var channelId2Models map[int][]string
|
||||
|
||||
func getPermission() []dto.OpenAIModelPermission {
|
||||
var permission []dto.OpenAIModelPermission
|
||||
permission = append(permission, dto.OpenAIModelPermission{
|
||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
||||
Object: "model_permission",
|
||||
Created: 1626777600,
|
||||
AllowCreateEngine: true,
|
||||
AllowSampling: true,
|
||||
AllowLogprobs: true,
|
||||
AllowSearchIndices: false,
|
||||
AllowView: true,
|
||||
AllowFineTuning: false,
|
||||
Organization: "*",
|
||||
Group: nil,
|
||||
IsBlocking: false,
|
||||
})
|
||||
return permission
|
||||
}
|
||||
|
||||
func init() {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
permission := getPermission()
|
||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||
if i == relayconstant.APITypeAIProxyLibrary {
|
||||
for i := 0; i < constant.APITypeDummy; i++ {
|
||||
if i == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
adaptor := relay.GetAdaptor(i)
|
||||
@@ -54,69 +35,51 @@ func init() {
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, modelName := range ai360.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: ai360.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: ai360.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range moonshot.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: moonshot.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: moonshot.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range lingyiwanwu.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: lingyiwanwu.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: lingyiwanwu.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range minimax.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: minimax.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: minimax.ChannelName,
|
||||
})
|
||||
}
|
||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
})
|
||||
}
|
||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||||
@@ -124,9 +87,9 @@ func init() {
|
||||
openAIModelsMap[aiModel.Id] = aiModel
|
||||
}
|
||||
channelId2Models = make(map[int][]string)
|
||||
for i := 1; i <= common.ChannelTypeDummy; i++ {
|
||||
apiType, success := relayconstant.ChannelType2APIType(i)
|
||||
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
|
||||
for i := 1; i <= constant.ChannelTypeDummy; i++ {
|
||||
apiType, success := common.ChannelType2APIType(i)
|
||||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||
@@ -134,15 +97,17 @@ func init() {
|
||||
adaptor.Init(meta)
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
}
|
||||
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||||
return m.Id
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
permission := getPermission()
|
||||
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
@@ -150,23 +115,22 @@ func ListModels(c *gin.Context) {
|
||||
tokenModelLimit = map[string]bool{}
|
||||
}
|
||||
for allowModel, _ := range tokenModelLimit {
|
||||
if _, ok := openAIModelsMap[allowModel]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
|
||||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: allowModel,
|
||||
Parent: nil,
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId, true)
|
||||
userGroup, err := model.GetUserGroup(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -175,23 +139,34 @@ func ListModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
group := userGroup
|
||||
tokenGroup := c.GetString("token_group")
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
models := model.GetGroupModels(group)
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupEnabledModels(group)
|
||||
}
|
||||
for _, modelName := range models {
|
||||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: s,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: s,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,10 +126,7 @@ func OidcAuth(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -195,10 +192,7 @@ func OidcBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
@@ -217,19 +211,13 @@ func OidcBind(c *gin.Context) {
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
@@ -102,7 +104,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = setting.CheckGroupRatio(option.Value)
|
||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -119,8 +121,8 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "ApiInfo":
|
||||
err = setting.ValidateApiInfo(option.Value)
|
||||
case "console_setting.api_info":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -128,8 +130,8 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "Announcements":
|
||||
err = setting.ValidateConsoleSettings(option.Value, "Announcements")
|
||||
case "console_setting.announcements":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -137,8 +139,17 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "FAQ":
|
||||
err = setting.ValidateConsoleSettings(option.Value, "FAQ")
|
||||
case "console_setting.faq":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.uptime_kuma_groups":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -149,10 +160,7 @@ func UpdateOption(c *gin.Context) {
|
||||
}
|
||||
err = model.UpdateOption(option.Key, option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -3,44 +3,44 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
if newAPIError != nil {
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
@@ -51,19 +51,34 @@ func Playground(c *gin.Context) {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
userCache.WriteContext(c)
|
||||
|
||||
tempToken := &model.Token{
|
||||
UserId: userId,
|
||||
Name: fmt.Sprintf("playground-%s", group),
|
||||
Group: group,
|
||||
}
|
||||
_ = middleware.SetupContextForToken(c, tempToken)
|
||||
_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
Relay(c)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
@@ -12,7 +13,7 @@ func GetPricing(c *gin.Context) {
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := map[string]float64{}
|
||||
for s, f := range setting.GetGroupRatioCopy() {
|
||||
for s, f := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupRatio[s] = f
|
||||
}
|
||||
var group string
|
||||
@@ -20,12 +21,18 @@ func GetPricing(c *gin.Context) {
|
||||
user, err := model.GetUserCache(userId.(int))
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
for g := range groupRatio {
|
||||
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
|
||||
if ok {
|
||||
groupRatio[g] = ratio
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
usableGroup = setting.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range setting.GetGroupRatioCopy() {
|
||||
for group := range ratio_setting.GetGroupRatioCopy() {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
@@ -40,7 +47,7 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
func ResetModelRatio(c *gin.Context) {
|
||||
defaultStr := operation_setting.DefaultModelRatio2JSONString()
|
||||
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
|
||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
@@ -49,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
|
||||
24
controller/ratio_config.go
Normal file
24
controller/ratio_config.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRatioConfig(c *gin.Context) {
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
474
controller/ratio_sync.go
Normal file
474
controller/ratio_sync.go
Normal file
@@ -0,0 +1,474 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
)
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Timeout <= 0 {
|
||||
req.Timeout = defaultTimeoutSeconds
|
||||
}
|
||||
|
||||
var upstreams []dto.UpstreamDTO
|
||||
|
||||
if len(req.Upstreams) > 0 {
|
||||
for _, u := range req.Upstreams {
|
||||
if strings.HasPrefix(u.BaseURL, "http") {
|
||||
if u.Endpoint == "" {
|
||||
u.Endpoint = defaultEndpoint
|
||||
}
|
||||
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||
upstreams = append(upstreams, u)
|
||||
}
|
||||
}
|
||||
} else if len(req.ChannelIDs) > 0 {
|
||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||
for _, id64 := range req.ChannelIDs {
|
||||
intIds = append(intIds, int(id64))
|
||||
}
|
||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||
if err != nil {
|
||||
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||
return
|
||||
}
|
||||
for _, ch := range dbChannels {
|
||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||
ID: ch.Id,
|
||||
Name: ch.Name,
|
||||
BaseURL: strings.TrimRight(base, "/"),
|
||||
Endpoint: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(upstreams) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan upstreamResult, len(upstreams))
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
|
||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(chItem dto.UpstreamDTO) {
|
||||
defer wg.Done()
|
||||
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
endpoint := chItem.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL := chItem.BaseURL + endpoint
|
||||
|
||||
uniqueName := chItem.Name
|
||||
if chItem.ID != 0 {
|
||||
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
// 兼容两种上游接口格式:
|
||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||
var body struct {
|
||||
Success bool `json:"success"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if !body.Success {
|
||||
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试按 type1 解析
|
||||
var type1Data map[string]any
|
||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isType1 {
|
||||
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
}
|
||||
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||
return
|
||||
}
|
||||
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
|
||||
for _, item := range pricingItems {
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
}
|
||||
|
||||
converted := make(map[string]any)
|
||||
|
||||
if len(modelRatioMap) > 0 {
|
||||
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||
for k, v := range modelRatioMap {
|
||||
ratioAny[k] = v
|
||||
}
|
||||
converted["model_ratio"] = ratioAny
|
||||
}
|
||||
|
||||
if len(completionRatioMap) > 0 {
|
||||
compAny := make(map[string]any, len(completionRatioMap))
|
||||
for k, v := range completionRatioMap {
|
||||
compAny[k] = v
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
for k, v := range modelPriceMap {
|
||||
priceAny[k] = v
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
for r := range ch {
|
||||
if r.Err != "" {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "error",
|
||||
Error: r.Err,
|
||||
})
|
||||
} else {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "success",
|
||||
})
|
||||
successfulChannels = append(successfulChannels, struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}{name: r.Name, data: r.Data})
|
||||
}
|
||||
}
|
||||
|
||||
differences := buildDifferences(localData, successfulChannels)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"differences": differences,
|
||||
"test_results": testResults,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}) map[string]map[string]dto.DifferenceItem {
|
||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
confidenceMap := make(map[string]map[string]bool)
|
||||
|
||||
// 预处理阶段:检查pricing接口的可信度
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
|
||||
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||
for modelName := range allModels {
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
confidenceValues := make(map[string]bool)
|
||||
hasUpstreamValue := false
|
||||
hasDifference := false
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && localValue != val {
|
||||
hasDifference = true
|
||||
} else if localValue == val {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
|
||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||
hasDifference = true
|
||||
}
|
||||
|
||||
upstreamValues[channel.name] = upstreamValue
|
||||
|
||||
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||
}
|
||||
|
||||
shouldInclude := false
|
||||
|
||||
if localValue != nil {
|
||||
if hasDifference {
|
||||
shouldInclude = true
|
||||
}
|
||||
} else {
|
||||
if hasUpstreamValue {
|
||||
shouldInclude = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldInclude {
|
||||
if differences[modelName] == nil {
|
||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||
}
|
||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||
Current: localValue,
|
||||
Upstreams: upstreamValues,
|
||||
Confidence: confidenceValues,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
channelHasDiff := make(map[string]bool)
|
||||
for _, ratioMap := range differences {
|
||||
for _, item := range ratioMap {
|
||||
for chName, val := range item.Upstreams {
|
||||
if val != nil && val != "same" {
|
||||
channelHasDiff[chName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName, ratioMap := range differences {
|
||||
for ratioType, item := range ratioMap {
|
||||
for chName := range item.Upstreams {
|
||||
if !channelHasDiff[chName] {
|
||||
delete(item.Upstreams, chName)
|
||||
delete(item.Confidence, chName)
|
||||
}
|
||||
}
|
||||
|
||||
allSame := true
|
||||
for _, v := range item.Upstreams {
|
||||
if v != "same" {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(item.Upstreams) == 0 || allSame {
|
||||
delete(ratioMap, ratioType)
|
||||
} else {
|
||||
differences[modelName][ratioType] = item
|
||||
}
|
||||
}
|
||||
|
||||
if len(ratioMap) == 0 {
|
||||
delete(differences, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
return differences
|
||||
}
|
||||
|
||||
func GetSyncableChannels(c *gin.Context) {
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var syncableChannels []dto.SyncableChannel
|
||||
for _, channel := range channels {
|
||||
if channel.GetBaseURL() != "" {
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: channel.Id,
|
||||
Name: channel.Name,
|
||||
BaseURL: channel.GetBaseURL(),
|
||||
Status: channel.Status,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": syncableChannels,
|
||||
})
|
||||
}
|
||||
@@ -1,90 +1,52 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllRedemptions(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": redemptions,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(redemptions)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchRedemptions(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": redemptions,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(redemptions)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetRedemption(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
redemption, err := model.GetRedemptionById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -99,13 +61,10 @@ func AddRedemption(c *gin.Context) {
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
||||
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "兑换码名称长度必须在1-20之间",
|
||||
@@ -126,6 +85,10 @@ func AddRedemption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
var keys []string
|
||||
for i := 0; i < redemption.Count; i++ {
|
||||
key := common.GetUUID()
|
||||
@@ -135,6 +98,7 @@ func AddRedemption(c *gin.Context) {
|
||||
Key: key,
|
||||
CreatedTime: common.GetTimestamp(),
|
||||
Quota: redemption.Quota,
|
||||
ExpiredTime: redemption.ExpiredTime,
|
||||
}
|
||||
err = cleanRedemption.Insert()
|
||||
if err != nil {
|
||||
@@ -159,10 +123,7 @@ func DeleteRedemption(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
err := model.DeleteRedemptionById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -177,33 +138,30 @@ func UpdateRedemption(c *gin.Context) {
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanRedemption.Status = redemption.Status
|
||||
} else {
|
||||
if statusOnly == "" {
|
||||
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// If you add more fields, please also update redemption.Update()
|
||||
cleanRedemption.Name = redemption.Name
|
||||
cleanRedemption.Quota = redemption.Quota
|
||||
cleanRedemption.ExpiredTime = redemption.ExpiredTime
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanRedemption.Status = redemption.Status
|
||||
}
|
||||
err = cleanRedemption.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -213,3 +171,24 @@ func UpdateRedemption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteInvalidRedemption(c *gin.Context) {
|
||||
rows, err := model.DeleteInvalidRedemptions()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": rows,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func validateExpiredTime(expired int64) error {
|
||||
if expired != 0 && expired < common.GetTimestamp() {
|
||||
return errors.New("过期时间不能早于当前时间")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,23 +8,24 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
var err *types.NewAPIError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
err = relay.ImageHelper(c)
|
||||
@@ -55,43 +56,42 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.Error.Type
|
||||
other["error_code"] = err.Error.Code
|
||||
other["error_type"] = err.GetErrorType()
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = relayRequest(c, relayMode, channel)
|
||||
newAPIError = relayRequest(c, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -101,14 +101,14 @@ func Relay(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,35 +127,34 @@ func WssRelay(c *gin.Context) {
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = wssRequest(c, ws, relayMode, channel)
|
||||
newAPIError = wssRequest(c, ws, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -165,12 +164,12 @@ func WssRelay(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,27 +178,25 @@ func RelayClaude(c *gin.Context) {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var claudeErr *dto.ClaudeErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
claudeErr = claudeRequest(c, channel)
|
||||
newAPIError = claudeRequest(c, channel)
|
||||
|
||||
if claudeErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -209,30 +206,30 @@ func RelayClaude(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if claudeErr != nil {
|
||||
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
|
||||
c.JSON(claudeErr.StatusCode, gin.H{
|
||||
if newAPIError != nil {
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": claudeErr.Error,
|
||||
"error": newAPIError.ToClaudeError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relayHandler(c, relayMode)
|
||||
}
|
||||
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
@@ -245,7 +242,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
|
||||
c.Set("use_channel", useChannel)
|
||||
}
|
||||
|
||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
|
||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||||
if retryCount == 0 {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
autoBanInt := 1
|
||||
@@ -259,19 +256,28 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
|
||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
return nil, newAPIError
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
||||
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
|
||||
if openaiErr == nil {
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
if types.IsChannelError(openaiErr) {
|
||||
return true
|
||||
}
|
||||
if types.IsSkipRetryError(openaiErr) {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
@@ -295,7 +301,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
channelType := c.GetInt("channel_type")
|
||||
if channelType == common.ChannelTypeAnthropic {
|
||||
if channelType == constant.ChannelTypeAnthropic {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -310,12 +316,12 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,9 +394,10 @@ func RelayTask(c *gin.Context) {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||||
if newAPIError != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
channelId = channel.Id
|
||||
@@ -398,9 +405,9 @@ func RelayTask(c *gin.Context) {
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
taskErr = taskRelayHandler(c, relayMode)
|
||||
}
|
||||
@@ -420,7 +427,7 @@ func RelayTask(c *gin.Context) {
|
||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayMode)
|
||||
|
||||
@@ -75,6 +75,14 @@ func PostSetup(c *gin.Context) {
|
||||
|
||||
// If root doesn't exist, validate and create admin account
|
||||
if !rootExists {
|
||||
// Validate username length: max 12 characters to align with model.User validation
|
||||
if len(req.Username) > 12 {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "用户名长度不能超过12个字符",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Validate password
|
||||
if req.Password != req.ConfirmPassword {
|
||||
c.JSON(400, gin.H{
|
||||
|
||||
136
controller/swag_video.go
Normal file
136
controller/swag_video.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// VideoGenerations
|
||||
// @Summary 生成视频
|
||||
// @Description 调用视频生成接口生成视频
|
||||
// @Description 支持多种视频生成服务:
|
||||
// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
|
||||
// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||
// @Param request body dto.VideoRequest true "视频生成请求参数"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /v1/video/generations [post]
|
||||
func VideoGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
// VideoGenerationsTaskId
|
||||
// @Summary 查询视频
|
||||
// @Description 根据任务ID查询视频生成任务的状态和结果
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param task_id path string true "Task ID"
|
||||
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /v1/video/generations/{task_id} [get]
|
||||
func VideoGenerationsTaskId(c *gin.Context) {
|
||||
}
|
||||
|
||||
// KlingText2VideoGenerations
|
||||
// @Summary 可灵文生视频
|
||||
// @Description 调用可灵AI文生视频接口,生成视频内容
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||
// @Param request body KlingText2VideoRequest true "视频生成请求参数"
|
||||
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /kling/v1/videos/text2video [post]
|
||||
func KlingText2VideoGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
type KlingText2VideoRequest struct {
|
||||
ModelName string `json:"model_name,omitempty" example:"kling-v1"`
|
||||
Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||
Mode string `json:"mode,omitempty" example:"std"`
|
||||
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||
Duration string `json:"duration,omitempty" example:"5"`
|
||||
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
|
||||
}
|
||||
|
||||
type KlingCameraControl struct {
|
||||
Type string `json:"type,omitempty" example:"simple"`
|
||||
Config *KlingCameraConfig `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
type KlingCameraConfig struct {
|
||||
Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
|
||||
Vertical float64 `json:"vertical,omitempty" example:"0"`
|
||||
Pan float64 `json:"pan,omitempty" example:"0"`
|
||||
Tilt float64 `json:"tilt,omitempty" example:"0"`
|
||||
Roll float64 `json:"roll,omitempty" example:"0"`
|
||||
Zoom float64 `json:"zoom,omitempty" example:"0"`
|
||||
}
|
||||
|
||||
// KlingImage2VideoGenerations
|
||||
// @Summary 可灵官方-图生视频
|
||||
// @Description 调用可灵AI图生视频接口,生成视频内容
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||
// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
|
||||
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /kling/v1/videos/image2video [post]
|
||||
func KlingImage2VideoGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
type KlingImage2VideoRequest struct {
|
||||
ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
|
||||
Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
|
||||
Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||
Mode string `json:"mode,omitempty" example:"std"`
|
||||
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||
Duration string `json:"duration,omitempty" example:"5"`
|
||||
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
|
||||
}
|
||||
|
||||
// KlingImage2videoTaskId godoc
|
||||
// @Summary 可灵任务查询--图生视频
|
||||
// @Description Query the status and result of a Kling video generation task by task ID
|
||||
// @Tags Origin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task_id path string true "Task ID"
|
||||
// @Router /kling/v1/videos/image2video/{task_id} [get]
|
||||
func KlingImage2videoTaskId(c *gin.Context) {}
|
||||
|
||||
// KlingText2videoTaskId godoc
|
||||
// @Summary 可灵任务查询--文生视频
|
||||
// @Description Query the status and result of a Kling text-to-video generation task by task ID
|
||||
// @Tags Origin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task_id path string true "Task ID"
|
||||
// @Router /kling/v1/videos/text2video/{task_id} [get]
|
||||
func KlingText2videoTaskId(c *gin.Context) {}
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -17,6 +15,9 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func UpdateTaskBulk() {
|
||||
@@ -75,7 +76,9 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
default:
|
||||
common.SysLog("未知平台")
|
||||
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,10 +226,8 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
|
||||
}
|
||||
|
||||
func GetAllTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
// 解析其他查询参数
|
||||
@@ -237,25 +238,18 @@ func GetAllTask(c *gin.Context) {
|
||||
Action: c.Query("action"),
|
||||
StartTimestamp: startTimestamp,
|
||||
EndTimestamp: endTimestamp,
|
||||
ChannelID: c.Query("channel_id"),
|
||||
}
|
||||
|
||||
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Task, 0)
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllTasks(queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func GetUserTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
@@ -271,14 +265,9 @@ func GetUserTask(c *gin.Context) {
|
||||
EndTimestamp: endTimestamp,
|
||||
}
|
||||
|
||||
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Task, 0)
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
147
controller/task_video.go
Normal file
147
controller/task_video.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if resp.StatusCode != http.StatusOK {
|
||||
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||
//}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
taskResult.Url = t.FailReason
|
||||
taskResult.Progress = t.Progress
|
||||
taskResult.Reason = t.FailReason
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = responseBody
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
return fmt.Errorf("task %s status is empty", taskId)
|
||||
}
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = "10%"
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = "20%"
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = "30%"
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Url
|
||||
case model.TaskStatusFailure:
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,38 +1,26 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
size, _ := strconv.Atoi(c.Query("size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if size <= 0 {
|
||||
size = common.ItemsPerPage
|
||||
} else if size > 100 {
|
||||
size = 100
|
||||
}
|
||||
tokens, err := model.GetAllUserTokens(userId, p*size, size)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": tokens,
|
||||
})
|
||||
total, _ := model.CountUserTokens(userId)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tokens)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -42,10 +30,7 @@ func SearchTokens(c *gin.Context) {
|
||||
token := c.Query("token")
|
||||
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -60,18 +45,12 @@ func GetToken(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt("id")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
token, err := model.GetTokenByIds(id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -87,10 +66,7 @@ func GetTokenStatus(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
token, err := model.GetTokenByIds(tokenId, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
expiredAt := token.ExpiredTime
|
||||
@@ -110,10 +86,7 @@ func AddToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
@@ -148,10 +121,7 @@ func AddToken(c *gin.Context) {
|
||||
}
|
||||
err = cleanToken.Insert()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -166,10 +136,7 @@ func DeleteToken(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
err := model.DeleteTokenById(id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -185,10 +152,7 @@ func UpdateToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
@@ -200,10 +164,7 @@ func UpdateToken(c *gin.Context) {
|
||||
}
|
||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if token.Status == common.TokenStatusEnabled {
|
||||
@@ -237,10 +198,7 @@ func UpdateToken(c *gin.Context) {
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -250,3 +208,29 @@ func UpdateToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type TokenBatch struct {
|
||||
Ids []int `json:"ids"`
|
||||
}
|
||||
|
||||
func DeleteTokenBatch(c *gin.Context) {
|
||||
tokenBatch := TokenBatch{}
|
||||
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
payType := "wxpay"
|
||||
if req.PaymentMethod == "zfb" {
|
||||
payType = "alipay"
|
||||
}
|
||||
if req.PaymentMethod == "wx" {
|
||||
req.PaymentMethod = "wxpay"
|
||||
payType = "wxpay"
|
||||
|
||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: payType,
|
||||
Type: req.PaymentMethod,
|
||||
ServiceTradeNo: tradeNo,
|
||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
|
||||
275
controller/topup_stripe.go
Normal file
275
controller/topup_stripe.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
"github.com/stripe/stripe-go/v81/checkout/session"
|
||||
"github.com/stripe/stripe-go/v81/webhook"
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentMethodStripe = "stripe"
|
||||
)
|
||||
|
||||
var stripeAdaptor = &StripeAdaptor{}
|
||||
|
||||
type StripePayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
}
|
||||
|
||||
type StripeAdaptor struct {
|
||||
}
|
||||
|
||||
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
|
||||
if req.Amount < getStripeMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getStripePayMoney(float64(req.Amount), group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
}
|
||||
|
||||
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||
if req.PaymentMethod != PaymentMethodStripe {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
return
|
||||
}
|
||||
if req.Amount < getStripeMinTopup() {
|
||||
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount > 10000 {
|
||||
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||
|
||||
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
|
||||
referenceId := "ref_" + common.Sha1([]byte(reference))
|
||||
|
||||
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
|
||||
if err != nil {
|
||||
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"pay_link": payLink,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func RequestStripeAmount(c *gin.Context) {
|
||||
var req StripePayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
stripeAdaptor.RequestAmount(c, &req)
|
||||
}
|
||||
|
||||
func RequestStripePay(c *gin.Context) {
|
||||
var req StripePayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
stripeAdaptor.RequestPay(c, &req)
|
||||
}
|
||||
|
||||
func StripeWebhook(c *gin.Context) {
|
||||
payload, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
signature := c.GetHeader("Stripe-Signature")
|
||||
endpointSecret := setting.StripeWebhookSecret
|
||||
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
|
||||
IgnoreAPIVersionMismatch: true,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case stripe.EventTypeCheckoutSessionCompleted:
|
||||
sessionCompleted(event)
|
||||
case stripe.EventTypeCheckoutSessionExpired:
|
||||
sessionExpired(event)
|
||||
default:
|
||||
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func sessionCompleted(event stripe.Event) {
|
||||
customerId := event.GetObjectValue("customer")
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "complete" != status {
|
||||
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
err := model.Recharge(referenceId, customerId)
|
||||
if err != nil {
|
||||
log.Println(err.Error(), referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||
}
|
||||
|
||||
func sessionExpired(event stripe.Event) {
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "expired" != status {
|
||||
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if len(referenceId) == 0 {
|
||||
log.Println("未提供支付单号")
|
||||
return
|
||||
}
|
||||
|
||||
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||
if topUp == nil {
|
||||
log.Println("充值订单不存在", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
log.Println("充值订单状态错误", referenceId)
|
||||
}
|
||||
|
||||
topUp.Status = common.TopUpStatusExpired
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("充值订单已过期", referenceId)
|
||||
}
|
||||
|
||||
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
||||
if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
|
||||
return "", fmt.Errorf("无效的Stripe API密钥")
|
||||
}
|
||||
|
||||
stripe.Key = setting.StripeApiSecret
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: stripe.String(referenceId),
|
||||
SuccessURL: stripe.String(setting.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(setting.ServerAddress + "/topup"),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(setting.StripePriceId),
|
||||
Quantity: stripe.Int64(amount),
|
||||
},
|
||||
},
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
}
|
||||
|
||||
if "" == customerId {
|
||||
if "" != email {
|
||||
params.CustomerEmail = stripe.String(email)
|
||||
}
|
||||
|
||||
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
||||
} else {
|
||||
params.Customer = stripe.String(customerId)
|
||||
}
|
||||
|
||||
result, err := session.New(params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func GetChargedAmount(count float64, user model.User) float64 {
|
||||
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||
if topUpGroupRatio == 0 {
|
||||
topUpGroupRatio = 1
|
||||
}
|
||||
|
||||
return count * topUpGroupRatio
|
||||
}
|
||||
|
||||
func getStripePayMoney(amount float64, group string) float64 {
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
// Using float64 for monetary calculations is acceptable here due to the small amounts involved
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
|
||||
return payMoney
|
||||
}
|
||||
|
||||
func getStripeMinTopup() int64 {
|
||||
minTopup := setting.StripeMinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
}
|
||||
return int64(minTopup)
|
||||
}
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/setting/console_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -14,45 +14,25 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type UptimeKumaMonitor struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
const (
|
||||
requestTimeout = 30 * time.Second
|
||||
httpTimeout = 10 * time.Second
|
||||
uptimeKeySuffix = "_24"
|
||||
apiStatusPath = "/api/status-page/"
|
||||
apiHeartbeatPath = "/api/status-page/heartbeat/"
|
||||
)
|
||||
|
||||
type UptimeKumaGroup struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Weight int `json:"weight"`
|
||||
MonitorList []UptimeKumaMonitor `json:"monitorList"`
|
||||
}
|
||||
|
||||
type UptimeKumaHeartbeat struct {
|
||||
Status int `json:"status"`
|
||||
Time string `json:"time"`
|
||||
Msg string `json:"msg"`
|
||||
Ping *float64 `json:"ping"`
|
||||
}
|
||||
|
||||
type UptimeKumaStatusResponse struct {
|
||||
PublicGroupList []UptimeKumaGroup `json:"publicGroupList"`
|
||||
}
|
||||
|
||||
type UptimeKumaHeartbeatResponse struct {
|
||||
HeartbeatList map[string][]UptimeKumaHeartbeat `json:"heartbeatList"`
|
||||
UptimeList map[string]float64 `json:"uptimeList"`
|
||||
}
|
||||
|
||||
type MonitorStatus struct {
|
||||
type Monitor struct {
|
||||
Name string `json:"name"`
|
||||
Uptime float64 `json:"uptime"`
|
||||
Status int `json:"status"`
|
||||
Group string `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUpstreamNon200 = errors.New("upstream non-200")
|
||||
ErrTimeout = errors.New("context deadline exceeded")
|
||||
)
|
||||
type UptimeGroupResult struct {
|
||||
CategoryName string `json:"categoryName"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
}
|
||||
|
||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
@@ -62,108 +42,113 @@ func getAndDecode(ctx context.Context, client *http.Client, url string, dest int
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return ErrTimeout
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return ErrUpstreamNon200
|
||||
return errors.New("non-200 status")
|
||||
}
|
||||
|
||||
return json.NewDecoder(resp.Body).Decode(dest)
|
||||
}
|
||||
|
||||
func GetUptimeKumaStatus(c *gin.Context) {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
uptimeKumaUrl := common.OptionMap["UptimeKumaUrl"]
|
||||
slug := common.OptionMap["UptimeKumaSlug"]
|
||||
common.OptionMapRWMutex.RUnlock()
|
||||
|
||||
if uptimeKumaUrl == "" || slug == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": []MonitorStatus{},
|
||||
})
|
||||
return
|
||||
func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
|
||||
url, _ := groupConfig["url"].(string)
|
||||
slug, _ := groupConfig["slug"].(string)
|
||||
categoryName, _ := groupConfig["categoryName"].(string)
|
||||
|
||||
result := UptimeGroupResult{
|
||||
CategoryName: categoryName,
|
||||
Monitors: []Monitor{},
|
||||
}
|
||||
|
||||
if url == "" || slug == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
uptimeKumaUrl = strings.TrimSuffix(uptimeKumaUrl, "/")
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := &http.Client{}
|
||||
|
||||
statusPageUrl := fmt.Sprintf("%s/api/status-page/%s", uptimeKumaUrl, slug)
|
||||
heartbeatUrl := fmt.Sprintf("%s/api/status-page/heartbeat/%s", uptimeKumaUrl, slug)
|
||||
|
||||
var (
|
||||
statusData UptimeKumaStatusResponse
|
||||
heartbeatData UptimeKumaHeartbeatResponse
|
||||
)
|
||||
baseURL := strings.TrimSuffix(url, "/")
|
||||
|
||||
var statusData struct {
|
||||
PublicGroupList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MonitorList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"monitorList"`
|
||||
} `json:"publicGroupList"`
|
||||
}
|
||||
|
||||
var heartbeatData struct {
|
||||
HeartbeatList map[string][]struct {
|
||||
Status int `json:"status"`
|
||||
} `json:"heartbeatList"`
|
||||
UptimeList map[string]float64 `json:"uptimeList"`
|
||||
}
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, statusPageUrl, &statusData)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
})
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, heartbeatUrl, &heartbeatData)
|
||||
})
|
||||
if g.Wait() != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
switch err {
|
||||
case ErrUpstreamNon200:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "上游接口出现问题",
|
||||
})
|
||||
case ErrTimeout:
|
||||
c.JSON(http.StatusRequestTimeout, gin.H{
|
||||
"success": false,
|
||||
"message": "请求上游接口超时",
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
for _, pg := range statusData.PublicGroupList {
|
||||
if len(pg.MonitorList) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, m := range pg.MonitorList {
|
||||
monitor := Monitor{
|
||||
Name: m.Name,
|
||||
Group: pg.Name,
|
||||
}
|
||||
|
||||
monitorID := strconv.Itoa(m.ID)
|
||||
|
||||
if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
|
||||
monitor.Uptime = uptime
|
||||
}
|
||||
|
||||
if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
|
||||
monitor.Status = heartbeats[0].Status
|
||||
}
|
||||
|
||||
result.Monitors = append(result.Monitors, monitor)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func GetUptimeKumaStatus(c *gin.Context) {
|
||||
groups := console_setting.GetUptimeKumaGroups()
|
||||
if len(groups) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
|
||||
return
|
||||
}
|
||||
|
||||
var monitors []MonitorStatus
|
||||
for _, group := range statusData.PublicGroupList {
|
||||
for _, monitor := range group.MonitorList {
|
||||
monitorStatus := MonitorStatus{
|
||||
Name: monitor.Name,
|
||||
Uptime: 0.0,
|
||||
Status: 0,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
|
||||
defer cancel()
|
||||
|
||||
uptimeKey := fmt.Sprintf("%d_24", monitor.ID)
|
||||
if uptime, exists := heartbeatData.UptimeList[uptimeKey]; exists {
|
||||
monitorStatus.Uptime = uptime
|
||||
}
|
||||
|
||||
heartbeatKey := fmt.Sprintf("%d", monitor.ID)
|
||||
if heartbeats, exists := heartbeatData.HeartbeatList[heartbeatKey]; exists && len(heartbeats) > 0 {
|
||||
latestHeartbeat := heartbeats[0]
|
||||
monitorStatus.Status = latestHeartbeat.Status
|
||||
}
|
||||
|
||||
monitors = append(monitors, monitorStatus)
|
||||
}
|
||||
client := &http.Client{Timeout: httpTimeout}
|
||||
results := make([]UptimeGroupResult, len(groups))
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
for i, group := range groups {
|
||||
i, group := i, group
|
||||
g.Go(func() error {
|
||||
results[i] = fetchGroupData(gCtx, client, group)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": monitors,
|
||||
})
|
||||
|
||||
g.Wait()
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllQuotaDates(c *gin.Context) {
|
||||
@@ -13,10 +15,7 @@ func GetAllQuotaDates(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -41,10 +40,7 @@ func GetUserQuotaDates(c *gin.Context) {
|
||||
}
|
||||
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
@@ -187,10 +188,7 @@ func Register(c *gin.Context) {
|
||||
cleanUser.Email = user.Email
|
||||
}
|
||||
if err := cleanUser.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -226,6 +224,9 @@ func Register(c *gin.Context) {
|
||||
UnlimitedQuota: true,
|
||||
ModelLimitsEnabled: false,
|
||||
}
|
||||
if setting.DefaultUseAutoGroup {
|
||||
token.Group = "auto"
|
||||
}
|
||||
if err := token.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -243,83 +244,45 @@ func Register(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetAllUsers(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.GetAllUsers(pageInfo)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": users,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
startIdx := (p - 1) * pageSize
|
||||
users, total, err := model.SearchUsers(keyword, group, startIdx, pageSize)
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": users,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
@@ -342,10 +305,7 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// get rand int 28-32
|
||||
@@ -370,10 +330,7 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -393,18 +350,12 @@ func TransferAffQuota(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
tran := TransferAffQuotaRequest{}
|
||||
if err := c.ShouldBindJSON(&tran); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
err = user.TransferAffQuotaToQuota(tran.Quota)
|
||||
@@ -425,10 +376,7 @@ func GetAffCode(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if user.AffCode == "" {
|
||||
@@ -453,12 +401,12 @@ func GetSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||
user.Remark = ""
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -474,16 +422,13 @@ func GetUserModels(c *gin.Context) {
|
||||
}
|
||||
user, err := model.GetUserCache(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
groups := setting.GetUserUsableGroups(user.Group)
|
||||
var models []string
|
||||
for group := range groups {
|
||||
for _, g := range model.GetGroupModels(group) {
|
||||
for _, g := range model.GetGroupEnabledModels(group) {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
@@ -519,10 +464,7 @@ func UpdateUser(c *gin.Context) {
|
||||
}
|
||||
originUser, err := model.GetUserById(updatedUser.Id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
@@ -545,10 +487,7 @@ func UpdateUser(c *gin.Context) {
|
||||
}
|
||||
updatePassword := updatedUser.Password != ""
|
||||
if err := updatedUser.Edit(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if originUser.Quota != updatedUser.Quota {
|
||||
@@ -594,17 +533,11 @@ func UpdateSelf(c *gin.Context) {
|
||||
}
|
||||
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := cleanUser.Update(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -635,18 +568,12 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int
|
||||
func DeleteUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
@@ -681,10 +608,7 @@ func DeleteSelf(c *gin.Context) {
|
||||
|
||||
err := model.DeleteUserById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -730,10 +654,7 @@ func CreateUser(c *gin.Context) {
|
||||
DisplayName: user.DisplayName,
|
||||
}
|
||||
if err := cleanUser.Insert(0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -843,10 +764,7 @@ func ManageUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
clearUser := model.User{
|
||||
@@ -878,20 +796,14 @@ func EmailBind(c *gin.Context) {
|
||||
}
|
||||
err := user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.Email = email
|
||||
// no need to check if this email already taken, because we have used verification code to check it
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -913,19 +825,13 @@ func TopUp(c *gin.Context) {
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
quota, err := model.Redeem(req.Key, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -943,6 +849,7 @@ type UpdateUserSettingRequest struct {
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||
NotificationEmail string `json:"notification_email,omitempty"`
|
||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||
RecordIpLog bool `json:"record_ip_log"`
|
||||
}
|
||||
|
||||
func UpdateUserSetting(c *gin.Context) {
|
||||
@@ -956,7 +863,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证预警类型
|
||||
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的预警类型",
|
||||
@@ -974,7 +881,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果是webhook类型,验证webhook地址
|
||||
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
||||
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||
if req.WebhookUrl == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -993,7 +900,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果是邮件类型,验证邮箱地址
|
||||
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
// 验证邮箱格式
|
||||
if !strings.Contains(req.NotificationEmail, "@") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -1007,31 +914,29 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建设置
|
||||
settings := map[string]interface{}{
|
||||
constant.UserSettingNotifyType: req.QuotaWarningType,
|
||||
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
|
||||
settings := dto.UserSetting{
|
||||
NotifyType: req.QuotaWarningType,
|
||||
QuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
|
||||
RecordIpLog: req.RecordIpLog,
|
||||
}
|
||||
|
||||
// 如果是webhook类型,添加webhook相关设置
|
||||
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
||||
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
|
||||
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||
settings.WebhookUrl = req.WebhookUrl
|
||||
if req.WebhookSecret != "" {
|
||||
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
|
||||
settings.WebhookSecret = req.WebhookSecret
|
||||
}
|
||||
}
|
||||
|
||||
// 如果提供了通知邮箱,添加到设置中
|
||||
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
|
||||
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
settings.NotificationEmail = req.NotificationEmail
|
||||
}
|
||||
|
||||
// 更新用户设置
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type wechatLoginResponse struct {
|
||||
@@ -150,19 +151,13 @@ func WeChatBind(c *gin.Context) {
|
||||
}
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.WeChatId = wechatId
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -16,7 +16,7 @@ services:
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- TZ=Asia/Shanghai
|
||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache,请取消注释
|
||||
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
|
||||
195
docs/api/web_api.md
Normal file
195
docs/api/web_api.md
Normal file
@@ -0,0 +1,195 @@
|
||||
# One API – Web 界面后端接口文档
|
||||
|
||||
> 本文档汇总了 **One API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
|
||||
>
|
||||
> 接口前缀统一为 `https://<your-domain>`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
|
||||
>
|
||||
> 鉴权级别说明:
|
||||
> * **公开** – 不需要登录即可调用
|
||||
> * **用户** – 需携带用户 Token(`middleware.UserAuth`)
|
||||
> * **管理员** – 需管理员 Token(`middleware.AdminAuth`)
|
||||
> * **Root** – 仅限最高权限 Root 用户(`middleware.RootAuth`)
|
||||
|
||||
---
|
||||
|
||||
## 1. 初始化 / 系统状态
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/setup | 公开 | 获取系统初始化状态 |
|
||||
| POST | /api/setup | 公开 | 完成首次安装向导 |
|
||||
| GET | /api/status | 公开 | 获取运行状态摘要 |
|
||||
| GET | /api/uptime/status | 公开 | Uptime-Kuma 兼容状态探针 |
|
||||
| GET | /api/status/test | 管理员 | 测试后端与依赖组件是否正常 |
|
||||
|
||||
## 2. 公共信息
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/models | 用户 | 获取前端可用模型列表 |
|
||||
| GET | /api/notice | 公开 | 获取公告栏内容 |
|
||||
| GET | /api/about | 公开 | 关于页面信息 |
|
||||
| GET | /api/home_page_content | 公开 | 首页自定义内容 |
|
||||
| GET | /api/pricing | 可匿名/用户 | 价格与套餐信息 |
|
||||
| GET | /api/ratio_config | 公开 | 模型倍率配置(仅公开字段) |
|
||||
|
||||
## 3. 邮件 / 身份验证
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/verification | 公开 (限流) | 发送邮箱验证邮件 |
|
||||
| GET | /api/reset_password | 公开 (限流) | 发送重置密码邮件 |
|
||||
| POST | /api/user/reset | 公开 | 提交重置密码请求 |
|
||||
|
||||
## 4. OAuth / 第三方登录
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
|
||||
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
|
||||
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
|
||||
| GET | /api/oauth/wechat/bind | 公开 | 微信账户绑定 |
|
||||
| GET | /api/oauth/email/bind | 公开 | 邮箱绑定 |
|
||||
| GET | /api/oauth/telegram/login | 公开 | Telegram 登录 |
|
||||
| GET | /api/oauth/telegram/bind | 公开 | Telegram 账户绑定 |
|
||||
| GET | /api/oauth/state | 公开 | 获取随机 state(防 CSRF) |
|
||||
|
||||
## 5. 用户模块
|
||||
### 5.1 账号注册/登录
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| POST | /api/user/register | 公开 | 注册新账号 |
|
||||
| POST | /api/user/login | 公开 | 用户登录 |
|
||||
| GET | /api/user/logout | 用户 | 退出登录 |
|
||||
| GET | /api/user/epay/notify | 公开 | Epay 支付回调 |
|
||||
| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
|
||||
|
||||
### 5.2 用户自身操作 (需登录)
|
||||
| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
|
||||
| GET | /api/user/self | 用户 | 获取个人资料 |
|
||||
| GET | /api/user/models | 用户 | 获取模型可见性 |
|
||||
| PUT | /api/user/self | 用户 | 修改个人资料 |
|
||||
| DELETE | /api/user/self | 用户 | 注销账号 |
|
||||
| GET | /api/user/token | 用户 | 生成用户级别 Access Token |
|
||||
| GET | /api/user/aff | 用户 | 获取推广码信息 |
|
||||
| POST | /api/user/topup | 用户 | 余额直充 |
|
||||
| POST | /api/user/pay | 用户 | 提交支付订单 |
|
||||
| POST | /api/user/amount | 用户 | 余额支付 |
|
||||
| POST | /api/user/aff_transfer | 用户 | 推广额度转账 |
|
||||
| PUT | /api/user/setting | 用户 | 更新用户设置 |
|
||||
|
||||
### 5.3 管理员用户管理
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/user/ | 管理员 | 获取全部用户列表 |
|
||||
| GET | /api/user/search | 管理员 | 搜索用户 |
|
||||
| GET | /api/user/:id | 管理员 | 获取单个用户信息 |
|
||||
| POST | /api/user/ | 管理员 | 创建用户 |
|
||||
| POST | /api/user/manage | 管理员 | 冻结/重置等管理操作 |
|
||||
| PUT | /api/user/ | 管理员 | 更新用户 |
|
||||
| DELETE | /api/user/:id | 管理员 | 删除用户 |
|
||||
|
||||
## 6. 站点选项 (Root)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/option/ | Root | 获取全局配置 |
|
||||
| PUT | /api/option/ | Root | 更新全局配置 |
|
||||
| POST | /api/option/rest_model_ratio | Root | 重置模型倍率 |
|
||||
| POST | /api/option/migrate_console_setting | Root | 迁移旧版控制台配置 |
|
||||
|
||||
## 7. 模型倍率同步 (Root)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/ratio_sync/channels | Root | 获取可同步渠道列表 |
|
||||
| POST | /api/ratio_sync/fetch | Root | 从上游拉取倍率 |
|
||||
|
||||
## 8. 渠道管理 (管理员)
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| GET | /api/channel/ | 获取渠道列表 |
|
||||
| GET | /api/channel/search | 搜索渠道 |
|
||||
| GET | /api/channel/models | 查询渠道模型能力 |
|
||||
| GET | /api/channel/models_enabled | 查询启用模型能力 |
|
||||
| GET | /api/channel/:id | 获取单个渠道 |
|
||||
| GET | /api/channel/test | 批量测试渠道连通性 |
|
||||
| GET | /api/channel/test/:id | 单个渠道测试 |
|
||||
| GET | /api/channel/update_balance | 批量刷新余额 |
|
||||
| GET | /api/channel/update_balance/:id | 单个刷新余额 |
|
||||
| POST | /api/channel/ | 新增渠道 |
|
||||
| PUT | /api/channel/ | 更新渠道 |
|
||||
| DELETE | /api/channel/disabled | 删除已禁用渠道 |
|
||||
| POST | /api/channel/tag/disabled | 批量禁用标签渠道 |
|
||||
| POST | /api/channel/tag/enabled | 批量启用标签渠道 |
|
||||
| PUT | /api/channel/tag | 编辑渠道标签 |
|
||||
| DELETE | /api/channel/:id | 删除渠道 |
|
||||
| POST | /api/channel/batch | 批量删除渠道 |
|
||||
| POST | /api/channel/fix | 修复渠道能力表 |
|
||||
| GET | /api/channel/fetch_models/:id | 拉取单渠道模型 |
|
||||
| POST | /api/channel/fetch_models | 拉取全部渠道模型 |
|
||||
| POST | /api/channel/batch/tag | 批量设置渠道标签 |
|
||||
| GET | /api/channel/tag/models | 根据标签获取模型 |
|
||||
| POST | /api/channel/copy/:id | 复制渠道 |
|
||||
|
||||
## 9. Token 管理
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/token/ | 用户 | 获取全部 Token |
|
||||
| GET | /api/token/search | 用户 | 搜索 Token |
|
||||
| GET | /api/token/:id | 用户 | 获取单个 Token |
|
||||
| POST | /api/token/ | 用户 | 创建 Token |
|
||||
| PUT | /api/token/ | 用户 | 更新 Token |
|
||||
| DELETE | /api/token/:id | 用户 | 删除 Token |
|
||||
| POST | /api/token/batch | 用户 | 批量删除 Token |
|
||||
|
||||
## 10. 兑换码管理 (管理员)
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| GET | /api/redemption/ | 获取兑换码列表 |
|
||||
| GET | /api/redemption/search | 搜索兑换码 |
|
||||
| GET | /api/redemption/:id | 获取单个兑换码 |
|
||||
| POST | /api/redemption/ | 创建兑换码 |
|
||||
| PUT | /api/redemption/ | 更新兑换码 |
|
||||
| DELETE | /api/redemption/invalid | 删除无效兑换码 |
|
||||
| DELETE | /api/redemption/:id | 删除兑换码 |
|
||||
|
||||
## 11. 日志
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/log/ | 管理员 | 获取全部日志 |
|
||||
| DELETE | /api/log/ | 管理员 | 删除历史日志 |
|
||||
| GET | /api/log/stat | 管理员 | 日志统计 |
|
||||
| GET | /api/log/self/stat | 用户 | 我的日志统计 |
|
||||
| GET | /api/log/search | 管理员 | 搜索全部日志 |
|
||||
| GET | /api/log/self | 用户 | 获取我的日志 |
|
||||
| GET | /api/log/self/search | 用户 | 搜索我的日志 |
|
||||
| GET | /api/log/token | 公开 | 根据 Token 查询日志(支持 CORS) |
|
||||
|
||||
## 12. 数据统计
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/data/ | 管理员 | 全站用量按日期统计 |
|
||||
| GET | /api/data/self | 用户 | 我的用量按日期统计 |
|
||||
|
||||
## 13. 分组
|
||||
| GET | /api/group/ | 管理员 | 获取全部分组列表 |
|
||||
|
||||
## 14. Midjourney 任务
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/mj/self | 用户 | 获取自己的 MJ 任务 |
|
||||
| GET | /api/mj/ | 管理员 | 获取全部 MJ 任务 |
|
||||
|
||||
## 15. 任务中心
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/task/self | 用户 | 获取我的任务 |
|
||||
| GET | /api/task/ | 管理员 | 获取全部任务 |
|
||||
|
||||
## 16. 账户计费面板 (Dashboard)
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /dashboard/billing/subscription | 用户 Token | 获取订阅额度信息 |
|
||||
| GET | /v1/dashboard/billing/subscription | 同上 | 兼容 OpenAI SDK 路径 |
|
||||
| GET | /dashboard/billing/usage | 用户 Token | 获取使用量信息 |
|
||||
| GET | /v1/dashboard/billing/usage | 同上 | 兼容 OpenAI SDK 路径 |
|
||||
|
||||
---
|
||||
|
||||
> **更新日期**:2025.07.17
|
||||
55
docs/images/cherry-studio.svg
Normal file
55
docs/images/cherry-studio.svg
Normal file
@@ -0,0 +1,55 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg id="_图层_2" data-name="图层_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 198.45 66.73">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: #ea5e5d;
|
||||
}
|
||||
|
||||
.cls-2 {
|
||||
fill: #23af69;
|
||||
}
|
||||
|
||||
.cls-3 {
|
||||
fill: #ea5756;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<g id="_图层_1-2" data-name="图层_1">
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path class="cls-1" d="M16.72,51.21c-4.45,0-8.64-1.78-11.81-5.01-3.17-3.23-4.91-7.51-4.91-12.04s1.74-8.81,4.91-12.04,7.36-5.01,11.81-5.01,8.71,1.82,11.82,4.99c2.32,2.36,2.32,6.2,0,8.56-2.32,2.36-6.08,2.36-8.4,0-.9-.92-2.15-1.45-3.43-1.45-2.63,0-4.85,2.26-4.85,4.94s2.22,4.94,4.85,4.94c1.28,0,2.52-.53,3.43-1.45,2.32-2.36,6.08-2.36,8.4,0,2.32,2.36,2.32,6.2,0,8.56-3.11,3.17-7.42,4.99-11.82,4.99Z"/>
|
||||
<path class="cls-1" d="M32.05,66.73c-4.45,0-8.64-1.78-11.81-5.01s-4.91-7.51-4.91-12.04,1.79-8.88,4.9-12.06c2.32-2.36,6.08-2.36,8.4,0,2.32,2.36,2.32,6.2,0,8.56-.9.92-1.42,2.19-1.42,3.49,0,2.68,2.22,4.94,4.85,4.94s4.85-2.26,4.85-4.94c0-.95-.23-2.31-1.32-3.43-3.13-3.19-4.92-7.6-4.92-12.09s1.74-8.81,4.91-12.04,7.36-5.01,11.81-5.01,8.64,1.78,11.81,5.01,4.91,7.51,4.91,12.04-1.79,8.88-4.9,12.06c-2.32,2.36-6.08,2.36-8.4,0-2.32-2.36-2.32-6.2,0-8.56.9-.92,1.42-2.19,1.42-3.49,0-2.68-2.22-4.94-4.85-4.94s-4.85,2.26-4.85,4.94c0,1.31.53,2.6,1.45,3.53,3.1,3.16,4.8,7.42,4.8,11.99s-1.74,8.81-4.91,12.04c-3.17,3.23-7.36,5.01-11.81,5.01Z"/>
|
||||
</g>
|
||||
<path class="cls-2" d="M32.05,19.09l-9.72-9.12c-1.5-1.4-1.57-3.75-.17-5.25,1.4-1.49,3.75-1.57,5.25-.17l3.89,3.65,5.53-6.83c1.29-1.59,3.63-1.84,5.22-.55,1.59,1.29,1.84,3.63.55,5.22l-10.56,13.05Z"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="cls-3" d="M93.93,24.6l.55-.39c.69-.4,1.17-.61,1.46-.61.63,0,1.3.57,2.03,1.7.44.71.67,1.27.67,1.7s-.14.78-.41,1.06c-.27.28-.59.54-.96.76-.36.22-.71.43-1.05.64-.33.2-1.02.47-2.05.79-1.03.32-2.03.49-2.99.49s-1.93-.13-2.91-.38c-.98-.25-1.99-.68-3.03-1.27-1.04-.6-1.98-1.32-2.81-2.18-.83-.86-1.51-1.96-2.05-3.31-.54-1.35-.8-2.81-.8-4.38s.26-3.01.79-4.29c.53-1.28,1.2-2.35,2.02-3.19.82-.84,1.75-1.54,2.81-2.11,1.98-1.09,3.97-1.64,5.98-1.64.95,0,1.92.15,2.9.44.98.29,1.72.59,2.23.9l.73.42c.36.22.65.4.85.55.53.42.79.91.79,1.44s-.21,1.1-.64,1.68c-.79,1.09-1.5,1.64-2.12,1.64-.36,0-.88-.22-1.55-.67-.85-.69-1.98-1.03-3.4-1.03-1.31,0-2.61.46-3.88,1.36-.61.44-1.11,1.07-1.52,1.88-.4.81-.61,1.72-.61,2.75s.2,1.94.61,2.75c.4.81.92,1.45,1.55,1.91,1.23.89,2.52,1.34,3.85,1.34.63,0,1.22-.08,1.77-.24.56-.16.96-.32,1.2-.49Z"/>
|
||||
<path class="cls-3" d="M114.38,9.07c.16-.3.43-.52.82-.64.38-.12.87-.18,1.46-.18s1.05.05,1.4.15c.34.1.61.22.79.36.18.14.32.34.42.61.1.34.15.87.15,1.58v16.84c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.28.55-1.07.82-2.37.82-1.42,0-2.25-.37-2.49-1.12-.12-.34-.18-.87-.18-1.58v-6.16h-8.04v6.19c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.28.55-1.07.82-2.37.82-1.42,0-2.25-.37-2.49-1.12-.12-.34-.18-.87-.18-1.58V10.92c0-.46.02-.81.05-1.05.03-.23.13-.5.29-.8.28-.55,1.07-.82,2.37-.82,1.42,0,2.25.37,2.52,1.12.1.34.15.87.15,1.58v6.19h8.04v-6.22c0-.46.02-.81.05-1.05.03-.23.13-.5.29-.8Z"/>
|
||||
<path class="cls-3" d="M127.21,25.1h9.34c.47,0,.81.02,1.05.05.23.03.5.13.8.29.55.28.82,1.07.82,2.37,0,1.42-.37,2.25-1.12,2.49-.34.12-.87.18-1.58.18h-12.01c-1.42,0-2.25-.38-2.49-1.15-.12-.32-.18-.84-.18-1.55V10.9c0-1.03.19-1.73.58-2.11.38-.37,1.11-.56,2.18-.56h11.95c.47,0,.81.02,1.05.05.23.03.5.13.8.29.55.28.82,1.07.82,2.37,0,1.42-.37,2.25-1.12,2.49-.34.12-.87.18-1.58.18h-9.31v3.06h6.01c.46,0,.81.02,1.05.05.23.03.5.13.8.29.55.28.82,1.07.82,2.37,0,1.42-.38,2.25-1.15,2.49-.34.12-.87.18-1.58.18h-5.95v3.06Z"/>
|
||||
<path class="cls-3" d="M196.96,8.79c.99.69,1.49,1.35,1.49,2,0,.38-.23.92-.7,1.61l-6.55,9.8v5.79c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.16.3-.43.52-.82.64-.38.12-.9.18-1.55.18s-1.16-.06-1.55-.18c-.38-.12-.66-.34-.82-.65-.16-.31-.26-.59-.29-.82-.03-.23-.05-.59-.05-1.08v-5.73l-6.55-9.8c-.47-.69-.7-1.22-.7-1.61,0-.65.44-1.27,1.33-1.87.89-.6,1.53-.9,1.91-.9s.69.08.91.24c.34.22.71.64,1.09,1.24l4.7,7.52,4.7-7.52c.38-.61.72-1.01,1-1.2s.61-.29.99-.29.97.25,1.77.76Z"/>
|
||||
<g>
|
||||
<path class="cls-3" d="M81.93,56.63c-.53-.65-.79-1.23-.79-1.74s.43-1.2,1.3-2.05c.51-.49,1.04-.73,1.61-.73s1.36.51,2.37,1.52c.28.34.69.67,1.21.99.53.31,1.01.47,1.46.47,1.88,0,2.82-.77,2.82-2.31,0-.46-.26-.85-.77-1.17-.52-.31-1.16-.54-1.93-.68-.77-.14-1.6-.37-2.49-.68-.89-.31-1.72-.68-2.49-1.11-.77-.42-1.41-1.1-1.93-2.02-.52-.92-.77-2.03-.77-3.32,0-1.78.66-3.33,1.99-4.66s3.13-1.99,5.42-1.99c1.21,0,2.32.16,3.32.47,1,.31,1.69.63,2.08.96l.76.58c.63.59.94,1.08.94,1.49s-.24.96-.73,1.67c-.69,1.01-1.4,1.52-2.12,1.52-.42,0-.95-.2-1.58-.61-.06-.04-.18-.14-.35-.3-.17-.16-.33-.29-.47-.39-.42-.26-.97-.39-1.62-.39s-1.2.16-1.64.47c-.43.31-.65.75-.65,1.3s.26,1.01.77,1.35c.52.34,1.16.58,1.93.7.77.12,1.61.31,2.52.56.91.25,1.75.56,2.52.93.77.36,1.41,1,1.93,1.9.52.9.77,2.01.77,3.32s-.26,2.47-.79,3.47c-.53,1-1.21,1.77-2.06,2.32-1.64,1.07-3.39,1.61-5.25,1.61-.95,0-1.85-.12-2.7-.35-.85-.23-1.54-.52-2.06-.86-1.07-.65-1.82-1.27-2.24-1.88l-.27-.33Z"/>
|
||||
<path class="cls-3" d="M100.74,37.49h16.87c.65,0,1.12.08,1.43.23.3.15.51.39.61.71.1.32.15.75.15,1.27s-.05.95-.15,1.26c-.1.31-.27.53-.52.65-.36.18-.88.27-1.55.27h-5.79v15.26c0,.47-.02.81-.05,1.03s-.12.48-.27.77c-.15.29-.42.5-.8.62-.38.12-.89.18-1.52.18s-1.13-.06-1.5-.18c-.37-.12-.64-.33-.79-.62-.15-.29-.24-.56-.27-.79-.03-.23-.05-.58-.05-1.05v-15.23h-5.82c-.65,0-1.12-.08-1.43-.23-.3-.15-.51-.39-.61-.71-.1-.32-.15-.75-.15-1.27s.05-.95.15-1.26c.1-.31.27-.53.52-.65.36-.18.88-.27,1.55-.27Z"/>
|
||||
<path class="cls-3" d="M135.99,38.34c.2-.32.5-.55.88-.67.38-.12.86-.18,1.44-.18s1.04.05,1.38.15c.34.1.61.22.79.36.18.14.31.35.39.64.12.34.18.87.18,1.58v9.16c0,2.67-.83,5.1-2.49,7.28-.81,1.03-1.85,1.87-3.12,2.5s-2.68.96-4.23.96-2.95-.32-4.22-.97c-1.26-.65-2.29-1.5-3.08-2.55-1.64-2.14-2.46-4.57-2.46-7.28v-9.13c0-.49.02-.84.05-1.08.03-.23.13-.5.29-.8.16-.3.43-.52.82-.64.38-.12.9-.18,1.55-.18s1.16.06,1.55.18c.38.12.65.33.79.64.24.47.36,1.1.36,1.91v9.1c0,1.23.3,2.41.91,3.52.3.57.76,1.02,1.37,1.36.61.34,1.32.52,2.15.52,1.48,0,2.58-.55,3.31-1.64.73-1.09,1.09-2.36,1.09-3.79v-9.28c0-.79.1-1.34.3-1.67Z"/>
|
||||
<path class="cls-3" d="M146.18,37.49l5.61.03c2.93,0,5.51,1.06,7.74,3.17,2.22,2.11,3.34,4.71,3.34,7.8s-1.09,5.73-3.26,7.93c-2.17,2.2-4.81,3.31-7.9,3.31h-5.55c-1.23,0-2-.25-2.31-.76-.24-.42-.36-1.07-.36-1.94v-16.87c0-.49.02-.84.05-1.06s.13-.49.29-.79c.28-.55,1.07-.82,2.37-.82ZM151.79,54.35c1.46,0,2.77-.54,3.94-1.62,1.17-1.08,1.76-2.44,1.76-4.08s-.57-3.01-1.71-4.11c-1.14-1.1-2.48-1.65-4.02-1.65h-2.91v11.47h2.94Z"/>
|
||||
<path class="cls-3" d="M164.84,40.19c0-.46.02-.81.05-1.05.03-.23.13-.5.29-.8.28-.55,1.07-.82,2.37-.82,1.42,0,2.25.37,2.52,1.12.1.34.15.87.15,1.58v16.87c0,.49-.02.84-.05,1.06s-.13.49-.29.79c-.28.55-1.07.82-2.37.82-1.42,0-2.25-.38-2.49-1.15-.12-.32-.18-.84-.18-1.55v-16.87Z"/>
|
||||
<path class="cls-3" d="M183.07,37.24c2.99,0,5.59,1.08,7.8,3.25,2.2,2.16,3.31,4.85,3.31,8.05s-1.05,5.94-3.16,8.19c-2.1,2.26-4.69,3.38-7.77,3.38s-5.69-1.11-7.84-3.34c-2.15-2.22-3.23-4.87-3.23-7.95,0-1.68.3-3.25.91-4.72.61-1.47,1.42-2.7,2.43-3.69,1.01-.99,2.17-1.77,3.49-2.34,1.31-.57,2.67-.85,4.07-.85ZM177.55,48.68c0,1.8.58,3.26,1.74,4.38,1.16,1.12,2.46,1.68,3.9,1.68s2.73-.55,3.88-1.64c1.15-1.09,1.73-2.56,1.73-4.4s-.58-3.32-1.74-4.43c-1.16-1.11-2.46-1.67-3.9-1.67s-2.73.56-3.88,1.68c-1.15,1.12-1.73,2.58-1.73,4.38Z"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="cls-3" d="M176.92,11.06c-.03-.23-.13-.5-.29-.8-.28-.55-1.07-.82-2.37-.82h-6.55c-1.78,0-3.51.65-5.19,1.94-.81.63-1.48,1.48-2,2.55-.53,1.07-.79,2.27-.79,3.58,0,2.29.76,4.17,2.28,5.64-.44,1.07-1.13,2.66-2.06,4.76-.3.73-.45,1.25-.45,1.58,0,.77.63,1.42,1.88,1.94.65.28,1.17.43,1.56.43s.72-.1.97-.29c.25-.19.44-.39.56-.59.2-.38.99-2.21,2.37-5.49l.94.06h3.82v3.43c0,.47.02.81.05,1.05.03.23.13.5.29.8.28.55,1.07.82,2.37.82,1.42,0,2.25-.37,2.49-1.12.12-.34.18-.87.18-1.58V12.11c0-.46-.02-.81-.05-1.05ZM172.81,19.44c-.09.14-.48.77-1.24.91-.2.04-.37.03-.48.02-.02.14-.04.26-.06.38-.16.83-.38,1.05-.57,1.07-.29.05-.51-.35-.93-.9-.23.01-.46.02-.69.02-.51,0-1.01-.03-1.49-.09-.25-.03-.5-.07-.74-.11-1.18-.32-2.03-1.27-2.03-2.4v-1.37c0-1.13.86-2.08,2.03-2.4.24-.04.49-.08.74-.11.48-.06.98-.09,1.49-.09s1.01.03,1.49.09c.25.03.5.07.74.11.6.16,1.12.49,1.49.93.34.41.55.92.55,1.47v1.37c0,.23-.01.66-.29,1.1Z"/>
|
||||
<circle class="cls-2" cx="167.24" cy="17.67" r=".49"/>
|
||||
<circle class="cls-2" cx="168.88" cy="17.71" r=".49"/>
|
||||
<circle class="cls-2" cx="170.59" cy="17.71" r=".49"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="cls-3" d="M141.01,8.24c.03-.23.13-.5.29-.8.28-.55,1.07-.82,2.37-.82h6.55c1.78,0,3.51.65,5.19,1.94.81.63,1.48,1.48,2,2.55.53,1.07.79,2.27.79,3.58,0,2.29-.76,4.17-2.28,5.64.44,1.07,1.13,2.66,2.06,4.76.3.73.45,1.25.45,1.58,0,.77-.63,1.42-1.88,1.94-.65.28-1.17.43-1.56.43s-.72-.1-.97-.29c-.25-.19-.44-.39-.56-.59-.2-.38-.99-2.21-2.37-5.49l-.94.06h-3.82v3.43c0,.47-.02.81-.05,1.05-.03.23-.13.5-.29.8-.28.55-1.07.82-2.37.82-1.42,0-2.25-.37-2.49-1.12-.12-.34-.18-.87-.18-1.58V9.28c0-.46.02-.81.05-1.05ZM145.12,16.62c.09.14.48.77,1.24.91.2.04.37.03.48.02.02.14.04.26.06.38.16.83.38,1.05.57,1.07.29.05.51-.35.93-.9.23.01.46.02.69.02.51,0,1.01-.03,1.49-.09.25-.03.5-.07.74-.11,1.18-.32,2.03-1.27,2.03-2.4v-1.37c0-1.13-.86-2.08-2.03-2.4-.24-.04-.49-.08-.74-.11-.48-.06-.98-.09-1.49-.09s-1.01.03-1.49.09c-.25.03-.5.07-.74.11-.6.16-1.12.49-1.49.93-.34.41-.55.92-.55,1.47v1.37c0,.23.01.66.29,1.1Z"/>
|
||||
<circle class="cls-2" cx="150.69" cy="14.84" r=".49"/>
|
||||
<circle class="cls-2" cx="149.05" cy="14.89" r=".49"/>
|
||||
<circle class="cls-2" cx="147.35" cy="14.89" r=".49"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 9.5 KiB |
BIN
docs/images/pku.png
Normal file
BIN
docs/images/pku.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
1
docs/images/ucloud.svg
Normal file
1
docs/images/ucloud.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 7.9 KiB |
9
dto/channel_settings.go
Normal file
9
dto/channel_settings.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package dto
|
||||
|
||||
type ChannelSettings struct {
|
||||
ForceFormat bool `json:"force_format,omitempty"`
|
||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||
Proxy string `json:"proxy"`
|
||||
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package dto
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
@@ -158,6 +159,27 @@ type InputSchema struct {
|
||||
Required any `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeWebSearchTool struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
MaxUses int `json:"max_uses,omitempty"`
|
||||
UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeWebSearchUserLocation struct {
|
||||
Type string `json:"type"`
|
||||
Timezone string `json:"timezone,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
@@ -176,9 +198,69 @@ type ClaudeRequest struct {
|
||||
Thinking *Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// AddTool 添加工具到请求中
|
||||
func (c *ClaudeRequest) AddTool(tool any) {
|
||||
if c.Tools == nil {
|
||||
c.Tools = make([]any, 0)
|
||||
}
|
||||
|
||||
switch tools := c.Tools.(type) {
|
||||
case []any:
|
||||
c.Tools = append(tools, tool)
|
||||
default:
|
||||
// 如果Tools不是[]any类型,重新初始化为[]any
|
||||
c.Tools = []any{tool}
|
||||
}
|
||||
}
|
||||
|
||||
// GetTools 获取工具列表
|
||||
func (c *ClaudeRequest) GetTools() []any {
|
||||
if c.Tools == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch tools := c.Tools.(type) {
|
||||
case []any:
|
||||
return tools
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessTools 处理工具列表,支持类型断言
|
||||
func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
|
||||
var normalTools []*Tool
|
||||
var webSearchTools []*ClaudeWebSearchTool
|
||||
|
||||
for _, tool := range tools {
|
||||
switch t := tool.(type) {
|
||||
case *Tool:
|
||||
normalTools = append(normalTools, t)
|
||||
case *ClaudeWebSearchTool:
|
||||
webSearchTools = append(webSearchTools, t)
|
||||
case Tool:
|
||||
normalTools = append(normalTools, &t)
|
||||
case ClaudeWebSearchTool:
|
||||
webSearchTools = append(webSearchTools, &t)
|
||||
default:
|
||||
// 未知类型,跳过
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return normalTools, webSearchTools
|
||||
}
|
||||
|
||||
type Thinking struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens int `json:"budget_tokens"`
|
||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Thinking) GetBudgetTokens() int {
|
||||
if c.BudgetTokens == nil {
|
||||
return 0
|
||||
}
|
||||
return *c.BudgetTokens
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) IsStringSystem() bool {
|
||||
@@ -221,7 +303,7 @@ type ClaudeResponse struct {
|
||||
Completion string `json:"completion,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *ClaudeError `json:"error,omitempty"`
|
||||
Error *types.ClaudeError `json:"error,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||
@@ -243,8 +325,13 @@ func (c *ClaudeResponse) GetIndex() int {
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"`
|
||||
}
|
||||
|
||||
type ClaudeServerToolUse struct {
|
||||
WebSearchRequests int `json:"web_search_requests"`
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ type ImageRequest struct {
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
|
||||
12
dto/error.go
12
dto/error.go
@@ -1,5 +1,7 @@
|
||||
package dto
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
@@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct {
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Error types.OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
|
||||
@@ -57,6 +57,8 @@ type MidjourneyDto struct {
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
VideoUrl string `json:"videoUrl"`
|
||||
VideoUrls []ImgUrls `json:"videoUrls"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
@@ -65,6 +67,10 @@ type MidjourneyDto struct {
|
||||
Properties *Properties `json:"properties"`
|
||||
}
|
||||
|
||||
type ImgUrls struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
@@ -7,15 +7,15 @@ import (
|
||||
)
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema json.RawMessage `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type FormatJsonSchema struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict any `json:"strict,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict json.RawMessage `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
@@ -53,19 +53,35 @@ type GeneralOpenAIRequest struct {
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
SearchParameters any `json:"search_parameters,omitempty"` //xai
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// OpenRouter Params
|
||||
Usage json.RawMessage `json:"usage,omitempty"`
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
// 用匿名参数接收额外参数,例如ollama的think参数在此接收
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
result := make(map[string]any)
|
||||
data, _ := common.EncodeJson(r)
|
||||
_ = common.DecodeJson(data, &result)
|
||||
data, _ := common.Marshal(r)
|
||||
_ = common.Unmarshal(data, &result)
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||
if strings.HasPrefix(r.Model, "o") {
|
||||
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||
return "developer"
|
||||
}
|
||||
}
|
||||
return "system"
|
||||
}
|
||||
|
||||
type ToolCallRequest struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
@@ -598,26 +614,29 @@ type WebSearchOptions struct {
|
||||
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/responses/create
|
||||
type OpenAIResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Include json.RawMessage `json:"include,omitempty"`
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools []ResponsesToolsCall `json:"tools,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Include json.RawMessage `json:"include,omitempty"`
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
}
|
||||
|
||||
type Reasoning struct {
|
||||
@@ -625,21 +644,23 @@ type Reasoning struct {
|
||||
Summary string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
type ResponsesToolsCall struct {
|
||||
Type string `json:"type"`
|
||||
// Web Search
|
||||
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
||||
SearchContextSize string `json:"search_context_size,omitempty"`
|
||||
// File Search
|
||||
VectorStoreIds []string `json:"vector_store_ids,omitempty"`
|
||||
MaxNumResults uint `json:"max_num_results,omitempty"`
|
||||
Filters json.RawMessage `json:"filters,omitempty"`
|
||||
// Computer Use
|
||||
DisplayWidth uint `json:"display_width,omitempty"`
|
||||
DisplayHeight uint `json:"display_height,omitempty"`
|
||||
Environment string `json:"environment,omitempty"`
|
||||
// Function
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
}
|
||||
//type ResponsesToolsCall struct {
|
||||
// Type string `json:"type"`
|
||||
// // Web Search
|
||||
// UserLocation json.RawMessage `json:"user_location,omitempty"`
|
||||
// SearchContextSize string `json:"search_context_size,omitempty"`
|
||||
// // File Search
|
||||
// VectorStoreIds []string `json:"vector_store_ids,omitempty"`
|
||||
// MaxNumResults uint `json:"max_num_results,omitempty"`
|
||||
// Filters json.RawMessage `json:"filters,omitempty"`
|
||||
// // Computer Use
|
||||
// DisplayWidth uint `json:"display_width,omitempty"`
|
||||
// DisplayHeight uint `json:"display_height,omitempty"`
|
||||
// Environment string `json:"environment,omitempty"`
|
||||
// // Function
|
||||
// Name string `json:"name,omitempty"`
|
||||
// Description string `json:"description,omitempty"`
|
||||
// Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
// Function json.RawMessage `json:"function,omitempty"`
|
||||
// Container json.RawMessage `json:"container,omitempty"`
|
||||
//}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
@@ -26,9 +29,9 @@ type OpenAITextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Created any `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
@@ -45,6 +48,19 @@ type OpenAIEmbeddingResponse struct {
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type FlexibleEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"`
|
||||
}
|
||||
|
||||
type FlexibleEmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []FlexibleEmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
|
||||
Logprobs *any `json:"logprobs"`
|
||||
@@ -178,6 +194,8 @@ type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||
// OpenRouter Params
|
||||
Cost any `json:"cost,omitempty"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
@@ -195,28 +213,28 @@ type OutputTokenDetails struct {
|
||||
}
|
||||
|
||||
type OpenAIResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutput `json:"output"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
PreviousResponseID string `json:"previous_response_id"`
|
||||
Reasoning *Reasoning `json:"reasoning"`
|
||||
Store bool `json:"store"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
ToolChoice string `json:"tool_choice"`
|
||||
Tools []ResponsesToolsCall `json:"tools"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Truncation string `json:"truncation"`
|
||||
Usage *Usage `json:"usage"`
|
||||
User json.RawMessage `json:"user"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutput `json:"output"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
PreviousResponseID string `json:"previous_response_id"`
|
||||
Reasoning *Reasoning `json:"reasoning"`
|
||||
Store bool `json:"store"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
ToolChoice string `json:"tool_choice"`
|
||||
Tools []map[string]any `json:"tools"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Truncation string `json:"truncation"`
|
||||
Usage *Usage `json:"usage"`
|
||||
User json.RawMessage `json:"user"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
type IncompleteDetails struct {
|
||||
|
||||
@@ -1,26 +1,11 @@
|
||||
package dto
|
||||
|
||||
type OpenAIModelPermission struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||||
AllowSampling bool `json:"allow_sampling"`
|
||||
AllowLogprobs bool `json:"allow_logprobs"`
|
||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||||
AllowView bool `json:"allow_view"`
|
||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||||
Organization string `json:"organization"`
|
||||
Group *string `json:"group"`
|
||||
IsBlocking bool `json:"is_blocking"`
|
||||
}
|
||||
import "one-api/constant"
|
||||
|
||||
type OpenAIModels struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []OpenAIModelPermission `json:"permission"`
|
||||
Root string `json:"root"`
|
||||
Parent *string `json:"parent"`
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
38
dto/ratio_sync.go
Normal file
38
dto/ratio_sync.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package dto
|
||||
|
||||
type UpstreamDTO struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type UpstreamRequest struct {
|
||||
ChannelIDs []int64 `json:"channel_ids"`
|
||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// TestResult 上游测试连通性结果
|
||||
type TestResult struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// DifferenceItem 差异项
|
||||
// Current 为本地值,可能为 nil
|
||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||
|
||||
type DifferenceItem struct {
|
||||
Current interface{} `json:"current"`
|
||||
Upstreams map[string]interface{} `json:"upstreams"`
|
||||
Confidence map[string]bool `json:"confidence"`
|
||||
}
|
||||
|
||||
type SyncableChannel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package dto
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
const (
|
||||
RealtimeEventTypeError = "error"
|
||||
RealtimeEventTypeSessionUpdate = "session.update"
|
||||
@@ -23,12 +25,12 @@ type RealtimeEvent struct {
|
||||
EventId string `json:"event_id"`
|
||||
Type string `json:"type"`
|
||||
//PreviousItemId string `json:"previous_item_id"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type RealtimeResponse struct {
|
||||
|
||||
@@ -4,7 +4,7 @@ type RerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
TopN int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
|
||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||
|
||||
16
dto/user_settings.go
Normal file
16
dto/user_settings.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package dto
|
||||
|
||||
type UserSetting struct {
|
||||
NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型
|
||||
QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值
|
||||
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
||||
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
||||
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
||||
}
|
||||
|
||||
var (
|
||||
NotifyTypeEmail = "email" // Email 邮件
|
||||
NotifyTypeWebhook = "webhook" // Webhook
|
||||
)
|
||||
47
dto/video.go
Normal file
47
dto/video.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package dto
|
||||
|
||||
type VideoRequest struct {
|
||||
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
|
||||
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
|
||||
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
|
||||
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
|
||||
Width int `json:"width" example:"512"` // Video width
|
||||
Height int `json:"height" example:"512"` // Video height
|
||||
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
|
||||
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
|
||||
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
|
||||
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
|
||||
User string `json:"user,omitempty" example:"user-1234"` // User identifier
|
||||
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
|
||||
}
|
||||
|
||||
// VideoResponse 视频生成提交任务后的响应
|
||||
type VideoResponse struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// VideoTaskResponse 查询视频生成任务状态的响应
|
||||
type VideoTaskResponse struct {
|
||||
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
|
||||
Status string `json:"status" example:"succeeded"` // 任务状态
|
||||
Url string `json:"url,omitempty"` // 视频资源URL(成功时)
|
||||
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
|
||||
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
|
||||
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
|
||||
}
|
||||
|
||||
// VideoTaskMetadata 视频任务元数据
|
||||
type VideoTaskMetadata struct {
|
||||
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
|
||||
Fps int `json:"fps" example:"30"` // 实际帧率
|
||||
Width int `json:"width" example:"512"` // 实际宽度
|
||||
Height int `json:"height" example:"512"` // 实际高度
|
||||
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
|
||||
}
|
||||
|
||||
// VideoTaskError 视频任务错误信息
|
||||
type VideoTaskError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
4
go.mod
4
go.mod
@@ -27,10 +27,13 @@ require (
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/stripe/stripe-go/v81 v81.4.0
|
||||
github.com/thanhpk/randstr v1.0.6
|
||||
github.com/tiktoken-go/tokenizer v0.6.2
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.35.0
|
||||
golang.org/x/sync v0.11.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/gorm v1.25.2
|
||||
@@ -84,7 +87,6 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.12.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -195,6 +195,10 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
|
||||
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
|
||||
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
@@ -224,6 +228,7 @@ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSO
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -232,6 +237,7 @@ golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
1054
i18n/zh-cn.json
Normal file
1054
i18n/zh-cn.json
Normal file
File diff suppressed because it is too large
Load Diff
96
main.go
96
main.go
@@ -12,7 +12,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
@@ -32,14 +32,13 @@ var buildFS embed.FS
|
||||
var indexPage []byte
|
||||
|
||||
func main() {
|
||||
err := godotenv.Load(".env")
|
||||
|
||||
err := InitResources()
|
||||
if err != nil {
|
||||
common.SysLog("Support for .env file is disabled: " + err.Error())
|
||||
common.FatalLog("failed to initialize resources: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
common.LoadEnv()
|
||||
|
||||
common.SetupLogger()
|
||||
common.SysLog("New API " + common.Version + " started")
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -47,19 +46,7 @@ func main() {
|
||||
if common.DebugEnabled {
|
||||
common.SysLog("running in debug mode")
|
||||
}
|
||||
// Initialize SQL Database
|
||||
err = model.InitDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
}
|
||||
|
||||
model.CheckSetup()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitLogDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
}
|
||||
defer func() {
|
||||
err := model.CloseDB()
|
||||
if err != nil {
|
||||
@@ -67,21 +54,6 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize Redis
|
||||
err = common.InitRedisClient()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize Redis: " + err.Error())
|
||||
}
|
||||
|
||||
// Initialize model settings
|
||||
operation_setting.InitRatioSettings()
|
||||
// Initialize constants
|
||||
constant.InitEnv()
|
||||
// Initialize options
|
||||
model.InitOptionMap()
|
||||
|
||||
service.InitTokenEncoders()
|
||||
|
||||
if common.RedisEnabled {
|
||||
// for compatibility with old versions
|
||||
common.MemoryCacheEnabled = true
|
||||
@@ -96,19 +68,21 @@ func main() {
|
||||
if r := recover(); r != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
||||
// Retry once
|
||||
_, fixErr := model.FixAbility()
|
||||
_, _, fixErr := model.FixAbility()
|
||||
if fixErr != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||
common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||
}
|
||||
}
|
||||
}()
|
||||
model.InitChannelCache()
|
||||
}()
|
||||
|
||||
go model.SyncOptions(common.SyncFrequency)
|
||||
go model.SyncChannelCache(common.SyncFrequency)
|
||||
}
|
||||
|
||||
// 热更新配置
|
||||
go model.SyncOptions(common.SyncFrequency)
|
||||
|
||||
// 数据看板
|
||||
go model.UpdateQuotaData()
|
||||
|
||||
@@ -184,3 +158,53 @@ func main() {
|
||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func InitResources() error {
|
||||
// Initialize resources here if needed
|
||||
// This is a placeholder function for future resource initialization
|
||||
err := godotenv.Load(".env")
|
||||
if err != nil {
|
||||
common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
|
||||
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||
}
|
||||
|
||||
// 加载环境变量
|
||||
common.InitEnv()
|
||||
|
||||
common.SetupLogger()
|
||||
|
||||
// Initialize model settings
|
||||
ratio_setting.InitRatioSettings()
|
||||
|
||||
service.InitHttpClient()
|
||||
|
||||
service.InitTokenEncoders()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
model.CheckSetup()
|
||||
|
||||
// Initialize options, should after model.InitDB()
|
||||
model.InitOptionMap()
|
||||
|
||||
// 初始化模型
|
||||
model.GetPricing()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitLogDB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize Redis
|
||||
err = common.InitRedisClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -121,7 +122,20 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Set("role", role)
|
||||
c.Set("id", id)
|
||||
c.Set("group", session.Get("group"))
|
||||
c.Set("user_group", session.Get("group"))
|
||||
c.Set("use_access_token", useAccessToken)
|
||||
|
||||
//userCache, err := model.GetUserCache(id.(int))
|
||||
//if err != nil {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": err.Error(),
|
||||
// })
|
||||
// c.Abort()
|
||||
// return
|
||||
//}
|
||||
//userCache.WriteContext(c)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@@ -184,7 +198,7 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
// gemini api 从query中获取key
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
skKey := c.Query("key")
|
||||
if skKey != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
||||
@@ -233,30 +247,41 @@ func TokenAuth() func(c *gin.Context) {
|
||||
|
||||
userCache.WriteContext(c)
|
||||
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
||||
if !token.UnlimitedQuota {
|
||||
c.Set("token_quota", token.RemainQuota)
|
||||
}
|
||||
if token.ModelLimitsEnabled {
|
||||
c.Set("token_model_limit_enabled", true)
|
||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
}
|
||||
err = SetupContextForToken(c, token, parts...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token is nil")
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
||||
if !token.UnlimitedQuota {
|
||||
c.Set("token_quota", token.RemainQuota)
|
||||
}
|
||||
if token.ModelLimitsEnabled {
|
||||
c.Set("token_model_limit_enabled", true)
|
||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return fmt.Errorf("普通用户不支持指定渠道")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -20,11 +22,12 @@ import (
|
||||
|
||||
type ModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
Group string `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
allowIpsMap := c.GetStringMap("allow_ips")
|
||||
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
||||
if len(allowIpsMap) != 0 {
|
||||
clientIp := c.ClientIP()
|
||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||
@@ -33,14 +36,14 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||
return
|
||||
}
|
||||
userGroup := c.GetString(constant.ContextKeyUserGroup)
|
||||
tokenGroup := c.GetString("token_group")
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
@@ -48,13 +51,15 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if !setting.ContainsGroupRatio(tokenGroup) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||
if tokenGroup != "auto" {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
c.Set("group", userGroup)
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -73,9 +78,9 @@ func Distribute() func(c *gin.Context) {
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
// check token model mapping
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
@@ -95,25 +100,33 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||
if modelRequest.Model == "" {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
|
||||
return
|
||||
}
|
||||
var selectGroup string
|
||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
showGroup := userGroup
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
//if channel != nil {
|
||||
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
// message = "数据库一致性已被破坏,请联系管理员"
|
||||
//}
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
if channel == nil {
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
@@ -162,7 +175,17 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
relayMode := relayconstant.RelayModeUnknown
|
||||
if c.Request.Method == http.MethodPost {
|
||||
relayMode = relayconstant.RelayModeVideoSubmit
|
||||
} else if c.Request.Method == http.MethodGet {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||
relayMode := relayconstant.RelayModeGemini
|
||||
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
|
||||
@@ -210,47 +233,67 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||
// playground chat completions
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
return nil, false, errors.New("无效的请求, " + err.Error())
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
|
||||
}
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
return
|
||||
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
c.Set("channel_type", channel.Type)
|
||||
c.Set("channel_create_time", channel.CreatedTime)
|
||||
c.Set("channel_setting", channel.GetSetting())
|
||||
c.Set("param_override", channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||
}
|
||||
c.Set("auto_ban", channel.GetAutoBan())
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
|
||||
|
||||
key, index, newAPIError := channel.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
||||
}
|
||||
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeVertexAi:
|
||||
case constant.ChannelTypeVertexAi:
|
||||
c.Set("region", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
case constant.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
case constant.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
case common.ChannelCloudflare:
|
||||
case constant.ChannelCloudflare:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeMokaAI:
|
||||
case constant.ChannelTypeMokaAI:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeCoze:
|
||||
case constant.ChannelTypeCoze:
|
||||
c.Set("bot_id", channel.Other)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
|
||||
|
||||
51
middleware/kling_adapter.go
Normal file
51
middleware/kling_adapter.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func KlingRequestConvert() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
var originalReq map[string]interface{}
|
||||
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Support both model_name and model fields
|
||||
model, _ := originalReq["model_name"].(string)
|
||||
if model == "" {
|
||||
model, _ = originalReq["model"].(string)
|
||||
}
|
||||
prompt, _ := originalReq["prompt"].(string)
|
||||
|
||||
unifiedReq := map[string]interface{}{
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"metadata": originalReq,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(unifiedReq)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite request body and path
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||
c.Request.URL.Path = "/v1/video/generations"
|
||||
if image, ok := originalReq["image"]; !ok || image == "" {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
// We have to reset the request body for the next handlers
|
||||
c.Set(common.KeyRequestBody, jsonData)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
|
||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||
|
||||
// 获取分组
|
||||
group := c.GetString("token_group")
|
||||
group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if group == "" {
|
||||
group = c.GetString(constant.ContextKeyUserGroup)
|
||||
group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
}
|
||||
|
||||
//获取分组的限流配置
|
||||
|
||||
161
model/ability.go
161
model/ability.go
@@ -5,9 +5,11 @@ import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Ability struct {
|
||||
@@ -20,10 +22,25 @@ type Ability struct {
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
}
|
||||
|
||||
func GetGroupModels(group string) []string {
|
||||
type AbilityWithChannel struct {
|
||||
Ability
|
||||
ChannelType int `json:"channel_type"`
|
||||
}
|
||||
|
||||
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
|
||||
var abilities []AbilityWithChannel
|
||||
err := DB.Table("abilities").
|
||||
Select("abilities.*, channels.type as channel_type").
|
||||
Joins("left join channels on abilities.channel_id = channels.id").
|
||||
Where("abilities.enabled = ?", true).
|
||||
Scan(&abilities).Error
|
||||
return abilities, err
|
||||
}
|
||||
|
||||
func GetGroupEnabledModels(group string) []string {
|
||||
var models []string
|
||||
// Find distinct models
|
||||
DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -41,16 +58,12 @@ func GetAllEnableAbilities() []Ability {
|
||||
}
|
||||
|
||||
func getPriority(group string, model string, retry int) (int, error) {
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var priorities []int
|
||||
err := DB.Model(&Ability{}).
|
||||
Select("DISTINCT(priority)").
|
||||
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||
|
||||
if err != nil {
|
||||
@@ -74,30 +87,29 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
return priorityToUse, nil
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
trueVal = "true"
|
||||
}
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
return nil, err
|
||||
} else {
|
||||
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
|
||||
}
|
||||
}
|
||||
|
||||
return channelQuery
|
||||
return channelQuery, nil
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
var abilities []Ability
|
||||
|
||||
var err error = nil
|
||||
channelQuery := getChannelQuery(group, model, retry)
|
||||
channelQuery, err := getChannelQuery(group, model, retry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||
} else {
|
||||
@@ -124,7 +136,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("channel not found")
|
||||
return nil, nil
|
||||
}
|
||||
err = DB.First(&channel, "id = ?", channel.Id).Error
|
||||
return &channel, err
|
||||
@@ -133,9 +145,15 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
func (channel *Channel) AddAbilities() error {
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilitySet := make(map[string]struct{})
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
for _, group := range groups_ {
|
||||
key := group + "|" + model
|
||||
if _, exists := abilitySet[key]; exists {
|
||||
continue
|
||||
}
|
||||
abilitySet[key] = struct{}{}
|
||||
ability := Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
@@ -152,7 +170,7 @@ func (channel *Channel) AddAbilities() error {
|
||||
return nil
|
||||
}
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err := DB.Create(&chunk).Error
|
||||
err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -194,9 +212,15 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
|
||||
// Then add new abilities
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilitySet := make(map[string]struct{})
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
for _, group := range groups_ {
|
||||
key := group + "|" + model
|
||||
if _, exists := abilitySet[key]; exists {
|
||||
continue
|
||||
}
|
||||
abilitySet[key] = struct{}{}
|
||||
ability := Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
@@ -212,7 +236,7 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
|
||||
|
||||
if len(abilities) > 0 {
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err = tx.Create(&chunk).Error
|
||||
err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
||||
if err != nil {
|
||||
if isNewTx {
|
||||
tx.Rollback()
|
||||
@@ -252,74 +276,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
|
||||
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
|
||||
}
|
||||
|
||||
func FixAbility() (int, error) {
|
||||
var channelIds []int
|
||||
count := 0
|
||||
// Find all channel ids from channel table
|
||||
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
|
||||
var fixLock = sync.Mutex{}
|
||||
|
||||
func FixAbility() (int, int, error) {
|
||||
lock := fixLock.TryLock()
|
||||
if !lock {
|
||||
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
||||
}
|
||||
defer fixLock.Unlock()
|
||||
var channels []*Channel
|
||||
// Find all channels
|
||||
err := DB.Model(&Channel{}).Find(&channels).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
|
||||
if len(channelIds) > 0 {
|
||||
// Process deletion in chunks to avoid "too many placeholders" error
|
||||
for _, chunk := range lo.Chunk(channelIds, 100) {
|
||||
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If no channels exist, delete all abilities
|
||||
err = DB.Delete(&Ability{}).Error
|
||||
if len(channels) == 0 {
|
||||
return 0, 0, nil
|
||||
}
|
||||
successCount := 0
|
||||
failCount := 0
|
||||
for _, chunk := range lo.Chunk(channels, 50) {
|
||||
ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
|
||||
// Delete all abilities of this channel
|
||||
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
|
||||
return 0, err
|
||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
failCount += len(chunk)
|
||||
continue
|
||||
}
|
||||
common.SysLog("Delete all abilities successfully")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
|
||||
count += len(channelIds)
|
||||
|
||||
// Use channelIds to find channel not in abilities table
|
||||
var abilityChannelIds []int
|
||||
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
|
||||
return count, err
|
||||
}
|
||||
|
||||
var channels []Channel
|
||||
if len(abilityChannelIds) == 0 {
|
||||
err = DB.Find(&channels).Error
|
||||
} else {
|
||||
// Process query in chunks to avoid "too many placeholders" error
|
||||
err = nil
|
||||
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
|
||||
var channelsChunk []Channel
|
||||
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
|
||||
// Then add new abilities
|
||||
for _, channel := range chunk {
|
||||
err = channel.AddAbilities()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
|
||||
return count, err
|
||||
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
||||
failCount++
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
channels = append(channels, channelsChunk...)
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
err := channel.UpdateAbilities(nil)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
|
||||
count++
|
||||
}
|
||||
}
|
||||
InitChannelCache()
|
||||
return count, nil
|
||||
return successCount, failCount, nil
|
||||
}
|
||||
|
||||
166
model/cache.go
166
model/cache.go
@@ -1,166 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
newChannelId2channel := make(map[int]*Channel)
|
||||
var channels []*Channel
|
||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
||||
for _, channel := range channels {
|
||||
newChannelId2channel[channel.Id] = channel
|
||||
}
|
||||
var abilities []*Ability
|
||||
DB.Find(&abilities)
|
||||
groups := make(map[string]bool)
|
||||
for _, ability := range abilities {
|
||||
groups[ability.Group] = true
|
||||
}
|
||||
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
||||
newChannelsIDM := make(map[int]*Channel)
|
||||
for group := range groups {
|
||||
newGroup2model2channels[group] = make(map[string][]*Channel)
|
||||
}
|
||||
for _, channel := range channels {
|
||||
newChannelsIDM[channel.Id] = channel
|
||||
groups := strings.Split(channel.Group, ",")
|
||||
for _, group := range groups {
|
||||
models := strings.Split(channel.Models, ",")
|
||||
for _, model := range models {
|
||||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
||||
newGroup2model2channels[group][model] = make([]*Channel, 0)
|
||||
}
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sort by priority
|
||||
for group, model2channels := range newGroup2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return channels[i].GetPriority() > channels[j].GetPriority()
|
||||
})
|
||||
newGroup2model2channels[group][model] = channels
|
||||
}
|
||||
}
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelsIDM = newChannelsIDM
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
|
||||
func SyncChannelCache(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing channels from database")
|
||||
InitChannelCache()
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(model, "gpt-4o-gizmo") {
|
||||
model = "gpt-4o-gizmo-*"
|
||||
}
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model, retry)
|
||||
}
|
||||
|
||||
channelSyncLock.RLock()
|
||||
channels := group2model2channels[group][model]
|
||||
channelSyncLock.RUnlock()
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
|
||||
uniquePriorities := make(map[int]bool)
|
||||
for _, channel := range channels {
|
||||
uniquePriorities[int(channel.GetPriority())] = true
|
||||
}
|
||||
var sortedUniquePriorities []int
|
||||
for priority := range uniquePriorities {
|
||||
sortedUniquePriorities = append(sortedUniquePriorities, priority)
|
||||
}
|
||||
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
|
||||
|
||||
if retry >= len(uniquePriorities) {
|
||||
retry = len(uniquePriorities) - 1
|
||||
}
|
||||
targetPriority := int64(sortedUniquePriorities[retry])
|
||||
|
||||
// get the priority for the given retry number
|
||||
var targetChannels []*Channel
|
||||
for _, channel := range channels {
|
||||
if channel.GetPriority() == targetPriority {
|
||||
targetChannels = append(targetChannels, channel)
|
||||
}
|
||||
}
|
||||
|
||||
// 平滑系数
|
||||
smoothingFactor := 10
|
||||
// Calculate the total weight of all channels up to endIdx
|
||||
totalWeight := 0
|
||||
for _, channel := range targetChannels {
|
||||
totalWeight += channel.GetWeight() + smoothingFactor
|
||||
}
|
||||
// Generate a random value in the range [0, totalWeight)
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
|
||||
// Find a channel based on its weight
|
||||
for _, channel := range targetChannels {
|
||||
randomWeight -= channel.GetWeight() + smoothingFactor
|
||||
if randomWeight < 0 {
|
||||
return channel, nil
|
||||
}
|
||||
}
|
||||
// return null if no channel is not found
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
|
||||
func CacheGetChannel(id int) (*Channel, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetChannelById(id, true)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
c, ok := channelsIDM[id]
|
||||
if !ok {
|
||||
return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func CacheUpdateChannelStatus(id int, status int) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
channelSyncLock.Lock()
|
||||
defer channelSyncLock.Unlock()
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
}
|
||||
421
model/channel.go
421
model/channel.go
@@ -1,8 +1,15 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -35,8 +42,138 @@ type Channel struct {
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
}
|
||||
|
||||
type ChannelInfo struct {
|
||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
func (c ChannelInfo) Value() (driver.Value, error) {
|
||||
return common.Marshal(&c)
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner interface
|
||||
func (c *ChannelInfo) Scan(value interface{}) error {
|
||||
bytesValue, _ := value.([]byte)
|
||||
return common.Unmarshal(bytesValue, c)
|
||||
}
|
||||
|
||||
func (channel *Channel) getKeys() []string {
|
||||
if channel.Key == "" {
|
||||
return []string{}
|
||||
}
|
||||
trimmed := strings.TrimSpace(channel.Key)
|
||||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
res := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
res[i] = string(v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
// Otherwise, fall back to splitting by newline
|
||||
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
|
||||
return keys
|
||||
}
|
||||
|
||||
func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||||
// If not in multi-key mode, return the original key string directly.
|
||||
if !channel.ChannelInfo.IsMultiKey {
|
||||
return channel.Key, 0, nil
|
||||
}
|
||||
|
||||
// Obtain all keys (split by \n)
|
||||
keys := channel.getKeys()
|
||||
if len(keys) == 0 {
|
||||
// No keys available, return error, should disable the channel
|
||||
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||||
}
|
||||
|
||||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||
// helper to get key status, default to enabled when missing
|
||||
getStatus := func(idx int) int {
|
||||
if statusList == nil {
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
if status, ok := statusList[idx]; ok {
|
||||
return status
|
||||
}
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
|
||||
// Collect indexes of enabled keys
|
||||
enabledIdx := make([]int, 0, len(keys))
|
||||
for i := range keys {
|
||||
if getStatus(i) == common.ChannelStatusEnabled {
|
||||
enabledIdx = append(enabledIdx, i)
|
||||
}
|
||||
}
|
||||
// If no specific status list or none enabled, fall back to first key
|
||||
if len(enabledIdx) == 0 {
|
||||
return keys[0], 0, nil
|
||||
}
|
||||
|
||||
switch channel.ChannelInfo.MultiKeyMode {
|
||||
case constant.MultiKeyModeRandom:
|
||||
// Randomly pick one enabled key
|
||||
selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
|
||||
return keys[selectedIdx], selectedIdx, nil
|
||||
case constant.MultiKeyModePolling:
|
||||
// Use channel-specific lock to ensure thread-safe polling
|
||||
lock := getChannelPollingLock(channel.Id)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||||
if err != nil {
|
||||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||||
defer func() {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
|
||||
}
|
||||
if !common.MemoryCacheEnabled {
|
||||
_ = channel.SaveChannelInfo()
|
||||
} else {
|
||||
// CacheUpdateChannel(channel)
|
||||
}
|
||||
}()
|
||||
// Start from the saved polling index and look for the next enabled key
|
||||
start := channelInfo.MultiKeyPollingIndex
|
||||
if start < 0 || start >= len(keys) {
|
||||
start = 0
|
||||
}
|
||||
for i := 0; i < len(keys); i++ {
|
||||
idx := (start + i) % len(keys)
|
||||
if getStatus(idx) == common.ChannelStatusEnabled {
|
||||
// update polling index for next call (point to the next position)
|
||||
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
|
||||
return keys[idx], idx, nil
|
||||
}
|
||||
}
|
||||
// Fallback – should not happen, but return first enabled key
|
||||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||||
default:
|
||||
// Unknown mode, default to first enabled key (or original key string)
|
||||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||||
}
|
||||
}
|
||||
|
||||
func (channel *Channel) SaveChannelInfo() error {
|
||||
return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
@@ -60,7 +197,7 @@ func (channel *Channel) GetGroups() []string {
|
||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||||
otherInfo := make(map[string]interface{})
|
||||
if channel.OtherInfo != "" {
|
||||
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal other info: " + err.Error())
|
||||
}
|
||||
@@ -145,7 +282,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
}
|
||||
|
||||
// 构造基础查询
|
||||
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
|
||||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||
|
||||
// 构造WHERE子句
|
||||
var whereClause string
|
||||
@@ -153,15 +290,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
if group != "" && group != "null" {
|
||||
var groupCondition string
|
||||
if common.UsingMySQL {
|
||||
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
|
||||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||||
} else {
|
||||
// sqlite, PostgreSQL
|
||||
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
|
||||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||||
}
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||||
} else {
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||||
}
|
||||
|
||||
@@ -174,14 +311,20 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
}
|
||||
|
||||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
channel := Channel{Id: id}
|
||||
channel := &Channel{Id: id}
|
||||
var err error = nil
|
||||
if selectAll {
|
||||
err = DB.First(&channel, "id = ?", id).Error
|
||||
err = DB.First(channel, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Omit("key").First(&channel, "id = ?", id).Error
|
||||
err = DB.Omit("key").First(channel, "id = ?", id).Error
|
||||
}
|
||||
return &channel, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func BatchInsertChannels(channels []Channel) error {
|
||||
@@ -265,6 +408,44 @@ func (channel *Channel) Insert() error {
|
||||
}
|
||||
|
||||
func (channel *Channel) Update() error {
|
||||
// If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
var keyStr string
|
||||
if channel.Key != "" {
|
||||
keyStr = channel.Key
|
||||
} else {
|
||||
// If key is not provided, read the existing key from the database
|
||||
if existing, err := GetChannelById(channel.Id, true); err == nil {
|
||||
keyStr = existing.Key
|
||||
}
|
||||
}
|
||||
// Parse the key list (supports newline separation or JSON array)
|
||||
keys := []string{}
|
||||
if keyStr != "" {
|
||||
trimmed := strings.TrimSpace(keyStr)
|
||||
if strings.HasPrefix(trimmed, "[") {
|
||||
var arr []json.RawMessage
|
||||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||||
keys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
keys[i] = string(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 { // fallback to newline split
|
||||
keys = strings.Split(strings.Trim(keyStr, "\n"), "\n")
|
||||
}
|
||||
}
|
||||
channel.ChannelInfo.MultiKeySize = len(keys)
|
||||
// Clean up status data that exceeds the new key count to prevent index out of range
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
for idx := range channel.ChannelInfo.MultiKeyStatusList {
|
||||
if idx >= channel.ChannelInfo.MultiKeySize {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
var err error
|
||||
err = DB.Model(channel).Updates(channel).Error
|
||||
if err != nil {
|
||||
@@ -307,48 +488,125 @@ func (channel *Channel) Delete() error {
|
||||
|
||||
var channelStatusLock sync.Mutex
|
||||
|
||||
func UpdateChannelStatusById(id int, status int, reason string) bool {
|
||||
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
|
||||
var channelPollingLocks sync.Map
|
||||
|
||||
// getChannelPollingLock returns or creates a mutex for the given channel ID
|
||||
func getChannelPollingLock(channelId int) *sync.Mutex {
|
||||
if lock, exists := channelPollingLocks.Load(channelId); exists {
|
||||
return lock.(*sync.Mutex)
|
||||
}
|
||||
// Create new lock for this channel
|
||||
newLock := &sync.Mutex{}
|
||||
actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
|
||||
return actual.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// CleanupChannelPollingLocks removes locks for channels that no longer exist
|
||||
// This is optional and can be called periodically to prevent memory leaks
|
||||
func CleanupChannelPollingLocks() {
|
||||
var activeChannelIds []int
|
||||
DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
|
||||
|
||||
activeChannelSet := make(map[int]bool)
|
||||
for _, id := range activeChannelIds {
|
||||
activeChannelSet[id] = true
|
||||
}
|
||||
|
||||
channelPollingLocks.Range(func(key, value interface{}) bool {
|
||||
channelId := key.(int)
|
||||
if !activeChannelSet[channelId] {
|
||||
channelPollingLocks.Delete(channelId)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
|
||||
keys := channel.getKeys()
|
||||
if len(keys) == 0 {
|
||||
channel.Status = status
|
||||
} else {
|
||||
var keyIndex int
|
||||
for i, key := range keys {
|
||||
if key == usingKey {
|
||||
keyIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
if status == common.ChannelStatusEnabled {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||
} else {
|
||||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
||||
}
|
||||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||||
channel.Status = common.ChannelStatusAutoDisabled
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = "All keys are disabled"
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
|
||||
println("UpdateChannelStatus called with channelId:", channelId, "usingKey:", usingKey, "status:", status, "reason:", reason)
|
||||
if common.MemoryCacheEnabled {
|
||||
channelStatusLock.Lock()
|
||||
defer channelStatusLock.Unlock()
|
||||
|
||||
channelCache, _ := CacheGetChannel(id)
|
||||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||||
if channelCache != nil && channelCache.Status == status {
|
||||
channelCache, _ := CacheGetChannel(channelId)
|
||||
if channelCache == nil {
|
||||
return false
|
||||
}
|
||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
||||
if channelCache == nil && status != common.ChannelStatusEnabled {
|
||||
return false
|
||||
if channelCache.ChannelInfo.IsMultiKey {
|
||||
// 如果是多Key模式,更新缓存中的状态
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status)
|
||||
//CacheUpdateChannel(channelCache)
|
||||
//return true
|
||||
} else {
|
||||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||||
if channelCache.Status == status {
|
||||
return false
|
||||
}
|
||||
CacheUpdateChannelStatus(channelId, status)
|
||||
}
|
||||
CacheUpdateChannelStatus(id, status)
|
||||
}
|
||||
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
||||
|
||||
shouldUpdateAbilities := false
|
||||
defer func() {
|
||||
if shouldUpdateAbilities {
|
||||
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
||||
if err != nil {
|
||||
common.SysError("failed to update ability status: " + err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
channel, err := GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.SysError("failed to update ability status: " + err.Error())
|
||||
return false
|
||||
}
|
||||
channel, err := GetChannelById(id, true)
|
||||
if err != nil {
|
||||
// find channel by id error, directly update status
|
||||
result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
|
||||
if result.Error != nil {
|
||||
common.SysError("failed to update channel status: " + result.Error.Error())
|
||||
return false
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if channel.Status == status {
|
||||
return false
|
||||
}
|
||||
// find channel by id success, update status and other info
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = reason
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
channel.Status = status
|
||||
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
beforeStatus := channel.Status
|
||||
handlerMultiKeyUpdate(channel, usingKey, status)
|
||||
if beforeStatus != channel.Status {
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
} else {
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = reason
|
||||
info["status_time"] = common.GetTimestamp()
|
||||
channel.SetOtherInfo(info)
|
||||
channel.Status = status
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
err = channel.Save()
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel status: " + err.Error())
|
||||
@@ -478,7 +736,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
}
|
||||
|
||||
// 构造基础查询
|
||||
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
|
||||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||
|
||||
// 构造WHERE子句
|
||||
var whereClause string
|
||||
@@ -486,15 +744,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
if group != "" && group != "null" {
|
||||
var groupCondition string
|
||||
if common.UsingMySQL {
|
||||
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
|
||||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||||
} else {
|
||||
// sqlite, PostgreSQL
|
||||
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
|
||||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||||
}
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||||
} else {
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||||
}
|
||||
|
||||
@@ -514,19 +772,32 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
return tags, nil
|
||||
}
|
||||
|
||||
func (channel *Channel) GetSetting() map[string]interface{} {
|
||||
setting := make(map[string]interface{})
|
||||
func (channel *Channel) ValidateSettings() error {
|
||||
channelParams := &dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
err := common.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
setting := dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
channel.Setting = nil // 清空设置以避免后续错误
|
||||
_ = channel.Save() // 保存修改
|
||||
}
|
||||
}
|
||||
return setting
|
||||
}
|
||||
|
||||
func (channel *Channel) SetSetting(setting map[string]interface{}) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
settingBytes, err := common.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
@@ -537,7 +808,7 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
|
||||
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||||
paramOverride := make(map[string]interface{})
|
||||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||||
err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal param override: " + err.Error())
|
||||
}
|
||||
@@ -583,3 +854,53 @@ func BatchSetChannelTag(ids []int, tag *string) error {
|
||||
// 提交事务
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// CountAllChannels returns total channels in DB
|
||||
func CountAllChannels() (int64, error) {
|
||||
var total int64
|
||||
err := DB.Model(&Channel{}).Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// CountAllTags returns number of non-empty distinct tags
|
||||
func CountAllTags() (int64, error) {
|
||||
var total int64
|
||||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// Get channels of specified type with pagination
|
||||
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// Count channels of specific type
|
||||
func CountChannelsByType(channelType int) (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Return map[type]count for all channels
|
||||
func CountChannelsGroupByType() (map[int64]int64, error) {
|
||||
type result struct {
|
||||
Type int64 `gorm:"column:type"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
}
|
||||
var results []result
|
||||
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
counts[r.Type] = r.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
259
model/channel_cache.go
Normal file
259
model/channel_cache.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var group2model2channels map[string]map[string][]int // enabled channel
|
||||
var channelsIDM map[int]*Channel // all channels include disabled
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
newChannelId2channel := make(map[int]*Channel)
|
||||
var channels []*Channel
|
||||
DB.Find(&channels)
|
||||
for _, channel := range channels {
|
||||
newChannelId2channel[channel.Id] = channel
|
||||
}
|
||||
var abilities []*Ability
|
||||
DB.Find(&abilities)
|
||||
groups := make(map[string]bool)
|
||||
for _, ability := range abilities {
|
||||
groups[ability.Group] = true
|
||||
}
|
||||
newGroup2model2channels := make(map[string]map[string][]int)
|
||||
for group := range groups {
|
||||
newGroup2model2channels[group] = make(map[string][]int)
|
||||
}
|
||||
for _, channel := range channels {
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue // skip disabled channels
|
||||
}
|
||||
groups := strings.Split(channel.Group, ",")
|
||||
for _, group := range groups {
|
||||
models := strings.Split(channel.Models, ",")
|
||||
for _, model := range models {
|
||||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
||||
newGroup2model2channels[group][model] = make([]int, 0)
|
||||
}
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sort by priority
|
||||
for group, model2channels := range newGroup2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
|
||||
})
|
||||
newGroup2model2channels[group][model] = channels
|
||||
}
|
||||
}
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelsIDM = newChannelId2channel
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
|
||||
func SyncChannelCache(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing channels from database")
|
||||
InitChannelCache()
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
|
||||
var channel *Channel
|
||||
var err error
|
||||
selectGroup := group
|
||||
if group == "auto" {
|
||||
if len(setting.AutoGroups) == 0 {
|
||||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||||
}
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
if common.DebugEnabled {
|
||||
println("autoGroup:", autoGroup)
|
||||
}
|
||||
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
|
||||
if channel == nil {
|
||||
continue
|
||||
} else {
|
||||
c.Set("auto_group", autoGroup)
|
||||
selectGroup = autoGroup
|
||||
if common.DebugEnabled {
|
||||
println("selectGroup:", selectGroup)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
channel, err = getRandomSatisfiedChannel(group, model, retry)
|
||||
if err != nil {
|
||||
return nil, group, err
|
||||
}
|
||||
}
|
||||
return channel, selectGroup, nil
|
||||
}
|
||||
|
||||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(model, "gpt-4o-gizmo") {
|
||||
model = "gpt-4o-gizmo-*"
|
||||
}
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model, retry)
|
||||
}
|
||||
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
channels := group2model2channels[group][model]
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(channels) == 1 {
|
||||
if channel, ok := channelsIDM[channels[0]]; ok {
|
||||
return channel, nil
|
||||
}
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
|
||||
}
|
||||
|
||||
uniquePriorities := make(map[int]bool)
|
||||
for _, channelId := range channels {
|
||||
if channel, ok := channelsIDM[channelId]; ok {
|
||||
uniquePriorities[int(channel.GetPriority())] = true
|
||||
} else {
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||||
}
|
||||
}
|
||||
var sortedUniquePriorities []int
|
||||
for priority := range uniquePriorities {
|
||||
sortedUniquePriorities = append(sortedUniquePriorities, priority)
|
||||
}
|
||||
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
|
||||
|
||||
if retry >= len(uniquePriorities) {
|
||||
retry = len(uniquePriorities) - 1
|
||||
}
|
||||
targetPriority := int64(sortedUniquePriorities[retry])
|
||||
|
||||
// get the priority for the given retry number
|
||||
var targetChannels []*Channel
|
||||
for _, channelId := range channels {
|
||||
if channel, ok := channelsIDM[channelId]; ok {
|
||||
if channel.GetPriority() == targetPriority {
|
||||
targetChannels = append(targetChannels, channel)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||||
}
|
||||
}
|
||||
|
||||
// 平滑系数
|
||||
smoothingFactor := 10
|
||||
// Calculate the total weight of all channels up to endIdx
|
||||
totalWeight := 0
|
||||
for _, channel := range targetChannels {
|
||||
totalWeight += channel.GetWeight() + smoothingFactor
|
||||
}
|
||||
// Generate a random value in the range [0, totalWeight)
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
|
||||
// Find a channel based on its weight
|
||||
for _, channel := range targetChannels {
|
||||
randomWeight -= channel.GetWeight() + smoothingFactor
|
||||
if randomWeight < 0 {
|
||||
return channel, nil
|
||||
}
|
||||
}
|
||||
// return null if no channel is not found
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
|
||||
func CacheGetChannel(id int) (*Channel, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetChannelById(id, true)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
c, ok := channelsIDM[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
channel, err := GetChannelById(id, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &channel.ChannelInfo, nil
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
c, ok := channelsIDM[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||
}
|
||||
if c.Status != common.ChannelStatusEnabled {
|
||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||
}
|
||||
return &c.ChannelInfo, nil
|
||||
}
|
||||
|
||||
func CacheUpdateChannelStatus(id int, status int) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
channelSyncLock.Lock()
|
||||
defer channelSyncLock.Unlock()
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
}
|
||||
|
||||
func CacheUpdateChannel(channel *Channel) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
channelSyncLock.Lock()
|
||||
defer channelSyncLock.Unlock()
|
||||
if channel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
|
||||
|
||||
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||||
channelsIDM[channel.Id] = channel
|
||||
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||||
}
|
||||
102
model/log.go
102
model/log.go
@@ -27,11 +27,12 @@ type Log struct {
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
UseTime int `json:"use_time" gorm:"default:0"`
|
||||
IsStream bool `json:"is_stream" gorm:"default:false"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
@@ -48,7 +49,7 @@ func formatUserLogs(logs []*Log) {
|
||||
for i := range logs {
|
||||
logs[i].ChannelName = ""
|
||||
var otherMap map[string]interface{}
|
||||
otherMap = common.StrToMap(logs[i].Other)
|
||||
otherMap, _ = common.StrToMap(logs[i].Other)
|
||||
if otherMap != nil {
|
||||
// delete admin
|
||||
delete(otherMap, "admin_info")
|
||||
@@ -61,7 +62,7 @@ func formatUserLogs(logs []*Log) {
|
||||
func GetLogByKey(key string) (logs []*Log, err error) {
|
||||
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||
var tk Token
|
||||
if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
||||
if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
|
||||
@@ -95,6 +96,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if settingMap.RecordIpLog {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
@@ -111,7 +119,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Other: otherStr,
|
||||
Ip: func() string {
|
||||
if needRecordIp {
|
||||
return c.ClientIP()
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
@@ -119,32 +133,59 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
||||
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
type RecordConsumeLogParams struct {
|
||||
ChannelId int `json:"channel_id"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
ModelName string `json:"model_name"`
|
||||
TokenName string `json:"token_name"`
|
||||
Quota int `json:"quota"`
|
||||
Content string `json:"content"`
|
||||
TokenId int `json:"token_id"`
|
||||
UserQuota int `json:"user_quota"`
|
||||
UseTimeSeconds int `json:"use_time_seconds"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
Group string `json:"group"`
|
||||
Other map[string]interface{} `json:"other"`
|
||||
}
|
||||
|
||||
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
|
||||
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
otherStr := common.MapToJsonStr(params.Other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if settingMap.RecordIpLog {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
CreatedAt: common.GetTimestamp(),
|
||||
Type: LogTypeConsume,
|
||||
Content: content,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TokenName: tokenName,
|
||||
ModelName: modelName,
|
||||
Quota: quota,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Other: otherStr,
|
||||
Content: params.Content,
|
||||
PromptTokens: params.PromptTokens,
|
||||
CompletionTokens: params.CompletionTokens,
|
||||
TokenName: params.TokenName,
|
||||
ModelName: params.ModelName,
|
||||
Quota: params.Quota,
|
||||
ChannelId: params.ChannelId,
|
||||
TokenId: params.TokenId,
|
||||
UseTime: params.UseTimeSeconds,
|
||||
IsStream: params.IsStream,
|
||||
Group: params.Group,
|
||||
Ip: func() string {
|
||||
if needRecordIp {
|
||||
return c.ClientIP()
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
@@ -152,7 +193,7 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
||||
}
|
||||
if common.DataExportEnabled {
|
||||
gopool.Go(func() {
|
||||
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
|
||||
LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -184,7 +225,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
tx = tx.Where("logs.channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where("logs."+groupCol+" = ?", group)
|
||||
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
@@ -195,13 +236,18 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
channelIds := make([]int, 0)
|
||||
channelIdsMap := make(map[int]struct{})
|
||||
channelMap := make(map[int]string)
|
||||
for _, log := range logs {
|
||||
if log.ChannelId != 0 {
|
||||
channelIds = append(channelIds, log.ChannelId)
|
||||
channelIdsMap[log.ChannelId] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
channelIds := make([]int, 0, len(channelIdsMap))
|
||||
for channelId := range channelIdsMap {
|
||||
channelIds = append(channelIds, channelId)
|
||||
}
|
||||
if len(channelIds) > 0 {
|
||||
var channels []struct {
|
||||
Id int `gorm:"column:id"`
|
||||
@@ -242,7 +288,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
tx = tx.Where("logs.created_at <= ?", endTimestamp)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where("logs."+groupCol+" = ?", group)
|
||||
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
@@ -303,8 +349,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
|
||||
tx = tx.Where(logGroupCol+" = ?", group)
|
||||
rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
|
||||
}
|
||||
|
||||
tx = tx.Where("type = ?", LogTypeConsume)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user