Compare commits

...

187 Commits

Author SHA1 Message Date
CaIon
f796c3b216 fix: update Init method to correctly set RequestMode based on upstream model name prefixes 2025-05-23 01:34:53 +08:00
CaIon
c53a48cde5 feat: add panic recovery and retry mechanism for InitChannelCache; improve batch deletion of abilities in FixAbility 2025-05-23 01:26:52 +08:00
CaIon
9a59da16a5 feat: implement search functionality in ChannelsTable for improved channel filtering 2025-05-22 16:54:55 +08:00
CaIon
e18001299b feat: enhance Gemini response handling by adding reasoning content and updating JSON decoding method 2025-05-22 16:11:50 +08:00
CaIon
66bdfe180c feat: add Thought field to GeminiPart and update response handling in streamResponseGeminiChat2OpenAI 2025-05-22 15:52:23 +08:00
CaIon
1f9fc09989 feat: add OutputFormat field to ImageRequest for enhanced image processing options 2025-05-20 19:40:29 +08:00
CaIon
498d73f67c refactor: update JSON field names in GeminiChatRequest for consistency 2025-05-19 20:26:30 +08:00
IcedTangerine
0ca17d3e6d Merge pull request #1071 from feitianbubu/fixMjImageProxy
fix: proxy settings not applied when request MJ image url
2025-05-18 14:56:47 +08:00
skynono
9927e5d191 fix: proxy settings not applied when request MJ image url 2025-05-16 18:07:56 +08:00
Calcium-Ion
7171a69512 Merge pull request #1067 from QuantumNous/coze
Coze 渠道
2025-05-16 16:11:02 +08:00
creamlike1024
e379ee8f66 coze stream 2025-05-16 10:27:07 +08:00
creamlike1024
59aabb4311 add frontend display, more model 2025-05-15 20:00:59 +08:00
CaIon
4825404d37 feat: enhance image decoding logic to handle base64 file types and improve error handling 2025-05-15 14:51:33 +08:00
CaIon
ea04e6bcc5 fix: update model selection logic for image edits in distributor middleware 2025-05-14 17:01:50 +08:00
creamlike1024
108b67be6c use channel bot id 2025-05-13 22:23:38 +08:00
creamlike1024
29c95c598e cozeChatHelper 2025-05-13 22:01:12 +08:00
creamlike1024
b2499b0a7e DoRequest 2025-05-13 21:13:34 +08:00
IcedTangerine
12737fb7e5 Merge pull request #1063 from kingfs/fix/ali-completions-api
fix: ALI completions api path error
2025-05-13 17:51:35 +08:00
王永振
f17f38e569 fix: ALI completions api path error 2025-05-13 13:39:44 +08:00
creamlike1024
b2cad22952 add coze request 2025-05-13 12:52:22 +08:00
creamlike1024
e763124b69 Merge branch 'a37836323-add-dalle-fields' 2025-05-11 17:03:56 +08:00
creamlike1024
153012789d Merge branch 'add-dalle-fields' of github.com:a37836323/new-api into a37836323-add-dalle-fields 2025-05-11 17:03:27 +08:00
creamlike1024
d985563516 feat: add support for socks5h 2025-05-11 17:00:33 +08:00
CaIon
58dc7ad770 feat: add moderation and background fields to ImageRequest struct in dalle.go #1052 2025-05-10 15:52:41 +08:00
a37836323
28cdfc0a14 添加DALL-E图像生成请求中的Background和Moderation字段 2025-05-10 04:33:49 +08:00
CaIon
7b176015b8 feat: enhance OpenAI handler to support forced response formatting and add debug logging for request URLs 2025-05-09 18:57:06 +08:00
Calcium-Ion
cc2d9f539d Merge pull request #1046 from QuantumNous/workerHttpRequest
feat: add option to allow worker HTTP image requests
2025-05-09 18:31:25 +08:00
IcedTangerine
7f86bdf548 Merge pull request #1050 from feitianbubu/fixRatio
fix: correct formatting string in PriceData.ToSetting to handle Image…
2025-05-09 18:15:39 +08:00
creamlike1024
0d929800cf fix: GetRequestURL remove unnecessary case 2025-05-09 18:13:19 +08:00
creamlike1024
9ebfcaf6aa feat: change azure default api version to 2025-04-01-preview 2025-05-09 18:11:37 +08:00
skynono
40efa73a42 fix: correct formatting string in PriceData.ToSetting to handle ImageRatio as float instead of integer 2025-05-09 17:12:35 +08:00
creamlike1024
4a59b3ccd6 Merge branch '9Ninety-fix/sse_ping' 2025-05-09 13:57:26 +08:00
creamlike1024
ec61534256 feat: send SSE ping before get response 2025-05-09 13:57:00 +08:00
creamlike1024
2a218c1c89 Merge branch 'fix/sse_ping' of github.com:9Ninety/new-api into 9Ninety-fix/sse_ping 2025-05-09 12:28:05 +08:00
IcedTangerine
993cd6b624 Merge pull request #1045 from tbphp/feat_openrouter_balance
feat: update OpenRouter balance
2025-05-09 12:19:20 +08:00
creamlike1024
3d4bd76083 feat: add option to allow worker HTTP image requests 2025-05-09 02:00:42 +08:00
tbphp
7192437863 fix: 修改命名规范 2025-05-09 00:20:26 +08:00
tbphp
4bbcb00d13 feat: update OpenRouter balance 2025-05-09 00:15:44 +08:00
creamlike1024
9de24668d8 Merge branch 'tbphp-tbphp_model_request_rate_limit_for_group' 2025-05-08 23:20:08 +08:00
CaIon
7aa54a2cd7 feat: add AzureNoRemoveDotTime constant and update channel handling #1044
- Introduced a new constant `AzureNoRemoveDotTime` in `constant/azure.go` to manage model name formatting for channels created after May 10, 2025.
- Updated `distributor.go` to set `channel_create_time` in the context.
- Modified `adaptor.go` to conditionally remove dots from model names based on the channel creation time.
- Enhanced `relay_info.go` to include `ChannelCreateTime` in the `RelayInfo` struct.
- Updated English localization files to reflect changes in model name handling for new channels.
2025-05-08 23:19:40 +08:00
CaIon
a836e97315 fix: update OpenAI request handling to include 'o1-preview' model support #1029 2025-05-08 23:19:38 +08:00
creamlike1024
3373f5e0a0 fix: tool quota calculate 2025-05-08 23:19:37 +08:00
liusanp
d6e601b424 fix: xAi response 2025-05-08 23:19:35 +08:00
liusanp
8c3a559690 fix: xAi requestUrl 2025-05-08 23:19:34 +08:00
liusanp
c008d391df fix: quality, size or style are not supported by xAI API 2025-05-08 23:19:32 +08:00
creamlike1024
7c29844e4a Merge branch 'tbphp_model_request_rate_limit_for_group' of github.com:tbphp/new-api into tbphp-tbphp_model_request_rate_limit_for_group 2025-05-08 23:16:06 +08:00
CaIon
90d85a6f0a feat: add AzureNoRemoveDotTime constant and update channel handling #1044
- Introduced a new constant `AzureNoRemoveDotTime` in `constant/azure.go` to manage model name formatting for channels created after May 10, 2025.
- Updated `distributor.go` to set `channel_create_time` in the context.
- Modified `adaptor.go` to conditionally remove dots from model names based on the channel creation time.
- Enhanced `relay_info.go` to include `ChannelCreateTime` in the `RelayInfo` struct.
- Updated English localization files to reflect changes in model name handling for new channels.
2025-05-08 22:39:55 +08:00
CaIon
d40429ad93 fix: update OpenAI request handling to include 'o1-preview' model support #1029 2025-05-08 21:34:31 +08:00
Calcium-Ion
30806ef270 Merge pull request #1040 from QuantumNous/responses-quota
fix: tool quota calculate
2025-05-08 01:21:34 +08:00
9
02acc52fdb fix: ensure SSE ping packets are sent before upstream response
These changes ensures SSE ping packets are sent before receiving a response from the upstream. The previous implementation did not send ping packets until after the upstream response, rendering the feature ineffective.
2025-05-07 23:29:06 +08:00
IcedTangerine
3458476115 Merge pull request #1039 from liusanp/main
Fix grok-2-image request error
2025-05-07 22:06:51 +08:00
IcedTangerine
61c685ad79 Merge pull request #1032 from feitianbubu/upstream
fix: correct error messages for dall-e models size parameters
2025-05-07 20:56:36 +08:00
IcedTangerine
0121795a84 Merge pull request #1037 from LarchLiu/main
fix: gemini response json schema
2025-05-07 20:53:45 +08:00
creamlike1024
ae254f5368 fix: tool quota calculate 2025-05-07 19:33:32 +08:00
liusanp
562448b441 fix: xAi response 2025-05-07 18:59:27 +08:00
liusanp
04f7d89399 fix: xAi requestUrl 2025-05-07 18:32:59 +08:00
Alex Liu
0d456df588 fix: gemini response json schema 2025-05-07 18:08:56 +08:00
CaIon
dc3b453b05 fix: update ResponseChunkData to format data correctly without newline 2025-05-07 17:02:47 +08:00
CaIon
b19e1b8207 feat: add support for BaiduV2 channel in relay info 2025-05-07 16:30:32 +08:00
liusanp
97b5ca8099 fix: quality, size or style are not supported by xAI API 2025-05-07 16:17:22 +08:00
CaIon
4ecf5dde14 Merge remote-tracking branch 'origin/main' 2025-05-07 16:16:19 +08:00
joey
65ccfd0848 feat: support model mapping chain
#1033
2025-05-07 16:00:35 +08:00
skynono
2621b77f9a fix: correct error messages for dall-e models size parameters
(cherry picked from commit 149d06850c10cc6cdb3291164e3e46f99ca59abc)
2025-05-07 11:21:19 +08:00
Calcium-Ion
65a15dbc17 Merge pull request #1025 from QuantumNous/responses_buildin_tools
feat: implement OpenAI responses built-in tool tracking
2025-05-07 02:25:47 +08:00
creamlike1024
c0095d4521 feat: 添加 built in tools 计费前端显示 2025-05-07 01:08:20 +08:00
creamlike1024
5043075135 chore: move file search tool price to operation_setting 2025-05-06 23:57:22 +08:00
creamlike1024
10ef61eedb chore: move web search tool price to operation_setting 2025-05-06 23:25:16 +08:00
IcedTangerine
dc9e3b4139 Merge pull request #1026 from tbphp/tbphp_fix_redis_limit
fix: Redis limit ignoring max eq 0
2025-05-06 22:36:13 +08:00
creamlike1024
27e3aa828c Merge branch 'feitianbubu-upstream' 2025-05-06 22:31:39 +08:00
creamlike1024
d859e3fa64 fix: 修复未输入新密码时提示修改成功 2025-05-06 22:28:32 +08:00
creamlike1024
459c277c94 feat: 添加 built in tools 计费
- 增加非流的工具调用次数统计
- 添加 web search 和 file search 计费
2025-05-06 21:58:01 +08:00
CaIon
5639f1c2d8 feat: add support for DeepSeek channel in streamSupportedChannels 2025-05-06 18:41:01 +08:00
skynono
0cf4c59d22 feat: add original password verification when changing password 2025-05-06 14:28:27 +08:00
tbphp
3d243c3ee2 fix: 样式修复 2025-05-05 23:56:15 +08:00
tbphp
87188cd7d4 fix: 缩进修复还原 2025-05-05 23:53:05 +08:00
tbphp
bbab729619 fix: text 2025-05-05 23:48:15 +08:00
Apple\Apple
1c67dd3c31 📕docs: Update the content in README.en.md and the structure of the docs directory 2025-05-05 23:44:30 +08:00
tbphp
0be3678c9c fix: 请求完成数必须大于等于1 2025-05-05 23:41:43 +08:00
tbphp
1cb4d750e4 feat: 分组速率前端优化 2025-05-05 22:06:16 +08:00
tbphp
88ed83f419 feat: Modellimitgroup check 2025-05-05 20:00:06 +08:00
tbphp
1513ed7847 refactor: 调整代码,符合项目现有规范 2025-05-05 19:32:22 +08:00
tbphp
1e1d24d1b0 fix: rm debug file 2025-05-05 17:57:02 +08:00
tbphp
b7fd1e4a20 fix: Redis limit ignoring max eq 0 2025-05-05 12:55:48 +08:00
tbphp
7e7d6112ca feat: 优化代码,去除多余注释和修改 2025-05-05 11:34:57 +08:00
tbphp
6c3fb7777e feat: 增加分组速率功能 2025-05-05 07:31:54 +08:00
CaIon
18b3300ff1 feat: implement OpenAI responses handling and streaming support with built-in tool tracking 2025-05-05 00:40:16 +08:00
Calcium-Ion
bae57c05c1 Merge pull request #1024 from tbphp/fix-edituser-text
fix: EditUser text error
2025-05-04 18:30:32 +08:00
tbphp
3def2bbd30 fix: EditUser text error 2025-05-04 18:26:18 +08:00
CaIon
419a056fbf refactor: remove unnecessary call to helper.Done and adjust data rendering in ClaudeChunkData 2025-05-04 17:35:45 +08:00
Calcium-Ion
48af027903 Merge pull request #1020 from QuantumNous/v1responses
feat: support /v1/responses API
2025-05-04 17:13:39 +08:00
Calcium-Ion
9bf90c3baf Merge pull request #1012 from tbphp/vertex_thinking_support
feat: support thinking suffix for vertex gemini channel
2025-05-04 17:11:27 +08:00
CaIon
fe3232bf23 feat: enhance OaiResponsesStreamHandler to handle output text and improve response streaming 2025-05-04 17:09:37 +08:00
creamlike1024
1236fa8fe4 add OaiResponsesStreamHandler 2025-05-03 22:36:27 +08:00
CaIon
e097d5a538 feat: add video URL support in MediaContent and update token counting logic 2025-05-03 21:12:07 +08:00
creamlike1024
425feb88d8 feat: support /v1/responses API 2025-05-02 13:59:46 +08:00
CaIon
fd6838e690 feat: enable error logging configuration in docker-compose and application 2025-04-29 16:26:55 +08:00
CaIon
b64480b750 fix: gemini thinking tokens count #1014 2025-04-29 16:21:54 +08:00
CaIon
da6423de33 refactor: Reducing the lock duration to the minimum necessary time in CacheGetRandomSatisfiedChannel function 2025-04-29 15:57:21 +08:00
tbphp
efc9d200b1 feat: support thinking suffix for vertex gemini channel 2025-04-29 13:30:03 +08:00
CaIon
fe37718259 fix: update audio ratio logic for model names in GetAudioRatio function 2025-04-28 20:55:40 +08:00
IcedTangerine
c412fd9cde Merge pull request #1008 from JoeyLearnsToCode/feat-search-channel-by-url
feat: support searching channels by base url
2025-04-28 13:15:49 +08:00
creamlike1024
54f5b1a951 Merge branch 'wzxjohn-feature/wellknown' 2025-04-28 12:55:06 +08:00
JoeyLearnsToCode
a9b9d23586 feat: support searching channels by base url 2025-04-28 11:38:53 +08:00
wzxjohn
168226ba10 fix: remove custom header in oidc well known request 2025-04-28 11:25:04 +08:00
wzxjohn
1a8fd61a98 feat: support empty well known url 2025-04-28 11:25:04 +08:00
wzxjohn
2bd2d73d33 feat: improve log delete api 2025-04-28 11:25:04 +08:00
creamlike1024
62da481dc6 Merge branch 'error-logs' of github.com:zenghongtu/new-api into zenghongtu-error-logs 2025-04-28 11:06:32 +08:00
CaIon
4217358de7 feat: add image preview functionality and update model name instructions in EditChannel 2025-04-27 17:20:49 +08:00
CaIon
bb9f5a4a6d refactor: rename InitModelSettings to InitRatioSettings 2025-04-26 17:15:34 +08:00
CaIon
935acccca4 fix: update cacheRatioMap initialization in InitModelSettings function 2025-04-26 17:09:23 +08:00
CaIon
453a42fad9 feat: initialize cacheRatioMap in InitModelSettings function 2025-04-26 17:06:03 +08:00
CaIon
58101328c5 fix: handle optional user_group_ratio in LogsTable and render helper 2025-04-26 15:59:49 +08:00
CaIon
a03c615fa4 Merge remote-tracking branch 'new-api/main' into gpt-image
# Conflicts:
#	relay/relay-image.go
2025-04-26 15:54:08 +08:00
CaIon
487ef35c58 feat: support image edit model mapping
(cherry picked from commit 1a869d8ad77f262ee27675ec2deaf451b1743eb7)
2025-04-26 15:48:59 +08:00
xyfacai
f9f32a0158 feat: support /images/edit
(cherry picked from commit 1c0a1238787d490f02dd9269b616580a16604180)
2025-04-26 15:44:56 +08:00
IcedTangerine
ea10806cf9 Merge pull request #950 from datehoer/main
fix: update getAndValidImageRequest function in relay/relay-image.go to support grok-2-image model
2025-04-26 15:34:15 +08:00
IcedTangerine
1a9ebb54b2 Merge pull request #843 from IllTamer/pr
fix: the pricing available popover display anyway
2025-04-25 18:27:45 +08:00
IcedTangerine
6de3857150 Merge branch 'main' into pr 2025-04-25 18:27:11 +08:00
han shi
32cd890b6e feat: 增加sendcloud邮件服务器的支持 (#947)
* 增加sendcloud邮件服务器的支持

* 调整代码结构

* Used slince.Contains function

---------

Co-authored-by: shih <shih@knownsec.com>
2025-04-25 18:17:46 +08:00
creamlike1024
f968d77365 fix: remove apikey from test channel log, close #1000 2025-04-25 17:08:26 +08:00
CaIon
dc22f7d32f refactor: update deepseek beta api 2025-04-25 16:26:16 +08:00
creamlike1024
c2b33e3b23 fix: GetMaxUserId use Unscope, close #987 2025-04-25 16:13:11 +08:00
IcedTangerine
db3326deae Merge pull request #975 from asjfoajs/qn-main
[#969] Refactor: Optimize the request rate limiting for ModelRequestRateLimi…
2025-04-25 11:59:05 +08:00
CaIon
25ae077ac9 refactor: update claude media source handling 2025-04-24 15:59:43 +08:00
CaIon
aaa41a8074 refactor: update ClaudeMessageSource struct to include optional Url field and adjust media source handling in relay-claude #993 2025-04-24 00:39:09 +08:00
CaIon
26f5b954c5 f*** gemini 2025-04-19 18:07:51 +08:00
CaIon
79c6dd08c9 refactor: enhance SystemSetting submission logic and handle empty WorkerUrl 2025-04-19 00:20:25 +08:00
CaIon
17e8a3432a refactor: update GeminiThinkingConfig initialization 2025-04-18 23:13:28 +08:00
CaIon
790af65b2c refactor: remove unsupported 'exclusiveMinimum' field from cleanFunctionParameters 2025-04-18 22:40:05 +08:00
CaIon
6522147183 refactor: remove unsupported root-level fields from cleanFunctionParameters 2025-04-18 21:38:12 +08:00
CaIon
0755ac9991 refactor: streamline value assignment in SettingGeminiModel 2025-04-18 20:08:26 +08:00
CaIon
4c4dc6e8b4 feat: add gemini thinking suffix support #981 2025-04-18 19:36:18 +08:00
CaIon
1eebdc4773 refactor: remove reasoning field from GeneralOpenAIRequest struct 2025-04-17 17:11:42 +08:00
CaIon
9b6c898675 feat: add reasoning field to GeneralOpenAIReques 2025-04-17 17:09:46 +08:00
CaIon
ee4f27d01b refactor: simplify model prefix checks and update message role for o-series models 2025-04-17 16:50:52 +08:00
Apple\Apple
995c19a997 🐛fix: Fix the issue where new whitelist email domain names cannot be added in the system settings 2025-04-16 17:11:59 +08:00
霍雨佳
e385e347ea Refactor: Optimize the token bucket algorithm, specifically the New method in common/imiterlimiter.go.
Solution: Remove Redis ping. When printing exceptions, use SysLog to print and add additional logging information.
2025-04-16 16:36:07 +08:00
Apple\Apple
71d0d759da Merge pull request #927 from QuentinHsu/refactor-system-setting
# Conflicts:
#	web/src/App.js
#	web/src/components/ModelSetting.js
#	web/src/components/PersonalSetting.js
#	web/src/components/SystemSetting.js
#	web/src/pages/Channel/EditChannel.js
2025-04-16 16:27:11 +08:00
霍雨佳
eb75ff232f Refactor: Optimize the request rate limiting for ModelRequestRateLimitCount.
Reason: The original steps 1 and 3 in the redisRateLimitHandler method were not atomic, leading to poor precision under high concurrent requests. For example, with a rate limit set to 60, sending 200 concurrent requests would result in none being blocked, whereas theoretically around 140 should be intercepted.
Solution: I chose not to merge steps 1 and 3 into a single Lua script because a single atomic operation involving read, write, and delete operations could suffer from performance issues under high concurrency. Instead, I implemented a token bucket algorithm to optimize this, reducing the atomic operation to just read and write steps while significantly decreasing the memory footprint.
2025-04-16 10:33:43 +08:00
CaIon
272662089d refactor: remove unused mutex from RelayInfo struct 2025-04-15 23:06:32 +08:00
CaIon
214ca4db56 fix: claude parallel function calling 2025-04-15 04:52:33 +08:00
CaIon
473e8e0eaf feat: support gemini output text and inline images. (close #866) 2025-04-15 02:32:51 +08:00
CaIon
99efc1fbb6 fix: try to fix claude to openai format mcp #966 2025-04-15 01:16:06 +08:00
Calcium-Ion
d283f6b35f Merge pull request #967 from neotf/fix-01
fix: wrong field for Claude (OpenAI Upstream)
2025-04-15 00:05:41 +08:00
CaIon
2f3acd9d22 feat: 添加流模式下的SSE保活机制 #945 2025-04-14 19:40:23 +08:00
neotf
eee6dee599 fix: wrong systemStr for Claude (OpenAI Upstream) 2025-04-14 01:09:02 +08:00
CaIon
dcf7878772 fix: update model name handling in UI and localization 2025-04-12 17:44:29 +08:00
jasonzeng
97bc2b4474 feat: add error logging functionality to relay and update logs table for error type display 2025-04-12 00:43:34 +08:00
CaIon
ef8ae4db80 fix: xAI usage 2025-04-11 23:31:32 +08:00
CaIon
90576d0261 feat: enhance Claude to OpenAI request conversion with additional relay info support 2025-04-11 19:13:38 +08:00
CaIon
4b3e30e669 feat: 完善openai转claude支持 2025-04-11 18:28:50 +08:00
CaIon
75570af967 chore: update .gitignore and docker-compose.yml to include tiktoken_cache directory 2025-04-11 16:24:27 +08:00
CaIon
cca9c0479f feat: enhance file handling and logging in the application 2025-04-11 16:23:54 +08:00
CaIon
8a2332074f refactor: move maxFileSize variable inside GetFileBase64FromUrl function 2025-04-11 15:53:23 +08:00
CaIon
2ec4565601 feat: implement parameter cleaning for Gemini functions 2025-04-10 22:35:03 +08:00
CaIon
a4fb33957f feat: support zhipu_4v embeddings path 2025-04-10 20:53:51 +08:00
Calcium-Ion
909c5eb276 Merge pull request #959 from Praying/main
fix(relay): 优化数据流处理
2025-04-10 17:21:55 +08:00
CaIon
8723e3f239 feat: add xAI handling and response processing 2025-04-10 17:20:59 +08:00
quran
9328b907f2 fix(relay): 优化数据流处理
- 移除了 bufio 的无效使用
- 在 StreamScannerHandler 中增加了初始和最大缓冲区大小的常量设置
- 调整 StreamScannerHandler 中缓冲区大小,避免出现token too long报错
2025-04-10 16:56:16 +08:00
Calcium-Ion
8efa12b941 Merge pull request #953 from wkxu/main
fix: .env文件配置DEBUG=true等参数不起作用的fix
2025-04-10 16:14:11 +08:00
Calcium-Ion
7b997b3a2c Merge pull request #956 from HynoR/feat/xai
feat: add xAI channel
2025-04-10 16:13:48 +08:00
HynoR
700c05b826 feat: update adaptor methods and add new image model 2025-04-10 15:08:12 +08:00
HynoR
c5103237b0 feat: add xai grok-3-mini reasoning effort 2025-04-10 13:31:43 +08:00
HynoR
f500eb17a8 feat: add xai channel
feat: add xai channel

feat: add xai channel
2025-04-10 13:04:43 +08:00
wkxu
86f6bb7abe refactor: 把common/instants.go里的从Getenv获取的参数,放到init.go的LoadEnv函数里获取
把constant/env.go里的从Getenv获取的参数,放到env.go的InitEnv函数里获取。以避免.env文件配置参数不起作用的情况
2025-04-10 09:02:19 +08:00
Calcium-Ion
c4c1099ae5 Merge pull request #944 from lamcodes/main
Update: Gemini channel fetch_models
2025-04-10 00:09:54 +08:00
CaIon
c869455456 fix: Update model ratios for gemini-2.5-pro 2025-04-10 00:09:11 +08:00
CaIon
f89d8a0fe5 refactor: Remove duplicate model settings initialization in main function 2025-04-10 00:07:34 +08:00
CaIon
3d6d19903b refactor: Update localization keys for API address in English translations and adjust related UI labels 2025-04-09 22:22:19 +08:00
datehoer
c5f1a0c712 Add support for grok-2-image. Currently, grok-2-image doesn't support the size, quality, or style parameters. Set 'size'='empty' to use grok-2-image 2025-04-09 15:05:00 +08:00
zkp
524d4a65bf Update: Gemini channel fetch_models 2025-04-08 22:43:13 +08:00
CaIon
082218173a feat: Add CheckSetup function call in main to ensure proper initialization #942 2025-04-08 18:14:36 +08:00
Calcium-Ion
67cbbc2266 Merge pull request #930 from Yiffyi/main
fix: save OIDC settings
2025-04-08 17:39:42 +08:00
CaIon
79b35e385f Update MaxTokens for gemini model to 300 in test request 2025-04-08 17:37:25 +08:00
Calcium-Ion
03e8ab4126 Merge pull request #936 from lamcodes/main
fix: gemini test MaxTokens
2025-04-08 17:33:31 +08:00
Calcium-Ion
30f32c6a6d Set MaxTokens to 50 for gemini 2025-04-08 17:33:10 +08:00
CaIon
5813ca780f feat: Integrate SetupCheck component for improved setup validation in routing 2025-04-08 17:31:46 +08:00
CaIon
aa34c3035a feat: Initialize model settings and improve concurrency control in operation settings 2025-04-07 22:20:47 +08:00
CaIon
fb9f595044 feat: Add concurrency control to group ratio management with mutexes 2025-04-07 21:55:54 +08:00
zkp
f24de65626 fix: gemini test MaxTokens 2025-04-06 23:24:47 +08:00
Yiffyi Jia
e34dccbc65 fix: cannot save OIDC settings 2025-04-05 04:24:38 +00:00
CaIon
f6e8887482 Update model-ratio.go 2025-04-04 23:43:14 +08:00
CaIon
a29f4d88c5 Update model-ratio.go 2025-04-04 23:41:41 +08:00
QuentinHsu
09adc6f201 refactor(web): systemSetting component to enhance UI structure and add new configuration options
- Wrapped form sections in Card components for better visual separation
- Added new configuration options for payment settings, email domain whitelist, SMTP, OIDC, GitHub OAuth, Linux DO OAuth, WeChat, and Telegram
- Improved layout with responsive design using Row and Col components
- Updated button actions for saving settings in new sections
2025-04-04 17:46:34 +08:00
QuentinHsu
6b79b89dc0 style(web): format code 2025-04-04 17:37:27 +08:00
IllTamer
3223c7e181 feat & fix: fix the pricing available sort, set defaultSortOrder descend 2025-03-10 22:39:21 +08:00
IllTamer
ccfac06645 fix: the pricing available popover display anyway 2025-03-10 22:16:02 +08:00
186 changed files with 11410 additions and 5040 deletions

3
.gitignore vendored
View File

@@ -9,4 +9,5 @@ logs
web/dist
.env
one-api
.DS_Store
.DS_Store
tiktoken_cache

View File

@@ -1,10 +1,13 @@
<p align="right">
<a href="./README.md">中文</a> | <strong>English</strong>
</p>
<div align="center">
![new-api](/web/public/logo.png)
# New API
🍥 Next Generation LLM Gateway and AI Asset Management System
🍥 Next-Generation Large Model Gateway and AI Asset Management System
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
@@ -33,171 +36,155 @@
> This is an open-source project developed based on [One API](https://github.com/songquanpeng/one-api)
> [!IMPORTANT]
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and relevant laws and regulations. Not to be used for illegal purposes.
> - This project is for personal learning only. Stability is not guaranteed, and no technical support is provided.
> - This project is for personal learning purposes only, with no guarantee of stability or technical support.
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
## 📚 Documentation
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
## ✨ Key Features
1. 🎨 New UI interface (some interfaces pending update)
2. 🌍 Multi-language support (work in progress)
3. 🎨 Added [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface support, [Integration Guide](Midjourney.md)
4. 💰 Online recharge support, configurable in system settings:
- [x] EasyPay
5. 🔍 Query usage quota by key:
- Works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)
6. 📑 Configurable items per page in pagination
7. 🔄 Compatible with original One API database (one-api.db)
8. 💵 Support per-request model pricing, configurable in System Settings - Operation Settings
9. ⚖️ Support channel **weighted random** selection
10. 📈 Data dashboard (console)
11. 🔒 Configurable model access per token
12. 🤖 Telegram authorization login support:
1. System Settings - Configure Login Registration - Allow Telegram Login
2. Send /setdomain command to [@Botfather](https://t.me/botfather)
3. Select your bot, then enter http(s)://your-website/login
4. Telegram Bot name is the bot username without @
13. 🎵 Added [Suno API](https://github.com/Suno-API/Suno-API) interface support, [Integration Guide](Suno.md)
14. 🔄 Support for Rerank models, compatible with Cohere and Jina, can integrate with Dify, [Integration Guide](Rerank.md)
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - Support for OpenAI's Realtime API, including Azure channels
16. 🧠 Support for setting reasoning effort through model name suffix:
- Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`)
- Add suffix `-medium` to set medium reasoning effort
- Add suffix `-low` to set low reasoning effort
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings`
19. 💰 Cache billing support, when enabled can charge a configurable ratio for cache hits:
1. Set `Prompt Cache Ratio` in `System Settings -> Operation Settings`
2. Set `Prompt Cache Ratio` in channel settings, range 0-1 (e.g., 0.5 means 50% charge on cache hits)
New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details:
1. 🎨 Brand new UI interface
2. 🌍 Multi-language support
3. 💰 Online recharge functionality (YiPay)
4. 🔍 Support for querying usage quotas with keys (works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
5. 🔄 Compatible with the original One API database
6. 💵 Support for pay-per-use model pricing
7. ⚖️ Support for weighted random channel selection
8. 📈 Data dashboard (console)
9. 🔒 Token grouping and model restrictions
10. 🤖 Support for more authorization login methods (LinuxDO, Telegram, OIDC)
11. 🔄 Support for Rerank models (Cohere and Jina), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank)
12. ⚡ Support for OpenAI Realtime API (including Azure channels), [API Documentation](https://docs.newapi.pro/api/openai-realtime)
13. ⚡ Support for Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat)
14. Support for entering chat interface via /chat2link route
15. 🧠 Support for setting reasoning effort through model name suffixes:
1. OpenAI o-series models
- Add `-high` suffix for high reasoning effort (e.g.: `o3-mini-high`)
- Add `-medium` suffix for medium reasoning effort (e.g.: `o3-mini-medium`)
- Add `-low` suffix for low reasoning effort (e.g.: `o3-mini-low`)
2. Claude thinking models
- Add `-thinking` suffix to enable thinking mode (e.g.: `claude-3-7-sonnet-20250219-thinking`)
16. 🔄 Thinking-to-content functionality
17. 🔄 Model rate limiting for users
18. 💰 Cache billing support, which allows billing at a set ratio when cache is hit:
1. Set the `Prompt Cache Ratio` option in `System Settings-Operation Settings`
2. Set `Prompt Cache Ratio` in the channel, range 0-1, e.g., setting to 0.5 means billing at 50% when cache is hit
3. Supported channels:
- [x] OpenAI
- [x] Azure
- [x] Azure
- [x] DeepSeek
- [ ] Claude
- [x] Claude
## Model Support
This version additionally supports:
1. Third-party model **gpts** (gpt-4-gizmo-*)
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md)
3. Custom channels with full API URL support
4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md)
5. Rerank models, supporting [Cohere](https://cohere.ai/) and [Jina](https://jina.ai/), [Integration Guide](Rerank.md)
6. Dify
You can add custom models gpt-4-gizmo-* in channels. These are third-party models and cannot be called with official OpenAI keys.
This version supports multiple models, please refer to [API Documentation-Relay Interface](https://docs.newapi.pro/api) for details:
## Additional Configurations Beyond One API
- `GENERATE_DEFAULT_TOKEN`: Generate initial token for new users, default `false`
- `STREAMING_TIMEOUT`: Set streaming response timeout, default 60 seconds
- `DIFY_DEBUG`: Output workflow and node info to client for Dify channel, default `true`
- `FORCE_STREAM_OPTION`: Override client stream_options parameter, default `true`
- `GET_MEDIA_TOKEN`: Calculate image tokens, default `true`
- `GET_MEDIA_TOKEN_NOT_STREAM`: Calculate image tokens in non-stream mode, default `true`
- `UPDATE_TASK`: Update async tasks (Midjourney, Suno), default `true`
- `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
- `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
- `CRYPTO_SECRET`: Encryption key for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
1. Third-party models **gpts** (gpt-4-gizmo-*)
2. Third-party channel [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [API Documentation](https://docs.newapi.pro/api/midjourney-proxy-image)
3. Third-party channel [Suno API](https://github.com/Suno-API/Suno-API) interface, [API Documentation](https://docs.newapi.pro/api/suno-music)
4. Custom channels, supporting full call address input
5. Rerank models ([Cohere](https://cohere.ai/) and [Jina](https://jina.ai/)), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank)
6. Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat)
7. Dify, currently only supports chatflow
## Environment Variable Configuration
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
- `GET_MEDIA_TOKEN_NOT_STREAM`: Whether to count image tokens in non-streaming cases, default is `true`
- `UPDATE_TASK`: Whether to update asynchronous tasks (Midjourney, Suno), default is `true`
- `COHERE_SAFETY_SETTING`: Cohere model safety settings, options are `NONE`, `CONTEXTUAL`, `STRICT`, default is `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16`
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20`
- `CRYPTO_SECRET`: Encryption key used for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2`
## Deployment
For detailed deployment guides, please refer to [Installation Guide-Deployment Methods](https://docs.newapi.pro/installation):
> [!TIP]
> Latest Docker image: `calciumion/new-api:latest`
> Default account: root, password: 123456
> Latest Docker image: `calciumion/new-api:latest`
### Multi-Server Deployment
- Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers.
- If using a public Redis, must set `CRYPTO_SECRET` environment variable, otherwise Redis content will not be able to be obtained in multi-server deployment.
### Multi-machine Deployment Considerations
- Environment variable `SESSION_SECRET` must be set, otherwise login status will be inconsistent across multiple machines
- If sharing Redis, `CRYPTO_SECRET` must be set, otherwise Redis content cannot be accessed across multiple machines
### Requirements
- Local database (default): SQLite (Docker deployment must mount `/data` directory)
- Remote database: MySQL >= 5.7.8, PgSQL >= 9.6
### Deployment Requirements
- Local database (default): SQLite (Docker deployment must mount the `/data` directory)
- Remote database: MySQL version >= 5.7.8, PgSQL version >= 9.6
### Deployment with BT Panel
Install BT Panel (**version 9.2.0** or above) from [BT Panel Official Website](https://www.bt.cn/new/download.html), choose the stable version script to download and install.
After installation, log in to BT Panel and click Docker in the menu bar. First-time access will prompt to install Docker service. Click Install Now and follow the prompts to complete installation.
After installation, find **New-API** in the app store, click install, configure basic options to complete installation.
[Pictorial Guide](BT.md)
### Deployment Methods
### Docker Deployment
#### Using BaoTa Panel Docker Feature
Install BaoTa Panel (version **9.2.0** or above), find **New-API** in the application store and install it.
[Tutorial with images](./docs/BT.md)
### Using Docker Compose (Recommended)
#### Using Docker Compose (Recommended)
```shell
# Clone project
# Download the project
git clone https://github.com/Calcium-Ion/new-api.git
cd new-api
# Edit docker-compose.yml as needed
# nano docker-compose.yml
# vim docker-compose.yml
# Start
docker-compose up -d
```
#### Update Version
#### Using Docker Image Directly
```shell
docker-compose pull
docker-compose up -d
```
### Direct Docker Image Usage
```shell
# SQLite deployment:
# Using SQLite
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# MySQL deployment (add -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"), modify database connection parameters as needed
# Example:
# Using MySQL
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
#### Update Version
```shell
# Pull the latest image
docker pull calciumion/new-api:latest
# Stop and remove the old container
docker stop new-api
docker rm new-api
# Run the new container with the same parameters as before
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
## Channel Retry and Cache
Channel retry functionality has been implemented, you can set the number of retries in `Settings->Operation Settings->General Settings`. It is **recommended to enable caching**.
Alternatively, you can use Watchtower for automatic updates (not recommended, may cause database incompatibility):
```shell
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
```
### Cache Configuration Method
1. `REDIS_CONN_STRING`: Set Redis as cache
2. `MEMORY_CACHE_ENABLED`: Enable memory cache (no need to set manually if Redis is set)
## Channel Retry
Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request.
## API Documentation
### Cache Configuration
1. `REDIS_CONN_STRING`: Use Redis as cache
+ Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`: Enable memory cache, default `false`
+ Example: `MEMORY_CACHE_ENABLED=true`
For detailed API documentation, please refer to [API Documentation](https://docs.newapi.pro/api):
### Why Some Errors Don't Retry
Error codes 400, 504, 524 won't retry
### To Enable Retry for 400
In `Channel->Edit`, set `Status Code Override` to:
```json
{
"400": "500"
}
```
## Integration Guides
- [Midjourney Integration](Midjourney.md)
- [Suno Integration](Suno.md)
- [Chat API](https://docs.newapi.pro/api/openai-chat)
- [Image API](https://docs.newapi.pro/api/openai-image)
- [Rerank API](https://docs.newapi.pro/api/jinaai-rerank)
- [Realtime API](https://docs.newapi.pro/api/openai-realtime)
- [Claude Chat API (messages)](https://docs.newapi.pro/api/anthropic-chat)
## Related Projects
- [One API](https://github.com/songquanpeng/one-api): Original project
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy): Midjourney interface support
- [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-gen AI B/C solution
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota by key
- [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-generation AI one-stop B/C-end solution
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota with key
Other projects based on New API:
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon): High-performance optimized version of New API
- [VoAPI](https://github.com/VoAPI/VoAPI): Frontend beautified version based on New API
## Help and Support
If you have any questions, please refer to [Help and Support](https://docs.newapi.pro/support):
- [Community Interaction](https://docs.newapi.pro/support/community-interaction)
- [Issue Feedback](https://docs.newapi.pro/support/feedback-issues)
- [FAQ](https://docs.newapi.pro/support/faq)
## 🌟 Star History
[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date)
[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date)

View File

@@ -107,7 +107,7 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认 `16`
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小单位MB默认 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本默认 `2024-12-01-preview`
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本默认 `2025-04-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
@@ -130,7 +130,7 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
#### 使用宝塔面板Docker功能部署
安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。
[图文教程](BT.md)
[图文教程](./docs/BT.md)
#### 使用Docker Compose部署推荐
```shell

View File

@@ -1,8 +1,8 @@
package common
import (
"os"
"strconv"
//"os"
//"strconv"
"sync"
"time"
@@ -62,9 +62,13 @@ var EmailDomainWhitelist = []string{
"yahoo.com",
"foxmail.com",
}
var EmailLoginAuthServerList = []string{
"smtp.sendcloud.net",
"smtp.azurecomm.net",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var DebugEnabled bool
var MemoryCacheEnabled bool
var LogConsumeEnabled = true
@@ -103,22 +107,22 @@ var RetryTimes = 0
//var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var IsMasterNode bool
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var requestInterval int
var RequestInterval time.Duration
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
var SyncFrequency int // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
var BatchUpdateInterval int
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var RelayTimeout int // unit is second
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var GeminiSafetySetting string
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
var CohereSafetySetting string
const (
RequestIdKey = "X-Oneapi-Request-Id"
@@ -145,13 +149,13 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
GlobalApiRateLimitEnable bool
GlobalApiRateLimitNum int
GlobalApiRateLimitDuration int64
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
GlobalWebRateLimitEnable bool
GlobalWebRateLimitNum int
GlobalWebRateLimitDuration int64
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
@@ -235,6 +239,8 @@ const (
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -288,4 +294,6 @@ var ChannelBaseURLs = []string{
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/base64"
"fmt"
"net/smtp"
"slices"
"strings"
"time"
)
@@ -79,7 +80,7 @@ func SendEmail(subject string, receiver string, content string) error {
if err != nil {
return err
}
} else if isOutlookServer(SMTPAccount) || SMTPServer == "smtp.azurecomm.net" {
} else if isOutlookServer(SMTPAccount) || slices.Contains(EmailLoginAuthServerList, SMTPServer) {
auth = LoginAuth(SMTPAccount, SMTPToken)
err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
} else {

View File

@@ -6,6 +6,8 @@ import (
"log"
"os"
"path/filepath"
"strconv"
"time"
)
var (
@@ -66,4 +68,31 @@ func LoadEnv() {
}
}
}
// Initialize variables from constants.go that were using environment variables
DebugEnabled = os.Getenv("DEBUG") == "true"
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
// Parse requestInterval and set RequestInterval
requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
RequestInterval = time.Duration(requestInterval) * time.Second
// Initialize variables with GetEnvOrDefault
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
// Initialize string variables with GetEnvOrDefaultString
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
// Initialize rate limit variables
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
}

View File

@@ -12,3 +12,7 @@ func DecodeJson(data []byte, v any) error {
func DecodeJsonStr(data string, v any) error {
return DecodeJson(StringToByteSlice(data), v)
}
func EncodeJson(v any) ([]byte, error) {
return json.Marshal(v)
}

89
common/limiter/limiter.go Normal file
View File

@@ -0,0 +1,89 @@
package limiter
import (
"context"
_ "embed"
"fmt"
"github.com/go-redis/redis/v8"
"one-api/common"
"sync"
)
//go:embed lua/rate_limit.lua
var rateLimitScript string
type RedisLimiter struct {
client *redis.Client
limitScriptSHA string
}
var (
instance *RedisLimiter
once sync.Once
)
func New(ctx context.Context, r *redis.Client) *RedisLimiter {
once.Do(func() {
// 预加载脚本
limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
if err != nil {
common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
}
instance = &RedisLimiter{
client: r,
limitScriptSHA: limitSHA,
}
})
return instance
}
func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
// 默认配置
config := &Config{
Capacity: 10,
Rate: 1,
Requested: 1,
}
// 应用选项模式
for _, opt := range opts {
opt(config)
}
// 执行限流
result, err := rl.client.EvalSha(
ctx,
rl.limitScriptSHA,
[]string{key},
config.Requested,
config.Rate,
config.Capacity,
).Int()
if err != nil {
return false, fmt.Errorf("rate limit failed: %w", err)
}
return result == 1, nil
}
// Config 配置选项模式
type Config struct {
Capacity int64
Rate int64
Requested int64
}
type Option func(*Config)
func WithCapacity(c int64) Option {
return func(cfg *Config) { cfg.Capacity = c }
}
func WithRate(r int64) Option {
return func(cfg *Config) { cfg.Rate = r }
}
func WithRequested(n int64) Option {
return func(cfg *Config) { cfg.Requested = n }
}

View File

@@ -0,0 +1,44 @@
-- 令牌桶限流器
-- KEYS[1]: 限流器唯一标识
-- ARGV[1]: 请求令牌数 (通常为1)
-- ARGV[2]: 令牌生成速率 (每秒)
-- ARGV[3]: 桶容量
local key = KEYS[1]
local requested = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local capacity = tonumber(ARGV[3])
-- 获取当前时间Redis服务器时间
local now = redis.call('TIME')
local nowInSeconds = tonumber(now[1])
-- 获取桶状态
local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
local tokens = tonumber(bucket[1])
local last_time = tonumber(bucket[2])
-- 初始化桶(首次请求或过期)
if not tokens or not last_time then
tokens = capacity
last_time = nowInSeconds
else
-- 计算新增令牌
local elapsed = nowInSeconds - last_time
local add_tokens = elapsed * rate
tokens = math.min(capacity, tokens + add_tokens)
last_time = nowInSeconds
end
-- 判断是否允许请求
local allowed = false
if tokens >= requested then
tokens = tokens - requested
allowed = true
end
---- 更新桶状态并设置过期时间
redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
return allowed and 1 or 0

View File

@@ -7,7 +7,6 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"html/template"
"io"
"log"
@@ -22,6 +21,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
)
func OpenBrowser(url string) {

5
constant/azure.go Normal file
View File

@@ -0,0 +1,5 @@
package constant
import "time"
var AzureNoRemoveDotTime = time.Date(2025, time.May, 10, 0, 0, 0, 0, time.UTC).Unix()

View File

@@ -4,32 +4,42 @@ import (
"one-api/common"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
var StreamingTimeout int
var DifyDebug bool
var MaxFileDownloadMB int
var ForceStreamOption bool
var GetMediaToken bool
var GetMediaTokenNotStream bool
var UpdateTask bool
var AzureDefaultAPIVersion string
var GeminiVisionMaxImageNum int
var NotifyLimitCount int
var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
var ErrorLogEnabled bool
//var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1",
//}
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
//if modelVersionMapStr == "" {
// return
@@ -43,6 +53,3 @@ func InitEnv() {
// }
//}
}
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

View File

@@ -108,6 +108,13 @@ type DeepSeekUsageResponse struct {
} `json:"balance_infos"`
}
type OpenRouterCreditResponse struct {
Data struct {
TotalCredits float64 `json:"total_credits"`
TotalUsage float64 `json:"total_usage"`
} `json:"data"`
}
// GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header {
h := http.Header{}
@@ -281,6 +288,22 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
return response.TotalAvailable, nil
}
func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
url := "https://openrouter.ai/api/v1/credits"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := OpenRouterCreditResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
balance := response.Data.TotalCredits - response.Data.TotalUsage
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
@@ -307,6 +330,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelSiliconFlowBalance(channel)
case common.ChannelTypeDeepSeek:
return updateChannelDeepSeekBalance(channel)
case common.ChannelTypeOpenRouter:
return updateChannelOpenRouterBalance(channel)
default:
return 0, errors.New("尚未实现")
}

View File

@@ -103,7 +103,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
request := buildTestRequest(testModel)
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
// 创建一个用于日志的 info 副本,移除 ApiKey
logInfo := *info
logInfo.ApiKey = ""
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
if err != nil {
@@ -186,12 +189,14 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
return testRequest
}
// 并非Embedding 模型
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = 10
} else if strings.Contains(model, "thinking") {
if !strings.Contains(model, "claude") {
testRequest.MaxTokens = 50
}
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 300
} else {
testRequest.MaxTokens = 10
}

View File

@@ -119,6 +119,9 @@ func FetchUpstreamModels(c *gin.Context) {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", baseURL)
if channel.Type == common.ChannelTypeGemini {
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
}
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -139,7 +142,11 @@ func FetchUpstreamModels(c *gin.Context) {
var ids []string
for _, model := range result.Data {
ids = append(ids, model.ID)
id := model.ID
if channel.Type == common.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
}
c.JSON(http.StatusOK, gin.H{

View File

@@ -196,7 +196,7 @@ func DeleteHistoryLogs(c *gin.Context) {
})
return
}
count, err := model.DeleteOldLog(targetTimestamp)
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "ModelRequestRateLimitGroup":
err = setting.CheckModelRequestRateLimitGroup(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
}
err = model.UpdateOption(option.Key, option.Value)

View File

@@ -4,12 +4,11 @@ import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"log"
"net/http"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
"one-api/middleware"
"one-api/model"
@@ -19,12 +18,15 @@ import (
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
err = relay.ImageHelper(c)
case relayconstant.RelayModeAudioSpeech:
fallthrough
@@ -36,9 +38,31 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
err = relay.RerankHelper(c, relayMode)
case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c)
case relayconstant.RelayModeResponses:
err = relay.ResponsesHelper(c)
default:
err = relay.TextHelper(c)
}
if constant2.ErrorLogEnabled && err != nil {
// 保存错误日志到mysql中
userId := c.GetInt("id")
tokenName := c.GetString("token_name")
modelName := c.GetString("original_model")
tokenId := c.GetInt("token_id")
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
other["error_type"] = err.Error.Type
other["error_code"] = err.Error.Code
other["status_code"] = err.StatusCode
other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type")
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
}
return err
}

View File

@@ -592,7 +592,14 @@ func UpdateSelf(c *gin.Context) {
user.Password = "" // rollback to what it should be
cleanUser.Password = ""
}
updatePassword := user.Password != ""
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if err := cleanUser.Update(updatePassword); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -608,6 +615,23 @@ func UpdateSelf(c *gin.Context) {
return
}
func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) {
var currentUser *model.User
currentUser, err = model.GetUserById(userId, true)
if err != nil {
return
}
if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) {
err = fmt.Errorf("原密码错误")
return
}
if newPassword == "" {
return
}
updatePassword = true
return
}
func DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {

View File

@@ -15,6 +15,8 @@ services:
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
- REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache请取消注释
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed

View File

@@ -1,3 +1,3 @@
密钥为环境变量SESSION_SECRET
![8285bba413e770fe9620f1bf9b40d44e](https://github.com/user-attachments/assets/7a6fc03e-c457-45e4-b8f9-184508fc26b0)
密钥为环境变量SESSION_SECRET
![8285bba413e770fe9620f1bf9b40d44e](https://github.com/user-attachments/assets/7a6fc03e-c457-45e4-b8f9-184508fc26b0)

View File

@@ -7,7 +7,7 @@ type ClaudeMetadata struct {
}
type ClaudeMediaMessage struct {
Type string `json:"type"`
Type string `json:"type,omitempty"`
Text *string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
@@ -50,6 +50,11 @@ func (c *ClaudeMediaMessage) GetStringContent() string {
return ""
}
func (c *ClaudeMediaMessage) GetJsonRowString() string {
jsonContent, _ := json.Marshal(c)
return string(jsonContent)
}
func (c *ClaudeMediaMessage) SetContent(content any) {
jsonContent, _ := json.Marshal(content)
c.Content = jsonContent
@@ -65,8 +70,9 @@ func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
type ClaudeMessageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data any `json:"data"`
MediaType string `json:"media_type,omitempty"`
Data any `json:"data,omitempty"`
Url string `json:"url,omitempty"`
}
type ClaudeMessage struct {

View File

@@ -1,14 +1,20 @@
package dto
import "encoding/json"
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
}
type ImageResponse struct {

View File

@@ -18,39 +18,41 @@ type FormatJsonSchema struct {
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
//Reasoning json.RawMessage `json:"reasoning,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
EnableThinking any `json:"enable_thinking,omitempty"` // ali
ExtraBody any `json:"extra_body,omitempty"`
}
type ToolCallRequest struct {
@@ -111,6 +113,8 @@ type MediaContent struct {
Text string `json:"text,omitempty"`
ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"`
File any `json:"file,omitempty"`
VideoUrl any `json:"video_url,omitempty"`
}
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
@@ -120,6 +124,20 @@ func (m *MediaContent) GetImageMedia() *MessageImageUrl {
return nil
}
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
if m.InputAudio != nil {
return m.InputAudio.(*MessageInputAudio)
}
return nil
}
func (m *MediaContent) GetFile() *MessageFile {
if m.File != nil {
return m.File.(*MessageFile)
}
return nil
}
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
@@ -135,10 +153,22 @@ type MessageInputAudio struct {
Format string `json:"format"`
}
type MessageFile struct {
FileName string `json:"filename,omitempty"`
FileData string `json:"file_data,omitempty"`
FileId string `json:"file_id,omitempty"`
}
type MessageVideoUrl struct {
Url string `json:"url"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
ContentTypeInputAudio = "input_audio"
ContentTypeFile = "file"
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
)
func (m *Message) GetPrefix() bool {
@@ -192,6 +222,12 @@ func (m *Message) StringContent() string {
return stringContent
}
func (m *Message) SetNullContent() {
m.Content = nil
m.parsedStringContent = nil
m.parsedContent = nil
}
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
@@ -292,6 +328,39 @@ func (m *Message) ParseContent() []MediaContent {
})
}
}
case ContentTypeFile:
if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
fileId, ok3 := fileData["file_id"].(string)
if ok3 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileId: fileId,
},
})
} else {
fileName, ok1 := fileData["filename"].(string)
fileDataStr, ok2 := fileData["file_data"].(string)
if ok1 && ok2 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileName: fileName,
FileData: fileDataStr,
},
})
}
}
}
case ContentTypeVideoUrl:
if videoUrl, ok := contentItem["video_url"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeVideoUrl,
VideoUrl: &MessageVideoUrl{
Url: videoUrl,
},
})
}
}
}
}
@@ -301,3 +370,49 @@ func (m *Message) ParseContent() []MediaContent {
}
return contentList
}
type OpenAIResponsesRequest struct {
Model string `json:"model"`
Input json.RawMessage `json:"input,omitempty"`
Include json.RawMessage `json:"include,omitempty"`
Instructions json.RawMessage `json:"instructions,omitempty"`
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Store bool `json:"store,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Text json.RawMessage `json:"text,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Tools []ResponsesToolsCall `json:"tools,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Truncation string `json:"truncation,omitempty"`
User string `json:"user,omitempty"`
}
type Reasoning struct {
Effort string `json:"effort,omitempty"`
Summary string `json:"summary,omitempty"`
}
type ResponsesToolsCall struct {
Type string `json:"type"`
// Web Search
UserLocation json.RawMessage `json:"user_location,omitempty"`
SearchContextSize string `json:"search_context_size,omitempty"`
// File Search
VectorStoreIds []string `json:"vector_store_ids,omitempty"`
MaxNumResults uint `json:"max_num_results,omitempty"`
Filters json.RawMessage `json:"filters,omitempty"`
// Computer Use
DisplayWidth uint `json:"display_width,omitempty"`
DisplayHeight uint `json:"display_height,omitempty"`
Environment string `json:"environment,omitempty"`
// Function
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}

View File

@@ -1,5 +1,7 @@
package dto
import "encoding/json"
type SimpleResponse struct {
Usage `json:"usage"`
Error *OpenAIError `json:"error"`
@@ -166,10 +168,93 @@ type CompletionsStreamResponse struct {
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
}
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
CachedCreationTokens int `json:"-"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
}
type OutputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
}
type OpenAIResponsesResponse struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int `json:"created_at"`
Status string `json:"status"`
Error *OpenAIError `json:"error,omitempty"`
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
Instructions string `json:"instructions"`
MaxOutputTokens int `json:"max_output_tokens"`
Model string `json:"model"`
Output []ResponsesOutput `json:"output"`
ParallelToolCalls bool `json:"parallel_tool_calls"`
PreviousResponseID string `json:"previous_response_id"`
Reasoning *Reasoning `json:"reasoning"`
Store bool `json:"store"`
Temperature float64 `json:"temperature"`
ToolChoice string `json:"tool_choice"`
Tools []ResponsesToolsCall `json:"tools"`
TopP float64 `json:"top_p"`
Truncation string `json:"truncation"`
Usage *Usage `json:"usage"`
User json.RawMessage `json:"user"`
Metadata json.RawMessage `json:"metadata"`
}
type IncompleteDetails struct {
Reasoning string `json:"reasoning"`
}
type ResponsesOutput struct {
Type string `json:"type"`
ID string `json:"id"`
Status string `json:"status"`
Role string `json:"role"`
Content []ResponsesOutputContent `json:"content"`
}
type ResponsesOutputContent struct {
Type string `json:"type"`
Text string `json:"text"`
Annotations []interface{} `json:"annotations"`
}
const (
BuildInToolWebSearchPreview = "web_search_preview"
BuildInToolFileSearch = "file_search"
)
const (
BuildInCallWebSearchCall = "web_search_call"
)
const (
ResponsesOutputTypeItemAdded = "response.output_item.added"
ResponsesOutputTypeItemDone = "response.output_item.done"
)
// ResponsesStreamResponse 用于处理 /v1/responses 流式响应
type ResponsesStreamResponse struct {
Type string `json:"type"`
Response *OpenAIResponsesResponse `json:"response,omitempty"`
Delta string `json:"delta,omitempty"`
Item *ResponsesOutput `json:"item,omitempty"`
}

View File

@@ -43,19 +43,6 @@ type RealtimeUsage struct {
OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
}
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
CachedCreationTokens int
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
}
type OutputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
}
type RealtimeSession struct {
Modalities []string `json:"modalities"`
Instructions string `json:"instructions"`

32
main.go
View File

@@ -12,6 +12,7 @@ import (
"one-api/model"
"one-api/router"
"one-api/service"
"one-api/setting/operation_setting"
"os"
"strconv"
@@ -33,7 +34,7 @@ var indexPage []byte
func main() {
err := godotenv.Load(".env")
if err != nil {
common.SysLog("Support for .env file is disabled")
common.SysLog("Support for .env file is disabled: " + err.Error())
}
common.LoadEnv()
@@ -51,6 +52,9 @@ func main() {
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
model.CheckSetup()
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
@@ -69,10 +73,15 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize model settings
operation_setting.InitRatioSettings()
// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
service.InitTokenEncoders()
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
@@ -80,9 +89,22 @@ func main() {
if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
model.InitChannelCache()
}
if common.MemoryCacheEnabled {
// Add panic recovery and retry for InitChannelCache
func() {
defer func() {
if r := recover(); r != nil {
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
// Retry once
_, fixErr := model.FixAbility()
if fixErr != nil {
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
}
}
}()
model.InitChannelCache()
}()
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
@@ -126,8 +148,6 @@ func main() {
common.SysLog("pprof enabled")
}
service.InitTokenEncoders()
// Initialize HTTP server
server := gin.New()
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {

View File

@@ -162,7 +162,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
@@ -184,6 +184,8 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode := relayconstant.RelayModeAudioSpeech
@@ -211,6 +213,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
c.Set("channel_create_time", channel.CreatedTime)
c.Set("channel_setting", channel.GetSetting())
c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
@@ -237,5 +240,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("api_version", channel.Other)
case common.ChannelTypeMokaAI:
c.Set("api_version", channel.Other)
case common.ChannelTypeCoze:
c.Set("bot_id", channel.Other)
}
}

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/common/limiter"
"one-api/constant"
"one-api/setting"
"strconv"
"time"
@@ -78,21 +80,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
ctx := context.Background()
rdb := common.RDB
// 1. 检查请求数限制当totalMaxCount为0时会自动跳过
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 2. 检查成功请求数限制
// 1. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
@@ -103,8 +93,29 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
return
}
// 3. 记录总请求当totalMaxCount为0时会自动跳过
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
//2.检查总请求数限制并记录总请求当totalMaxCount为0时会自动跳过,使用令牌桶限流器
if totalMaxCount > 0 {
totalKey := fmt.Sprintf("rateLimit:%s", userId)
// 初始化
tb := limiter.New(ctx, rdb)
allowed, err = tb.Allow(
ctx,
totalKey,
limiter.WithCapacity(int64(totalMaxCount)*duration),
limiter.WithRate(int64(totalMaxCount)),
limiter.WithRequested(duration),
)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
}
// 4. 处理请求
c.Next()
@@ -165,6 +176,19 @@ func ModelRequestRateLimit() func(c *gin.Context) {
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 获取分组
group := c.GetString("token_group")
if group == "" {
group = c.GetString(constant.ContextKeyUserGroup)
}
//获取分组的限流配置
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
if found {
totalMaxCount = groupTotalCount
successMaxCount = groupSuccessCount
}
// 根据存储类型选择并执行限流处理器
if common.RedisEnabled {
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)

View File

@@ -50,7 +50,7 @@ func getPriority(group string, model string, retry int) (int, error) {
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
Order("priority DESC"). // 按优先级降序排序
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
@@ -261,12 +261,28 @@ func FixAbility() (int, error) {
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
return 0, err
}
// Delete abilities of channels that are not in channel table
err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error()))
return 0, err
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
if len(channelIds) > 0 {
// Process deletion in chunks to avoid "too many placeholders" error
for _, chunk := range lo.Chunk(channelIds, 100) {
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
return 0, err
}
}
} else {
// If no channels exist, delete all abilities
err = DB.Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
return 0, err
}
common.SysLog("Delete all abilities successfully")
return 0, nil
}
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
count += len(channelIds)
@@ -275,17 +291,26 @@ func FixAbility() (int, error) {
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err
return count, err
}
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
}
if err != nil {
return 0, err
// Process query in chunks to avoid "too many placeholders" error
err = nil
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
var channelsChunk []Channel
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
if err != nil {
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
return count, err
}
channels = append(channels, channelsChunk...)
}
}
for _, channel := range channels {
err := channel.UpdateAbilities(nil)
if err != nil {

View File

@@ -16,6 +16,9 @@ var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex
func InitChannelCache() {
if !common.MemoryCacheEnabled {
return
}
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
@@ -84,9 +87,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model, retry)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model]
channelSyncLock.RUnlock()
if len(channels) == 0 {
return nil, errors.New("channel not found")
}

View File

@@ -46,6 +46,17 @@ func (channel *Channel) GetModels() []string {
return strings.Split(strings.Trim(channel.Models, ","), ",")
}
func (channel *Channel) GetGroups() []string {
if channel.Group == "" {
return []string{}
}
groups := strings.Split(strings.Trim(channel.Group, ","), ",")
for i, group := range groups {
groups[i] = strings.TrimSpace(group)
}
return groups
}
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {
@@ -119,10 +130,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
// 如果是 PostgreSQL使用双引号
if common.UsingPostgreSQL {
keyCol = `"key"`
modelsCol = `"models"`
}
baseURLCol := "`base_url`"
// 如果是 PostgreSQL使用双引号
if common.UsingPostgreSQL {
baseURLCol = `"base_url"`
}
order := "priority desc"
if idSort {
order = "id desc"
@@ -142,11 +158,11 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
// sqlite, PostgreSQL
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%", "%,"+group+",%")
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
// 执行查询
@@ -450,6 +466,12 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
modelsCol = `"models"`
}
baseURLCol := "`base_url`"
// 如果是 PostgreSQL使用双引号
if common.UsingPostgreSQL {
baseURLCol = `"base_url"`
}
order := "priority desc"
if idSort {
order = "id desc"
@@ -469,11 +491,11 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
// sqlite, PostgreSQL
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%", "%,"+group+",%")
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
subQuery := baseQuery.Where(whereClause, args...).

View File

@@ -1,6 +1,7 @@
package model
import (
"context"
"fmt"
"one-api/common"
"os"
@@ -40,6 +41,7 @@ const (
LogTypeConsume
LogTypeManage
LogTypeSystem
LogTypeError
)
func formatUserLogs(logs []*Log) {
@@ -88,6 +90,35 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeError,
Content: content,
PromptTokens: 0,
CompletionTokens: 0,
TokenName: tokenName,
ModelName: modelName,
Quota: 0,
ChannelId: channelId,
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
common.LogError(c, "failed to record log: "+err.Error())
}
}
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
@@ -310,7 +341,25 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
return token
}
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) {
var total int64 = 0
for {
if nil != ctx.Err() {
return total, ctx.Err()
}
result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
if nil != result.Error {
return total, result.Error
}
total += result.RowsAffected
if result.RowsAffected < int64(limit) {
break
}
}
return total, nil
}

View File

@@ -56,7 +56,7 @@ func createRootAccountIfNeed() error {
return nil
}
func checkSetup() {
func CheckSetup() {
setup := GetSetup()
if setup == nil {
// No setup record exists, check if we have a root user
@@ -244,7 +244,6 @@ func migrateDB() error {
}
err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated")
checkSetup()
//err = createRootAccountIfNeed()
return err
}

View File

@@ -67,6 +67,7 @@ func InitOptionMap() {
common.OptionMap["ServerAddress"] = ""
common.OptionMap["WorkerUrl"] = setting.WorkerUrl
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled)
common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = ""
@@ -92,6 +93,7 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
@@ -256,6 +258,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
}
}
switch key {
@@ -338,6 +342,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitGroup":
err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":

View File

@@ -18,6 +18,7 @@ type User struct {
Id int `json:"id"`
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database!
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
@@ -108,7 +109,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
func GetMaxUserId() int {
var user User
DB.Last(&user)
DB.Unscoped().Last(&user)
return user.Id
}

View File

@@ -1,11 +1,12 @@
package channel
import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaycommon "one-api/relay/common"
"github.com/gin-gonic/gin"
)
type Adaptor interface {
@@ -18,6 +19,7 @@ type Adaptor interface {
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string

View File

@@ -3,7 +3,6 @@ package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -11,6 +10,8 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -32,6 +33,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
case constant.RelayModeCompletions:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
}
@@ -79,6 +82,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,7 +1,12 @@
package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"qwen-turbo",
"qwen-plus",
"qwen-max",
"qwen-max-longcontext",
"qwq-32b",
"qwen3-235b-a22b",
"text-embedding-v1",
}

View File

@@ -1,16 +1,23 @@
package channel
import (
"context"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"net/http"
common2 "one-api/common"
"one-api/relay/common"
"one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/operation_setting"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
@@ -55,6 +62,9 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
@@ -105,7 +115,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
} else {
client = service.GetHttpClient()
}
// 流式请求 ping 保活
var stopPinger func()
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
var pingerWg sync.WaitGroup
if info.IsStream {
helper.SetEventStreamHeaders(c)
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
var pingerCtx context.Context
pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
if pingEnabled {
pingerWg.Add(1)
gopool.Go(func() {
defer pingerWg.Done()
if pingInterval <= 0 {
pingInterval = helper.DefaultPingInterval
}
ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
var pingMutex sync.Mutex
if common2.DebugEnabled {
println("SSE ping goroutine started")
}
for {
select {
case <-ticker.C:
pingMutex.Lock()
err2 := helper.PingData(c)
pingMutex.Unlock()
if err2 != nil {
common2.LogError(c, "SSE ping error: "+err.Error())
return
}
if common2.DebugEnabled {
println("SSE ping data sent.")
}
case <-pingerCtx.Done():
if common2.DebugEnabled {
println("SSE ping goroutine stopped.")
}
return
}
}
})
}
}
resp, err := client.Do(req)
// request结束后停止ping
if info.IsStream && pingEnabled {
stopPinger()
pingerWg.Wait()
}
if err != nil {
return nil, err
}

View File

@@ -2,13 +2,14 @@ package aws
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"github.com/gin-gonic/gin"
)
const (
@@ -74,6 +75,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
}

View File

@@ -3,7 +3,6 @@ package baidu
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -11,6 +10,8 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -130,6 +131,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return baiduEmbeddingRequest, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -3,13 +3,14 @@ package baidu_v2
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -60,6 +61,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -3,7 +3,6 @@ package claude
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -11,6 +10,8 @@ import (
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"strings"
"github.com/gin-gonic/gin"
)
const (
@@ -37,10 +38,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
a.RequestMode = RequestModeCompletion
} else {
a.RequestMode = RequestModeMessage
}
}
@@ -84,6 +85,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -24,6 +24,8 @@ func stopReasonClaude2OpenAI(reason string) string {
return "stop"
case "max_tokens":
return "max_tokens"
case "tool_use":
return "tool_calls"
default:
return reason
}
@@ -298,6 +300,13 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCallResponse, 0)
fcIdx := 0
if claudeResponse.Index != nil {
fcIdx = *claudeResponse.Index - 1
if fcIdx < 0 {
fcIdx = 0
}
}
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
@@ -317,8 +326,9 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCallResponse{
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Index: common.GetPointer(fcIdx),
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
@@ -330,11 +340,12 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
choice.Index = *claudeResponse.Index
choice.Delta.Content = claudeResponse.Delta.Text
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Type: "function",
Index: common.GetPointer(fcIdx),
Function: dto.FunctionResponse{
Arguments: *claudeResponse.Delta.PartialJson,
},

View File

@@ -55,6 +55,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
}
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -3,13 +3,14 @@ package cohere
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -52,6 +53,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
return requestOpenAI2Cohere(*request), nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -0,0 +1,132 @@
package coze
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/common"
"time"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
// ConvertAudioRequest implements channel.Adaptor.
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented")
}
// ConvertClaudeRequest implements channel.Adaptor.
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertEmbeddingRequest implements channel.Adaptor.
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertImageRequest implements channel.Adaptor.
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertOpenAIRequest implements channel.Adaptor.
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return convertCozeChatRequest(c, *request), nil
}
// ConvertOpenAIResponsesRequest implements channel.Adaptor.
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertRerankRequest implements channel.Adaptor.
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
// DoRequest implements channel.Adaptor.
func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
if info.IsStream {
return channel.DoApiRequest(a, c, info, requestBody)
}
// 首先发送创建消息请求,成功后再发送获取消息请求
// 发送创建消息请求
resp, err := channel.DoApiRequest(a, c, info, requestBody)
if err != nil {
return nil, err
}
// 解析 resp
var cozeResponse CozeChatResponse
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
err = json.Unmarshal(respBody, &cozeResponse)
if cozeResponse.Code != 0 {
return nil, errors.New(cozeResponse.Msg)
}
c.Set("coze_conversation_id", cozeResponse.Data.ConversationId)
c.Set("coze_chat_id", cozeResponse.Data.Id)
// 轮询检查消息是否完成
for {
err, isComplete := checkIfChatComplete(a, c, info)
if err != nil {
return nil, err
} else {
if isComplete {
break
}
}
time.Sleep(time.Second * 1)
}
// 发送获取消息请求
return getChatDetail(a, c, info)
}
// DoResponse implements channel.Adaptor.
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cozeChatStreamHandler(c, resp, info)
} else {
err, usage = cozeChatHandler(c, resp, info)
}
return
}
// GetChannelName implements channel.Adaptor.
func (a *Adaptor) GetChannelName() string {
return ChannelName
}
// GetModelList implements channel.Adaptor.
func (a *Adaptor) GetModelList() []string {
return ModelList
}
// GetRequestURL implements channel.Adaptor.
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil
}
// Init implements channel.Adaptor.
func (a *Adaptor) Init(info *common.RelayInfo) {
}
// SetupRequestHeader implements channel.Adaptor.
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}

