Compare commits

...

123 Commits

Author SHA1 Message Date
1808837298@qq.com
49bfd2b719 feat: Enhance mobile UI responsiveness and layout for ChannelsTable and SiderBar 2025-03-10 19:01:56 +08:00
1808837298@qq.com
b2938ffe2c refactor: Improve mobile responsiveness and scrolling behavior in UI layout 2025-03-10 15:49:32 +08:00
1808837298@qq.com
d9cf0885f1 refactor: Enhance UI layout and styling with responsive design improvements 2025-03-10 03:25:02 +08:00
1808837298@qq.com
3ed50787b3 style: Enhance LogsTable header tags with improved styling and visual hierarchy 2025-03-10 00:34:24 +08:00
1808837298@qq.com
97d948cdb1 refactor: Make Channel Setting nullable and improve setting handling #836 2025-03-09 23:42:48 +08:00
1808837298@qq.com
5017fabbfa fix: Correct typo in group_ratio variable name in LogsTable 2025-03-09 21:24:19 +08:00
1808837298@qq.com
bd5c261b99 fix: Add optional chaining to prevent potential undefined errors in LogsTable #833 2025-03-09 21:23:33 +08:00
1808837298@qq.com
00c2d6c102 feat: Introduce configurable docs link and remove hardcoded chat links
- Added a new GeneralSetting struct to manage configurable docs link
- Removed hardcoded ChatLink and ChatLink2 variables across multiple files
- Updated frontend components to dynamically render docs link from status
- Simplified chat and link-related logic in various components
- Added a warning modal for quota per unit setting in operation settings
2025-03-09 18:31:16 +08:00
1808837298@qq.com
4a8bb625b8 fix: Refine embedding model detection in channel test 2025-03-09 15:03:07 +08:00
1808837298@qq.com
db01994cd0 refactor: Improve price rendering with clearer token and price calculations 2025-03-08 23:47:02 +08:00
Calcium-Ion
a0ca3effa7 Merge pull request #830 from Calcium-Ion/decimal
feat: Improve decimal precision for quota and payment calculationsDecimal
2025-03-08 22:01:15 +08:00
1808837298@qq.com
5a10ebd384 refactor: Update topup amount type from int to int64 for improved precision 2025-03-08 21:59:18 +08:00
1808837298@qq.com
68097c132d feat: Improve decimal precision for quota and payment calculations
- Added github.com/shopspring/decimal for precise floating-point calculations
- Refactored quota and payment calculations in multiple files to use decimal arithmetic
- Updated go.mod and go.sum to include decimal library
- Improved precision in topup, relay, and quota service calculations
- Added support for more OpenAI model variants in cache ratio settings
2025-03-08 21:55:50 +08:00
Calcium-Ion
3352bacd35 Merge pull request #828 from Calcium-Ion/ui
feat: Add column visibility settings for Channels and Logs tables
2025-03-08 19:55:28 +08:00
1808837298@qq.com
7fcb14e25f feat: Add column visibility settings for Channels and Logs tables
- Implemented dynamic column visibility for ChannelsTable and LogsTable
- Added localStorage persistence for column preferences
- Introduced column selector modal with select all/reset functionality
- Supported role-based default column visibility
- Added column settings button to table interfaces
2025-03-08 19:53:07 +08:00
1808837298@qq.com
867187ab4d refactor: Simplify chat menu items rendering in SiderBar 2025-03-08 19:06:49 +08:00
1808837298@qq.com
3ad96d3b4e feat: update readme and i18n 2025-03-08 18:13:44 +08:00
Calcium-Ion
d9390ff4c3 Merge pull request #826 from Calcium-Ion/cache
feat: Add prompt cache hit tokens support for DeepSeek channel #406
2025-03-08 16:52:19 +08:00
1808837298@qq.com
8c209e2fb9 fix: Adjust DeepSeek cache ratio to 0.1 2025-03-08 16:51:43 +08:00
1808837298@qq.com
a9bfcb0daf feat: Add prompt cache hit tokens support for DeepSeek channel #406 2025-03-08 16:50:53 +08:00
1808837298@qq.com
bb848b2fe0 refactor: Improve quota calculation precision using floating-point arithmetic 2025-03-08 16:44:08 +08:00
Calcium-Ion
618908f6f8 Merge pull request #821 from Calcium-Ion/cache
chore: Update terminology from "cache ratio" to "cache multiplier" in UI and add placeholder for default create cache ratio
2025-03-08 02:49:21 +08:00
1808837298@qq.com
1f4ebddcfa fix: Update default cache ratio from 0.5 to 1 2025-03-08 02:47:41 +08:00
1808837298@qq.com
6d79d8993e chore: Update terminology from "cache ratio" to "cache multiplier" in UI and add placeholder for default create cache ratio 2025-03-08 02:44:09 +08:00
Calcium-Ion
7c03ad71de Merge pull request #820 from Calcium-Ion/cache
feat: Implement cache token ratio for more precise token pricing
2025-03-08 01:31:44 +08:00
1808837298@qq.com
4f194f4e6a feat: Implement cache token ratio for more precise token pricing 2025-03-08 01:30:50 +08:00
1808837298@qq.com
81137e0533 refactor: Remove redundant user quota retrieval in audio relay 2025-03-07 19:59:00 +08:00
Calcium-Ion
b9b66dda54 Merge pull request #815 from Sh1n3zZ/openrouter-adapter
fix: adapting return format for openrouter think content (#793)
2025-03-07 19:25:20 +08:00
1808837298@qq.com
fd22948ead refactor: Reorganize sidebar navigation and add personal settings route 2025-03-07 17:22:37 +08:00
Sh1n3zZ
894dce7366 fix: possible incomplete return of the think field and incorrect occurrences of the reasoning field 2025-03-06 19:20:29 +08:00
Sh1n3zZ
b95142bbac fix: adapting return format for openrouter think content (#793) 2025-03-06 19:16:26 +08:00
1808837298@qq.com
7f74a9664e feat: Enhance channel status update with success tracking and dynamic notification #812 2025-03-06 17:46:03 +08:00
1808837298@qq.com
a3739f67f7 fix: Handle error in NotifyRootUser and log system errors #812 2025-03-06 17:25:39 +08:00
1808837298@qq.com
b841ce006f refactor: Improve model request rate limit middleware execution 2025-03-06 16:32:11 +08:00
1808837298@qq.com
e3f9ef1894 fix: error NotifyRootUser #812 2025-03-06 15:56:42 +08:00
1808837298@qq.com
558e625a01 fix: Prevent resource leaks by adding body close in stream handlers 2025-03-05 19:51:22 +08:00
1808837298@qq.com
37a83ecc33 refactor: Centralize stream handling and helper functions in relay package 2025-03-05 19:47:41 +08:00
1808837298@qq.com
37bb34b4b0 Update README.md 2025-03-05 16:55:17 +08:00
1808837298@qq.com
8deab221f9 fix: vertex claude 2025-03-05 16:43:40 +08:00
1808837298@qq.com
17e9f1a07d fix: #810 2025-03-05 16:39:42 +08:00
1808837298@qq.com
792754cee3 fix: #810 2025-03-05 16:34:08 +08:00
1808837298@qq.com
98b27a17a6 refactor: Extract operation-related settings into a separate package 2025-03-04 18:52:08 +08:00
1808837298@qq.com
7855f83e2d Update README.md 2025-03-04 18:50:05 +08:00
1808837298@qq.com
cbdf26bf2c feat: Add context-aware goroutine pool for safer concurrent operations 2025-03-04 18:42:34 +08:00
1808837298@qq.com
eb46b71a71 fix: Ignore EOF errors in OpenAI stream scanner 2025-03-04 17:35:41 +08:00
1808837298@qq.com
a42c3b6227 Merge remote-tracking branch 'origin/main' 2025-03-04 17:11:07 +08:00
1808837298@qq.com
b00dd8b405 fix: Handle scanner errors in OpenAI relay stream handler 2025-03-04 17:10:56 +08:00
Calcium-Ion
be228ccd2c Merge pull request #805 from PaperPlaneDeemo/main
Fix: fix typo in README
2025-03-04 16:27:15 +08:00
1808837298@qq.com
b1be64bcf3 fix: vertex claude 2025-03-03 20:06:08 +08:00
1808837298@qq.com
6ecfb81cbc feat: Improve image download and validation in GetImageFromUrl 2025-03-03 16:15:04 +08:00
Nekof
14848ff789 Merge branch 'Calcium-Ion:main' into main 2025-03-03 11:37:40 +08:00
“Deemo”
47d3b515da fix: Typo in README 2025-03-03 11:35:04 +08:00
1808837298@qq.com
760514c3e1 fix: channel test model mapped 2025-03-02 23:53:10 +08:00
1808837298@qq.com
254c25c27a feat: yanjingxia 2025-03-02 23:17:37 +08:00
1808837298@qq.com
8731a32e56 feat: Add model testing modal with search functionality in ChannelsTable
- Implement a new modal for selecting and testing models per channel
- Add search functionality to filter models by keyword
- Replace dropdown with direct button for model testing
- Introduce new state variables for managing model test modal
2025-03-02 19:53:35 +08:00
1808837298@qq.com
7208a65e5d refactor: Add index to Username column in Log model 2025-03-02 17:57:52 +08:00
1808837298@qq.com
4084b18071 refactor: Update rate limit configuration to use dynamic expiration duration 2025-03-02 17:34:39 +08:00
1808837298@qq.com
2ca0d7246d fix: Use channel group in model testing log record 2025-03-02 15:59:39 +08:00
1808837298@qq.com
d042a1bd55 refactor: Improve channel testing and model price handling 2025-03-02 15:47:12 +08:00
1808837298@qq.com
816e831a2e feat: Persist models expanded state in PersonalSetting component 2025-03-02 01:35:50 +08:00
1808837298@qq.com
a3ceae4a86 feat: Enhance update checking and system information display
- Add version and startup time display in OtherSetting component
- Implement robust GitHub release update checking mechanism
- Add error handling for update check process
- Update Modal component for displaying update information
- Add new translations for version and system information
2025-03-02 01:31:27 +08:00
1808837298@qq.com
eb163d9c94 feat: Add self-use mode and demo site mode indicators to HeaderBar 2025-03-02 00:46:54 +08:00
1808837298@qq.com
a592a81bc2 fix: Correct option map key for PreConsumedQuota 2025-03-01 22:37:14 +08:00
1808837298@qq.com
bb300d199e feat: Add translations for self-use mode and demo site mode settings 2025-03-01 21:15:59 +08:00
1808837298@qq.com
7dbb6b017c feat: Add self-use mode for model ratio and price configuration
- Introduce `SelfUseModeEnabled` setting to allow flexible model ratio configuration
- Update error handling to provide more informative messages when model ratios are not set
- Modify pricing and relay logic to support self-use mode
- Add UI toggle for enabling self-use mode in operation settings
- Implement fallback mechanism for model ratios when self-use mode is enabled
2025-03-01 21:13:48 +08:00
1808837298@qq.com
ce1854847b fix: Enhance error message for missing model ratio configuration 2025-03-01 17:02:31 +08:00
1808837298@qq.com
2f9faba40d fix: Improve error handling for model ratio and price validation #800 2025-03-01 15:27:32 +08:00
1808837298@qq.com
a5085014cc fix: Improve model ratio and price management
- Update error message for missing model ratio to be more user-friendly
- Modify ModelRatioNotSetEditor to filter models without price or ratio
- Enhance model data initialization with fallback values
2025-02-28 23:28:47 +08:00
1808837298@qq.com
18d3706ff8 feat: Add new model management features
- Implement `/api/channel/models_enabled` endpoint to retrieve enabled models
- Add `EnabledListModels` handler in controller
- Create new `ModelRatioNotSetEditor` component for managing unset model ratios
- Update router to include new models_enabled route
- Add internationalization support for new model management UI
- Include GPT-4.5 preview model in OpenAI model list
2025-02-28 21:13:30 +08:00
1808837298@qq.com
152950497e fix 2025-02-28 20:28:44 +08:00
1808837298@qq.com
d6fd50e382 feat: add new GPT-4.5 preview model ratios 2025-02-28 19:17:15 +08:00
1808837298@qq.com
cfd3f6c073 feat: Enhance Claude default max tokens configuration
- Replace ThinkingAdapterMaxTokens with a more flexible DefaultMaxTokens map
- Add support for model-specific default max tokens configuration
- Update relay and web interface to use the new configuration approach
- Implement a fallback mechanism for default max tokens
2025-02-28 17:53:08 +08:00
1808837298@qq.com
45c56b5ded feat: Implement model-specific headers configuration for Claude 2025-02-28 16:47:31 +08:00
1808837298@qq.com
d306394f33 fix: Simplify Claude settings value conversion logic 2025-02-27 22:26:21 +08:00
1808837298@qq.com
cdba87a7da fix: Prevent duplicate headers in Claude settings 2025-02-27 22:14:53 +08:00
1808837298@qq.com
ae5b874a6c refactor: Reorganize Claude MaxTokens configuration UI layout 2025-02-27 22:12:14 +08:00
1808837298@qq.com
d0bc8d17d1 feat: Enhance Claude MaxTokens configuration handling
- Update Claude relay to set default MaxTokens dynamically
- Modify web interface to clarify default MaxTokens input purpose
- Improve token configuration logic for thinking adapter models
2025-02-27 22:10:29 +08:00
1808837298@qq.com
4784ca7514 fix: Update Claude thinking adapter token percentage input guidance 2025-02-27 20:59:32 +08:00
1808837298@qq.com
3a18c0ce9f fix: Correct model request configuration in Vertex Claude adaptor 2025-02-27 20:51:10 +08:00
1808837298@qq.com
929668bead feat: Refactor model configuration management with new config system
- Introduce a new configuration management approach for model-specific settings
- Update Gemini settings to use the new config system with more flexible management
- Add support for dynamic configuration updates in option handling
- Modify Claude and Vertex adaptors to use new configuration methods
- Enhance web interface to support namespaced configuration keys
2025-02-27 20:49:34 +08:00
1808837298@qq.com
06a78f9042 feat: Add Claude model configuration management #791 2025-02-27 20:49:21 +08:00
1808837298@qq.com
0f1c4c4ebe fix: Add pagination support to user search functionality 2025-02-27 16:55:02 +08:00
1808837298@qq.com
1bcf7a3c39 chore: Update Azure OpenAI API version and embedding model detection
- Enhance channel test to detect more embedding models
- Update Azure OpenAI default API version to 2024-12-01-preview
- Remove redundant default API version setting in channel edit
- Add user cache writing in channel test
2025-02-27 16:49:32 +08:00
1808837298@qq.com
5f0b3f6d6f fix: Improve AWS Claude adaptor request conversion error handling #796 2025-02-27 14:57:00 +08:00
1808837298@qq.com
19a318c943 init openrouter adaptor 2025-02-27 00:01:21 +08:00
1808837298@qq.com
13ab0f8e4f fix: gemini&claude tool call format #795 #766 2025-02-26 23:56:10 +08:00
1808837298@qq.com
6d8d40e67b fix: claude tool call format #795 #766 2025-02-26 23:40:16 +08:00
1808837298@qq.com
287caf8e38 feat: Add Jina reranking support for OpenAI adaptor 2025-02-26 21:46:06 +08:00
1808837298@qq.com
c802b3b41a fix: Update Gemini safety settings to use 'OFF' as default 2025-02-26 19:20:17 +08:00
1808837298@qq.com
ed4e1c2332 fix: Update Gemini safety settings category 2025-02-26 19:18:00 +08:00
1808837298@qq.com
e581ea33c2 fix: Update Gemini safety settings default value 2025-02-26 19:01:45 +08:00
1808837298@qq.com
bf80d71ddf feat: Add Gemini version settings configuration support (close #568) 2025-02-26 18:19:09 +08:00
1808837298@qq.com
e19b244e73 feat: Add Gemini safety settings configuration support (close #703) 2025-02-26 16:54:43 +08:00
1808837298@qq.com
f451268830 feat: Update Claude relay temperature setting 2025-02-25 22:01:05 +08:00
1808837298@qq.com
069f2672c1 refactor: Enhance user context and quota management
- Add new context keys for user-related information
- Modify user cache and authentication middleware to populate context
- Refactor quota and notification services to use context-based user data
- Remove redundant database queries by leveraging context information
- Update various components to use new context-based user retrieval methods
2025-02-25 20:56:16 +08:00
1808837298@qq.com
ccf13d445f feat: redis poolsize 2025-02-25 19:39:29 +08:00
1808837298@qq.com
da4d1861fe fix: Adjust Claude thinking mode request parameters 2025-02-25 16:52:45 +08:00
1808837298@qq.com
3de5b96cb4 docs: Update README 2025-02-25 16:31:42 +08:00
Calcium-Ion
5b9e275690 Merge pull request #788 from MartialBE/main
feat: Add Claude 3.7 Sonnet thinking mode support
2025-02-25 15:21:39 +08:00
1808837298@qq.com
607e3206b3 Merge branch 'main' into thinking
# Conflicts:
#	relay/channel/claude/dto.go
2025-02-25 15:21:22 +08:00
1808837298@qq.com
83feb492fb feat: Add support for Claude thinking parameter in request 2025-02-25 14:37:03 +08:00
MartialBE
4f212be45c feat: Add Claude 3.7 Sonnet thinking mode support 2025-02-25 14:10:43 +08:00
1808837298@qq.com
92918e3751 feat: Add Claude 3.7 Sonnet model to AWS channel mapping 2025-02-25 02:55:23 +08:00
1808837298@qq.com
de15551570 feat: Add support for Claude 3.7 Sonnet model 2025-02-25 02:51:31 +08:00
1808837298@qq.com
a81a28b7a5 feat: Support max_tokens parameter for Ollama channel #782 2025-02-24 17:35:49 +08:00
Calcium-Ion
dc36fdedc2 Merge pull request #781 from zeyugao/main
feat: Pass extra_body in OpenAI request to the backend
2025-02-24 16:29:48 +08:00
Calcium-Ion
3017882fa3 Merge pull request #783 from Calcium-Ion/rate-limit
feat: Add model request rate limiting functionality
2025-02-24 16:29:23 +08:00
1808837298@qq.com
e9ba392af8 feat: Add model rate limit settings in system configuration 2025-02-24 16:27:20 +08:00
1808837298@qq.com
83a37e4653 feat: Add model request rate limiting functionality 2025-02-24 16:20:55 +08:00
1808837298@qq.com
b6f95dca41 feat: Add support for different Dify bot types and request URLs 2025-02-24 14:18:30 +08:00
1808837298@qq.com
7ff4cebdbe feat: Enhance token counting and content parsing for messages 2025-02-24 14:18:15 +08:00
Elsa
af00f7b311 Pass extra_body to the backend 2025-02-24 10:52:55 +08:00
1808837298@qq.com
cc1d6e1c05 fix: Improve 429 error logging with detailed message 2025-02-23 21:26:31 +08:00
1808837298@qq.com
6c7a8c811c fix typo 2025-02-23 17:27:33 +08:00
1808837298@qq.com
d5ab7d2d34 feat: Add thinking-to-content option in channel extra settings #780 2025-02-23 17:13:08 +08:00
1808837298@qq.com
115a181db3 feat: Add thinking-to-content conversion for stream responses 2025-02-23 17:05:57 +08:00
1808837298@qq.com
88a2fec190 fix: mistral 2025-02-22 16:29:48 +08:00
1808837298@qq.com
27ea231d66 fix: fix image ratio calculation 2025-02-22 15:50:18 +08:00
Calcium-Ion
4b6101b3ea Merge pull request #778 from utopeadia/main
美化日志界面刷新图标
2025-02-22 15:21:28 +08:00
1808837298@qq.com
48926b8a5a fix: Ensure correct quota warning threshold type conversion 2025-02-22 15:19:55 +08:00
1808837298@qq.com
c44a32efe0 chore: update rerank.md 2025-02-22 15:13:26 +08:00
HowieWood
c541d6c97e 进一步美化刷新图标 2025-02-22 14:18:25 +08:00
HowieWood
7dfcd135da 优化日志刷新图标显示 2025-02-22 14:12:49 +08:00
131 changed files with 5115 additions and 1323 deletions

View File

@@ -50,10 +50,6 @@
# CHANNEL_TEST_FREQUENCY=10
# 生成默认token
# GENERATE_DEFAULT_TOKEN=false
# Gemini 安全设置
# GEMINI_SAFETY_SETTING=BLOCK_NONE
# Gemini版本设置
# GEMINI_MODEL_MAP=gemini-1.0-pro:v1
# Cohere 安全设置
# COHERE_SAFETY_SETTING=NONE
# 是否统计图片token

View File

@@ -63,10 +63,20 @@
- Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`)
- Add suffix `-medium` to set medium reasoning effort
- Add suffix `-low` to set low reasoning effort
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings`
19. 💰 Cache billing support, when enabled can charge a configurable ratio for cache hits:
1. Set `Prompt Cache Ratio` in `System Settings -> Operation Settings`
2. Set `Prompt Cache Ratio` in channel settings, range 0-1 (e.g., 0.5 means 50% charge on cache hits)
3. Supported channels:
- [x] OpenAI
- [x] Azure
- [x] DeepSeek
- [ ] Claude
## Model Support
This version additionally supports:
1. Third-party model **gps** (gpt-4-gizmo-*)
1. Third-party model **gpts** (gpt-4-gizmo-*)
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md)
3. Custom channels with full API URL support
4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md)
@@ -160,7 +170,7 @@ docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtow
## Channel Retry
Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
First retry uses same priority, second retry uses next priority, and so on.
If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request.
### Cache Configuration
1. `REDIS_CONN_STRING`: Use Redis as cache

View File

@@ -66,13 +66,26 @@
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API支持Azure渠道
16. 支持使用路由/chat2link 进入聊天界面
17. 🧠 支持通过模型名称后缀设置 reasoning effort
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
1. OpenAI o系列模型
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
2. Claude 思考模型
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制
20. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
1.`系统设置-运营设置` 中设置 `提示缓存倍率` 选项
2. 在渠道中设置 `提示缓存倍率`,范围 0-1例如设置为 0.5 表示缓存命中时按照 50% 计费
3. 支持的渠道:
- [x] OpenAI
- [x] Azure
- [x] DeepSeek
- [ ] Claude
## 模型支持
此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-*
1. 第三方模型 **gpts** gpt-4-gizmo-*
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
3. 自定义渠道,支持填入完整调用地址
4. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
@@ -89,7 +102,6 @@
- `GET_MEDIA_TOKEN`是否统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用"模型:版本"指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB默认为 `20`
@@ -98,6 +110,10 @@
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
## 已废弃的环境变量
- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置
- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置
## 部署
> [!TIP]
@@ -169,7 +185,7 @@ docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtow
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
如果开启了重试功能,重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
@@ -212,8 +228,8 @@ docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtow
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)用key查询使用额度
其他基于New API的项目
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon)New API高性能优化版并支持Claude格式
- [VoAPI](https://github.com/VoAPI/VoAPI)基于New API的闭源项目
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon)New API高性能优化版专注于高并发优化,并支持Claude格式
- [VoAPI](https://github.com/VoAPI/VoAPI)基于New API的前端美化版本,闭源免费
## 🌟 Star History

View File

@@ -13,7 +13,7 @@ Request:
```json
{
"model": "rerank-multilingual-v3.0",
"model": "jina-reranker-v2-base-multilingual",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [

View File

@@ -15,8 +15,9 @@ var SystemName = "New API"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var ChatLink2 = ""
// var ChatLink = ""
// var ChatLink2 = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
@@ -276,7 +277,7 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"", //37
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40

24
common/gopool.go Normal file
View File

@@ -0,0 +1,24 @@
package common
import (
"context"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"math"
)
var relayGoPool gopool.Pool
func init() {
relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
SafeSendBool(stopChan, true)
}
SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i))
})
}
func RelayCtxGo(ctx context.Context, f func()) {
relayGoPool.CtxGo(ctx, f)
}

View File

@@ -32,6 +32,7 @@ func InitRedisClient() (err error) {
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
}
opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
RDB = redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -41,6 +42,10 @@ func InitRedisClient() (err error) {
if err != nil {
FatalLog("Redis ping test failed: " + err.Error())
}
if DebugEnabled {
SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
}
return err
}
@@ -53,13 +58,20 @@ func ParseRedisOption() *redis.Options {
}
func RedisSet(key string, value string, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
}
ctx := context.Background()
return RDB.Set(ctx, key, value, expiration).Err()
}
func RedisGet(key string) (string, error) {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis GET: key=%s", key))
}
ctx := context.Background()
return RDB.Get(ctx, key).Result()
val, err := RDB.Get(ctx, key).Result()
return val, err
}
//func RedisExpire(key string, expiration time.Duration) error {
@@ -73,16 +85,25 @@ func RedisGet(key string) (string, error) {
//}
func RedisDel(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
}
ctx := context.Background()
return RDB.Del(ctx, key).Err()
}
func RedisHDelObj(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
}
ctx := context.Background()
return RDB.HDel(ctx, key).Err()
}
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
}
ctx := context.Background()
data := make(map[string]interface{})
@@ -130,6 +151,9 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
}
func RedisHGetObj(key string, obj interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
}
ctx := context.Background()
result, err := RDB.HGetAll(ctx, key).Result()
@@ -208,6 +232,9 @@ func RedisHGetObj(key string, obj interface{}) error {
// RedisIncr Add this function to handle atomic increments
func RedisIncr(key string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
}
// 检查键的剩余生存时间
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
@@ -238,6 +265,9 @@ func RedisIncr(key string, delta int64) error {
}
func RedisHIncrBy(key, field string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
}
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
@@ -262,6 +292,9 @@ func RedisHIncrBy(key, field string, delta int64) error {
}
func RedisHSetField(key, field string, value interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
}
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {

View File

@@ -5,6 +5,7 @@ import (
"context"
crand "crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"html/template"
@@ -213,6 +214,24 @@ func RandomSleep() {
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
}
func GetPointer[T any](v T) *T {
return &v
}
func Any2Type[T any](data any) (T, error) {
var zero T
bytes, err := json.Marshal(data)
if err != nil {
return zero, err
}
var res T
err = json.Unmarshal(bytes, &res)
if err != nil {
return zero, err
}
return res, nil
}
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
func SaveTmpFile(filename string, data io.Reader) (string, error) {
f, err := os.CreateTemp(os.TempDir(), filename)

View File

@@ -1,6 +1,7 @@
package constant
var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
)

View File

@@ -2,4 +2,9 @@ package constant
const (
ContextKeyRequestStartTime = "request_start_time"
ContextKeyUserSetting = "user_setting"
ContextKeyUserQuota = "user_quota"
ContextKeyUserStatus = "user_status"
ContextKeyUserEmail = "user_email"
ContextKeyUserGroup = "user_group"
)

View File

@@ -1,10 +1,7 @@
package constant
import (
"fmt"
"one-api/common"
"os"
"strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
@@ -23,9 +20,9 @@ var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
var GeminiModelMap = map[string]string{
"gemini-1.0-pro": "v1",
}
//var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1",
//}
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
@@ -33,18 +30,18 @@ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
//if modelVersionMapStr == "" {
// return
//}
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
// parts := strings.Split(pair, ":")
// if len(parts) == 2 {
// GeminiModelMap[parts[0]] = parts[1]
// } else {
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
// }
//}
}
// GenerateDefaultToken 是否生成初始令牌,默认关闭。

View File

@@ -17,6 +17,7 @@ import (
"one-api/relay"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strconv"
"strings"
@@ -48,7 +49,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
testModel == "text-embedding-v1" ||
strings.Contains(testModel, "embed") ||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
@@ -72,26 +73,29 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
}
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
cache, err := model.GetUserCache(1)
if err != nil {
return err, nil
}
cache.WriteContext(c)
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(1, false)
c.Set("group", group)
middleware.SetupContextForSelectedChannel(c, channel, testModel)
meta := relaycommon.GenRelayInfo(c)
info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info)
if err != nil {
return err, nil
}
testModel = info.UpstreamModelName
apiType, _ := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
@@ -99,12 +103,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
adaptor.Init(meta)
adaptor.Init(info)
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
convertedRequest, err := adaptor.ConvertRequest(c, info, request)
if err != nil {
return err, nil
}
@@ -114,7 +117,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return err, nil
}
@@ -126,7 +129,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
}
}
usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr
}
@@ -139,26 +142,28 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if err != nil {
return err, nil
}
modelPrice, usePrice := common.GetModelPrice(testModel, false)
modelRatio := common.GetModelRatio(testModel)
completionRatio := common.GetCompletionRatio(testModel)
ratio := modelRatio
info.PromptTokens = usage.PromptTokens
priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
if err != nil {
return err, nil
}
quota := 0
if !usePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
if !priceData.UsePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
if priceData.ModelRatio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit)
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
}
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
@@ -170,10 +175,10 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") || // bge 系列模型
model == "text-embedding-v1" { // 其他 embedding 模型
strings.Contains(model, "bge-") {
testRequest.Model = model
// Embedding 请求
testRequest.Input = []string{"hello world"}
return testRequest
@@ -181,6 +186,8 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
// 并非Embedding 模型
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
testRequest.MaxCompletionTokens = 10
} else if strings.Contains(model, "thinking") {
testRequest.MaxTokens = 50
} else {
testRequest.MaxTokens = 10
}

View File

@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota)
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}

View File

@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -53,8 +54,7 @@ func GetStatus(c *gin.Context) {
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"chat_link": common.ChatLink,
"chat_link2": common.ChatLink2,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
@@ -66,7 +66,8 @@ func GetStatus(c *gin.Context) {
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": setting.DemoSiteEnabled,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
},
})
return

View File

@@ -216,6 +216,13 @@ func DashboardListModels(c *gin.Context) {
})
}
func EnabledListModels(c *gin.Context) {
c.JSON(200, gin.H{
"success": true,
"data": model.GetEnabledModels(),
})
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if aiModel, ok := openAIModelsMap[modelId]; ok {

View File

@@ -2,9 +2,9 @@ package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
)
func GetPricing(c *gin.Context) {
@@ -40,7 +40,7 @@ func GetPricing(c *gin.Context) {
}
func ResetModelRatio(c *gin.Context) {
defaultStr := common.DefaultModelRatio2JSONString()
defaultStr := operation_setting.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
@@ -49,7 +49,7 @@ func ResetModelRatio(c *gin.Context) {
})
return
}
err = common.UpdateModelRatioByJSONString(defaultStr)
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,

View File

@@ -16,6 +16,7 @@ import (
"one-api/relay"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strings"
)
@@ -41,15 +42,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
return err
}
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
default:
err = relay.TextHelper(c)
}
return err
}
func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
@@ -85,6 +77,7 @@ func Relay(c *gin.Context) {
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
@@ -109,7 +102,7 @@ func WssRelay(c *gin.Context) {
if err != nil {
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
service.WssError(c, ws, openaiErr.Error)
helper.WssError(c, ws, openaiErr.Error)
return
}
@@ -151,7 +144,7 @@ func WssRelay(c *gin.Context) {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
service.WssError(c, ws, openaiErr.Error)
helper.WssError(c, ws, openaiErr.Error)
}
}

View File

@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}

View File

@@ -2,9 +2,6 @@ package controller
import (
"fmt"
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"log"
"net/url"
"one-api/common"
@@ -14,16 +11,21 @@ import (
"strconv"
"sync"
"time"
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/shopspring/decimal"
)
type EpayRequest struct {
Amount int `json:"amount"`
Amount int64 `json:"amount"`
PaymentMethod string `json:"payment_method"`
TopUpCode string `json:"top_up_code"`
}
type AmountRequest struct {
Amount int `json:"amount"`
Amount int64 `json:"amount"`
TopUpCode string `json:"top_up_code"`
}
@@ -41,25 +43,35 @@ func GetEpayClient() *epay.Client {
return withUrl
}
func getPayMoney(amount float64, group string) float64 {
func getPayMoney(amount int64, group string) float64 {
dAmount := decimal.NewFromInt(amount)
if !common.DisplayInCurrencyEnabled {
amount = amount / common.QuotaPerUnit
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
dAmount = dAmount.Div(dQuotaPerUnit)
}
// 别问为什么用float64问就是这么点钱没必要
topupGroupRatio := common.GetTopupGroupRatio(group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
payMoney := amount * setting.Price * topupGroupRatio
return payMoney
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
dPrice := decimal.NewFromFloat(setting.Price)
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
return payMoney.InexactFloat64()
}
func getMinTopup() int {
func getMinTopup() int64 {
minTopup := setting.MinTopUp
if !common.DisplayInCurrencyEnabled {
minTopup = minTopup * int(common.QuotaPerUnit)
dMinTopup := decimal.NewFromInt(int64(minTopup))
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
minTopup = int(dMinTopup.Mul(dQuotaPerUnit).IntPart())
}
return minTopup
return int64(minTopup)
}
func RequestEpay(c *gin.Context) {
@@ -80,7 +92,7 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getPayMoney(float64(req.Amount), group)
payMoney := getPayMoney(req.Amount, group)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
@@ -118,7 +130,9 @@ func RequestEpay(c *gin.Context) {
}
amount := req.Amount
if !common.DisplayInCurrencyEnabled {
amount = amount / int(common.QuotaPerUnit)
dAmount := decimal.NewFromInt(int64(amount))
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
amount = dAmount.Div(dQuotaPerUnit).IntPart()
}
topUp := &model.TopUp{
UserId: id,
@@ -210,13 +224,16 @@ func EpayNotify(c *gin.Context) {
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
dAmount := decimal.NewFromInt(int64(topUp.Amount))
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(quotaToAdd), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
@@ -241,7 +258,7 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getPayMoney(float64(req.Amount), group)
payMoney := getPayMoney(req.Amount, group)
if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return

View File

@@ -913,11 +913,11 @@ func TopUp(c *gin.Context) {
}
type UpdateUserSettingRequest struct {
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold int `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
}
func UpdateUserSetting(c *gin.Context) {

View File

@@ -10,6 +10,10 @@
- 用于配置网络代理
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
3. thinking_to_content
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
- 类型为布尔值,设置为 true 时启用思考内容转换
--------------------------------------------------------------
## JSON 格式示例
@@ -19,6 +23,7 @@
```json
{
"force_format": true,
"thinking_to_content": true,
"proxy": "socks5://xxxxxxx"
}
```

View File

@@ -1,6 +1,9 @@
package dto
import "encoding/json"
import (
"encoding/json"
"strings"
)
type ResponseFormat struct {
Type string `json:"type,omitempty"`
@@ -15,49 +18,52 @@ type FormatJsonSchema struct {
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"`
}
type OpenAITools struct {
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
type ToolCallRequest struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Function FunctionRequest `json:"function"`
}
type OpenAIFunction struct {
type FunctionRequest struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
type StreamOptions struct {
@@ -93,6 +99,7 @@ type Message struct {
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Reasoning string `json:"reasoning,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
parsedContent []MediaContent
@@ -101,7 +108,7 @@ type Message struct {
type MediaContent struct {
Type string `json:"type"`
Text string `json:"text"`
Text string `json:"text,omitempty"`
ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"`
}
@@ -133,11 +140,11 @@ func (m *Message) SetPrefix(prefix bool) {
m.Prefix = &prefix
}
func (m *Message) ParseToolCalls() []ToolCall {
func (m *Message) ParseToolCalls() []ToolCallRequest {
if m.ToolCalls == nil {
return nil
}
var toolCalls []ToolCall
var toolCalls []ToolCallRequest
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
return toolCalls
}
@@ -153,11 +160,24 @@ func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return stringContent
}
return string(m.Content)
contentStr := new(strings.Builder)
arrayContent := m.ParseContent()
for _, content := range arrayContent {
if content.Type == ContentTypeText {
contentStr.WriteString(content.Text)
}
}
stringContent = contentStr.String()
m.parsedStringContent = &stringContent
return stringContent
}
func (m *Message) SetStringContent(content string) {

View File

@@ -62,10 +62,11 @@ type ChatCompletionsStreamResponseChoice struct {
}
type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Reasoning *string `json:"reasoning,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -80,30 +81,38 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
}
func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
if c.ReasoningContent == nil {
if c.ReasoningContent == nil && c.Reasoning == nil {
return ""
}
return *c.ReasoningContent
if c.ReasoningContent != nil {
return *c.ReasoningContent
}
return *c.Reasoning
}
type ToolCall struct {
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
c.ReasoningContent = &s
c.Reasoning = &s
}
type ToolCallResponse struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionCall `json:"function"`
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionResponse `json:"function"`
}
func (c *ToolCall) SetIndex(i int) {
func (c *ToolCallResponse) SetIndex(i int) {
c.Index = &i
}
type FunctionCall struct {
type FunctionResponse struct {
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
// call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"`
Arguments string `json:"arguments"` // response
}
type ChatCompletionsStreamResponse struct {
@@ -116,6 +125,20 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"`
}
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
copy(choices, c.Choices)
return &ChatCompletionsStreamResponse{
Id: c.Id,
Object: c.Object,
Created: c.Created,
Model: c.Model,
SystemFingerprint: c.SystemFingerprint,
Choices: choices,
Usage: c.Usage,
}
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
@@ -143,6 +166,7 @@ type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
}

2
go.mod
View File

@@ -22,12 +22,12 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
golang.org/x/crypto v0.27.0
golang.org/x/image v0.23.0
golang.org/x/net v0.28.0

4
go.sum
View File

@@ -117,8 +117,6 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@@ -183,6 +181,8 @@ github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=

View File

@@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.IsUserEnabled(token.UserId, false)
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
userEnabled := userCache.Status == common.UserStatusEnabled
if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
userCache.WriteContext(c)
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)

View File

@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
return
}
}
userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.GetUserGroup(userId, false)
userGroup := c.GetString(constant.ContextKeyUserGroup)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]

