mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-17 19:17:27 +00:00
Compare commits
886 Commits
refactor_e
...
refactor/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e994abdd9 | ||
|
|
26a18346b2 | ||
|
|
99fcc354e3 | ||
|
|
456987a3d4 | ||
|
|
347c31f93c | ||
|
|
836ae7affe | ||
|
|
dd46322421 | ||
|
|
71f5dc987a | ||
|
|
6992fd2b66 | ||
|
|
92895ebe5a | ||
|
|
c0fb3bf95f | ||
|
|
abe31f216f | ||
|
|
44bc65691e | ||
|
|
7c27558de9 | ||
|
|
51ef19a3fb | ||
|
|
8e7301b79a | ||
|
|
ec98a21933 | ||
|
|
1dd59f5d08 | ||
|
|
ea084e775e | ||
|
|
41be436c04 | ||
|
|
b73b16e102 | ||
|
|
8f9960bcc7 | ||
|
|
3c70617060 | ||
|
|
3a98ae3f70 | ||
|
|
1894ddc786 | ||
|
|
f23be16e98 | ||
|
|
b882dfa8f6 | ||
|
|
d491cbd3d2 | ||
|
|
334ba555fc | ||
|
|
ba632d0b4d | ||
|
|
b5d3e87ea2 | ||
|
|
f22ea6e0a8 | ||
|
|
9f1ab16aa5 | ||
|
|
0dd475d2ff | ||
|
|
dd374cdd9b | ||
|
|
daf3ef9848 | ||
|
|
23ee0fc3b4 | ||
|
|
08638b18ce | ||
|
|
d331f0fb2a | ||
|
|
4b98fceb6e | ||
|
|
ef63416098 | ||
|
|
50a432180d | ||
|
|
2ea7634549 | ||
|
|
10da082412 | ||
|
|
31c8ead1d4 | ||
|
|
00f4594062 | ||
|
|
467e584359 | ||
|
|
f635fc3ae6 | ||
|
|
168ebb1cd4 | ||
|
|
b7bc609a7a | ||
|
|
046c8b27b6 | ||
|
|
4be61d00e4 | ||
|
|
4ac7d94026 | ||
|
|
9af71caf73 | ||
|
|
91e57a4c69 | ||
|
|
45a6a779e5 | ||
|
|
49c7a0dee5 | ||
|
|
956244c742 | ||
|
|
752dc11dd4 | ||
|
|
17be7c3b45 | ||
|
|
11cf70e60d | ||
|
|
dfa27f3412 | ||
|
|
e34b5def60 | ||
|
|
63f94e7669 | ||
|
|
18a385f817 | ||
|
|
8e95d338b5 | ||
|
|
f236785ed5 | ||
|
|
f3e220b196 | ||
|
|
33bf267ce8 | ||
|
|
05c2dde38f | ||
|
|
0ee5670be6 | ||
|
|
9790e2c4f6 | ||
|
|
4f760a8d40 | ||
|
|
8563eafc57 | ||
|
|
72d5b35d3f | ||
|
|
7d71f467d9 | ||
|
|
aea732ab92 | ||
|
|
da6f24a3d4 | ||
|
|
28ed42130c | ||
|
|
96215c9fd5 | ||
|
|
6628fd9181 | ||
|
|
a3b8a1998a | ||
|
|
6a34d365ec | ||
|
|
406a3e4dca | ||
|
|
c1d7ecdeec | ||
|
|
6451158680 | ||
|
|
0bd4b34046 | ||
|
|
f14b06ec3a | ||
|
|
6ed775be8f | ||
|
|
b712279b2a | ||
|
|
1bffe3081d | ||
|
|
cfebe80822 | ||
|
|
17e697af8f | ||
|
|
01b35bb667 | ||
|
|
d8410d2f11 | ||
|
|
e68eed3d40 | ||
|
|
04cc668430 | ||
|
|
5d76e16324 | ||
|
|
b6c547ae98 | ||
|
|
93adcd57d7 | ||
|
|
e813da59cc | ||
|
|
b25ac0bfb6 | ||
|
|
70c27bc662 | ||
|
|
db6a788e0d | ||
|
|
e3bc40f11b | ||
|
|
3e9be07db4 | ||
|
|
684caa3673 | ||
|
|
47aaa695b2 | ||
|
|
cda73a2ec5 | ||
|
|
27a0a447d0 | ||
|
|
fcdfd027cd | ||
|
|
3f9698bb47 | ||
|
|
041782c49e | ||
|
|
18077b6e87 | ||
|
|
c40a4f5444 | ||
|
|
028f0220dd | ||
|
|
a616aa3c89 | ||
|
|
1c12c73496 | ||
|
|
b29efbde52 | ||
|
|
b7527eb80e | ||
|
|
d05974fa3d | ||
|
|
a77a88308a | ||
|
|
e5a5d2de7c | ||
|
|
c0187d50ff | ||
|
|
3d0bf36981 | ||
|
|
e61c1dc738 | ||
|
|
91a627ddfc | ||
|
|
3064ff093a | ||
|
|
e2f736bd2d | ||
|
|
c6cf1b98f8 | ||
|
|
56fc3441da | ||
|
|
ebaaecb9d9 | ||
|
|
fa7ba4a390 | ||
|
|
29983e434f | ||
|
|
8c65264474 | ||
|
|
cd4b75f492 | ||
|
|
faad6bcd0c | ||
|
|
265a9ea78c | ||
|
|
aeab08099b | ||
|
|
d9f37d16f7 | ||
|
|
203abf4430 | ||
|
|
17024490e9 | ||
|
|
f7ae3621f4 | ||
|
|
5cbd9da3f5 | ||
|
|
daffba3641 | ||
|
|
860ab51434 | ||
|
|
1442666cc0 | ||
|
|
5ac9ebdebb | ||
|
|
a47a37d315 | ||
|
|
fbc19abd28 | ||
|
|
1f111a163a | ||
|
|
b601d8fd7c | ||
|
|
e98ca000f2 | ||
|
|
5351c28af8 | ||
|
|
e174861b96 | ||
|
|
247e029159 | ||
|
|
5cfc133413 | ||
|
|
c6f53e4cc8 | ||
|
|
c8acbdb363 | ||
|
|
3a3be21366 | ||
|
|
274da13a19 | ||
|
|
153994fe45 | ||
|
|
cdef6da9e9 | ||
|
|
9127449a7a | ||
|
|
8809c44443 | ||
|
|
d15718a87e | ||
|
|
da5aace109 | ||
|
|
6a87808612 | ||
|
|
105b86c660 | ||
|
|
b8b66c3900 | ||
|
|
bc5b9a5506 | ||
|
|
9c798dcd16 | ||
|
|
f5b8abc3f3 | ||
|
|
09cc127121 | ||
|
|
ac67d50616 | ||
|
|
86964bb426 | ||
|
|
c05dc07666 | ||
|
|
af94e11c7d | ||
|
|
0f86c4df9e | ||
|
|
5f0db18d3a | ||
|
|
919e6937ee | ||
|
|
64e23f02f7 | ||
|
|
fbe7f35a25 | ||
|
|
8cd0150a75 | ||
|
|
839aa401f0 | ||
|
|
4055777110 | ||
|
|
b3a99a2625 | ||
|
|
872f7a9648 | ||
|
|
b0c703935f | ||
|
|
621d2b0b6a | ||
|
|
e69520b7fb | ||
|
|
4b968d03a1 | ||
|
|
edc6679140 | ||
|
|
e732c58426 | ||
|
|
81e29aaa3d | ||
|
|
c5a1cbe755 | ||
|
|
35218609d9 | ||
|
|
7629ad553a | ||
|
|
7ddf3a112c | ||
|
|
034094c2d2 | ||
|
|
65ed6d9d5b | ||
|
|
4524f90ebd | ||
|
|
33dd326007 | ||
|
|
95b487c51e | ||
|
|
fcb03392d1 | ||
|
|
64a6168092 | ||
|
|
6a6edaa7cf | ||
|
|
a95d70cf93 | ||
|
|
3e01dc81ec | ||
|
|
e087c9fe9e | ||
|
|
33d601db82 | ||
|
|
eef73e3699 | ||
|
|
1cc07546cb | ||
|
|
e23f01f8d5 | ||
|
|
a3c2b28d6a | ||
|
|
289ed24899 | ||
|
|
98db907680 | ||
|
|
b1cc9050ff | ||
|
|
dc4f5750af | ||
|
|
d374a22b70 | ||
|
|
595ed6b40e | ||
|
|
c9f5b1de1a | ||
|
|
522f2d920b | ||
|
|
bef59929db | ||
|
|
b27b9a1098 | ||
|
|
70de3819e8 | ||
|
|
af18dec46b | ||
|
|
43efc2161a | ||
|
|
caaa988c87 | ||
|
|
ee6dd9179b | ||
|
|
f96a733430 | ||
|
|
de23ccd234 | ||
|
|
da516af837 | ||
|
|
7fbf9c4851 | ||
|
|
808f5c481e | ||
|
|
6dcf954bfe | ||
|
|
cb6fa7d46d | ||
|
|
1e3621833f | ||
|
|
eedb57b2c6 | ||
|
|
524f6d6af5 | ||
|
|
53f7a7993e | ||
|
|
abcb353793 | ||
|
|
d7c2a9f1b8 | ||
|
|
7969df3926 | ||
|
|
97c52a6991 | ||
|
|
a50288c186 | ||
|
|
f246c12959 | ||
|
|
5d7ab194e2 | ||
|
|
8a329f6522 | ||
|
|
4200edb983 | ||
|
|
93ce48aca8 | ||
|
|
df1ec4832c | ||
|
|
e3a38d27f5 | ||
|
|
754498a012 | ||
|
|
4226746675 | ||
|
|
94536be9be | ||
|
|
2c6a9245ee | ||
|
|
fc18a3c89e | ||
|
|
4f23e53002 | ||
|
|
005e9659e1 | ||
|
|
43c6bbb3ad | ||
|
|
def4d16c73 | ||
|
|
61ae19ac82 | ||
|
|
08add538a0 | ||
|
|
bd166b2f77 | ||
|
|
8b7384e47f | ||
|
|
60dc032cb8 | ||
|
|
d47190f1fd | ||
|
|
e581422810 | ||
|
|
ad151bb919 | ||
|
|
b5040e0182 | ||
|
|
c826d06d2c | ||
|
|
7c058bfee3 | ||
|
|
3133e91d8e | ||
|
|
b5e55c81d4 | ||
|
|
0837747428 | ||
|
|
518763cd08 | ||
|
|
2b862f65a2 | ||
|
|
cb53adef62 | ||
|
|
c3481f5a67 | ||
|
|
ba50b6fcc0 | ||
|
|
003246f113 | ||
|
|
13aee98d4a | ||
|
|
6c94573323 | ||
|
|
03a257bddb | ||
|
|
e02e1e8d4a | ||
|
|
57f1015197 | ||
|
|
974b93a8be | ||
|
|
652d71d799 | ||
|
|
f6d4c586eb | ||
|
|
adc7fbd424 | ||
|
|
cfc6bc8e5e | ||
|
|
da802ece3b | ||
|
|
1074f8acb1 | ||
|
|
a0e6a72b69 | ||
|
|
795cfd471a | ||
|
|
0a053ee633 | ||
|
|
85f81df2f8 | ||
|
|
94d9607447 | ||
|
|
2be4489d18 | ||
|
|
d34e4f1f28 | ||
|
|
11a81c25ef | ||
|
|
c18414cbe4 | ||
|
|
998305fd00 | ||
|
|
49ab1a3b38 | ||
|
|
c123ea3179 | ||
|
|
a6ad49dba0 | ||
|
|
3749be3e09 | ||
|
|
b67a42e0a8 | ||
|
|
9805d35a5d | ||
|
|
e3473e3c39 | ||
|
|
a1cab158ea | ||
|
|
9934cdc5bd | ||
|
|
c834694992 | ||
|
|
aa1f5c6e4e | ||
|
|
2d28fb3a73 | ||
|
|
206ed55db4 | ||
|
|
9b0913343c | ||
|
|
5696a62c27 | ||
|
|
11a7ac9b10 | ||
|
|
cbce487362 | ||
|
|
f8ca8d7cea | ||
|
|
732e5d2661 | ||
|
|
5d6fac69c4 | ||
|
|
5654d08086 | ||
|
|
73a7b33864 | ||
|
|
64a752a3b4 | ||
|
|
0ad918c21d | ||
|
|
5829bc69ca | ||
|
|
b591b4ebdf | ||
|
|
dd497d5bd8 | ||
|
|
f70cac54d1 | ||
|
|
f6a48434c1 | ||
|
|
c63b6b3ef8 | ||
|
|
28bd31a30b | ||
|
|
491013e27a | ||
|
|
0bb43aa464 | ||
|
|
0edc707657 | ||
|
|
68b7badb80 | ||
|
|
b57e97d2a1 | ||
|
|
eeb421513b | ||
|
|
ef1e380bbc | ||
|
|
2579b3c0ba | ||
|
|
d646a922ee | ||
|
|
8b2afcec90 | ||
|
|
726f1632b0 | ||
|
|
5a2dad2e16 | ||
|
|
039b00d695 | ||
|
|
1cb63063f7 | ||
|
|
5671503c28 | ||
|
|
50dafeaa0b | ||
|
|
1d4850e47a | ||
|
|
cc4f73dc7e | ||
|
|
067be3727e | ||
|
|
edeb4791c9 | ||
|
|
2f25e44e60 | ||
|
|
5fe1ce89ec | ||
|
|
03fc89da00 | ||
|
|
44e9b02b3f | ||
|
|
7f1f368065 | ||
|
|
89caccd4e0 | ||
|
|
6748b006b7 | ||
|
|
baf086d5b3 | ||
|
|
e2037ad756 | ||
|
|
d75e198304 | ||
|
|
223f0d0850 | ||
|
|
01cd279f9f | ||
|
|
3b26810c17 | ||
|
|
196e2a0abb | ||
|
|
63b9457b6c | ||
|
|
0082b87f61 | ||
|
|
936b1f8d09 | ||
|
|
c1a545ac23 | ||
|
|
33925dd313 | ||
|
|
f5abbeb353 | ||
|
|
c13683e982 | ||
|
|
17bab355e4 | ||
|
|
e77effaf8b | ||
|
|
2e39323782 | ||
|
|
8db5356caf | ||
|
|
fa2edd9d3f | ||
|
|
7997a04a68 | ||
|
|
2c30b4cf60 | ||
|
|
38c3349a6a | ||
|
|
41cb01bac9 | ||
|
|
c2ef4c8e54 | ||
|
|
a7cd44e536 | ||
|
|
dc12ec6dfd | ||
|
|
6eec8851eb | ||
|
|
39c966efdd | ||
|
|
981023154b | ||
|
|
14a9a99e2d | ||
|
|
e74c6f5de7 | ||
|
|
d3170310ff | ||
|
|
03cfc05afd | ||
|
|
fa686207ed | ||
|
|
e863be7ec3 | ||
|
|
2d4edb3eca | ||
|
|
ba5333a092 | ||
|
|
53fa7255ec | ||
|
|
dddf772f19 | ||
|
|
3768fc37da | ||
|
|
9d6d580cbd | ||
|
|
c3696cd857 | ||
|
|
c7498b768c | ||
|
|
da17bdb688 | ||
|
|
c87a741fc9 | ||
|
|
4ad8eefaec | ||
|
|
e64b13c925 | ||
|
|
b97a683bfd | ||
|
|
6ea19b0ae2 | ||
|
|
dd9d2a150d | ||
|
|
195be56c46 | ||
|
|
78662e8194 | ||
|
|
c8f7aa76e7 | ||
|
|
543e7b0b6b | ||
|
|
42d2394585 | ||
|
|
94bd44d0f2 | ||
|
|
92022360de | ||
|
|
7f462a084c | ||
|
|
b77d64bc9f | ||
|
|
d1d945eaa0 | ||
|
|
28fdb8af37 | ||
|
|
dbde044213 | ||
|
|
870132a5cb | ||
|
|
ffa898c52d | ||
|
|
cdf27d60be | ||
|
|
6cf84b118b | ||
|
|
1cc81deb69 | ||
|
|
1d578b73ce | ||
|
|
fdb6a3ce16 | ||
|
|
ca1f3c6e4c | ||
|
|
f942361f7b | ||
|
|
02fd80b703 | ||
|
|
d6b03d4760 | ||
|
|
cb75e25a1a | ||
|
|
9572e16dcb | ||
|
|
459fce196f | ||
|
|
ada434fb20 | ||
|
|
71ba3fa310 | ||
|
|
0727353afa | ||
|
|
fd2ff2a973 | ||
|
|
50f9195f2d | ||
|
|
a47fc5a76b | ||
|
|
72ffe61ad1 | ||
|
|
ea8cac7c10 | ||
|
|
8639699d49 | ||
|
|
f242220132 | ||
|
|
55dbdba636 | ||
|
|
03b670971b | ||
|
|
24860fdc05 | ||
|
|
229dd3a123 | ||
|
|
919eacd907 | ||
|
|
4cec55c9a4 | ||
|
|
aa8ec92976 | ||
|
|
44da9c9a28 | ||
|
|
c776a1edff | ||
|
|
a5cbef1a61 | ||
|
|
ae22ba593a | ||
|
|
4ad4ad7088 | ||
|
|
8bccda5649 | ||
|
|
2a804b6c02 | ||
|
|
3b61617cb1 | ||
|
|
ec28671aed | ||
|
|
c7c7229b8b | ||
|
|
2efc133997 | ||
|
|
df72ac1215 | ||
|
|
2fc0d7b2a7 | ||
|
|
3a9e394814 | ||
|
|
3d9d3da1ae | ||
|
|
8abd764eca | ||
|
|
7a31e481a6 | ||
|
|
b70d2655ed | ||
|
|
15cb2f1a9e | ||
|
|
2471367c92 | ||
|
|
962c40c1a7 | ||
|
|
f6c7828160 | ||
|
|
8b57da9a2b | ||
|
|
daa7a13505 | ||
|
|
cda4790219 | ||
|
|
c6bb1dcc0e | ||
|
|
f8e1b084cd | ||
|
|
7d869c9af1 | ||
|
|
1690b05629 | ||
|
|
563825492e | ||
|
|
eee37017e1 | ||
|
|
29ec328f46 | ||
|
|
b843bb8286 | ||
|
|
77975529fe | ||
|
|
9de65184ab | ||
|
|
4912b1e632 | ||
|
|
cf91cf1b14 | ||
|
|
d0fb54fbfe | ||
|
|
346b869d60 | ||
|
|
ac158e227e | ||
|
|
d96f846648 | ||
|
|
473f3b6f3e | ||
|
|
7f1a471751 | ||
|
|
bbac342f3a | ||
|
|
4b3702987f | ||
|
|
6341847203 | ||
|
|
4e75a9b3b3 | ||
|
|
26f44b8d4b | ||
|
|
8fba0017c7 | ||
|
|
7f4056abc9 | ||
|
|
0257918571 | ||
|
|
1d4e746c4f | ||
|
|
677a02c632 | ||
|
|
177b891905 | ||
|
|
c4dcc6df9c | ||
|
|
7ddd314015 | ||
|
|
ba7325c884 | ||
|
|
3c4b1ef127 | ||
|
|
18c630e5e4 | ||
|
|
0ea0a432bf | ||
|
|
8a964efbed | ||
|
|
865bb7aad8 | ||
|
|
d9c1fb5244 | ||
|
|
71c39c9893 | ||
|
|
38067f1ddc | ||
|
|
7cfeb6e87c | ||
|
|
0a231a8acc | ||
|
|
1cea7a0314 | ||
|
|
ed95a9f2b2 | ||
|
|
76d71a032a | ||
|
|
38bff1a0e0 | ||
|
|
0c0caad827 | ||
|
|
4445e5891f | ||
|
|
f46cefbd39 | ||
|
|
feef022303 | ||
|
|
6a80c18189 | ||
|
|
6616bb4048 | ||
|
|
ac5f51c3d5 | ||
|
|
587888a688 | ||
|
|
7370b4fbcd | ||
|
|
94506bee99 | ||
|
|
7c814a5fd9 | ||
|
|
24aa29598a | ||
|
|
d61a862fa2 | ||
|
|
e29c6b44c7 | ||
|
|
327a0ca323 | ||
|
|
a746309a8e | ||
|
|
d247f90571 | ||
|
|
edbe18b157 | ||
|
|
d951485431 | ||
|
|
306a1a3f57 | ||
|
|
2431de78fa | ||
|
|
49abd6aaf3 | ||
|
|
f3a1f98add | ||
|
|
1ccc728e5d | ||
|
|
11ee80d377 | ||
|
|
512850e83d | ||
|
|
0e9c3cde7c | ||
|
|
43263a3bc8 | ||
|
|
8cce3cc84a | ||
|
|
faaa5a2949 | ||
|
|
c00f5a17c8 | ||
|
|
9c079d04a8 | ||
|
|
c9d4cdc57e | ||
|
|
12b4e80d4b | ||
|
|
6e2a04f374 | ||
|
|
3feeca627c | ||
|
|
8357b15fec | ||
|
|
ecdd9d1ccb | ||
|
|
fc69f4f757 | ||
|
|
5e70274003 | ||
|
|
57b194c63f | ||
|
|
10b04416c1 | ||
|
|
9f6027325c | ||
|
|
b64c8ea56b | ||
|
|
e74d3f4a8f | ||
|
|
8a2aebf845 | ||
|
|
984c8ee477 | ||
|
|
398ae7156b | ||
|
|
d85eeabf11 | ||
|
|
6a62654759 | ||
|
|
c056a7ad7c | ||
|
|
c784a70277 | ||
|
|
e6c87907d5 | ||
|
|
71e9290142 | ||
|
|
74ec34da67 | ||
|
|
7188749cb3 | ||
|
|
c28add55db | ||
|
|
78f34a8245 | ||
|
|
97d6f10f15 | ||
|
|
afefc4caca | ||
|
|
6abbd036f8 | ||
|
|
ef0db0f914 | ||
|
|
e01986fdd4 | ||
|
|
a0c6ebe2d8 | ||
|
|
d2183af23f | ||
|
|
953f1bdc3c | ||
|
|
e2429f20f8 | ||
|
|
f0945da4fb | ||
|
|
8df3de9ae5 | ||
|
|
277cc1cac8 | ||
|
|
07a92293e4 | ||
|
|
9730b9ba2d | ||
|
|
508799c452 | ||
|
|
5e81ef4a44 | ||
|
|
eb42eb6f27 | ||
|
|
232612898b | ||
|
|
6a37efb871 | ||
|
|
af59b61f8a | ||
|
|
f995e31d04 | ||
|
|
9758a9e60d | ||
|
|
6f56696af2 | ||
|
|
345fbdf3d2 | ||
|
|
ce031f7d15 | ||
|
|
bd6b811183 | ||
|
|
196bafff03 | ||
|
|
82bf149ade | ||
|
|
f20b558e22 | ||
|
|
54447bf227 | ||
|
|
fc09051d8b | ||
|
|
1f5ef24ecd | ||
|
|
b1faf42529 | ||
|
|
6a85206e32 | ||
|
|
e3d3e697d3 | ||
|
|
db9b333930 | ||
|
|
f7b284ad73 | ||
|
|
e1970e8a66 | ||
|
|
0cd93d67ff | ||
|
|
6e806e21bd | ||
|
|
a8462c1b70 | ||
|
|
706ea8b649 | ||
|
|
95d46d1dfc | ||
|
|
010f27678d | ||
|
|
1c1e3386f8 | ||
|
|
d87117a2cf | ||
|
|
4ed92a94a1 | ||
|
|
821ea34a3c | ||
|
|
ecb3d01376 | ||
|
|
e322ed4f05 | ||
|
|
a385c8a6f8 | ||
|
|
bcf7e78665 | ||
|
|
4cc76f2deb | ||
|
|
0cb2bb2ea7 | ||
|
|
b41c24d653 | ||
|
|
c5d97597c4 | ||
|
|
fe9acb6c59 | ||
|
|
75548c449b | ||
|
|
bca78beb1b | ||
|
|
9110611489 | ||
|
|
a8a42cbfa8 | ||
|
|
19df2ac234 | ||
|
|
e7524c85c2 | ||
|
|
a4356727e9 | ||
|
|
f15a53fae4 | ||
|
|
8e3cf2eaab | ||
|
|
c51ec3135b | ||
|
|
2469c439b1 | ||
|
|
1297addfb1 | ||
|
|
d6cbf43373 | ||
|
|
0b1a1ca064 | ||
|
|
df647e7b42 | ||
|
|
52a9cee0e1 | ||
|
|
fe16d05fbb | ||
|
|
1430c05b6c | ||
|
|
b25841e50d | ||
|
|
34d45bb3b8 | ||
|
|
9b73696a98 | ||
|
|
aecdbfacf3 | ||
|
|
1c25e29999 | ||
|
|
5ceb898676 | ||
|
|
2fe3706ef0 | ||
|
|
1880164e29 | ||
|
|
b704fc9254 | ||
|
|
352da66bd1 | ||
|
|
8205ad2cd0 | ||
|
|
e417c269eb | ||
|
|
59a76b3970 | ||
|
|
53be79a00e | ||
|
|
c4b69b341a | ||
|
|
e162b9c169 | ||
|
|
77e3502028 | ||
|
|
ae0461692c | ||
|
|
13bdb80958 | ||
|
|
6f74e7b738 | ||
|
|
eaee89f77a | ||
|
|
756a8c50d6 | ||
|
|
a99dbc78c9 | ||
|
|
8a54512037 | ||
|
|
3f96bd9509 | ||
|
|
6d06cb8fb3 | ||
|
|
4247883173 | ||
|
|
bf491d6fe7 | ||
|
|
c15e753a0a | ||
|
|
902aee4e6b | ||
|
|
b964f755ec | ||
|
|
a044070e1d | ||
|
|
e0b859dbbe | ||
|
|
07b64ff1a4 | ||
|
|
7bc9192f3f | ||
|
|
057e551059 | ||
|
|
2f80c814aa | ||
|
|
136a029bb4 | ||
|
|
d4b32a403b | ||
|
|
722b187f83 | ||
|
|
0c5c5823bf | ||
|
|
f5a6b7d1f0 | ||
|
|
bcd236286c | ||
|
|
6c4ada5098 | ||
|
|
2402715492 | ||
|
|
f32cf02714 | ||
|
|
e224ee5498 | ||
|
|
90011aa0c9 | ||
|
|
d0589468c1 | ||
|
|
6ef5acbfe5 | ||
|
|
efe894cad6 | ||
|
|
2a366c176d | ||
|
|
8e280a6a24 | ||
|
|
f144518e0e | ||
|
|
fcc006ecd3 | ||
|
|
5fbadc6b21 | ||
|
|
7902570855 | ||
|
|
55898780f1 | ||
|
|
d16cb90c2f | ||
|
|
66dd514c56 | ||
|
|
ba40748118 | ||
|
|
3538cefe68 | ||
|
|
f77aef82d2 | ||
|
|
4d0037a40c | ||
|
|
fd7a4461cc | ||
|
|
8bc6ddbca8 | ||
|
|
7d50e432b5 | ||
|
|
6103888610 | ||
|
|
4d8189f21b | ||
|
|
cddb778577 | ||
|
|
fa506ec04f | ||
|
|
0eaeef5723 | ||
|
|
f87054895e | ||
|
|
d74a5bd507 | ||
|
|
b5d4535db6 | ||
|
|
4d7562fd79 | ||
|
|
5b869376ab | ||
|
|
19c522d9bc | ||
|
|
1d4ecad134 | ||
|
|
805464e406 | ||
|
|
c674c3561a | ||
|
|
7aa2972c3f | ||
|
|
986558fea7 | ||
|
|
818e34682c | ||
|
|
252fddf3de | ||
|
|
39079e7aff | ||
|
|
1fa4518bb9 | ||
|
|
1b739e87ae | ||
|
|
e944983567 | ||
|
|
4fccaf3284 | ||
|
|
0a79dc9ecc | ||
|
|
847a8c8c4d | ||
|
|
a1018c5823 | ||
|
|
323417182a | ||
|
|
f3bcf570f4 | ||
|
|
a3059597fb | ||
|
|
d19a6914f9 | ||
|
|
4313ede132 | ||
|
|
f3b7ac508d | ||
|
|
635bfd4aba | ||
|
|
38e72e1af7 | ||
|
|
26644bfd1e | ||
|
|
6a827fc7b9 | ||
|
|
3b3ae9c0dd | ||
|
|
301909e3e5 | ||
|
|
97a9c8627c | ||
|
|
56c1fbecea | ||
|
|
de9d18a2fe | ||
|
|
be16ad26b5 | ||
|
|
d762da9141 | ||
|
|
c05d6f7cdf | ||
|
|
7af3fb5ae4 | ||
|
|
3ac54b2178 | ||
|
|
42a26f076a | ||
|
|
3b67759730 | ||
|
|
5407a8345f | ||
|
|
3fe509757b | ||
|
|
952b679ca3 | ||
|
|
6799daacd1 | ||
|
|
fa02b5150c | ||
|
|
63a1904242 | ||
|
|
1e3450fdcb | ||
|
|
5541026b86 | ||
|
|
c36c920b34 | ||
|
|
514fea65c4 | ||
|
|
e269b3bfdd | ||
|
|
0862a9bfa7 | ||
|
|
f43c695527 | ||
|
|
ead43f081c | ||
|
|
4e2a3d61dc | ||
|
|
218ad6bbe0 | ||
|
|
b485f2e42e | ||
|
|
16e32c3f67 | ||
|
|
15f65bb558 | ||
|
|
b161d6831f | ||
|
|
969953039f | ||
|
|
f1506ed5da | ||
|
|
9a239d9e13 | ||
|
|
a5da09dfb9 | ||
|
|
6f81f2d143 | ||
|
|
0b877ca8a3 | ||
|
|
2911b9cd04 | ||
|
|
6b3f1ab0e4 | ||
|
|
2c15655b08 | ||
|
|
afa9c650fe | ||
|
|
28d8d82ded | ||
|
|
a100baf57f | ||
|
|
5621755655 | ||
|
|
d892bfc278 | ||
|
|
4369b18fbf | ||
|
|
fb9b5d31e8 | ||
|
|
3bf0748389 | ||
|
|
cf46b89814 | ||
|
|
3360b34af9 | ||
|
|
4558eb41fc | ||
|
|
bbc5584f80 | ||
|
|
8604c9f9d5 | ||
|
|
747e02ee0d | ||
|
|
8b0334309b | ||
|
|
48afa821e4 | ||
|
|
42a8d3e3dc | ||
|
|
a44fc51007 | ||
|
|
961bc874d2 | ||
|
|
b2b018ab93 | ||
|
|
77da33de4f | ||
|
|
06ad5e3f8c | ||
|
|
9326bf96fc | ||
|
|
bed73102b4 | ||
|
|
eb59f9c75d | ||
|
|
f3bd2ed472 | ||
|
|
456475d593 | ||
|
|
a36ce199ba | ||
|
|
b7c3ad0867 | ||
|
|
ea3545cc7e | ||
|
|
232ba46b16 | ||
|
|
5f011502d1 | ||
|
|
93b6f1066b | ||
|
|
52fe92ed7f | ||
|
|
0d005df463 | ||
|
|
e3ef3ace29 | ||
|
|
a203e98689 | ||
|
|
27f99a0f38 | ||
|
|
d1e48d02bd | ||
|
|
4f06a1df50 | ||
|
|
2d7ae1180f | ||
|
|
75b486b467 | ||
|
|
5b5f10fe93 | ||
|
|
5f654e76e2 | ||
|
|
aa8d112c58 | ||
|
|
e82dc0e841 | ||
|
|
dd741fc38a | ||
|
|
120e4ee92f | ||
|
|
9d2a56bff4 | ||
|
|
31d82a3169 | ||
|
|
d22ee5d451 | ||
|
|
203edaed50 | ||
|
|
93b5638a9c | ||
|
|
52a5e58f0c | ||
|
|
20607b0b5c | ||
|
|
6bebfe9e54 | ||
|
|
50b76f4466 | ||
|
|
23e4e25e9a | ||
|
|
5b83d478d6 | ||
|
|
dca38d01d6 | ||
|
|
0a434d3b3a | ||
|
|
7c4b83a430 | ||
|
|
b7f24b428b | ||
|
|
22a0ed0ee2 | ||
|
|
cf711d55a5 | ||
|
|
26ea562fdb | ||
|
|
efce0c6c57 | ||
|
|
a3768dae97 | ||
|
|
85efea3fb8 | ||
|
|
c820fda26d | ||
|
|
4740293640 | ||
|
|
8be8813cd8 | ||
|
|
8cc747ef22 | ||
|
|
d6ed2ab3e0 | ||
|
|
e8ae980104 | ||
|
|
cd8c23c0ab | ||
|
|
3568042cd9 | ||
|
|
7443129e18 | ||
|
|
4196a3db5a | ||
|
|
cd7594f623 | ||
|
|
b887db474e | ||
|
|
6c4242ad2a | ||
|
|
530af5e358 |
@@ -4,4 +4,5 @@
|
|||||||
.vscode
|
.vscode
|
||||||
.gitignore
|
.gitignore
|
||||||
Makefile
|
Makefile
|
||||||
docs
|
docs
|
||||||
|
.eslintcache
|
||||||
@@ -47,7 +47,7 @@
|
|||||||
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||||
# RELAY_TIMEOUT=0
|
# RELAY_TIMEOUT=0
|
||||||
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||||
# STREAMING_TIMEOUT=120
|
# STREAMING_TIMEOUT=300
|
||||||
|
|
||||||
# Gemini 识别图片 最大图片数量
|
# Gemini 识别图片 最大图片数量
|
||||||
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
||||||
@@ -56,8 +56,6 @@
|
|||||||
# SESSION_SECRET=random_string
|
# SESSION_SECRET=random_string
|
||||||
|
|
||||||
# 其他配置
|
# 其他配置
|
||||||
# 渠道测试频率(单位:秒)
|
|
||||||
# CHANNEL_TEST_FREQUENCY=10
|
|
||||||
# 生成默认token
|
# 生成默认token
|
||||||
# GENERATE_DEFAULT_TOKEN=false
|
# GENERATE_DEFAULT_TOKEN=false
|
||||||
# Cohere 安全设置
|
# Cohere 安全设置
|
||||||
|
|||||||
21
.github/workflows/pr-target-branch-check.yml
vendored
21
.github/workflows/pr-target-branch-check.yml
vendored
@@ -1,21 +0,0 @@
|
|||||||
name: Check PR Branching Strategy
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened, edited]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-branching-strategy:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Enforce branching strategy
|
|
||||||
run: |
|
|
||||||
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
|
||||||
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
|
||||||
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
|
||||||
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
echo "Branching strategy check passed."
|
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -10,4 +10,7 @@ web/dist
|
|||||||
.env
|
.env
|
||||||
one-api
|
one-api
|
||||||
.DS_Store
|
.DS_Store
|
||||||
tiktoken_cache
|
tiktoken_cache
|
||||||
|
.eslintcache
|
||||||
|
.cursor
|
||||||
|
*.mdc
|
||||||
@@ -2,6 +2,7 @@ FROM oven/bun:latest AS builder
|
|||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
COPY web/package.json .
|
COPY web/package.json .
|
||||||
|
COPY web/bun.lock .
|
||||||
RUN bun install
|
RUN bun install
|
||||||
COPY ./web .
|
COPY ./web .
|
||||||
COPY ./VERSION .
|
COPY ./VERSION .
|
||||||
|
|||||||
240
LICENSE
240
LICENSE
@@ -1,201 +1,103 @@
|
|||||||
Apache License
|
# **New API 许可协议 (Licensing)**
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
本项目采用**基于使用场景的双重许可 (Usage-Based Dual Licensing)** 模式。
|
||||||
|
|
||||||
1. Definitions.
|
**核心原则:**
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
- **默认许可:** 本项目默认在 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)** 下提供。任何用户在遵守 AGPLv3 条款和下述附加限制的前提下,均可免费使用。
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
- **商业许可:** 在特定商业场景下,或当您希望获得 AGPLv3 之外的权利时,**必须**获取**商业许可证 (Commercial License)**。
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
---
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
## **1. 开源许可证 (Open Source License): AGPLv3 - 适用于基础使用**
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
- 在遵守 **AGPLv3** 条款的前提下,您可以自由地使用、修改和分发 New API。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
|
||||||
exercising permissions granted by this License.
|
- **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 New API 并通过网络提供服务 (SaaS),或者分发了修改后的版本,您必须以 AGPLv3 许可证向所有用户提供相应的**完整源代码**。
|
||||||
|
- **附加限制 (重要):** 在仅使用 AGPLv3 开源许可证的情况下,您**必须**完整保留项目代码中原有的品牌标识、LOGO 及版权声明信息。**禁止以任何形式修改、移除或遮盖**这些信息。如需移除,必须获取商业许可证。
|
||||||
|
- 使用前请务必仔细阅读并理解 AGPLv3 的所有条款及上述附加限制。
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
## **2. 商业许可证 (Commercial License) - 适用于高级场景及闭源需求**
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
在以下任一情况下,您**必须**联系我们获取并签署一份商业许可证,才能合法使用 New API:
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
- **场景一:移除品牌和版权信息**
|
||||||
Object form, made available under the License, as indicated by a
|
您希望在您的产品或服务中移除 New API 的 LOGO、UI界面中的版权声明或其他品牌标识。
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
- **场景二:规避 AGPLv3 开源义务**
|
||||||
form, that is based on (or derived from) the Work and for which the
|
您基于 New API 进行了修改,并希望:
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
- 通过网络提供服务(SaaS),但**不希望**向您的服务用户公开您修改后的源代码。
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
- 分发一个集成了 New API 的软件产品,但**不希望**以 AGPLv3 许可证发布您的产品或公开源代码。
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
- **场景三:企业政策与集成需求**
|
||||||
the original version of the Work and any modifications or additions
|
- 您所在公司的政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件。
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
- 您需要进行 OEM 集成,将 New API 作为您闭源商业产品的一部分进行再分发。
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
- **场景四:需要商业支持与保障**
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
您需要 AGPLv3 未提供的商业保障,如官方技术支持等。
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
**获取商业许可:**
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
请通过电子邮件 **support@quantumnous.com** 联系 New API 团队洽谈商业授权事宜。
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
## **3. 贡献 (Contributions)**
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
- 我们欢迎社区对 New API 的贡献。所有向本项目提交的贡献(例如通过 Pull Request)都将被视为在 **AGPLv3** 许可证下提供。
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
- 通过向本项目提交贡献,即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
|
||||||
modifications, and in Source or Object form, provided that You
|
- 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 New API 版本中。
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
## **4. 其他条款 (Other Terms)**
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
- 关于商业许可证的具体条款、条件和价格,以双方签署的正式商业许可协议为准。
|
||||||
stating that You changed the files; and
|
- 项目维护者保留根据需要更新本许可政策的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
---
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
# **New API Licensing**
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
This project uses a **Usage-Based Dual Licensing** model.
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
**Core Principles:**
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
- **Default License:** This project is available by default under the **GNU Affero General Public License v3.0 (AGPLv3)**. Any user may use it free of charge, provided they comply with both the AGPLv3 terms and the additional restrictions listed below.
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
- **Commercial License:** For specific commercial scenarios, or if you require rights beyond those granted by AGPLv3, you **must** obtain a **Commercial License**.
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
---
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
## **1. Open Source License: AGPLv3 – For Basic Usage**
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
- Under the terms of the **AGPLv3**, you are free to use, modify, and distribute New API. The complete AGPLv3 license text can be viewed at [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html).
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
- **Core Obligation:** A key AGPLv3 requirement is that if you modify New API and provide it as a network service (SaaS), or distribute a modified version, you must make the **complete corresponding source code** available to all users under the AGPLv3 license.
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
- **Additional Restriction (Important):** When using only the AGPLv3 open-source license, you **must** retain all original branding, logos, and copyright statements within the project’s code. **You are strictly prohibited from modifying, removing, or concealing** any such information. If you wish to remove this, you must obtain a Commercial License.
|
||||||
or other liability obligations and/or rights consistent with this
|
- Please read and ensure that you fully understand all AGPLv3 terms and the above additional restriction before use.
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
## **2. Commercial License – For Advanced Scenarios & Closed Source Needs**
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
You **must** contact us to obtain and sign a Commercial License in any of the following scenarios in order to legally use New API:
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
- **Scenario 1: Removal of Branding and Copyright**
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
You wish to remove the New API logo, copyright statement, or other branding elements from your product or service.
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
- **Scenario 2: Avoidance of AGPLv3 Open Source Obligations**
|
||||||
|
You have modified New API and wish to:
|
||||||
|
- Offer it as a network service (SaaS) **without** disclosing your modifications' source code to your users.
|
||||||
|
- Distribute a software product integrated with New API **without** releasing your product under AGPLv3 or open-sourcing the code.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
- **Scenario 3: Enterprise Policy & Integration Needs**
|
||||||
you may not use this file except in compliance with the License.
|
- Your organization’s policies, client contracts, or project requirements prohibit the use of AGPLv3-licensed software.
|
||||||
You may obtain a copy of the License at
|
- You require OEM integration and need to redistribute New API as part of your closed-source commercial product.
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
- **Scenario 4: Commercial Support and Assurances**
|
||||||
|
You require commercial assurances not provided by AGPLv3, such as official technical support.
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
**Obtaining a Commercial License:**
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
Please contact the New API team via email at **support@quantumnous.com** to discuss commercial licensing.
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
## **3. Contributions**
|
||||||
limitations under the License.
|
|
||||||
|
- We welcome community contributions to New API. All contributions (e.g., via Pull Request) are deemed to be provided under the **AGPLv3** license.
|
||||||
|
- By submitting a contribution, you agree that your code is licensed to this project and all downstream users under the AGPLv3 license (regardless of whether those users ultimately operate under AGPLv3 or a Commercial License).
|
||||||
|
- You also acknowledge and agree that your contribution may be included in New API releases distributed under a Commercial License.
|
||||||
|
|
||||||
|
## **4. Other Terms**
|
||||||
|
|
||||||
|
- The specific terms, conditions, and pricing of the Commercial License are governed by the formal commercial license agreement executed by both parties.
|
||||||
|
- Project maintainers reserve the right to update this licensing policy as needed. Updates will be communicated via official project channels (e.g., repository, official website).
|
||||||
|
|||||||
24
README.en.md
24
README.en.md
@@ -40,6 +40,28 @@
|
|||||||
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
|
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
|
||||||
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
|
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
|
||||||
|
|
||||||
|
<h2>🤝 Trusted Partners</h2>
|
||||||
|
<p id="premium-sponsors"> </p>
|
||||||
|
<p align="center"><strong>No particular order</strong></p>
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://www.cherry-ai.com/" target=_blank><img
|
||||||
|
src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://bda.pku.edu.cn/" target=_blank><img
|
||||||
|
src="./docs/images/pku.png" alt="Peking University" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
|
||||||
|
src="./docs/images/ucloud.png" alt="UCloud" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://www.aliyun.com/" target=_blank><img
|
||||||
|
src="./docs/images/aliyun.png" alt="Alibaba Cloud" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://io.net/" target=_blank><img
|
||||||
|
src="./docs/images/io-net.png" alt="IO.NET" height="120"
|
||||||
|
/></a>
|
||||||
|
</p>
|
||||||
|
<p> </p>
|
||||||
|
|
||||||
## 📚 Documentation
|
## 📚 Documentation
|
||||||
|
|
||||||
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
|
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||||
@@ -100,7 +122,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
|
|||||||
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
||||||
|
|
||||||
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
||||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 120 seconds
|
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 300 seconds
|
||||||
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
||||||
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
||||||
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
||||||
|
|||||||
30
README.md
30
README.md
@@ -40,6 +40,28 @@
|
|||||||
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||||
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||||
|
|
||||||
|
<h2>🤝 我们信任的合作伙伴</h2>
|
||||||
|
<p id="premium-sponsors"> </p>
|
||||||
|
<p align="center"><strong>排名不分先后</strong></p>
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://www.cherry-ai.com/" target=_blank><img
|
||||||
|
src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://bda.pku.edu.cn/" target=_blank><img
|
||||||
|
src="./docs/images/pku.png" alt="北京大学" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
|
||||||
|
src="./docs/images/ucloud.png" alt="UCloud 优刻得" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://www.aliyun.com/" target=_blank><img
|
||||||
|
src="./docs/images/aliyun.png" alt="阿里云" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://io.net/" target=_blank><img
|
||||||
|
src="./docs/images/io-net.png" alt="IO.NET" height="120"
|
||||||
|
/></a>
|
||||||
|
</p>
|
||||||
|
<p> </p>
|
||||||
|
|
||||||
## 📚 文档
|
## 📚 文档
|
||||||
|
|
||||||
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
|
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||||
@@ -74,7 +96,11 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
|||||||
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
|
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
|
||||||
16. 🔄 思考转内容功能
|
16. 🔄 思考转内容功能
|
||||||
17. 🔄 针对用户的模型限流功能
|
17. 🔄 针对用户的模型限流功能
|
||||||
18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
18. 🔄 请求格式转换功能,支持以下三种格式转换:
|
||||||
|
1. OpenAI Chat Completions => Claude Messages
|
||||||
|
2. Clade Messages => OpenAI Chat Completions (可用于Claude Code调用第三方模型)
|
||||||
|
3. OpenAI Chat Completions => Gemini Chat
|
||||||
|
19. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
||||||
1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
|
1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
|
||||||
2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
|
2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
|
||||||
3. 支持的渠道:
|
3. 支持的渠道:
|
||||||
@@ -100,7 +126,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
|||||||
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
||||||
|
|
||||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
||||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认120秒
|
- `STREAMING_TIMEOUT`:流式回复超时时间,默认300秒
|
||||||
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
||||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
||||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
||||||
|
|||||||
@@ -63,6 +63,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
|||||||
apiType = constant.APITypeXai
|
apiType = constant.APITypeXai
|
||||||
case constant.ChannelTypeCoze:
|
case constant.ChannelTypeCoze:
|
||||||
apiType = constant.APITypeCoze
|
apiType = constant.APITypeCoze
|
||||||
|
case constant.ChannelTypeJimeng:
|
||||||
|
apiType = constant.APITypeJimeng
|
||||||
|
case constant.ChannelTypeMoonshot:
|
||||||
|
apiType = constant.APITypeMoonshot
|
||||||
}
|
}
|
||||||
if apiType == -1 {
|
if apiType == -1 {
|
||||||
return constant.APITypeOpenAI, false
|
return constant.APITypeOpenAI, false
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ var GitHubClientId = ""
|
|||||||
var GitHubClientSecret = ""
|
var GitHubClientSecret = ""
|
||||||
var LinuxDOClientId = ""
|
var LinuxDOClientId = ""
|
||||||
var LinuxDOClientSecret = ""
|
var LinuxDOClientSecret = ""
|
||||||
|
var LinuxDOMinimumTrustLevel = 0
|
||||||
|
|
||||||
var WeChatServerAddress = ""
|
var WeChatServerAddress = ""
|
||||||
var WeChatServerToken = ""
|
var WeChatServerToken = ""
|
||||||
@@ -193,3 +194,9 @@ const (
|
|||||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||||
ChannelStatusAutoDisabled = 3
|
ChannelStatusAutoDisabled = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TopUpStatusPending = "pending"
|
||||||
|
TopUpStatusSuccess = "success"
|
||||||
|
TopUpStatusExpired = "expired"
|
||||||
|
)
|
||||||
|
|||||||
19
common/copy.go
Normal file
19
common/copy.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jinzhu/copier"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DeepCopy[T any](src *T) (*T, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, fmt.Errorf("copy source cannot be nil")
|
||||||
|
}
|
||||||
|
var dst T
|
||||||
|
err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &dst, nil
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type stringWriter interface {
|
type stringWriter interface {
|
||||||
@@ -52,6 +53,8 @@ type CustomEvent struct {
|
|||||||
Id string
|
Id string
|
||||||
Retry uint
|
Retry uint
|
||||||
Data interface{}
|
Data interface{}
|
||||||
|
|
||||||
|
Mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func encode(writer io.Writer, event CustomEvent) error {
|
func encode(writer io.Writer, event CustomEvent) error {
|
||||||
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||||
|
r.Mutex.Lock()
|
||||||
|
defer r.Mutex.Unlock()
|
||||||
header := w.Header()
|
header := w.Header()
|
||||||
header["Content-Type"] = contentType
|
header["Content-Type"] = contentType
|
||||||
|
|
||||||
|
|||||||
@@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
|||||||
var UsingMySQL = false
|
var UsingMySQL = false
|
||||||
var UsingClickHouse = false
|
var UsingClickHouse = false
|
||||||
|
|
||||||
var SQLitePath = "one-api.db?_busy_timeout=5000"
|
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||||
32
common/endpoint_defaults.go
Normal file
32
common/endpoint_defaults.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "one-api/constant"
|
||||||
|
|
||||||
|
// EndpointInfo 描述单个端点的默认请求信息
|
||||||
|
// path: 上游路径
|
||||||
|
// method: HTTP 请求方式,例如 POST/GET
|
||||||
|
// 目前均为 POST,后续可扩展
|
||||||
|
//
|
||||||
|
// json 标签用于直接序列化到 API 输出
|
||||||
|
// 例如:{"path":"/v1/chat/completions","method":"POST"}
|
||||||
|
|
||||||
|
type EndpointInfo struct {
|
||||||
|
Path string `json:"path"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
|
||||||
|
var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
|
||||||
|
constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
|
||||||
|
constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
|
||||||
|
constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
|
||||||
|
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
|
||||||
|
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
|
||||||
|
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
|
||||||
|
func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
|
||||||
|
info, ok := defaultEndpointInfoMap[et]
|
||||||
|
return info, ok
|
||||||
|
}
|
||||||
@@ -2,11 +2,13 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const KeyRequestBody = "key_request_body"
|
const KeyRequestBody = "key_request_body"
|
||||||
@@ -30,6 +32,9 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
//if DebugEnabled {
|
||||||
|
// println("UnmarshalBodyReusable request body:", string(requestBody))
|
||||||
|
//}
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = Unmarshal(requestBody, &v)
|
err = Unmarshal(requestBody, &v)
|
||||||
@@ -86,3 +91,25 @@ func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool)
|
|||||||
var t T
|
var t T
|
||||||
return t, false
|
return t, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ApiError(c *gin.Context, err error) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApiErrorMsg(c *gin.Context, msg string) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": msg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApiSuccess(c *gin.Context, data any) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
34
common/hash.go
Normal file
34
common/hash.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Sha256Raw(data []byte) []byte {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write(data)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Sha1Raw(data []byte) []byte {
|
||||||
|
h := sha1.New()
|
||||||
|
h.Write(data)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Sha1(data []byte) string {
|
||||||
|
return hex.EncodeToString(Sha1Raw(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func HmacSha256Raw(message, key []byte) []byte {
|
||||||
|
h := hmac.New(sha256.New, key)
|
||||||
|
h.Write(message)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func HmacSha256(message, key string) string {
|
||||||
|
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
|
||||||
|
}
|
||||||
@@ -101,7 +101,7 @@ func InitEnv() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func initConstantEnv() {
|
func initConstantEnv() {
|
||||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
|
||||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||||
|
|||||||
22
common/ip.go
Normal file
22
common/ip.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
func IsPrivateIP(ip net.IP) bool {
|
||||||
|
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
private := []net.IPNet{
|
||||||
|
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
|
||||||
|
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
|
||||||
|
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, privateNet := range private {
|
||||||
|
if privateNet.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -20,3 +20,25 @@ func DecodeJson(reader *bytes.Reader, v any) error {
|
|||||||
func Marshal(v any) ([]byte, error) {
|
func Marshal(v any) ([]byte, error) {
|
||||||
return json.Marshal(v)
|
return json.Marshal(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetJsonType(data json.RawMessage) string {
|
||||||
|
data = bytes.TrimSpace(data)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
firstChar := bytes.TrimSpace(data)[0]
|
||||||
|
switch firstChar {
|
||||||
|
case '{':
|
||||||
|
return "object"
|
||||||
|
case '[':
|
||||||
|
return "array"
|
||||||
|
case '"':
|
||||||
|
return "string"
|
||||||
|
case 't', 'f':
|
||||||
|
return "boolean"
|
||||||
|
case 'n':
|
||||||
|
return "null"
|
||||||
|
default:
|
||||||
|
return "number"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PageInfo struct {
|
type PageInfo struct {
|
||||||
Page int `json:"page"` // page num 页码
|
Page int `json:"page"` // page num 页码
|
||||||
PageSize int `json:"page_size"` // page size 页大小
|
PageSize int `json:"page_size"` // page size 页大小
|
||||||
StartTimestamp int64 `json:"start_timestamp"` // 秒级
|
|
||||||
EndTimestamp int64 `json:"end_timestamp"` // 秒级
|
|
||||||
|
|
||||||
Total int `json:"total"` // 总条数,后设置
|
Total int `json:"total"` // 总条数,后设置
|
||||||
Items any `json:"items"` // 数据,后设置
|
Items any `json:"items"` // 数据,后设置
|
||||||
@@ -39,11 +38,14 @@ func (p *PageInfo) SetItems(items any) {
|
|||||||
p.Items = items
|
p.Items = items
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPageQuery(c *gin.Context) (*PageInfo, error) {
|
func GetPageQuery(c *gin.Context) *PageInfo {
|
||||||
pageInfo := &PageInfo{}
|
pageInfo := &PageInfo{}
|
||||||
err := c.BindQuery(pageInfo)
|
// 手动获取并处理每个参数
|
||||||
if err != nil {
|
if page, err := strconv.Atoi(c.Query("p")); err == nil {
|
||||||
return nil, err
|
pageInfo.Page = page
|
||||||
|
}
|
||||||
|
if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
|
||||||
|
pageInfo.PageSize = pageSize
|
||||||
}
|
}
|
||||||
if pageInfo.Page < 1 {
|
if pageInfo.Page < 1 {
|
||||||
// 兼容
|
// 兼容
|
||||||
@@ -56,7 +58,25 @@ func GetPageQuery(c *gin.Context) (*PageInfo, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if pageInfo.PageSize == 0 {
|
if pageInfo.PageSize == 0 {
|
||||||
pageInfo.PageSize = ItemsPerPage
|
// 兼容
|
||||||
|
pageSize, _ := strconv.Atoi(c.Query("ps"))
|
||||||
|
if pageSize != 0 {
|
||||||
|
pageInfo.PageSize = pageSize
|
||||||
|
}
|
||||||
|
if pageInfo.PageSize == 0 {
|
||||||
|
pageSize, _ = strconv.Atoi(c.Query("size")) // token page
|
||||||
|
if pageSize != 0 {
|
||||||
|
pageInfo.PageSize = pageSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pageInfo.PageSize == 0 {
|
||||||
|
pageInfo.PageSize = ItemsPerPage
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return pageInfo, nil
|
|
||||||
|
if pageInfo.PageSize > 100 {
|
||||||
|
pageInfo.PageSize = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
return pageInfo
|
||||||
}
|
}
|
||||||
|
|||||||
5
common/quota.go
Normal file
5
common/quota.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
func GetTrustQuota() int {
|
||||||
|
return int(10 * QuotaPerUnit)
|
||||||
|
}
|
||||||
327
common/ssrf_protection.go
Normal file
327
common/ssrf_protection.go
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSRFProtection SSRF防护配置
|
||||||
|
type SSRFProtection struct {
|
||||||
|
AllowPrivateIp bool
|
||||||
|
DomainFilterMode bool // true: 白名单, false: 黑名单
|
||||||
|
DomainList []string // domain format, e.g. example.com, *.example.com
|
||||||
|
IpFilterMode bool // true: 白名单, false: 黑名单
|
||||||
|
IpList []string // CIDR or single IP
|
||||||
|
AllowedPorts []int // 允许的端口范围
|
||||||
|
ApplyIPFilterForDomain bool // 对域名启用IP过滤
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultSSRFProtection 默认SSRF防护配置
|
||||||
|
var DefaultSSRFProtection = &SSRFProtection{
|
||||||
|
AllowPrivateIp: false,
|
||||||
|
DomainFilterMode: true,
|
||||||
|
DomainList: []string{},
|
||||||
|
IpFilterMode: true,
|
||||||
|
IpList: []string{},
|
||||||
|
AllowedPorts: []int{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrivateIP 检查IP是否为私有地址
|
||||||
|
func isPrivateIP(ip net.IP) bool {
|
||||||
|
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查私有网段
|
||||||
|
private := []net.IPNet{
|
||||||
|
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
|
||||||
|
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
|
||||||
|
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
|
||||||
|
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
|
||||||
|
{IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
|
||||||
|
{IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
|
||||||
|
{IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, privateNet := range private {
|
||||||
|
if privateNet.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查IPv6私有地址
|
||||||
|
if ip.To4() == nil {
|
||||||
|
// IPv6 loopback
|
||||||
|
if ip.Equal(net.IPv6loopback) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// IPv6 link-local
|
||||||
|
if strings.HasPrefix(ip.String(), "fe80:") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// IPv6 unique local
|
||||||
|
if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePortRanges 解析端口范围配置
|
||||||
|
// 支持格式: "80", "443", "8000-9000"
|
||||||
|
func parsePortRanges(portConfigs []string) ([]int, error) {
|
||||||
|
var ports []int
|
||||||
|
|
||||||
|
for _, config := range portConfigs {
|
||||||
|
config = strings.TrimSpace(config)
|
||||||
|
if config == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(config, "-") {
|
||||||
|
// 处理端口范围 "8000-9000"
|
||||||
|
parts := strings.Split(config, "-")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, fmt.Errorf("invalid port range format: %s", config)
|
||||||
|
}
|
||||||
|
|
||||||
|
startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if startPort > endPort {
|
||||||
|
return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
|
||||||
|
}
|
||||||
|
|
||||||
|
if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
|
||||||
|
return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加范围内的所有端口
|
||||||
|
for port := startPort; port <= endPort; port++ {
|
||||||
|
ports = append(ports, port)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 处理单个端口 "80"
|
||||||
|
port, err := strconv.Atoi(config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid port number: %s", config)
|
||||||
|
}
|
||||||
|
|
||||||
|
if port < 1 || port > 65535 {
|
||||||
|
return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
ports = append(ports, port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ports, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAllowedPort 检查端口是否被允许
|
||||||
|
func (p *SSRFProtection) isAllowedPort(port int) bool {
|
||||||
|
if len(p.AllowedPorts) == 0 {
|
||||||
|
return true // 如果没有配置端口限制,则允许所有端口
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowedPort := range p.AllowedPorts {
|
||||||
|
if port == allowedPort {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDomainWhitelisted 检查域名是否在白名单中
|
||||||
|
func isDomainListed(domain string, list []string) bool {
|
||||||
|
if len(list) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
domain = strings.ToLower(domain)
|
||||||
|
for _, item := range list {
|
||||||
|
item = strings.ToLower(strings.TrimSpace(item))
|
||||||
|
if item == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 精确匹配
|
||||||
|
if domain == item {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// 通配符匹配 (*.example.com)
|
||||||
|
if strings.HasPrefix(item, "*.") {
|
||||||
|
suffix := strings.TrimPrefix(item, "*.")
|
||||||
|
if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSRFProtection) isDomainAllowed(domain string) bool {
|
||||||
|
listed := isDomainListed(domain, p.DomainList)
|
||||||
|
if p.DomainFilterMode { // 白名单
|
||||||
|
return listed
|
||||||
|
}
|
||||||
|
// 黑名单
|
||||||
|
return !listed
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIPWhitelisted 检查IP是否在白名单中
|
||||||
|
|
||||||
|
func isIPListed(ip net.IP, list []string) bool {
|
||||||
|
if len(list) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, whitelistCIDR := range list {
|
||||||
|
_, network, err := net.ParseCIDR(whitelistCIDR)
|
||||||
|
if err != nil {
|
||||||
|
// 尝试作为单个IP处理
|
||||||
|
if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil {
|
||||||
|
if ip.Equal(whitelistIP) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIPAccessAllowed 检查IP是否允许访问
|
||||||
|
func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
|
||||||
|
// 私有IP限制
|
||||||
|
if isPrivateIP(ip) && !p.AllowPrivateIp {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
listed := isIPListed(ip, p.IpList)
|
||||||
|
if p.IpFilterMode { // 白名单
|
||||||
|
return listed
|
||||||
|
}
|
||||||
|
// 黑名单
|
||||||
|
return !listed
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateURL 验证URL是否安全
|
||||||
|
func (p *SSRFProtection) ValidateURL(urlStr string) error {
|
||||||
|
// 解析URL
|
||||||
|
u, err := url.Parse(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid URL format: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只允许HTTP/HTTPS协议
|
||||||
|
if u.Scheme != "http" && u.Scheme != "https" {
|
||||||
|
return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析主机和端口
|
||||||
|
host, portStr, err := net.SplitHostPort(u.Host)
|
||||||
|
if err != nil {
|
||||||
|
// 没有端口,使用默认端口
|
||||||
|
host = u.Hostname()
|
||||||
|
if u.Scheme == "https" {
|
||||||
|
portStr = "443"
|
||||||
|
} else {
|
||||||
|
portStr = "80"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证端口
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid port: %s", portStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.isAllowedPort(port) {
|
||||||
|
return fmt.Errorf("port %d is not allowed", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果 host 是 IP,则跳过域名检查
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
if !p.IsIPAccessAllowed(ip) {
|
||||||
|
if isPrivateIP(ip) {
|
||||||
|
return fmt.Errorf("private IP address not allowed: %s", ip.String())
|
||||||
|
}
|
||||||
|
if p.IpFilterMode {
|
||||||
|
return fmt.Errorf("ip not in whitelist: %s", ip.String())
|
||||||
|
}
|
||||||
|
return fmt.Errorf("ip in blacklist: %s", ip.String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先进行域名过滤
|
||||||
|
if !p.isDomainAllowed(host) {
|
||||||
|
if p.DomainFilterMode {
|
||||||
|
return fmt.Errorf("domain not in whitelist: %s", host)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("domain in blacklist: %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 若未启用对域名应用IP过滤,则到此通过
|
||||||
|
if !p.ApplyIPFilterForDomain {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析域名对应IP并检查
|
||||||
|
ips, err := net.LookupIP(host)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
|
||||||
|
}
|
||||||
|
for _, ip := range ips {
|
||||||
|
if !p.IsIPAccessAllowed(ip) {
|
||||||
|
if isPrivateIP(ip) && !p.AllowPrivateIp {
|
||||||
|
return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
|
||||||
|
}
|
||||||
|
if p.IpFilterMode {
|
||||||
|
return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
|
||||||
|
}
|
||||||
|
return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
|
||||||
|
func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
|
||||||
|
// 如果SSRF防护被禁用,直接返回成功
|
||||||
|
if !enableSSRFProtection {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析端口范围配置
|
||||||
|
allowedPortInts, err := parsePortRanges(allowedPorts)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("request reject - invalid port configuration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
protection := &SSRFProtection{
|
||||||
|
AllowPrivateIp: allowPrivateIp,
|
||||||
|
DomainFilterMode: domainFilterMode,
|
||||||
|
DomainList: domainList,
|
||||||
|
IpFilterMode: ipFilterMode,
|
||||||
|
IpList: ipList,
|
||||||
|
AllowedPorts: allowedPortInts,
|
||||||
|
ApplyIPFilterForDomain: applyIPFilterForDomain,
|
||||||
|
}
|
||||||
|
return protection.ValidateURL(urlStr)
|
||||||
|
}
|
||||||
140
common/str.go
140
common/str.go
@@ -4,7 +4,10 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -95,3 +98,140 @@ func GetJsonString(data any) string {
|
|||||||
b, _ := json.Marshal(data)
|
b, _ := json.Marshal(data)
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaskEmail masks a user email to prevent PII leakage in logs
|
||||||
|
// Returns "***masked***" if email is empty, otherwise shows only the domain part
|
||||||
|
func MaskEmail(email string) string {
|
||||||
|
if email == "" {
|
||||||
|
return "***masked***"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the @ symbol
|
||||||
|
atIndex := strings.Index(email, "@")
|
||||||
|
if atIndex == -1 {
|
||||||
|
// No @ symbol found, return masked
|
||||||
|
return "***masked***"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return only the domain part with @ symbol
|
||||||
|
return "***@" + email[atIndex+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostTail returns the tail parts of a domain/host that should be preserved.
|
||||||
|
// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
|
||||||
|
func maskHostTail(parts []string) []string {
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
lastPart := parts[len(parts)-1]
|
||||||
|
secondLastPart := parts[len(parts)-2]
|
||||||
|
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
||||||
|
// Likely country code TLD like co.uk, com.cn
|
||||||
|
return []string{secondLastPart, lastPart}
|
||||||
|
}
|
||||||
|
return []string{lastPart}
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
|
||||||
|
// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
|
||||||
|
func maskHostForURL(host string) string {
|
||||||
|
parts := strings.Split(host, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "***"
|
||||||
|
}
|
||||||
|
tail := maskHostTail(parts)
|
||||||
|
return "***." + strings.Join(tail, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
|
||||||
|
// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
|
||||||
|
func maskHostForPlainDomain(domain string) string {
|
||||||
|
parts := strings.Split(domain, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return domain
|
||||||
|
}
|
||||||
|
tail := maskHostTail(parts)
|
||||||
|
numStars := len(parts) - len(tail)
|
||||||
|
if numStars < 1 {
|
||||||
|
numStars = 1
|
||||||
|
}
|
||||||
|
stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
|
||||||
|
return stars + "." + strings.Join(tail, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
|
||||||
|
// Example:
|
||||||
|
// http://example.com -> http://***.com
|
||||||
|
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
||||||
|
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
||||||
|
// 192.168.1.1 -> ***.***.***.***
|
||||||
|
// openai.com -> ***.com
|
||||||
|
// www.openai.com -> ***.***.com
|
||||||
|
// api.openai.com -> ***.***.com
|
||||||
|
func MaskSensitiveInfo(str string) string {
|
||||||
|
// Mask URLs
|
||||||
|
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||||
|
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||||
|
u, err := url.Parse(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
return urlStr
|
||||||
|
}
|
||||||
|
|
||||||
|
host := u.Host
|
||||||
|
if host == "" {
|
||||||
|
return urlStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask host with unified logic
|
||||||
|
maskedHost := maskHostForURL(host)
|
||||||
|
|
||||||
|
result := u.Scheme + "://" + maskedHost
|
||||||
|
|
||||||
|
// Mask path
|
||||||
|
if u.Path != "" && u.Path != "/" {
|
||||||
|
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
|
||||||
|
maskedPathParts := make([]string, len(pathParts))
|
||||||
|
for i := range pathParts {
|
||||||
|
if pathParts[i] != "" {
|
||||||
|
maskedPathParts[i] = "***"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(maskedPathParts) > 0 {
|
||||||
|
result += "/" + strings.Join(maskedPathParts, "/")
|
||||||
|
}
|
||||||
|
} else if u.Path == "/" {
|
||||||
|
result += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask query parameters
|
||||||
|
if u.RawQuery != "" {
|
||||||
|
values, err := url.ParseQuery(u.RawQuery)
|
||||||
|
if err != nil {
|
||||||
|
// If can't parse query, just mask the whole query string
|
||||||
|
result += "?***"
|
||||||
|
} else {
|
||||||
|
maskedParams := make([]string, 0, len(values))
|
||||||
|
for key := range values {
|
||||||
|
maskedParams = append(maskedParams, key+"=***")
|
||||||
|
}
|
||||||
|
if len(maskedParams) > 0 {
|
||||||
|
result += "?" + strings.Join(maskedParams, "&")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mask domain names without protocol (like openai.com, www.openai.com)
|
||||||
|
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||||
|
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
||||||
|
return maskHostForPlainDomain(domain)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mask IP addresses
|
||||||
|
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||||
|
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||||
|
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|||||||
24
common/sys_log.go
Normal file
24
common/sys_log.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SysLog(s string) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SysError(s string) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FatalLog(v ...any) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
150
common/totp.go
Normal file
150
common/totp.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pquerna/otp"
|
||||||
|
"github.com/pquerna/otp/totp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 备用码配置
|
||||||
|
BackupCodeLength = 8 // 备用码长度
|
||||||
|
BackupCodeCount = 4 // 生成备用码数量
|
||||||
|
|
||||||
|
// 限制配置
|
||||||
|
MaxFailAttempts = 5 // 最大失败尝试次数
|
||||||
|
LockoutDuration = 300 // 锁定时间(秒)
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateTOTPSecret 生成TOTP密钥和配置
|
||||||
|
func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
|
||||||
|
issuer := Get2FAIssuer()
|
||||||
|
return totp.Generate(totp.GenerateOpts{
|
||||||
|
Issuer: issuer,
|
||||||
|
AccountName: accountName,
|
||||||
|
Period: 30,
|
||||||
|
Digits: otp.DigitsSix,
|
||||||
|
Algorithm: otp.AlgorithmSHA1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTOTPCode 验证TOTP验证码
|
||||||
|
func ValidateTOTPCode(secret, code string) bool {
|
||||||
|
// 清理验证码格式
|
||||||
|
cleanCode := strings.ReplaceAll(code, " ", "")
|
||||||
|
if len(cleanCode) != 6 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证验证码
|
||||||
|
return totp.Validate(cleanCode, secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateBackupCodes 生成备用恢复码
|
||||||
|
func GenerateBackupCodes() ([]string, error) {
|
||||||
|
codes := make([]string, BackupCodeCount)
|
||||||
|
|
||||||
|
for i := 0; i < BackupCodeCount; i++ {
|
||||||
|
code, err := generateRandomBackupCode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
codes[i] = code
|
||||||
|
}
|
||||||
|
|
||||||
|
return codes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomBackupCode 生成单个备用码
|
||||||
|
func generateRandomBackupCode() (string, error) {
|
||||||
|
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
code := make([]byte, BackupCodeLength)
|
||||||
|
|
||||||
|
for i := range code {
|
||||||
|
randomBytes := make([]byte, 1)
|
||||||
|
_, err := rand.Read(randomBytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
code[i] = charset[int(randomBytes[0])%len(charset)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 格式化为 XXXX-XXXX 格式
|
||||||
|
return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateBackupCode 验证备用码格式
|
||||||
|
func ValidateBackupCode(code string) bool {
|
||||||
|
// 移除所有分隔符并转为大写
|
||||||
|
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||||
|
if len(cleanCode) != BackupCodeLength {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查字符是否合法
|
||||||
|
for _, char := range cleanCode {
|
||||||
|
if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeBackupCode 标准化备用码格式
|
||||||
|
func NormalizeBackupCode(code string) string {
|
||||||
|
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||||
|
if len(cleanCode) == BackupCodeLength {
|
||||||
|
return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
|
||||||
|
}
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashBackupCode 对备用码进行哈希
|
||||||
|
func HashBackupCode(code string) (string, error) {
|
||||||
|
normalizedCode := NormalizeBackupCode(code)
|
||||||
|
return Password2Hash(normalizedCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FAIssuer 获取2FA发行者名称
|
||||||
|
func Get2FAIssuer() string {
|
||||||
|
return SystemName
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvOrDefault 获取环境变量或默认值
|
||||||
|
func getEnvOrDefault(key, defaultValue string) string {
|
||||||
|
if value, exists := os.LookupEnv(key); exists {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateNumericCode 验证数字验证码格式
|
||||||
|
func ValidateNumericCode(code string) (string, error) {
|
||||||
|
// 移除空格
|
||||||
|
code = strings.ReplaceAll(code, " ", "")
|
||||||
|
|
||||||
|
if len(code) != 6 {
|
||||||
|
return "", fmt.Errorf("验证码必须是6位数字")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否为纯数字
|
||||||
|
if _, err := strconv.Atoi(code); err != nil {
|
||||||
|
return "", fmt.Errorf("验证码只能包含数字")
|
||||||
|
}
|
||||||
|
|
||||||
|
return code, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateQRCodeData 生成二维码数据
|
||||||
|
func GenerateQRCodeData(secret, username string) string {
|
||||||
|
issuer := Get2FAIssuer()
|
||||||
|
accountName := fmt.Sprintf("%s (%s)", username, issuer)
|
||||||
|
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
|
||||||
|
issuer, accountName, secret, issuer)
|
||||||
|
}
|
||||||
@@ -123,8 +123,16 @@ func Interface2String(inter interface{}) string {
|
|||||||
return fmt.Sprintf("%d", inter.(int))
|
return fmt.Sprintf("%d", inter.(int))
|
||||||
case float64:
|
case float64:
|
||||||
return fmt.Sprintf("%f", inter.(float64))
|
return fmt.Sprintf("%f", inter.(float64))
|
||||||
|
case bool:
|
||||||
|
if inter.(bool) {
|
||||||
|
return "true"
|
||||||
|
} else {
|
||||||
|
return "false"
|
||||||
|
}
|
||||||
|
case nil:
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
return "Not Implemented"
|
return fmt.Sprintf("%v", inter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnescapeHTML(x string) interface{} {
|
func UnescapeHTML(x string) interface{} {
|
||||||
@@ -257,32 +265,32 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||||
}
|
}
|
||||||
durationStr := string(bytes.TrimSpace(output))
|
durationStr := string(bytes.TrimSpace(output))
|
||||||
if durationStr == "N/A" {
|
if durationStr == "N/A" {
|
||||||
// Create a temporary output file name
|
// Create a temporary output file name
|
||||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||||
}
|
}
|
||||||
tmpName := tmpFp.Name()
|
tmpName := tmpFp.Name()
|
||||||
// Close immediately so ffmpeg can open the file on Windows.
|
// Close immediately so ffmpeg can open the file on Windows.
|
||||||
_ = tmpFp.Close()
|
_ = tmpFp.Close()
|
||||||
defer os.Remove(tmpName)
|
defer os.Remove(tmpName)
|
||||||
|
|
||||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||||
if err := ffmpegCmd.Run(); err != nil {
|
if err := ffmpegCmd.Run(); err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recalculate the duration of the new file
|
// Recalculate the duration of the new file
|
||||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||||
output, err := c.Output()
|
output, err := c.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||||
}
|
}
|
||||||
durationStr = string(bytes.TrimSpace(output))
|
durationStr = string(bytes.TrimSpace(output))
|
||||||
}
|
}
|
||||||
return strconv.ParseFloat(durationStr, 64)
|
return strconv.ParseFloat(durationStr, 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,5 +30,7 @@ const (
|
|||||||
APITypeXinference
|
APITypeXinference
|
||||||
APITypeXai
|
APITypeXai
|
||||||
APITypeCoze
|
APITypeCoze
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeJimeng
|
||||||
|
APITypeMoonshot // this one is only for count, do not add any channel after this
|
||||||
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ const (
|
|||||||
ChannelTypeCoze = 49
|
ChannelTypeCoze = 49
|
||||||
ChannelTypeKling = 50
|
ChannelTypeKling = 50
|
||||||
ChannelTypeJimeng = 51
|
ChannelTypeJimeng = 51
|
||||||
|
ChannelTypeVidu = 52
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
@@ -106,4 +107,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.coze.cn", //49
|
"https://api.coze.cn", //49
|
||||||
"https://api.klingai.com", //50
|
"https://api.klingai.com", //50
|
||||||
"https://visual.volcengineapi.com", //51
|
"https://visual.volcengineapi.com", //51
|
||||||
|
"https://api.vidu.cn", //52
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package constant
|
|||||||
type ContextKey string
|
type ContextKey string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
ContextKeyTokenCountMeta ContextKey = "token_count_meta"
|
||||||
|
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||||
|
|
||||||
ContextKeyOriginalModel ContextKey = "original_model"
|
ContextKeyOriginalModel ContextKey = "original_model"
|
||||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||||
|
|
||||||
@@ -11,7 +14,6 @@ const (
|
|||||||
ContextKeyTokenKey ContextKey = "token_key"
|
ContextKeyTokenKey ContextKey = "token_key"
|
||||||
ContextKeyTokenId ContextKey = "token_id"
|
ContextKeyTokenId ContextKey = "token_id"
|
||||||
ContextKeyTokenGroup ContextKey = "token_group"
|
ContextKeyTokenGroup ContextKey = "token_group"
|
||||||
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
|
||||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||||
@@ -19,16 +21,20 @@ const (
|
|||||||
/* channel related keys */
|
/* channel related keys */
|
||||||
ContextKeyChannelId ContextKey = "channel_id"
|
ContextKeyChannelId ContextKey = "channel_id"
|
||||||
ContextKeyChannelName ContextKey = "channel_name"
|
ContextKeyChannelName ContextKey = "channel_name"
|
||||||
ContextKeyChannelCreateTime ContextKey = "channel_create_name"
|
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
|
||||||
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||||
ContextKeyChannelType ContextKey = "channel_type"
|
ContextKeyChannelType ContextKey = "channel_type"
|
||||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||||
|
ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
|
||||||
ContextKeyChannelParamOverride ContextKey = "param_override"
|
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||||
|
ContextKeyChannelHeaderOverride ContextKey = "header_override"
|
||||||
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||||
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||||
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||||
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||||
|
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
|
||||||
|
ContextKeyChannelKey ContextKey = "channel_key"
|
||||||
|
|
||||||
/* user related keys */
|
/* user related keys */
|
||||||
ContextKeyUserId ContextKey = "id"
|
ContextKeyUserId ContextKey = "id"
|
||||||
@@ -39,4 +45,6 @@ const (
|
|||||||
ContextKeyUserGroup ContextKey = "user_group"
|
ContextKeyUserGroup ContextKey = "user_group"
|
||||||
ContextKeyUsingGroup ContextKey = "group"
|
ContextKeyUsingGroup ContextKey = "group"
|
||||||
ContextKeyUserName ContextKey = "username"
|
ContextKeyUserName ContextKey = "username"
|
||||||
|
|
||||||
|
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,16 +5,16 @@ type TaskPlatform string
|
|||||||
const (
|
const (
|
||||||
TaskPlatformSuno TaskPlatform = "suno"
|
TaskPlatformSuno TaskPlatform = "suno"
|
||||||
TaskPlatformMidjourney = "mj"
|
TaskPlatformMidjourney = "mj"
|
||||||
TaskPlatformKling TaskPlatform = "kling"
|
|
||||||
TaskPlatformJimeng TaskPlatform = "jimeng"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SunoActionMusic = "MUSIC"
|
SunoActionMusic = "MUSIC"
|
||||||
SunoActionLyrics = "LYRICS"
|
SunoActionLyrics = "LYRICS"
|
||||||
|
|
||||||
TaskActionGenerate = "generate"
|
TaskActionGenerate = "generate"
|
||||||
TaskActionTextGenerate = "textGenerate"
|
TaskActionTextGenerate = "textGenerate"
|
||||||
|
TaskActionFirstTailGenerate = "firstTailGenerate"
|
||||||
|
TaskActionReferenceGenerate = "referenceGenerate"
|
||||||
)
|
)
|
||||||
|
|
||||||
var SunoModel2Action = map[string]string{
|
var SunoModel2Action = map[string]string{
|
||||||
|
|||||||
@@ -4,17 +4,19 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/shopspring/decimal"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting/operation_setting"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -133,7 +135,11 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
|||||||
for k := range headers {
|
for k := range headers {
|
||||||
req.Header.Add(k, headers.Get(k))
|
req.Header.Add(k, headers.Get(k))
|
||||||
}
|
}
|
||||||
res, err := service.GetHttpClient().Do(req)
|
client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -336,7 +342,7 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
|||||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||||
}
|
}
|
||||||
availableBalanceCny := response.Data.AvailableBalance
|
availableBalanceCny := response.Data.AvailableBalance
|
||||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
|
||||||
channel.UpdateBalance(availableBalanceUsd)
|
channel.UpdateBalance(availableBalanceUsd)
|
||||||
return availableBalanceUsd, nil
|
return availableBalanceUsd, nil
|
||||||
}
|
}
|
||||||
@@ -409,26 +415,24 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
func UpdateChannelBalance(c *gin.Context) {
|
func UpdateChannelBalance(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err := model.GetChannelById(id, true)
|
channel, err := model.CacheGetChannel(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": "多密钥渠道不支持余额查询",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
balance, err := updateChannelBalance(channel)
|
balance, err := updateChannelBalance(channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -436,7 +440,6 @@ func UpdateChannelBalance(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"balance": balance,
|
"balance": balance,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateAllChannelsBalance() error {
|
func updateAllChannelsBalance() error {
|
||||||
@@ -448,6 +451,9 @@ func updateAllChannelsBalance() error {
|
|||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
|
continue // skip multi-key channels
|
||||||
|
}
|
||||||
// TODO: support Azure
|
// TODO: support Azure
|
||||||
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||||||
// continue
|
// continue
|
||||||
@@ -458,7 +464,7 @@ func updateAllChannelsBalance() error {
|
|||||||
} else {
|
} else {
|
||||||
// err is nil & balance <= 0 means quota is used up
|
// err is nil & balance <= 0 means quota is used up
|
||||||
if balance <= 0 {
|
if balance <= 0 {
|
||||||
service.DisableChannel(channel.Id, channel.Name, "余额不足")
|
service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
@@ -470,10 +476,7 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
|||||||
// TODO: make it async
|
// TODO: make it async
|
||||||
err := updateAllChannelsBalance()
|
err := updateAllChannelsBalance()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -17,8 +17,10 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting/operation_setting"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -30,22 +32,49 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, newAPIError *types.NewAPIError) {
|
type testResult struct {
|
||||||
|
context *gin.Context
|
||||||
|
localErr error
|
||||||
|
newAPIError *types.NewAPIError
|
||||||
|
}
|
||||||
|
|
||||||
|
func testChannel(channel *model.Channel, testModel string) testResult {
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
if channel.Type == constant.ChannelTypeMidjourney {
|
if channel.Type == constant.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("midjourney channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||||
return errors.New("midjourney plus channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("midjourney plus channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == constant.ChannelTypeSunoAPI {
|
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||||
return errors.New("suno channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("suno channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == constant.ChannelTypeKling {
|
if channel.Type == constant.ChannelTypeKling {
|
||||||
return errors.New("kling channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("kling channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == constant.ChannelTypeJimeng {
|
if channel.Type == constant.ChannelTypeJimeng {
|
||||||
return errors.New("jimeng channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("jimeng channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if channel.Type == constant.ChannelTypeVidu {
|
||||||
|
return testResult{
|
||||||
|
localErr: errors.New("vidu channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -61,6 +90,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|||||||
requestPath = "/v1/embeddings" // 修改请求路径
|
requestPath = "/v1/embeddings" // 修改请求路径
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VolcEngine 图像生成模型
|
||||||
|
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
||||||
|
requestPath = "/v1/images/generations"
|
||||||
|
}
|
||||||
|
|
||||||
c.Request = &http.Request{
|
c.Request = &http.Request{
|
||||||
Method: "POST",
|
Method: "POST",
|
||||||
URL: &url.URL{Path: requestPath}, // 使用动态路径
|
URL: &url.URL{Path: requestPath}, // 使用动态路径
|
||||||
@@ -80,82 +114,198 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 重新检查模型类型并更新请求路径
|
||||||
|
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||||||
|
strings.HasPrefix(testModel, "m3e") ||
|
||||||
|
strings.Contains(testModel, "bge-") ||
|
||||||
|
strings.Contains(testModel, "embed") ||
|
||||||
|
channel.Type == constant.ChannelTypeMokaAI {
|
||||||
|
requestPath = "/v1/embeddings"
|
||||||
|
c.Request.URL.Path = requestPath
|
||||||
|
}
|
||||||
|
|
||||||
|
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
||||||
|
requestPath = "/v1/images/generations"
|
||||||
|
c.Request.URL.Path = requestPath
|
||||||
|
}
|
||||||
|
|
||||||
cache, err := model.GetUserCache(1)
|
cache, err := model.GetUserCache(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return testResult{
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cache.WriteContext(c)
|
cache.WriteContext(c)
|
||||||
|
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
c.Request.Header.Set("Content-Type", "application/json")
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
group, _ := model.GetUserGroup(1, false)
|
group, _ := model.GetUserGroup(1, false)
|
||||||
c.Set("group", group)
|
c.Set("group", group)
|
||||||
|
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||||
|
if newAPIError != nil {
|
||||||
info := relaycommon.GenRelayInfo(c)
|
return testResult{
|
||||||
|
context: c,
|
||||||
err = helper.ModelMappedHelper(c, info, nil)
|
localErr: newAPIError,
|
||||||
if err != nil {
|
newAPIError: newAPIError,
|
||||||
return err, types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
}
|
||||||
}
|
}
|
||||||
|
request := buildTestRequest(testModel)
|
||||||
|
|
||||||
|
// Determine relay format based on request path
|
||||||
|
relayFormat := types.RelayFormatOpenAI
|
||||||
|
if c.Request.URL.Path == "/v1/embeddings" {
|
||||||
|
relayFormat = types.RelayFormatEmbedding
|
||||||
|
}
|
||||||
|
if c.Request.URL.Path == "/v1/images/generations" {
|
||||||
|
relayFormat = types.RelayFormatOpenAIImage
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info.InitChannelMeta(c)
|
||||||
|
|
||||||
|
err = helper.ModelMappedHelper(c, info, request)
|
||||||
|
if err != nil {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
testModel = info.UpstreamModelName
|
testModel = info.UpstreamModelName
|
||||||
|
request.Model = testModel
|
||||||
|
|
||||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
||||||
|
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildTestRequest(testModel)
|
//// 创建一个用于日志的 info 副本,移除 ApiKey
|
||||||
// 创建一个用于日志的 info 副本,移除 ApiKey
|
//logInfo := info
|
||||||
logInfo := *info
|
//logInfo.ApiKey = ""
|
||||||
logInfo.ApiKey = ""
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
|
||||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeModelPriceError)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor.Init(info)
|
adaptor.Init(info)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
var convertedRequest any
|
||||||
|
// 根据 RelayMode 选择正确的转换函数
|
||||||
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
|
// 创建一个 EmbeddingRequest
|
||||||
|
embeddingRequest := dto.EmbeddingRequest{
|
||||||
|
Input: request.Input,
|
||||||
|
Model: request.Model,
|
||||||
|
}
|
||||||
|
// 调用专门用于 Embedding 的转换函数
|
||||||
|
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
|
||||||
|
} else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
|
||||||
|
// 创建一个 ImageRequest
|
||||||
|
prompt := "cat"
|
||||||
|
if request.Prompt != nil {
|
||||||
|
if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
|
||||||
|
prompt = promptStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
imageRequest := dto.ImageRequest{
|
||||||
|
Prompt: prompt,
|
||||||
|
Model: request.Model,
|
||||||
|
N: uint(request.N),
|
||||||
|
Size: request.Size,
|
||||||
|
}
|
||||||
|
// 调用专门用于图像生成的转换函数
|
||||||
|
convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
|
||||||
|
} else {
|
||||||
|
// 对其他所有请求类型(如 Chat),保持原有逻辑
|
||||||
|
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeJsonMarshalFailed)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
c.Request.Body = io.NopCloser(requestBody)
|
c.Request.Body = io.NopCloser(requestBody)
|
||||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
err := service.RelayErrorHandler(httpResp, true)
|
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
|
||||||
return err, types.NewError(err, types.ErrorCodeBadResponse)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return respErr, respErr
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: respErr,
|
||||||
|
newAPIError: respErr,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if usageA == nil {
|
if usageA == nil {
|
||||||
return errors.New("usage is nil"), types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: errors.New("usage is nil"),
|
||||||
|
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
usage := usageA.(*dto.Usage)
|
usage := usageA.(*dto.Usage)
|
||||||
result := w.Result()
|
result := w.Result()
|
||||||
respBody, err := io.ReadAll(result.Body)
|
respBody, err := io.ReadAll(result.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
info.PromptTokens = usage.PromptTokens
|
info.PromptTokens = usage.PromptTokens
|
||||||
|
|
||||||
@@ -183,12 +333,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, newAPIErr
|
|||||||
Quota: quota,
|
Quota: quota,
|
||||||
Content: "模型测试",
|
Content: "模型测试",
|
||||||
UseTimeSeconds: int(consumedTime),
|
UseTimeSeconds: int(consumedTime),
|
||||||
IsStream: false,
|
IsStream: info.IsStream,
|
||||||
Group: info.UsingGroup,
|
Group: info.UsingGroup,
|
||||||
Other: other,
|
Other: other,
|
||||||
})
|
})
|
||||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
return nil, nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: nil,
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||||
@@ -203,7 +357,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|||||||
strings.Contains(model, "bge-") {
|
strings.Contains(model, "bge-") {
|
||||||
testRequest.Model = model
|
testRequest.Model = model
|
||||||
// Embedding 请求
|
// Embedding 请求
|
||||||
testRequest.Input = []string{"hello world"}
|
testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
|
||||||
return testRequest
|
return testRequest
|
||||||
}
|
}
|
||||||
// 并非Embedding 模型
|
// 并非Embedding 模型
|
||||||
@@ -231,31 +385,41 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|||||||
func TestChannel(c *gin.Context) {
|
func TestChannel(c *gin.Context) {
|
||||||
channelId, err := strconv.Atoi(c.Param("id"))
|
channelId, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err := model.GetChannelById(channelId, true)
|
channel, err := model.CacheGetChannel(channelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
channel, err = model.GetChannelById(channelId, true)
|
||||||
"success": false,
|
if err != nil {
|
||||||
"message": err.Error(),
|
common.ApiError(c, err)
|
||||||
})
|
return
|
||||||
return
|
}
|
||||||
}
|
}
|
||||||
|
//defer func() {
|
||||||
|
// if channel.ChannelInfo.IsMultiKey {
|
||||||
|
// go func() { _ = channel.SaveChannelInfo() }()
|
||||||
|
// }
|
||||||
|
//}()
|
||||||
testModel := c.Query("model")
|
testModel := c.Query("model")
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
_, newAPIError := testChannel(channel, testModel)
|
result := testChannel(channel, testModel)
|
||||||
|
if result.localErr != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": result.localErr.Error(),
|
||||||
|
"time": 0.0,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
go channel.UpdateResponseTime(milliseconds)
|
go channel.UpdateResponseTime(milliseconds)
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
if newAPIError != nil {
|
if result.newAPIError != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": newAPIError.Error(),
|
"message": result.newAPIError.Error(),
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@@ -280,9 +444,9 @@ func testAllChannels(notify bool) error {
|
|||||||
}
|
}
|
||||||
testAllChannelsRunning = true
|
testAllChannelsRunning = true
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
||||||
if err != nil {
|
if getChannelErr != nil {
|
||||||
return err
|
return getChannelErr
|
||||||
}
|
}
|
||||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||||
if disableThreshold == 0 {
|
if disableThreshold == 0 {
|
||||||
@@ -299,30 +463,34 @@ func testAllChannels(notify bool) error {
|
|||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, newAPIError := testChannel(channel, "")
|
result := testChannel(channel, "")
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
|
||||||
shouldBanChannel := false
|
shouldBanChannel := false
|
||||||
|
newAPIError := result.newAPIError
|
||||||
// request error disables the channel
|
// request error disables the channel
|
||||||
if err != nil {
|
if newAPIError != nil {
|
||||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, newAPIError)
|
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if milliseconds > disableThreshold {
|
// 当错误检查通过,才检查响应时间
|
||||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||||
shouldBanChannel = true
|
if milliseconds > disableThreshold {
|
||||||
|
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||||
|
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||||||
|
shouldBanChannel = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// disable channel
|
// disable channel
|
||||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// enable channel
|
// enable channel
|
||||||
if !isChannelEnabled && service.ShouldEnableChannel(err, newAPIError, channel.Status) {
|
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
||||||
service.EnableChannel(channel.Id, channel.Name)
|
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
@@ -339,10 +507,7 @@ func testAllChannels(notify bool) error {
|
|||||||
func TestAllChannels(c *gin.Context) {
|
func TestAllChannels(c *gin.Context) {
|
||||||
err := testAllChannels(true)
|
err := testAllChannels(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -352,15 +517,26 @@ func TestAllChannels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func AutomaticallyTestChannels(frequency int) {
|
var autoTestChannelsOnce sync.Once
|
||||||
if frequency <= 0 {
|
|
||||||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
func AutomaticallyTestChannels() {
|
||||||
return
|
autoTestChannelsOnce.Do(func() {
|
||||||
}
|
for {
|
||||||
for {
|
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(10 * time.Minute)
|
||||||
common.SysLog("testing all channels")
|
continue
|
||||||
_ = testAllChannels(false)
|
}
|
||||||
common.SysLog("channel test finished")
|
frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
|
||||||
}
|
common.SysLog(fmt.Sprintf("automatically test channels with interval %d minutes", frequency))
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
|
common.SysLog("automatically testing all channels")
|
||||||
|
_ = testAllChannels(false)
|
||||||
|
common.SysLog("automatically channel test finished")
|
||||||
|
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -3,101 +3,102 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
||||||
func MigrateConsoleSetting(c *gin.Context) {
|
func MigrateConsoleSetting(c *gin.Context) {
|
||||||
// 读取全部 option
|
// 读取全部 option
|
||||||
opts, err := model.AllOption()
|
opts, err := model.AllOption()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 建立 map
|
// 建立 map
|
||||||
valMap := map[string]string{}
|
valMap := map[string]string{}
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
valMap[o.Key] = o.Value
|
valMap[o.Key] = o.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理 APIInfo
|
// 处理 APIInfo
|
||||||
if v := valMap["ApiInfo"]; v != "" {
|
if v := valMap["ApiInfo"]; v != "" {
|
||||||
var arr []map[string]interface{}
|
var arr []map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||||
if len(arr) > 50 {
|
if len(arr) > 50 {
|
||||||
arr = arr[:50]
|
arr = arr[:50]
|
||||||
}
|
}
|
||||||
bytes, _ := json.Marshal(arr)
|
bytes, _ := json.Marshal(arr)
|
||||||
model.UpdateOption("console_setting.api_info", string(bytes))
|
model.UpdateOption("console_setting.api_info", string(bytes))
|
||||||
}
|
}
|
||||||
model.UpdateOption("ApiInfo", "")
|
model.UpdateOption("ApiInfo", "")
|
||||||
}
|
}
|
||||||
// Announcements 直接搬
|
// Announcements 直接搬
|
||||||
if v := valMap["Announcements"]; v != "" {
|
if v := valMap["Announcements"]; v != "" {
|
||||||
model.UpdateOption("console_setting.announcements", v)
|
model.UpdateOption("console_setting.announcements", v)
|
||||||
model.UpdateOption("Announcements", "")
|
model.UpdateOption("Announcements", "")
|
||||||
}
|
}
|
||||||
// FAQ 转换
|
// FAQ 转换
|
||||||
if v := valMap["FAQ"]; v != "" {
|
if v := valMap["FAQ"]; v != "" {
|
||||||
var arr []map[string]interface{}
|
var arr []map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||||
out := []map[string]interface{}{}
|
out := []map[string]interface{}{}
|
||||||
for _, item := range arr {
|
for _, item := range arr {
|
||||||
q, _ := item["question"].(string)
|
q, _ := item["question"].(string)
|
||||||
if q == "" {
|
if q == "" {
|
||||||
q, _ = item["title"].(string)
|
q, _ = item["title"].(string)
|
||||||
}
|
}
|
||||||
a, _ := item["answer"].(string)
|
a, _ := item["answer"].(string)
|
||||||
if a == "" {
|
if a == "" {
|
||||||
a, _ = item["content"].(string)
|
a, _ = item["content"].(string)
|
||||||
}
|
}
|
||||||
if q != "" && a != "" {
|
if q != "" && a != "" {
|
||||||
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(out) > 50 {
|
if len(out) > 50 {
|
||||||
out = out[:50]
|
out = out[:50]
|
||||||
}
|
}
|
||||||
bytes, _ := json.Marshal(out)
|
bytes, _ := json.Marshal(out)
|
||||||
model.UpdateOption("console_setting.faq", string(bytes))
|
model.UpdateOption("console_setting.faq", string(bytes))
|
||||||
}
|
}
|
||||||
model.UpdateOption("FAQ", "")
|
model.UpdateOption("FAQ", "")
|
||||||
}
|
}
|
||||||
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
||||||
url := valMap["UptimeKumaUrl"]
|
url := valMap["UptimeKumaUrl"]
|
||||||
slug := valMap["UptimeKumaSlug"]
|
slug := valMap["UptimeKumaSlug"]
|
||||||
if url != "" && slug != "" {
|
if url != "" && slug != "" {
|
||||||
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
||||||
groups := []map[string]interface{}{
|
groups := []map[string]interface{}{
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"categoryName": "old",
|
"categoryName": "old",
|
||||||
"url": url,
|
"url": url,
|
||||||
"slug": slug,
|
"slug": slug,
|
||||||
"description": "",
|
"description": "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
bytes, _ := json.Marshal(groups)
|
bytes, _ := json.Marshal(groups)
|
||||||
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
||||||
}
|
}
|
||||||
// 清空旧键内容
|
// 清空旧键内容
|
||||||
if url != "" {
|
if url != "" {
|
||||||
model.UpdateOption("UptimeKumaUrl", "")
|
model.UpdateOption("UptimeKumaUrl", "")
|
||||||
}
|
}
|
||||||
if slug != "" {
|
if slug != "" {
|
||||||
model.UpdateOption("UptimeKumaSlug", "")
|
model.UpdateOption("UptimeKumaSlug", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 删除旧键记录
|
// 删除旧键记录
|
||||||
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
||||||
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
||||||
|
|
||||||
// 重新加载 OptionMap
|
// 重新加载 OptionMap
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
common.SysLog("console setting migrated")
|
common.SysLog("console setting migrated")
|
||||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,13 +5,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GitHubOAuthResponse struct {
|
type GitHubOAuthResponse struct {
|
||||||
@@ -103,10 +104,7 @@ func GitHubOAuth(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
githubUser, err := getGitHubUserInfoByCode(code)
|
githubUser, err := getGitHubUserInfoByCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := model.User{
|
user := model.User{
|
||||||
@@ -185,10 +183,7 @@ func GitHubBind(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
githubUser, err := getGitHubUserInfoByCode(code)
|
githubUser, err := getGitHubUserInfoByCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := model.User{
|
user := model.User{
|
||||||
@@ -207,19 +202,13 @@ func GitHubBind(c *gin.Context) {
|
|||||||
user.Id = id.(int)
|
user.Id = id.(int)
|
||||||
err = user.FillUserById()
|
err = user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.GitHubId = githubUser.Login
|
user.GitHubId = githubUser.Login
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -239,10 +228,7 @@ func GenerateOAuthCode(c *gin.Context) {
|
|||||||
session.Set("oauth_state", state)
|
session.Set("oauth_state", state)
|
||||||
err := session.Save()
|
err := session.Save()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -38,10 +38,7 @@ func LinuxDoBind(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,20 +60,14 @@ func LinuxDoBind(c *gin.Context) {
|
|||||||
|
|
||||||
err = user.FillUserById()
|
err = user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,10 +193,7 @@ func LinuxdoOAuth(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,21 +220,29 @@ func LinuxdoOAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if common.RegisterEnabled {
|
if common.RegisterEnabled {
|
||||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
|
||||||
user.DisplayName = linuxdoUser.Name
|
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
user.Role = common.RoleCommonUser
|
user.DisplayName = linuxdoUser.Name
|
||||||
user.Status = common.UserStatusEnabled
|
user.Role = common.RoleCommonUser
|
||||||
|
user.Status = common.UserStatusEnabled
|
||||||
|
|
||||||
affCode := session.Get("aff")
|
affCode := session.Get("aff")
|
||||||
inviterId := 0
|
inviterId := 0
|
||||||
if affCode != nil {
|
if affCode != nil {
|
||||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.Insert(inviterId); err != nil {
|
if err := user.Insert(inviterId); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,14 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func GetAllLogs(c *gin.Context) {
|
func GetAllLogs(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
if pageSize < 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
logType, _ := strconv.Atoi(c.Query("type"))
|
logType, _ := strconv.Atoi(c.Query("type"))
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
@@ -26,38 +19,19 @@ func GetAllLogs(c *gin.Context) {
|
|||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
group := c.Query("group")
|
group := c.Query("group")
|
||||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
|
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(logs)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": map[string]any{
|
return
|
||||||
"items": logs,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserLogs(c *gin.Context) {
|
func GetUserLogs(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
if pageSize < 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
if pageSize > 100 {
|
|
||||||
pageSize = 100
|
|
||||||
}
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
logType, _ := strconv.Atoi(c.Query("type"))
|
logType, _ := strconv.Atoi(c.Query("type"))
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
@@ -65,24 +39,14 @@ func GetUserLogs(c *gin.Context) {
|
|||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
group := c.Query("group")
|
group := c.Query("group")
|
||||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
|
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(logs)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": map[string]any{
|
|
||||||
"items": logs,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,10 +54,7 @@ func SearchAllLogs(c *gin.Context) {
|
|||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
logs, err := model.SearchAllLogs(keyword)
|
logs, err := model.SearchAllLogs(keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -109,10 +70,7 @@ func SearchUserLogs(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
logs, err := model.SearchUserLogs(userId, keyword)
|
logs, err := model.SearchUserLogs(userId, keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -198,10 +156,7 @@ func DeleteHistoryLogs(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -5,16 +5,18 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strconv"
|
"one-api/setting/system_setting"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UpdateMidjourneyTaskBulk() {
|
func UpdateMidjourneyTaskBulk() {
|
||||||
@@ -28,7 +30,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||||
taskChannelM := make(map[int][]string)
|
taskChannelM := make(map[int][]string)
|
||||||
taskM := make(map[string]*model.Midjourney)
|
taskM := make(map[string]*model.Midjourney)
|
||||||
nullTaskIds := make([]int, 0)
|
nullTaskIds := make([]int, 0)
|
||||||
@@ -47,9 +49,9 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
||||||
} else {
|
} else {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(taskChannelM) == 0 {
|
if len(taskChannelM) == 0 {
|
||||||
@@ -57,20 +59,20 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for channelId, taskIds := range taskChannelM {
|
for channelId, taskIds := range taskChannelM {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||||
if len(taskIds) == 0 {
|
if len(taskIds) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
||||||
err := model.MjBulkUpdate(taskIds, map[string]any{
|
err := model.MjBulkUpdate(taskIds, map[string]any{
|
||||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -81,7 +83,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
})
|
})
|
||||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 设置超时时间
|
// 设置超时时间
|
||||||
@@ -93,22 +95,22 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||||
resp, err := service.GetHttpClient().Do(req)
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var responseItems []dto.MidjourneyDto
|
var responseItems []dto.MidjourneyDto
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
@@ -145,9 +147,25 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||||
task.Buttons = string(buttonStr)
|
task.Buttons = string(buttonStr)
|
||||||
}
|
}
|
||||||
|
// 映射 VideoUrl
|
||||||
|
task.VideoUrl = responseItem.VideoUrl
|
||||||
|
|
||||||
|
// 映射 VideoUrls - 将数组序列化为 JSON 字符串
|
||||||
|
if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
|
||||||
|
videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
|
||||||
|
task.VideoUrls = "[]" // 失败时设置为空数组
|
||||||
|
} else {
|
||||||
|
task.VideoUrls = string(videoUrlsStr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
task.VideoUrls = "" // 空值时清空字段
|
||||||
|
}
|
||||||
|
|
||||||
shouldReturnQuota := false
|
shouldReturnQuota := false
|
||||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
if task.Quota != 0 {
|
if task.Quota != 0 {
|
||||||
shouldReturnQuota = true
|
shouldReturnQuota = true
|
||||||
@@ -155,14 +173,14 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
err = task.Update()
|
err = task.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||||
} else {
|
} else {
|
||||||
if shouldReturnQuota {
|
if shouldReturnQuota {
|
||||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
|
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,19 +226,26 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
|||||||
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
// 检查 VideoUrl 是否需要更新
|
||||||
|
if oldTask.VideoUrl != newTask.VideoUrl {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// 检查 VideoUrls 是否需要更新
|
||||||
|
if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 {
|
||||||
|
newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls)
|
||||||
|
if oldTask.VideoUrls != string(newVideoUrlsStr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else if oldTask.VideoUrls != "" {
|
||||||
|
// 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllMidjourney(c *gin.Context) {
|
func GetAllMidjourney(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
||||||
if pageSize <= 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析其他查询参数
|
// 解析其他查询参数
|
||||||
queryParams := model.TaskQueryParams{
|
queryParams := model.TaskQueryParams{
|
||||||
@@ -230,36 +255,22 @@ func GetAllMidjourney(c *gin.Context) {
|
|||||||
EndTimestamp: c.Query("end_timestamp"),
|
EndTimestamp: c.Query("end_timestamp"),
|
||||||
}
|
}
|
||||||
|
|
||||||
items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||||
total := model.CountAllTasks(queryParams)
|
total := model.CountAllTasks(queryParams)
|
||||||
|
|
||||||
if setting.MjForwardUrlEnabled {
|
if setting.MjForwardUrlEnabled {
|
||||||
for i, midjourney := range items {
|
for i, midjourney := range items {
|
||||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
items[i] = midjourney
|
items[i] = midjourney
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(items)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": gin.H{
|
|
||||||
"items": items,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserMidjourney(c *gin.Context) {
|
func GetUserMidjourney(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
||||||
if pageSize <= 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
@@ -269,23 +280,16 @@ func GetUserMidjourney(c *gin.Context) {
|
|||||||
EndTimestamp: c.Query("end_timestamp"),
|
EndTimestamp: c.Query("end_timestamp"),
|
||||||
}
|
}
|
||||||
|
|
||||||
items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||||
total := model.CountAllUserTask(userId, queryParams)
|
total := model.CountAllUserTask(userId, queryParams)
|
||||||
|
|
||||||
if setting.MjForwardUrlEnabled {
|
if setting.MjForwardUrlEnabled {
|
||||||
for i, midjourney := range items {
|
for i, midjourney := range items {
|
||||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
items[i] = midjourney
|
items[i] = midjourney
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(items)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": gin.H{
|
|
||||||
"items": items,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,44 +39,47 @@ func TestStatus(c *gin.Context) {
|
|||||||
func GetStatus(c *gin.Context) {
|
func GetStatus(c *gin.Context) {
|
||||||
|
|
||||||
cs := console_setting.GetConsoleSetting()
|
cs := console_setting.GetConsoleSetting()
|
||||||
|
common.OptionMapRWMutex.RLock()
|
||||||
|
defer common.OptionMapRWMutex.RUnlock()
|
||||||
|
|
||||||
data := gin.H{
|
data := gin.H{
|
||||||
"version": common.Version,
|
"version": common.Version,
|
||||||
"start_time": common.StartTime,
|
"start_time": common.StartTime,
|
||||||
"email_verification": common.EmailVerificationEnabled,
|
"email_verification": common.EmailVerificationEnabled,
|
||||||
"github_oauth": common.GitHubOAuthEnabled,
|
"github_oauth": common.GitHubOAuthEnabled,
|
||||||
"github_client_id": common.GitHubClientId,
|
"github_client_id": common.GitHubClientId,
|
||||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||||
"linuxdo_client_id": common.LinuxDOClientId,
|
"linuxdo_client_id": common.LinuxDOClientId,
|
||||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||||
"telegram_bot_name": common.TelegramBotName,
|
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||||
"system_name": common.SystemName,
|
"telegram_bot_name": common.TelegramBotName,
|
||||||
"logo": common.Logo,
|
"system_name": common.SystemName,
|
||||||
"footer_html": common.Footer,
|
"logo": common.Logo,
|
||||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
"footer_html": common.Footer,
|
||||||
"wechat_login": common.WeChatAuthEnabled,
|
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||||
"server_address": setting.ServerAddress,
|
"wechat_login": common.WeChatAuthEnabled,
|
||||||
"price": setting.Price,
|
"server_address": system_setting.ServerAddress,
|
||||||
"min_topup": setting.MinTopUp,
|
"turnstile_check": common.TurnstileCheckEnabled,
|
||||||
"turnstile_check": common.TurnstileCheckEnabled,
|
"turnstile_site_key": common.TurnstileSiteKey,
|
||||||
"turnstile_site_key": common.TurnstileSiteKey,
|
"top_up_link": common.TopUpLink,
|
||||||
"top_up_link": common.TopUpLink,
|
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
"quota_per_unit": common.QuotaPerUnit,
|
||||||
"quota_per_unit": common.QuotaPerUnit,
|
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
"enable_batch_update": common.BatchUpdateEnabled,
|
||||||
"enable_batch_update": common.BatchUpdateEnabled,
|
"enable_drawing": common.DrawingEnabled,
|
||||||
"enable_drawing": common.DrawingEnabled,
|
"enable_task": common.TaskEnabled,
|
||||||
"enable_task": common.TaskEnabled,
|
"enable_data_export": common.DataExportEnabled,
|
||||||
"enable_data_export": common.DataExportEnabled,
|
"data_export_default_time": common.DataExportDefaultTime,
|
||||||
"data_export_default_time": common.DataExportDefaultTime,
|
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
"chats": setting.Chats,
|
||||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||||
"chats": setting.Chats,
|
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
|
||||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
"usd_exchange_rate": operation_setting.USDExchangeRate,
|
||||||
"pay_methods": setting.PayMethods,
|
"price": operation_setting.Price,
|
||||||
|
"stripe_unit_price": setting.StripeUnitPrice,
|
||||||
|
|
||||||
// 面板启用开关
|
// 面板启用开关
|
||||||
"api_info_enabled": cs.ApiInfoEnabled,
|
"api_info_enabled": cs.ApiInfoEnabled,
|
||||||
@@ -84,6 +87,10 @@ func GetStatus(c *gin.Context) {
|
|||||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||||
"faq_enabled": cs.FAQEnabled,
|
"faq_enabled": cs.FAQEnabled,
|
||||||
|
|
||||||
|
// 模块管理配置
|
||||||
|
"HeaderNavModules": common.OptionMap["HeaderNavModules"],
|
||||||
|
"SidebarModulesAdmin": common.OptionMap["SidebarModulesAdmin"],
|
||||||
|
|
||||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||||
@@ -214,10 +221,7 @@ func SendEmailVerification(c *gin.Context) {
|
|||||||
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
|
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
|
||||||
err := common.SendEmail(subject, email, content)
|
err := common.SendEmail(subject, email, content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -245,7 +249,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
code := common.GenerateVerificationCode(0)
|
code := common.GenerateVerificationCode(0)
|
||||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
|
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
|
||||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||||
@@ -253,10 +257,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||||
err := common.SendEmail(subject, email, content)
|
err := common.SendEmail(subject, email, content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -291,10 +292,7 @@ func ResetPassword(c *gin.Context) {
|
|||||||
password := common.GenerateVerificationCode(12)
|
password := common.GenerateVerificationCode(12)
|
||||||
err = model.ResetUserPasswordByEmail(req.Email, password)
|
err = model.ResetUserPasswordByEmail(req.Email, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.DeleteKey(req.Email, common.PasswordResetPurpose)
|
common.DeleteKey(req.Email, common.PasswordResetPurpose)
|
||||||
|
|||||||
27
controller/missing_models.go
Normal file
27
controller/missing_models.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetMissingModels returns the list of model names that are referenced by channels
|
||||||
|
// but do not have corresponding records in the models meta table.
|
||||||
|
// This helps administrators quickly discover models that need configuration.
|
||||||
|
func GetMissingModels(c *gin.Context) {
|
||||||
|
missing, err := model.GetMissingModels()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": missing,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@@ -92,7 +93,9 @@ func init() {
|
|||||||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
ChannelType: i,
|
||||||
|
}}
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
channelId2Models[i] = adaptor.GetModelList()
|
channelId2Models[i] = adaptor.GetModelList()
|
||||||
@@ -102,7 +105,7 @@ func init() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func ListModels(c *gin.Context, modelType int) {
|
||||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||||
|
|
||||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
@@ -171,10 +174,42 @@ func ListModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
switch modelType {
|
||||||
"success": true,
|
case constant.ChannelTypeAnthropic:
|
||||||
"data": userOpenAiModels,
|
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||||||
})
|
for i, model := range userOpenAiModels {
|
||||||
|
useranthropicModels[i] = dto.AnthropicModel{
|
||||||
|
ID: model.Id,
|
||||||
|
CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
|
||||||
|
DisplayName: model.Id,
|
||||||
|
Type: "model",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"data": useranthropicModels,
|
||||||
|
"first_id": useranthropicModels[0].ID,
|
||||||
|
"has_more": false,
|
||||||
|
"last_id": useranthropicModels[len(useranthropicModels)-1].ID,
|
||||||
|
})
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
|
||||||
|
for i, model := range userOpenAiModels {
|
||||||
|
userGeminiModels[i] = dto.GeminiModel{
|
||||||
|
Name: model.Id,
|
||||||
|
DisplayName: model.Id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"models": userGeminiModels,
|
||||||
|
"nextPageToken": nil,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": userOpenAiModels,
|
||||||
|
"object": "list",
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ChannelListModels(c *gin.Context) {
|
func ChannelListModels(c *gin.Context) {
|
||||||
@@ -198,10 +233,20 @@ func EnabledListModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func RetrieveModel(c *gin.Context) {
|
func RetrieveModel(c *gin.Context, modelType int) {
|
||||||
modelId := c.Param("model")
|
modelId := c.Param("model")
|
||||||
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
||||||
c.JSON(200, aiModel)
|
switch modelType {
|
||||||
|
case constant.ChannelTypeAnthropic:
|
||||||
|
c.JSON(200, dto.AnthropicModel{
|
||||||
|
ID: aiModel.Id,
|
||||||
|
CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
|
||||||
|
DisplayName: aiModel.Id,
|
||||||
|
Type: "model",
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(200, aiModel)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
openAIError := dto.OpenAIError{
|
openAIError := dto.OpenAIError{
|
||||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||||
|
|||||||
330
controller/model_meta.go
Normal file
330
controller/model_meta.go
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAllModelsMeta 获取模型列表(分页)
|
||||||
|
func GetAllModelsMeta(c *gin.Context) {
|
||||||
|
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 批量填充附加字段,提升列表接口性能
|
||||||
|
enrichModels(modelsMeta)
|
||||||
|
var total int64
|
||||||
|
model.DB.Model(&model.Model{}).Count(&total)
|
||||||
|
|
||||||
|
// 统计供应商计数(全部数据,不受分页影响)
|
||||||
|
vendorCounts, _ := model.GetVendorModelCounts()
|
||||||
|
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(modelsMeta)
|
||||||
|
common.ApiSuccess(c, gin.H{
|
||||||
|
"items": modelsMeta,
|
||||||
|
"total": total,
|
||||||
|
"page": pageInfo.GetPage(),
|
||||||
|
"page_size": pageInfo.GetPageSize(),
|
||||||
|
"vendor_counts": vendorCounts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchModelsMeta 搜索模型列表
|
||||||
|
func SearchModelsMeta(c *gin.Context) {
|
||||||
|
|
||||||
|
keyword := c.Query("keyword")
|
||||||
|
vendor := c.Query("vendor")
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
|
||||||
|
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 批量填充附加字段,提升列表接口性能
|
||||||
|
enrichModels(modelsMeta)
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(modelsMeta)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelMeta 根据 ID 获取单条模型信息
|
||||||
|
func GetModelMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var m model.Model
|
||||||
|
if err := model.DB.First(&m, id).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
enrichModels([]*model.Model{&m})
|
||||||
|
common.ApiSuccess(c, &m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateModelMeta 新建模型
|
||||||
|
func CreateModelMeta(c *gin.Context) {
|
||||||
|
var m model.Model
|
||||||
|
if err := c.ShouldBindJSON(&m); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.ModelName == "" {
|
||||||
|
common.ApiErrorMsg(c, "模型名称不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "模型名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Insert(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model.RefreshPricing()
|
||||||
|
common.ApiSuccess(c, &m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModelMeta 更新模型
|
||||||
|
func UpdateModelMeta(c *gin.Context) {
|
||||||
|
statusOnly := c.Query("status_only") == "true"
|
||||||
|
|
||||||
|
var m model.Model
|
||||||
|
if err := c.ShouldBindJSON(&m); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.Id == 0 {
|
||||||
|
common.ApiErrorMsg(c, "缺少模型 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusOnly {
|
||||||
|
// 只更新状态,防止误清空其他字段
|
||||||
|
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "模型名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Update(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model.RefreshPricing()
|
||||||
|
common.ApiSuccess(c, &m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteModelMeta 删除模型
|
||||||
|
func DeleteModelMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model.RefreshPricing()
|
||||||
|
common.ApiSuccess(c, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
|
||||||
|
func enrichModels(models []*model.Model) {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) 拆分精确与规则匹配
|
||||||
|
exactNames := make([]string, 0)
|
||||||
|
exactIdx := make(map[string][]int) // modelName -> indices in models
|
||||||
|
ruleIndices := make([]int, 0)
|
||||||
|
for i, m := range models {
|
||||||
|
if m == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m.NameRule == model.NameRuleExact {
|
||||||
|
exactNames = append(exactNames, m.ModelName)
|
||||||
|
exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
|
||||||
|
} else {
|
||||||
|
ruleIndices = append(ruleIndices, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 批量查询精确模型的绑定渠道
|
||||||
|
channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
|
||||||
|
|
||||||
|
// 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
|
||||||
|
for name, indices := range exactIdx {
|
||||||
|
chs := channelsByModel[name]
|
||||||
|
for _, idx := range indices {
|
||||||
|
mm := models[idx]
|
||||||
|
if mm.Endpoints == "" {
|
||||||
|
eps := model.GetModelSupportEndpointTypes(mm.ModelName)
|
||||||
|
if b, err := json.Marshal(eps); err == nil {
|
||||||
|
mm.Endpoints = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mm.BoundChannels = chs
|
||||||
|
mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
|
||||||
|
mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ruleIndices) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 一次性读取定价缓存,内存匹配所有规则模型
|
||||||
|
pricings := model.GetPricing()
|
||||||
|
|
||||||
|
// 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
|
||||||
|
matchedNamesByIdx := make(map[int][]string)
|
||||||
|
endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
|
||||||
|
groupSetByIdx := make(map[int]map[string]struct{})
|
||||||
|
quotaSetByIdx := make(map[int]map[int]struct{})
|
||||||
|
|
||||||
|
for _, p := range pricings {
|
||||||
|
for _, idx := range ruleIndices {
|
||||||
|
mm := models[idx]
|
||||||
|
var matched bool
|
||||||
|
switch mm.NameRule {
|
||||||
|
case model.NameRulePrefix:
|
||||||
|
matched = strings.HasPrefix(p.ModelName, mm.ModelName)
|
||||||
|
case model.NameRuleSuffix:
|
||||||
|
matched = strings.HasSuffix(p.ModelName, mm.ModelName)
|
||||||
|
case model.NameRuleContains:
|
||||||
|
matched = strings.Contains(p.ModelName, mm.ModelName)
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
|
||||||
|
|
||||||
|
es := endpointSetByIdx[idx]
|
||||||
|
if es == nil {
|
||||||
|
es = make(map[constant.EndpointType]struct{})
|
||||||
|
endpointSetByIdx[idx] = es
|
||||||
|
}
|
||||||
|
for _, et := range p.SupportedEndpointTypes {
|
||||||
|
es[et] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
gs := groupSetByIdx[idx]
|
||||||
|
if gs == nil {
|
||||||
|
gs = make(map[string]struct{})
|
||||||
|
groupSetByIdx[idx] = gs
|
||||||
|
}
|
||||||
|
for _, g := range p.EnableGroup {
|
||||||
|
gs[g] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
qs := quotaSetByIdx[idx]
|
||||||
|
if qs == nil {
|
||||||
|
qs = make(map[int]struct{})
|
||||||
|
quotaSetByIdx[idx] = qs
|
||||||
|
}
|
||||||
|
qs[p.QuotaType] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5) 汇总所有匹配到的模型名称,批量查询一次渠道
|
||||||
|
allMatchedSet := make(map[string]struct{})
|
||||||
|
for _, names := range matchedNamesByIdx {
|
||||||
|
for _, n := range names {
|
||||||
|
allMatchedSet[n] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allMatched := make([]string, 0, len(allMatchedSet))
|
||||||
|
for n := range allMatchedSet {
|
||||||
|
allMatched = append(allMatched, n)
|
||||||
|
}
|
||||||
|
matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
|
||||||
|
|
||||||
|
// 6) 回填每个规则模型的并集信息
|
||||||
|
for _, idx := range ruleIndices {
|
||||||
|
mm := models[idx]
|
||||||
|
|
||||||
|
// 端点并集 -> 序列化
|
||||||
|
if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
|
||||||
|
eps := make([]constant.EndpointType, 0, len(es))
|
||||||
|
for et := range es {
|
||||||
|
eps = append(eps, et)
|
||||||
|
}
|
||||||
|
if b, err := json.Marshal(eps); err == nil {
|
||||||
|
mm.Endpoints = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 分组并集
|
||||||
|
if gs, ok := groupSetByIdx[idx]; ok {
|
||||||
|
groups := make([]string, 0, len(gs))
|
||||||
|
for g := range gs {
|
||||||
|
groups = append(groups, g)
|
||||||
|
}
|
||||||
|
mm.EnableGroups = groups
|
||||||
|
}
|
||||||
|
|
||||||
|
// 配额类型集合(保持去重并排序)
|
||||||
|
if qs, ok := quotaSetByIdx[idx]; ok {
|
||||||
|
arr := make([]int, 0, len(qs))
|
||||||
|
for k := range qs {
|
||||||
|
arr = append(arr, k)
|
||||||
|
}
|
||||||
|
sort.Ints(arr)
|
||||||
|
mm.QuotaTypes = arr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 渠道并集
|
||||||
|
names := matchedNamesByIdx[idx]
|
||||||
|
channelSet := make(map[string]model.BoundChannel)
|
||||||
|
for _, n := range names {
|
||||||
|
for _, ch := range matchedChannelsByModel[n] {
|
||||||
|
key := ch.Name + "_" + strconv.Itoa(ch.Type)
|
||||||
|
channelSet[key] = ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(channelSet) > 0 {
|
||||||
|
chs := make([]model.BoundChannel, 0, len(channelSet))
|
||||||
|
for _, ch := range channelSet {
|
||||||
|
chs = append(chs, ch)
|
||||||
|
}
|
||||||
|
mm.BoundChannels = chs
|
||||||
|
}
|
||||||
|
|
||||||
|
// 匹配信息
|
||||||
|
mm.MatchedModels = names
|
||||||
|
mm.MatchedCount = len(names)
|
||||||
|
}
|
||||||
|
}
|
||||||
604
controller/model_sync.go
Normal file
604
controller/model_sync.go
Normal file
@@ -0,0 +1,604 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 上游地址
|
||||||
|
const (
|
||||||
|
upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
|
||||||
|
upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeLocale(locale string) (string, bool) {
|
||||||
|
l := strings.ToLower(strings.TrimSpace(locale))
|
||||||
|
switch l {
|
||||||
|
case "en", "zh", "ja":
|
||||||
|
return l, true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUpstreamBase() string {
|
||||||
|
return common.GetEnvOrDefaultString("SYNC_UPSTREAM_BASE", "https://basellm.github.io/llm-metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUpstreamURLs(locale string) (modelsURL, vendorsURL string) {
|
||||||
|
base := strings.TrimRight(getUpstreamBase(), "/")
|
||||||
|
if l, ok := normalizeLocale(locale); ok && l != "" {
|
||||||
|
return fmt.Sprintf("%s/api/i18n/%s/newapi/models.json", base, l),
|
||||||
|
fmt.Sprintf("%s/api/i18n/%s/newapi/vendors.json", base, l)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/api/newapi/models.json", base), fmt.Sprintf("%s/api/newapi/vendors.json", base)
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamEnvelope[T any] struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data []T `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamModel struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
Endpoints json.RawMessage `json:"endpoints"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
NameRule int `json:"name_rule"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
Tags string `json:"tags"`
|
||||||
|
VendorName string `json:"vendor_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamVendor struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
etagCache = make(map[string]string)
|
||||||
|
bodyCache = make(map[string][]byte)
|
||||||
|
cacheMutex sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
type overwriteField struct {
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
Fields []string `json:"fields"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type syncRequest struct {
|
||||||
|
Overwrite []overwriteField `json:"overwrite"`
|
||||||
|
Locale string `json:"locale"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPClient() *http.Client {
|
||||||
|
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 10)
|
||||||
|
dialer := &net.Dialer{Timeout: time.Duration(timeoutSec) * time.Second}
|
||||||
|
transport := &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: time.Duration(timeoutSec) * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
ResponseHeaderTimeout: time.Duration(timeoutSec) * time.Second,
|
||||||
|
}
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
host = addr
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(host, "github.io") {
|
||||||
|
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, "tcp6", addr)
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: transport}
|
||||||
|
}
|
||||||
|
|
||||||
|
var httpClient = newHTTPClient()
|
||||||
|
|
||||||
|
func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error {
|
||||||
|
var lastErr error
|
||||||
|
attempts := common.GetEnvOrDefault("SYNC_HTTP_RETRY", 3)
|
||||||
|
if attempts < 1 {
|
||||||
|
attempts = 1
|
||||||
|
}
|
||||||
|
baseDelay := 200 * time.Millisecond
|
||||||
|
maxMB := common.GetEnvOrDefault("SYNC_HTTP_MAX_MB", 10)
|
||||||
|
maxBytes := int64(maxMB) << 20
|
||||||
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// ETag conditional request
|
||||||
|
cacheMutex.RLock()
|
||||||
|
if et := etagCache[url]; et != "" {
|
||||||
|
req.Header.Set("If-None-Match", et)
|
||||||
|
}
|
||||||
|
cacheMutex.RUnlock()
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
// backoff with jitter
|
||||||
|
sleep := baseDelay * time.Duration(1<<attempt)
|
||||||
|
jitter := time.Duration(rand.Intn(150)) * time.Millisecond
|
||||||
|
time.Sleep(sleep + jitter)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
func() {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case http.StatusOK:
|
||||||
|
// read body into buffer for caching and flexible decode
|
||||||
|
limited := io.LimitReader(resp.Body, maxBytes)
|
||||||
|
buf, err := io.ReadAll(limited)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// cache body and ETag
|
||||||
|
cacheMutex.Lock()
|
||||||
|
if et := resp.Header.Get("ETag"); et != "" {
|
||||||
|
etagCache[url] = et
|
||||||
|
}
|
||||||
|
bodyCache[url] = buf
|
||||||
|
cacheMutex.Unlock()
|
||||||
|
|
||||||
|
// Try decode as envelope first
|
||||||
|
if err := json.Unmarshal(buf, out); err != nil {
|
||||||
|
// Try decode as pure array
|
||||||
|
var arr []T
|
||||||
|
if err2 := json.Unmarshal(buf, &arr); err2 != nil {
|
||||||
|
lastErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out.Success = true
|
||||||
|
out.Data = arr
|
||||||
|
out.Message = ""
|
||||||
|
} else {
|
||||||
|
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||||
|
out.Success = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastErr = nil
|
||||||
|
case http.StatusNotModified:
|
||||||
|
// use cache
|
||||||
|
cacheMutex.RLock()
|
||||||
|
buf := bodyCache[url]
|
||||||
|
cacheMutex.RUnlock()
|
||||||
|
if len(buf) == 0 {
|
||||||
|
lastErr = errors.New("cache miss for 304 response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(buf, out); err != nil {
|
||||||
|
var arr []T
|
||||||
|
if err2 := json.Unmarshal(buf, &arr); err2 != nil {
|
||||||
|
lastErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out.Success = true
|
||||||
|
out.Data = arr
|
||||||
|
out.Message = ""
|
||||||
|
} else {
|
||||||
|
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||||
|
out.Success = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastErr = nil
|
||||||
|
default:
|
||||||
|
lastErr = errors.New(resp.Status)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if lastErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sleep := baseDelay * time.Duration(1<<attempt)
|
||||||
|
jitter := time.Duration(rand.Intn(150)) * time.Millisecond
|
||||||
|
time.Sleep(sleep + jitter)
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureVendorID(vendorName string, vendorByName map[string]upstreamVendor, vendorIDCache map[string]int, createdVendors *int) int {
|
||||||
|
if vendorName == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if id, ok := vendorIDCache[vendorName]; ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
var existing model.Vendor
|
||||||
|
if err := model.DB.Where("name = ?", vendorName).First(&existing).Error; err == nil {
|
||||||
|
vendorIDCache[vendorName] = existing.Id
|
||||||
|
return existing.Id
|
||||||
|
}
|
||||||
|
uv := vendorByName[vendorName]
|
||||||
|
v := &model.Vendor{
|
||||||
|
Name: vendorName,
|
||||||
|
Description: uv.Description,
|
||||||
|
Icon: coalesce(uv.Icon, ""),
|
||||||
|
Status: chooseStatus(uv.Status, 1),
|
||||||
|
}
|
||||||
|
if err := v.Insert(); err == nil {
|
||||||
|
*createdVendors++
|
||||||
|
vendorIDCache[vendorName] = v.Id
|
||||||
|
return v.Id
|
||||||
|
}
|
||||||
|
vendorIDCache[vendorName] = 0
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncUpstreamModels 同步上游模型与供应商,仅对「未配置模型」生效
|
||||||
|
func SyncUpstreamModels(c *gin.Context) {
|
||||||
|
var req syncRequest
|
||||||
|
// 允许空体
|
||||||
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
// 1) 获取未配置模型列表
|
||||||
|
missing, err := model.GetMissingModels()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(missing) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true, "data": gin.H{
|
||||||
|
"created_models": 0,
|
||||||
|
"created_vendors": 0,
|
||||||
|
"skipped_models": []string{},
|
||||||
|
}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 拉取上游 vendors 与 models
|
||||||
|
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15)
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
modelsURL, vendorsURL := getUpstreamURLs(req.Locale)
|
||||||
|
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||||
|
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||||
|
var fetchErr error
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
// vendor 失败不拦截
|
||||||
|
_ = fetchJSON(ctx, vendorsURL, &vendorsEnv)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil {
|
||||||
|
fetchErr = err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
if fetchErr != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": req.Locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 建立映射
|
||||||
|
vendorByName := make(map[string]upstreamVendor)
|
||||||
|
for _, v := range vendorsEnv.Data {
|
||||||
|
if v.Name != "" {
|
||||||
|
vendorByName[v.Name] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelByName := make(map[string]upstreamModel)
|
||||||
|
for _, m := range modelsEnv.Data {
|
||||||
|
if m.ModelName != "" {
|
||||||
|
modelByName[m.ModelName] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 执行同步:仅创建缺失模型;若上游缺失该模型则跳过
|
||||||
|
createdModels := 0
|
||||||
|
createdVendors := 0
|
||||||
|
updatedModels := 0
|
||||||
|
var skipped []string
|
||||||
|
var createdList []string
|
||||||
|
var updatedList []string
|
||||||
|
|
||||||
|
// 本地缓存:vendorName -> id
|
||||||
|
vendorIDCache := make(map[string]int)
|
||||||
|
|
||||||
|
for _, name := range missing {
|
||||||
|
up, ok := modelByName[name]
|
||||||
|
if !ok {
|
||||||
|
skipped = append(skipped, name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时)
|
||||||
|
var existing model.Model
|
||||||
|
if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil {
|
||||||
|
if existing.SyncOfficial == 0 {
|
||||||
|
skipped = append(skipped, name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保 vendor 存在
|
||||||
|
vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||||
|
|
||||||
|
// 创建模型
|
||||||
|
mi := &model.Model{
|
||||||
|
ModelName: name,
|
||||||
|
Description: up.Description,
|
||||||
|
Icon: up.Icon,
|
||||||
|
Tags: up.Tags,
|
||||||
|
VendorID: vendorID,
|
||||||
|
Status: chooseStatus(up.Status, 1),
|
||||||
|
NameRule: up.NameRule,
|
||||||
|
}
|
||||||
|
if err := mi.Insert(); err == nil {
|
||||||
|
createdModels++
|
||||||
|
createdList = append(createdList, name)
|
||||||
|
} else {
|
||||||
|
skipped = append(skipped, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 处理可选覆盖(更新本地已有模型的差异字段)
|
||||||
|
if len(req.Overwrite) > 0 {
|
||||||
|
// vendorIDCache 已用于创建阶段,可复用
|
||||||
|
for _, ow := range req.Overwrite {
|
||||||
|
up, ok := modelByName[ow.ModelName]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var local model.Model
|
||||||
|
if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过被禁用官方同步的模型
|
||||||
|
if local.SyncOfficial == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 映射 vendor
|
||||||
|
newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||||
|
|
||||||
|
// 应用字段覆盖(事务)
|
||||||
|
_ = model.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
needUpdate := false
|
||||||
|
if containsField(ow.Fields, "description") {
|
||||||
|
local.Description = up.Description
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "icon") {
|
||||||
|
local.Icon = up.Icon
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "tags") {
|
||||||
|
local.Tags = up.Tags
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "vendor") {
|
||||||
|
local.VendorID = newVendorID
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "name_rule") {
|
||||||
|
local.NameRule = up.NameRule
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "status") {
|
||||||
|
local.Status = chooseStatus(up.Status, local.Status)
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if !needUpdate {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := tx.Save(&local).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updatedModels++
|
||||||
|
updatedList = append(updatedList, ow.ModelName)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"created_models": createdModels,
|
||||||
|
"created_vendors": createdVendors,
|
||||||
|
"updated_models": updatedModels,
|
||||||
|
"skipped_models": skipped,
|
||||||
|
"created_list": createdList,
|
||||||
|
"updated_list": updatedList,
|
||||||
|
"source": gin.H{
|
||||||
|
"locale": req.Locale,
|
||||||
|
"models_url": modelsURL,
|
||||||
|
"vendors_url": vendorsURL,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsField(fields []string, key string) bool {
|
||||||
|
key = strings.ToLower(strings.TrimSpace(key))
|
||||||
|
for _, f := range fields {
|
||||||
|
if strings.ToLower(strings.TrimSpace(f)) == key {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func coalesce(a, b string) string {
|
||||||
|
if strings.TrimSpace(a) != "" {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func chooseStatus(primary, fallback int) int {
|
||||||
|
if primary == 0 && fallback != 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
if primary != 0 {
|
||||||
|
return primary
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择)
|
||||||
|
func SyncUpstreamPreview(c *gin.Context) {
|
||||||
|
// 1) 拉取上游数据
|
||||||
|
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15)
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
locale := c.Query("locale")
|
||||||
|
modelsURL, vendorsURL := getUpstreamURLs(locale)
|
||||||
|
|
||||||
|
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||||
|
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||||
|
var fetchErr error
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = fetchJSON(ctx, vendorsURL, &vendorsEnv)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil {
|
||||||
|
fetchErr = err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
if fetchErr != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vendorByName := make(map[string]upstreamVendor)
|
||||||
|
for _, v := range vendorsEnv.Data {
|
||||||
|
if v.Name != "" {
|
||||||
|
vendorByName[v.Name] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelByName := make(map[string]upstreamModel)
|
||||||
|
upstreamNames := make([]string, 0, len(modelsEnv.Data))
|
||||||
|
for _, m := range modelsEnv.Data {
|
||||||
|
if m.ModelName != "" {
|
||||||
|
modelByName[m.ModelName] = m
|
||||||
|
upstreamNames = append(upstreamNames, m.ModelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 本地已有模型
|
||||||
|
var locals []model.Model
|
||||||
|
if len(upstreamNames) > 0 {
|
||||||
|
_ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// 本地 vendor 名称映射
|
||||||
|
vendorIdSet := make(map[int]struct{})
|
||||||
|
for _, m := range locals {
|
||||||
|
if m.VendorID != 0 {
|
||||||
|
vendorIdSet[m.VendorID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vendorIDs := make([]int, 0, len(vendorIdSet))
|
||||||
|
for id := range vendorIdSet {
|
||||||
|
vendorIDs = append(vendorIDs, id)
|
||||||
|
}
|
||||||
|
idToVendorName := make(map[int]string)
|
||||||
|
if len(vendorIDs) > 0 {
|
||||||
|
var dbVendors []model.Vendor
|
||||||
|
_ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error
|
||||||
|
for _, v := range dbVendors {
|
||||||
|
idToVendorName[v.Id] = v.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 缺失且上游存在的模型
|
||||||
|
missingList, _ := model.GetMissingModels()
|
||||||
|
var missing []string
|
||||||
|
for _, name := range missingList {
|
||||||
|
if _, ok := modelByName[name]; ok {
|
||||||
|
missing = append(missing, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 计算冲突字段
|
||||||
|
type conflictField struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Local interface{} `json:"local"`
|
||||||
|
Upstream interface{} `json:"upstream"`
|
||||||
|
}
|
||||||
|
type conflictItem struct {
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
Fields []conflictField `json:"fields"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var conflicts []conflictItem
|
||||||
|
for _, local := range locals {
|
||||||
|
up, ok := modelByName[local.ModelName]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields := make([]conflictField, 0, 6)
|
||||||
|
if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) {
|
||||||
|
fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description})
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) {
|
||||||
|
fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon})
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) {
|
||||||
|
fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags})
|
||||||
|
}
|
||||||
|
// vendor 对比使用名称
|
||||||
|
localVendor := idToVendorName[local.VendorID]
|
||||||
|
if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) {
|
||||||
|
fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName})
|
||||||
|
}
|
||||||
|
if local.NameRule != up.NameRule {
|
||||||
|
fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule})
|
||||||
|
}
|
||||||
|
if local.Status != chooseStatus(up.Status, local.Status) {
|
||||||
|
fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status})
|
||||||
|
}
|
||||||
|
if len(fields) > 0 {
|
||||||
|
conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"missing": missing,
|
||||||
|
"conflicts": conflicts,
|
||||||
|
"source": gin.H{
|
||||||
|
"locale": locale,
|
||||||
|
"models_url": modelsURL,
|
||||||
|
"vendors_url": vendorsURL,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
|
||||||
"one-api/setting/system_setting"
|
"one-api/setting/system_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -45,7 +44,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||||
values.Set("code", code)
|
values.Set("code", code)
|
||||||
values.Set("grant_type", "authorization_code")
|
values.Set("grant_type", "authorization_code")
|
||||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
|
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
|
||||||
formData := values.Encode()
|
formData := values.Encode()
|
||||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -69,7 +68,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oidcResponse.AccessToken == "" {
|
if oidcResponse.AccessToken == "" {
|
||||||
common.SysError("OIDC 获取 Token 失败,请检查设置!")
|
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,7 +84,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
}
|
}
|
||||||
defer res2.Body.Close()
|
defer res2.Body.Close()
|
||||||
if res2.StatusCode != http.StatusOK {
|
if res2.StatusCode != http.StatusOK {
|
||||||
common.SysError("OIDC 获取用户信息失败!请检查设置!")
|
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,7 +94,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||||
common.SysError("OIDC 获取用户信息为空!请检查设置!")
|
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||||
}
|
}
|
||||||
return &oidcUser, nil
|
return &oidcUser, nil
|
||||||
@@ -126,10 +125,7 @@ func OidcAuth(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
oidcUser, err := getOidcUserInfoByCode(code)
|
oidcUser, err := getOidcUserInfoByCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := model.User{
|
user := model.User{
|
||||||
@@ -195,10 +191,7 @@ func OidcBind(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
oidcUser, err := getOidcUserInfoByCode(code)
|
oidcUser, err := getOidcUserInfoByCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := model.User{
|
user := model.User{
|
||||||
@@ -217,19 +210,13 @@ func OidcBind(c *gin.Context) {
|
|||||||
user.Id = id.(int)
|
user.Id = id.(int)
|
||||||
err = user.FillUserById()
|
err = user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.OidcId = oidcUser.OpenID
|
user.OidcId = oidcUser.OpenID
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
@@ -35,8 +36,13 @@ func GetOptions(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OptionUpdateRequest struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
Value any `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
func UpdateOption(c *gin.Context) {
|
func UpdateOption(c *gin.Context) {
|
||||||
var option model.Option
|
var option OptionUpdateRequest
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&option)
|
err := json.NewDecoder(c.Request.Body).Decode(&option)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
@@ -45,6 +51,16 @@ func UpdateOption(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
switch option.Value.(type) {
|
||||||
|
case bool:
|
||||||
|
option.Value = common.Interface2String(option.Value.(bool))
|
||||||
|
case float64:
|
||||||
|
option.Value = common.Interface2String(option.Value.(float64))
|
||||||
|
case int:
|
||||||
|
option.Value = common.Interface2String(option.Value.(int))
|
||||||
|
default:
|
||||||
|
option.Value = fmt.Sprintf("%v", option.Value)
|
||||||
|
}
|
||||||
switch option.Key {
|
switch option.Key {
|
||||||
case "GitHubOAuthEnabled":
|
case "GitHubOAuthEnabled":
|
||||||
if option.Value == "true" && common.GitHubClientId == "" {
|
if option.Value == "true" && common.GitHubClientId == "" {
|
||||||
@@ -104,7 +120,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
err = ratio_setting.CheckGroupRatio(option.Value.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -112,8 +128,35 @@ func UpdateOption(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case "ImageRatio":
|
||||||
|
err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "图片倍率设置失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "AudioRatio":
|
||||||
|
err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "音频倍率设置失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "AudioCompletionRatio":
|
||||||
|
err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "音频补全倍率设置失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
case "ModelRequestRateLimitGroup":
|
case "ModelRequestRateLimitGroup":
|
||||||
err = setting.CheckModelRequestRateLimitGroup(option.Value)
|
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -122,7 +165,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "console_setting.api_info":
|
case "console_setting.api_info":
|
||||||
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
|
err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -131,7 +174,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "console_setting.announcements":
|
case "console_setting.announcements":
|
||||||
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
|
err = console_setting.ValidateConsoleSettings(option.Value.(string), "Announcements")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -140,7 +183,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "console_setting.faq":
|
case "console_setting.faq":
|
||||||
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
|
err = console_setting.ValidateConsoleSettings(option.Value.(string), "FAQ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -149,7 +192,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "console_setting.uptime_kuma_groups":
|
case "console_setting.uptime_kuma_groups":
|
||||||
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
|
err = console_setting.ValidateConsoleSettings(option.Value.(string), "UptimeKumaGroups")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -158,12 +201,9 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = model.UpdateOption(option.Key, option.Value)
|
err = model.UpdateOption(option.Key, option.Value.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -5,10 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -28,57 +26,35 @@ func Playground(c *gin.Context) {
|
|||||||
|
|
||||||
useAccessToken := c.GetBool("use_access_token")
|
useAccessToken := c.GetBool("use_access_token")
|
||||||
if useAccessToken {
|
if useAccessToken {
|
||||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
playgroundRequest := &dto.PlayGroundRequest{}
|
group := c.GetString("group")
|
||||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
modelName := c.GetString("original_model")
|
||||||
if err != nil {
|
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if playgroundRequest.Model == "" {
|
|
||||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("original_model", playgroundRequest.Model)
|
|
||||||
group := playgroundRequest.Group
|
|
||||||
userGroup := c.GetString("group")
|
|
||||||
|
|
||||||
if group == "" {
|
|
||||||
group = userGroup
|
|
||||||
} else {
|
|
||||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
|
||||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("group", group)
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
//c.Set("token_name", "playground-"+group)
|
|
||||||
|
// Write user context to ensure acceptUnsetRatio is available
|
||||||
|
userCache, err := model.GetUserCache(userId)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userCache.WriteContext(c)
|
||||||
|
|
||||||
tempToken := &model.Token{
|
tempToken := &model.Token{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Name: fmt.Sprintf("playground-%s", group),
|
Name: fmt.Sprintf("playground-%s", group),
|
||||||
Group: group,
|
Group: group,
|
||||||
}
|
}
|
||||||
_ = middleware.SetupContextForToken(c, tempToken)
|
_ = middleware.SetupContextForToken(c, tempToken)
|
||||||
_, err = getChannel(c, group, playgroundRequest.Model, 0)
|
_, newAPIError = getChannel(c, group, modelName, 0)
|
||||||
if err != nil {
|
if newAPIError != nil {
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
|
|
||||||
// Write user context to ensure acceptUnsetRatio is available
|
Relay(c, types.RelayFormatOpenAI)
|
||||||
userCache, err := model.GetUserCache(userId)
|
|
||||||
if err != nil {
|
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
userCache.WriteContext(c)
|
|
||||||
Relay(c)
|
|
||||||
}
|
}
|
||||||
|
|||||||
90
controller/prefill_group.go
Normal file
90
controller/prefill_group.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
|
||||||
|
func GetPrefillGroups(c *gin.Context) {
|
||||||
|
groupType := c.Query("type")
|
||||||
|
groups, err := model.GetAllPrefillGroups(groupType)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePrefillGroup 创建新的预填组
|
||||||
|
func CreatePrefillGroup(c *gin.Context) {
|
||||||
|
var g model.PrefillGroup
|
||||||
|
if err := c.ShouldBindJSON(&g); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if g.Name == "" || g.Type == "" {
|
||||||
|
common.ApiErrorMsg(c, "组名称和类型不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 创建前检查名称
|
||||||
|
if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "组名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Insert(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePrefillGroup 更新预填组
|
||||||
|
func UpdatePrefillGroup(c *gin.Context) {
|
||||||
|
var g model.PrefillGroup
|
||||||
|
if err := c.ShouldBindJSON(&g); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if g.Id == 0 {
|
||||||
|
common.ApiErrorMsg(c, "缺少组 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "组名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Update(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePrefillGroup 删除预填组
|
||||||
|
func DeletePrefillGroup(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := model.DeletePrefillGroupByID(id); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, nil)
|
||||||
|
}
|
||||||
@@ -39,10 +39,13 @@ func GetPricing(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": pricing,
|
"data": pricing,
|
||||||
"group_ratio": groupRatio,
|
"vendors": model.GetVendors(),
|
||||||
"usable_group": usableGroup,
|
"group_ratio": groupRatio,
|
||||||
|
"usable_group": usableGroup,
|
||||||
|
"supported_endpoint": model.GetSupportedEndpointMap(),
|
||||||
|
"auto_groups": setting.AutoGroups,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,24 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetRatioConfig(c *gin.Context) {
|
func GetRatioConfig(c *gin.Context) {
|
||||||
if !ratio_setting.IsExposeRatioEnabled() {
|
if !ratio_setting.IsExposeRatioEnabled() {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "倍率配置接口未启用",
|
"message": "倍率配置接口未启用",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": ratio_setting.GetExposedData(),
|
"data": ratio_setting.GetExposedData(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,474 +1,539 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"io"
|
||||||
"strings"
|
"net"
|
||||||
"sync"
|
"net/http"
|
||||||
"time"
|
"one-api/logger"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"one-api/common"
|
"one-api/dto"
|
||||||
"one-api/dto"
|
"one-api/model"
|
||||||
"one-api/model"
|
"one-api/setting/ratio_setting"
|
||||||
"one-api/setting/ratio_setting"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultTimeoutSeconds = 10
|
defaultTimeoutSeconds = 10
|
||||||
defaultEndpoint = "/api/ratio_config"
|
defaultEndpoint = "/api/ratio_config"
|
||||||
maxConcurrentFetches = 8
|
maxConcurrentFetches = 8
|
||||||
|
maxRatioConfigBytes = 10 << 20 // 10MB
|
||||||
|
floatEpsilon = 1e-9
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func nearlyEqual(a, b float64) bool {
|
||||||
|
if a > b {
|
||||||
|
return a-b < floatEpsilon
|
||||||
|
}
|
||||||
|
return b-a < floatEpsilon
|
||||||
|
}
|
||||||
|
|
||||||
|
func valuesEqual(a, b interface{}) bool {
|
||||||
|
af, aok := a.(float64)
|
||||||
|
bf, bok := b.(float64)
|
||||||
|
if aok && bok {
|
||||||
|
return nearlyEqual(af, bf)
|
||||||
|
}
|
||||||
|
return a == b
|
||||||
|
}
|
||||||
|
|
||||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||||
|
|
||||||
type upstreamResult struct {
|
type upstreamResult struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Data map[string]any `json:"data,omitempty"`
|
Data map[string]any `json:"data,omitempty"`
|
||||||
Err string `json:"err,omitempty"`
|
Err string `json:"err,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchUpstreamRatios(c *gin.Context) {
|
func FetchUpstreamRatios(c *gin.Context) {
|
||||||
var req dto.UpstreamRequest
|
var req dto.UpstreamRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Timeout <= 0 {
|
if req.Timeout <= 0 {
|
||||||
req.Timeout = defaultTimeoutSeconds
|
req.Timeout = defaultTimeoutSeconds
|
||||||
}
|
}
|
||||||
|
|
||||||
var upstreams []dto.UpstreamDTO
|
var upstreams []dto.UpstreamDTO
|
||||||
|
|
||||||
if len(req.Upstreams) > 0 {
|
if len(req.Upstreams) > 0 {
|
||||||
for _, u := range req.Upstreams {
|
for _, u := range req.Upstreams {
|
||||||
if strings.HasPrefix(u.BaseURL, "http") {
|
if strings.HasPrefix(u.BaseURL, "http") {
|
||||||
if u.Endpoint == "" {
|
if u.Endpoint == "" {
|
||||||
u.Endpoint = defaultEndpoint
|
u.Endpoint = defaultEndpoint
|
||||||
}
|
}
|
||||||
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||||
upstreams = append(upstreams, u)
|
upstreams = append(upstreams, u)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if len(req.ChannelIDs) > 0 {
|
} else if len(req.ChannelIDs) > 0 {
|
||||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||||
for _, id64 := range req.ChannelIDs {
|
for _, id64 := range req.ChannelIDs {
|
||||||
intIds = append(intIds, int(id64))
|
intIds = append(intIds, int(id64))
|
||||||
}
|
}
|
||||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, ch := range dbChannels {
|
for _, ch := range dbChannels {
|
||||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||||
ID: ch.Id,
|
ID: ch.Id,
|
||||||
Name: ch.Name,
|
Name: ch.Name,
|
||||||
BaseURL: strings.TrimRight(base, "/"),
|
BaseURL: strings.TrimRight(base, "/"),
|
||||||
Endpoint: "",
|
Endpoint: "",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(upstreams) == 0 {
|
if len(upstreams) == 0 {
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
ch := make(chan upstreamResult, len(upstreams))
|
ch := make(chan upstreamResult, len(upstreams))
|
||||||
|
|
||||||
sem := make(chan struct{}, maxConcurrentFetches)
|
sem := make(chan struct{}, maxConcurrentFetches)
|
||||||
|
|
||||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||||
|
transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
host = addr
|
||||||
|
}
|
||||||
|
// 对 github.io 优先尝试 IPv4,失败则回退 IPv6
|
||||||
|
if strings.HasSuffix(host, "github.io") {
|
||||||
|
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, "tcp6", addr)
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: transport}
|
||||||
|
|
||||||
for _, chn := range upstreams {
|
for _, chn := range upstreams {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(chItem dto.UpstreamDTO) {
|
go func(chItem dto.UpstreamDTO) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
sem <- struct{}{}
|
sem <- struct{}{}
|
||||||
defer func() { <-sem }()
|
defer func() { <-sem }()
|
||||||
|
|
||||||
endpoint := chItem.Endpoint
|
endpoint := chItem.Endpoint
|
||||||
if endpoint == "" {
|
var fullURL string
|
||||||
endpoint = defaultEndpoint
|
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
||||||
} else if !strings.HasPrefix(endpoint, "/") {
|
fullURL = endpoint
|
||||||
endpoint = "/" + endpoint
|
} else {
|
||||||
}
|
if endpoint == "" {
|
||||||
fullURL := chItem.BaseURL + endpoint
|
endpoint = defaultEndpoint
|
||||||
|
} else if !strings.HasPrefix(endpoint, "/") {
|
||||||
|
endpoint = "/" + endpoint
|
||||||
|
}
|
||||||
|
fullURL = chItem.BaseURL + endpoint
|
||||||
|
}
|
||||||
|
|
||||||
uniqueName := chItem.Name
|
uniqueName := chItem.Name
|
||||||
if chItem.ID != 0 {
|
if chItem.ID != 0 {
|
||||||
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.Do(httpReq)
|
// 简单重试:最多 3 次,指数退避
|
||||||
if err != nil {
|
var resp *http.Response
|
||||||
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
var lastErr error
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
for attempt := 0; attempt < 3; attempt++ {
|
||||||
return
|
resp, lastErr = client.Do(httpReq)
|
||||||
}
|
if lastErr == nil {
|
||||||
defer resp.Body.Close()
|
break
|
||||||
if resp.StatusCode != http.StatusOK {
|
}
|
||||||
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
}
|
||||||
return
|
if lastErr != nil {
|
||||||
}
|
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
|
||||||
// 兼容两种上游接口格式:
|
ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
|
||||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
return
|
||||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
}
|
||||||
var body struct {
|
defer resp.Body.Close()
|
||||||
Success bool `json:"success"`
|
if resp.StatusCode != http.StatusOK {
|
||||||
Data json.RawMessage `json:"data"`
|
logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||||
Message string `json:"message"`
|
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
// Content-Type 和响应体大小校验
|
||||||
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
|
||||||
return
|
}
|
||||||
}
|
limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
|
||||||
|
// 兼容两种上游接口格式:
|
||||||
|
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||||
|
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||||
|
var body struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
if !body.Success {
|
if err := json.NewDecoder(limited).Decode(&body); err != nil {
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||||
return
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 尝试按 type1 解析
|
if !body.Success {
|
||||||
var type1Data map[string]any
|
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
return
|
||||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
}
|
||||||
isType1 := false
|
|
||||||
for _, rt := range ratioTypes {
|
|
||||||
if _, ok := type1Data[rt]; ok {
|
|
||||||
isType1 = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if isType1 {
|
|
||||||
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
// 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容)
|
||||||
var pricingItems []struct {
|
|
||||||
ModelName string `json:"model_name"`
|
|
||||||
QuotaType int `json:"quota_type"`
|
|
||||||
ModelRatio float64 `json:"model_ratio"`
|
|
||||||
ModelPrice float64 `json:"model_price"`
|
|
||||||
CompletionRatio float64 `json:"completion_ratio"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
|
||||||
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
modelRatioMap := make(map[string]float64)
|
// 尝试按 type1 解析
|
||||||
completionRatioMap := make(map[string]float64)
|
var type1Data map[string]any
|
||||||
modelPriceMap := make(map[string]float64)
|
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||||
|
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||||
|
isType1 := false
|
||||||
|
for _, rt := range ratioTypes {
|
||||||
|
if _, ok := type1Data[rt]; ok {
|
||||||
|
isType1 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isType1 {
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, item := range pricingItems {
|
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||||
if item.QuotaType == 1 {
|
var pricingItems []struct {
|
||||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
ModelName string `json:"model_name"`
|
||||||
} else {
|
QuotaType int `json:"quota_type"`
|
||||||
modelRatioMap[item.ModelName] = item.ModelRatio
|
ModelRatio float64 `json:"model_ratio"`
|
||||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
ModelPrice float64 `json:"model_price"`
|
||||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
CompletionRatio float64 `json:"completion_ratio"`
|
||||||
}
|
}
|
||||||
}
|
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||||
|
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
converted := make(map[string]any)
|
modelRatioMap := make(map[string]float64)
|
||||||
|
completionRatioMap := make(map[string]float64)
|
||||||
|
modelPriceMap := make(map[string]float64)
|
||||||
|
|
||||||
if len(modelRatioMap) > 0 {
|
for _, item := range pricingItems {
|
||||||
ratioAny := make(map[string]any, len(modelRatioMap))
|
if item.QuotaType == 1 {
|
||||||
for k, v := range modelRatioMap {
|
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||||
ratioAny[k] = v
|
} else {
|
||||||
}
|
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||||
converted["model_ratio"] = ratioAny
|
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||||
}
|
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(completionRatioMap) > 0 {
|
converted := make(map[string]any)
|
||||||
compAny := make(map[string]any, len(completionRatioMap))
|
|
||||||
for k, v := range completionRatioMap {
|
|
||||||
compAny[k] = v
|
|
||||||
}
|
|
||||||
converted["completion_ratio"] = compAny
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(modelPriceMap) > 0 {
|
if len(modelRatioMap) > 0 {
|
||||||
priceAny := make(map[string]any, len(modelPriceMap))
|
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||||
for k, v := range modelPriceMap {
|
for k, v := range modelRatioMap {
|
||||||
priceAny[k] = v
|
ratioAny[k] = v
|
||||||
}
|
}
|
||||||
converted["model_price"] = priceAny
|
converted["model_ratio"] = ratioAny
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
if len(completionRatioMap) > 0 {
|
||||||
}(chn)
|
compAny := make(map[string]any, len(completionRatioMap))
|
||||||
}
|
for k, v := range completionRatioMap {
|
||||||
|
compAny[k] = v
|
||||||
|
}
|
||||||
|
converted["completion_ratio"] = compAny
|
||||||
|
}
|
||||||
|
|
||||||
wg.Wait()
|
if len(modelPriceMap) > 0 {
|
||||||
close(ch)
|
priceAny := make(map[string]any, len(modelPriceMap))
|
||||||
|
for k, v := range modelPriceMap {
|
||||||
|
priceAny[k] = v
|
||||||
|
}
|
||||||
|
converted["model_price"] = priceAny
|
||||||
|
}
|
||||||
|
|
||||||
localData := ratio_setting.GetExposedData()
|
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||||
|
}(chn)
|
||||||
|
}
|
||||||
|
|
||||||
var testResults []dto.TestResult
|
wg.Wait()
|
||||||
var successfulChannels []struct {
|
close(ch)
|
||||||
name string
|
|
||||||
data map[string]any
|
|
||||||
}
|
|
||||||
|
|
||||||
for r := range ch {
|
localData := ratio_setting.GetExposedData()
|
||||||
if r.Err != "" {
|
|
||||||
testResults = append(testResults, dto.TestResult{
|
|
||||||
Name: r.Name,
|
|
||||||
Status: "error",
|
|
||||||
Error: r.Err,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
testResults = append(testResults, dto.TestResult{
|
|
||||||
Name: r.Name,
|
|
||||||
Status: "success",
|
|
||||||
})
|
|
||||||
successfulChannels = append(successfulChannels, struct {
|
|
||||||
name string
|
|
||||||
data map[string]any
|
|
||||||
}{name: r.Name, data: r.Data})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
differences := buildDifferences(localData, successfulChannels)
|
var testResults []dto.TestResult
|
||||||
|
var successfulChannels []struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
for r := range ch {
|
||||||
"success": true,
|
if r.Err != "" {
|
||||||
"data": gin.H{
|
testResults = append(testResults, dto.TestResult{
|
||||||
"differences": differences,
|
Name: r.Name,
|
||||||
"test_results": testResults,
|
Status: "error",
|
||||||
},
|
Error: r.Err,
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
testResults = append(testResults, dto.TestResult{
|
||||||
|
Name: r.Name,
|
||||||
|
Status: "success",
|
||||||
|
})
|
||||||
|
successfulChannels = append(successfulChannels, struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}{name: r.Name, data: r.Data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
differences := buildDifferences(localData, successfulChannels)
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"differences": differences,
|
||||||
|
"test_results": testResults,
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||||
name string
|
name string
|
||||||
data map[string]any
|
data map[string]any
|
||||||
}) map[string]map[string]dto.DifferenceItem {
|
}) map[string]map[string]dto.DifferenceItem {
|
||||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||||
|
|
||||||
allModels := make(map[string]struct{})
|
allModels := make(map[string]struct{})
|
||||||
|
|
||||||
for _, ratioType := range ratioTypes {
|
|
||||||
if localRatioAny, ok := localData[ratioType]; ok {
|
|
||||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
|
||||||
for modelName := range localRatio {
|
|
||||||
allModels[modelName] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, channel := range successfulChannels {
|
|
||||||
for _, ratioType := range ratioTypes {
|
|
||||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
|
||||||
for modelName := range upstreamRatio {
|
|
||||||
allModels[modelName] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
confidenceMap := make(map[string]map[string]bool)
|
for _, ratioType := range ratioTypes {
|
||||||
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
// 预处理阶段:检查pricing接口的可信度
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
for _, channel := range successfulChannels {
|
for modelName := range localRatio {
|
||||||
confidenceMap[channel.name] = make(map[string]bool)
|
allModels[modelName] = struct{}{}
|
||||||
|
}
|
||||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
}
|
||||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
}
|
||||||
|
}
|
||||||
if hasModelRatio && hasCompletionRatio {
|
|
||||||
// 遍历所有模型,检查是否满足不可信条件
|
|
||||||
for modelName := range allModels {
|
|
||||||
// 默认为可信
|
|
||||||
confidenceMap[channel.name][modelName] = true
|
|
||||||
|
|
||||||
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
|
||||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
|
||||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
|
||||||
// 转换为float64进行比较
|
|
||||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
|
||||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
|
||||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
|
||||||
confidenceMap[channel.name][modelName] = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
|
||||||
for modelName := range allModels {
|
|
||||||
confidenceMap[channel.name][modelName] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for modelName := range allModels {
|
for _, channel := range successfulChannels {
|
||||||
for _, ratioType := range ratioTypes {
|
for _, ratioType := range ratioTypes {
|
||||||
var localValue interface{} = nil
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
if localRatioAny, ok := localData[ratioType]; ok {
|
for modelName := range upstreamRatio {
|
||||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
allModels[modelName] = struct{}{}
|
||||||
if val, exists := localRatio[modelName]; exists {
|
}
|
||||||
localValue = val
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
upstreamValues := make(map[string]interface{})
|
confidenceMap := make(map[string]map[string]bool)
|
||||||
confidenceValues := make(map[string]bool)
|
|
||||||
hasUpstreamValue := false
|
|
||||||
hasDifference := false
|
|
||||||
|
|
||||||
for _, channel := range successfulChannels {
|
// 预处理阶段:检查pricing接口的可信度
|
||||||
var upstreamValue interface{} = nil
|
for _, channel := range successfulChannels {
|
||||||
|
confidenceMap[channel.name] = make(map[string]bool)
|
||||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
|
||||||
if val, exists := upstreamRatio[modelName]; exists {
|
|
||||||
upstreamValue = val
|
|
||||||
hasUpstreamValue = true
|
|
||||||
|
|
||||||
if localValue != nil && localValue != val {
|
|
||||||
hasDifference = true
|
|
||||||
} else if localValue == val {
|
|
||||||
upstreamValue = "same"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if upstreamValue == nil && localValue == nil {
|
|
||||||
upstreamValue = "same"
|
|
||||||
}
|
|
||||||
|
|
||||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
|
||||||
hasDifference = true
|
|
||||||
}
|
|
||||||
|
|
||||||
upstreamValues[channel.name] = upstreamValue
|
|
||||||
|
|
||||||
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldInclude := false
|
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||||
|
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||||
if localValue != nil {
|
|
||||||
if hasDifference {
|
|
||||||
shouldInclude = true
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if hasUpstreamValue {
|
|
||||||
shouldInclude = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldInclude {
|
if hasModelRatio && hasCompletionRatio {
|
||||||
if differences[modelName] == nil {
|
// 遍历所有模型,检查是否满足不可信条件
|
||||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
for modelName := range allModels {
|
||||||
}
|
// 默认为可信
|
||||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
confidenceMap[channel.name][modelName] = true
|
||||||
Current: localValue,
|
|
||||||
Upstreams: upstreamValues,
|
|
||||||
Confidence: confidenceValues,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
channelHasDiff := make(map[string]bool)
|
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||||
for _, ratioMap := range differences {
|
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||||
for _, item := range ratioMap {
|
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||||
for chName, val := range item.Upstreams {
|
// 转换为float64进行比较
|
||||||
if val != nil && val != "same" {
|
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||||
channelHasDiff[chName] = true
|
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||||
}
|
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||||
}
|
confidenceMap[channel.name][modelName] = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||||
|
for modelName := range allModels {
|
||||||
|
confidenceMap[channel.name][modelName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for modelName, ratioMap := range differences {
|
for modelName := range allModels {
|
||||||
for ratioType, item := range ratioMap {
|
for _, ratioType := range ratioTypes {
|
||||||
for chName := range item.Upstreams {
|
var localValue interface{} = nil
|
||||||
if !channelHasDiff[chName] {
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
delete(item.Upstreams, chName)
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
delete(item.Confidence, chName)
|
if val, exists := localRatio[modelName]; exists {
|
||||||
}
|
localValue = val
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
allSame := true
|
upstreamValues := make(map[string]interface{})
|
||||||
for _, v := range item.Upstreams {
|
confidenceValues := make(map[string]bool)
|
||||||
if v != "same" {
|
hasUpstreamValue := false
|
||||||
allSame = false
|
hasDifference := false
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(item.Upstreams) == 0 || allSame {
|
|
||||||
delete(ratioMap, ratioType)
|
|
||||||
} else {
|
|
||||||
differences[modelName][ratioType] = item
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ratioMap) == 0 {
|
for _, channel := range successfulChannels {
|
||||||
delete(differences, modelName)
|
var upstreamValue interface{} = nil
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return differences
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
|
if val, exists := upstreamRatio[modelName]; exists {
|
||||||
|
upstreamValue = val
|
||||||
|
hasUpstreamValue = true
|
||||||
|
|
||||||
|
if localValue != nil && !valuesEqual(localValue, val) {
|
||||||
|
hasDifference = true
|
||||||
|
} else if valuesEqual(localValue, val) {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if upstreamValue == nil && localValue == nil {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
|
||||||
|
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||||
|
hasDifference = true
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamValues[channel.name] = upstreamValue
|
||||||
|
|
||||||
|
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldInclude := false
|
||||||
|
|
||||||
|
if localValue != nil {
|
||||||
|
if hasDifference {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if hasUpstreamValue {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldInclude {
|
||||||
|
if differences[modelName] == nil {
|
||||||
|
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||||
|
}
|
||||||
|
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||||
|
Current: localValue,
|
||||||
|
Upstreams: upstreamValues,
|
||||||
|
Confidence: confidenceValues,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
channelHasDiff := make(map[string]bool)
|
||||||
|
for _, ratioMap := range differences {
|
||||||
|
for _, item := range ratioMap {
|
||||||
|
for chName, val := range item.Upstreams {
|
||||||
|
if val != nil && val != "same" {
|
||||||
|
channelHasDiff[chName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelName, ratioMap := range differences {
|
||||||
|
for ratioType, item := range ratioMap {
|
||||||
|
for chName := range item.Upstreams {
|
||||||
|
if !channelHasDiff[chName] {
|
||||||
|
delete(item.Upstreams, chName)
|
||||||
|
delete(item.Confidence, chName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allSame := true
|
||||||
|
for _, v := range item.Upstreams {
|
||||||
|
if v != "same" {
|
||||||
|
allSame = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(item.Upstreams) == 0 || allSame {
|
||||||
|
delete(ratioMap, ratioType)
|
||||||
|
} else {
|
||||||
|
differences[modelName][ratioType] = item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ratioMap) == 0 {
|
||||||
|
delete(differences, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return differences
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSyncableChannels(c *gin.Context) {
|
func GetSyncableChannels(c *gin.Context) {
|
||||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var syncableChannels []dto.SyncableChannel
|
var syncableChannels []dto.SyncableChannel
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||||
ID: channel.Id,
|
ID: channel.Id,
|
||||||
Name: channel.Name,
|
Name: channel.Name,
|
||||||
BaseURL: channel.GetBaseURL(),
|
BaseURL: channel.GetBaseURL(),
|
||||||
Status: channel.Status,
|
Status: channel.Status,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||||
"success": true,
|
ID: -100,
|
||||||
"message": "",
|
Name: "官方倍率预设",
|
||||||
"data": syncableChannels,
|
BaseURL: "https://basellm.github.io",
|
||||||
})
|
Status: 1,
|
||||||
}
|
})
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": syncableChannels,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,91 +1,52 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"errors"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllRedemptions(c *gin.Context) {
|
func GetAllRedemptions(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
if pageSize < 1 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(redemptions)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": gin.H{
|
|
||||||
"items": redemptions,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchRedemptions(c *gin.Context) {
|
func SearchRedemptions(c *gin.Context) {
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
if pageSize < 1 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(redemptions)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": gin.H{
|
|
||||||
"items": redemptions,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRedemption(c *gin.Context) {
|
func GetRedemption(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
redemption, err := model.GetRedemptionById(id)
|
redemption, err := model.GetRedemptionById(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -100,13 +61,10 @@ func AddRedemption(c *gin.Context) {
|
|||||||
redemption := model.Redemption{}
|
redemption := model.Redemption{}
|
||||||
err := c.ShouldBindJSON(&redemption)
|
err := c.ShouldBindJSON(&redemption)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "兑换码名称长度必须在1-20之间",
|
"message": "兑换码名称长度必须在1-20之间",
|
||||||
@@ -165,10 +123,7 @@ func DeleteRedemption(c *gin.Context) {
|
|||||||
id, _ := strconv.Atoi(c.Param("id"))
|
id, _ := strconv.Atoi(c.Param("id"))
|
||||||
err := model.DeleteRedemptionById(id)
|
err := model.DeleteRedemptionById(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -183,18 +138,12 @@ func UpdateRedemption(c *gin.Context) {
|
|||||||
redemption := model.Redemption{}
|
redemption := model.Redemption{}
|
||||||
err := c.ShouldBindJSON(&redemption)
|
err := c.ShouldBindJSON(&redemption)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
|
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if statusOnly == "" {
|
if statusOnly == "" {
|
||||||
@@ -212,10 +161,7 @@ func UpdateRedemption(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = cleanRedemption.Update()
|
err = cleanRedemption.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -229,16 +175,13 @@ func UpdateRedemption(c *gin.Context) {
|
|||||||
func DeleteInvalidRedemption(c *gin.Context) {
|
func DeleteInvalidRedemption(c *gin.Context) {
|
||||||
rows, err := model.DeleteInvalidRedemptions()
|
rows, err := model.DeleteInvalidRedemptions()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": rows,
|
"data": rows,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,115 +2,193 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
constant2 "one-api/constant"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||||||
var err *types.NewAPIError
|
var err *types.NewAPIError
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||||
err = relay.ImageHelper(c)
|
err = relay.ImageHelper(c, info)
|
||||||
case relayconstant.RelayModeAudioSpeech:
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranslation:
|
case relayconstant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err = relay.AudioHelper(c)
|
err = relay.AudioHelper(c, info)
|
||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err = relay.RerankHelper(c, relayMode)
|
err = relay.RerankHelper(c, info)
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
err = relay.EmbeddingHelper(c)
|
err = relay.EmbeddingHelper(c, info)
|
||||||
case relayconstant.RelayModeResponses:
|
case relayconstant.RelayModeResponses:
|
||||||
err = relay.ResponsesHelper(c)
|
err = relay.ResponsesHelper(c, info)
|
||||||
case relayconstant.RelayModeGemini:
|
|
||||||
err = relay.GeminiHelper(c)
|
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
if constant2.ErrorLogEnabled && err != nil {
|
|
||||||
// 保存错误日志到mysql中
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
modelName := c.GetString("original_model")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userGroup := c.GetString("group")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
other := make(map[string]interface{})
|
|
||||||
other["error_type"] = err.ErrorType
|
|
||||||
other["error_code"] = err.GetErrorCode()
|
|
||||||
other["status_code"] = err.StatusCode
|
|
||||||
other["channel_id"] = channelId
|
|
||||||
other["channel_name"] = c.GetString("channel_name")
|
|
||||||
other["channel_type"] = c.GetInt("channel_type")
|
|
||||||
|
|
||||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
var err *types.NewAPIError
|
||||||
|
if strings.Contains(c.Request.URL.Path, "embed") {
|
||||||
|
err = relay.GeminiEmbeddingHandler(c, info)
|
||||||
|
} else {
|
||||||
|
err = relay.GeminiHelper(c, info)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||||
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||||
var newAPIError *types.NewAPIError
|
|
||||||
|
var (
|
||||||
|
newAPIError *types.NewAPIError
|
||||||
|
ws *websocket.Conn
|
||||||
|
)
|
||||||
|
|
||||||
|
if relayFormat == types.RelayFormatOpenAIRealtime {
|
||||||
|
var err error
|
||||||
|
ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if newAPIError != nil {
|
||||||
|
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||||
|
switch relayFormat {
|
||||||
|
case types.RelayFormatOpenAIRealtime:
|
||||||
|
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||||
|
case types.RelayFormatClaude:
|
||||||
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": newAPIError.ToClaudeError(),
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
|
"error": newAPIError.ToOpenAIError(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
request, err := helper.GetAndValidateRequest(c, relayFormat)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := request.GetTokenCountMeta()
|
||||||
|
|
||||||
|
if setting.ShouldCheckPromptSensitive() {
|
||||||
|
contains, words := service.CheckSensitiveText(meta.CombineText)
|
||||||
|
if contains {
|
||||||
|
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := service.CountRequestToken(c, meta, relayInfo)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayInfo.SetPromptTokens(tokens)
|
||||||
|
|
||||||
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||||
|
|
||||||
|
newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
|
if newAPIError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
// Only return quota if downstream failed and quota was actually pre-consumed
|
||||||
|
if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
|
||||||
|
service.ReturnPreConsumedQuota(c, relayInfo)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
logger.LogError(c, err.Error())
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
newAPIError = err
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
newAPIError = relayRequest(c, relayMode, channel)
|
addUsedChannel(c, channel.Id)
|
||||||
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|
||||||
if newAPIError == nil {
|
switch relayFormat {
|
||||||
return // 成功处理请求,直接返回
|
case types.RelayFormatOpenAIRealtime:
|
||||||
|
newAPIError = relay.WssHelper(c, relayInfo)
|
||||||
|
case types.RelayFormatClaude:
|
||||||
|
newAPIError = relay.ClaudeHelper(c, relayInfo)
|
||||||
|
case types.RelayFormatGemini:
|
||||||
|
newAPIError = geminiRelayHandler(c, relayInfo)
|
||||||
|
default:
|
||||||
|
newAPIError = relayHandler(c, relayInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
if newAPIError == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c, retryLogStr)
|
logger.LogInfo(c, retryLogStr)
|
||||||
}
|
|
||||||
|
|
||||||
if newAPIError != nil {
|
|
||||||
if newAPIError.StatusCode == http.StatusTooManyRequests {
|
|
||||||
common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
|
||||||
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
|
||||||
}
|
|
||||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
|
||||||
c.JSON(newAPIError.StatusCode, gin.H{
|
|
||||||
"error": newAPIError.ToOpenAIError(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,129 +199,13 @@ var upgrader = websocket.Upgrader{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func WssRelay(c *gin.Context) {
|
|
||||||
// 将 HTTP 连接升级为 WebSocket 连接
|
|
||||||
|
|
||||||
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
||||||
defer ws.Close()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
|
||||||
group := c.GetString("group")
|
|
||||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
|
||||||
originalModel := c.GetString("original_model")
|
|
||||||
var newAPIError *types.NewAPIError
|
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, err.Error())
|
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
newAPIError = wssRequest(c, ws, relayMode, channel)
|
|
||||||
|
|
||||||
if newAPIError == nil {
|
|
||||||
return // 成功处理请求,直接返回
|
|
||||||
}
|
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
if len(useChannel) > 1 {
|
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
|
||||||
common.LogInfo(c, retryLogStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if newAPIError != nil {
|
|
||||||
if newAPIError.StatusCode == http.StatusTooManyRequests {
|
|
||||||
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
|
||||||
}
|
|
||||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
|
||||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func RelayClaude(c *gin.Context) {
|
|
||||||
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
|
||||||
group := c.GetString("group")
|
|
||||||
originalModel := c.GetString("original_model")
|
|
||||||
var newAPIError *types.NewAPIError
|
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, err.Error())
|
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
newAPIError = claudeRequest(c, channel)
|
|
||||||
|
|
||||||
if newAPIError == nil {
|
|
||||||
return // 成功处理请求,直接返回
|
|
||||||
}
|
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
if len(useChannel) > 1 {
|
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
|
||||||
common.LogInfo(c, retryLogStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if newAPIError != nil {
|
|
||||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
|
||||||
c.JSON(newAPIError.StatusCode, gin.H{
|
|
||||||
"type": "error",
|
|
||||||
"error": newAPIError.ToClaudeError(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
|
|
||||||
addUsedChannel(c, channel.Id)
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return relayHandler(c, relayMode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
|
|
||||||
addUsedChannel(c, channel.Id)
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return relay.WssHelper(c, ws)
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
|
|
||||||
addUsedChannel(c, channel.Id)
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return relay.ClaudeHelper(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addUsedChannel(c *gin.Context, channelId int) {
|
func addUsedChannel(c *gin.Context, channelId int) {
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
|
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||||||
if retryCount == 0 {
|
if retryCount == 0 {
|
||||||
autoBan := c.GetBool("auto_ban")
|
autoBan := c.GetBool("auto_ban")
|
||||||
autoBanInt := 1
|
autoBanInt := 1
|
||||||
@@ -259,12 +221,15 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
|||||||
}
|
}
|
||||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if group == "auto" {
|
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error()))
|
}
|
||||||
}
|
if channel == nil {
|
||||||
return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error()))
|
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
if newAPIError != nil {
|
||||||
|
return nil, newAPIError
|
||||||
}
|
}
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
||||||
return channel, nil
|
return channel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,7 +240,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
|||||||
if types.IsChannelError(openaiErr) {
|
if types.IsChannelError(openaiErr) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if types.IsLocalError(openaiErr) {
|
if types.IsSkipRetryError(openaiErr) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if retryTimes <= 0 {
|
if retryTimes <= 0 {
|
||||||
@@ -298,10 +263,6 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||||
channelType := c.GetInt("channel_type")
|
|
||||||
if channelType == constant.ChannelTypeAnthropic {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == 408 {
|
if openaiErr.StatusCode == 408 {
|
||||||
@@ -314,45 +275,84 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *types.NewAPIError) {
|
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||||
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error()))
|
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
gopool.Go(func() {
|
||||||
service.DisableChannel(channelId, channelName, err.Error())
|
service.DisableChannel(channelError, err.Error())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||||
|
// 保存错误日志到mysql中
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
modelName := c.GetString("original_model")
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userGroup := c.GetString("group")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
other := make(map[string]interface{})
|
||||||
|
other["error_type"] = err.GetErrorType()
|
||||||
|
other["error_code"] = err.GetErrorCode()
|
||||||
|
other["status_code"] = err.StatusCode
|
||||||
|
other["channel_id"] = channelId
|
||||||
|
other["channel_name"] = c.GetString("channel_name")
|
||||||
|
other["channel_type"] = c.GetInt("channel_type")
|
||||||
|
adminInfo := make(map[string]interface{})
|
||||||
|
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||||
|
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||||
|
if isMultiKey {
|
||||||
|
adminInfo["is_multi_key"] = true
|
||||||
|
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||||
|
}
|
||||||
|
other["admin_info"] = adminInfo
|
||||||
|
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourney(c *gin.Context) {
|
func RelayMidjourney(c *gin.Context) {
|
||||||
relayMode := c.GetInt("relay_mode")
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
|
||||||
var err *dto.MidjourneyResponse
|
|
||||||
switch relayMode {
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
|
||||||
|
"type": "upstream_error",
|
||||||
|
"code": 4,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var mjErr *dto.MidjourneyResponse
|
||||||
|
switch relayInfo.RelayMode {
|
||||||
case relayconstant.RelayModeMidjourneyNotify:
|
case relayconstant.RelayModeMidjourneyNotify:
|
||||||
err = relay.RelayMidjourneyNotify(c)
|
mjErr = relay.RelayMidjourneyNotify(c)
|
||||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
|
||||||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||||
err = relay.RelayMidjourneyTaskImageSeed(c)
|
mjErr = relay.RelayMidjourneyTaskImageSeed(c)
|
||||||
case relayconstant.RelayModeSwapFace:
|
case relayconstant.RelayModeSwapFace:
|
||||||
err = relay.RelaySwapFace(c)
|
mjErr = relay.RelaySwapFace(c, relayInfo)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
|
||||||
}
|
}
|
||||||
//err = relayMidjourneySubmit(c, relayMode)
|
//err = relayMidjourneySubmit(c, relayMode)
|
||||||
log.Println(err)
|
log.Println(mjErr)
|
||||||
if err != nil {
|
if mjErr != nil {
|
||||||
statusCode := http.StatusBadRequest
|
statusCode := http.StatusBadRequest
|
||||||
if err.Code == 30 {
|
if mjErr.Code == 30 {
|
||||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||||
statusCode = http.StatusTooManyRequests
|
statusCode = http.StatusTooManyRequests
|
||||||
}
|
}
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
"description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
|
||||||
"type": "upstream_error",
|
"type": "upstream_error",
|
||||||
"code": err.Code,
|
"code": mjErr.Code,
|
||||||
})
|
})
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,36 +383,39 @@ func RelayNotFound(c *gin.Context) {
|
|||||||
func RelayTask(c *gin.Context) {
|
func RelayTask(c *gin.Context) {
|
||||||
retryTimes := common.RetryTimes
|
retryTimes := common.RetryTimes
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
relayMode := c.GetInt("relay_mode")
|
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||||
taskErr := taskRelayHandler(c, relayMode)
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
taskErr := taskRelayHandler(c, relayInfo)
|
||||||
if taskErr == nil {
|
if taskErr == nil {
|
||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if newAPIError != nil {
|
||||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||||
taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
channelId = channel.Id
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
taskErr = taskRelayHandler(c, relayMode)
|
taskErr = taskRelayHandler(c, relayInfo)
|
||||||
}
|
}
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c, retryLogStr)
|
logger.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
if taskErr != nil {
|
if taskErr != nil {
|
||||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||||
@@ -422,13 +425,13 @@ func RelayTask(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
|
||||||
var err *dto.TaskError
|
var err *dto.TaskError
|
||||||
switch relayMode {
|
switch relayInfo.RelayMode {
|
||||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||||
err = relay.RelayTaskFetch(c, relayMode)
|
err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayTaskSubmit(c, relayMode)
|
err = relay.RelayTaskSubmit(c, relayInfo)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func GetSetup(c *gin.Context) {
|
|||||||
func PostSetup(c *gin.Context) {
|
func PostSetup(c *gin.Context) {
|
||||||
// Check if setup is already completed
|
// Check if setup is already completed
|
||||||
if constant.Setup {
|
if constant.Setup {
|
||||||
c.JSON(400, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "系统已经初始化完成",
|
"message": "系统已经初始化完成",
|
||||||
})
|
})
|
||||||
@@ -66,7 +66,7 @@ func PostSetup(c *gin.Context) {
|
|||||||
var req SetupRequest
|
var req SetupRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(400, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "请求参数有误",
|
"message": "请求参数有误",
|
||||||
})
|
})
|
||||||
@@ -77,7 +77,7 @@ func PostSetup(c *gin.Context) {
|
|||||||
if !rootExists {
|
if !rootExists {
|
||||||
// Validate username length: max 12 characters to align with model.User validation
|
// Validate username length: max 12 characters to align with model.User validation
|
||||||
if len(req.Username) > 12 {
|
if len(req.Username) > 12 {
|
||||||
c.JSON(400, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "用户名长度不能超过12个字符",
|
"message": "用户名长度不能超过12个字符",
|
||||||
})
|
})
|
||||||
@@ -85,7 +85,7 @@ func PostSetup(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
// Validate password
|
// Validate password
|
||||||
if req.Password != req.ConfirmPassword {
|
if req.Password != req.ConfirmPassword {
|
||||||
c.JSON(400, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "两次输入的密码不一致",
|
"message": "两次输入的密码不一致",
|
||||||
})
|
})
|
||||||
@@ -93,7 +93,7 @@ func PostSetup(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(req.Password) < 8 {
|
if len(req.Password) < 8 {
|
||||||
c.JSON(400, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "密码长度至少为8个字符",
|
"message": "密码长度至少为8个字符",
|
||||||
})
|
})
|
||||||
@@ -103,7 +103,7 @@ func PostSetup(c *gin.Context) {
|
|||||||
// Create root user
|
// Create root user
|
||||||
hashedPassword, err := common.Password2Hash(req.Password)
|
hashedPassword, err := common.Password2Hash(req.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(500, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "系统错误: " + err.Error(),
|
"message": "系统错误: " + err.Error(),
|
||||||
})
|
})
|
||||||
@@ -120,7 +120,7 @@ func PostSetup(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = model.DB.Create(&rootUser).Error
|
err = model.DB.Create(&rootUser).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(500, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "创建管理员账号失败: " + err.Error(),
|
"message": "创建管理员账号失败: " + err.Error(),
|
||||||
})
|
})
|
||||||
@@ -135,7 +135,7 @@ func PostSetup(c *gin.Context) {
|
|||||||
// Save operation modes to database for persistence
|
// Save operation modes to database for persistence
|
||||||
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
|
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(500, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "保存自用模式设置失败: " + err.Error(),
|
"message": "保存自用模式设置失败: " + err.Error(),
|
||||||
})
|
})
|
||||||
@@ -144,7 +144,7 @@ func PostSetup(c *gin.Context) {
|
|||||||
|
|
||||||
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
|
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(500, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "保存演示站点模式设置失败: " + err.Error(),
|
"message": "保存演示站点模式设置失败: " + err.Error(),
|
||||||
})
|
})
|
||||||
@@ -160,7 +160,7 @@ func PostSetup(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = model.DB.Create(&setup).Error
|
err = model.DB.Create(&setup).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(500, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "系统初始化失败: " + err.Error(),
|
"message": "系统初始化失败: " + err.Error(),
|
||||||
})
|
})
|
||||||
@@ -178,4 +178,4 @@ func boolToString(b bool) string {
|
|||||||
return "true"
|
return "true"
|
||||||
}
|
}
|
||||||
return "false"
|
return "false"
|
||||||
}
|
}
|
||||||
136
controller/swag_video.go
Normal file
136
controller/swag_video.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VideoGenerations
|
||||||
|
// @Summary 生成视频
|
||||||
|
// @Description 调用视频生成接口生成视频
|
||||||
|
// @Description 支持多种视频生成服务:
|
||||||
|
// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
|
||||||
|
// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
|
||||||
|
// @Tags Video
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||||
|
// @Param request body dto.VideoRequest true "视频生成请求参数"
|
||||||
|
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||||
|
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||||
|
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||||
|
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||||
|
// @Router /v1/video/generations [post]
|
||||||
|
func VideoGenerations(c *gin.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoGenerationsTaskId
|
||||||
|
// @Summary 查询视频
|
||||||
|
// @Description 根据任务ID查询视频生成任务的状态和结果
|
||||||
|
// @Tags Video
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param task_id path string true "Task ID"
|
||||||
|
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||||
|
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||||
|
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||||
|
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||||
|
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||||
|
// @Router /v1/video/generations/{task_id} [get]
|
||||||
|
func VideoGenerationsTaskId(c *gin.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// KlingText2VideoGenerations
|
||||||
|
// @Summary 可灵文生视频
|
||||||
|
// @Description 调用可灵AI文生视频接口,生成视频内容
|
||||||
|
// @Tags Video
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||||
|
// @Param request body KlingText2VideoRequest true "视频生成请求参数"
|
||||||
|
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||||
|
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||||
|
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||||
|
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||||
|
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||||
|
// @Router /kling/v1/videos/text2video [post]
|
||||||
|
func KlingText2VideoGenerations(c *gin.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
type KlingText2VideoRequest struct {
|
||||||
|
ModelName string `json:"model_name,omitempty" example:"kling-v1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
|
||||||
|
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||||
|
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||||
|
Mode string `json:"mode,omitempty" example:"std"`
|
||||||
|
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||||
|
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||||
|
Duration string `json:"duration,omitempty" example:"5"`
|
||||||
|
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||||
|
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KlingCameraControl struct {
|
||||||
|
Type string `json:"type,omitempty" example:"simple"`
|
||||||
|
Config *KlingCameraConfig `json:"config,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KlingCameraConfig struct {
|
||||||
|
Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
|
||||||
|
Vertical float64 `json:"vertical,omitempty" example:"0"`
|
||||||
|
Pan float64 `json:"pan,omitempty" example:"0"`
|
||||||
|
Tilt float64 `json:"tilt,omitempty" example:"0"`
|
||||||
|
Roll float64 `json:"roll,omitempty" example:"0"`
|
||||||
|
Zoom float64 `json:"zoom,omitempty" example:"0"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KlingImage2VideoGenerations
|
||||||
|
// @Summary 可灵官方-图生视频
|
||||||
|
// @Description 调用可灵AI图生视频接口,生成视频内容
|
||||||
|
// @Tags Video
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||||
|
// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
|
||||||
|
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||||
|
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||||
|
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||||
|
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||||
|
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||||
|
// @Router /kling/v1/videos/image2video [post]
|
||||||
|
func KlingImage2VideoGenerations(c *gin.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
type KlingImage2VideoRequest struct {
|
||||||
|
ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
|
||||||
|
Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
|
||||||
|
Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
|
||||||
|
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||||
|
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||||
|
Mode string `json:"mode,omitempty" example:"std"`
|
||||||
|
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||||
|
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||||
|
Duration string `json:"duration,omitempty" example:"5"`
|
||||||
|
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||||
|
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KlingImage2videoTaskId godoc
|
||||||
|
// @Summary 可灵任务查询--图生视频
|
||||||
|
// @Description Query the status and result of a Kling video generation task by task ID
|
||||||
|
// @Tags Origin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param task_id path string true "Task ID"
|
||||||
|
// @Router /kling/v1/videos/image2video/{task_id} [get]
|
||||||
|
func KlingImage2videoTaskId(c *gin.Context) {}
|
||||||
|
|
||||||
|
// KlingText2videoTaskId godoc
|
||||||
|
// @Summary 可灵任务查询--文生视频
|
||||||
|
// @Description Query the status and result of a Kling text-to-video generation task by task ID
|
||||||
|
// @Tags Origin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param task_id path string true "Task ID"
|
||||||
|
// @Router /kling/v1/videos/text2video/{task_id} [get]
|
||||||
|
func KlingText2videoTaskId(c *gin.Context) {}
|
||||||
@@ -5,18 +5,20 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/samber/lo"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/samber/lo"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UpdateTaskBulk() {
|
func UpdateTaskBulk() {
|
||||||
@@ -53,9 +55,9 @@ func UpdateTaskBulk() {
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||||
} else {
|
} else {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(taskChannelM) == 0 {
|
if len(taskChannelM) == 0 {
|
||||||
@@ -74,10 +76,10 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
|||||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||||
case constant.TaskPlatformSuno:
|
case constant.TaskPlatformSuno:
|
||||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
|
|
||||||
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
|
|
||||||
default:
|
default:
|
||||||
common.SysLog("未知平台")
|
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,14 +87,14 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
|
|||||||
for channelId, taskIds := range taskChannelM {
|
for channelId, taskIds := range taskChannelM {
|
||||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||||
if len(taskIds) == 0 {
|
if len(taskIds) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -105,7 +107,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -117,23 +119,23 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
"ids": taskIds,
|
"ids": taskIds,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !responseItems.IsSuccess() {
|
if !responseItems.IsSuccess() {
|
||||||
@@ -153,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||||
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
logger.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
} else {
|
} else {
|
||||||
quota := task.Quota
|
quota := task.Quota
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
|
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -177,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
|
|
||||||
err = task.Update()
|
err = task.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -225,14 +227,7 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetAllTask(c *gin.Context) {
|
func GetAllTask(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
||||||
if pageSize <= 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
|
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
@@ -247,30 +242,15 @@ func GetAllTask(c *gin.Context) {
|
|||||||
ChannelID: c.Query("channel_id"),
|
ChannelID: c.Query("channel_id"),
|
||||||
}
|
}
|
||||||
|
|
||||||
items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||||
total := model.TaskCountAllTasks(queryParams)
|
total := model.TaskCountAllTasks(queryParams)
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
c.JSON(200, gin.H{
|
pageInfo.SetItems(items)
|
||||||
"success": true,
|
common.ApiSuccess(c, pageInfo)
|
||||||
"message": "",
|
|
||||||
"data": gin.H{
|
|
||||||
"items": items,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserTask(c *gin.Context) {
|
func GetUserTask(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
||||||
if pageSize <= 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
@@ -286,17 +266,9 @@ func GetUserTask(c *gin.Context) {
|
|||||||
EndTimestamp: endTimestamp,
|
EndTimestamp: endTimestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
c.JSON(200, gin.H{
|
pageInfo.SetItems(items)
|
||||||
"success": true,
|
common.ApiSuccess(c, pageInfo)
|
||||||
"message": "",
|
|
||||||
"data": gin.H{
|
|
||||||
"items": items,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,27 +2,31 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||||
for channelId, taskIds := range taskChannelM {
|
for channelId, taskIds := range taskChannelM {
|
||||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||||
if len(taskIds) == 0 {
|
if len(taskIds) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -34,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if errUpdate != nil {
|
if errUpdate != nil {
|
||||||
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||||
}
|
}
|
||||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -44,7 +48,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
|||||||
}
|
}
|
||||||
for _, taskId := range taskIds {
|
for _, taskId := range taskIds {
|
||||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -58,7 +62,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
|
|
||||||
task := taskM[taskId]
|
task := taskM[taskId]
|
||||||
if task == nil {
|
if task == nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||||
return fmt.Errorf("task %s not found", taskId)
|
return fmt.Errorf("task %s not found", taskId)
|
||||||
}
|
}
|
||||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||||
@@ -77,13 +81,21 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
taskResult, err := adaptor.ParseTaskResult(responseBody)
|
taskResult := &relaycommon.TaskInfo{}
|
||||||
if err != nil {
|
// try parse as New API response format
|
||||||
|
var responseItems dto.TaskResponse[model.Task]
|
||||||
|
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||||
|
t := responseItems.Data
|
||||||
|
taskResult.TaskID = t.TaskID
|
||||||
|
taskResult.Status = string(t.Status)
|
||||||
|
taskResult.Url = t.FailReason
|
||||||
|
taskResult.Progress = t.Progress
|
||||||
|
taskResult.Reason = t.FailReason
|
||||||
|
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||||
|
} else {
|
||||||
|
task.Data = redactVideoResponseBody(responseBody)
|
||||||
}
|
}
|
||||||
//if taskResult.Code != 0 {
|
|
||||||
// return fmt.Errorf("video task fetch failed for task %s", taskId)
|
|
||||||
//}
|
|
||||||
|
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
if taskResult.Status == "" {
|
if taskResult.Status == "" {
|
||||||
@@ -105,7 +117,9 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
if task.FinishTime == 0 {
|
if task.FinishTime == 0 {
|
||||||
task.FinishTime = now
|
task.FinishTime = now
|
||||||
}
|
}
|
||||||
task.FailReason = taskResult.Url
|
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||||
|
task.FailReason = taskResult.Url
|
||||||
|
}
|
||||||
case model.TaskStatusFailure:
|
case model.TaskStatusFailure:
|
||||||
task.Status = model.TaskStatusFailure
|
task.Status = model.TaskStatusFailure
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
@@ -113,13 +127,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
task.FinishTime = now
|
task.FinishTime = now
|
||||||
}
|
}
|
||||||
task.FailReason = taskResult.Reason
|
task.FailReason = taskResult.Reason
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||||
quota := task.Quota
|
quota := task.Quota
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||||
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -128,11 +142,43 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
if taskResult.Progress != "" {
|
if taskResult.Progress != "" {
|
||||||
task.Progress = taskResult.Progress
|
task.Progress = taskResult.Progress
|
||||||
}
|
}
|
||||||
|
|
||||||
task.Data = responseBody
|
|
||||||
if err := task.Update(); err != nil {
|
if err := task.Update(); err != nil {
|
||||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func redactVideoResponseBody(body []byte) []byte {
|
||||||
|
var m map[string]any
|
||||||
|
if err := json.Unmarshal(body, &m); err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
resp, _ := m["response"].(map[string]any)
|
||||||
|
if resp != nil {
|
||||||
|
delete(resp, "bytesBase64Encoded")
|
||||||
|
if v, ok := resp["video"].(string); ok {
|
||||||
|
resp["video"] = truncateBase64(v)
|
||||||
|
}
|
||||||
|
if vs, ok := resp["videos"].([]any); ok {
|
||||||
|
for i := range vs {
|
||||||
|
if vm, ok := vs[i].(map[string]any); ok {
|
||||||
|
delete(vm, "bytesBase64Encoded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateBase64(s string) string {
|
||||||
|
const maxKeep = 256
|
||||||
|
if len(s) <= maxKeep {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxKeep] + "..."
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,46 +1,27 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllTokens(c *gin.Context) {
|
func GetAllTokens(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
size, _ := strconv.Atoi(c.Query("size"))
|
tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
if size <= 0 {
|
|
||||||
size = common.ItemsPerPage
|
|
||||||
} else if size > 100 {
|
|
||||||
size = 100
|
|
||||||
}
|
|
||||||
tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Get total count for pagination
|
|
||||||
total, _ := model.CountUserTokens(userId)
|
total, _ := model.CountUserTokens(userId)
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
c.JSON(http.StatusOK, gin.H{
|
pageInfo.SetItems(tokens)
|
||||||
"success": true,
|
common.ApiSuccess(c, pageInfo)
|
||||||
"message": "",
|
|
||||||
"data": gin.H{
|
|
||||||
"items": tokens,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": size,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,10 +31,7 @@ func SearchTokens(c *gin.Context) {
|
|||||||
token := c.Query("token")
|
token := c.Query("token")
|
||||||
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -68,18 +46,12 @@ func GetToken(c *gin.Context) {
|
|||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err := model.GetTokenByIds(id, userId)
|
token, err := model.GetTokenByIds(id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -95,10 +67,7 @@ func GetTokenStatus(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
token, err := model.GetTokenByIds(tokenId, userId)
|
token, err := model.GetTokenByIds(tokenId, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
expiredAt := token.ExpiredTime
|
expiredAt := token.ExpiredTime
|
||||||
@@ -114,9 +83,27 @@ func GetTokenStatus(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddToken(c *gin.Context) {
|
func GetTokenUsage(c *gin.Context) {
|
||||||
token := model.Token{}
|
authHeader := c.GetHeader("Authorization")
|
||||||
err := c.ShouldBindJSON(&token)
|
if authHeader == "" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "No Authorization header",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(authHeader, " ")
|
||||||
|
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "Invalid Bearer token",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenKey := parts[1]
|
||||||
|
|
||||||
|
token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -124,6 +111,36 @@ func AddToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expiredAt := token.ExpiredTime
|
||||||
|
if expiredAt == -1 {
|
||||||
|
expiredAt = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": true,
|
||||||
|
"message": "ok",
|
||||||
|
"data": gin.H{
|
||||||
|
"object": "token_usage",
|
||||||
|
"name": token.Name,
|
||||||
|
"total_granted": token.RemainQuota + token.UsedQuota,
|
||||||
|
"total_used": token.UsedQuota,
|
||||||
|
"total_available": token.RemainQuota,
|
||||||
|
"unlimited_quota": token.UnlimitedQuota,
|
||||||
|
"model_limits": token.GetModelLimitsMap(),
|
||||||
|
"model_limits_enabled": token.ModelLimitsEnabled,
|
||||||
|
"expires_at": expiredAt,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddToken(c *gin.Context) {
|
||||||
|
token := model.Token{}
|
||||||
|
err := c.ShouldBindJSON(&token)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
if len(token.Name) > 30 {
|
if len(token.Name) > 30 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -137,7 +154,7 @@ func AddToken(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成令牌失败",
|
"message": "生成令牌失败",
|
||||||
})
|
})
|
||||||
common.SysError("failed to generate token key: " + err.Error())
|
common.SysLog("failed to generate token key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cleanToken := model.Token{
|
cleanToken := model.Token{
|
||||||
@@ -156,10 +173,7 @@ func AddToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = cleanToken.Insert()
|
err = cleanToken.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -174,10 +188,7 @@ func DeleteToken(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
err := model.DeleteTokenById(id, userId)
|
err := model.DeleteTokenById(id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -193,10 +204,7 @@ func UpdateToken(c *gin.Context) {
|
|||||||
token := model.Token{}
|
token := model.Token{}
|
||||||
err := c.ShouldBindJSON(&token)
|
err := c.ShouldBindJSON(&token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(token.Name) > 30 {
|
if len(token.Name) > 30 {
|
||||||
@@ -208,10 +216,7 @@ func UpdateToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if token.Status == common.TokenStatusEnabled {
|
if token.Status == common.TokenStatusEnabled {
|
||||||
@@ -245,10 +250,7 @@ func UpdateToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = cleanToken.Update()
|
err = cleanToken.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -275,10 +277,7 @@ func DeleteTokenBatch(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -5,9 +5,12 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/operation_setting"
|
||||||
|
"one-api/setting/system_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,6 +21,44 @@ import (
|
|||||||
"github.com/shopspring/decimal"
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func GetTopUpInfo(c *gin.Context) {
|
||||||
|
// 获取支付方式
|
||||||
|
payMethods := operation_setting.PayMethods
|
||||||
|
|
||||||
|
// 如果启用了 Stripe 支付,添加到支付方法列表
|
||||||
|
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
|
||||||
|
// 检查是否已经包含 Stripe
|
||||||
|
hasStripe := false
|
||||||
|
for _, method := range payMethods {
|
||||||
|
if method["type"] == "stripe" {
|
||||||
|
hasStripe = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasStripe {
|
||||||
|
stripeMethod := map[string]string{
|
||||||
|
"name": "Stripe",
|
||||||
|
"type": "stripe",
|
||||||
|
"color": "rgba(var(--semi-purple-5), 1)",
|
||||||
|
"min_topup": strconv.Itoa(setting.StripeMinTopUp),
|
||||||
|
}
|
||||||
|
payMethods = append(payMethods, stripeMethod)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data := gin.H{
|
||||||
|
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
|
||||||
|
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||||
|
"pay_methods": payMethods,
|
||||||
|
"min_topup": operation_setting.MinTopUp,
|
||||||
|
"stripe_min_topup": setting.StripeMinTopUp,
|
||||||
|
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||||
|
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
type EpayRequest struct {
|
type EpayRequest struct {
|
||||||
Amount int64 `json:"amount"`
|
Amount int64 `json:"amount"`
|
||||||
PaymentMethod string `json:"payment_method"`
|
PaymentMethod string `json:"payment_method"`
|
||||||
@@ -30,13 +71,13 @@ type AmountRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetEpayClient() *epay.Client {
|
func GetEpayClient() *epay.Client {
|
||||||
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
|
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
withUrl, err := epay.NewClient(&epay.Config{
|
withUrl, err := epay.NewClient(&epay.Config{
|
||||||
PartnerID: setting.EpayId,
|
PartnerID: operation_setting.EpayId,
|
||||||
Key: setting.EpayKey,
|
Key: operation_setting.EpayKey,
|
||||||
}, setting.PayAddress)
|
}, operation_setting.PayAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -57,15 +98,23 @@ func getPayMoney(amount int64, group string) float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
|
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
|
||||||
dPrice := decimal.NewFromFloat(setting.Price)
|
dPrice := decimal.NewFromFloat(operation_setting.Price)
|
||||||
|
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||||
|
discount := 1.0
|
||||||
|
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok {
|
||||||
|
if ds > 0 {
|
||||||
|
discount = ds
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dDiscount := decimal.NewFromFloat(discount)
|
||||||
|
|
||||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
|
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio).Mul(dDiscount)
|
||||||
|
|
||||||
return payMoney.InexactFloat64()
|
return payMoney.InexactFloat64()
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMinTopup() int64 {
|
func getMinTopup() int64 {
|
||||||
minTopup := setting.MinTopUp
|
minTopup := operation_setting.MinTopUp
|
||||||
if !common.DisplayInCurrencyEnabled {
|
if !common.DisplayInCurrencyEnabled {
|
||||||
dMinTopup := decimal.NewFromInt(int64(minTopup))
|
dMinTopup := decimal.NewFromInt(int64(minTopup))
|
||||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||||
@@ -98,13 +147,13 @@ func RequestEpay(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
|
||||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
callBackAddress := service.GetCallbackAddress()
|
callBackAddress := service.GetCallbackAddress()
|
||||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log")
|
||||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||||
@@ -231,7 +280,7 @@ func EpayNotify(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
|
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||||
|
|||||||
285
controller/topup_stripe.go
Normal file
285
controller/topup_stripe.go
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/setting"
|
||||||
|
"one-api/setting/operation_setting"
|
||||||
|
"one-api/setting/system_setting"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stripe/stripe-go/v81"
|
||||||
|
"github.com/stripe/stripe-go/v81/checkout/session"
|
||||||
|
"github.com/stripe/stripe-go/v81/webhook"
|
||||||
|
"github.com/thanhpk/randstr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PaymentMethodStripe = "stripe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stripeAdaptor = &StripeAdaptor{}
|
||||||
|
|
||||||
|
type StripePayRequest struct {
|
||||||
|
Amount int64 `json:"amount"`
|
||||||
|
PaymentMethod string `json:"payment_method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StripeAdaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
|
||||||
|
if req.Amount < getStripeMinTopup() {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := c.GetInt("id")
|
||||||
|
group, err := model.GetUserGroup(id, true)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payMoney := getStripePayMoney(float64(req.Amount), group)
|
||||||
|
if payMoney <= 0.01 {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||||
|
if req.PaymentMethod != PaymentMethodStripe {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Amount < getStripeMinTopup() {
|
||||||
|
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Amount > 10000 {
|
||||||
|
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.GetInt("id")
|
||||||
|
user, _ := model.GetUserById(id, false)
|
||||||
|
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||||
|
|
||||||
|
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
|
||||||
|
referenceId := "ref_" + common.Sha1([]byte(reference))
|
||||||
|
|
||||||
|
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topUp := &model.TopUp{
|
||||||
|
UserId: id,
|
||||||
|
Amount: req.Amount,
|
||||||
|
Money: chargedMoney,
|
||||||
|
TradeNo: referenceId,
|
||||||
|
CreateTime: time.Now().Unix(),
|
||||||
|
Status: common.TopUpStatusPending,
|
||||||
|
}
|
||||||
|
err = topUp.Insert()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"message": "success",
|
||||||
|
"data": gin.H{
|
||||||
|
"pay_link": payLink,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RequestStripeAmount(c *gin.Context) {
|
||||||
|
var req StripePayRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stripeAdaptor.RequestAmount(c, &req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RequestStripePay(c *gin.Context) {
|
||||||
|
var req StripePayRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stripeAdaptor.RequestPay(c, &req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripeWebhook(c *gin.Context) {
|
||||||
|
payload, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||||
|
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
signature := c.GetHeader("Stripe-Signature")
|
||||||
|
endpointSecret := setting.StripeWebhookSecret
|
||||||
|
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
|
||||||
|
IgnoreAPIVersionMismatch: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||||
|
c.AbortWithStatus(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch event.Type {
|
||||||
|
case stripe.EventTypeCheckoutSessionCompleted:
|
||||||
|
sessionCompleted(event)
|
||||||
|
case stripe.EventTypeCheckoutSessionExpired:
|
||||||
|
sessionExpired(event)
|
||||||
|
default:
|
||||||
|
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionCompleted(event stripe.Event) {
|
||||||
|
customerId := event.GetObjectValue("customer")
|
||||||
|
referenceId := event.GetObjectValue("client_reference_id")
|
||||||
|
status := event.GetObjectValue("status")
|
||||||
|
if "complete" != status {
|
||||||
|
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := model.Recharge(referenceId, customerId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err.Error(), referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||||
|
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||||
|
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionExpired(event stripe.Event) {
|
||||||
|
referenceId := event.GetObjectValue("client_reference_id")
|
||||||
|
status := event.GetObjectValue("status")
|
||||||
|
if "expired" != status {
|
||||||
|
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(referenceId) == 0 {
|
||||||
|
log.Println("未提供支付单号")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||||
|
if topUp == nil {
|
||||||
|
log.Println("充值订单不存在", referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if topUp.Status != common.TopUpStatusPending {
|
||||||
|
log.Println("充值订单状态错误", referenceId)
|
||||||
|
}
|
||||||
|
|
||||||
|
topUp.Status = common.TopUpStatusExpired
|
||||||
|
err := topUp.Update()
|
||||||
|
if err != nil {
|
||||||
|
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("充值订单已过期", referenceId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
||||||
|
if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
|
||||||
|
return "", fmt.Errorf("无效的Stripe API密钥")
|
||||||
|
}
|
||||||
|
|
||||||
|
stripe.Key = setting.StripeApiSecret
|
||||||
|
|
||||||
|
params := &stripe.CheckoutSessionParams{
|
||||||
|
ClientReferenceID: stripe.String(referenceId),
|
||||||
|
SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"),
|
||||||
|
CancelURL: stripe.String(system_setting.ServerAddress + "/topup"),
|
||||||
|
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||||
|
{
|
||||||
|
Price: stripe.String(setting.StripePriceId),
|
||||||
|
Quantity: stripe.Int64(amount),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "" == customerId {
|
||||||
|
if "" != email {
|
||||||
|
params.CustomerEmail = stripe.String(email)
|
||||||
|
}
|
||||||
|
|
||||||
|
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
||||||
|
} else {
|
||||||
|
params.Customer = stripe.String(customerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := session.New(params)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.URL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetChargedAmount(count float64, user model.User) float64 {
|
||||||
|
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||||
|
if topUpGroupRatio == 0 {
|
||||||
|
topUpGroupRatio = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return count * topUpGroupRatio
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStripePayMoney(amount float64, group string) float64 {
|
||||||
|
originalAmount := amount
|
||||||
|
if !common.DisplayInCurrencyEnabled {
|
||||||
|
amount = amount / common.QuotaPerUnit
|
||||||
|
}
|
||||||
|
// Using float64 for monetary calculations is acceptable here due to the small amounts involved
|
||||||
|
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||||
|
if topupGroupRatio == 0 {
|
||||||
|
topupGroupRatio = 1
|
||||||
|
}
|
||||||
|
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||||
|
discount := 1.0
|
||||||
|
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
|
||||||
|
if ds > 0 {
|
||||||
|
discount = ds
|
||||||
|
}
|
||||||
|
}
|
||||||
|
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
|
||||||
|
return payMoney
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStripeMinTopup() int64 {
|
||||||
|
minTopup := setting.StripeMinTopUp
|
||||||
|
if !common.DisplayInCurrencyEnabled {
|
||||||
|
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||||
|
}
|
||||||
|
return int64(minTopup)
|
||||||
|
}
|
||||||
553
controller/twofa.go
Normal file
553
controller/twofa.go
Normal file
@@ -0,0 +1,553 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Setup2FARequest 设置2FA请求结构
|
||||||
|
type Setup2FARequest struct {
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify2FARequest 验证2FA请求结构
|
||||||
|
type Verify2FARequest struct {
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup2FAResponse 设置2FA响应结构
|
||||||
|
type Setup2FAResponse struct {
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
QRCodeData string `json:"qr_code_data"`
|
||||||
|
BackupCodes []string `json:"backup_codes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup2FA 初始化2FA设置
|
||||||
|
func Setup2FA(c *gin.Context) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 检查用户是否已经启用2FA
|
||||||
|
existing, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing != nil && existing.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户已启用2FA,请先禁用后重新设置",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果存在已禁用的2FA记录,先删除它
|
||||||
|
if existing != nil && !existing.IsEnabled {
|
||||||
|
if err := existing.Delete(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
existing = nil // 重置为nil,后续将创建新记录
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户信息
|
||||||
|
user, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成TOTP密钥
|
||||||
|
key, err := common.GenerateTOTPSecret(user.Username)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "生成2FA密钥失败",
|
||||||
|
})
|
||||||
|
common.SysLog("生成TOTP密钥失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成备用码
|
||||||
|
backupCodes, err := common.GenerateBackupCodes()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "生成备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("生成备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成二维码数据
|
||||||
|
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
|
||||||
|
|
||||||
|
// 创建或更新2FA记录(暂未启用)
|
||||||
|
twoFA := &model.TwoFA{
|
||||||
|
UserId: userId,
|
||||||
|
Secret: key.Secret(),
|
||||||
|
IsEnabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if existing != nil {
|
||||||
|
// 更新现有记录
|
||||||
|
twoFA.Id = existing.Id
|
||||||
|
err = twoFA.Update()
|
||||||
|
} else {
|
||||||
|
// 创建新记录
|
||||||
|
err = twoFA.Create()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建备用码记录
|
||||||
|
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "保存备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("保存备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
|
||||||
|
"data": Setup2FAResponse{
|
||||||
|
Secret: key.Secret(),
|
||||||
|
QRCodeData: qrCodeData,
|
||||||
|
BackupCodes: backupCodes,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable2FA 启用2FA
|
||||||
|
func Enable2FA(c *gin.Context) {
|
||||||
|
var req Setup2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "请先完成2FA初始化设置",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "2FA已经启用",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启用2FA
|
||||||
|
if err := twoFA.Enable(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "两步验证启用成功",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable2FA 禁用2FA
|
||||||
|
func Disable2FA(c *gin.Context) {
|
||||||
|
var req Verify2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码或备用码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
isValidTOTP := false
|
||||||
|
isValidBackup := false
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// 尝试验证TOTP
|
||||||
|
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP {
|
||||||
|
// 尝试验证备用码
|
||||||
|
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP && !isValidBackup {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 禁用2FA
|
||||||
|
if err := model.DisableTwoFA(userId); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "两步验证已禁用",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FAStatus 获取用户2FA状态
|
||||||
|
func Get2FAStatus(c *gin.Context) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status := map[string]interface{}{
|
||||||
|
"enabled": false,
|
||||||
|
"locked": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if twoFA != nil {
|
||||||
|
status["enabled"] = twoFA.IsEnabled
|
||||||
|
status["locked"] = twoFA.IsLocked()
|
||||||
|
if twoFA.IsEnabled {
|
||||||
|
// 获取剩余备用码数量
|
||||||
|
backupCount, err := model.GetUnusedBackupCodeCount(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("获取备用码数量失败: " + err.Error())
|
||||||
|
} else {
|
||||||
|
status["backup_codes_remaining"] = backupCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegenerateBackupCodes 重新生成备用码
|
||||||
|
func RegenerateBackupCodes(c *gin.Context) {
|
||||||
|
var req Verify2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成新的备用码
|
||||||
|
backupCodes, err := common.GenerateBackupCodes()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "生成备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("生成备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存新的备用码
|
||||||
|
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "保存备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("保存备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "备用码重新生成成功",
|
||||||
|
"data": map[string]interface{}{
|
||||||
|
"backup_codes": backupCodes,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify2FALogin 登录时验证2FA
|
||||||
|
func Verify2FALogin(c *gin.Context) {
|
||||||
|
var req Verify2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从会话中获取pending用户信息
|
||||||
|
session := sessions.Default(c)
|
||||||
|
pendingUserId := session.Get("pending_user_id")
|
||||||
|
if pendingUserId == nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "会话已过期,请重新登录",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId, ok := pendingUserId.(int)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "会话数据无效,请重新登录",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 获取用户信息
|
||||||
|
user, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户不存在",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(user.Id)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码或备用码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
isValidTOTP := false
|
||||||
|
isValidBackup := false
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// 尝试验证TOTP
|
||||||
|
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP {
|
||||||
|
// 尝试验证备用码
|
||||||
|
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP && !isValidBackup {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2FA验证成功,清理pending会话信息并完成登录
|
||||||
|
session.Delete("pending_username")
|
||||||
|
session.Delete("pending_user_id")
|
||||||
|
session.Save()
|
||||||
|
|
||||||
|
setupLogin(user, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admin2FAStats 管理员获取2FA统计信息
|
||||||
|
func Admin2FAStats(c *gin.Context) {
|
||||||
|
stats, err := model.GetTwoFAStats()
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": stats,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminDisable2FA 管理员强制禁用用户2FA
|
||||||
|
func AdminDisable2FA(c *gin.Context) {
|
||||||
|
userIdStr := c.Param("id")
|
||||||
|
userId, err := strconv.Atoi(userIdStr)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户ID格式错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查目标用户权限
|
||||||
|
targetUser, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
myRole := c.GetInt("role")
|
||||||
|
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无权操作同级或更高级用户的2FA设置",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 禁用2FA
|
||||||
|
if err := model.DisableTwoFA(userId); err != nil {
|
||||||
|
if errors.Is(err, model.ErrTwoFANotEnabled) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
adminId := c.GetInt("id")
|
||||||
|
model.RecordLog(userId, model.LogTypeManage,
|
||||||
|
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "用户2FA已被强制禁用",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -31,7 +31,7 @@ type Monitor struct {
|
|||||||
|
|
||||||
type UptimeGroupResult struct {
|
type UptimeGroupResult struct {
|
||||||
CategoryName string `json:"categoryName"`
|
CategoryName string `json:"categoryName"`
|
||||||
Monitors []Monitor `json:"monitors"`
|
Monitors []Monitor `json:"monitors"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||||
@@ -57,29 +57,29 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
|||||||
url, _ := groupConfig["url"].(string)
|
url, _ := groupConfig["url"].(string)
|
||||||
slug, _ := groupConfig["slug"].(string)
|
slug, _ := groupConfig["slug"].(string)
|
||||||
categoryName, _ := groupConfig["categoryName"].(string)
|
categoryName, _ := groupConfig["categoryName"].(string)
|
||||||
|
|
||||||
result := UptimeGroupResult{
|
result := UptimeGroupResult{
|
||||||
CategoryName: categoryName,
|
CategoryName: categoryName,
|
||||||
Monitors: []Monitor{},
|
Monitors: []Monitor{},
|
||||||
}
|
}
|
||||||
|
|
||||||
if url == "" || slug == "" {
|
if url == "" || slug == "" {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := strings.TrimSuffix(url, "/")
|
baseURL := strings.TrimSuffix(url, "/")
|
||||||
|
|
||||||
var statusData struct {
|
var statusData struct {
|
||||||
PublicGroupList []struct {
|
PublicGroupList []struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
MonitorList []struct {
|
MonitorList []struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
} `json:"monitorList"`
|
} `json:"monitorList"`
|
||||||
} `json:"publicGroupList"`
|
} `json:"publicGroupList"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var heartbeatData struct {
|
var heartbeatData struct {
|
||||||
HeartbeatList map[string][]struct {
|
HeartbeatList map[string][]struct {
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
@@ -88,11 +88,11 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
|||||||
}
|
}
|
||||||
|
|
||||||
g, gCtx := errgroup.WithContext(ctx)
|
g, gCtx := errgroup.WithContext(ctx)
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||||
})
|
})
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||||
})
|
})
|
||||||
|
|
||||||
if g.Wait() != nil {
|
if g.Wait() != nil {
|
||||||
@@ -139,7 +139,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
|||||||
|
|
||||||
client := &http.Client{Timeout: httpTimeout}
|
client := &http.Client{Timeout: httpTimeout}
|
||||||
results := make([]UptimeGroupResult, len(groups))
|
results := make([]UptimeGroupResult, len(groups))
|
||||||
|
|
||||||
g, gCtx := errgroup.WithContext(ctx)
|
g, gCtx := errgroup.WithContext(ctx)
|
||||||
for i, group := range groups {
|
for i, group := range groups {
|
||||||
i, group := i, group
|
i, group := i, group
|
||||||
@@ -148,7 +148,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Wait()
|
g.Wait()
|
||||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllQuotaDates(c *gin.Context) {
|
func GetAllQuotaDates(c *gin.Context) {
|
||||||
@@ -13,10 +15,7 @@ func GetAllQuotaDates(c *gin.Context) {
|
|||||||
username := c.Query("username")
|
username := c.Query("username")
|
||||||
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
|
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -41,10 +40,7 @@ func GetUserQuotaDates(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
|
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -62,6 +63,32 @@ func Login(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否启用2FA
|
||||||
|
if model.IsTwoFAEnabled(user.Id) {
|
||||||
|
// 设置pending session,等待2FA验证
|
||||||
|
session := sessions.Default(c)
|
||||||
|
session.Set("pending_username", user.Username)
|
||||||
|
session.Set("pending_user_id", user.Id)
|
||||||
|
err := session.Save()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "无法保存会话信息,请重试",
|
||||||
|
"success": false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "请输入两步验证码",
|
||||||
|
"success": true,
|
||||||
|
"data": map[string]interface{}{
|
||||||
|
"require_2fa": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
setupLogin(&user, c)
|
setupLogin(&user, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +193,7 @@ func Register(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "数据库错误,请稍后重试",
|
"message": "数据库错误,请稍后重试",
|
||||||
})
|
})
|
||||||
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if exist {
|
if exist {
|
||||||
@@ -183,15 +210,13 @@ func Register(c *gin.Context) {
|
|||||||
Password: user.Password,
|
Password: user.Password,
|
||||||
DisplayName: user.Username,
|
DisplayName: user.Username,
|
||||||
InviterId: inviterId,
|
InviterId: inviterId,
|
||||||
|
Role: common.RoleCommonUser, // 明确设置角色为普通用户
|
||||||
}
|
}
|
||||||
if common.EmailVerificationEnabled {
|
if common.EmailVerificationEnabled {
|
||||||
cleanUser.Email = user.Email
|
cleanUser.Email = user.Email
|
||||||
}
|
}
|
||||||
if err := cleanUser.Insert(inviterId); err != nil {
|
if err := cleanUser.Insert(inviterId); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,7 +237,7 @@ func Register(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成默认令牌失败",
|
"message": "生成默认令牌失败",
|
||||||
})
|
})
|
||||||
common.SysError("failed to generate token key: " + err.Error())
|
common.SysLog("failed to generate token key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 生成默认令牌
|
// 生成默认令牌
|
||||||
@@ -247,81 +272,45 @@ func Register(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUsers(c *gin.Context) {
|
func GetAllUsers(c *gin.Context) {
|
||||||
pageInfo, err := common.GetPageQuery(c)
|
pageInfo := common.GetPageQuery(c)
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": "parse page query failed",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
users, total, err := model.GetAllUsers(pageInfo)
|
users, total, err := model.GetAllUsers(pageInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pageInfo.SetTotal(int(total))
|
pageInfo.SetTotal(int(total))
|
||||||
pageInfo.SetItems(users)
|
pageInfo.SetItems(users)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
common.ApiSuccess(c, pageInfo)
|
||||||
"message": "",
|
|
||||||
"data": pageInfo,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUsers(c *gin.Context) {
|
func SearchUsers(c *gin.Context) {
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
group := c.Query("group")
|
group := c.Query("group")
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
if pageSize < 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
startIdx := (p - 1) * pageSize
|
|
||||||
users, total, err := model.SearchUsers(keyword, group, startIdx, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
pageInfo.SetTotal(int(total))
|
||||||
"message": "",
|
pageInfo.SetItems(users)
|
||||||
"data": gin.H{
|
common.ApiSuccess(c, pageInfo)
|
||||||
"items": users,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(c *gin.Context) {
|
func GetUser(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user, err := model.GetUserById(id, false)
|
user, err := model.GetUserById(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
myRole := c.GetInt("role")
|
myRole := c.GetInt("role")
|
||||||
@@ -344,10 +333,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, err := model.GetUserById(id, true)
|
user, err := model.GetUserById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// get rand int 28-32
|
// get rand int 28-32
|
||||||
@@ -358,7 +344,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成失败",
|
"message": "生成失败",
|
||||||
})
|
})
|
||||||
common.SysError("failed to generate key: " + err.Error())
|
common.SysLog("failed to generate key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.SetAccessToken(key)
|
user.SetAccessToken(key)
|
||||||
@@ -372,10 +358,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -395,18 +378,12 @@ func TransferAffQuota(c *gin.Context) {
|
|||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, err := model.GetUserById(id, true)
|
user, err := model.GetUserById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tran := TransferAffQuotaRequest{}
|
tran := TransferAffQuotaRequest{}
|
||||||
if err := c.ShouldBindJSON(&tran); err != nil {
|
if err := c.ShouldBindJSON(&tran); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = user.TransferAffQuotaToQuota(tran.Quota)
|
err = user.TransferAffQuotaToQuota(tran.Quota)
|
||||||
@@ -427,10 +404,7 @@ func GetAffCode(c *gin.Context) {
|
|||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, err := model.GetUserById(id, true)
|
user, err := model.GetUserById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if user.AffCode == "" {
|
if user.AffCode == "" {
|
||||||
@@ -453,25 +427,143 @@ func GetAffCode(c *gin.Context) {
|
|||||||
|
|
||||||
func GetSelf(c *gin.Context) {
|
func GetSelf(c *gin.Context) {
|
||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
|
userRole := c.GetInt("role")
|
||||||
user, err := model.GetUserById(id, false)
|
user, err := model.GetUserById(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||||
user.Remark = ""
|
user.Remark = ""
|
||||||
|
|
||||||
|
// 计算用户权限信息
|
||||||
|
permissions := calculateUserPermissions(userRole)
|
||||||
|
|
||||||
|
// 获取用户设置并提取sidebar_modules
|
||||||
|
userSetting := user.GetSetting()
|
||||||
|
|
||||||
|
// 构建响应数据,包含用户信息和权限
|
||||||
|
responseData := map[string]interface{}{
|
||||||
|
"id": user.Id,
|
||||||
|
"username": user.Username,
|
||||||
|
"display_name": user.DisplayName,
|
||||||
|
"role": user.Role,
|
||||||
|
"status": user.Status,
|
||||||
|
"email": user.Email,
|
||||||
|
"group": user.Group,
|
||||||
|
"quota": user.Quota,
|
||||||
|
"used_quota": user.UsedQuota,
|
||||||
|
"request_count": user.RequestCount,
|
||||||
|
"aff_code": user.AffCode,
|
||||||
|
"aff_count": user.AffCount,
|
||||||
|
"aff_quota": user.AffQuota,
|
||||||
|
"aff_history_quota": user.AffHistoryQuota,
|
||||||
|
"inviter_id": user.InviterId,
|
||||||
|
"linux_do_id": user.LinuxDOId,
|
||||||
|
"setting": user.Setting,
|
||||||
|
"stripe_customer": user.StripeCustomer,
|
||||||
|
"sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段
|
||||||
|
"permissions": permissions, // 新增权限字段
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": user,
|
"data": responseData,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 计算用户权限的辅助函数
|
||||||
|
func calculateUserPermissions(userRole int) map[string]interface{} {
|
||||||
|
permissions := map[string]interface{}{}
|
||||||
|
|
||||||
|
// 根据用户角色计算权限
|
||||||
|
if userRole == common.RoleRootUser {
|
||||||
|
// 超级管理员不需要边栏设置功能
|
||||||
|
permissions["sidebar_settings"] = false
|
||||||
|
permissions["sidebar_modules"] = map[string]interface{}{}
|
||||||
|
} else if userRole == common.RoleAdminUser {
|
||||||
|
// 管理员可以设置边栏,但不包含系统设置功能
|
||||||
|
permissions["sidebar_settings"] = true
|
||||||
|
permissions["sidebar_modules"] = map[string]interface{}{
|
||||||
|
"admin": map[string]interface{}{
|
||||||
|
"setting": false, // 管理员不能访问系统设置
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 普通用户只能设置个人功能,不包含管理员区域
|
||||||
|
permissions["sidebar_settings"] = true
|
||||||
|
permissions["sidebar_modules"] = map[string]interface{}{
|
||||||
|
"admin": false, // 普通用户不能访问管理员区域
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return permissions
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据用户角色生成默认的边栏配置
|
||||||
|
func generateDefaultSidebarConfig(userRole int) string {
|
||||||
|
defaultConfig := map[string]interface{}{}
|
||||||
|
|
||||||
|
// 聊天区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["chat"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"playground": true,
|
||||||
|
"chat": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 控制台区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["console"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"detail": true,
|
||||||
|
"token": true,
|
||||||
|
"log": true,
|
||||||
|
"midjourney": true,
|
||||||
|
"task": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 个人中心区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["personal"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"topup": true,
|
||||||
|
"personal": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 管理员区域 - 根据角色决定
|
||||||
|
if userRole == common.RoleAdminUser {
|
||||||
|
// 管理员可以访问管理员区域,但不能访问系统设置
|
||||||
|
defaultConfig["admin"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"channel": true,
|
||||||
|
"models": true,
|
||||||
|
"redemption": true,
|
||||||
|
"user": true,
|
||||||
|
"setting": false, // 管理员不能访问系统设置
|
||||||
|
}
|
||||||
|
} else if userRole == common.RoleRootUser {
|
||||||
|
// 超级管理员可以访问所有功能
|
||||||
|
defaultConfig["admin"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"channel": true,
|
||||||
|
"models": true,
|
||||||
|
"redemption": true,
|
||||||
|
"user": true,
|
||||||
|
"setting": true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 普通用户不包含admin区域
|
||||||
|
|
||||||
|
// 转换为JSON字符串
|
||||||
|
configBytes, err := json.Marshal(defaultConfig)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("生成默认边栏配置失败: " + err.Error())
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(configBytes)
|
||||||
|
}
|
||||||
|
|
||||||
func GetUserModels(c *gin.Context) {
|
func GetUserModels(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -479,10 +571,7 @@ func GetUserModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
user, err := model.GetUserCache(id)
|
user, err := model.GetUserCache(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
groups := setting.GetUserUsableGroups(user.Group)
|
groups := setting.GetUserUsableGroups(user.Group)
|
||||||
@@ -524,10 +613,7 @@ func UpdateUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
originUser, err := model.GetUserById(updatedUser.Id, false)
|
originUser, err := model.GetUserById(updatedUser.Id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
myRole := c.GetInt("role")
|
myRole := c.GetInt("role")
|
||||||
@@ -550,14 +636,11 @@ func UpdateUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
updatePassword := updatedUser.Password != ""
|
updatePassword := updatedUser.Password != ""
|
||||||
if err := updatedUser.Edit(updatePassword); err != nil {
|
if err := updatedUser.Edit(updatePassword); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if originUser.Quota != updatedUser.Quota {
|
if originUser.Quota != updatedUser.Quota {
|
||||||
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
|
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
@@ -567,8 +650,8 @@ func UpdateUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateSelf(c *gin.Context) {
|
func UpdateSelf(c *gin.Context) {
|
||||||
var user model.User
|
var requestData map[string]interface{}
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
err := json.NewDecoder(c.Request.Body).Decode(&requestData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -576,6 +659,60 @@ func UpdateSelf(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否是sidebar_modules更新请求
|
||||||
|
if sidebarModules, exists := requestData["sidebar_modules"]; exists {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
user, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前用户设置
|
||||||
|
currentSetting := user.GetSetting()
|
||||||
|
|
||||||
|
// 更新sidebar_modules字段
|
||||||
|
if sidebarModulesStr, ok := sidebarModules.(string); ok {
|
||||||
|
currentSetting.SidebarModules = sidebarModulesStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存更新后的设置
|
||||||
|
user.SetSetting(currentSetting)
|
||||||
|
if err := user.Update(false); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "更新设置失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "设置更新成功",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有的用户信息更新逻辑
|
||||||
|
var user model.User
|
||||||
|
requestDataBytes, err := json.Marshal(requestData)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的参数",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(requestDataBytes, &user)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的参数",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if user.Password == "" {
|
if user.Password == "" {
|
||||||
user.Password = "$I_LOVE_U" // make Validator happy :)
|
user.Password = "$I_LOVE_U" // make Validator happy :)
|
||||||
}
|
}
|
||||||
@@ -599,17 +736,11 @@ func UpdateSelf(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
|
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := cleanUser.Update(updatePassword); err != nil {
|
if err := cleanUser.Update(updatePassword); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -640,18 +771,12 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int
|
|||||||
func DeleteUser(c *gin.Context) {
|
func DeleteUser(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
originUser, err := model.GetUserById(id, false)
|
originUser, err := model.GetUserById(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
myRole := c.GetInt("role")
|
myRole := c.GetInt("role")
|
||||||
@@ -686,10 +811,7 @@ func DeleteSelf(c *gin.Context) {
|
|||||||
|
|
||||||
err := model.DeleteUserById(id)
|
err := model.DeleteUserById(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -733,12 +855,10 @@ func CreateUser(c *gin.Context) {
|
|||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Password: user.Password,
|
Password: user.Password,
|
||||||
DisplayName: user.DisplayName,
|
DisplayName: user.DisplayName,
|
||||||
|
Role: user.Role, // 保持管理员设置的角色
|
||||||
}
|
}
|
||||||
if err := cleanUser.Insert(0); err != nil {
|
if err := cleanUser.Insert(0); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -848,10 +968,7 @@ func ManageUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
clearUser := model.User{
|
clearUser := model.User{
|
||||||
@@ -883,20 +1000,14 @@ func EmailBind(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err := user.FillUserById()
|
err := user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.Email = email
|
user.Email = email
|
||||||
// no need to check if this email already taken, because we have used verification code to check it
|
// no need to check if this email already taken, because we have used verification code to check it
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -910,27 +1021,67 @@ type topUpRequest struct {
|
|||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var topUpLock = sync.Mutex{}
|
var topUpLocks sync.Map
|
||||||
|
var topUpCreateLock sync.Mutex
|
||||||
|
|
||||||
|
type topUpTryLock struct {
|
||||||
|
ch chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTopUpTryLock() *topUpTryLock {
|
||||||
|
return &topUpTryLock{ch: make(chan struct{}, 1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *topUpTryLock) TryLock() bool {
|
||||||
|
select {
|
||||||
|
case l.ch <- struct{}{}:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *topUpTryLock) Unlock() {
|
||||||
|
select {
|
||||||
|
case <-l.ch:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTopUpLock(userID int) *topUpTryLock {
|
||||||
|
if v, ok := topUpLocks.Load(userID); ok {
|
||||||
|
return v.(*topUpTryLock)
|
||||||
|
}
|
||||||
|
topUpCreateLock.Lock()
|
||||||
|
defer topUpCreateLock.Unlock()
|
||||||
|
if v, ok := topUpLocks.Load(userID); ok {
|
||||||
|
return v.(*topUpTryLock)
|
||||||
|
}
|
||||||
|
l := newTopUpTryLock()
|
||||||
|
topUpLocks.Store(userID, l)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
func TopUp(c *gin.Context) {
|
func TopUp(c *gin.Context) {
|
||||||
topUpLock.Lock()
|
id := c.GetInt("id")
|
||||||
defer topUpLock.Unlock()
|
lock := getTopUpLock(id)
|
||||||
req := topUpRequest{}
|
if !lock.TryLock() {
|
||||||
err := c.ShouldBindJSON(&req)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": "充值处理中,请稍后重试",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
id := c.GetInt("id")
|
defer lock.Unlock()
|
||||||
|
req := topUpRequest{}
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
quota, err := model.Redeem(req.Key, id)
|
quota, err := model.Redeem(req.Key, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -938,7 +1089,6 @@ func TopUp(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": quota,
|
"data": quota,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateUserSettingRequest struct {
|
type UpdateUserSettingRequest struct {
|
||||||
@@ -947,6 +1097,7 @@ type UpdateUserSettingRequest struct {
|
|||||||
WebhookUrl string `json:"webhook_url,omitempty"`
|
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||||
NotificationEmail string `json:"notification_email,omitempty"`
|
NotificationEmail string `json:"notification_email,omitempty"`
|
||||||
|
BarkUrl string `json:"bark_url,omitempty"`
|
||||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||||
RecordIpLog bool `json:"record_ip_log"`
|
RecordIpLog bool `json:"record_ip_log"`
|
||||||
}
|
}
|
||||||
@@ -962,7 +1113,7 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证预警类型
|
// 验证预警类型
|
||||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无效的预警类型",
|
"message": "无效的预警类型",
|
||||||
@@ -1010,13 +1161,37 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果是Bark类型,验证Bark URL
|
||||||
|
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||||
|
if req.BarkUrl == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "Bark推送URL不能为空",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 验证URL格式
|
||||||
|
if _, err := url.ParseRequestURI(req.BarkUrl); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的Bark推送URL",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 检查是否是HTTP或HTTPS
|
||||||
|
if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "Bark推送URL必须以http://或https://开头",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
user, err := model.GetUserById(userId, true)
|
user, err := model.GetUserById(userId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1041,6 +1216,11 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
settings.NotificationEmail = req.NotificationEmail
|
settings.NotificationEmail = req.NotificationEmail
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果是Bark类型,添加Bark URL到设置中
|
||||||
|
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||||
|
settings.BarkUrl = req.BarkUrl
|
||||||
|
}
|
||||||
|
|
||||||
// 更新用户设置
|
// 更新用户设置
|
||||||
user.SetSetting(settings)
|
user.SetSetting(settings)
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
|
|||||||
124
controller/vendor_meta.go
Normal file
124
controller/vendor_meta.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAllVendors 获取供应商列表(分页)
|
||||||
|
func GetAllVendors(c *gin.Context) {
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
model.DB.Model(&model.Vendor{}).Count(&total)
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(vendors)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchVendors 搜索供应商
|
||||||
|
func SearchVendors(c *gin.Context) {
|
||||||
|
keyword := c.Query("keyword")
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(vendors)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVendorMeta 根据 ID 获取供应商
|
||||||
|
func GetVendorMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v, err := model.GetVendorByID(id)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateVendorMeta 新建供应商
|
||||||
|
func CreateVendorMeta(c *gin.Context) {
|
||||||
|
var v model.Vendor
|
||||||
|
if err := c.ShouldBindJSON(&v); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v.Name == "" {
|
||||||
|
common.ApiErrorMsg(c, "供应商名称不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 创建前先检查名称
|
||||||
|
if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.Insert(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateVendorMeta 更新供应商
|
||||||
|
func UpdateVendorMeta(c *gin.Context) {
|
||||||
|
var v model.Vendor
|
||||||
|
if err := c.ShouldBindJSON(&v); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v.Id == 0 {
|
||||||
|
common.ApiErrorMsg(c, "缺少供应商 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.Update(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteVendorMeta 删除供应商
|
||||||
|
func DeleteVendorMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, nil)
|
||||||
|
}
|
||||||
@@ -4,13 +4,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wechatLoginResponse struct {
|
type wechatLoginResponse struct {
|
||||||
@@ -150,19 +151,13 @@ func WeChatBind(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = user.FillUserById()
|
err = user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.WeChatId = wechatId
|
user.WeChatId = wechatId
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ services:
|
|||||||
- REDIS_CONN_STRING=redis://redis
|
- REDIS_CONN_STRING=redis://redis
|
||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||||
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
||||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||||
|
|||||||
197
docs/api/web_api.md
Normal file
197
docs/api/web_api.md
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
# New API – Web 界面后端接口文档
|
||||||
|
|
||||||
|
> 本文档汇总了 **New API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
|
||||||
|
>
|
||||||
|
> 接口前缀统一为 `https://<your-domain>`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
|
||||||
|
>
|
||||||
|
> 鉴权级别说明:
|
||||||
|
> * **公开** – 不需要登录即可调用
|
||||||
|
> * **用户** – 需携带用户 Token(`middleware.UserAuth`)
|
||||||
|
> * **管理员** – 需管理员 Token(`middleware.AdminAuth`)
|
||||||
|
> * **Root** – 仅限最高权限 Root 用户(`middleware.RootAuth`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 初始化 / 系统状态
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/setup | 公开 | 获取系统初始化状态 |
|
||||||
|
| POST | /api/setup | 公开 | 完成首次安装向导 |
|
||||||
|
| GET | /api/status | 公开 | 获取运行状态摘要 |
|
||||||
|
| GET | /api/uptime/status | 公开 | Uptime-Kuma 兼容状态探针 |
|
||||||
|
| GET | /api/status/test | 管理员 | 测试后端与依赖组件是否正常 |
|
||||||
|
|
||||||
|
## 2. 公共信息
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/models | 用户 | 获取前端可用模型列表 |
|
||||||
|
| GET | /api/notice | 公开 | 获取公告栏内容 |
|
||||||
|
| GET | /api/about | 公开 | 关于页面信息 |
|
||||||
|
| GET | /api/home_page_content | 公开 | 首页自定义内容 |
|
||||||
|
| GET | /api/pricing | 可匿名/用户 | 价格与套餐信息 |
|
||||||
|
| GET | /api/ratio_config | 公开 | 模型倍率配置(仅公开字段) |
|
||||||
|
|
||||||
|
## 3. 邮件 / 身份验证
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/verification | 公开 (限流) | 发送邮箱验证邮件 |
|
||||||
|
| GET | /api/reset_password | 公开 (限流) | 发送重置密码邮件 |
|
||||||
|
| POST | /api/user/reset | 公开 | 提交重置密码请求 |
|
||||||
|
|
||||||
|
## 4. OAuth / 第三方登录
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
|
||||||
|
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
|
||||||
|
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
|
||||||
|
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
|
||||||
|
| GET | /api/oauth/wechat/bind | 公开 | 微信账户绑定 |
|
||||||
|
| GET | /api/oauth/email/bind | 公开 | 邮箱绑定 |
|
||||||
|
| GET | /api/oauth/telegram/login | 公开 | Telegram 登录 |
|
||||||
|
| GET | /api/oauth/telegram/bind | 公开 | Telegram 账户绑定 |
|
||||||
|
| GET | /api/oauth/state | 公开 | 获取随机 state(防 CSRF) |
|
||||||
|
|
||||||
|
## 5. 用户模块
|
||||||
|
### 5.1 账号注册/登录
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| POST | /api/user/register | 公开 | 注册新账号 |
|
||||||
|
| POST | /api/user/login | 公开 | 用户登录 |
|
||||||
|
| GET | /api/user/logout | 用户 | 退出登录 |
|
||||||
|
| GET | /api/user/epay/notify | 公开 | Epay 支付回调 |
|
||||||
|
| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
|
||||||
|
|
||||||
|
### 5.2 用户自身操作 (需登录)
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
|
||||||
|
| GET | /api/user/self | 用户 | 获取个人资料 |
|
||||||
|
| GET | /api/user/models | 用户 | 获取模型可见性 |
|
||||||
|
| PUT | /api/user/self | 用户 | 修改个人资料 |
|
||||||
|
| DELETE | /api/user/self | 用户 | 注销账号 |
|
||||||
|
| GET | /api/user/token | 用户 | 生成用户级别 Access Token |
|
||||||
|
| GET | /api/user/aff | 用户 | 获取推广码信息 |
|
||||||
|
| POST | /api/user/topup | 用户 | 余额直充 |
|
||||||
|
| POST | /api/user/pay | 用户 | 提交支付订单 |
|
||||||
|
| POST | /api/user/amount | 用户 | 余额支付 |
|
||||||
|
| POST | /api/user/aff_transfer | 用户 | 推广额度转账 |
|
||||||
|
| PUT | /api/user/setting | 用户 | 更新用户设置 |
|
||||||
|
|
||||||
|
### 5.3 管理员用户管理
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/user/ | 管理员 | 获取全部用户列表 |
|
||||||
|
| GET | /api/user/search | 管理员 | 搜索用户 |
|
||||||
|
| GET | /api/user/:id | 管理员 | 获取单个用户信息 |
|
||||||
|
| POST | /api/user/ | 管理员 | 创建用户 |
|
||||||
|
| POST | /api/user/manage | 管理员 | 冻结/重置等管理操作 |
|
||||||
|
| PUT | /api/user/ | 管理员 | 更新用户 |
|
||||||
|
| DELETE | /api/user/:id | 管理员 | 删除用户 |
|
||||||
|
|
||||||
|
## 6. 站点选项 (Root)
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/option/ | Root | 获取全局配置 |
|
||||||
|
| PUT | /api/option/ | Root | 更新全局配置 |
|
||||||
|
| POST | /api/option/rest_model_ratio | Root | 重置模型倍率 |
|
||||||
|
| POST | /api/option/migrate_console_setting | Root | 迁移旧版控制台配置 |
|
||||||
|
|
||||||
|
## 7. 模型倍率同步 (Root)
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/ratio_sync/channels | Root | 获取可同步渠道列表 |
|
||||||
|
| POST | /api/ratio_sync/fetch | Root | 从上游拉取倍率 |
|
||||||
|
|
||||||
|
## 8. 渠道管理 (管理员)
|
||||||
|
| 方法 | 路径 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| GET | /api/channel/ | 获取渠道列表 |
|
||||||
|
| GET | /api/channel/search | 搜索渠道 |
|
||||||
|
| GET | /api/channel/models | 查询渠道模型能力 |
|
||||||
|
| GET | /api/channel/models_enabled | 查询启用模型能力 |
|
||||||
|
| GET | /api/channel/:id | 获取单个渠道 |
|
||||||
|
| GET | /api/channel/test | 批量测试渠道连通性 |
|
||||||
|
| GET | /api/channel/test/:id | 单个渠道测试 |
|
||||||
|
| GET | /api/channel/update_balance | 批量刷新余额 |
|
||||||
|
| GET | /api/channel/update_balance/:id | 单个刷新余额 |
|
||||||
|
| POST | /api/channel/ | 新增渠道 |
|
||||||
|
| PUT | /api/channel/ | 更新渠道 |
|
||||||
|
| DELETE | /api/channel/disabled | 删除已禁用渠道 |
|
||||||
|
| POST | /api/channel/tag/disabled | 批量禁用标签渠道 |
|
||||||
|
| POST | /api/channel/tag/enabled | 批量启用标签渠道 |
|
||||||
|
| PUT | /api/channel/tag | 编辑渠道标签 |
|
||||||
|
| DELETE | /api/channel/:id | 删除渠道 |
|
||||||
|
| POST | /api/channel/batch | 批量删除渠道 |
|
||||||
|
| POST | /api/channel/fix | 修复渠道能力表 |
|
||||||
|
| GET | /api/channel/fetch_models/:id | 拉取单渠道模型 |
|
||||||
|
| POST | /api/channel/fetch_models | 拉取全部渠道模型 |
|
||||||
|
| POST | /api/channel/batch/tag | 批量设置渠道标签 |
|
||||||
|
| GET | /api/channel/tag/models | 根据标签获取模型 |
|
||||||
|
| POST | /api/channel/copy/:id | 复制渠道 |
|
||||||
|
|
||||||
|
## 9. Token 管理
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/token/ | 用户 | 获取全部 Token |
|
||||||
|
| GET | /api/token/search | 用户 | 搜索 Token |
|
||||||
|
| GET | /api/token/:id | 用户 | 获取单个 Token |
|
||||||
|
| POST | /api/token/ | 用户 | 创建 Token |
|
||||||
|
| PUT | /api/token/ | 用户 | 更新 Token |
|
||||||
|
| DELETE | /api/token/:id | 用户 | 删除 Token |
|
||||||
|
| POST | /api/token/batch | 用户 | 批量删除 Token |
|
||||||
|
|
||||||
|
## 10. 兑换码管理 (管理员)
|
||||||
|
| 方法 | 路径 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| GET | /api/redemption/ | 获取兑换码列表 |
|
||||||
|
| GET | /api/redemption/search | 搜索兑换码 |
|
||||||
|
| GET | /api/redemption/:id | 获取单个兑换码 |
|
||||||
|
| POST | /api/redemption/ | 创建兑换码 |
|
||||||
|
| PUT | /api/redemption/ | 更新兑换码 |
|
||||||
|
| DELETE | /api/redemption/invalid | 删除无效兑换码 |
|
||||||
|
| DELETE | /api/redemption/:id | 删除兑换码 |
|
||||||
|
|
||||||
|
## 11. 日志
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/log/ | 管理员 | 获取全部日志 |
|
||||||
|
| DELETE | /api/log/ | 管理员 | 删除历史日志 |
|
||||||
|
| GET | /api/log/stat | 管理员 | 日志统计 |
|
||||||
|
| GET | /api/log/self/stat | 用户 | 我的日志统计 |
|
||||||
|
| GET | /api/log/search | 管理员 | 搜索全部日志 |
|
||||||
|
| GET | /api/log/self | 用户 | 获取我的日志 |
|
||||||
|
| GET | /api/log/self/search | 用户 | 搜索我的日志 |
|
||||||
|
| GET | /api/log/token | 公开 | 根据 Token 查询日志(支持 CORS) |
|
||||||
|
|
||||||
|
## 12. 数据统计
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/data/ | 管理员 | 全站用量按日期统计 |
|
||||||
|
| GET | /api/data/self | 用户 | 我的用量按日期统计 |
|
||||||
|
|
||||||
|
## 13. 分组
|
||||||
|
| GET | /api/group/ | 管理员 | 获取全部分组列表 |
|
||||||
|
|
||||||
|
## 14. Midjourney 任务
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/mj/self | 用户 | 获取自己的 MJ 任务 |
|
||||||
|
| GET | /api/mj/ | 管理员 | 获取全部 MJ 任务 |
|
||||||
|
|
||||||
|
## 15. 任务中心
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/task/self | 用户 | 获取我的任务 |
|
||||||
|
| GET | /api/task/ | 管理员 | 获取全部任务 |
|
||||||
|
|
||||||
|
## 16. 账户计费面板 (Dashboard)
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /dashboard/billing/subscription | 用户 Token | 获取订阅额度信息 |
|
||||||
|
| GET | /v1/dashboard/billing/subscription | 同上 | 兼容 OpenAI SDK 路径 |
|
||||||
|
| GET | /dashboard/billing/usage | 用户 Token | 获取使用量信息 |
|
||||||
|
| GET | /v1/dashboard/billing/usage | 同上 | 兼容 OpenAI SDK 路径 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
> **更新日期**:2025.07.17
|
||||||
BIN
docs/images/aliyun.png
Normal file
BIN
docs/images/aliyun.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.0 KiB |
BIN
docs/images/cherry-studio.png
Normal file
BIN
docs/images/cherry-studio.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
BIN
docs/images/io-net.png
Normal file
BIN
docs/images/io-net.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.0 KiB |
BIN
docs/images/pku.png
Normal file
BIN
docs/images/pku.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
BIN
docs/images/ucloud.png
Normal file
BIN
docs/images/ucloud.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
24
dto/audio.go
24
dto/audio.go
@@ -1,5 +1,11 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
type AudioRequest struct {
|
type AudioRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input string `json:"input"`
|
Input string `json:"input"`
|
||||||
@@ -8,6 +14,24 @@ type AudioRequest struct {
|
|||||||
ResponseFormat string `json:"response_format,omitempty"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
meta := &types.TokenCountMeta{
|
||||||
|
CombineText: r.Input,
|
||||||
|
TokenType: types.TokenTypeTextNumber,
|
||||||
|
}
|
||||||
|
return meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AudioRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AudioRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type AudioResponse struct {
|
type AudioResponse struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,22 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type ChannelSettings struct {
|
type ChannelSettings struct {
|
||||||
ForceFormat bool `json:"force_format,omitempty"`
|
ForceFormat bool `json:"force_format,omitempty"`
|
||||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||||
Proxy string `json:"proxy"`
|
Proxy string `json:"proxy"`
|
||||||
|
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
|
||||||
|
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||||
|
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type VertexKeyType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
VertexKeyTypeJSON VertexKeyType = "json"
|
||||||
|
VertexKeyTypeAPIKey VertexKeyType = "api_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChannelOtherSettings struct {
|
||||||
|
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||||
|
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||||
}
|
}
|
||||||
|
|||||||
283
dto/claude.go
283
dto/claude.go
@@ -2,8 +2,12 @@ package dto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClaudeMetadata struct {
|
type ClaudeMetadata struct {
|
||||||
@@ -80,7 +84,7 @@ func (c *ClaudeMediaMessage) GetStringContent() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMediaMessage) GetJsonRowString() string {
|
func (c *ClaudeMediaMessage) GetJsonRowString() string {
|
||||||
jsonContent, _ := json.Marshal(c)
|
jsonContent, _ := common.Marshal(c)
|
||||||
return string(jsonContent)
|
return string(jsonContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,6 +163,27 @@ type InputSchema struct {
|
|||||||
Required any `json:"required,omitempty"`
|
Required any `json:"required,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ClaudeWebSearchTool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
MaxUses int `json:"max_uses,omitempty"`
|
||||||
|
UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeWebSearchUserLocation struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Timezone string `json:"timezone,omitempty"`
|
||||||
|
Country string `json:"country,omitempty"`
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
City string `json:"city,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeToolChoice struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type ClaudeRequest struct {
|
type ClaudeRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
@@ -177,6 +202,200 @@ type ClaudeRequest struct {
|
|||||||
Thinking *Thinking `json:"thinking,omitempty"`
|
Thinking *Thinking `json:"thinking,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var tokenCountMeta = types.TokenCountMeta{
|
||||||
|
TokenType: types.TokenTypeTokenizer,
|
||||||
|
MaxTokens: int(c.MaxTokens),
|
||||||
|
}
|
||||||
|
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
var fileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
|
// system
|
||||||
|
if c.System != nil {
|
||||||
|
if c.IsStringSystem() {
|
||||||
|
sys := c.GetStringSystem()
|
||||||
|
if sys != "" {
|
||||||
|
texts = append(texts, sys)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
systemMedia := c.ParseSystem()
|
||||||
|
for _, media := range systemMedia {
|
||||||
|
switch media.Type {
|
||||||
|
case "text":
|
||||||
|
texts = append(texts, media.GetText())
|
||||||
|
case "image":
|
||||||
|
if media.Source != nil {
|
||||||
|
data := media.Source.Url
|
||||||
|
if data == "" {
|
||||||
|
data = common.Interface2String(media.Source.Data)
|
||||||
|
}
|
||||||
|
if data != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// messages
|
||||||
|
for _, message := range c.Messages {
|
||||||
|
tokenCountMeta.MessagesCount++
|
||||||
|
texts = append(texts, message.Role)
|
||||||
|
if message.IsStringContent() {
|
||||||
|
content := message.GetStringContent()
|
||||||
|
if content != "" {
|
||||||
|
texts = append(texts, content)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
content, _ := message.ParseContent()
|
||||||
|
for _, media := range content {
|
||||||
|
switch media.Type {
|
||||||
|
case "text":
|
||||||
|
texts = append(texts, media.GetText())
|
||||||
|
case "image":
|
||||||
|
if media.Source != nil {
|
||||||
|
data := media.Source.Url
|
||||||
|
if data == "" {
|
||||||
|
data = common.Interface2String(media.Source.Data)
|
||||||
|
}
|
||||||
|
if data != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "tool_use":
|
||||||
|
if media.Name != "" {
|
||||||
|
texts = append(texts, media.Name)
|
||||||
|
}
|
||||||
|
if media.Input != nil {
|
||||||
|
b, _ := common.Marshal(media.Input)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
if media.Content != nil {
|
||||||
|
b, _ := common.Marshal(media.Content)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tools
|
||||||
|
if c.Tools != nil {
|
||||||
|
tools := c.GetTools()
|
||||||
|
normalTools, webSearchTools := ProcessTools(tools)
|
||||||
|
if normalTools != nil {
|
||||||
|
for _, t := range normalTools {
|
||||||
|
tokenCountMeta.ToolsCount++
|
||||||
|
if t.Name != "" {
|
||||||
|
texts = append(texts, t.Name)
|
||||||
|
}
|
||||||
|
if t.Description != "" {
|
||||||
|
texts = append(texts, t.Description)
|
||||||
|
}
|
||||||
|
if t.InputSchema != nil {
|
||||||
|
b, _ := common.Marshal(t.InputSchema)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if webSearchTools != nil {
|
||||||
|
for _, t := range webSearchTools {
|
||||||
|
tokenCountMeta.ToolsCount++
|
||||||
|
if t.Name != "" {
|
||||||
|
texts = append(texts, t.Name)
|
||||||
|
}
|
||||||
|
if t.UserLocation != nil {
|
||||||
|
b, _ := common.Marshal(t.UserLocation)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCountMeta.CombineText = strings.Join(texts, "\n")
|
||||||
|
tokenCountMeta.Files = fileMeta
|
||||||
|
return &tokenCountMeta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
|
||||||
|
return c.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
c.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
|
||||||
|
for _, message := range c.Messages {
|
||||||
|
content, _ := message.ParseContent()
|
||||||
|
for _, mediaMessage := range content {
|
||||||
|
if mediaMessage.Id == toolCallId {
|
||||||
|
return mediaMessage.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTool 添加工具到请求中
|
||||||
|
func (c *ClaudeRequest) AddTool(tool any) {
|
||||||
|
if c.Tools == nil {
|
||||||
|
c.Tools = make([]any, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tools := c.Tools.(type) {
|
||||||
|
case []any:
|
||||||
|
c.Tools = append(tools, tool)
|
||||||
|
default:
|
||||||
|
// 如果Tools不是[]any类型,重新初始化为[]any
|
||||||
|
c.Tools = []any{tool}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTools 获取工具列表
|
||||||
|
func (c *ClaudeRequest) GetTools() []any {
|
||||||
|
if c.Tools == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tools := c.Tools.(type) {
|
||||||
|
case []any:
|
||||||
|
return tools
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessTools 处理工具列表,支持类型断言
|
||||||
|
func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
|
||||||
|
var normalTools []*Tool
|
||||||
|
var webSearchTools []*ClaudeWebSearchTool
|
||||||
|
|
||||||
|
for _, tool := range tools {
|
||||||
|
switch t := tool.(type) {
|
||||||
|
case *Tool:
|
||||||
|
normalTools = append(normalTools, t)
|
||||||
|
case *ClaudeWebSearchTool:
|
||||||
|
webSearchTools = append(webSearchTools, t)
|
||||||
|
case Tool:
|
||||||
|
normalTools = append(normalTools, &t)
|
||||||
|
case ClaudeWebSearchTool:
|
||||||
|
webSearchTools = append(webSearchTools, &t)
|
||||||
|
default:
|
||||||
|
// 未知类型,跳过
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalTools, webSearchTools
|
||||||
|
}
|
||||||
|
|
||||||
type Thinking struct {
|
type Thinking struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||||
@@ -210,14 +429,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
|||||||
return mediaContent
|
return mediaContent
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeError struct {
|
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
Message string `json:"message,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeErrorWithStatusCode struct {
|
type ClaudeErrorWithStatusCode struct {
|
||||||
Error ClaudeError `json:"error"`
|
Error types.ClaudeError `json:"error"`
|
||||||
StatusCode int `json:"status_code"`
|
StatusCode int `json:"status_code"`
|
||||||
LocalError bool
|
LocalError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,7 +443,7 @@ type ClaudeResponse struct {
|
|||||||
Completion string `json:"completion,omitempty"`
|
Completion string `json:"completion,omitempty"`
|
||||||
StopReason string `json:"stop_reason,omitempty"`
|
StopReason string `json:"stop_reason,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Error *types.ClaudeError `json:"error,omitempty"`
|
Error any `json:"error,omitempty"`
|
||||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||||
Index *int `json:"index,omitempty"`
|
Index *int `json:"index,omitempty"`
|
||||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||||
@@ -250,9 +464,50 @@ func (c *ClaudeResponse) GetIndex() int {
|
|||||||
return *c.Index
|
return *c.Index
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeUsage struct {
|
// GetClaudeError 从动态错误类型中提取ClaudeError结构
|
||||||
InputTokens int `json:"input_tokens"`
|
func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
|
||||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
if c.Error == nil {
|
||||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
return nil
|
||||||
OutputTokens int `json:"output_tokens"`
|
}
|
||||||
|
|
||||||
|
switch err := c.Error.(type) {
|
||||||
|
case types.ClaudeError:
|
||||||
|
return &err
|
||||||
|
case *types.ClaudeError:
|
||||||
|
return err
|
||||||
|
case map[string]interface{}:
|
||||||
|
// 处理从JSON解析来的map结构
|
||||||
|
claudeErr := &types.ClaudeError{}
|
||||||
|
if errType, ok := err["type"].(string); ok {
|
||||||
|
claudeErr.Type = errType
|
||||||
|
}
|
||||||
|
if errMsg, ok := err["message"].(string); ok {
|
||||||
|
claudeErr.Message = errMsg
|
||||||
|
}
|
||||||
|
return claudeErr
|
||||||
|
case string:
|
||||||
|
// 处理简单字符串错误
|
||||||
|
return &types.ClaudeError{
|
||||||
|
Type: "upstream_error",
|
||||||
|
Message: err,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// 未知类型,尝试转换为字符串
|
||||||
|
return &types.ClaudeError{
|
||||||
|
Type: "unknown_upstream_error",
|
||||||
|
Message: fmt.Sprintf("unknown_error: %v", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||||
|
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeServerToolUse struct {
|
||||||
|
WebSearchRequests int `json:"web_search_requests"`
|
||||||
}
|
}
|
||||||
|
|||||||
29
dto/dalle.go
29
dto/dalle.go
@@ -1,29 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
import "encoding/json"
|
|
||||||
|
|
||||||
type ImageRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt" binding:"required"`
|
|
||||||
N int `json:"n,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Quality string `json:"quality,omitempty"`
|
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
|
||||||
Style string `json:"style,omitempty"`
|
|
||||||
User string `json:"user,omitempty"`
|
|
||||||
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
|
||||||
Background string `json:"background,omitempty"`
|
|
||||||
Moderation string `json:"moderation,omitempty"`
|
|
||||||
OutputFormat string `json:"output_format,omitempty"`
|
|
||||||
Watermark *bool `json:"watermark,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageResponse struct {
|
|
||||||
Data []ImageData `json:"data"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
}
|
|
||||||
type ImageData struct {
|
|
||||||
Url string `json:"url"`
|
|
||||||
B64Json string `json:"b64_json"`
|
|
||||||
RevisedPrompt string `json:"revised_prompt"`
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,12 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
type EmbeddingOptions struct {
|
type EmbeddingOptions struct {
|
||||||
Seed int `json:"seed,omitempty"`
|
Seed int `json:"seed,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
@@ -24,9 +31,32 @@ type EmbeddingRequest struct {
|
|||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r EmbeddingRequest) ParseInput() []string {
|
func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
|
||||||
|
inputs := r.ParseInput()
|
||||||
|
for _, input := range inputs {
|
||||||
|
texts = append(texts, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: strings.Join(texts, "\n"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRequest) ParseInput() []string {
|
||||||
if r.Input == nil {
|
if r.Input == nil {
|
||||||
return nil
|
return make([]string, 0)
|
||||||
}
|
}
|
||||||
var input []string
|
var input []string
|
||||||
switch r.Input.(type) {
|
switch r.Input.(type) {
|
||||||
|
|||||||
@@ -1,15 +1,117 @@
|
|||||||
package gemini
|
package dto
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type GeminiChatRequest struct {
|
type GeminiChatRequest struct {
|
||||||
Contents []GeminiChatContent `json:"contents"`
|
Contents []GeminiChatContent `json:"contents"`
|
||||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||||
Tools []GeminiChatTool `json:"tools,omitempty"`
|
Tools json.RawMessage `json:"tools,omitempty"`
|
||||||
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var files []*types.FileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
|
var maxTokens int
|
||||||
|
|
||||||
|
if r.GenerationConfig.MaxOutputTokens > 0 {
|
||||||
|
maxTokens = int(r.GenerationConfig.MaxOutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputTexts []string
|
||||||
|
for _, content := range r.Contents {
|
||||||
|
for _, part := range content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||||
|
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeAudio,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeVideo,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeFile,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: inputText,
|
||||||
|
Files: files,
|
||||||
|
MaxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
|
||||||
|
if c.Query("alt") == "sse" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) SetModelName(modelName string) {
|
||||||
|
// GeminiChatRequest does not have a model field, so this method does nothing.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
|
||||||
|
var tools []GeminiChatTool
|
||||||
|
if strings.HasSuffix(string(r.Tools), "[") {
|
||||||
|
// is array
|
||||||
|
if err := common.Unmarshal(r.Tools, &tools); err != nil {
|
||||||
|
logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(string(r.Tools), "{") {
|
||||||
|
// is object
|
||||||
|
singleTool := GeminiChatTool{}
|
||||||
|
if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
|
||||||
|
logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tools = []GeminiChatTool{singleTool}
|
||||||
|
}
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
|
||||||
|
if len(tools) == 0 {
|
||||||
|
r.Tools = json.RawMessage("[]")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal the tools to JSON
|
||||||
|
data, err := common.Marshal(tools)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(nil, "error_marshalling_tools: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Tools = data
|
||||||
|
}
|
||||||
|
|
||||||
type GeminiThinkingConfig struct {
|
type GeminiThinkingConfig struct {
|
||||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||||
@@ -32,7 +134,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
|||||||
MimeTypeSnake string `json:"mime_type"`
|
MimeTypeSnake string `json:"mime_type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &aux); err != nil {
|
if err := common.Unmarshal(data, &aux); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +155,7 @@ type FunctionCall struct {
|
|||||||
Arguments any `json:"args"`
|
Arguments any `json:"args"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FunctionResponse struct {
|
type GeminiFunctionResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Response map[string]interface{} `json:"response"`
|
Response map[string]interface{} `json:"response"`
|
||||||
}
|
}
|
||||||
@@ -78,7 +180,7 @@ type GeminiPart struct {
|
|||||||
Thought bool `json:"thought,omitempty"`
|
Thought bool `json:"thought,omitempty"`
|
||||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||||
@@ -93,7 +195,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error {
|
|||||||
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &aux); err != nil {
|
if err := common.Unmarshal(data, &aux); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,16 +309,76 @@ type GeminiImagePrediction struct {
|
|||||||
|
|
||||||
// Embedding related structs
|
// Embedding related structs
|
||||||
type GeminiEmbeddingRequest struct {
|
type GeminiEmbeddingRequest struct {
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
Content GeminiChatContent `json:"content"`
|
Content GeminiChatContent `json:"content"`
|
||||||
TaskType string `json:"taskType,omitempty"`
|
TaskType string `json:"taskType,omitempty"`
|
||||||
Title string `json:"title,omitempty"`
|
Title string `json:"title,omitempty"`
|
||||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||||
|
// Gemini embedding requests are not streamed
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var inputTexts []string
|
||||||
|
for _, part := range r.Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: inputText,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiBatchEmbeddingRequest struct {
|
||||||
|
Requests []*GeminiEmbeddingRequest `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||||
|
// Gemini batch embedding requests are not streamed
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var inputTexts []string
|
||||||
|
for _, request := range r.Requests {
|
||||||
|
meta := request.GetTokenCountMeta()
|
||||||
|
if meta != nil && meta.CombineText != "" {
|
||||||
|
inputTexts = append(inputTexts, meta.CombineText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: inputText,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
for _, req := range r.Requests {
|
||||||
|
req.SetModelName(modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type GeminiEmbeddingResponse struct {
|
type GeminiEmbeddingResponse struct {
|
||||||
Embedding ContentEmbedding `json:"embedding"`
|
Embedding ContentEmbedding `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GeminiBatchEmbeddingResponse struct {
|
||||||
|
Embeddings []*ContentEmbedding `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
type ContentEmbedding struct {
|
type ContentEmbedding struct {
|
||||||
Values []float64 `json:"values"`
|
Values []float64 `json:"values"`
|
||||||
}
|
}
|
||||||
172
dto/openai_image.go
Normal file
172
dto/openai_image.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt" binding:"required"`
|
||||||
|
N uint `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Style json.RawMessage `json:"style,omitempty"`
|
||||||
|
User json.RawMessage `json:"user,omitempty"`
|
||||||
|
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
||||||
|
Background json.RawMessage `json:"background,omitempty"`
|
||||||
|
Moderation json.RawMessage `json:"moderation,omitempty"`
|
||||||
|
OutputFormat json.RawMessage `json:"output_format,omitempty"`
|
||||||
|
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||||
|
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||||
|
// Stream bool `json:"stream,omitempty"`
|
||||||
|
Watermark *bool `json:"watermark,omitempty"`
|
||||||
|
// 用匿名参数接收额外参数
|
||||||
|
Extra map[string]json.RawMessage `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) UnmarshalJSON(data []byte) error {
|
||||||
|
// 先解析成 map[string]interface{}
|
||||||
|
var rawMap map[string]json.RawMessage
|
||||||
|
if err := common.Unmarshal(data, &rawMap); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用 struct tag 获取所有已定义字段名
|
||||||
|
knownFields := GetJSONFieldNames(reflect.TypeOf(*i))
|
||||||
|
|
||||||
|
// 再正常解析已定义字段
|
||||||
|
type Alias ImageRequest
|
||||||
|
var known Alias
|
||||||
|
if err := common.Unmarshal(data, &known); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*i = ImageRequest(known)
|
||||||
|
|
||||||
|
// 提取多余字段
|
||||||
|
i.Extra = make(map[string]json.RawMessage)
|
||||||
|
for k, v := range rawMap {
|
||||||
|
if _, ok := knownFields[k]; !ok {
|
||||||
|
i.Extra[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 序列化时需要重新把字段平铺
|
||||||
|
func (r ImageRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
// 将已定义字段转为 map
|
||||||
|
type Alias ImageRequest
|
||||||
|
alias := Alias(r)
|
||||||
|
base, err := common.Marshal(alias)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var baseMap map[string]json.RawMessage
|
||||||
|
if err := common.Unmarshal(base, &baseMap); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并 ExtraFields
|
||||||
|
for k, v := range r.Extra {
|
||||||
|
if _, exists := baseMap[k]; !exists {
|
||||||
|
baseMap[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(baseMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||||
|
fields := make(map[string]struct{})
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
|
||||||
|
// 跳过匿名字段(例如 ExtraFields)
|
||||||
|
if field.Anonymous {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tag := field.Tag.Get("json")
|
||||||
|
if tag == "-" || tag == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 取逗号前字段名(排除 omitempty 等)
|
||||||
|
name := tag
|
||||||
|
if commaIdx := indexComma(tag); commaIdx != -1 {
|
||||||
|
name = tag[:commaIdx]
|
||||||
|
}
|
||||||
|
fields[name] = struct{}{}
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexComma(s string) int {
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if s[i] == ',' {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var sizeRatio = 1.0
|
||||||
|
var qualityRatio = 1.0
|
||||||
|
|
||||||
|
if strings.HasPrefix(i.Model, "dall-e") {
|
||||||
|
// Size
|
||||||
|
if i.Size == "256x256" {
|
||||||
|
sizeRatio = 0.4
|
||||||
|
} else if i.Size == "512x512" {
|
||||||
|
sizeRatio = 0.45
|
||||||
|
} else if i.Size == "1024x1024" {
|
||||||
|
sizeRatio = 1
|
||||||
|
} else if i.Size == "1024x1792" || i.Size == "1792x1024" {
|
||||||
|
sizeRatio = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.Model == "dall-e-3" && i.Quality == "hd" {
|
||||||
|
qualityRatio = 2.0
|
||||||
|
if i.Size == "1024x1792" || i.Size == "1792x1024" {
|
||||||
|
qualityRatio = 1.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// not support token count for dalle
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: i.Prompt,
|
||||||
|
MaxTokens: 1584,
|
||||||
|
ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
i.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageResponse struct {
|
||||||
|
Data []ImageData `json:"data"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Extra any `json:"extra,omitempty"`
|
||||||
|
}
|
||||||
|
type ImageData struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
B64Json string `json:"b64_json"`
|
||||||
|
RevisedPrompt string `json:"revised_prompt"`
|
||||||
|
}
|
||||||
@@ -2,20 +2,24 @@ package dto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
|
JsonSchema json.RawMessage `json:"json_schema,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FormatJsonSchema struct {
|
type FormatJsonSchema struct {
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Schema any `json:"schema,omitempty"`
|
Schema any `json:"schema,omitempty"`
|
||||||
Strict any `json:"strict,omitempty"`
|
Strict json.RawMessage `json:"strict,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeneralOpenAIRequest struct {
|
type GeneralOpenAIRequest struct {
|
||||||
@@ -29,6 +33,7 @@ type GeneralOpenAIRequest struct {
|
|||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||||
|
Verbosity json.RawMessage `json:"verbosity,omitempty"` // gpt-5
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
@@ -52,15 +57,142 @@ type GeneralOpenAIRequest struct {
|
|||||||
Dimensions int `json:"dimensions,omitempty"`
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||||
Audio json.RawMessage `json:"audio,omitempty"`
|
Audio json.RawMessage `json:"audio,omitempty"`
|
||||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
// gemini
|
||||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
//xai
|
||||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
SearchParameters json.RawMessage `json:"search_parameters,omitempty"`
|
||||||
|
// claude
|
||||||
|
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||||
// OpenRouter Params
|
// OpenRouter Params
|
||||||
Usage json.RawMessage `json:"usage,omitempty"`
|
Usage json.RawMessage `json:"usage,omitempty"`
|
||||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||||
// Ali Qwen Params
|
// Ali Qwen Params
|
||||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||||
|
EnableThinking any `json:"enable_thinking,omitempty"`
|
||||||
|
// ollama Params
|
||||||
|
Think json.RawMessage `json:"think,omitempty"`
|
||||||
|
// baidu v2
|
||||||
|
WebSearch json.RawMessage `json:"web_search,omitempty"`
|
||||||
|
// doubao,zhipu_v4
|
||||||
|
THINKING json.RawMessage `json:"thinking,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var tokenCountMeta types.TokenCountMeta
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
var fileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
|
if r.Prompt != nil {
|
||||||
|
switch v := r.Prompt.(type) {
|
||||||
|
case string:
|
||||||
|
texts = append(texts, v)
|
||||||
|
case []any:
|
||||||
|
for _, item := range v {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
texts = append(texts, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
texts = append(texts, fmt.Sprintf("%v", r.Prompt))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Input != nil {
|
||||||
|
inputs := r.ParseInput()
|
||||||
|
texts = append(texts, inputs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.MaxCompletionTokens > r.MaxTokens {
|
||||||
|
tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
|
||||||
|
} else {
|
||||||
|
tokenCountMeta.MaxTokens = int(r.MaxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, message := range r.Messages {
|
||||||
|
tokenCountMeta.MessagesCount++
|
||||||
|
texts = append(texts, message.Role)
|
||||||
|
if message.Content != nil {
|
||||||
|
if message.Name != nil {
|
||||||
|
tokenCountMeta.NameCount++
|
||||||
|
texts = append(texts, *message.Name)
|
||||||
|
}
|
||||||
|
arrayContent := message.ParseContent()
|
||||||
|
for _, m := range arrayContent {
|
||||||
|
if m.Type == ContentTypeImageURL {
|
||||||
|
imageUrl := m.GetImageMedia()
|
||||||
|
if imageUrl != nil {
|
||||||
|
if imageUrl.Url != "" {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
}
|
||||||
|
meta.OriginData = imageUrl.Url
|
||||||
|
meta.Detail = imageUrl.Detail
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if m.Type == ContentTypeInputAudio {
|
||||||
|
inputAudio := m.GetInputAudio()
|
||||||
|
if inputAudio != nil {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeAudio,
|
||||||
|
}
|
||||||
|
meta.OriginData = inputAudio.Data
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
} else if m.Type == ContentTypeFile {
|
||||||
|
file := m.GetFile()
|
||||||
|
if file != nil {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeFile,
|
||||||
|
}
|
||||||
|
meta.OriginData = file.FileData
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
} else if m.Type == ContentTypeVideoUrl {
|
||||||
|
videoUrl := m.GetVideoUrl()
|
||||||
|
if videoUrl != nil && videoUrl.Url != "" {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeVideo,
|
||||||
|
}
|
||||||
|
meta.OriginData = videoUrl.Url
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
texts = append(texts, m.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Tools != nil {
|
||||||
|
openaiTools := r.Tools
|
||||||
|
for _, tool := range openaiTools {
|
||||||
|
tokenCountMeta.ToolsCount++
|
||||||
|
texts = append(texts, tool.Function.Name)
|
||||||
|
if tool.Function.Description != "" {
|
||||||
|
texts = append(texts, tool.Function.Description)
|
||||||
|
}
|
||||||
|
if tool.Function.Parameters != nil {
|
||||||
|
texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//toolTokens := CountTokenInput(countStr, request.Model)
|
||||||
|
//tkm += 8
|
||||||
|
//tkm += toolTokens
|
||||||
|
}
|
||||||
|
tokenCountMeta.CombineText = strings.Join(texts, "\n")
|
||||||
|
tokenCountMeta.Files = fileMeta
|
||||||
|
return &tokenCountMeta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return r.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||||
@@ -70,6 +202,17 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||||
|
if strings.HasPrefix(r.Model, "o") {
|
||||||
|
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||||
|
return "developer"
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(r.Model, "gpt-5") {
|
||||||
|
return "developer"
|
||||||
|
}
|
||||||
|
return "system"
|
||||||
|
}
|
||||||
|
|
||||||
type ToolCallRequest struct {
|
type ToolCallRequest struct {
|
||||||
ID string `json:"id,omitempty"`
|
ID string `json:"id,omitempty"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@@ -87,8 +230,11 @@ type StreamOptions struct {
|
|||||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) GetMaxTokens() int {
|
func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
|
||||||
return int(r.MaxTokens)
|
if r.MaxCompletionTokens != 0 {
|
||||||
|
return r.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
return r.MaxTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
||||||
@@ -184,6 +330,21 @@ func (m *MediaContent) GetFile() *MessageFile {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
|
||||||
|
if m.VideoUrl != nil {
|
||||||
|
if _, ok := m.VideoUrl.(*MessageVideoUrl); ok {
|
||||||
|
return m.VideoUrl.(*MessageVideoUrl)
|
||||||
|
}
|
||||||
|
if itemMap, ok := m.VideoUrl.(map[string]any); ok {
|
||||||
|
out := &MessageVideoUrl{
|
||||||
|
Url: common.Interface2String(itemMap["url"]),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type MessageImageUrl struct {
|
type MessageImageUrl struct {
|
||||||
Url string `json:"url"`
|
Url string `json:"url"`
|
||||||
Detail string `json:"detail"`
|
Detail string `json:"detail"`
|
||||||
@@ -215,6 +376,7 @@ const (
|
|||||||
ContentTypeInputAudio = "input_audio"
|
ContentTypeInputAudio = "input_audio"
|
||||||
ContentTypeFile = "file"
|
ContentTypeFile = "file"
|
||||||
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
|
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
|
||||||
|
//ContentTypeAudioUrl = "audio_url"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *Message) GetPrefix() bool {
|
func (m *Message) GetPrefix() bool {
|
||||||
@@ -602,26 +764,107 @@ type WebSearchOptions struct {
|
|||||||
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/responses/create
|
||||||
type OpenAIResponsesRequest struct {
|
type OpenAIResponsesRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input json.RawMessage `json:"input,omitempty"`
|
Input json.RawMessage `json:"input,omitempty"`
|
||||||
Include json.RawMessage `json:"include,omitempty"`
|
Include json.RawMessage `json:"include,omitempty"`
|
||||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
|
||||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||||
ServiceTier string `json:"service_tier,omitempty"`
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
Store bool `json:"store,omitempty"`
|
Store json.RawMessage `json:"store,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Text json.RawMessage `json:"text,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
Text json.RawMessage `json:"text,omitempty"`
|
||||||
Tools []ResponsesToolsCall `json:"tools,omitempty"`
|
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||||
Truncation string `json:"truncation,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
Truncation string `json:"truncation,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||||
|
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var fileMeta = make([]*types.FileMeta, 0)
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
|
||||||
|
if r.Input != nil {
|
||||||
|
inputs := r.ParseInput()
|
||||||
|
for _, input := range inputs {
|
||||||
|
if input.Type == "input_image" {
|
||||||
|
if input.ImageUrl != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
OriginData: input.ImageUrl,
|
||||||
|
Detail: input.Detail,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if input.Type == "input_file" {
|
||||||
|
if input.FileUrl != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeFile,
|
||||||
|
OriginData: input.FileUrl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
texts = append(texts, input.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Instructions) > 0 {
|
||||||
|
texts = append(texts, string(r.Instructions))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Metadata) > 0 {
|
||||||
|
texts = append(texts, string(r.Metadata))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Text) > 0 {
|
||||||
|
texts = append(texts, string(r.Text))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.ToolChoice) > 0 {
|
||||||
|
texts = append(texts, string(r.ToolChoice))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Prompt) > 0 {
|
||||||
|
texts = append(texts, string(r.Prompt))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Tools) > 0 {
|
||||||
|
texts = append(texts, string(r.Tools))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: strings.Join(texts, "\n"),
|
||||||
|
Files: fileMeta,
|
||||||
|
MaxTokens: int(r.MaxOutputTokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return r.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any {
|
||||||
|
var toolsMap []map[string]any
|
||||||
|
if len(r.Tools) > 0 {
|
||||||
|
_ = common.Unmarshal(r.Tools, &toolsMap)
|
||||||
|
}
|
||||||
|
return toolsMap
|
||||||
}
|
}
|
||||||
|
|
||||||
type Reasoning struct {
|
type Reasoning struct {
|
||||||
@@ -629,23 +872,88 @@ type Reasoning struct {
|
|||||||
Summary string `json:"summary,omitempty"`
|
Summary string `json:"summary,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesToolsCall struct {
|
type MediaInput struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Web Search
|
Text string `json:"text,omitempty"`
|
||||||
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
FileUrl string `json:"file_url,omitempty"`
|
||||||
SearchContextSize string `json:"search_context_size,omitempty"`
|
ImageUrl string `json:"image_url,omitempty"`
|
||||||
// File Search
|
Detail string `json:"detail,omitempty"` // 仅 input_image 有效
|
||||||
VectorStoreIds []string `json:"vector_store_ids,omitempty"`
|
}
|
||||||
MaxNumResults uint `json:"max_num_results,omitempty"`
|
|
||||||
Filters json.RawMessage `json:"filters,omitempty"`
|
// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput.
|
||||||
// Computer Use
|
// Reference implementation mirrors Message.ParseContent:
|
||||||
DisplayWidth uint `json:"display_width,omitempty"`
|
// - input can be a string, treated as an input_text item
|
||||||
DisplayHeight uint `json:"display_height,omitempty"`
|
// - input can be an array of objects with a `type` field
|
||||||
Environment string `json:"environment,omitempty"`
|
// supported types: input_text, input_image, input_file
|
||||||
// Function
|
func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||||
Name string `json:"name,omitempty"`
|
if r.Input == nil {
|
||||||
Description string `json:"description,omitempty"`
|
return nil
|
||||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
}
|
||||||
Function json.RawMessage `json:"function,omitempty"`
|
|
||||||
Container json.RawMessage `json:"container,omitempty"`
|
var inputs []MediaInput
|
||||||
|
|
||||||
|
// Try string first
|
||||||
|
// if str, ok := common.GetJsonType(r.Input); ok {
|
||||||
|
// inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||||
|
// return inputs
|
||||||
|
// }
|
||||||
|
if common.GetJsonType(r.Input) == "string" {
|
||||||
|
var str string
|
||||||
|
_ = common.Unmarshal(r.Input, &str)
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||||
|
return inputs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try array of parts
|
||||||
|
if common.GetJsonType(r.Input) == "array" {
|
||||||
|
var array []any
|
||||||
|
_ = common.Unmarshal(r.Input, &array)
|
||||||
|
for _, itemAny := range array {
|
||||||
|
// Already parsed MediaInput
|
||||||
|
if media, ok := itemAny.(MediaInput); ok {
|
||||||
|
inputs = append(inputs, media)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Generic map
|
||||||
|
item, ok := itemAny.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
typeVal, ok := item["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch typeVal {
|
||||||
|
case "input_text":
|
||||||
|
text, _ := item["text"].(string)
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
|
||||||
|
case "input_image":
|
||||||
|
// image_url may be string or object with url field
|
||||||
|
var imageUrl string
|
||||||
|
switch v := item["image_url"].(type) {
|
||||||
|
case string:
|
||||||
|
imageUrl = v
|
||||||
|
case map[string]any:
|
||||||
|
if url, ok := v["url"].(string); ok {
|
||||||
|
imageUrl = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
|
||||||
|
case "input_file":
|
||||||
|
// file_url may be string or object with url field
|
||||||
|
var fileUrl string
|
||||||
|
switch v := item["file_url"].(type) {
|
||||||
|
case string:
|
||||||
|
fileUrl = v
|
||||||
|
case map[string]any:
|
||||||
|
if url, ok := v["url"].(string); ok {
|
||||||
|
fileUrl = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,22 @@ package dto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ResponsesOutputTypeImageGenerationCall = "image_generation_call"
|
||||||
|
)
|
||||||
|
|
||||||
type SimpleResponse struct {
|
type SimpleResponse struct {
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
Error *OpenAIError `json:"error"`
|
Error any `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
|
||||||
|
return GetOpenAIError(s.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextResponse struct {
|
type TextResponse struct {
|
||||||
@@ -31,10 +41,15 @@ type OpenAITextResponse struct {
|
|||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created any `json:"created"`
|
Created any `json:"created"`
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||||
Error *types.OpenAIError `json:"error,omitempty"`
|
Error any `json:"error,omitempty"`
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
|
||||||
|
return GetOpenAIError(o.Error)
|
||||||
|
}
|
||||||
|
|
||||||
type OpenAIEmbeddingResponseItem struct {
|
type OpenAIEmbeddingResponseItem struct {
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
@@ -48,6 +63,19 @@ type OpenAIEmbeddingResponse struct {
|
|||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FlexibleEmbeddingResponseItem struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Embedding any `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlexibleEmbeddingResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []FlexibleEmbeddingResponseItem `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseChoice struct {
|
type ChatCompletionsStreamResponseChoice struct {
|
||||||
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
|
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
|
||||||
Logprobs *any `json:"logprobs"`
|
Logprobs *any `json:"logprobs"`
|
||||||
@@ -86,7 +114,7 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
|
|||||||
|
|
||||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
|
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
|
||||||
c.ReasoningContent = &s
|
c.ReasoningContent = &s
|
||||||
c.Reasoning = &s
|
//c.Reasoning = &s
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCallResponse struct {
|
type ToolCallResponse struct {
|
||||||
@@ -119,6 +147,13 @@ type ChatCompletionsStreamResponse struct {
|
|||||||
Usage *Usage `json:"usage"`
|
Usage *Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) IsFinished() bool {
|
||||||
|
if len(c.Choices) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != ""
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
|
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
|
||||||
if len(c.Choices) == 0 {
|
if len(c.Choices) == 0 {
|
||||||
return false
|
return false
|
||||||
@@ -133,6 +168,19 @@ func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) ClearToolCalls() {
|
||||||
|
if !c.IsToolCall() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for choiceIdx := range c.Choices {
|
||||||
|
for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls {
|
||||||
|
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = ""
|
||||||
|
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil
|
||||||
|
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
||||||
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
||||||
copy(choices, c.Choices)
|
copy(choices, c.Choices)
|
||||||
@@ -182,7 +230,7 @@ type Usage struct {
|
|||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||||
// OpenRouter Params
|
// OpenRouter Params
|
||||||
Cost float64 `json:"cost,omitempty"`
|
Cost any `json:"cost,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type InputTokenDetails struct {
|
type InputTokenDetails struct {
|
||||||
@@ -200,28 +248,69 @@ type OutputTokenDetails struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIResponsesResponse struct {
|
type OpenAIResponsesResponse struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
CreatedAt int `json:"created_at"`
|
CreatedAt int `json:"created_at"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Error *types.OpenAIError `json:"error,omitempty"`
|
Error any `json:"error,omitempty"`
|
||||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||||
Instructions string `json:"instructions"`
|
Instructions string `json:"instructions"`
|
||||||
MaxOutputTokens int `json:"max_output_tokens"`
|
MaxOutputTokens int `json:"max_output_tokens"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Output []ResponsesOutput `json:"output"`
|
Output []ResponsesOutput `json:"output"`
|
||||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||||
PreviousResponseID string `json:"previous_response_id"`
|
PreviousResponseID string `json:"previous_response_id"`
|
||||||
Reasoning *Reasoning `json:"reasoning"`
|
Reasoning *Reasoning `json:"reasoning"`
|
||||||
Store bool `json:"store"`
|
Store bool `json:"store"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
ToolChoice string `json:"tool_choice"`
|
ToolChoice string `json:"tool_choice"`
|
||||||
Tools []ResponsesToolsCall `json:"tools"`
|
Tools []map[string]any `json:"tools"`
|
||||||
TopP float64 `json:"top_p"`
|
TopP float64 `json:"top_p"`
|
||||||
Truncation string `json:"truncation"`
|
Truncation string `json:"truncation"`
|
||||||
Usage *Usage `json:"usage"`
|
Usage *Usage `json:"usage"`
|
||||||
User json.RawMessage `json:"user"`
|
User json.RawMessage `json:"user"`
|
||||||
Metadata json.RawMessage `json:"metadata"`
|
Metadata json.RawMessage `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
|
||||||
|
return GetOpenAIError(o.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool {
|
||||||
|
if len(o.Output) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, output := range o.Output {
|
||||||
|
if output.Type == ResponsesOutputTypeImageGenerationCall {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAIResponsesResponse) GetQuality() string {
|
||||||
|
if len(o.Output) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, output := range o.Output {
|
||||||
|
if output.Type == ResponsesOutputTypeImageGenerationCall {
|
||||||
|
return output.Quality
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAIResponsesResponse) GetSize() string {
|
||||||
|
if len(o.Output) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, output := range o.Output {
|
||||||
|
if output.Type == ResponsesOutputTypeImageGenerationCall {
|
||||||
|
return output.Size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type IncompleteDetails struct {
|
type IncompleteDetails struct {
|
||||||
@@ -234,6 +323,8 @@ type ResponsesOutput struct {
|
|||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content []ResponsesOutputContent `json:"content"`
|
Content []ResponsesOutputContent `json:"content"`
|
||||||
|
Quality string `json:"quality"`
|
||||||
|
Size string `json:"size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesOutputContent struct {
|
type ResponsesOutputContent struct {
|
||||||
@@ -263,3 +354,45 @@ type ResponsesStreamResponse struct {
|
|||||||
Delta string `json:"delta,omitempty"`
|
Delta string `json:"delta,omitempty"`
|
||||||
Item *ResponsesOutput `json:"item,omitempty"`
|
Item *ResponsesOutput `json:"item,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func GetOpenAIError(errorField any) *types.OpenAIError {
|
||||||
|
if errorField == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch err := errorField.(type) {
|
||||||
|
case types.OpenAIError:
|
||||||
|
return &err
|
||||||
|
case *types.OpenAIError:
|
||||||
|
return err
|
||||||
|
case map[string]interface{}:
|
||||||
|
// 处理从JSON解析来的map结构
|
||||||
|
openaiErr := &types.OpenAIError{}
|
||||||
|
if errType, ok := err["type"].(string); ok {
|
||||||
|
openaiErr.Type = errType
|
||||||
|
}
|
||||||
|
if errMsg, ok := err["message"].(string); ok {
|
||||||
|
openaiErr.Message = errMsg
|
||||||
|
}
|
||||||
|
if errParam, ok := err["param"].(string); ok {
|
||||||
|
openaiErr.Param = errParam
|
||||||
|
}
|
||||||
|
if errCode, ok := err["code"]; ok {
|
||||||
|
openaiErr.Code = errCode
|
||||||
|
}
|
||||||
|
return openaiErr
|
||||||
|
case string:
|
||||||
|
// 处理简单字符串错误
|
||||||
|
return &types.OpenAIError{
|
||||||
|
Type: "error",
|
||||||
|
Message: err,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// 未知类型,尝试转换为字符串
|
||||||
|
return &types.OpenAIError{
|
||||||
|
Type: "unknown_error",
|
||||||
|
Message: fmt.Sprintf("%v", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package dto
|
|||||||
|
|
||||||
import "one-api/constant"
|
import "one-api/constant"
|
||||||
|
|
||||||
|
// 这里不好动就不动了,本来想独立出来的(
|
||||||
type OpenAIModels struct {
|
type OpenAIModels struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
@@ -9,3 +10,26 @@ type OpenAIModels struct {
|
|||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy string `json:"owned_by"`
|
||||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AnthropicModel struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiModel struct {
|
||||||
|
Name interface{} `json:"name"`
|
||||||
|
BaseModelId interface{} `json:"baseModelId"`
|
||||||
|
Version interface{} `json:"version"`
|
||||||
|
DisplayName interface{} `json:"displayName"`
|
||||||
|
Description interface{} `json:"description"`
|
||||||
|
InputTokenLimit interface{} `json:"inputTokenLimit"`
|
||||||
|
OutputTokenLimit interface{} `json:"outputTokenLimit"`
|
||||||
|
SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
|
||||||
|
Thinking interface{} `json:"thinking"`
|
||||||
|
Temperature interface{} `json:"temperature"`
|
||||||
|
MaxTemperature interface{} `json:"maxTemperature"`
|
||||||
|
TopP interface{} `json:"topP"`
|
||||||
|
TopK interface{} `json:"topK"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,23 +1,23 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type UpstreamDTO struct {
|
type UpstreamDTO struct {
|
||||||
ID int `json:"id,omitempty"`
|
ID int `json:"id,omitempty"`
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
BaseURL string `json:"base_url" binding:"required"`
|
BaseURL string `json:"base_url" binding:"required"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpstreamRequest struct {
|
type UpstreamRequest struct {
|
||||||
ChannelIDs []int64 `json:"channel_ids"`
|
ChannelIDs []int64 `json:"channel_ids"`
|
||||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||||
Timeout int `json:"timeout"`
|
Timeout int `json:"timeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestResult 上游测试连通性结果
|
// TestResult 上游测试连通性结果
|
||||||
type TestResult struct {
|
type TestResult struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DifferenceItem 差异项
|
// DifferenceItem 差异项
|
||||||
@@ -25,14 +25,14 @@ type TestResult struct {
|
|||||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||||
|
|
||||||
type DifferenceItem struct {
|
type DifferenceItem struct {
|
||||||
Current interface{} `json:"current"`
|
Current interface{} `json:"current"`
|
||||||
Upstreams map[string]interface{} `json:"upstreams"`
|
Upstreams map[string]interface{} `json:"upstreams"`
|
||||||
Confidence map[string]bool `json:"confidence"`
|
Confidence map[string]bool `json:"confidence"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncableChannel struct {
|
type SyncableChannel struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
BaseURL string `json:"base_url"`
|
BaseURL string `json:"base_url"`
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
}
|
}
|
||||||
|
|||||||
25
dto/request_common.go
Normal file
25
dto/request_common.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Request interface {
|
||||||
|
GetTokenCountMeta() *types.TokenCountMeta
|
||||||
|
IsStream(c *gin.Context) bool
|
||||||
|
SetModelName(modelName string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseRequest struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
TokenType: types.TokenTypeTokenizer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (b *BaseRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
func (b *BaseRequest) SetModelName(modelName string) {}
|
||||||
@@ -1,5 +1,12 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type RerankRequest struct {
|
type RerankRequest struct {
|
||||||
Documents []any `json:"documents"`
|
Documents []any `json:"documents"`
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
@@ -10,6 +17,32 @@ type RerankRequest struct {
|
|||||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *RerankRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
|
||||||
|
for _, document := range r.Documents {
|
||||||
|
texts = append(texts, fmt.Sprintf("%v", document))
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Query != "" {
|
||||||
|
texts = append(texts, r.Query)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: strings.Join(texts, "\n"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RerankRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *RerankRequest) GetReturnDocuments() bool {
|
func (r *RerankRequest) GetReturnDocuments() bool {
|
||||||
if r.ReturnDocuments == nil {
|
if r.ReturnDocuments == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ type UserSetting struct {
|
|||||||
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
||||||
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
||||||
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
||||||
|
BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL
|
||||||
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||||
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
||||||
|
SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
NotifyTypeEmail = "email" // Email 邮件
|
NotifyTypeEmail = "email" // Email 邮件
|
||||||
NotifyTypeWebhook = "webhook" // Webhook
|
NotifyTypeWebhook = "webhook" // Webhook
|
||||||
|
NotifyTypeBark = "bark" // Bark 推送
|
||||||
)
|
)
|
||||||
|
|||||||
23
go.mod
23
go.mod
@@ -7,9 +7,10 @@ require (
|
|||||||
github.com/Calcium-Ion/go-epay v0.0.4
|
github.com/Calcium-Ion/go-epay v0.0.4
|
||||||
github.com/andybalholm/brotli v1.1.1
|
github.com/andybalholm/brotli v1.1.1
|
||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
||||||
github.com/aws/aws-sdk-go-v2 v1.26.1
|
github.com/aws/aws-sdk-go-v2 v1.37.2
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
|
||||||
|
github.com/aws/smithy-go v1.22.5
|
||||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||||
github.com/gin-contrib/cors v1.7.2
|
github.com/gin-contrib/cors v1.7.2
|
||||||
github.com/gin-contrib/gzip v0.0.6
|
github.com/gin-contrib/gzip v0.0.6
|
||||||
@@ -22,15 +23,22 @@ require (
|
|||||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
|
github.com/jinzhu/copier v0.4.0
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
|
github.com/pquerna/otp v1.5.0
|
||||||
github.com/samber/lo v1.39.0
|
github.com/samber/lo v1.39.0
|
||||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||||
github.com/shopspring/decimal v1.4.0
|
github.com/shopspring/decimal v1.4.0
|
||||||
|
github.com/stripe/stripe-go/v81 v81.4.0
|
||||||
|
github.com/thanhpk/randstr v1.0.6
|
||||||
|
github.com/tidwall/gjson v1.18.0
|
||||||
|
github.com/tidwall/sjson v1.2.5
|
||||||
github.com/tiktoken-go/tokenizer v0.6.2
|
github.com/tiktoken-go/tokenizer v0.6.2
|
||||||
golang.org/x/crypto v0.35.0
|
golang.org/x/crypto v0.35.0
|
||||||
golang.org/x/image v0.23.0
|
golang.org/x/image v0.23.0
|
||||||
golang.org/x/net v0.35.0
|
golang.org/x/net v0.35.0
|
||||||
|
golang.org/x/sync v0.11.0
|
||||||
gorm.io/driver/mysql v1.4.3
|
gorm.io/driver/mysql v1.4.3
|
||||||
gorm.io/driver/postgres v1.5.2
|
gorm.io/driver/postgres v1.5.2
|
||||||
gorm.io/gorm v1.25.2
|
gorm.io/gorm v1.25.2
|
||||||
@@ -38,10 +46,10 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
|
||||||
github.com/aws/smithy-go v1.20.2 // indirect
|
github.com/boombuler/barcode v1.1.0 // indirect
|
||||||
github.com/bytedance/sonic v1.11.6 // indirect
|
github.com/bytedance/sonic v1.11.6 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
@@ -77,6 +85,8 @@ require (
|
|||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
@@ -84,7 +94,6 @@ require (
|
|||||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||||
golang.org/x/arch v0.12.0 // indirect
|
golang.org/x/arch v0.12.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||||
golang.org/x/sync v0.11.0 // indirect
|
|
||||||
golang.org/x/sys v0.30.0 // indirect
|
golang.org/x/sys v0.30.0 // indirect
|
||||||
golang.org/x/text v0.22.0 // indirect
|
golang.org/x/text v0.22.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.2 // indirect
|
google.golang.org/protobuf v1.34.2 // indirect
|
||||||
|
|||||||
46
go.sum
46
go.sum
@@ -6,20 +6,23 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc
|
|||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
|
github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
|
github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
||||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
|
||||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
|
||||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
|
||||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
|
||||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
|
||||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
|
||||||
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
|
||||||
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||||
|
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||||
|
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
|
||||||
|
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
||||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
@@ -117,6 +120,8 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
|
|||||||
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
|
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
|
||||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
|
||||||
|
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
@@ -169,6 +174,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
|||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||||
|
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
@@ -195,6 +202,19 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
|
||||||
|
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
|
||||||
|
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
|
||||||
|
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
||||||
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||||
@@ -224,6 +244,7 @@ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSO
|
|||||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
|
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -232,6 +253,7 @@ golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
|||||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
|||||||
1041
i18n/zh-cn.json
1041
i18n/zh-cn.json
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,26 @@
|
|||||||
package common
|
package logger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"one-api/common"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
loggerINFO = "INFO"
|
loggerINFO = "INFO"
|
||||||
loggerWarn = "WARN"
|
loggerWarn = "WARN"
|
||||||
loggerError = "ERR"
|
loggerError = "ERR"
|
||||||
|
loggerDebug = "DEBUG"
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxLogCount = 1000000
|
const maxLogCount = 1000000
|
||||||
@@ -27,7 +30,10 @@ var setupLogLock sync.Mutex
|
|||||||
var setupLogWorking bool
|
var setupLogWorking bool
|
||||||
|
|
||||||
func SetupLogger() {
|
func SetupLogger() {
|
||||||
if *LogDir != "" {
|
defer func() {
|
||||||
|
setupLogWorking = false
|
||||||
|
}()
|
||||||
|
if *common.LogDir != "" {
|
||||||
ok := setupLogLock.TryLock()
|
ok := setupLogLock.TryLock()
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Println("setup log is already working")
|
log.Println("setup log is already working")
|
||||||
@@ -35,9 +41,8 @@ func SetupLogger() {
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
setupLogLock.Unlock()
|
setupLogLock.Unlock()
|
||||||
setupLogWorking = false
|
|
||||||
}()
|
}()
|
||||||
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
|
logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
|
||||||
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to open log file")
|
log.Fatal("failed to open log file")
|
||||||
@@ -47,16 +52,6 @@ func SetupLogger() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SysLog(s string) {
|
|
||||||
t := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func SysError(s string) {
|
|
||||||
t := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogInfo(ctx context.Context, msg string) {
|
func LogInfo(ctx context.Context, msg string) {
|
||||||
logHelper(ctx, loggerINFO, msg)
|
logHelper(ctx, loggerINFO, msg)
|
||||||
}
|
}
|
||||||
@@ -69,12 +64,21 @@ func LogError(ctx context.Context, msg string) {
|
|||||||
logHelper(ctx, loggerError, msg)
|
logHelper(ctx, loggerError, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LogDebug(ctx context.Context, msg string) {
|
||||||
|
if common.DebugEnabled {
|
||||||
|
logHelper(ctx, loggerDebug, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func logHelper(ctx context.Context, level string, msg string) {
|
func logHelper(ctx context.Context, level string, msg string) {
|
||||||
writer := gin.DefaultErrorWriter
|
writer := gin.DefaultErrorWriter
|
||||||
if level == loggerINFO {
|
if level == loggerINFO {
|
||||||
writer = gin.DefaultWriter
|
writer = gin.DefaultWriter
|
||||||
}
|
}
|
||||||
id := ctx.Value(RequestIdKey)
|
id := ctx.Value(common.RequestIdKey)
|
||||||
|
if id == nil {
|
||||||
|
id = "SYSTEM"
|
||||||
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||||
logCount++ // we don't need accurate count, so no lock here
|
logCount++ // we don't need accurate count, so no lock here
|
||||||
@@ -87,23 +91,17 @@ func logHelper(ctx context.Context, level string, msg string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FatalLog(v ...any) {
|
|
||||||
t := time.Now()
|
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogQuota(quota int) string {
|
func LogQuota(quota int) string {
|
||||||
if DisplayInCurrencyEnabled {
|
if common.DisplayInCurrencyEnabled {
|
||||||
return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
|
return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Sprintf("%d 点额度", quota)
|
return fmt.Sprintf("%d 点额度", quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormatQuota(quota int) string {
|
func FormatQuota(quota int) string {
|
||||||
if DisplayInCurrencyEnabled {
|
if common.DisplayInCurrencyEnabled {
|
||||||
return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
|
return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Sprintf("%d", quota)
|
return fmt.Sprintf("%d", quota)
|
||||||
}
|
}
|
||||||
23
main.go
23
main.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/router"
|
"one-api/router"
|
||||||
@@ -60,13 +61,13 @@ func main() {
|
|||||||
}
|
}
|
||||||
if common.MemoryCacheEnabled {
|
if common.MemoryCacheEnabled {
|
||||||
common.SysLog("memory cache enabled")
|
common.SysLog("memory cache enabled")
|
||||||
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
||||||
|
|
||||||
// Add panic recovery and retry for InitChannelCache
|
// Add panic recovery and retry for InitChannelCache
|
||||||
func() {
|
func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
||||||
// Retry once
|
// Retry once
|
||||||
_, _, fixErr := model.FixAbility()
|
_, _, fixErr := model.FixAbility()
|
||||||
if fixErr != nil {
|
if fixErr != nil {
|
||||||
@@ -93,13 +94,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
go controller.AutomaticallyUpdateChannels(frequency)
|
go controller.AutomaticallyUpdateChannels(frequency)
|
||||||
}
|
}
|
||||||
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
|
||||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
go controller.AutomaticallyTestChannels()
|
||||||
if err != nil {
|
|
||||||
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
|
|
||||||
}
|
|
||||||
go controller.AutomaticallyTestChannels(frequency)
|
|
||||||
}
|
|
||||||
if common.IsMasterNode && constant.UpdateTask {
|
if common.IsMasterNode && constant.UpdateTask {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
controller.UpdateMidjourneyTaskBulk()
|
controller.UpdateMidjourneyTaskBulk()
|
||||||
@@ -125,7 +122,7 @@ func main() {
|
|||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.New()
|
server := gin.New()
|
||||||
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
|
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
|
||||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
common.SysLog(fmt.Sprintf("panic detected: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||||
@@ -168,11 +165,11 @@ func InitResources() error {
|
|||||||
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||||
}
|
}
|
||||||
|
|
||||||
common.SetupLogger()
|
|
||||||
|
|
||||||
// 加载环境变量
|
// 加载环境变量
|
||||||
common.InitEnv()
|
common.InitEnv()
|
||||||
|
|
||||||
|
logger.SetupLogger()
|
||||||
|
|
||||||
// Initialize model settings
|
// Initialize model settings
|
||||||
ratio_setting.InitRatioSettings()
|
ratio_setting.InitRatioSettings()
|
||||||
|
|
||||||
@@ -207,4 +204,4 @@ func InitResources() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -4,7 +4,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -122,7 +125,20 @@ func authHelper(c *gin.Context, minRole int) {
|
|||||||
c.Set("role", role)
|
c.Set("role", role)
|
||||||
c.Set("id", id)
|
c.Set("id", id)
|
||||||
c.Set("group", session.Get("group"))
|
c.Set("group", session.Get("group"))
|
||||||
|
c.Set("user_group", session.Get("group"))
|
||||||
c.Set("use_access_token", useAccessToken)
|
c.Set("use_access_token", useAccessToken)
|
||||||
|
|
||||||
|
//userCache, err := model.GetUserCache(id.(int))
|
||||||
|
//if err != nil {
|
||||||
|
// c.JSON(http.StatusOK, gin.H{
|
||||||
|
// "success": false,
|
||||||
|
// "message": err.Error(),
|
||||||
|
// })
|
||||||
|
// c.Abort()
|
||||||
|
// return
|
||||||
|
//}
|
||||||
|
//userCache.WriteContext(c)
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,14 +194,15 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
// 检查path包含/v1/messages
|
// 检查path包含/v1/messages
|
||||||
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
|
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
|
||||||
// 从x-api-key中获取key
|
anthropicKey := c.Request.Header.Get("x-api-key")
|
||||||
key := c.Request.Header.Get("x-api-key")
|
if anthropicKey != "" {
|
||||||
if key != "" {
|
c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// gemini api 从query中获取key
|
// gemini api 从query中获取key
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||||
skKey := c.Query("key")
|
skKey := c.Query("key")
|
||||||
if skKey != "" {
|
if skKey != "" {
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
||||||
@@ -221,6 +238,16 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allowIpsMap := token.GetIpLimitsMap()
|
||||||
|
if len(allowIpsMap) != 0 {
|
||||||
|
clientIp := c.ClientIP()
|
||||||
|
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
userCache, err := model.GetUserCache(token.UserId)
|
userCache, err := model.GetUserCache(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
@@ -234,6 +261,25 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
|
|
||||||
userCache.WriteContext(c)
|
userCache.WriteContext(c)
|
||||||
|
|
||||||
|
userGroup := userCache.Group
|
||||||
|
tokenGroup := token.Group
|
||||||
|
if tokenGroup != "" {
|
||||||
|
// check common.UserUsableGroups[userGroup]
|
||||||
|
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// check group in common.GroupRatio
|
||||||
|
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||||
|
if tokenGroup != "auto" {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userGroup = tokenGroup
|
||||||
|
}
|
||||||
|
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||||
|
|
||||||
err = SetupContextForToken(c, token, parts...)
|
err = SetupContextForToken(c, token, parts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -260,7 +306,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
|
|||||||
} else {
|
} else {
|
||||||
c.Set("token_model_limit_enabled", false)
|
c.Set("token_model_limit_enabled", false)
|
||||||
}
|
}
|
||||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
|
||||||
c.Set("token_group", token.Group)
|
c.Set("token_group", token.Group)
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
|
|||||||
12
middleware/disable-cache.go
Normal file
12
middleware/disable-cache.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
func DisableCache() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.Header("Cache-Control", "no-store, no-cache, must-revalidate, private, max-age=0")
|
||||||
|
c.Header("Pragma", "no-cache")
|
||||||
|
c.Header("Expires", "0")
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -26,14 +27,6 @@ type ModelRequest struct {
|
|||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
|
||||||
if len(allowIpsMap) != 0 {
|
|
||||||
clientIp := c.ClientIP()
|
|
||||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||||
@@ -41,24 +34,6 @@ func Distribute() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
|
||||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
|
||||||
if tokenGroup != "" {
|
|
||||||
// check common.UserUsableGroups[userGroup]
|
|
||||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// check group in common.GroupRatio
|
|
||||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
|
||||||
if tokenGroup != "auto" {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
userGroup = tokenGroup
|
|
||||||
}
|
|
||||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
|
||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -80,44 +55,63 @@ func Distribute() func(c *gin.Context) {
|
|||||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
if modelLimitEnable {
|
if modelLimitEnable {
|
||||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||||
var tokenModelLimit map[string]bool
|
if !ok {
|
||||||
if ok {
|
|
||||||
tokenModelLimit = s.(map[string]bool)
|
|
||||||
} else {
|
|
||||||
tokenModelLimit = map[string]bool{}
|
|
||||||
}
|
|
||||||
if tokenModelLimit != nil {
|
|
||||||
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// token model limit is empty, all models are not allowed
|
// token model limit is empty, all models are not allowed
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var tokenModelLimit map[string]bool
|
||||||
|
tokenModelLimit, ok = s.(map[string]bool)
|
||||||
|
if !ok {
|
||||||
|
tokenModelLimit = map[string]bool{}
|
||||||
|
}
|
||||||
|
matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
|
||||||
|
if _, ok := tokenModelLimit[matchName]; !ok {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldSelectChannel {
|
if shouldSelectChannel {
|
||||||
|
if modelRequest.Model == "" {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
var selectGroup string
|
var selectGroup string
|
||||||
|
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||||
|
// check path is /pg/chat/completions
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||||
|
playgroundRequest := &dto.PlayGroundRequest{}
|
||||||
|
err = common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||||
|
if err != nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if playgroundRequest.Group != "" {
|
||||||
|
if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userGroup = playgroundRequest.Group
|
||||||
|
}
|
||||||
|
}
|
||||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
showGroup := userGroup
|
showGroup := userGroup
|
||||||
if userGroup == "auto" {
|
if userGroup == "auto" {
|
||||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||||
}
|
}
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
|
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||||
if channel != nil {
|
//if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
// message = "数据库一致性已被破坏,请联系管理员"
|
||||||
}
|
//}
|
||||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
|
||||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel == nil {
|
if channel == nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -172,24 +166,17 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
relayMode := relayconstant.RelayModeUnknown
|
||||||
var platform string
|
if c.Request.Method == http.MethodPost {
|
||||||
var relayMode int
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
if strings.HasPrefix(modelRequest.Model, "jimeng") {
|
relayMode = relayconstant.RelayModeVideoSubmit
|
||||||
platform = string(constant.TaskPlatformJimeng)
|
} else if c.Request.Method == http.MethodGet {
|
||||||
relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
|
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||||
if relayMode == relayconstant.RelayModeJimengFetchByID {
|
shouldSelectChannel = false
|
||||||
shouldSelectChannel = false
|
}
|
||||||
}
|
if _, ok := c.Get("relay_mode"); !ok {
|
||||||
} else {
|
c.Set("relay_mode", relayMode)
|
||||||
platform = string(constant.TaskPlatformKling)
|
|
||||||
relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
|
||||||
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
|
||||||
shouldSelectChannel = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
c.Set("platform", platform)
|
|
||||||
c.Set("relay_mode", relayMode)
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||||
relayMode := relayconstant.RelayModeGemini
|
relayMode := relayconstant.RelayModeGemini
|
||||||
@@ -198,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
modelRequest.Model = modelName
|
modelRequest.Model = modelName
|
||||||
}
|
}
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -221,7 +208,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||||
modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
||||||
|
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
|
modelRequest.Model = c.PostForm("model")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||||
relayMode := relayconstant.RelayModeAudioSpeech
|
relayMode := relayconstant.RelayModeAudioSpeech
|
||||||
@@ -249,30 +239,43 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
return &modelRequest, shouldSelectChannel, nil
|
return &modelRequest, shouldSelectChannel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
|
||||||
c.Set("original_model", modelName) // for retry
|
c.Set("original_model", modelName) // for retry
|
||||||
if channel == nil {
|
if channel == nil {
|
||||||
return
|
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride())
|
||||||
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||||
}
|
}
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
|
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
|
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
|
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
|
||||||
|
|
||||||
|
key, index, newAPIError := channel.GetNextEnabledKey()
|
||||||
|
if newAPIError != nil {
|
||||||
|
return newAPIError
|
||||||
|
}
|
||||||
if channel.ChannelInfo.IsMultiKey {
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
||||||
|
} else {
|
||||||
|
// 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
|
||||||
}
|
}
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||||
|
|
||||||
|
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, false)
|
||||||
|
|
||||||
// TODO: api_version统一
|
// TODO: api_version统一
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case constant.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
@@ -292,6 +295,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
case constant.ChannelTypeCoze:
|
case constant.ChannelTypeCoze:
|
||||||
c.Set("bot_id", channel.Other)
|
c.Set("bot_id", channel.Other)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
|
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
|
||||||
|
|||||||
80
middleware/email-verification-rate-limit.go
Normal file
80
middleware/email-verification-rate-limit.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EmailVerificationRateLimitMark = "EV"
|
||||||
|
EmailVerificationMaxRequests = 2 // 30秒内最多2次
|
||||||
|
EmailVerificationDuration = 30 // 30秒时间窗口
|
||||||
|
)
|
||||||
|
|
||||||
|
func redisEmailVerificationRateLimiter(c *gin.Context) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := common.RDB
|
||||||
|
key := "emailVerification:" + EmailVerificationRateLimitMark + ":" + c.ClientIP()
|
||||||
|
|
||||||
|
count, err := rdb.Incr(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
// fallback
|
||||||
|
memoryEmailVerificationRateLimiter(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第一次设置键时设置过期时间
|
||||||
|
if count == 1 {
|
||||||
|
_ = rdb.Expire(ctx, key, time.Duration(EmailVerificationDuration)*time.Second).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否超出限制
|
||||||
|
if count <= int64(EmailVerificationMaxRequests) {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取剩余等待时间
|
||||||
|
ttl, err := rdb.TTL(ctx, key).Result()
|
||||||
|
waitSeconds := int64(EmailVerificationDuration)
|
||||||
|
if err == nil && ttl > 0 {
|
||||||
|
waitSeconds = int64(ttl.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": fmt.Sprintf("发送过于频繁,请等待 %d 秒后再试", waitSeconds),
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
|
func memoryEmailVerificationRateLimiter(c *gin.Context) {
|
||||||
|
key := EmailVerificationRateLimitMark + ":" + c.ClientIP()
|
||||||
|
|
||||||
|
if !inMemoryRateLimiter.Request(key, EmailVerificationMaxRequests, EmailVerificationDuration) {
|
||||||
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "发送过于频繁,请稍后再试",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmailVerificationRateLimit() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if common.RedisEnabled {
|
||||||
|
redisEmailVerificationRateLimiter(c)
|
||||||
|
} else {
|
||||||
|
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
|
||||||
|
memoryEmailVerificationRateLimiter(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
66
middleware/jimeng_adapter.go
Normal file
66
middleware/jimeng_adapter.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
func JimengRequestConvert() func(c *gin.Context) {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
action := c.Query("Action")
|
||||||
|
if action == "" {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Action query parameter is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Jimeng official API request
|
||||||
|
var originalReq map[string]interface{}
|
||||||
|
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model, _ := originalReq["req_key"].(string)
|
||||||
|
prompt, _ := originalReq["prompt"].(string)
|
||||||
|
|
||||||
|
unifiedReq := map[string]interface{}{
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"metadata": originalReq,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(unifiedReq)
|
||||||
|
if err != nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, "Failed to marshal request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update request body
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||||
|
c.Set(common.KeyRequestBody, jsonData)
|
||||||
|
|
||||||
|
if image, ok := originalReq["image"]; !ok || image == "" {
|
||||||
|
c.Set("action", constant.TaskActionTextGenerate)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.URL.Path = "/v1/video/generations"
|
||||||
|
|
||||||
|
if action == "CVSync2AsyncGetResult" {
|
||||||
|
taskId, ok := originalReq["task_id"].(string)
|
||||||
|
if !ok || taskId == "" {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "task_id is required for CVSync2AsyncGetResult")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Request.URL.Path = "/v1/video/generations/" + taskId
|
||||||
|
c.Request.Method = http.MethodGet
|
||||||
|
c.Set("task_id", taskId)
|
||||||
|
c.Set("relay_mode", relayconstant.RelayModeVideoFetchByID)
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,7 +18,11 @@ func KlingRequestConvert() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
model, _ := originalReq["model"].(string)
|
// Support both model_name and model fields
|
||||||
|
model, _ := originalReq["model_name"].(string)
|
||||||
|
if model == "" {
|
||||||
|
model, _ = originalReq["model"].(string)
|
||||||
|
}
|
||||||
prompt, _ := originalReq["prompt"].(string)
|
prompt, _ := originalReq["prompt"].(string)
|
||||||
|
|
||||||
unifiedReq := map[string]interface{}{
|
unifiedReq := map[string]interface{}{
|
||||||
@@ -36,7 +40,7 @@ func KlingRequestConvert() func(c *gin.Context) {
|
|||||||
// Rewrite request body and path
|
// Rewrite request body and path
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||||
c.Request.URL.Path = "/v1/video/generations"
|
c.Request.URL.Path = "/v1/video/generations"
|
||||||
if image := originalReq["image"]; image == "" {
|
if image, ok := originalReq["image"]; !ok || image == "" {
|
||||||
c.Set("action", constant.TaskActionTextGenerate)
|
c.Set("action", constant.TaskActionTextGenerate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
common.SysLog(fmt.Sprintf("panic detected: %v", err))
|
||||||
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||||
|
|||||||
@@ -18,12 +18,12 @@ func StatsMiddleware() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// 增加活跃连接数
|
// 增加活跃连接数
|
||||||
atomic.AddInt64(&globalStats.activeConnections, 1)
|
atomic.AddInt64(&globalStats.activeConnections, 1)
|
||||||
|
|
||||||
// 确保在请求结束时减少连接数
|
// 确保在请求结束时减少连接数
|
||||||
defer func() {
|
defer func() {
|
||||||
atomic.AddInt64(&globalStats.activeConnections, -1)
|
atomic.AddInt64(&globalStats.activeConnections, -1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -38,4 +38,4 @@ func GetStats() StatsInfo {
|
|||||||
return StatsInfo{
|
return StatsInfo{
|
||||||
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user