View File

@@ -0,0 +1,30 @@
package coze
var ModelList = []string{
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
"Baichuan4",
"abab6.5s-chat-pro",
"glm-4-0520",
"qwen-max",
"deepseek-r1",
"deepseek-v3",
"deepseek-r1-distill-qwen-32b",
"deepseek-r1-distill-qwen-7b",
"step-1v-8k",
"step-1.5v-mini",
"Doubao-pro-32k",
"Doubao-pro-256k",
"Doubao-lite-128k",
"Doubao-lite-32k",
"Doubao-vision-lite-32k",
"Doubao-vision-pro-32k",
"Doubao-1.5-pro-vision-32k",
"Doubao-1.5-lite-32k",
"Doubao-1.5-pro-32k",
"Doubao-1.5-thinking-pro",
"Doubao-1.5-pro-256k",
}
var ChannelName = "coze"

78
relay/channel/coze/dto.go Normal file
View File

@@ -0,0 +1,78 @@
package coze
import "encoding/json"
type CozeError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type CozeEnterMessage struct {
Role string `json:"role"`
Type string `json:"type,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
MetaData json.RawMessage `json:"meta_data,omitempty"`
ContentType string `json:"content_type,omitempty"`
}
type CozeChatRequest struct {
BotId string `json:"bot_id"`
UserId string `json:"user_id"`
AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"`
Stream bool `json:"stream,omitempty"`
CustomVariables json.RawMessage `json:"custom_variables,omitempty"`
AutoSaveHistory bool `json:"auto_save_history,omitempty"`
MetaData json.RawMessage `json:"meta_data,omitempty"`
ExtraParams json.RawMessage `json:"extra_params,omitempty"`
ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
type CozeChatResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data CozeChatResponseData `json:"data"`
}
type CozeChatResponseData struct {
Id string `json:"id"`
ConversationId string `json:"conversation_id"`
BotId string `json:"bot_id"`
CreatedAt int64 `json:"created_at"`
LastError CozeError `json:"last_error"`
Status string `json:"status"`
Usage CozeChatUsage `json:"usage"`
}
type CozeChatUsage struct {
TokenCount int `json:"token_count"`
OutputCount int `json:"output_count"`
InputCount int `json:"input_count"`
}
type CozeChatDetailResponse struct {
Data []CozeChatV3MessageDetail `json:"data"`
Code int `json:"code"`
Msg string `json:"msg"`
Detail CozeResponseDetail `json:"detail"`
}
type CozeChatV3MessageDetail struct {
Id string `json:"id"`
Role string `json:"role"`
Type string `json:"type"`
BotId string `json:"bot_id"`
ChatId string `json:"chat_id"`
Content json.RawMessage `json:"content"`
MetaData json.RawMessage `json:"meta_data"`
CreatedAt int64 `json:"created_at"`
SectionId string `json:"section_id"`
UpdatedAt int64 `json:"updated_at"`
ContentType string `json:"content_type"`
ConversationId string `json:"conversation_id"`
ReasoningContent string `json:"reasoning_content"`
}
type CozeResponseDetail struct {
Logid string `json:"logid"`
}