View File

@@ -0,0 +1,175 @@
package middleware
import (
"context"
"fmt"
"net/http"
"one-api/common"
"one-api/setting"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
const (
ModelRequestRateLimitCountMark = "MRRL"
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
// 检查Redis中的请求限制
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
// 如果maxCount为0表示不限制
if maxCount == 0 {
return true, nil
}
// 获取当前计数
length, err := rdb.LLen(ctx, key).Result()
if err != nil {
return false, err
}
// 如果未达到限制,允许请求
if length < int64(maxCount) {
return true, nil
}
// 检查时间窗口
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
return false, err
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
return false, err
}
// 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration {
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
return false, nil
}
return true, nil
}
// 记录Redis请求
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
// 如果maxCount为0不记录请求
if maxCount == 0 {
return
}
now := time.Now().Format(timeFormat)
rdb.LPush(ctx, key, now)
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
}
// Redis限流处理器
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background()
rdb := common.RDB
// 1. 检查总请求数限制当totalMaxCount为0时会自动跳过
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 2. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
return
}
// 3. 记录总请求当totalMaxCount为0时会自动跳过
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
// 4. 处理请求
c.Next()
// 5. 如果请求成功,记录成功请求
if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
}
}
}
// 内存限流处理器
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
// 1. 检查总请求数限制当totalMaxCount为0时跳过
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 2. 检查成功请求数限制
// 使用一个临时key来检查限制这样可以避免实际记录
checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 3. 处理请求
c.Next()
// 4. 如果请求成功,记录到实际的成功请求计数中
if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
}
}
}
// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit() func(c *gin.Context) {
return func(c *gin.Context) {
// 在每个请求时检查是否启用限流
if !setting.ModelRequestRateLimitEnabled {
c.Next()
return
}
// 计算限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 根据存储类型选择并执行限流处理器
if common.RedisEnabled {
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
} else {
memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
}
}
}

View File

@@ -35,7 +35,7 @@ type Channel struct {
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
Tag *string `json:"tag" gorm:"index"`
Setting string `json:"setting" gorm:"type:text"`
Setting *string `json:"setting" gorm:"type:text"`
}
func (channel *Channel) GetModels() []string {
@@ -290,35 +290,42 @@ func (channel *Channel) Delete() error {
var channelStatusLock sync.Mutex
func UpdateChannelStatusById(id int, status int, reason string) {
func UpdateChannelStatusById(id int, status int, reason string) bool {
if common.MemoryCacheEnabled {
channelStatusLock.Lock()
defer channelStatusLock.Unlock()
channelCache, _ := CacheGetChannel(id)
// 如果缓存渠道存在,且状态已是目标状态,直接返回
if channelCache != nil && channelCache.Status == status {
channelStatusLock.Unlock()
return
return false
}
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
if channelCache == nil && status != common.ChannelStatusEnabled {
channelStatusLock.Unlock()
return
return false
}
CacheUpdateChannelStatus(id, status)
channelStatusLock.Unlock()
}
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
return false
}
channel, err := GetChannelById(id, true)
if err != nil {
// find channel by id error, directly update status
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
if result.Error != nil {
common.SysError("failed to update channel status: " + result.Error.Error())
return false
}
if result.RowsAffected == 0 {
return false
}
} else {
if channel.Status == status {
return false
}
// find channel by id success, update status and other info
info := channel.GetOtherInfo()
info["status_reason"] = reason
@@ -328,9 +335,10 @@ func UpdateChannelStatusById(id int, status int, reason string) {
err = channel.Save()
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
return false
}
}
return true
}
func EnableChannelByTag(tag string) error {
@@ -485,8 +493,8 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
func (channel *Channel) GetSetting() map[string]interface{} {
setting := make(map[string]interface{})
if channel.Setting != "" {
err := json.Unmarshal([]byte(channel.Setting), &setting)
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}
@@ -500,7 +508,7 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
common.SysError("failed to marshal setting: " + err.Error())
return
}
channel.Setting = string(settingBytes)
channel.Setting = common.GetPointer[string](string(settingBytes))
}
func GetChannelsByIds(ids []int) ([]*Channel, error) {

View File

@@ -1,13 +1,14 @@
package model
import (
"context"
"fmt"
"one-api/common"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
@@ -18,7 +19,7 @@ type Log struct {
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
Quota int `json:"quota" gorm:"default:0"`
@@ -87,14 +88,14 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
username, _ := GetUsernameById(userId, false)
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
@@ -116,7 +117,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
err := LOG_DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
common.LogError(c, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {

View File

@@ -3,6 +3,8 @@ package model
import (
"one-api/common"
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
"strconv"
"strings"
"time"
@@ -23,6 +25,8 @@ func AllOption() ([]*Option, error) {
func InitOptionMap() {
common.OptionMapRWMutex.Lock()
common.OptionMap = make(map[string]string)
// 添加原有的系统配置
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
@@ -84,15 +88,19 @@ func InitOptionMap() {
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["ChatLink2"] = common.ChatLink2
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
@@ -104,13 +112,20 @@ func InitOptionMap() {
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled)
common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled)
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
for k, v := range modelConfigs {
common.OptionMap[k] = v
}
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -154,6 +169,13 @@ func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock()
common.OptionMap[key] = value
// 检查是否是模型配置 - 使用更规范的方式处理
if handleConfigUpdate(key, value) {
return nil // 已由配置系统处理
}
// 处理传统配置项...
if strings.HasSuffix(key, "Permission") {
intValue, _ := strconv.Atoi(value)
switch key {
@@ -223,11 +245,13 @@ func updateOptionMap(key string, value string) (err error) {
case "CheckSensitiveEnabled":
setting.CheckSensitiveEnabled = boolValue
case "DemoSiteEnabled":
setting.DemoSiteEnabled = boolValue
operation_setting.DemoSiteEnabled = boolValue
case "SelfUseModeEnabled":
operation_setting.SelfUseModeEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
setting.CheckSensitiveOnPromptEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue
case "ModelRequestRateLimitEnabled":
setting.ModelRequestRateLimitEnabled = boolValue
case "StopOnSensitiveEnabled":
setting.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled":
@@ -306,8 +330,14 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "ShouldPreConsumedQuota":
case "PreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRequestRateLimitCount":
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitDurationMinutes":
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
@@ -315,21 +345,23 @@ func updateOptionMap(key string, value string) (err error) {
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
err = operation_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = setting.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
err = operation_setting.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
err = common.UpdateModelPriceByJSONString(value)
err = operation_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = operation_setting.UpdateCacheRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
case "ChatLink":
common.ChatLink = value
case "ChatLink2":
common.ChatLink2 = value
//case "ChatLink":
// common.ChatLink = value
//case "ChatLink2":
// common.ChatLink2 = value
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
@@ -337,9 +369,34 @@ func updateOptionMap(key string, value string) (err error) {
case "SensitiveWords":
setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
setting.AutomaticDisableKeywordsFromString(value)
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
return err
}
// handleConfigUpdate 处理分层配置更新,返回是否已处理
func handleConfigUpdate(key, value string) bool {
parts := strings.SplitN(key, ".", 2)
if len(parts) != 2 {
return false // 不是分层配置
}
configName := parts[0]
configKey := parts[1]
// 获取配置对象
cfg := config.GlobalConfig.Get(configName)
if cfg == nil {
return false // 未注册的配置
}
// 更新配置
configMap := map[string]string{
configKey: value,
}
config.UpdateConfigFromMap(cfg, configMap)
return true // 已处理
}

View File

@@ -2,6 +2,7 @@ package model
import (
"one-api/common"
"one-api/setting/operation_setting"
"sync"
"time"
)
@@ -64,13 +65,14 @@ func updatePricing() {
ModelName: model,
EnableGroup: groups,
}
modelPrice, findPrice := common.GetModelPrice(model, false)
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
pricing.ModelRatio = common.GetModelRatio(model)
pricing.CompletionRatio = common.GetCompletionRatio(model)
modelRatio, _ := operation_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)

View File

@@ -3,7 +3,7 @@ package model
type TopUp struct {
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
Amount int `json:"amount"`
Amount int64 `json:"amount"`
Money float64 `json:"money"`
TradeNo string `json:"trade_no"`
CreateTime int64 `json:"create_time"`

View File

@@ -320,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
@@ -502,35 +502,35 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser
}
// IsUserEnabled checks user status from Redis first, falls back to DB if needed
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserStatusCache(id, status); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
// Try Redis first
status, err := getUserStatusCache(id)
if err == nil {
return status == common.UserStatusEnabled, nil
}
// Don't return error - fall through to DB
}
fromDB = true
var user User
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
if err != nil {
return false, err
}
return user.Status == common.UserStatusEnabled, nil
}
//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
// defer func() {
// // Update Redis cache asynchronously on successful DB read
// if shouldUpdateRedis(fromDB, err) {
// gopool.Go(func() {
// if err := updateUserStatusCache(id, status); err != nil {
// common.SysError("failed to update user status cache: " + err.Error())
// }
// })
// }
// }()
// if !fromDB && common.RedisEnabled {
// // Try Redis first
// status, err := getUserStatusCache(id)
// if err == nil {
// return status == common.UserStatusEnabled, nil
// }
// // Don't return error - fall through to DB
// }
// fromDB = true
// var user User
// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
// if err != nil {
// return false, err
// }
//
// return user.Status == common.UserStatusEnabled, nil
//}
func ValidateAccessToken(token string) (user *User) {
if token == "" {
@@ -639,7 +639,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
return common.StrToMap(setting), nil
}
func IncreaseUserQuota(id int, quota int) (err error) {
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -649,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
common.SysError("failed to increase user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
if !db && common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
@@ -694,7 +694,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
return nil
}
if delta > 0 {
return IncreaseUserQuota(id, delta)
return IncreaseUserQuota(id, delta, false)
} else {
return DecreaseUserQuota(id, -delta)
}

View File

@@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
"time"
@@ -21,6 +22,15 @@ type UserBase struct {
Setting string `json:"setting"`
}
func (user *UserBase) WriteContext(c *gin.Context) {
c.Set(constant.ContextKeyUserGroup, user.Group)
c.Set(constant.ContextKeyUserQuota, user.Quota)
c.Set(constant.ContextKeyUserStatus, user.Status)
c.Set(constant.ContextKeyUserEmail, user.Email)
c.Set("username", user.Username)
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
)
@@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {

View File

@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req, info.ToRelayInfo())
resp, err := doRequest(c, req, info.RelayInfo)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}

View File

@@ -8,6 +8,7 @@ import (
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
)
const (
@@ -38,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
return nil
}
@@ -49,8 +51,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
var claudeReq *claude.ClaudeRequest
var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
c.Set("request_model", request.Model)
if err != nil {
return nil, err
}
c.Set("request_model", claudeReq.Model)
c.Set("converted_request", claudeReq)
return claudeReq, err
}
@@ -64,7 +68,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
}

View File

@@ -9,7 +9,8 @@ var awsModelIDMap = map[string]string{
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
}
var ChannelName = "aws"

View File

@@ -14,8 +14,9 @@ type AwsClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
@@ -30,5 +31,6 @@ func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
StopSequences: req.StopSequences,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
}
}

View File

@@ -12,6 +12,7 @@ import (
relaymodel "one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"time"
@@ -203,13 +204,13 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
})
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := service.ObjectData(c, response)
response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
helper.Done(c)
if resp != nil {
err = resp.Body.Close()
if err != nil {

View File

@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"sync"
@@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -9,6 +9,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"strings"
)
@@ -55,6 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
anthropicVersion = "2023-06-01"
}
req.Set("anthropic-version", anthropicVersion)
model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
return nil
}

View File

@@ -11,6 +11,8 @@ var ModelList = []string{
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking",
}
var ChannelName = "claude"

View File

@@ -11,6 +11,9 @@ type ClaudeMediaMessage struct {
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
Delta string `json:"delta,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
@@ -54,9 +57,15 @@ type ClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
}
type ClaudeError struct {

View File

@@ -1,7 +1,6 @@
package claude
import (
"bufio"
"encoding/json"
"fmt"
"io"
@@ -9,7 +8,9 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -92,9 +93,31 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
Stream: textRequest.Stream,
Tools: claudeTools,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
}
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") {
// 因为BudgetTokens 必须大于1024
if claudeRequest.MaxTokens < 1280 {
claudeRequest.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &Thinking{
Type: "enabled",
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
claudeRequest.TopP = 0
claudeRequest.Temperature = common.GetPointer[float64](1.0)
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
}
if textRequest.Stop != nil {
// stop maybe string/array string, convert to array string
switch textRequest.Stop.(type) {
@@ -273,7 +296,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0)
tools := make([]dto.ToolCallResponse, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
@@ -292,10 +315,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{
tools = append(tools, dto.ToolCallResponse{
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
@@ -308,12 +331,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.Delta != nil {
choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, dto.ToolCall{
Function: dto.FunctionCall{
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
thinkingContent := claudeResponse.Delta.Thinking
choice.Delta.ReasoningContent = &thinkingContent
}
}
} else if claudeResponse.Type == "message_delta" {
@@ -351,7 +382,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
tools := make([]dto.ToolCall, 0)
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
@@ -367,16 +400,22 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
} else {
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
if message.Type == "tool_use" {
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCall{
tools = append(tools, dto.ToolCallResponse{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Name: message.Name,
Arguments: string(args),
},
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking
case "text":
responseText = message.Text
}
}
}
@@ -391,6 +430,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(tools) > 0 {
choice.Message.SetToolCalls(tools)
}
choice.Message.ReasoningContent = thinkingContent
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice)
fullTextResponse.Choices = choices
@@ -403,28 +443,18 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
usage = &dto.Usage{}
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
data = strings.TrimSpace(data)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
return true
}
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
continue
return true
}
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
@@ -441,9 +471,9 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
return true
} else {
continue
return true
}
}
//response.Id = responseId
@@ -451,11 +481,12 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
response.Created = createdTime
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
err = helper.ObjectData(c, response)
if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error())
}
}
return true
})
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
@@ -468,14 +499,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
resp.Body.Close()
helper.Done(c)
//resp.Body.Close()
return nil, usage
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"time"
@@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
id := service.GetResponseID(c)
helper.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
var responseText string
isFirst := true
@@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
}
response.Id = id
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
err = helper.ObjectData(c, response)
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
@@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
service.Done(c)
helper.Done(c)
err := resp.Body.Close()
if err != nil {
@@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = service.GetResponseID(c)
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"time"
@@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool {
select {

View File

@@ -9,9 +9,18 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"strings"
)
const (
BotTypeChatFlow = 1 // chatflow default
BotTypeAgent = 2
BotTypeWorkFlow = 3
BotTypeCompletion = 4
)
type Adaptor struct {
BotType int
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -25,10 +34,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "agent") {
a.BotType = BotTypeAgent
} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
a.BotType = BotTypeWorkFlow
} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
a.BotType = BotTypeCompletion
} else {
a.BotType = BotTypeChatFlow
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
switch a.BotType {
case BotTypeWorkFlow:
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
case BotTypeCompletion:
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
case BotTypeAgent:
fallthrough
default:
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -53,7 +80,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
)
@@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
@@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
responseText += openaiResponse.Choices[0].Delta.GetContentString()
}
}
err = service.ObjectData(c, openaiResponse)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
@@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
helper.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -7,11 +7,11 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting/model_setting"
"strings"
@@ -64,15 +64,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta"
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
} else {
version = "v1beta"
}
}
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil

View File

@@ -20,4 +20,12 @@ var ModelList = []string{
"imagen-3.0-generate-002",
}
var SafetySettingList = []string{
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
}
var ChannelName = "google gemini"

View File

@@ -1,7 +1,6 @@
package gemini
import (
"bufio"
"encoding/json"
"fmt"
"io"
@@ -10,7 +9,9 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"strings"
"unicode/utf8"
@@ -22,28 +23,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
Threshold: common.GeminiSafetySetting,
},
},
//SafetySettings: []GeminiChatSafetySettings{},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
@@ -52,9 +32,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
Category: category,
Threshold: model_setting.GetGeminiSafetySetting(category),
})
}
geminiRequest.SafetySettings = safetySettings
// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
for _, tool := range textRequest.Tools {
@@ -349,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data
}
func getToolCall(item *GeminiPart) *dto.ToolCall {
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte
var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -361,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
if err != nil {
return nil
}
return &dto.ToolCall{
return &dto.ToolCallResponse{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
@@ -379,7 +368,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
is_tool_call := false
isToolCall := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: int(candidate.Index),
@@ -391,12 +380,12 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
}
if len(candidate.Content.Parts) > 0 {
var texts []string
var tool_calls []dto.ToolCall
var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call)
if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call)
}
} else {
if part.ExecutableCode != nil {
@@ -411,9 +400,9 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
}
}
}
if len(tool_calls) > 0 {
choice.Message.SetToolCalls(tool_calls)
is_tool_call = true
if len(toolCalls) > 0 {
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
@@ -429,7 +418,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
choice.FinishReason = constant.FinishReasonContentFilter
}
}
if is_tool_call {
if isToolCall {
choice.FinishReason = constant.FinishReasonToolCalls
}
@@ -440,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
is_stop := false
isStop := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
is_stop = true
isStop = true
candidate.FinishReason = nil
}
choice := dto.ChatCompletionsStreamResponseChoice{
@@ -468,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
isTools = true
if call := getToolCall(&part); call != nil {
if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
@@ -493,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = choices
return &response, is_stop
return &response, isStop
}
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -503,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"")
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
continue
return false
}
response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
@@ -532,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
}
err = service.ObjectData(c, response)
err = helper.ObjectData(c, response)
if err != nil {
common.LogError(c, err.Error())
}
if is_stop {
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
service.ObjectData(c, response)
if isStop {
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
helper.ObjectData(c, response)
}
}
return true
})
var response *dto.ChatCompletionsStreamResponse
@@ -549,14 +527,14 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
if info.ShouldIncludeUsage {
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
resp.Body.Close()
helper.Done(c)
//resp.Body.Close()
return nil, usage
}

View File

@@ -61,7 +61,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)
err, usage = JinaRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = jinaEmbeddingHandler(c, resp)
}

View File

@@ -9,7 +9,7 @@ import (
"one-api/service"
)
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -41,7 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
return requestOpenAI2Mistral(request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

View File

@@ -1,25 +1,21 @@
package mistral
import (
"encoding/json"
"one-api/dto"
)
func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
}
message.SetMediaContent(mediaMessages)
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,

View File

@@ -3,21 +3,22 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
}
type Options struct {

View File

@@ -58,6 +58,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
MaxTokens: request.MaxTokens,
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,

View File

@@ -14,6 +14,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
"one-api/relay/channel/jina"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
@@ -146,7 +147,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
@@ -228,6 +229,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank:
err, usage = jina.JinaRerankHandler(c, resp)
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)

View File

@@ -11,6 +11,7 @@ var ModelList = []string{
"chatgpt-4o-latest",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
"o3-mini", "o3-mini-2025-01-31",

View File

@@ -1,14 +1,9 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"io"
"math"
"mime/multipart"
@@ -18,26 +13,94 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"os"
"strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
if data == "" {
return nil
}
if forceFormat {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
return err
}
return service.ObjectData(c, lastStreamResponse)
if !forceFormat && !thinkToContent {
return helper.StringData(c, data)
}
return service.StringData(c, data)
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
return err
}
if !thinkToContent {
return helper.ObjectData(c, lastStreamResponse)
}
hasThinkingContent := false
hasContent := false
var thinkingContent strings.Builder
for _, choice := range lastStreamResponse.Choices {
if len(choice.Delta.GetReasoningContent()) > 0 {
hasThinkingContent = true
thinkingContent.WriteString(choice.Delta.GetReasoningContent())
}
if len(choice.Delta.GetContentString()) > 0 {
hasContent = true
}
}
// Handle think to content conversion
if info.ThinkingContentInfo.IsFirstThinkingContent {
if hasThinkingContent {
response := lastStreamResponse.Copy()
for i := range response.Choices {
// send `think` tag with thinking content
response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
response.Choices[i].Delta.ReasoningContent = nil
response.Choices[i].Delta.Reasoning = nil
}
info.ThinkingContentInfo.IsFirstThinkingContent = false
return helper.ObjectData(c, response)
}
}
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
return helper.ObjectData(c, lastStreamResponse)
}
// Process each choice
for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent {
response := lastStreamResponse.Copy()
for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>\n")
response.Choices[j].Delta.ReasoningContent = nil
response.Choices[j].Delta.Reasoning = nil
}
info.ThinkingContentInfo.SendLastThinkingContent = true
helper.ObjectData(c, response)
}
// Convert reasoning content to regular content
if len(choice.Delta.GetReasoningContent()) > 0 {
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
lastStreamResponse.Choices[i].Delta.Reasoning = nil
} else if !hasThinkingContent && !hasContent {
// flush thinking content
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
lastStreamResponse.Choices[i].Delta.Reasoning = nil
}
}
return helper.ObjectData(c, lastStreamResponse)
}
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -56,68 +119,33 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var usage = &dto.Usage{}
var streamItems []string // store stream items
var forceFormat bool
var thinkToContent bool
if info.ChannelType == common.ChannelTypeCustom {
if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok {
forceFormat = forceFmt
}
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
}
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
thinkToContent = think2Content
}
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
// twice timeout for o1 model
streamingTimeout *= 2
}
ticker := time.NewTicker(streamingTimeout)
defer ticker.Stop()
stopChan := make(chan bool)
defer close(stopChan)
var (
lastStreamData string
mu sync.Mutex
)
gopool.Go(func() {
for scanner.Scan() {
info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" && data[:6] != "[DONE]" {
continue
}
mu.Lock()
data = data[5:]
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" {
err := sendStreamData(c, lastStreamData, forceFormat)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
}
lastStreamData = data
streamItems = append(streamItems, data)
}
mu.Unlock()
}
common.SafeSendBool(stopChan, true)
})
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
}
lastStreamData = data
streamItems = append(streamItems, data)
return true
})
shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse
@@ -141,7 +169,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
if shouldSendLastResp {
sendStreamData(c, lastStreamData, forceFormat)
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 计算token
@@ -162,7 +190,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
// handle both reasoning_content and reasoning
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -183,7 +214,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) // This will handle both reasoning_content and reasoning
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -223,17 +254,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
if info.ChannelType == common.ChannelTypeDeepSeek {
if usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
}
if info.ShouldIncludeUsage && !containStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
helper.ObjectData(c, response)
}
service.Done(c)
helper.Done(c)
resp.Body.Close()
//resp.Body.Close()
return nil, usage
}
@@ -275,7 +312,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, model)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
@@ -464,7 +501,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = service.WssString(c, targetConn, string(message))
err = helper.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
return
@@ -570,7 +607,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
err = service.WssString(c, clientConn, string(message))
err = helper.WssString(c, clientConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to client: %v", err)
return

View File

@@ -0,0 +1,74 @@
package openrouter
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
req.Set("X-Title", "New API")
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,5 @@
package openrouter
var ModelList = []string{}
var ChannelName = "openrouter"

View File

@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
)
@@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
dataChan <- string(jsonResponse)
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -14,6 +14,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strconv"
"strings"
@@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
@@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
responseText += response.Choices[0].Delta.GetContentString()
}
err = service.ObjectData(c, response)
err = helper.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
@@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
helper.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"io"
"net/http"
"one-api/dto"
@@ -28,6 +27,8 @@ var claudeModelMap = map[string]string{
"claude-3-opus-20240229": "claude-3-opus@20240229",
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
}
const anthropicVersion = "vertex-2023-10-16"
@@ -85,15 +86,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} else {
suffix = "rawPredict"
}
model := info.UpstreamModelName
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
info.UpstreamModelName = v
model = v
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
info.UpstreamModelName,
model,
suffix,
), nil
} else if a.RequestMode == RequestModeLlama {
@@ -126,13 +128,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if err != nil {
return nil, err
}
vertexClaudeReq := &VertexAIClaudeRequest{
AnthropicVersion: anthropicVersion,
}
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
return nil, errors.New("failed to copy claude request")
}
c.Set("request_model", request.Model)
vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
c.Set("request_model", claudeReq.Model)
info.UpstreamModelName = claudeReq.Model
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
@@ -156,7 +154,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,17 +1,37 @@
package vertex
import "one-api/relay/channel/claude"
import (
"one-api/relay/channel/claude"
)
type VertexAIClaudeRequest struct {
AnthropicVersion string `json:"anthropic_version"`
Messages []claude.ClaudeMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
System any `json:"system,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *claude.ClaudeRequest, version string) *VertexAIClaudeRequest {
return &VertexAIClaudeRequest{
AnthropicVersion: version,
System: req.System,
Messages: req.Messages,
MaxTokens: req.MaxTokens,
Stream: req.Stream,
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
StopSequences: req.StopSequences,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
}
}

View File

@@ -14,6 +14,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"time"
@@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
var usage dto.Usage
c.Stream(func(w io.Writer) bool {
select {

View File

@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"sync"
@@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -10,6 +10,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"sync"
@@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -12,6 +12,11 @@ import (
"github.com/gorilla/websocket"
)
type ThinkingContentInfo struct {
IsFirstThinkingContent bool
SendLastThinkingContent bool
}
type RelayInfo struct {
ChannelType int
ChannelId int
@@ -22,7 +27,8 @@ type RelayInfo struct {
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
setFirstResponse bool
isFirstResponse bool
//SendLastReasoningResponse bool
ApiType int
IsStream bool
IsPlayground bool
@@ -49,6 +55,10 @@ type RelayInfo struct {
AudioUsage bool
ReasoningEffort string
ChannelSetting map[string]interface{}
UserSetting map[string]interface{}
UserEmail string
UserQuota int
ThinkingContentInfo
}
// 定义支持流式选项的通道类型
@@ -88,6 +98,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
UserEmail: c.GetString(constant.ContextKeyUserEmail),
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
@@ -109,6 +123,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
},
}
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
@@ -139,26 +157,14 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
}
func (info *RelayInfo) SetFirstResponseTime() {
if !info.setFirstResponse {
if info.isFirstResponse {
info.FirstResponseTime = time.Now()
info.setFirstResponse = true
info.isFirstResponse = false
}
}
type TaskRelayInfo struct {
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
StartTime time.Time
ApiType int
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiKey string
BaseUrl string
*RelayInfo
Action string
OriginTaskID string
@@ -166,48 +172,8 @@ type TaskRelayInfo struct {
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &TaskRelayInfo{
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
StartTime: startTime,
ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
RelayInfo: GenRelayInfo(c),
}
return info
}
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}

View File

@@ -30,6 +30,7 @@ const (
APITypeMokaAI
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -86,6 +87,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeVolcEngine
case common.ChannelTypeBaiduV2:
apiType = APITypeBaiduV2
case common.ChannelTypeOpenRouter:
apiType = APITypeOpenRouter
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -1,4 +1,4 @@
package service
package helper
import (
"encoding/json"

View File

@@ -1,31 +1,47 @@
package helper
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
"one-api/setting/operation_setting"
)
type PriceData struct {
ModelPrice float64
ModelRatio float64
CompletionRatio float64
CacheRatio float64
GroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
var cacheRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 {
preConsumedTokens = promptTokens + maxTokens
}
modelRatio = common.GetModelRatio(info.OriginModelName)
var success bool
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
if !success {
if info.UserId == 1 {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
} else {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
}
}
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
@@ -34,8 +50,10 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
return PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
GroupRatio: groupRatio,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ShouldPreConsumedQuota: preConsumedQuota,
}
}, nil
}

View File

@@ -0,0 +1,91 @@
package helper
import (
"bufio"
"context"
"io"
"net/http"
"one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
if resp == nil {
return
}
defer resp.Body.Close()
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
// twice timeout for thinking model
streamingTimeout *= 2
}
var (
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
)
defer func() {
ticker.Stop()
close(stopChan)
}()
scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "stop_chan", stopChan)
common.RelayCtxGo(ctx, func() {
for scanner.Scan() {
ticker.Reset(streamingTimeout)
data := scanner.Text()
if common.DebugEnabled {
println(data)
}
if len(data) < 6 {
continue
}
if data[:5] != "data:" && data[:6] != "[DONE]" {
continue
}
data = data[5:]
data = strings.TrimLeft(data, " ")
data = strings.TrimSuffix(data, "\"")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
success := dataHandler(data)
if !success {
break
}
}
}
if err := scanner.Err(); err != nil {
if err != io.EOF {
common.LogError(c, "scanner error: "+err.Error())
}
}
common.SafeSendBool(stopChan, true)
})
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
}

View File

@@ -7,7 +7,6 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
@@ -75,12 +74,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo.PromptTokens = promptTokens
}
priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr

View File

@@ -86,7 +86,10 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
imageRequest.Model = relayInfo.UpstreamModelName
priceData := helper.ModelPriceHelper(c, relayInfo, 0, 0)
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
if !priceData.UsePrice {
// modelRatio 16 = modelPrice $0.04
// per 1 modelRatio = $0.04 / 16
@@ -115,8 +118,8 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
}
}
imageRatio := priceData.ModelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(imageRatio * priceData.GroupRatio * common.QuotaPerUnit)
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -16,6 +15,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
"strconv"
"strings"
"time"
@@ -158,10 +158,10 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
modelPrice, success := common.GetModelPrice(modelName, true)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -192,7 +192,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
if err != nil {
return &mjResp.Response
}
defer func(ctx context.Context) {
defer func() {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
@@ -208,14 +208,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
}()
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
@@ -464,10 +464,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
modelName := service.CoverActionToModelName(midjRequest.Action)
modelPrice, success := common.GetModelPrice(modelName, true)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -498,7 +498,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
defer func() {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
@@ -510,14 +510,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
}()
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
@@ -21,6 +20,9 @@ import (
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/shopspring/decimal"
"github.com/gin-gonic/gin"
)
@@ -106,7 +108,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens)
}
priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -248,6 +253,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
}
relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited {
@@ -267,7 +273,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
}
if preConsumedQuota > 0 {
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
@@ -300,34 +306,55 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
CompletionTokens: 0,
TotalTokens: relayInfo.PromptTokens,
}
extraContent += " (可能是请求出错)"
extraContent += "(可能是请求出错)"
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
cacheTokens := usage.PromptTokensDetails.CachedTokens
completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
ratio := priceData.ModelRatio * priceData.GroupRatio
completionRatio := priceData.CompletionRatio
cacheRatio := priceData.CacheRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
quota := 0
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
dCompletionRatio := decimal.NewFromFloat(completionRatio)
dCacheRatio := decimal.NewFromFloat(cacheRatio)
dModelRatio := decimal.NewFromFloat(modelRatio)
dGroupRatio := decimal.NewFromFloat(groupRatio)
dModelPrice := decimal.NewFromFloat(modelPrice)
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
ratio := dModelRatio.Mul(dGroupRatio)
var quotaCalculateDecimal decimal.Decimal
if !priceData.UsePrice {
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio)
promptQuota := nonCachedTokens.Add(cachedTokensWithRatio)
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
quotaCalculateDecimal = decimal.NewFromInt(1)
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
}
quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens
var logContent string
if !usePrice {
if !priceData.UsePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
@@ -342,9 +369,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
//if sensitiveResp != nil {
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
//}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
@@ -368,11 +392,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
//if quota != 0 {
//
//}
}