View File

@@ -0,0 +1,300 @@
package coze
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest {
var messages []CozeEnterMessage
// 将 request的messages的role为user的content转换为CozeMessage
for _, message := range request.Messages {
if message.Role == "user" {
messages = append(messages, CozeEnterMessage{
Role: "user",
Content: message.Content,
// TODO: support more content type
ContentType: "text",
})
}
}
user := request.User
if user == "" {
user = helper.GetResponseID(c)
}
cozeRequest := &CozeChatRequest{
BotId: c.GetString("bot_id"),
UserId: user,
AdditionalMessages: messages,
Stream: request.Stream,
}
return cozeRequest
}
func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// convert coze response to openai response
var response dto.TextResponse
var cozeResponse CozeChatDetailResponse
response.Model = info.UpstreamModelName
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if cozeResponse.Code != 0 {
return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil
}
// 从上下文获取 usage
var usage dto.Usage
usage.PromptTokens = c.GetInt("coze_input_count")
usage.CompletionTokens = c.GetInt("coze_output_count")
usage.TotalTokens = c.GetInt("coze_token_count")
response.Usage = usage
response.Id = helper.GetResponseID(c)
var responseContent json.RawMessage
for _, data := range cozeResponse.Data {
if data.Type == "answer" {
responseContent = data.Content
response.Created = data.CreatedAt
}
}
// 添加 response.Choices
response.Choices = []dto.OpenAITextResponseChoice{
{
Index: 0,
Message: dto.Message{Role: "assistant", Content: responseContent},
FinishReason: "stop",
},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
}
func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
var responseText string
var currentEvent string
var currentData string
var usage dto.Usage
for scanner.Scan() {
line := scanner.Text()
if line == "" {
if currentEvent != "" && currentData != "" {
// handle last event
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
currentEvent = ""
currentData = ""
}
continue
}
if strings.HasPrefix(line, "event:") {
currentEvent = strings.TrimSpace(line[6:])
continue
}
if strings.HasPrefix(line, "data:") {
currentData = strings.TrimSpace(line[5:])
continue
}
}
// Last event
if currentEvent != "" && currentData != "" {
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
}
if err := scanner.Err(); err != nil {
return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
}
helper.Done(c)
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, &usage
}
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
switch event {
case "conversation.chat.completed":
// 将 data 解析为 CozeChatResponseData
var chatData CozeChatResponseData
err := json.Unmarshal([]byte(data), &chatData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
usage.PromptTokens = chatData.Usage.InputCount
usage.CompletionTokens = chatData.Usage.OutputCount
usage.TotalTokens = chatData.Usage.TokenCount
finishReason := "stop"
stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
helper.ObjectData(c, stopResponse)
case "conversation.message.delta":
// 将 data 解析为 CozeChatV3MessageDetail
var messageData CozeChatV3MessageDetail
err := json.Unmarshal([]byte(data), &messageData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
var content string
err = json.Unmarshal(messageData.Content, &content)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
*responseText += content
openaiResponse := dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
}
choice := dto.ChatCompletionsStreamResponseChoice{
Index: 0,
}
choice.Delta.SetContentString(content)
openaiResponse.Choices = append(openaiResponse.Choices, choice)
helper.ObjectData(c, openaiResponse)
case "error":
var errorData CozeError
err := json.Unmarshal([]byte(data), &errorData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
}
}
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
// 将 conversationId和chatId作为参数发送get请求
req, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return err, false
}
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return err, false
}
resp, err := doRequest(req, info) // 调用 doRequest
if err != nil {
return err, false
}
if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
return fmt.Errorf("resp is nil"), false
}
defer resp.Body.Close() // 确保响应体被关闭
// 解析 resp 到 CozeChatResponse
var cozeResponse CozeChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response body failed: %w", err), false
}
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return fmt.Errorf("unmarshal response body failed: %w", err), false
}
if cozeResponse.Data.Status == "completed" {
// 在上下文设置 usage
c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
return nil, true
} else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
} else {
return nil, false
}
}
func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
req, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error // 声明 err 变量
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
client = service.GetHttpClient()
}
resp, err := client.Do(req)
if err != nil { // 增加对 client.Do(req) 返回错误的检查
return nil, fmt.Errorf("client.Do failed: %w", err)
}
// _ = resp.Body.Close()
return resp, nil
}

View File

@@ -3,7 +3,6 @@ package deepseek
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -11,6 +10,9 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -36,9 +38,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fimBaseUrl := info.BaseUrl
if !strings.HasSuffix(info.BaseUrl, "/beta") {
fimBaseUrl += "/beta"
}
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
@@ -66,6 +72,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -3,12 +3,13 @@ package dify
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"github.com/gin-gonic/gin"
)
const (
@@ -86,6 +87,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,7 +1,6 @@
package dify
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/json"
@@ -213,12 +212,8 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var responseText string
usage := &dto.Usage{}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
var nodeToken int
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
@@ -247,13 +242,10 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
}
return true
})
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
helper.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
if usage.TotalTokens == 0 {

View File

@@ -12,7 +12,6 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting/model_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -70,6 +69,16 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// suffix -thinking and -nothinking
if strings.HasSuffix(info.OriginModelName, "-thinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
@@ -99,11 +108,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
ai, err := CovertGemini2OpenAI(*request)
geminiRequest, err := CovertGemini2OpenAI(*request, info)
if err != nil {
return nil, err
}
return ai, nil
return geminiRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -144,6 +155,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return geminiRequest, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -165,6 +181,18 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
err, usage = GeminiChatHandler(c, resp, info)
}
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
// // 没有请求-thinking的情况下产生思考token则按照思考模型计费
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
// thinkingModelName := info.OriginModelName + "-thinking"
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
// info.OriginModelName = thinkingModelName
// }
// }
//}
return
}