View File

@@ -18,6 +18,7 @@ import (
"one-api/relay/channel/mokaai"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
@@ -83,6 +84,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &volcengine.Adaptor{}
case constant.APITypeBaiduV2:
return &baidu_v2.Adaptor{}
case constant.APITypeOpenRouter:
return &openrouter.Adaptor{}
}
return nil
}

View File

@@ -57,8 +57,10 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
promptToken := getEmbeddingPromptToken(*embeddingRequest)
relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {

View File

@@ -50,8 +50,10 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
promptToken := getRerankPromptToken(*rerankRequest)
relayInfo.PromptTokens = promptToken
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -17,6 +16,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
)
/*
@@ -38,9 +38,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
modelPrice, success := common.GetModelPrice(modelName, true)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
if !success {
defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -109,11 +109,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
return
}
defer func(ctx context.Context) {
defer func() {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -123,13 +123,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
}
}(c.Request.Context())
}()
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
if taskErr != nil {

View File

@@ -11,6 +11,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
)
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -39,7 +40,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
}
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
@@ -65,7 +66,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
//}
modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName)
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {

View File

@@ -84,6 +84,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/models", controller.ChannelListModels)
channelRoute.GET("/models_enabled", controller.EnabledListModels)
channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test", controller.TestAllChannels)
channelRoute.GET("/test/:id", controller.TestChannel)

View File

@@ -24,6 +24,7 @@ func SetRelayRouter(router *gin.Engine) {
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth())
relayV1Router.Use(middleware.ModelRequestRateLimit())
{
// WebSocket 路由
wsRouter := relayV1Router.Group("")

View File

@@ -6,23 +6,31 @@ import (
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
"strings"
)
func formatNotifyType(channelId int, status int) string {
return fmt.Sprintf("%s_%d_%d", dto.NotifyTypeChannelUpdate, channelId, status)
}
// disable & notify
func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用,原因:%s", channelName, channelId, reason)
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
if success {
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content)
}
}
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
if success {
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusEnabled), subject, content)
}
}
func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
@@ -67,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
}
lowerMessage := strings.ToLower(err.Error.Message)
search, _ := AcSearch(lowerMessage, setting.AutomaticDisableKeywords, true)
search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true)
if search {
return true
}

View File

@@ -7,7 +7,9 @@ import (
"fmt"
"image"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"strings"
"golang.org/x/image/webp"
@@ -23,7 +25,7 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
decodedData, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
fmt.Println("Error: Failed to decode base64 string")
return image.Config{}, "", "", err
return image.Config{}, "", "", fmt.Errorf("failed to decode base64 string: %s", err.Error())
}
// 创建一个bytes.Buffer用于存储解码后的数据
@@ -61,20 +63,51 @@ func DecodeBase64FileData(base64String string) (string, string, error) {
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
resp, err := DoDownloadRequest(url)
if err != nil {
return "", "", err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
return "", "", fmt.Errorf("failed to download image: %w", err)
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
// Check HTTP status code
if resp.StatusCode != http.StatusOK {
return "", "", fmt.Errorf("failed to download image: HTTP %d", resp.StatusCode)
}
mimeType = resp.Header.Get("Content-Type")
contentType := resp.Header.Get("Content-Type")
if contentType != "application/octet-stream" && !strings.HasPrefix(contentType, "image/") {
return "", "", fmt.Errorf("invalid content type: %s, required image/*", contentType)
}
maxImageSize := int64(constant.MaxFileDownloadMB * 1024 * 1024)
// Check Content-Length if available
if resp.ContentLength > maxImageSize {
return "", "", fmt.Errorf("image size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxImageSize)
}
// Use LimitReader to prevent reading oversized images
limitReader := io.LimitReader(resp.Body, maxImageSize)
buffer := &bytes.Buffer{}
written, err := io.Copy(buffer, limitReader)
if err != nil {
return "", "", fmt.Errorf("failed to read image data: %w", err)
}
if written >= maxImageSize {
return "", "", fmt.Errorf("image size exceeds maximum allowed size of %d bytes", maxImageSize)
}
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
mimeType = contentType
// Handle application/octet-stream type
if mimeType == "application/octet-stream" {
_, format, _, err := DecodeBase64ImageData(data)
if err != nil {
return "", "", err
}
mimeType = "image/" + format
}
return mimeType, data, nil
}
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
@@ -92,7 +125,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
mimeType := response.Header.Get("Content-Type")
if !strings.HasPrefix(mimeType, "image/") {
if mimeType != "application/octet-stream" && !strings.HasPrefix(mimeType, "image/") {
return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
}

View File

@@ -1,16 +1,20 @@
package service
import (
"github.com/gin-gonic/gin"
"one-api/dto"
relaycommon "one-api/relay/common"
"github.com/gin-gonic/gin"
)
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
cacheTokens int, cacheRatio float64, modelPrice float64) map[string]interface{} {
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
other["completion_ratio"] = completionRatio
other["cache_tokens"] = cacheTokens
other["cache_ratio"] = cacheRatio
other["model_price"] = modelPrice
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
if relayInfo.ReasoningEffort != "" {
@@ -27,7 +31,7 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
}
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
info["ws"] = true
info["audio_input"] = usage.InputTokenDetails.AudioTokens
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
@@ -39,7 +43,7 @@ func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us
}
func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
info["audio"] = true
info["audio_input"] = usage.PromptTokensDetails.AudioTokens
info["audio_output"] = usage.CompletionTokenDetails.AudioTokens

View File

@@ -3,8 +3,6 @@ package service
import (
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"math"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
@@ -12,10 +10,14 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting"
"one-api/setting/operation_setting"
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
)
type TokenDetails struct {
@@ -35,24 +37,41 @@ type QuotaInfo struct {
func calculateAudioQuota(info QuotaInfo) int {
if info.UsePrice {
return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
modelPrice := decimal.NewFromFloat(info.ModelPrice)
quotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
groupRatio := decimal.NewFromFloat(info.GroupRatio)
quota := modelPrice.Mul(quotaPerUnit).Mul(groupRatio)
return int(quota.IntPart())
}
completionRatio := common.GetCompletionRatio(info.ModelName)
audioRatio := common.GetAudioRatio(info.ModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName)
ratio := info.GroupRatio * info.ModelRatio
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) +
int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio))
groupRatio := decimal.NewFromFloat(info.GroupRatio)
modelRatio := decimal.NewFromFloat(info.ModelRatio)
ratio := groupRatio.Mul(modelRatio)
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
inputTextTokens := decimal.NewFromInt(int64(info.InputDetails.TextTokens))
outputTextTokens := decimal.NewFromInt(int64(info.OutputDetails.TextTokens))
inputAudioTokens := decimal.NewFromInt(int64(info.InputDetails.AudioTokens))
outputAudioTokens := decimal.NewFromInt(int64(info.OutputDetails.AudioTokens))
quota := decimal.Zero
quota = quota.Add(inputTextTokens)
quota = quota.Add(outputTextTokens.Mul(completionRatio))
quota = quota.Add(inputAudioTokens.Mul(audioRatio))
quota = quota.Add(outputAudioTokens.Mul(audioRatio).Mul(audioCompletionRatio))
quota = quota.Mul(ratio)
// If ratio is not zero and quota is less than or equal to zero, set quota to 1
if !ratio.IsZero() && quota.LessThanOrEqual(decimal.Zero) {
quota = decimal.NewFromInt(1)
}
return quota
return int(quota.Round(0).IntPart())
}
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
@@ -75,7 +94,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
groupRatio := setting.GetGroupRatio(relayInfo.Group)
modelRatio := common.GetModelRatio(modelName)
modelRatio, _ := operation_setting.GetModelRatio(modelName)
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -122,9 +141,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -146,7 +165,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
@@ -168,7 +188,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -184,9 +205,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
@@ -213,7 +234,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
@@ -242,7 +264,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -276,7 +299,7 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
}
if err != nil {
return err
@@ -295,20 +318,16 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if sendEmail {
if (quota + preConsumedQuota) != 0 {
checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
}
}
return nil
}
func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
gopool.Go(func() {
userCache, err := model.GetUserCache(userId)
if err != nil {
common.SysError("failed to get user cache: " + err.Error())
}
userSetting := userCache.GetSetting()
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
@@ -317,16 +336,16 @@ func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false
consumeQuota := quota + preConsumedQuota
if userCache.Quota-consumeQuota < threshold {
if relayInfo.UserQuota-consumeQuota < threshold {
quotaTooLow = true
}
if quotaTooLow {
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
}
})

View File

@@ -1,7 +1,6 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"image"
@@ -11,6 +10,7 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/setting/operation_setting"
"strings"
"unicode/utf8"
@@ -33,7 +33,7 @@ func InitTokenEncoders() {
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
}
for model, _ := range common.GetDefaultModelRatioMap() {
for model, _ := range operation_setting.GetDefaultModelRatioMap() {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = cl100TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
@@ -78,6 +78,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if text == "" {
return 0
}
return len(tokenEncoder.Encode(text, nil, nil))
}
@@ -167,12 +170,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
}
tkm += msgTokens
if request.Tools != nil {
toolsData, _ := json.Marshal(request.Tools)
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
openaiTools := request.Tools
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
@@ -282,30 +280,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
} else {
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}
}
}

View File

@@ -11,47 +11,48 @@ import (
func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
if err != nil {
common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
}
}
func NotifyUser(user *model.UserBase, data dto.Notify) error {
userSetting := user.GetSetting()
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok {
notifyType = constant.NotifyTypeEmail
}
// Check notification limit
canSend, err := CheckNotificationLimit(user.Id, data.Type)
canSend, err := CheckNotificationLimit(userId, data.Type)
if err != nil {
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
}
switch notifyType {
case constant.NotifyTypeEmail:
userEmail := user.Email
// check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string)
}
if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}
webhookURLStr, ok := webhookURL.(string)
if !ok {
common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
return nil
}

259
setting/config/config.go Normal file
View File

@@ -0,0 +1,259 @@
package config
import (
"encoding/json"
"one-api/common"
"reflect"
"strconv"
"strings"
"sync"
)
// ConfigManager 统一管理所有配置
type ConfigManager struct {
configs map[string]interface{}
mutex sync.RWMutex
}
var GlobalConfig = NewConfigManager()
func NewConfigManager() *ConfigManager {
return &ConfigManager{
configs: make(map[string]interface{}),
}
}
// Register 注册一个配置模块
func (cm *ConfigManager) Register(name string, config interface{}) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
cm.configs[name] = config
}
// Get 获取指定配置模块
func (cm *ConfigManager) Get(name string) interface{} {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
return cm.configs[name]
}
// LoadFromDB 从数据库加载配置
func (cm *ConfigManager) LoadFromDB(options map[string]string) error {
cm.mutex.Lock()
defer cm.mutex.Unlock()
for name, config := range cm.configs {
prefix := name + "."
configMap := make(map[string]string)
// 收集属于此配置的所有选项
for key, value := range options {
if strings.HasPrefix(key, prefix) {
configKey := strings.TrimPrefix(key, prefix)
configMap[configKey] = value
}
}
// 如果找到配置项,则更新配置
if len(configMap) > 0 {
if err := updateConfigFromMap(config, configMap); err != nil {
common.SysError("failed to update config " + name + ": " + err.Error())
continue
}
}
}
return nil
}
// SaveToDB 将配置保存到数据库
func (cm *ConfigManager) SaveToDB(updateFunc func(key, value string) error) error {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
for name, config := range cm.configs {
configMap, err := configToMap(config)
if err != nil {
return err
}
for key, value := range configMap {
dbKey := name + "." + key
if err := updateFunc(dbKey, value); err != nil {
return err
}
}
}
return nil
}
// 辅助函数将配置对象转换为map
func configToMap(config interface{}) (map[string]string, error) {
result := make(map[string]string)
val := reflect.ValueOf(config)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil, nil
}
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// 跳过未导出字段
if !fieldType.IsExported() {
continue
}
// 获取json标签作为键名
key := fieldType.Tag.Get("json")
if key == "" || key == "-" {
key = fieldType.Name
}
// 处理不同类型的字段
var strValue string
switch field.Kind() {
case reflect.String:
strValue = field.String()
case reflect.Bool:
strValue = strconv.FormatBool(field.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
strValue = strconv.FormatInt(field.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
strValue = strconv.FormatUint(field.Uint(), 10)
case reflect.Float32, reflect.Float64:
strValue = strconv.FormatFloat(field.Float(), 'f', -1, 64)
case reflect.Map, reflect.Slice, reflect.Struct:
// 复杂类型使用JSON序列化
bytes, err := json.Marshal(field.Interface())
if err != nil {
return nil, err
}
strValue = string(bytes)
default:
// 跳过不支持的类型
continue
}
result[key] = strValue
}
return result, nil
}
// 辅助函数从map更新配置对象
func updateConfigFromMap(config interface{}, configMap map[string]string) error {
val := reflect.ValueOf(config)
if val.Kind() != reflect.Ptr {
return nil
}
val = val.Elem()
if val.Kind() != reflect.Struct {
return nil
}
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// 跳过未导出字段
if !fieldType.IsExported() {
continue
}
// 获取json标签作为键名
key := fieldType.Tag.Get("json")
if key == "" || key == "-" {
key = fieldType.Name
}
// 检查map中是否有对应的值
strValue, ok := configMap[key]
if !ok {
continue
}
// 根据字段类型设置值
if !field.CanSet() {
continue
}
switch field.Kind() {
case reflect.String:
field.SetString(strValue)
case reflect.Bool:
boolValue, err := strconv.ParseBool(strValue)
if err != nil {
continue
}
field.SetBool(boolValue)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(strValue, 10, 64)
if err != nil {
continue
}
field.SetInt(intValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(strValue, 10, 64)
if err != nil {
continue
}
field.SetUint(uintValue)
case reflect.Float32, reflect.Float64:
floatValue, err := strconv.ParseFloat(strValue, 64)
if err != nil {
continue
}
field.SetFloat(floatValue)
case reflect.Map, reflect.Slice, reflect.Struct:
// 复杂类型使用JSON反序列化
err := json.Unmarshal([]byte(strValue), field.Addr().Interface())
if err != nil {
continue
}
}
}
return nil
}
// ConfigToMap 将配置对象转换为map导出函数
func ConfigToMap(config interface{}) (map[string]string, error) {
return configToMap(config)
}
// UpdateConfigFromMap 从map更新配置对象导出函数
func UpdateConfigFromMap(config interface{}, configMap map[string]string) error {
return updateConfigFromMap(config, configMap)
}
// ExportAllConfigs 导出所有已注册的配置为扁平结构
func (cm *ConfigManager) ExportAllConfigs() map[string]string {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
result := make(map[string]string)
for name, cfg := range cm.configs {
configMap, err := ConfigToMap(cfg)
if err != nil {
continue
}
// 使用 "模块名.配置项" 的格式添加到结果中
for key, value := range configMap {
result[name+"."+key] = value
}
}
return result
}

View File

@@ -0,0 +1,65 @@
package model_setting
import (
"net/http"
"one-api/setting/config"
)
//var claudeHeadersSettings = map[string][]string{}
//
//var ClaudeThinkingAdapterEnabled = true
//var ClaudeThinkingAdapterMaxTokens = 8192
//var ClaudeThinkingAdapterBudgetTokensPercentage = 0.8
// ClaudeSettings 定义Claude模型的配置
type ClaudeSettings struct {
HeadersSettings map[string]map[string][]string `json:"model_headers_settings"`
DefaultMaxTokens map[string]int `json:"default_max_tokens"`
ThinkingAdapterEnabled bool `json:"thinking_adapter_enabled"`
ThinkingAdapterBudgetTokensPercentage float64 `json:"thinking_adapter_budget_tokens_percentage"`
}
// 默认配置
var defaultClaudeSettings = ClaudeSettings{
HeadersSettings: map[string]map[string][]string{},
ThinkingAdapterEnabled: true,
DefaultMaxTokens: map[string]int{
"default": 8192,
},
ThinkingAdapterBudgetTokensPercentage: 0.8,
}
// 全局实例
var claudeSettings = defaultClaudeSettings
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("claude", &claudeSettings)
}
// GetClaudeSettings 获取Claude配置
func GetClaudeSettings() *ClaudeSettings {
// check default max tokens must have default key
if _, ok := claudeSettings.DefaultMaxTokens["default"]; !ok {
claudeSettings.DefaultMaxTokens["default"] = 8192
}
return &claudeSettings
}
func (c *ClaudeSettings) WriteHeaders(originModel string, httpHeader *http.Header) {
if headers, ok := c.HeadersSettings[originModel]; ok {
for headerKey, headerValues := range headers {
httpHeader.Del(headerKey)
for _, headerValue := range headerValues {
httpHeader.Add(headerKey, headerValue)
}
}
}
}
func (c *ClaudeSettings) GetDefaultMaxTokens(model string) int {
if maxTokens, ok := c.DefaultMaxTokens[model]; ok {
return maxTokens
}
return c.DefaultMaxTokens["default"]
}

View File

@@ -0,0 +1,52 @@
package model_setting
import (
"one-api/setting/config"
)
// GeminiSettings 定义Gemini模型的配置
type GeminiSettings struct {
SafetySettings map[string]string `json:"safety_settings"`
VersionSettings map[string]string `json:"version_settings"`
}
// 默认配置
var defaultGeminiSettings = GeminiSettings{
SafetySettings: map[string]string{
"default": "OFF",
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
},
VersionSettings: map[string]string{
"default": "v1beta",
"gemini-1.0-pro": "v1",
},
}
// 全局实例
var geminiSettings = defaultGeminiSettings
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("gemini", &geminiSettings)
}
// GetGeminiSettings 获取Gemini配置
func GetGeminiSettings() *GeminiSettings {
return &geminiSettings
}
// GetGeminiSafetySetting 获取安全设置
func GetGeminiSafetySetting(key string) string {
if value, ok := geminiSettings.SafetySettings[key]; ok {
return value
}
return geminiSettings.SafetySettings["default"]
}
// GetGeminiVersionSetting 获取版本设置
func GetGeminiVersionSetting(key string) string {
if value, ok := geminiSettings.VersionSettings[key]; ok {
return value
}
return geminiSettings.VersionSettings["default"]
}

View File

@@ -0,0 +1,84 @@
package operation_setting
import (
"encoding/json"
"one-api/common"
"sync"
)
var defaultCacheRatio = map[string]float64{
"gpt-4": 0.5,
"o1": 0.5,
"o1-2024-12-17": 0.5,
"o1-preview-2024-09-12": 0.5,
"o1-preview": 0.5,
"o1-mini-2024-09-12": 0.5,
"o1-mini": 0.5,
"gpt-4o-2024-11-20": 0.5,
"gpt-4o-2024-08-06": 0.5,
"gpt-4o": 0.5,
"gpt-4o-mini-2024-07-18": 0.5,
"gpt-4o-mini": 0.5,
"gpt-4o-realtime-preview": 0.5,
"gpt-4o-mini-realtime-preview": 0.5,
"deepseek-chat": 0.1,
"deepseek-reasoner": 0.1,
"deepseek-coder": 0.1,
}
var defaultCreateCacheRatio = map[string]float64{}
var cacheRatioMap map[string]float64
var cacheRatioMapMutex sync.RWMutex
// GetCacheRatioMap returns the cache ratio map
func GetCacheRatioMap() map[string]float64 {
cacheRatioMapMutex.Lock()
defer cacheRatioMapMutex.Unlock()
if cacheRatioMap == nil {
cacheRatioMap = defaultCacheRatio
}
return cacheRatioMap
}
// CacheRatio2JSONString converts the cache ratio map to a JSON string
func CacheRatio2JSONString() string {
GetCacheRatioMap()
jsonBytes, err := json.Marshal(cacheRatioMap)
if err != nil {
common.SysError("error marshalling cache ratio: " + err.Error())
}
return string(jsonBytes)
}
// UpdateCacheRatioByJSONString updates the cache ratio map from a JSON string
func UpdateCacheRatioByJSONString(jsonStr string) error {
cacheRatioMapMutex.Lock()
defer cacheRatioMapMutex.Unlock()
cacheRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
}
// GetCacheRatio returns the cache ratio for a model
func GetCacheRatio(name string) (float64, bool) {
GetCacheRatioMap()
ratio, ok := cacheRatioMap[name]
if !ok {
return 1, false // Default to 0.5 if not found
}
return ratio, true
}
// DefaultCacheRatio2JSONString converts the default cache ratio map to a JSON string
func DefaultCacheRatio2JSONString() string {
jsonBytes, err := json.Marshal(defaultCacheRatio)
if err != nil {
common.SysError("error marshalling default cache ratio: " + err.Error())
}
return string(jsonBytes)
}
// GetDefaultCacheRatioMap returns the default cache ratio map
func GetDefaultCacheRatioMap() map[string]float64 {
return defaultCacheRatio
}

View File

@@ -0,0 +1,21 @@
package operation_setting
import "one-api/setting/config"
type GeneralSetting struct {
DocsLink string `json:"docs_link"`
}
// 默认配置
var generalSetting = GeneralSetting{
DocsLink: "https://docs.newapi.pro",
}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("general_setting", &generalSetting)
}
func GetGeneralSetting() *GeneralSetting {
return &generalSetting
}

View File

@@ -1,7 +1,8 @@
package common
package operation_setting
import (
"encoding/json"
"one-api/common"
"strings"
"sync"
)
@@ -50,24 +51,26 @@ var defaultModelRatio = map[string]float64{
"gpt-4o-realtime-preview-2024-12-17": 2.5,
"gpt-4o-mini-realtime-preview": 0.3,
"gpt-4o-mini-realtime-preview-2024-12-17": 0.3,
"o1": 7.5,
"o1-2024-12-17": 7.5,
"o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
"o1-mini": 0.55,
"o1-mini-2024-09-12": 0.55,
"o3-mini": 0.55,
"o3-mini-2025-01-31": 0.55,
"o3-mini-high": 0.55,
"o3-mini-2025-01-31-high": 0.55,
"o3-mini-low": 0.55,
"o3-mini-2025-01-31-low": 0.55,
"o3-mini-medium": 0.55,
"o3-mini-2025-01-31-medium": 0.55,
"gpt-4o-mini": 0.075,
"gpt-4o-mini-2024-07-18": 0.075,
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"o1": 7.5,
"o1-2024-12-17": 7.5,
"o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
"o1-mini": 0.55,
"o1-mini-2024-09-12": 0.55,
"o3-mini": 0.55,
"o3-mini-2025-01-31": 0.55,
"o3-mini-high": 0.55,
"o3-mini-2025-01-31-high": 0.55,
"o3-mini-low": 0.55,
"o3-mini-2025-01-31-low": 0.55,
"o3-mini-medium": 0.55,
"o3-mini-2025-01-31-medium": 0.55,
"gpt-4o-mini": 0.075,
"gpt-4o-mini-2024-07-18": 0.075,
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4.5-preview": 37.5,
"gpt-4.5-preview-2025-02-27": 37.5,
//"gpt-3.5-turbo-0301": 0.75, //deprecated
"gpt-3.5-turbo": 0.25,
"gpt-3.5-turbo-0613": 0.75,
@@ -83,92 +86,94 @@ var defaultModelRatio = map[string]float64{
"text-curie-001": 1,
//"text-davinci-002": 10,
//"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5,
"claude-3-7-sonnet-20250219": 1.5,
"claude-3-7-sonnet-20250219-thinking": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // 0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18,
@@ -257,7 +262,7 @@ func ModelPrice2JSONString() string {
GetModelPriceMap()
jsonBytes, err := json.Marshal(modelPriceMap)
if err != nil {
SysError("error marshalling model price: " + err.Error())
common.SysError("error marshalling model price: " + err.Error())
}
return string(jsonBytes)
}
@@ -281,7 +286,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
price, ok := modelPriceMap[name]
if !ok {
if printErr {
SysError("model price not found: " + name)
common.SysError("model price not found: " + name)
}
return -1, false
}
@@ -301,7 +306,7 @@ func ModelRatio2JSONString() string {
GetModelRatioMap()
jsonBytes, err := json.Marshal(modelRatioMap)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -313,23 +318,22 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
}
func GetModelRatio(name string) float64 {
func GetModelRatio(name string) (float64, bool) {
GetModelRatioMap()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
ratio, ok := modelRatioMap[name]
if !ok {
SysError("model ratio not found: " + name)
return 30
return 37.5, SelfUseModeEnabled
}
return ratio
return ratio, true
}
func DefaultModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(defaultModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -351,7 +355,7 @@ func CompletionRatio2JSONString() string {
GetCompletionRatioMap()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
common.SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -385,6 +389,9 @@ func GetCompletionRatio(name string) float64 {
}
return 4
}
if strings.HasPrefix(name, "gpt-4.5") {
return 2
}
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
return 3
}

View File

@@ -1,8 +1,9 @@
package setting
package operation_setting
import "strings"
var DemoSiteEnabled = false
var SelfUseModeEnabled = false
var AutomaticDisableKeywords = []string{
"Your credit balance is too low",

Some files were not shown because too many files have changed in this diff Show More