View File

@@ -16,6 +16,8 @@ var ModelList = []string{
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-pro-preview-03-25",
// imagen models
"imagen-3.0-generate-002",
// embedding models

View File

@@ -2,10 +2,19 @@ package gemini
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
Tools []GeminiChatTool `json:"tools,omitempty"`
SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
}
type GeminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts,omitempty"`
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
}
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
c.ThinkingBudget = &budget
}
type GeminiInlineData struct {
@@ -45,6 +54,7 @@ type GeminiFileData struct {
type GeminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
@@ -71,15 +81,17 @@ type GeminiChatTool struct {
}
type GeminiChatGenerationConfig struct {
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Seed int64 `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
}
type GeminiChatCandidate struct {
@@ -108,6 +120,7 @@ type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
}
// Imagen related structs

View File

@@ -19,11 +19,10 @@ import (
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
//SafetySettings: []GeminiChatSafetySettings{},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
@@ -32,6 +31,30 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
geminiRequest.GenerationConfig.ResponseModalities = []string{
"TEXT",
"IMAGE",
}
}
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
if strings.HasSuffix(info.OriginModelName, "-thinking") {
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
if budgetTokens == 0 || budgetTokens > 24576 {
budgetTokens = 24576
}
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(int(budgetTokens)),
IncludeThoughts: true,
}
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(0),
}
}
}
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
@@ -56,6 +79,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
continue
}
if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{})
if ok {
if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
@@ -65,6 +89,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}
}
}
// Clean the parameters before appending
cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
tool.Function.Parameters = cleanedParams
functions = append(functions, tool.Function)
}
if codeExecution {
@@ -86,11 +113,11 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
// json_data, _ := json.Marshal(geminiRequest.Tools)
// common.SysLog("tools_json: " + string(json_data))
} else if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTool{
{
FunctionDeclarations: textRequest.Functions,
},
}
//geminiRequest.Tools = []GeminiChatTool{
// {
// FunctionDeclarations: textRequest.Functions,
// },
//}
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -204,6 +231,34 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
})
}
} else if part.Type == dto.ContentTypeFile {
if part.GetFile().FileId != "" {
return nil, fmt.Errorf("only base64 file is supported in gemini")
}
format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
} else if part.Type == dto.ContentTypeInputAudio {
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
}
}
@@ -229,6 +284,102 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
return &geminiRequest, nil
}
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
func cleanFunctionParameters(params interface{}) interface{} {
if params == nil {
return nil
}
paramMap, ok := params.(map[string]interface{})
if !ok {
// Not a map, return as is (e.g., could be an array or primitive)
return params
}
// Create a copy to avoid modifying the original
cleanedMap := make(map[string]interface{})
for k, v := range paramMap {
cleanedMap[k] = v
}
// Remove unsupported root-level fields
delete(cleanedMap, "default")
delete(cleanedMap, "exclusiveMaximum")
delete(cleanedMap, "exclusiveMinimum")
delete(cleanedMap, "$schema")
delete(cleanedMap, "additionalProperties")
// Clean properties
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
cleanedProps := make(map[string]interface{})
for propName, propValue := range props {
propMap, ok := propValue.(map[string]interface{})
if !ok {
cleanedProps[propName] = propValue // Keep non-map properties
continue
}
// Create a copy of the property map
cleanedPropMap := make(map[string]interface{})
for k, v := range propMap {
cleanedPropMap[k] = v
}
// Remove unsupported fields
delete(cleanedPropMap, "default")
delete(cleanedPropMap, "exclusiveMaximum")
delete(cleanedPropMap, "exclusiveMinimum")
delete(cleanedPropMap, "$schema")
delete(cleanedPropMap, "additionalProperties")
// Check and clean 'format' for string types
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" {
if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists {
if formatValue != "enum" && formatValue != "date-time" {
delete(cleanedPropMap, "format")
}
}
}
// Recursively clean nested properties within this property if it's an object/array
// Check the type before recursing
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") {
cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap)
} else {
cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing
}
}
cleanedMap["properties"] = cleanedProps
}
// Recursively clean items in arrays if needed (e.g., type: array, items: { ... })
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
cleanedMap["items"] = cleanFunctionParameters(items)
}
// Also handle items if it's an array of schemas
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
cleanedItemsArray := make([]interface{}, len(itemsArray))
for i, item := range itemsArray {
cleanedItemsArray[i] = cleanFunctionParameters(item)
}
cleanedMap["items"] = cleanedItemsArray
}
// Recursively clean other schema composition keywords if necessary
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := cleanedMap[field].([]interface{}); ok {
cleanedNested := make([]interface{}, len(nested))
for i, item := range nested {
cleanedNested[i] = cleanFunctionParameters(item)
}
cleanedMap[field] = cleanedNested
}
}
return cleanedMap
}
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 {
return schema
@@ -240,6 +391,7 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
}
// 删除所有的title字段
delete(v, "title")
delete(v, "$schema")
// 如果type不为object和array则直接返回
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
return schema
@@ -387,6 +539,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call)
}
} else if part.Thought {
choice.Message.ReasoningContent = part.Text
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
@@ -404,7 +558,6 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
@@ -427,9 +580,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
isStop := false
hasImage := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
isStop = true
@@ -443,6 +597,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
}
var texts []string
isTools := false
isThought := false
if candidate.FinishReason != nil {
// p := GeminiConvertFinishReason(*candidate.FinishReason)
switch *candidate.FinishReason {
@@ -455,12 +610,21 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
}
}
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
if part.InlineData != nil {
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
hasImage = true
}
} else if part.FunctionCall != nil {
isTools = true
if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
} else if part.Thought {
isThought = true
texts = append(texts, part.Text)
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
@@ -473,7 +637,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
}
}
}
choice.Delta.SetContentString(strings.Join(texts, "\n"))
if isThought {
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
} else {
choice.Delta.SetContentString(strings.Join(texts, "\n"))
}
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
}
@@ -483,7 +651,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Choices = choices
return &response, isStop
return &response, isStop, hasImage
}
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -491,23 +659,28 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
var imageCount int
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse)
err := common.DecodeJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
if hasImage {
imageCount++
}
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
// responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
}
err = helper.ObjectData(c, response)
if err != nil {
@@ -522,9 +695,14 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
var response *dto.ChatCompletionsStreamResponse
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258
}
}
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
if info.ShouldIncludeUsage {
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
@@ -547,8 +725,11 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if common.DebugEnabled {
println(string(responseBody))
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
err = common.DecodeJson(responseBody, &geminiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
@@ -570,6 +751,10 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {

View File

@@ -3,7 +3,6 @@ package jina
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -12,6 +11,8 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -55,6 +56,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
return request, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -2,13 +2,14 @@ package mistral
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -59,6 +60,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -3,7 +3,6 @@ package mokaai
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -11,6 +10,8 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -74,6 +75,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -2,7 +2,6 @@ package ollama
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -10,6 +9,8 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -64,6 +65,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return requestOpenAI2Embeddings(request), nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"mime/multipart"
"net/http"
"net/textproto"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
@@ -22,6 +23,7 @@ import (
"one-api/relay/common_handler"
"one-api/relay/constant"
"one-api/service"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
@@ -36,7 +38,7 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
if !strings.Contains(request.Model, "claude") {
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
}
aiRequest, err := service.ClaudeToOpenAIRequest(*request)
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
if err != nil {
return nil, err
}
@@ -87,7 +89,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1)
// 2025年5月10日后创建的渠道不移除.
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
model_ = strings.Replace(model_, ".", "", -1)
}
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
@@ -147,14 +152,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
request.StreamOptions = nil
}
if strings.HasPrefix(request.Model, "o1") || strings.HasPrefix(request.Model, "o3") {
if strings.HasPrefix(request.Model, "o") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
request.Temperature = nil
}
request.Temperature = nil
if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
@@ -167,11 +170,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
}
info.ReasoningEffort = request.ReasoningEffort
info.UpstreamModelName = request.Model
}
if request.Model == "o1" || request.Model == "o1-2024-12-17" || strings.HasPrefix(request.Model, "o3") {
//修改第一个Message的内容将system改为developer
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer"
// o系列模型developer适配o1-mini除外
if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") {
//修改第一个Message的内容将system改为developer
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer"
}
}
}
@@ -236,11 +241,167 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
// Parse the multipart form to handle both single image and multiple images
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
return nil, errors.New("failed to parse multipart form")
}
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// Check if "image" field exists in any form, including array notation
var imageFiles []*multipart.FileHeader
var exists bool
// First check for standard "image" field
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// If not found, check for "image[]" field
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// If still not found, iterate through all fields to find any that start with "image["
foundArrayImages := false
for fieldName, files := range c.Request.MultipartForm.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
foundArrayImages = true
for _, file := range files {
imageFiles = append(imageFiles, file)
}
}
}
// If no image fields found at all
if !foundArrayImages && (len(imageFiles) == 0) {
return nil, errors.New("image is required")
}
}
}
// Process all image files
for i, fileHeader := range imageFiles {
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
}
defer file.Close()
// If multiple images, use image[] as the field name
fieldName := "image"
if len(imageFiles) > 1 {
fieldName = "image[]"
}
// Determine MIME type based on file extension
mimeType := detectImageMimeType(fileHeader.Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
h.Set("Content-Type", mimeType)
part, err := writer.CreatePart(h)
if err != nil {
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
}
if _, err := io.Copy(part, file); err != nil {
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
}
}
// Handle mask file if present
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
maskFile, err := maskFiles[0].Open()
if err != nil {
return nil, errors.New("failed to open mask file")
}
defer maskFile.Close()
// Determine MIME type for mask file
mimeType := detectImageMimeType(maskFiles[0].Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
h.Set("Content-Type", mimeType)
maskPart, err := writer.CreatePart(h)
if err != nil {
return nil, errors.New("create form file failed for mask")
}
if _, err := io.Copy(maskPart, maskFile); err != nil {
return nil, errors.New("copy mask file failed")
}
}
} else {
return nil, errors.New("no multipart form data found")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return bytes.NewReader(requestBody.Bytes()), nil
default:
return request, nil
}
}
// detectImageMimeType determines the MIME type based on the file extension
func detectImageMimeType(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".webp":
return "image/webp"
default:
// Try to detect from extension if possible
if strings.HasPrefix(ext, ".jp") {
return "image/jpeg"
}
// Default to png as a fallback
return "image/png"
}
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// 模型后缀转换 reasoning effort
if strings.HasSuffix(request.Model, "-high") {
request.Reasoning.Effort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
} else if strings.HasSuffix(request.Model, "-low") {
request.Reasoning.Effort = "low"
request.Model = strings.TrimSuffix(request.Model, "-low")
} else if strings.HasSuffix(request.Model, "-medium") {
request.Reasoning.Effort = "medium"
request.Model = strings.TrimSuffix(request.Model, "-medium")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
if info.RelayMode == constant.RelayModeAudioTranscription ||
info.RelayMode == constant.RelayModeAudioTranslation ||
info.RelayMode == constant.RelayModeImagesEdits {
return channel.DoFormRequest(a, c, info, requestBody)
} else if info.RelayMode == constant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
@@ -259,10 +420,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
fallthrough
case constant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
err, usage = OpenaiHandlerWithUsage(c, resp, info)
case constant.RelayModeRerank:
err, usage = common_handler.RerankHandler(c, info, resp)
case constant.RelayModeResponses:
if info.IsStream {
err, usage = OaiResponsesStreamHandler(c, resp, info)
} else {
err, usage = OaiResponsesHandler(c, resp, info)
}
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)

View File

@@ -31,6 +31,9 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
return err
}
if streamResponse.Usage != nil {
info.ClaudeConvertInfo.Usage = streamResponse.Usage
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
@@ -38,12 +41,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
return nil
}
func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
@@ -78,7 +76,11 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error())
}
}
@@ -170,15 +172,14 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
helper.Done(c)
case relaycommon.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
if !containStreamUsage {
streamResponse.Usage = usage
}
info.ClaudeConvertInfo.Usage = usage
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
@@ -186,3 +187,10 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
}
}
}
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
if data == "" {
return
}
helper.ResponseChunkData(c, streamResponse, data)
}

View File

@@ -117,6 +117,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
model := info.UpstreamModelName
var responseTextBuilder strings.Builder
var toolCount int
var usage = &dto.Usage{}
var streamItems []string // store stream items
var forceFormat bool
@@ -130,8 +131,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
thinkToContent = think2Content
}
toolCount := 0
var (
lastStreamData string
)
@@ -142,7 +141,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if err != nil {
common.SysError("error handling stream format: " + err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)
@@ -170,8 +168,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
}
if shouldSendLastResp {
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 处理token计算
@@ -215,10 +215,35 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
StatusCode: resp.StatusCode,
}, nil
}
forceFormat := false
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
}
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
TotalTokens: info.PromptTokens + completionTokens,
}
}
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
break
if forceFormat {
responseBody, err = json.Marshal(simpleResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
} else {
break
}
case relaycommon.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
claudeRespStr, err := json.Marshal(claudeResp)
@@ -244,18 +269,6 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
common.SysError("error copying response body: " + err.Error())
}
resp.Body.Close()
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
TotalTokens: info.PromptTokens + completionTokens,
}
}
return nil, &simpleResponse.Usage
}
@@ -595,3 +608,52 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
err := service.PreWssConsumeQuota(ctx, info, usage)
return err
}
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
// reset content length
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var usageResp dto.SimpleResponse
err = json.Unmarshal(responseBody, &usageResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
}
// format
if usageResp.InputTokens > 0 {
usageResp.PromptTokens += usageResp.InputTokens
}
if usageResp.OutputTokens > 0 {
usageResp.CompletionTokens += usageResp.OutputTokens
}
if usageResp.InputTokensDetails != nil {
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
}
return nil, &usageResp.Usage
}

View File

@@ -0,0 +1,119 @@
package openai
import (
"bytes"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// read response body
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = common.DecodeJson(responseBody, &responsesResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if responsesResponse.Error != nil {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: responsesResponse.Error.Message,
Type: "openai_error",
Code: responsesResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
// reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
// copy response body
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
common.SysError("error copying response body: " + err.Error())
}
resp.Body.Close()
// compute usage
usage := dto.Usage{}
usage.PromptTokens = responsesResponse.Usage.InputTokens
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
usage.TotalTokens = responsesResponse.Usage.TotalTokens
// 解析 Tools 用量
for _, tool := range responsesResponse.Tools {
info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++
}
return nil, &usage
}
func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
}
var usage = &dto.Usage{}
var responseTextBuilder strings.Builder
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
// 检查当前数据是否包含 completed 状态和 usage 信息
var streamResponse dto.ResponsesStreamResponse
if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
case "response.output_text.delta":
// 处理输出文本
responseTextBuilder.WriteString(streamResponse.Delta)
case dto.ResponsesOutputTypeItemDone:
// 函数调用处理
if streamResponse.Item != nil {
switch streamResponse.Item.Type {
case dto.BuildInCallWebSearchCall:
info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
}
}
}
}
return true
})
if usage.CompletionTokens == 0 {
// 计算输出文本的 token 数量
tempStr := responseTextBuilder.String()
if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens
}
}
return nil, usage
}

View File

@@ -3,13 +3,14 @@ package palm
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -60,6 +61,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -3,13 +3,14 @@ package perplexity
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -63,6 +64,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -3,7 +3,6 @@ package siliconflow
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -11,6 +10,8 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -58,6 +59,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
return request, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -74,13 +80,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayMode {
case constant.RelayModeRerank:
err, usage = siliconflowRerankHandler(c, resp)
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info)
}
case constant.RelayModeCompletions:
fallthrough
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {

View File

@@ -3,7 +3,6 @@ package tencent
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -13,6 +12,8 @@ import (
"one-api/service"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -84,6 +85,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -13,7 +12,10 @@ import (
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"strings"
"github.com/gin-gonic/gin"
)
const (
@@ -77,6 +79,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// suffix -thinking and -nothinking
if strings.HasSuffix(info.OriginModelName, "-thinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
if info.IsStream {
suffix = "streamGenerateContent?alt=sse"
} else {
@@ -143,7 +154,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
info.UpstreamModelName = claudeReq.Model
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
if err != nil {
return nil, err
}
@@ -164,6 +175,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -3,7 +3,6 @@ package volcengine
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -12,6 +11,8 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -71,6 +72,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return request, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -0,0 +1,118 @@
package xai
import (
"errors"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"strings"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
//panic("implement me")
return nil, errors.New("not available")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//not available
return nil, errors.New("not available")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
xaiRequest := ImageRequest{
Model: request.Model,
Prompt: request.Prompt,
N: request.N,
ResponseFormat: request.ResponseFormat,
}
return xaiRequest, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if strings.HasPrefix(request.Model, "grok-3-mini") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
} else if strings.HasSuffix(request.Model, "-low") {
request.ReasoningEffort = "low"
request.Model = strings.TrimSuffix(request.Model, "-low")
} else if strings.HasSuffix(request.Model, "-medium") {
request.ReasoningEffort = "medium"
request.Model = strings.TrimSuffix(request.Model, "-medium")
}
info.ReasoningEffort = request.ReasoningEffort
info.UpstreamModelName = request.Model
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//not available
return nil, errors.New("not available")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
default:
if info.IsStream {
err, usage = xAIStreamHandler(c, resp, info)
} else {
err, usage = xAIHandler(c, resp, info)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,18 @@
package xai
var ModelList = []string{
// grok-3
"grok-3-beta", "grok-3-mini-beta",
// grok-3 mini
"grok-3-fast-beta", "grok-3-mini-fast-beta",
// extend grok-3-mini reasoning
"grok-3-mini-beta-high", "grok-3-mini-beta-low", "grok-3-mini-beta-medium",
"grok-3-mini-fast-beta-high", "grok-3-mini-fast-beta-low", "grok-3-mini-fast-beta-medium",
// image model
"grok-2-image",
// legacy models
"grok-2", "grok-2-vision",
"grok-beta", "grok-vision-beta",
}
var ChannelName = "xai"

27
relay/channel/xai/dto.go Normal file
View File

@@ -0,0 +1,27 @@
package xai
import "one-api/dto"
// ChatCompletionResponse represents the response from XAI chat completion API
type ChatCompletionResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []dto.ChatCompletionsStreamResponseChoice
Usage *dto.Usage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
}
// quality, size or style are not supported by xAI API at the moment.
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
// Size string `json:"size,omitempty"`
// Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
// Style string `json:"style,omitempty"`
// User string `json:"user,omitempty"`
// ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
}

119
relay/channel/xai/text.go Normal file
View File

@@ -0,0 +1,119 @@
package xai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
)
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
if xAIResp == nil {
return nil
}
if xAIResp.Usage != nil {
xAIResp.Usage.CompletionTokens = usage.CompletionTokens
}
openAIResp := &dto.ChatCompletionsStreamResponse{
Id: xAIResp.Id,
Object: xAIResp.Object,
Created: xAIResp.Created,
Model: xAIResp.Model,
Choices: xAIResp.Choices,
Usage: xAIResp.Usage,
}
return openAIResp
}
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
usage := &dto.Usage{}
var responseTextBuilder strings.Builder
var toolCount int
var containStreamUsage bool
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var xAIResp *dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &xAIResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
// 把 xAI 的usage转换为 OpenAI 的usage
if xAIResp.Usage != nil {
containStreamUsage = true
usage.PromptTokens = xAIResp.Usage.PromptTokens
usage.TotalTokens = xAIResp.Usage.TotalTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
}
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
return true
})
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
helper.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
return nil, usage
}
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
var response *dto.TextResponse
err = common.DecodeJson(responseBody, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return nil, nil
}
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
// new body
encodeJson, err := common.EncodeJson(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return nil, nil
}
// set new body
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &response.Usage
}

View File

@@ -2,7 +2,6 @@ package xunfei
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -10,6 +9,8 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -61,6 +62,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}

View File

@@ -3,12 +3,13 @@ package zhipu
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -71,6 +72,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = zhipuStreamHandler(c, resp)

View File

@@ -3,13 +3,15 @@ package zhipu_4v
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -35,7 +37,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return fmt.Sprintf("%s/embeddings", baseUrl), nil
default:
return fmt.Sprintf("%s/chat/completions", baseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -60,7 +68,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return request, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}

View File

@@ -1,17 +1,9 @@
package zhipu_4v
import (
"bufio"
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"sync"
"time"
@@ -119,163 +111,3 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
ToolChoice: request.ToolChoice,
}
}
//func responseZhipu2OpenAI(response *dto.OpenAITextResponse) *dto.OpenAITextResponse {
// fullTextResponse := dto.OpenAITextResponse{
// Id: response.Id,
// Object: "chat.completion",
// Created: common.GetTimestamp(),
// Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.TextResponseChoices)),
// Usage: response.Usage,
// }
// for i, choice := range response.TextResponseChoices {
// content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
// openaiChoice := dto.OpenAITextResponseChoice{
// Index: i,
// Message: dto.Message{
// Role: choice.Role,
// Content: content,
// },
// FinishReason: "",
// }
// if i == len(response.TextResponseChoices)-1 {
// openaiChoice.FinishReason = "stop"
// }
// fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
// }
// return &fullTextResponse
//}
func streamResponseZhipu2OpenAI(zhipuResponse *ZhipuV4StreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content
choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role
choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls
choice.Index = zhipuResponse.Choices[0].Index
choice.FinishReason = zhipuResponse.Choices[0].FinishReason
response := dto.ChatCompletionsStreamResponse{
Id: zhipuResponse.Id,
Object: "chat.completion.chunk",
Created: zhipuResponse.Created,
Model: "glm-4v",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func lastStreamResponseZhipuV42OpenAI(zhipuResponse *ZhipuV4StreamResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
response := streamResponseZhipu2OpenAI(zhipuResponse)
return response, &zhipuResponse.Usage
}
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var streamResponse ZhipuV4StreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
}
var response *dto.ChatCompletionsStreamResponse
if strings.Contains(data, "prompt_tokens") {
response, usage = lastStreamResponseZhipuV42OpenAI(&streamResponse)
} else {
response = streamResponseZhipu2OpenAI(&streamResponse)
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var textResponse ZhipuV4Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the HTTPClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &textResponse.Usage
}

View File

@@ -19,18 +19,24 @@ type ThinkingContentInfo struct {
}
const (
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
LastMessageTypeNone = "none"
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
LastMessageTypeThinking = "thinking"
)
type ClaudeConvertInfo struct {
LastMessagesType string
Index int
Usage *dto.Usage
FinishReason string
Done bool
}
const (
RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude"
RelayFormatGemini = "gemini"
)
type RerankerInfo struct {
@@ -38,6 +44,16 @@ type RerankerInfo struct {
ReturnDocuments bool
}
type BuildInToolInfo struct {
ToolName string
CallCount int
SearchContextSize string
}
type ResponsesUsageInfo struct {
BuiltInTools map[string]*BuildInToolInfo
}
type RelayInfo struct {
ChannelType int
ChannelId int
@@ -82,9 +98,11 @@ type RelayInfo struct {
UserQuota int
RelayFormat string
SendResponseCount int
ChannelCreateTime int64
ThinkingContentInfo
ClaudeConvertInfo
*ClaudeConvertInfo
*RerankerInfo
*ResponsesUsageInfo
}
// 定义支持流式选项的通道类型
@@ -97,6 +115,9 @@ var streamSupportedChannels = map[int]bool{
common.ChannelTypeAzure: true,
common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true,
common.ChannelTypeXai: true,
common.ChannelTypeDeepSeek: true,
common.ChannelTypeBaiduV2: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -112,8 +133,8 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatClaude
info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = ClaudeConvertInfo{
LastMessagesType: LastMessageTypeText,
info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone,
}
return info
}
@@ -128,6 +149,31 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
return info
}
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeResponses
info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo),
}
if len(req.Tools) > 0 {
for _, tool := range req.Tools {
info.ResponsesUsageInfo.BuiltInTools[tool.Type] = &BuildInToolInfo{
ToolName: tool.Type,
CallCount: 0,
}
switch tool.Type {
case dto.BuildInToolWebSearchPreview:
if tool.SearchContextSize == "" {
tool.SearchContextSize = "medium"
}
info.ResponsesUsageInfo.BuiltInTools[tool.Type].SearchContextSize = tool.SearchContextSize
}
}
}
info.IsStream = req.Stream
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
@@ -164,14 +210,15 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"),
//RecodeModelName: c.GetString("original_model"),
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
ParamOverride: paramOverride,
RelayFormat: RelayFormatOpenAI,
IsModelMapped: false,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
ChannelCreateTime: c.GetInt64("channel_create_time"),
ParamOverride: paramOverride,
RelayFormat: RelayFormatOpenAI,
ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
@@ -194,6 +241,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
// responses 模式不支持 StreamOptions
if relayconstant.RelayModeResponses == info.RelayMode {
info.SupportStreamOptions = false
}
return info
}
@@ -212,6 +263,10 @@ func (info *RelayInfo) SetFirstResponseTime() {
}
}
func (info *RelayInfo) HasSendResponse() bool {
return info.FirstResponseTime.After(info.StartTime)
}
type TaskRelayInfo struct {
*RelayInfo
Action string

View File

@@ -32,6 +32,8 @@ const (
APITypeBaiduV2
APITypeOpenRouter
APITypeXinference
APITypeXai
APITypeCoze
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -92,6 +94,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeOpenRouter
case common.ChannelTypeXinference:
apiType = APITypeXinference
case common.ChannelTypeXai:
apiType = APITypeXai
case common.ChannelTypeCoze:
apiType = APITypeCoze
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -12,6 +12,7 @@ const (
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeImagesEdits
RelayModeEdits
RelayModeMidjourneyImagine
@@ -39,6 +40,8 @@ const (
RelayModeRerank
RelayModeResponses
RelayModeRealtime
)
@@ -56,8 +59,12 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeModerations
} else if strings.HasPrefix(path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(path, "/v1/images/edits") {
relayMode = RelayModeImagesEdits
} else if strings.HasPrefix(path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(path, "/v1/responses") {
relayMode = RelayModeResponses
} else if strings.HasPrefix(path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {

View File

@@ -12,11 +12,19 @@ import (
)
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
// 检查是否已经设置过头部
if _, exists := c.Get("event_stream_headers_set"); exists {
return
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
// 设置标志,表示头部已经设置过
c.Set("event_stream_headers_set", true)
}
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
@@ -43,6 +51,14 @@ func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
}
}
func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
func StringData(c *gin.Context, str string) error {
//str = strings.TrimPrefix(str, "data: ")
//str = strings.TrimSuffix(str, "\r")
@@ -55,7 +71,20 @@ func StringData(c *gin.Context, str string) error {
return nil
}
func PingData(c *gin.Context) error {
c.Writer.Write([]byte(": PING\n\n"))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ObjectData(c *gin.Context, object interface{}) error {
if object == nil {
return errors.New("object is nil")
}
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)

View File

@@ -2,9 +2,11 @@ package helper
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"one-api/relay/common"
"github.com/gin-gonic/gin"
)
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
@@ -16,9 +18,36 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
if err != nil {
return fmt.Errorf("unmarshal_model_mapping_failed")
}
if modelMap[info.OriginModelName] != "" {
info.UpstreamModelName = modelMap[info.OriginModelName]
info.IsModelMapped = true
// 支持链式模型重定向,最终使用链尾的模型
currentModel := info.OriginModelName
visitedModels := map[string]bool{
currentModel: true,
}
for {
if mappedModel, exists := modelMap[currentModel]; exists && mappedModel != "" {
// 模型重定向循环检测,避免无限循环
if visitedModels[mappedModel] {
if mappedModel == currentModel {
if currentModel == info.OriginModelName {
info.IsModelMapped = false
return nil
} else {
info.IsModelMapped = true
break
}
}
return errors.New("model_mapping_contains_cycle")
}
visitedModels[mappedModel] = true
currentModel = mappedModel
info.IsModelMapped = true
} else {
break
}
}
if info.IsModelMapped {
info.UpstreamModelName = currentModel
}
}
return nil

View File

@@ -15,14 +15,15 @@ type PriceData struct {
ModelRatio float64
CompletionRatio float64
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
GroupRatio float64
UsePrice bool
CacheCreationRatio float64
ShouldPreConsumedQuota int
}
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota)
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
@@ -32,6 +33,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var modelRatio float64
var completionRatio float64
var cacheRatio float64
var imageRatio float64
var cacheCreationRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
@@ -49,16 +51,13 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
}
}
if !acceptUnsetRatio {
if info.UserId == 1 {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
} else {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
}
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请联系管理员设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
}
}
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
@@ -72,6 +71,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
GroupRatio: groupRatio,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
CacheCreationRatio: cacheCreationRatio,
ShouldPreConsumedQuota: preConsumedQuota,
}
@@ -82,3 +82,15 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
return priceData, nil
}
func ContainPriceOrRatio(modelName string) bool {
_, ok := operation_setting.GetModelPrice(modelName, false)
if ok {
return true
}
_, ok = operation_setting.GetModelRatio(modelName)
if ok {
return true
}
return false
}

View File

@@ -8,37 +8,63 @@ import (
"one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting/operation_setting"
"strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
)
const (
InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
DefaultPingInterval = 10 * time.Second
)
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
if resp == nil {
if resp == nil || dataHandler == nil {
return
}
defer resp.Body.Close()
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
if strings.HasPrefix(info.UpstreamModelName, "o") {
// twice timeout for thinking model
streamingTimeout *= 2
}
var (
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker
writeMutex sync.Mutex // Mutex to protect concurrent writes
)
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
if pingInterval <= 0 {
pingInterval = DefaultPingInterval
}
if pingEnabled {
pingTicker = time.NewTicker(pingInterval)
}
defer func() {
ticker.Stop()
if pingTicker != nil {
pingTicker.Stop()
}
close(stopChan)
}()
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c)
@@ -46,6 +72,34 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer cancel()
ctx = context.WithValue(ctx, "stop_chan", stopChan)
// Handle ping data sending
if pingEnabled && pingTicker != nil {
gopool.Go(func() {
for {
select {
case <-pingTicker.C:
writeMutex.Lock() // Lock before writing
err := PingData(c)
writeMutex.Unlock() // Unlock after writing
if err != nil {
common.LogError(c, "ping data error: "+err.Error())
common.SafeSendBool(stopChan, true)
return
}
if common.DebugEnabled {
println("ping data sent")
}
case <-ctx.Done():
if common.DebugEnabled {
println("ping data goroutine stopped")
}
return
}
}
})
}
common.RelayCtxGo(ctx, func() {
for scanner.Scan() {
ticker.Reset(streamingTimeout)
@@ -62,10 +116,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
}
data = data[5:]
data = strings.TrimLeft(data, " ")
data = strings.TrimSuffix(data, "\"")
data = strings.TrimSuffix(data, "\r")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
writeMutex.Lock() // Lock before writing
success := dataHandler(data)
writeMutex.Unlock() // Unlock after writing
if !success {
break
}
@@ -85,7 +141,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
common.SafeSendBool(stopChan, true)
case <-stopChan:
// 正常结束
common.LogInfo(c, "streaming finished")
}
}

View File

@@ -5,21 +5,83 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
"github.com/gin-gonic/gin"
)
func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
imageRequest := &dto.ImageRequest{}
switch info.RelayMode {
case relayconstant.RelayModeImagesEdits:
_, err := c.MultipartForm()
if err != nil {
return nil, err
}
formData := c.Request.PostForm
imageRequest.Prompt = formData.Get("prompt")
imageRequest.Model = formData.Get("model")
imageRequest.N = common.String2Int(formData.Get("n"))
imageRequest.Quality = formData.Get("quality")
imageRequest.Size = formData.Get("size")
if imageRequest.Model == "gpt-image-1" {
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
}
default:
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
}
} else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
// N should between 1 and 10
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
//}
}
}
if imageRequest.Prompt == "" {
return nil, errors.New("prompt is required")
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
if strings.Contains(imageRequest.Size, "×") {
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
@@ -39,6 +101,10 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
// x.ai grok-2-image not support size, quality or style
if imageRequest.Size == "empty" {
imageRequest.Size = ""
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
@@ -86,43 +152,59 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
imageRequest.Model = relayInfo.UpstreamModelName
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
var preConsumedQuota int
var quota int
var userQuota int
if !priceData.UsePrice {
// modelRatio 16 = modelPrice $0.04
// per 1 modelRatio = $0.04 / 16
priceData.ModelPrice = 0.0025 * priceData.ModelRatio
}
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
sizeRatio := 1.0
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 0.4
} else if imageRequest.Size == "512x512" {
sizeRatio = 0.45
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
sizeRatio = 2
}
qualityRatio := 1.0
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
qualityRatio = 2.0
if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
qualityRatio = 1.5
// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
var openaiErr *dto.OpenAIErrorWithStatusCode
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
} else {
sizeRatio := 1.0
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 0.4
} else if imageRequest.Size == "512x512" {
sizeRatio = 0.45
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
sizeRatio = 2
}
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
qualityRatio := 1.0
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
qualityRatio = 2.0
if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
qualityRatio = 1.5
}
}
// reset model price
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
}
}
adaptor := GetAdaptor(relayInfo.ApiType)
@@ -137,12 +219,15 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
requestBody = convertedRequest.(io.Reader)
} else {
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
}
requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
@@ -162,24 +247,25 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
}
}
_, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
usage := &dto.Usage{
PromptTokens: imageRequest.N,
TotalTokens: imageRequest.N,
if usage.(*dto.Usage).TotalTokens == 0 {
usage.(*dto.Usage).TotalTokens = imageRequest.N
}
if usage.(*dto.Usage).PromptTokens == 0 {
usage.(*dto.Usage).PromptTokens = imageRequest.N
}
quality := "standard"
if imageRequest.Quality == "hd" {
quality = "hd"
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent)
return nil
}

View File

@@ -32,7 +32,23 @@ func RelayMidjourneyImage(c *gin.Context) {
})
return
}
resp, err := http.Get(midjourneyTask.ImageUrl)
var httpClient *http.Client
if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
if proxy, ok := channel.GetSetting()["proxy"]; ok {
if proxyURL, ok := proxy.(string); ok && proxyURL != "" {
if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil {
c.JSON(400, gin.H{
"error": "proxy_url_invalid",
})
return
}
}
}
}
if httpClient == nil {
httpClient = service.GetHttpClient()
}
resp, err := httpClient.Get(midjourneyTask.ImageUrl)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed",

171
relay/relay-responses.go Normal file
View File

@@ -0,0 +1,171 @@
package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"strings"
"github.com/gin-gonic/gin"
)
func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
request := &dto.OpenAIResponsesRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
if request.Model == "" {
return nil, errors.New("model is required")
}
if len(request.Input) == 0 {
return nil, errors.New("input is required")
}
return request, nil
}
func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) {
sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input)
return sensitiveWords, err
}
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
inputTokens, err := service.CountTokenInput(req.Input, req.Model)
info.PromptTokens = inputTokens
return inputTokens, err
}
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
req, err := getAndValidateResponsesRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest)
}
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
if setting.ShouldCheckPromptSensitive() {
sensitiveWords, err := checkInputSensitive(req, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
}
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
}
req.Model = relayInfo.UpstreamModelName
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
promptTokens, err := getInputTokens(req, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
}
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre consume quota
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_request_body_error", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_error", http.StatusBadRequest)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "marshal_request_error", http.StatusInternalServerError)
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
err = json.Unmarshal(jsonData, &reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
}
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = json.Marshal(reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
}
}
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
requestBody = bytes.NewBuffer(jsonData)
}
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
} else {
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
}
return nil
}

View File

@@ -18,6 +18,7 @@ import (
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"one-api/setting/operation_setting"
"strings"
"time"
@@ -193,6 +194,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
@@ -331,12 +333,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
cacheTokens := usage.PromptTokensDetails.CachedTokens
imageTokens := usage.PromptTokensDetails.ImageTokens
completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := priceData.CompletionRatio
cacheRatio := priceData.CacheRatio
imageRatio := priceData.ImageRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
@@ -344,9 +348,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
dImageTokens := decimal.NewFromInt(int64(imageTokens))
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
dCompletionRatio := decimal.NewFromFloat(completionRatio)
dCacheRatio := decimal.NewFromFloat(cacheRatio)
dImageRatio := decimal.NewFromFloat(imageRatio)
dModelRatio := decimal.NewFromFloat(modelRatio)
dGroupRatio := decimal.NewFromFloat(groupRatio)
dModelPrice := decimal.NewFromFloat(modelPrice)
@@ -354,11 +360,46 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
ratio := dModelRatio.Mul(dGroupRatio)
// openai web search 工具计费
var dWebSearchQuota decimal.Decimal
var webSearchPrice float64
if relayInfo.ResponsesUsageInfo != nil {
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
// 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s调用花费 $%s",
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
}
}
// file search tool 计费
var dFileSearchQuota decimal.Decimal
var fileSearchPrice float64
if relayInfo.ResponsesUsageInfo != nil {
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 $%s",
fileSearchTool.CallCount, dFileSearchQuota.String())
}
}
var quotaCalculateDecimal decimal.Decimal
if !priceData.UsePrice {
nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio)
promptQuota := nonCachedTokens.Add(cachedTokensWithRatio)
if imageTokens > 0 {
nonImageTokens := dPromptTokens.Sub(dImageTokens)
imageTokensWithRatio := dImageTokens.Mul(dImageRatio)
promptQuota = nonImageTokens.Add(imageTokensWithRatio)
}
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
@@ -369,6 +410,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
} else {
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
}
// 添加 responses tools call 调用的配额
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens
@@ -414,6 +458,25 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
logContent += ", " + extraContent
}
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio
other["image_output"] = imageTokens
}
if !dWebSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
other["web_search"] = true
other["web_search_call_count"] = webSearchTool.CallCount
other["web_search_price"] = webSearchPrice
}
}
if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
other["file_search"] = true
other["file_search_call_count"] = fileSearchTool.CallCount
other["file_search_price"] = fileSearchPrice
}
}
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel/claude"
"one-api/relay/channel/cloudflare"
"one-api/relay/channel/cohere"
"one-api/relay/channel/coze"
"one-api/relay/channel/deepseek"
"one-api/relay/channel/dify"
"one-api/relay/channel/gemini"
@@ -25,6 +26,7 @@ import (
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
"one-api/relay/channel/volcengine"
"one-api/relay/channel/xai"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v"
@@ -85,6 +87,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &openai.Adaptor{}
case constant.APITypeXinference:
return &openai.Adaptor{}
case constant.APITypeXai:
return &xai.Adaptor{}
case constant.APITypeCoze:
return &coze.Adaptor{}
}
return nil
}

View File

@@ -1,10 +1,11 @@
package router
import (
"github.com/gin-gonic/gin"
"one-api/controller"
"one-api/middleware"
"one-api/relay"
"github.com/gin-gonic/gin"
)
func SetRelayRouter(router *gin.Engine) {
@@ -40,13 +41,14 @@ func SetRelayRouter(router *gin.Engine) {
httpRouter.POST("/chat/completions", controller.Relay)
httpRouter.POST("/edits", controller.Relay)
httpRouter.POST("/images/generations", controller.Relay)
httpRouter.POST("/images/edits", controller.RelayNotImplemented)
httpRouter.POST("/images/edits", controller.Relay)
httpRouter.POST("/images/variations", controller.RelayNotImplemented)
httpRouter.POST("/embeddings", controller.Relay)
httpRouter.POST("/engines/:model/embeddings", controller.Relay)
httpRouter.POST("/audio/transcriptions", controller.Relay)
httpRouter.POST("/audio/translations", controller.Relay)
httpRouter.POST("/audio/speech", controller.Relay)
httpRouter.POST("/responses", controller.Relay)
httpRouter.GET("/files", controller.RelayNotImplemented)
httpRouter.POST("/files", controller.RelayNotImplemented)
httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)

View File

@@ -24,7 +24,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
if !setting.EnableWorker() {
return nil, fmt.Errorf("worker not enabled")
}
if !strings.HasPrefix(req.URL, "https") {
if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
return nil, fmt.Errorf("only support https url")
}

View File

@@ -6,9 +6,10 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"strings"
)
func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIRequest, error) {
func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
openAIRequest := dto.GeneralOpenAIRequest{
Model: claudeRequest.Model,
MaxTokens: claudeRequest.MaxTokens,
@@ -17,6 +18,13 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
Stream: claudeRequest.Stream,
}
if claudeRequest.Thinking != nil {
if strings.HasSuffix(info.OriginModelName, "-thinking") &&
!strings.HasSuffix(claudeRequest.Model, "-thinking") {
openAIRequest.Model = openAIRequest.Model + "-thinking"
}
}
// Convert stop sequences
if len(claudeRequest.StopSequences) == 1 {
openAIRequest.Stop = claudeRequest.StopSequences[0]
@@ -45,7 +53,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
// Add system message if present
if claudeRequest.System != nil {
if claudeRequest.IsStringSystem() {
if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
openAIMessage := dto.Message{
Role: "system",
}
@@ -59,7 +67,9 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
Role: "system",
}
for _, system := range systems {
systemStr += system.Type
if system.Text != nil {
systemStr += *system.Text
}
}
openAIMessage.SetStringContent(systemStr)
openAIMessages = append(openAIMessages, openAIMessage)
@@ -122,23 +132,22 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
} else {
mediaContents := mediaMsg.ParseMediaContent()
if len(mediaContents) > 0 && mediaContents[0].Text != nil {
oaiToolMessage.SetStringContent(*mediaContents[0].Text)
}
encodeJson, _ := common.EncodeJson(mediaContents)
oaiToolMessage.SetStringContent(string(encodeJson))
}
openAIMessages = append(openAIMessages, oaiToolMessage)
}
}
if len(mediaMessages) > 0 {
openAIMessage.SetMediaContent(mediaMessages)
}
if len(toolCalls) > 0 {
openAIMessage.SetToolCalls(toolCalls)
}
if len(mediaMessages) > 0 && len(toolCalls) == 0 {
openAIMessage.SetMediaContent(mediaMessages)
}
}
if len(openAIMessage.ParseContent()) > 0 {
if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
openAIMessages = append(openAIMessages, openAIMessage)
}
}
@@ -211,15 +220,15 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
} else {
resp := &dto.ClaudeResponse{
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
}
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
//resp := &dto.ClaudeResponse{
// Type: "content_block_start",
// ContentBlock: &dto.ClaudeMediaMessage{
// Type: "text",
// Text: common.GetPointer[string](""),
// },
//}
//resp.SetIndex(0)
//claudeResponses = append(claudeResponses, resp)
}
return claudeResponses
}
@@ -232,16 +241,20 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
chosenChoice := openAIResponse.Choices[0]
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
// should be done
info.FinishReason = *chosenChoice.FinishReason
return claudeResponses
}
if info.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
if openAIResponse.Usage != nil {
if info.ClaudeConvertInfo.Usage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_delta",
Usage: &dto.ClaudeUsage{
InputTokens: openAIResponse.Usage.PromptTokens,
OutputTokens: openAIResponse.Usage.CompletionTokens,
InputTokens: info.ClaudeConvertInfo.Usage.PromptTokens,
OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens,
},
Delta: &dto.ClaudeMediaMessage{
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(*chosenChoice.FinishReason)),
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
},
})
}
@@ -250,10 +263,10 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
})
} else {
var claudeResponse dto.ClaudeResponse
claudeResponse.SetIndex(0)
var isEmpty bool
claudeResponse.Type = "content_block_delta"
if len(chosenChoice.Delta.ToolCalls) > 0 {
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
@@ -274,15 +287,57 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
}
} else {
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "text_delta",
Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
reasoning := chosenChoice.Delta.GetReasoningContent()
textContent := chosenChoice.Delta.GetContentString()
if reasoning != "" || textContent != "" {
if reasoning != "" {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
//info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "thinking",
Thinking: "",
},
})
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "thinking_delta",
Thinking: reasoning,
}
} else {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
}
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
})
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "text_delta",
Text: common.GetPointer[string](textContent),
}
}
} else {
isEmpty = true
}
}
claudeResponse.Index = &info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &claudeResponse)
if !isEmpty {
claudeResponses = append(claudeResponses, &claudeResponse)
}
}
}

View File

@@ -8,9 +8,9 @@ import (
"one-api/dto"
)
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
resp, err := DoDownloadRequest(url)
if err != nil {
return nil, err
@@ -22,7 +22,6 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
if err != nil {
return nil, err
}
// Check actual size after reading
if len(fileBytes) > maxFileSize {
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)

View File

@@ -3,12 +3,13 @@ package service
import (
"context"
"fmt"
"golang.org/x/net/proxy"
"net"
"net/http"
"net/url"
"one-api/common"
"time"
"golang.org/x/net/proxy"
)
var httpClient *http.Client
@@ -55,7 +56,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
},
}, nil
case "socks5":
case "socks5", "socks5h":
// 获取认证信息
var auth *proxy.Auth
if parsedURL.User != nil {
@@ -69,6 +70,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
}
// 创建 SOCKS5 代理拨号器
// proxy.SOCKS5 使用 tcp 参数,所有 TCP 连接包括 DNS 查询都将通过代理进行。行为与 socks5h 相同
dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct)
if err != nil {
return nil, err

Some files were not shown because too many files have changed in this diff Show More