Compare commits

...

139 Commits

Author SHA1 Message Date
1808837298@qq.com
3a18c0ce9f fix: Correct model request configuration in Vertex Claude adaptor 2025-02-27 20:51:10 +08:00
1808837298@qq.com
929668bead feat: Refactor model configuration management with new config system
- Introduce a new configuration management approach for model-specific settings
- Update Gemini settings to use the new config system with more flexible management
- Add support for dynamic configuration updates in option handling
- Modify Claude and Vertex adaptors to use new configuration methods
- Enhance web interface to support namespaced configuration keys
2025-02-27 20:49:34 +08:00
1808837298@qq.com
06a78f9042 feat: Add Claude model configuration management #791 2025-02-27 20:49:21 +08:00
1808837298@qq.com
0f1c4c4ebe fix: Add pagination support to user search functionality 2025-02-27 16:55:02 +08:00
1808837298@qq.com
1bcf7a3c39 chore: Update Azure OpenAI API version and embedding model detection
- Enhance channel test to detect more embedding models
- Update Azure OpenAI default API version to 2024-12-01-preview
- Remove redundant default API version setting in channel edit
- Add user cache writing in channel test
2025-02-27 16:49:32 +08:00
1808837298@qq.com
5f0b3f6d6f fix: Improve AWS Claude adaptor request conversion error handling #796 2025-02-27 14:57:00 +08:00
1808837298@qq.com
19a318c943 init openrouter adaptor 2025-02-27 00:01:21 +08:00
1808837298@qq.com
13ab0f8e4f fix: gemini&claude tool call format #795 #766 2025-02-26 23:56:10 +08:00
1808837298@qq.com
6d8d40e67b fix: claude tool call format #795 #766 2025-02-26 23:40:16 +08:00
1808837298@qq.com
287caf8e38 feat: Add Jina reranking support for OpenAI adaptor 2025-02-26 21:46:06 +08:00
1808837298@qq.com
c802b3b41a fix: Update Gemini safety settings to use 'OFF' as default 2025-02-26 19:20:17 +08:00
1808837298@qq.com
ed4e1c2332 fix: Update Gemini safety settings category 2025-02-26 19:18:00 +08:00
1808837298@qq.com
e581ea33c2 fix: Update Gemini safety settings default value 2025-02-26 19:01:45 +08:00
1808837298@qq.com
bf80d71ddf feat: Add Gemini version settings configuration support (close #568) 2025-02-26 18:19:09 +08:00
1808837298@qq.com
e19b244e73 feat: Add Gemini safety settings configuration support (close #703) 2025-02-26 16:54:43 +08:00
1808837298@qq.com
f451268830 feat: Update Claude relay temperature setting 2025-02-25 22:01:05 +08:00
1808837298@qq.com
069f2672c1 refactor: Enhance user context and quota management
- Add new context keys for user-related information
- Modify user cache and authentication middleware to populate context
- Refactor quota and notification services to use context-based user data
- Remove redundant database queries by leveraging context information
- Update various components to use new context-based user retrieval methods
2025-02-25 20:56:16 +08:00
1808837298@qq.com
ccf13d445f feat: redis poolsize 2025-02-25 19:39:29 +08:00
1808837298@qq.com
da4d1861fe fix: Adjust Claude thinking mode request parameters 2025-02-25 16:52:45 +08:00
1808837298@qq.com
3de5b96cb4 docs: Update README 2025-02-25 16:31:42 +08:00
Calcium-Ion
5b9e275690 Merge pull request #788 from MartialBE/main
feat: Add Claude 3.7 Sonnet thinking mode support
2025-02-25 15:21:39 +08:00
1808837298@qq.com
607e3206b3 Merge branch 'main' into thinking
# Conflicts:
#	relay/channel/claude/dto.go
2025-02-25 15:21:22 +08:00
1808837298@qq.com
83feb492fb feat: Add support for Claude thinking parameter in request 2025-02-25 14:37:03 +08:00
MartialBE
4f212be45c feat: Add Claude 3.7 Sonnet thinking mode support 2025-02-25 14:10:43 +08:00
1808837298@qq.com
92918e3751 feat: Add Claude 3.7 Sonnet model to AWS channel mapping 2025-02-25 02:55:23 +08:00
1808837298@qq.com
de15551570 feat: Add support for Claude 3.7 Sonnet model 2025-02-25 02:51:31 +08:00
1808837298@qq.com
a81a28b7a5 feat: Support max_tokens parameter for Ollama channel #782 2025-02-24 17:35:49 +08:00
Calcium-Ion
dc36fdedc2 Merge pull request #781 from zeyugao/main
feat: Pass extra_body in OpenAI request to the backend
2025-02-24 16:29:48 +08:00
Calcium-Ion
3017882fa3 Merge pull request #783 from Calcium-Ion/rate-limit
feat: Add model request rate limiting functionality
2025-02-24 16:29:23 +08:00
1808837298@qq.com
e9ba392af8 feat: Add model rate limit settings in system configuration 2025-02-24 16:27:20 +08:00
1808837298@qq.com
83a37e4653 feat: Add model request rate limiting functionality 2025-02-24 16:20:55 +08:00
1808837298@qq.com
b6f95dca41 feat: Add support for different Dify bot types and request URLs 2025-02-24 14:18:30 +08:00
1808837298@qq.com
7ff4cebdbe feat: Enhance token counting and content parsing for messages 2025-02-24 14:18:15 +08:00
Elsa
af00f7b311 Pass extra_body to the backend 2025-02-24 10:52:55 +08:00
1808837298@qq.com
cc1d6e1c05 fix: Improve 429 error logging with detailed message 2025-02-23 21:26:31 +08:00
1808837298@qq.com
6c7a8c811c fix typo 2025-02-23 17:27:33 +08:00
1808837298@qq.com
d5ab7d2d34 feat: Add thinking-to-content option in channel extra settings #780 2025-02-23 17:13:08 +08:00
1808837298@qq.com
115a181db3 feat: Add thinking-to-content conversion for stream responses 2025-02-23 17:05:57 +08:00
1808837298@qq.com
88a2fec190 fix: mistral 2025-02-22 16:29:48 +08:00
1808837298@qq.com
27ea231d66 fix: fix image ratio calculation 2025-02-22 15:50:18 +08:00
Calcium-Ion
4b6101b3ea Merge pull request #778 from utopeadia/main
美化日志界面刷新图标
2025-02-22 15:21:28 +08:00
1808837298@qq.com
48926b8a5a fix: Ensure correct quota warning threshold type conversion 2025-02-22 15:19:55 +08:00
1808837298@qq.com
c44a32efe0 chore: update rerank.md 2025-02-22 15:13:26 +08:00
HowieWood
c541d6c97e 进一步美化刷新图标 2025-02-22 14:18:25 +08:00
HowieWood
7dfcd135da 优化日志刷新图标显示 2025-02-22 14:12:49 +08:00
1808837298@qq.com
7a13fab271 fix: ShouldDisableChannel 2025-02-22 02:02:03 +08:00
1808837298@qq.com
bf75b30870 fix: mistral adaptor (close #774) 2025-02-21 22:21:19 +08:00
1808837298@qq.com
6e7587ab46 feat: Add reasoning content support in OpenAI response handling 2025-02-21 18:52:51 +08:00
1808837298@qq.com
cc5066c510 refactor: Improve message content parsing with robust type handling 2025-02-21 18:27:43 +08:00
1808837298@qq.com
b9b69b01e5 refactor: Improve message content handling and quota error responses 2025-02-21 18:18:21 +08:00
1808837298@qq.com
1f4f9123aa refactor: Optimize sensitive word detection and text processing 2025-02-21 17:05:35 +08:00
1808837298@qq.com
9cc6385b0c feat: Enhance sensitive word detection with detailed logging 2025-02-21 16:57:30 +08:00
1808837298@qq.com
2d42145b66 refactor: Improve quota error messages with formatted quota display 2025-02-21 16:42:48 +08:00
1808837298@qq.com
94736407a0 feat: Add base URL input with localized tooltip for channel configuration 2025-02-21 16:17:59 +08:00
1808837298@qq.com
de859c3cc9 feat: Add localization for notification and webhook settings 2025-02-21 15:36:24 +08:00
Calcium-Ion
8dd4ce986c Merge pull request #775 from Calcium-Ion/model_mappping
refactor: Simplify model mapping and pricing logic across relay modules
2025-02-20 16:42:23 +08:00
1808837298@qq.com
06da65a9d0 refactor: Simplify model mapping and pricing logic across relay modules 2025-02-20 16:41:46 +08:00
1808837298@qq.com
60aac77c08 fix: Correct Ollama channel authentication header setting 2025-02-20 01:28:15 +08:00
Calcium-Ion
6e0046f73c Merge pull request #773 from wellcoming/patch-1
fix: Fix Ollama channel authentication
2025-02-20 01:26:12 +08:00
Coming
a13f4d6c56 fix: Fix Ollama channel authentication 2025-02-20 00:52:30 +08:00
CalciumIon
4ce12ea6e3 feat: Improve mobile text truncation and sidebar visibility 2025-02-19 23:25:42 +08:00
1808837298@qq.com
971aea09ee feat: Improve image handling for Ollama channels 2025-02-19 20:45:42 +08:00
1808837298@qq.com
a4b2b9c935 feat: Enhance Ollama channel support with additional request parameters #771 2025-02-19 19:58:34 +08:00
1808837298@qq.com
ae5875d4c7 fix: Remove redundant error handling in distributor and relay modules 2025-02-19 18:47:28 +08:00
1808837298@qq.com
5937d850d9 refactor: Replace manual goroutine creation with gopool.Go 2025-02-19 18:38:29 +08:00
Calcium-Ion
2b7435500c Merge pull request #770 from Calcium-Ion/refactor_notify
feat: Add user notification settings and multiple notification methods
2025-02-19 14:54:54 +07:00
1808837298@qq.com
90191b8d5b chore: update env name and README 2025-02-19 15:54:33 +08:00
1808837298@qq.com
585c19fc70 docs: Add proxy usage information note in SystemSetting component 2025-02-19 15:45:09 +08:00
1808837298@qq.com
4e871507cf feat: Implement comprehensive webhook notification system 2025-02-19 15:40:54 +08:00
1808837298@qq.com
b1847509a4 refactor: Optimize user caching and token retrieval methods 2025-02-19 15:12:26 +08:00
Calcium-Ion
63f3412394 Merge pull request #768 from lgphone/main
bugfix: 配置文件 .env.example 示例配置错误
2025-02-18 19:35:08 +07:00
lgphone
a13bea5ffa Update .env.example
修复示例配置中MySQL的DSN错误问题
2025-02-18 19:18:54 +08:00
Calcium-Ion
2e3b920a2c Merge pull request #763 from Sh1n3zZ/support-imagen-3.0-generate-002
feat: add Gemini Imagen image generation support
2025-02-18 15:32:32 +07:00
1808837298@qq.com
812c188ab1 fix: Extend temperature handling for OpenAI-like models
- Add support for suppressing temperature for o1 models
- Expand model prefix check to include 'o1' alongside 'o3' models
2025-02-18 16:00:56 +08:00
1808837298@qq.com
0907a078b4 refactor: Simplify root user notification and remove global email variable
- Remove global `RootUserEmail` variable
- Modify channel testing and user notification methods to use `GetRootUser()`
- Update user cache and notification service to use more consistent user base type
- Add new channel test notification type
2025-02-18 15:59:17 +08:00
1808837298@qq.com
56f6b2ab56 feat: Implement notification rate limiting mechanism
- Add in-memory and Redis-based notification rate limiting
- Create configurable hourly notification limits
- Implement notification limit checking for user notifications
- Add environment variables for customizing notification limits
2025-02-18 15:30:43 +08:00
1808837298@qq.com
9d9c461c48 refactor: Improve CompletionRatio handling with thread-safe access and initialization 2025-02-18 15:01:43 +08:00
1808837298@qq.com
3da1344897 feat: Add user notification settings with quota warning and multiple notification methods
- Implement user notification settings with email and webhook options
- Add new user settings for quota warning threshold and notification preferences
- Create backend API and database support for user notification configuration
- Enhance frontend personal settings with notification configuration UI
- Support custom notification email and webhook URL
- Add service layer for sending user notifications
2025-02-18 14:54:21 +08:00
Sh1n3zZ
61d2a2f92d feat: add Gemini Imagen image generation support 2025-02-18 01:41:58 +08:00
1808837298@qq.com
995b3a2403 Merge remote-tracking branch 'origin/main' 2025-02-17 18:15:13 +08:00
1808837298@qq.com
7b384cb933 feat: Add support for DeepSeek completions endpoint 2025-02-17 18:15:01 +08:00
Calcium-Ion
78f19d4690 Merge pull request #735 from jyc001/main
feat:Add Supoorts to FIM
2025-02-17 14:37:06 +07:00
1808837298@qq.com
3239c60535 refactor: Optimize channel testing and model menu generation (fix #761) 2025-02-15 19:12:28 +08:00
1808837298@qq.com
e6f4587f6f refactor: Improve channel property update mechanism (fix #761) 2025-02-15 15:30:55 +08:00
Calcium-Ion
814be84500 Merge pull request #759 from nightcoffee/patch-1
feat: add 火山引擎 support stream options
2025-02-15 14:22:04 +07:00
nightcoffee
e7e5a16767 feat: add 火山引擎 support stream options 2025-02-15 04:55:57 +08:00
1808837298@qq.com
6bf99f218c feat: Enhance VolcEngine channel support with bot model routing (fix #757) 2025-02-15 00:10:58 +08:00
1808837298@qq.com
bd4ce9cd91 fix: Improve OpenAI stream data parsing and handling 2025-02-14 23:52:25 +08:00
1808837298@qq.com
9edb9f7a71 feat: Add automatic channel disabling based on configurable keywords
- Introduce AutomaticDisableKeywords setting to dynamically control channel disabling
- Implement AC search for matching error messages against disable keywords
- Add frontend UI for configuring automatic disable keywords
- Update localization with new keyword-based channel disabling feature
- Refactor sensitive word and AC search logic to support multiple keyword lists
2025-02-13 16:39:17 +08:00
1808837298@qq.com
bc62d1bb81 refactor: Optimize log retrieval with separate channel name fetching (fix #751)
- Remove inline channel join in log queries
- Implement separate channel name lookup for logs
- Improve performance by fetching channel names in a single query
- Ensure channel names are correctly associated with logs
2025-02-12 19:19:13 +08:00
1808837298@qq.com
6b923ef728 feat: Add invite link banner for specific channel type 2025-02-12 17:48:48 +08:00
1808837298@qq.com
81591f20e0 refactor: Optimize Dockerfile for Go build process
- Use alpine-based Golang image for smaller build size
- Simplify Go build command by removing static linking flag
- Improve Docker multi-stage build configuration
2025-02-12 17:18:23 +08:00
1808837298@qq.com
2072376694 docs: Update README with detailed Docker deployment and update instructions 2025-02-12 16:54:53 +08:00
1808837298@qq.com
871d73ecc9 fix: Update BaseURL placeholder text and label in channel edit page 2025-02-12 15:39:18 +08:00
1808837298@qq.com
f5e3063f33 feat: Improve embedding request handling and support across channels
- Update EmbeddingRequest DTO to support more flexible input types
- Add input parsing method to handle various input formats
- Implement ConvertEmbeddingRequest for multiple channel adaptors
- Remove relayMode parameter from EmbeddingHelper
- Add input validation for embedding requests
- Simplify embedding request conversion for different channels
2025-02-12 14:39:36 +08:00
1808837298@qq.com
eceb6afcdd feat: Add Baidu Qianfan V2 channel support #725
- Update channel constants to include Baidu V2 channel
- Create new Baidu V2 adaptor for relay
- Add Baidu V2 models and channel configuration
- Update relay adaptor to support Baidu V2 channel
- Modify web channel constants to include Baidu V2 option
2025-02-12 00:07:02 +08:00
1808837298@qq.com
28c13e5a0f feat: Add support for VolcEngine (Doubao) channel #313 #734 2025-02-11 23:47:15 +08:00
Calcium-Ion
81d11e5d31 Merge pull request #714 from NitroRCr/main
feat:  添加 AIaW 的聊天链接
2025-02-11 22:17:49 +07:00
Calcium-Ion
88bdedd2c9 Merge pull request #723 from kuwork/main
Support for MokaAI M3E
2025-02-11 22:16:18 +07:00
1808837298@qq.com
cf0ff0371b fix: adjust max tokens configuration in test request builder
- Update max tokens default value to 10
2025-02-11 20:00:05 +08:00
1808837298@qq.com
1f527ffc50 feat: enhance OpenAI request and response DTOs
- Add `Prefix` and `ReasoningContent` fields to Message struct
- Add getter and setter methods for `Prefix`
- Make `ToolCall.ID` field optional (fix #749)
2025-02-11 19:54:54 +08:00
1808837298@qq.com
cad8a83260 chore: disable cgo 2025-02-11 18:51:27 +08:00
1808837298@qq.com
40d878e8a9 chore: disable cgo 2025-02-11 18:51:09 +08:00
1808837298@qq.com
3a2e22443f chore: replace sqlite lib with prue go lib 2025-02-11 18:34:34 +08:00
1808837298@qq.com
13d1b8203c chore: update CI 2025-02-11 18:23:20 +08:00
1808837298@qq.com
7fce084aa5 update CI 2025-02-11 17:44:54 +08:00
1808837298@qq.com
cb4d40c3c8 feat: enhance session store security and configuration
- Add 30-day max age for session cookies
- Enable HttpOnly flag
- Set SameSite to strict mode
2025-02-11 17:06:51 +08:00
1808837298@qq.com
bbc1550a9e fix: update session store configuration
- Change session cookie path from "/api" to "/"
- Remove HttpOnly flag
2025-02-11 15:53:15 +08:00
1808837298@qq.com
6acc37cf27 feat: configure session store options for API routes
- Set session cookie path to "/api"
- Disable secure flag for local development
- Enable HttpOnly flag for improved security
2025-02-11 15:45:24 +08:00
Calcium-Ion
0e89939a12 Merge pull request #746 from zjjxwhh/main
fix: always use modelMapping in channel test
2025-02-11 12:21:06 +07:00
1808837298@qq.com
1b4fe8600e chore: update CI 2025-02-11 13:14:38 +08:00
zjjxwhh
882c5970d9 fix: always use modelMapping in channel test 2025-02-10 22:39:56 +08:00
1808837298@qq.com
d10b47005c chore: update CI 2025-02-10 21:59:41 +08:00
1808837298@qq.com
8418dbe7c4 fix: replace context-based user ID with session-based retrieval #741
- Update user and wechat controllers to use sessions for user ID
- Modify ID retrieval to use `session.Get("id")` instead of `c.GetInt("id")`
- Cast session ID to int when creating user object
2025-02-10 20:52:33 +08:00
1808837298@qq.com
68c559c119 fix: CI #744 2025-02-10 20:39:04 +08:00
1808837298@qq.com
2c2d1da227 Merge remote-tracking branch 'origin/main' 2025-02-10 20:34:11 +08:00
1808837298@qq.com
39aacf5fb6 refactor: improve SSE response handling in Playground
- Simplify event listener logic for streaming responses
- Add null-safe checks for payload content
- Optimize message generation and completion flow
2025-02-10 20:24:14 +08:00
Calcium-Ion
ec50f665a7 Merge pull request #736 from xy3xy3/main
更正硅基流动的SenseVoiceSmall模型名字
2025-02-09 12:23:34 +07:00
Calcium-Ion
1a09b1aed6 Merge pull request #742 from HynoR/chore/ds
chore: 同步deepseek价格
2025-02-09 12:23:10 +07:00
HynoR
34fdac38bf chore: 同步deepseek价格 2025-02-09 12:35:37 +08:00
xy3
8910efb1da 更正硅基流动的SenseVoiceSmall模型名字 2025-02-08 11:54:08 +08:00
e.
206dbfa45e Merge pull request #2 from jyc001/dev
fix: correct JSON tags for `Prompt` and `Suffix` in `GeneralOpenAIReq…
2025-02-08 00:37:37 +08:00
e.
1eb72f2f22 fix: correct JSON tags for Prompt and Suffix in GeneralOpenAIRequest 2025-02-08 00:36:42 +08:00
e.
68bd7f70a4 Merge pull request #1 from jyc001/dev
Dev
2025-02-08 00:25:49 +08:00
e.
8082905184 feat: add Suffix to GeneralOpenAIRequest in order to support FIM 2025-02-08 00:25:08 +08:00
e.
ce4269955e feat add FIM support for siliconflow 2025-02-08 00:23:35 +08:00
1808837298@qq.com
70083ecd27 fix: channels model_mapping 2025-02-06 19:51:33 +08:00
1808837298@qq.com
f7a4016d53 fix: update logs table total count display
- Replace `logs.length` with `logCount` in pagination information
- Ensure accurate total log count is displayed in the logs table
2025-02-06 14:56:23 +08:00
Calcium-Ion
562c66330c Merge pull request #727 from HynoR/feat/autogemini
chore: 同步gemini模型
2025-02-06 13:43:13 +07:00
1808837298@qq.com
675e62d854 feat: modify channel model_mapping column type to TEXT
- Change `ModelMapping` column type from varchar(1024) to TEXT in channels table
- Add MySQL migration script to alter column type during database initialization
- Improve database schema flexibility for storing complex model mappings
2025-02-06 14:35:14 +08:00
HynoR
efdd6fb657 chore: sync gemini aistudio model 2025-02-06 13:32:19 +08:00
kuwork
89d48a6618 Merge branch 'main' into main 2025-02-04 22:52:37 +08:00
1808837298@qq.com
0f5c090ad6 feat: add SOCKS5 proxy authentication support
- Enhance `NewProxyHttpClient` to handle SOCKS5 proxy authentication
- Extract username and password from proxy URL for SOCKS5 proxy configuration
- Provide optional authentication for SOCKS5 proxy connections
2025-02-04 18:10:25 +08:00
1808837298@qq.com
a0fe527047 feat: add demo site configuration flag
- Introduce `DemoSiteEnabled` variable in operation settings
- Provide a configurable flag to enable/disable demo site functionality
2025-02-04 14:15:01 +08:00
1808837298@qq.com
187c336121 feat: add Azure default API version configuration
- Introduce `AZURE_DEFAULT_API_VERSION` environment variable
- Set default Azure API version to `2024-12-01-preview`
- Update README documentation for new environment configuration
- Modify Azure channel relay to use default API version when not specified
2025-02-03 22:38:23 +08:00
NitroRCr
324d127a88 feat: add chat link for AIaW 2025-01-25 11:57:54 +08:00
Jerry
7588c42b42 Fix M3E not working 2025-01-23 05:54:39 +08:00
Jerry
8a2d220cf4 fix : chanel test did not refresh 2025-01-22 13:16:06 +08:00
Jerry
126f04e08f Support for MokaAI M3E 2025-01-22 04:21:08 +08:00
151 changed files with 5102 additions and 2841 deletions

View File

@@ -10,9 +10,9 @@
# 数据库相关配置 # 数据库相关配置
# 数据库连接字符串 # 数据库连接字符串
# SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true # SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# 日志数据库连接字符串 # 日志数据库连接字符串
# LOG_SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true # LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# SQLite数据库路径 # SQLite数据库路径
# SQLITE_PATH=/path/to/sqlite.db # SQLITE_PATH=/path/to/sqlite.db
# 数据库最大空闲连接数 # 数据库最大空闲连接数
@@ -50,10 +50,6 @@
# CHANNEL_TEST_FREQUENCY=10 # CHANNEL_TEST_FREQUENCY=10
# 生成默认token # 生成默认token
# GENERATE_DEFAULT_TOKEN=false # GENERATE_DEFAULT_TOKEN=false
# Gemini 安全设置
# GEMINI_SAFETY_SETTING=BLOCK_NONE
# Gemini版本设置
# GEMINI_MODEL_MAP=gemini-1.0-pro:v1
# Cohere 安全设置 # Cohere 安全设置
# COHERE_SAFETY_SETTING=NONE # COHERE_SAFETY_SETTING=NONE
# 是否统计图片token # 是否统计图片token

View File

@@ -13,7 +13,7 @@ on:
jobs: jobs:
push_to_registries: push_to_registries:
name: Push Docker image to multiple registries name: Push Docker image to multiple registries
runs-on: self-hosted runs-on: ubuntu-latest
permissions: permissions:
packages: write packages: write
contents: read contents: read

View File

@@ -9,7 +9,7 @@ on:
- '!*-alpha*' - '!*-alpha*'
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: macos-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@@ -9,7 +9,7 @@ on:
- '!*-alpha*' - '!*-alpha*'
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: windows-latest
defaults: defaults:
run: run:
shell: bash shell: bash

View File

@@ -1,4 +1,4 @@
FROM oven/bun:latest as builder FROM oven/bun:latest AS builder
WORKDIR /build WORKDIR /build
COPY web/package.json . COPY web/package.json .
@@ -7,18 +7,20 @@ COPY ./web .
COPY ./VERSION . COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
FROM golang AS builder2 FROM golang:alpine AS builder2
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=1 \ CGO_ENABLED=0 \
GOOS=linux GOOS=linux
WORKDIR /build WORKDIR /build
ADD go.mod go.sum ./ ADD go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
COPY --from=builder /build/dist ./web/dist COPY --from=builder /build/dist ./web/dist
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-api
FROM alpine FROM alpine

View File

@@ -59,6 +59,12 @@
13. 🎵 Added [Suno API](https://github.com/Suno-API/Suno-API) interface support, [Integration Guide](Suno.md) 13. 🎵 Added [Suno API](https://github.com/Suno-API/Suno-API) interface support, [Integration Guide](Suno.md)
14. 🔄 Support for Rerank models, compatible with Cohere and Jina, can integrate with Dify, [Integration Guide](Rerank.md) 14. 🔄 Support for Rerank models, compatible with Cohere and Jina, can integrate with Dify, [Integration Guide](Rerank.md)
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - Support for OpenAI's Realtime API, including Azure channels 15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - Support for OpenAI's Realtime API, including Azure channels
16. 🧠 Support for setting reasoning effort through model name suffix:
- Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`)
- Add suffix `-medium` to set medium reasoning effort
- Add suffix `-low` to set low reasoning effort
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings`
## Model Support ## Model Support
This version additionally supports: This version additionally supports:
@@ -84,15 +90,15 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable - `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20` - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
- `CRYPTO_SECRET`: Encryption key for encrypting database content - `CRYPTO_SECRET`: Encryption key for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
## Deployment ## Deployment
> [!TIP] > [!TIP]
> Latest Docker image: `calciumion/new-api:latest` > Latest Docker image: `calciumion/new-api:latest`
> Default account: root, password: 123456 > Default account: root, password: 123456
> Update command:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
### Multi-Server Deployment ### Multi-Server Deployment
- Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers. - Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers.
@@ -102,26 +108,58 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- Local database (default): SQLite (Docker deployment must mount `/data` directory) - Local database (default): SQLite (Docker deployment must mount `/data` directory)
- Remote database: MySQL >= 5.7.8, PgSQL >= 9.6 - Remote database: MySQL >= 5.7.8, PgSQL >= 9.6
### Deployment with BT Panel
Install BT Panel (**version 9.2.0** or above) from [BT Panel Official Website](https://www.bt.cn/new/download.html), choose the stable version script to download and install.
After installation, log in to BT Panel and click Docker in the menu bar. First-time access will prompt to install Docker service. Click Install Now and follow the prompts to complete installation.
After installation, find **New-API** in the app store, click install, configure basic options to complete installation.
[Pictorial Guide](BT.md)
### Docker Deployment ### Docker Deployment
### Using Docker Compose (Recommended) ### Using Docker Compose (Recommended)
```shell ```shell
# Clone project # Clone project
git clone https://github.com/Calcium-Ion/new-api.git git clone https://github.com/Calcium-Ion/new-api.git
cd new-api cd new-api
# Edit docker-compose.yml as needed # Edit docker-compose.yml as needed
# nano docker-compose.yml
# vim docker-compose.yml
# Start # Start
docker-compose up -d docker-compose up -d
``` ```
#### Update Version
```shell
docker-compose pull
docker-compose up -d
```
### Direct Docker Image Usage ### Direct Docker Image Usage
```shell ```shell
# SQLite deployment: # SQLite deployment:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# MySQL deployment (add -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"), modify database connection parameters as needed # MySQL deployment (add -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"), modify database connection parameters as needed
# Example: # Example:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
``` ```
#### Update Version
```shell
# Pull the latest image
docker pull calciumion/new-api:latest
# Stop and remove the old container
docker stop new-api
docker rm new-api
# Run the new container with the same parameters as before
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
Alternatively, you can use Watchtower for automatic updates (not recommended, may cause database incompatibility):
```shell
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
```
## Channel Retry ## Channel Retry
Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**. Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
First retry uses same priority, second retry uses next priority, and so on. First retry uses same priority, second retry uses next priority, and so on.

View File

@@ -66,9 +66,14 @@
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API支持Azure渠道 15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API支持Azure渠道
16. 支持使用路由/chat2link 进入聊天界面 16. 支持使用路由/chat2link 进入聊天界面
17. 🧠 支持通过模型名称后缀设置 reasoning effort 17. 🧠 支持通过模型名称后缀设置 reasoning effort
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`) 1. OpenAI o系列模型
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`) - 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`) - 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
2. Claude 思考模型
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制
## 模型支持 ## 模型支持
此版本额外支持以下模型: 此版本额外支持以下模型:
@@ -89,19 +94,23 @@
- `GET_MEDIA_TOKEN`是否统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。 - `GET_MEDIA_TOKEN`是否统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true` - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。 - `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用"模型:版本"指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE` - `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认为 `16`,设置为 `-1` 则不限制。 - `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB默认为 `20` - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB默认为 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。 - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本如果渠道设置中未指定API版本则使用此版本默认为 `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
## 已废弃的环境变量
- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置
- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置
## 部署 ## 部署
> [!TIP] > [!TIP]
> 最新版Docker镜像`calciumion/new-api:latest` > 最新版Docker镜像`calciumion/new-api:latest`
> 默认账号root 密码123456 > 默认账号root 密码123456
> 更新指令:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
### 多机部署 ### 多机部署
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。 - 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。
@@ -118,25 +127,54 @@
[图文教程](BT.md) [图文教程](BT.md)
### 基于 Docker 进行部署 ### 基于 Docker 进行部署
> [!TIP]
> 默认管理员账号root 密码123456
### 使用 Docker Compose 部署(推荐) ### 使用 Docker Compose 部署(推荐)
```shell ```shell
# 下载项目 # 下载项目
git clone https://github.com/Calcium-Ion/new-api.git git clone https://github.com/Calcium-Ion/new-api.git
cd new-api cd new-api
# 按需编辑 docker-compose.yml # 按需编辑 docker-compose.yml
# nano docker-compose.yml
# vim docker-compose.yml
# 启动 # 启动
docker-compose up -d docker-compose up -d
``` ```
#### 更新版本
```shell
docker-compose pull
docker-compose up -d
```
### 直接使用 Docker 镜像 ### 直接使用 Docker 镜像
```shell ```shell
# 使用 SQLite 的部署命令: # 使用 SQLite 的部署命令:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。 # 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
# 例如: # 例如:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
``` ```
#### 更新版本
```shell
# 拉取最新镜像
docker pull calciumion/new-api:latest
# 停止并删除旧容器
docker stop new-api
docker rm new-api
# 使用相同参数运行新容器
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
或者使用 Watchtower 自动更新(不推荐,可能会导致数据库不兼容):
```shell
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
```
## 渠道重试 ## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。

View File

@@ -13,7 +13,7 @@ Request:
```json ```json
{ {
"model": "rerank-multilingual-v3.0", "model": "jina-reranker-v2-base-multilingual",
"query": "What is the capital of the United States?", "query": "What is the capital of the United States?",
"top_n": 3, "top_n": 3,
"documents": [ "documents": [

View File

@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
var RetryTimes = 0 var RetryTimes = 0
var RootUserEmail = "" //var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
@@ -231,8 +231,10 @@ const (
ChannelTypeVertexAi = 41 ChannelTypeVertexAi = 41
ChannelTypeMistral = 42 ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43 ChannelTypeDeepSeek = 43
ChannelTypeMokaAI = 44
ChannelTypeDummy // this one is only for count, do not add any channel after this ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeDummy // this one is only for count, do not add any channel after this
) )
@@ -274,11 +276,14 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34 "https://api.cohere.ai", //34
"https://api.minimax.chat", //35 "https://api.minimax.chat", //35
"", //36 "", //36
"", //37 "https://api.dify.ai", //37
"https://api.jina.ai", //38 "https://api.jina.ai", //38
"https://api.cloudflare.com", //39 "https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40 "https://api.siliconflow.cn", //40
"", //41 "", //41
"https://api.mistral.ai", //42 "https://api.mistral.ai", //42
"https://api.deepseek.com", //43 "https://api.deepseek.com", //43
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
} }

View File

@@ -3,5 +3,6 @@ package common
var UsingSQLite = false var UsingSQLite = false
var UsingPostgreSQL = false var UsingPostgreSQL = false
var UsingMySQL = false var UsingMySQL = false
var UsingClickHouse = false
var SQLitePath = "one-api.db?_busy_timeout=5000" var SQLitePath = "one-api.db?_busy_timeout=5000"

View File

@@ -1,22 +1,9 @@
package common package common
import ( import (
"fmt"
"runtime/debug"
"time" "time"
) )
func SafeGoroutine(f func()) {
go func() {
defer func() {
if r := recover(); r != nil {
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
}
}()
f()
}()
}
func SafeSendBool(ch chan bool, value bool) (closed bool) { func SafeSendBool(ch chan bool, value bool) (closed bool) {
defer func() { defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed. // Recover from panic if one occured. A panic would mean the channel was closed.

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"log" "log"
@@ -80,9 +81,9 @@ func logHelper(ctx context.Context, level string, msg string) {
if logCount > maxLogCount && !setupLogWorking { if logCount > maxLogCount && !setupLogWorking {
logCount = 0 logCount = 0
setupLogWorking = true setupLogWorking = true
go func() { gopool.Go(func() {
SetupLogger() SetupLogger()
}() })
} }
} }
@@ -100,6 +101,14 @@ func LogQuota(quota int) string {
} }
} }
func FormatQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f", float64(quota)/QuotaPerUnit)
} else {
return fmt.Sprintf("%d", quota)
}
}
// LogJson 仅供测试使用 only for test // LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) { func LogJson(ctx context.Context, msg string, obj any) {
jsonStr, err := json.Marshal(obj) jsonStr, err := json.Marshal(obj)

View File

@@ -83,92 +83,94 @@ var defaultModelRatio = map[string]float64{
"text-curie-001": 1, "text-curie-001": 1,
//"text-davinci-002": 10, //"text-davinci-002": 10,
//"text-davinci-003": 10, //"text-davinci-003": 10,
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015 "tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015 "tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03 "tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03 "tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10, "davinci": 10,
"curie": 10, "curie": 10,
"babbage": 10, "babbage": 10,
"ada": 10, "ada": 10,
"text-embedding-3-small": 0.01, "text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065, "text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05, "text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1, "text-moderation-stable": 0.1,
"text-moderation-latest": 0.1, "text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens "claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens "claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens "claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens "claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens "claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens "claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5, "claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5, "claude-3-5-sonnet-20241022": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens "claude-3-7-sonnet-20250219": 1.5,
"ERNIE-4.0-8K": 0.120 * RMB, "claude-3-7-sonnet-20250219-thinking": 1.5,
"ERNIE-3.5-8K": 0.012 * RMB, "claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-3.5-8K-0205": 0.024 * RMB, "ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB, "ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB, "ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB, "ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB, "ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB, "ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB, "ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB, "ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB, "ERNIE-Lite-8K-0922": 0.008 * RMB,
"BLOOMZ-7B": 0.004 * RMB, "ERNIE-Lite-8K-0308": 0.003 * RMB,
"Embedding-V1": 0.002 * RMB, "ERNIE-Tiny-8K": 0.001 * RMB,
"bge-large-zh": 0.002 * RMB, "BLOOMZ-7B": 0.004 * RMB,
"bge-large-en": 0.002 * RMB, "Embedding-V1": 0.002 * RMB,
"tao-8k": 0.002 * RMB, "bge-large-zh": 0.002 * RMB,
"PaLM-2": 1, "bge-large-en": 0.002 * RMB,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "tao-8k": 0.002 * RMB,
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "PaLM-2": 1,
"gemini-1.0-pro-vision-001": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-001": 1, "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens "gemini-1.0-pro-vision-001": 1,
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens "gemini-1.0-pro-001": 1,
"gemini-1.5-flash-latest": 1, "gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-exp-0827": 1, "gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.0-pro-latest": 1, "gemini-1.5-flash-latest": 1,
"gemini-1.0-pro-vision-latest": 1, "gemini-1.5-flash-exp-0827": 1,
"gemini-ultra": 1, "gemini-1.0-pro-latest": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "gemini-1.0-pro-vision-latest": 1,
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "gemini-ultra": 1,
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens "glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572, "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-plus": 0.05 * RMB, "glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-4-0520": 0.1 * RMB, "glm-3-turbo": 0.3572,
"glm-4-air": 0.001 * RMB, "glm-4-plus": 0.05 * RMB,
"glm-4-airx": 0.01 * RMB, "glm-4-0520": 0.1 * RMB,
"glm-4-long": 0.001 * RMB, "glm-4-air": 0.001 * RMB,
"glm-4-flash": 0, "glm-4-airx": 0.01 * RMB,
"glm-4v-plus": 0.01 * RMB, "glm-4-long": 0.001 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens "glm-4-flash": 0,
"qwen-plus": 10, // ¥0.14 / 1k tokens "glm-4v-plus": 0.01 * RMB,
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "qwen-plus": 10, // ¥0.14 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858, "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "SparkDesk-v3.5": 1.2858, // 0.018 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens "SparkDesk-v4.0": 1.2858,
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens "360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens "360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元 // https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格 // 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18, "yi-34b-chat-0205": 0.18,
@@ -191,8 +193,9 @@ var defaultModelRatio = map[string]float64{
"command-r-plus": 1.5, "command-r-plus": 1.5,
"command-r-08-2024": 0.075, "command-r-08-2024": 0.075,
"command-r-plus-08-2024": 1.25, "command-r-plus-08-2024": 1.25,
"deepseek-chat": 0.07, "deepseek-chat": 0.27 / 2,
"deepseek-coder": 0.07, "deepseek-coder": 0.27 / 2,
"deepseek-reasoner": 0.55 / 2, // 0.55 / 1k tokens
// Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用 // Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用
"llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD, "llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD,
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD, "llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
@@ -232,7 +235,11 @@ var (
modelRatioMapMutex = sync.RWMutex{} modelRatioMapMutex = sync.RWMutex{}
) )
var CompletionRatio map[string]float64 = nil var (
CompletionRatio map[string]float64 = nil
CompletionRatioMutex = sync.RWMutex{}
)
var defaultCompletionRatio = map[string]float64{ var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2, "gpt-4-gizmo-*": 2,
"gpt-4o-gizmo-*": 3, "gpt-4o-gizmo-*": 3,
@@ -333,10 +340,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio return defaultModelRatio
} }
func CompletionRatio2JSONString() string { func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
if CompletionRatio == nil { if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio CompletionRatio = defaultCompletionRatio
} }
return CompletionRatio
}
func CompletionRatio2JSONString() string {
GetCompletionRatioMap()
jsonBytes, err := json.Marshal(CompletionRatio) jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil { if err != nil {
SysError("error marshalling completion ratio: " + err.Error()) SysError("error marshalling completion ratio: " + err.Error())
@@ -345,11 +359,15 @@ func CompletionRatio2JSONString() string {
} }
func UpdateCompletionRatioByJSONString(jsonStr string) error { func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64) CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio) return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
} }
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string) float64 {
GetCompletionRatioMap()
if strings.Contains(name, "/") { if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok { if ratio, ok := CompletionRatio[name]; ok {
return ratio return ratio
@@ -418,11 +436,9 @@ func GetCompletionRatio(name string) float64 {
return 4 return 4
} }
} }
if strings.HasPrefix(lowercaseName, "deepseek") { // hint 只给官方上4倍率由于开源模型供应商自行定价不对其进行补全倍率进行强制对齐
if strings.HasSuffix(lowercaseName, "reasoner") || strings.HasSuffix(lowercaseName, "r1") { if lowercaseName == "deepseek-chat" || lowercaseName == "deepseek-reasoner" {
return 4 return 4
}
return 2
} }
if strings.HasPrefix(name, "ERNIE-Speed-") { if strings.HasPrefix(name, "ERNIE-Speed-") {
return 2 return 2
@@ -477,24 +493,3 @@ func GetAudioCompletionRatio(name string) float64 {
} }
return 2 return 2
} }
//func GetAudioPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.06
// }
// return 0.06
//}
//
//func GetAudioCompletionPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.24
// }
// return 0.24
//}
func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}

View File

@@ -32,6 +32,7 @@ func InitRedisClient() (err error) {
if err != nil { if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error()) FatalLog("failed to parse Redis connection string: " + err.Error())
} }
opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
RDB = redis.NewClient(opt) RDB = redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -41,6 +42,10 @@ func InitRedisClient() (err error) {
if err != nil { if err != nil {
FatalLog("Redis ping test failed: " + err.Error()) FatalLog("Redis ping test failed: " + err.Error())
} }
if DebugEnabled {
SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
}
return err return err
} }
@@ -53,13 +58,20 @@ func ParseRedisOption() *redis.Options {
} }
func RedisSet(key string, value string, expiration time.Duration) error { func RedisSet(key string, value string, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
}
ctx := context.Background() ctx := context.Background()
return RDB.Set(ctx, key, value, expiration).Err() return RDB.Set(ctx, key, value, expiration).Err()
} }
func RedisGet(key string) (string, error) { func RedisGet(key string) (string, error) {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis GET: key=%s", key))
}
ctx := context.Background() ctx := context.Background()
return RDB.Get(ctx, key).Result() val, err := RDB.Get(ctx, key).Result()
return val, err
} }
//func RedisExpire(key string, expiration time.Duration) error { //func RedisExpire(key string, expiration time.Duration) error {
@@ -73,16 +85,25 @@ func RedisGet(key string) (string, error) {
//} //}
func RedisDel(key string) error { func RedisDel(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
}
ctx := context.Background() ctx := context.Background()
return RDB.Del(ctx, key).Err() return RDB.Del(ctx, key).Err()
} }
func RedisHDelObj(key string) error { func RedisHDelObj(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
}
ctx := context.Background() ctx := context.Background()
return RDB.HDel(ctx, key).Err() return RDB.HDel(ctx, key).Err()
} }
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error { func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
}
ctx := context.Background() ctx := context.Background()
data := make(map[string]interface{}) data := make(map[string]interface{})
@@ -130,6 +151,9 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
} }
func RedisHGetObj(key string, obj interface{}) error { func RedisHGetObj(key string, obj interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
}
ctx := context.Background() ctx := context.Background()
result, err := RDB.HGetAll(ctx, key).Result() result, err := RDB.HGetAll(ctx, key).Result()
@@ -208,6 +232,9 @@ func RedisHGetObj(key string, obj interface{}) error {
// RedisIncr Add this function to handle atomic increments // RedisIncr Add this function to handle atomic increments
func RedisIncr(key string, delta int64) error { func RedisIncr(key string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
}
// 检查键的剩余生存时间 // 检查键的剩余生存时间
ttlCmd := RDB.TTL(context.Background(), key) ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result() ttl, err := ttlCmd.Result()
@@ -238,6 +265,9 @@ func RedisIncr(key string, delta int64) error {
} }
func RedisHIncrBy(key, field string, delta int64) error { func RedisHIncrBy(key, field string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
}
ttlCmd := RDB.TTL(context.Background(), key) ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result() ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) { if err != nil && !errors.Is(err, redis.Nil) {
@@ -262,6 +292,9 @@ func RedisHIncrBy(key, field string, delta int64) error {
} }
func RedisHSetField(key, field string, value interface{}) error { func RedisHSetField(key, field string, value interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
}
ttlCmd := RDB.TTL(context.Background(), key) ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result() ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) { if err != nil && !errors.Is(err, redis.Nil) {

View File

@@ -5,6 +5,7 @@ import (
"context" "context"
crand "crypto/rand" crand "crypto/rand"
"encoding/base64" "encoding/base64"
"encoding/json"
"fmt" "fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"html/template" "html/template"
@@ -213,6 +214,24 @@ func RandomSleep() {
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
} }
func GetPointer[T any](v T) *T {
return &v
}
func Any2Type[T any](data any) (T, error) {
var zero T
bytes, err := json.Marshal(data)
if err != nil {
return zero, err
}
var res T
err = json.Unmarshal(bytes, &res)
if err != nil {
return zero, err
}
return res, nil
}
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. // SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
func SaveTmpFile(filename string, data io.Reader) (string, error) { func SaveTmpFile(filename string, data io.Reader) (string, error) {
f, err := os.CreateTemp(os.TempDir(), filename) f, err := os.CreateTemp(os.TempDir(), filename)

View File

@@ -1,6 +1,7 @@
package constant package constant
var ( var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理 ChanelSettingProxy = "proxy" // Proxy 代理
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
) )

View File

@@ -2,4 +2,9 @@ package constant
const ( const (
ContextKeyRequestStartTime = "request_start_time" ContextKeyRequestStartTime = "request_start_time"
ContextKeyUserSetting = "user_setting"
ContextKeyUserQuota = "user_quota"
ContextKeyUserStatus = "user_status"
ContextKeyUserEmail = "user_email"
ContextKeyUserGroup = "user_group"
) )

View File

@@ -1,10 +1,7 @@
package constant package constant
import ( import (
"fmt"
"one-api/common" "one-api/common"
"os"
"strings"
) )
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60) var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
@@ -21,26 +18,31 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{ var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
"gemini-1.0-pro": "v1",
} //var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1",
//}
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() { func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" { //if modelVersionMapStr == "" {
return // return
} //}
for _, pair := range strings.Split(modelVersionMapStr, ",") { //for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":") // parts := strings.Split(pair, ":")
if len(parts) == 2 { // if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1] // GeminiModelMap[parts[0]] = parts[1]
} else { // } else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) // common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
} // }
} //}
} }
// 是否生成初始令牌,默认关闭。 // GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false) var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

14
constant/user_setting.go Normal file
View File

@@ -0,0 +1,14 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
)
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)

View File

@@ -41,9 +41,21 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
requestPath := "/v1/chat/completions"
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
strings.Contains(testModel, "embed") ||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
c.Request = &http.Request{ c.Request = &http.Request{
Method: "POST", Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"}, URL: &url.URL{Path: requestPath}, // 使用动态路径
Body: nil, Body: nil,
Header: make(http.Header), Header: make(http.Header),
} }
@@ -55,23 +67,29 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if len(channel.GetModels()) > 0 { if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0] testModel = channel.GetModels()[0]
} else { } else {
testModel = "gpt-3.5-turbo" testModel = "gpt-4o-mini"
}
}
} else {
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
} }
} }
} }
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
cache, err := model.GetUserCache(1)
if err != nil {
return err, nil
}
cache.WriteContext(c)
c.Request.Header.Set("Authorization", "Bearer "+channel.Key) c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
@@ -88,7 +106,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel) request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
adaptor.Init(meta) adaptor.Init(meta)
@@ -156,12 +174,21 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
Model: "", // this will be set later Model: "", // this will be set later
Stream: false, Stream: false,
} }
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
strings.Contains(model, "bge-") || // bge 系列模型
model == "text-embedding-v1" { // 其他 embedding 模型
// Embedding 请求
testRequest.Input = []string{"hello world"}
return testRequest
}
// 并非Embedding 模型
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") { if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
testRequest.MaxCompletionTokens = 10 testRequest.MaxCompletionTokens = 10
} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {
testRequest.MaxTokens = 10
} else { } else {
testRequest.MaxTokens = 1 testRequest.MaxTokens = 10
} }
content, _ := json.Marshal("hi") content, _ := json.Marshal("hi")
testMessage := dto.Message{ testMessage := dto.Message{
@@ -217,9 +244,7 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error { func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
if testAllChannelsRunning { if testAllChannelsRunning {
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
@@ -274,10 +299,7 @@ func testAllChannels(notify bool) error {
testAllChannelsRunning = false testAllChannelsRunning = false
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
if notify { if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
} }
}) })
return nil return nil

View File

@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else { } else {
if shouldReturnQuota { if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota) err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil { if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error()) common.LogError(ctx, "fail to increase user quota: "+err.Error())
} }

View File

@@ -66,6 +66,7 @@ func GetStatus(c *gin.Context) {
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled, "mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats, "chats": setting.Chats,
"demo_site_enabled": setting.DemoSiteEnabled,
}, },
}) })
return return

View File

@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
} }
var group string var group string
if exists { if exists {
user, err := model.GetUserById(userId.(int), false) user, err := model.GetUserCache(userId.(int))
if err == nil { if err == nil {
group = user.Group group = user.Group
} }

View File

@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case relayconstant.RelayModeImagesGenerations: case relayconstant.RelayModeImagesGenerations:
err = relay.ImageHelper(c, relayMode) err = relay.ImageHelper(c)
case relayconstant.RelayModeAudioSpeech: case relayconstant.RelayModeAudioSpeech:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranslation: case relayconstant.RelayModeAudioTranslation:
@@ -33,6 +33,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
err = relay.AudioHelper(c) err = relay.AudioHelper(c)
case relayconstant.RelayModeRerank: case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode) err = relay.RerankHelper(c, relayMode)
case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c)
default: default:
err = relay.TextHelper(c) err = relay.TextHelper(c)
} }
@@ -83,6 +85,7 @@ func Relay(c *gin.Context) {
if openaiErr != nil { if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests { if openaiErr.StatusCode == http.StatusTooManyRequests {
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
} }
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)

View File

@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
} else { } else {
quota := task.Quota quota := task.Quota
if quota != 0 { if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota) err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil { if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error()) common.LogError(ctx, "fail to increase user quota: "+err.Error())
} }

View File

@@ -210,7 +210,7 @@ func EpayNotify(c *gin.Context) {
} }
//user, _ := model.GetUserById(topUp.UserId, false) //user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000 //user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit)) err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true)
if err != nil { if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp) log.Printf("易支付回调更新用户失败: %v", topUp)
return return

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
if err != nil { if err != nil {
id = c.GetInt("id") id = c.GetInt("id")
} }
user, err := model.GetUserById(id, true) user, err := model.GetUserCache(id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -846,9 +847,10 @@ func EmailBind(c *gin.Context) {
}) })
return return
} }
id := c.GetInt("id") session := sessions.Default(c)
id := session.Get("id")
user := model.User{ user := model.User{
Id: id, Id: id.(int),
} }
err := user.FillUserById() err := user.FillUserById()
if err != nil { if err != nil {
@@ -868,9 +870,6 @@ func EmailBind(c *gin.Context) {
}) })
return return
} }
if user.Role == common.RoleRootUser {
common.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@@ -912,3 +911,115 @@ func TopUp(c *gin.Context) {
}) })
return return
} }
type UpdateUserSettingRequest struct {
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
}
func UpdateUserSetting(c *gin.Context) {
var req UpdateUserSettingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
// 验证预警类型
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
})
return
}
// 验证预警阈值
if req.QuotaWarningThreshold <= 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "预警阈值必须大于0",
})
return
}
// 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == constant.NotifyTypeWebhook {
if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Webhook地址不能为空",
})
return
}
// 验证URL格式
if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的Webhook地址",
})
return
}
}
// 如果是邮件类型,验证邮箱地址
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
}
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 构建设置
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
}
// 如果是webhook类型,添加webhook相关设置
if req.QuotaWarningType == constant.NotifyTypeWebhook {
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
if req.WebhookSecret != "" {
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
}
}
// 如果提供了通知邮箱,添加到设置中
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
}
// 更新用户设置
user.SetSetting(settings)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "更新设置失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "设置已更新",
})
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -142,9 +143,10 @@ func WeChatBind(c *gin.Context) {
}) })
return return
} }
id := c.GetInt("id") session := sessions.Default(c)
id := session.Get("id")
user := model.User{ user := model.User{
Id: id, Id: id.(int),
} }
err = user.FillUserById() err = user.FillUserById()
if err != nil { if err != nil {

View File

@@ -24,7 +24,7 @@ services:
- redis - redis
- mysql - mysql
healthcheck: healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3

View File

@@ -10,6 +10,10 @@
- 用于配置网络代理 - 用于配置网络代理
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址) - 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
3. thinking_to_content
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
- 类型为布尔值,设置为 true 时启用思考内容转换
-------------------------------------------------------------- --------------------------------------------------------------
## JSON 格式示例 ## JSON 格式示例
@@ -19,6 +23,7 @@
```json ```json
{ {
"force_format": true, "force_format": true,
"thinking_to_content": true,
"proxy": "socks5://xxxxxxx" "proxy": "socks5://xxxxxxx"
} }
``` ```

57
dto/embedding.go Normal file
View File

@@ -0,0 +1,57 @@
package dto
type EmbeddingOptions struct {
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input any `json:"input"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
User string `json:"user,omitempty"`
Seed float64 `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
func (r EmbeddingRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type EmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type EmbeddingResponse struct {
Object string `json:"object"`
Data []EmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}

25
dto/notify.go Normal file
View File

@@ -0,0 +1,25 @@
package dto
type Notify struct {
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Values []interface{} `json:"values"`
}
const ContentValueParam = "{{value}}"
const (
NotifyTypeQuotaExceed = "quota_exceed"
NotifyTypeChannelUpdate = "channel_update"
NotifyTypeChannelTest = "channel_test"
)
func NewNotify(t string, title string, content string, values []interface{}) Notify {
return Notify{
Type: t,
Title: title,
Content: content,
Values: values,
}
}

View File

@@ -1,6 +1,9 @@
package dto package dto
import "encoding/json" import (
"encoding/json"
"strings"
)
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
@@ -15,47 +18,52 @@ type FormatJsonSchema struct {
} }
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"` Prefix any `json:"prefix,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` Suffix any `json:"suffix,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` Stream bool `json:"stream,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"`
TopK int `json:"top_k,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
Stop any `json:"stop,omitempty"` TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"` TopK int `json:"top_k,omitempty"`
Input any `json:"input,omitempty"` Stop any `json:"stop,omitempty"`
Instruction string `json:"instruction,omitempty"` N int `json:"n,omitempty"`
Size string `json:"size,omitempty"` Input any `json:"input,omitempty"`
Functions any `json:"functions,omitempty"` Instruction string `json:"instruction,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` Size string `json:"size,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` Functions any `json:"functions,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed float64 `json:"seed,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Tools []ToolCall `json:"tools,omitempty"` EncodingFormat any `json:"encoding_format,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` Seed float64 `json:"seed,omitempty"`
User string `json:"user,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"`
LogProbs bool `json:"logprobs,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"` User string `json:"user,omitempty"`
Dimensions int `json:"dimensions,omitempty"` LogProbs bool `json:"logprobs,omitempty"`
Modalities any `json:"modalities,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"`
Audio any `json:"audio,omitempty"` Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"`
} }
type OpenAITools struct { type ToolCallRequest struct {
Type string `json:"type"` ID string `json:"id,omitempty"`
Function OpenAIFunction `json:"function"` Type string `json:"type"`
Function FunctionRequest `json:"function"`
} }
type OpenAIFunction struct { type FunctionRequest struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Parameters any `json:"parameters,omitempty"` Parameters any `json:"parameters,omitempty"`
Arguments string `json:"arguments,omitempty"`
} }
type StreamOptions struct { type StreamOptions struct {
@@ -86,16 +94,20 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content json.RawMessage `json:"content"` Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"` Prefix *bool `json:"prefix,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
parsedContent []MediaContent
parsedStringContent *string
} }
type MediaContent struct { type MediaContent struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text,omitempty"`
ImageUrl any `json:"image_url,omitempty"` ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"` InputAudio any `json:"input_audio,omitempty"`
} }
@@ -116,11 +128,22 @@ const (
ContentTypeInputAudio = "input_audio" ContentTypeInputAudio = "input_audio"
) )
func (m *Message) ParseToolCalls() []ToolCall { func (m *Message) GetPrefix() bool {
if m.Prefix == nil {
return false
}
return *m.Prefix
}
func (m *Message) SetPrefix(prefix bool) {
m.Prefix = &prefix
}
func (m *Message) ParseToolCalls() []ToolCallRequest {
if m.ToolCalls == nil { if m.ToolCalls == nil {
return nil return nil
} }
var toolCalls []ToolCall var toolCalls []ToolCallRequest
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil { if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
return toolCalls return toolCalls
} }
@@ -133,88 +156,139 @@ func (m *Message) SetToolCalls(toolCalls any) {
} }
func (m *Message) StringContent() string { func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
var stringContent string var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil { if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return stringContent return stringContent
} }
return string(m.Content)
contentStr := new(strings.Builder)
arrayContent := m.ParseContent()
for _, content := range arrayContent {
if content.Type == ContentTypeText {
contentStr.WriteString(content.Text)
}
}
stringContent = contentStr.String()
m.parsedStringContent = &stringContent
return stringContent
} }
func (m *Message) SetStringContent(content string) { func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content) jsonContent, _ := json.Marshal(content)
m.Content = jsonContent m.Content = jsonContent
m.parsedStringContent = &content
m.parsedContent = nil
}
func (m *Message) SetMediaContent(content []MediaContent) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
m.parsedContent = nil
m.parsedStringContent = nil
} }
func (m *Message) IsStringContent() bool { func (m *Message) IsStringContent() bool {
if m.parsedStringContent != nil {
return true
}
var stringContent string var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil { if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return true return true
} }
return false return false
} }
func (m *Message) ParseContent() []MediaContent { func (m *Message) ParseContent() []MediaContent {
if m.parsedContent != nil {
return m.parsedContent
}
var contentList []MediaContent var contentList []MediaContent
// 先尝试解析为字符串
var stringContent string var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil { if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaContent{ contentList = []MediaContent{{
Type: ContentTypeText, Type: ContentTypeText,
Text: stringContent, Text: stringContent,
}) }}
m.parsedContent = contentList
return contentList return contentList
} }
var arrayContent []json.RawMessage
// 尝试解析为数组
var arrayContent []map[string]interface{}
if err := json.Unmarshal(m.Content, &arrayContent); err == nil { if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent { for _, contentItem := range arrayContent {
var contentMap map[string]any contentType, ok := contentItem["type"].(string)
if err := json.Unmarshal(contentItem, &contentMap); err != nil { if !ok {
continue continue
} }
switch contentMap["type"] {
switch contentType {
case ContentTypeText: case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok { if text, ok := contentItem["text"].(string); ok {
contentList = append(contentList, MediaContent{ contentList = append(contentList, MediaContent{
Type: ContentTypeText, Type: ContentTypeText,
Text: subStr, Text: text,
}) })
} }
case ContentTypeImageURL: case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok { imageUrl := contentItem["image_url"]
detail, ok := subObj["detail"] switch v := imageUrl.(type) {
if ok { case string:
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "high"
}
contentList = append(contentList, MediaContent{ contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL, Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{ ImageUrl: MessageImageUrl{
Url: subObj["url"].(string), Url: v,
Detail: subObj["detail"].(string),
},
})
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: "high", Detail: "high",
}, },
}) })
case map[string]interface{}:
url, ok1 := v["url"].(string)
detail, ok2 := v["detail"].(string)
if !ok2 {
detail = "high"
}
if ok1 {
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: detail,
},
})
}
} }
case ContentTypeInputAudio: case ContentTypeInputAudio:
if subObj, ok := contentMap["input_audio"].(map[string]any); ok { if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
contentList = append(contentList, MediaContent{ data, ok1 := audioData["data"].(string)
Type: ContentTypeInputAudio, format, ok2 := audioData["format"].(string)
InputAudio: MessageInputAudio{ if ok1 && ok2 {
Data: subObj["data"].(string), contentList = append(contentList, MediaContent{
Format: subObj["format"].(string), Type: ContentTypeInputAudio,
}, InputAudio: MessageInputAudio{
}) Data: data,
Format: format,
},
})
}
} }
} }
} }
return contentList
} }
return nil
if len(contentList) > 0 {
m.parsedContent = contentList
}
return contentList
} }

View File

@@ -62,9 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
} }
type ChatCompletionsStreamResponseChoiceDelta struct { type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"` Content *string `json:"content,omitempty"`
Role string `json:"role,omitempty"` ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` Role string `json:"role,omitempty"`
ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
} }
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -78,24 +79,35 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
return *c.Content return *c.Content
} }
type ToolCall struct { func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
// Index is not nil only in chat completion chunk object if c.ReasoningContent == nil {
Index *int `json:"index,omitempty"` return ""
ID string `json:"id"` }
Type any `json:"type"` return *c.ReasoningContent
Function FunctionCall `json:"function"`
} }
func (c *ToolCall) SetIndex(i int) { func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
c.ReasoningContent = &s
}
type ToolCallResponse struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionResponse `json:"function"`
}
func (c *ToolCallResponse) SetIndex(i int) {
c.Index = &i c.Index = &i
} }
type FunctionCall struct { type FunctionResponse struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
// call function with arguments in JSON format // call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"` Arguments string `json:"arguments"` // response
} }
type ChatCompletionsStreamResponse struct { type ChatCompletionsStreamResponse struct {
@@ -108,6 +120,20 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`
} }
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
copy(choices, c.Choices)
return &ChatCompletionsStreamResponse{
Id: c.Id,
Object: c.Object,
Created: c.Created,
Model: c.Model,
SystemFingerprint: c.SystemFingerprint,
Choices: choices,
Usage: c.Usage,
}
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string { func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil { if c.SystemFingerprint == nil {
return "" return ""

16
go.mod
View File

@@ -16,6 +16,7 @@ require (
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1 github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/glebarez/sqlite v1.9.0
github.com/go-playground/validator/v10 v10.20.0 github.com/go-playground/validator/v10 v10.20.0
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible
@@ -29,10 +30,10 @@ require (
github.com/shirou/gopsutil v3.21.11+incompatible github.com/shirou/gopsutil v3.21.11+incompatible
golang.org/x/crypto v0.27.0 golang.org/x/crypto v0.27.0
golang.org/x/image v0.23.0 golang.org/x/image v0.23.0
golang.org/x/net v0.28.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2 gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3 gorm.io/gorm v1.25.2
gorm.io/gorm v1.25.0
) )
require ( require (
@@ -48,12 +49,14 @@ require (
github.com/cloudwego/iasm v0.2.0 // indirect github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/context v1.1.1 // indirect
@@ -69,11 +72,11 @@ require (
github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.1 // indirect github.com/pelletier/go-toml/v2 v2.2.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -81,10 +84,13 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.10.0 // indirect golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.27.0 // indirect golang.org/x/sys v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.5 // indirect
modernc.org/mathutil v1.5.0 // indirect
modernc.org/memory v1.5.0 // indirect
modernc.org/sqlite v1.23.1 // indirect
) )

32
go.sum
View File

@@ -40,6 +40,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -58,6 +60,10 @@ github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwv
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@@ -77,8 +83,9 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
@@ -90,6 +97,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -140,9 +149,6 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -167,6 +173,9 @@ github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQ
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -263,11 +272,16 @@ gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

11
main.go
View File

@@ -119,9 +119,9 @@ func main() {
} }
if os.Getenv("ENABLE_PPROF") == "true" { if os.Getenv("ENABLE_PPROF") == "true" {
go func() { gopool.Go(func() {
log.Println(http.ListenAndServe("0.0.0.0:8005", nil)) log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
}() })
go common.Monitor() go common.Monitor()
common.SysLog("pprof enabled") common.SysLog("pprof enabled")
} }
@@ -145,6 +145,13 @@ func main() {
middleware.SetUpLogger(server) middleware.SetUpLogger(server)
// Initialize session store // Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret)) store := cookie.NewStore([]byte(common.SessionSecret))
store.Options(sessions.Options{
Path: "/",
MaxAge: 2592000, // 30 days
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteStrictMode,
})
server.Use(sessions.Sessions("session", store)) server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS, indexPage) router.SetRouter(server, buildFS, indexPage)

View File

@@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return return
} }
userEnabled, err := model.IsUserEnabled(token.UserId, false) userCache, err := model.GetUserCache(token.UserId)
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return return
} }
userEnabled := userCache.Status == common.UserStatusEnabled
if !userEnabled { if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return return
} }
userCache.WriteContext(c)
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_key", token.Key) c.Set("token_key", token.Key)

View File

@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
return return
} }
} }
userId := c.GetInt("id")
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("specific_channel_id") channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c) modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return return
} }
userGroup, _ := model.GetUserGroup(userId, false) userGroup := c.GetString(constant.ContextKeyUserGroup)
tokenGroup := c.GetString("token_group") tokenGroup := c.GetString("token_group")
if tokenGroup != "" { if tokenGroup != "" {
// check common.UserUsableGroups[userGroup] // check common.UserUsableGroups[userGroup]
@@ -135,17 +134,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
midjourneyRequest := dto.MidjourneyRequest{} midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest) err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil { if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err return nil, false, err
} }
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil { if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description) return nil, false, fmt.Errorf(mjErr.Description)
} }
if midjourneyModel == "" { if midjourneyModel == "" {
if !success { if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型") return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else { } else {
// task fetch, task fetch by condition, notify // task fetch, task fetch by condition, notify
@@ -170,7 +166,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
err = common.UnmarshalBodyReusable(c, &modelRequest) err = common.UnmarshalBodyReusable(c, &modelRequest)
} }
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error()) return nil, false, errors.New("无效的请求, " + err.Error())
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") { if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
@@ -239,5 +234,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("plugin", channel.Other) c.Set("plugin", channel.Other)
case common.ChannelCloudflare: case common.ChannelCloudflare:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeMokaAI:
c.Set("api_version", channel.Other)
} }
} }

View File

@@ -0,0 +1,172 @@
package middleware
import (
"context"
"fmt"
"net/http"
"one-api/common"
"one-api/setting"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
const (
ModelRequestRateLimitCountMark = "MRRL"
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
// 检查Redis中的请求限制
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
// 如果maxCount为0表示不限制
if maxCount == 0 {
return true, nil
}
// 获取当前计数
length, err := rdb.LLen(ctx, key).Result()
if err != nil {
return false, err
}
// 如果未达到限制,允许请求
if length < int64(maxCount) {
return true, nil
}
// 检查时间窗口
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
return false, err
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
return false, err
}
// 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
return false, nil
}
return true, nil
}
// 记录Redis请求
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
// 如果maxCount为0不记录请求
if maxCount == 0 {
return
}
now := time.Now().Format(timeFormat)
rdb.LPush(ctx, key, now)
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
}
// Redis限流处理器
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background()
rdb := common.RDB
// 1. 检查总请求数限制当totalMaxCount为0时会自动跳过
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 2. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
return
}
// 3. 记录总请求当totalMaxCount为0时会自动跳过
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
// 4. 处理请求
c.Next()
// 5. 如果请求成功,记录成功请求
if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
}
}
}
// 内存限流处理器
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
// 1. 检查总请求数限制当totalMaxCount为0时跳过
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 2. 检查成功请求数限制
// 使用一个临时key来检查限制这样可以避免实际记录
checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 3. 处理请求
c.Next()
// 4. 如果请求成功,记录到实际的成功请求计数中
if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
}
}
}
// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit() func(c *gin.Context) {
// 如果未启用限流,直接放行
if !setting.ModelRequestRateLimitEnabled {
return defNext
}
// 计算限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 根据存储类型选择限流处理器
if common.RedisEnabled {
return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
} else {
return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
}
}

View File

@@ -11,106 +11,6 @@ import (
"time" "time"
) )
//func CacheGetUserGroup(id int) (group string, err error) {
// if !common.RedisEnabled {
// return GetUserGroup(id)
// }
// group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
// if err != nil {
// group, err = GetUserGroup(id)
// if err != nil {
// return "", err
// }
// err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user group error: " + err.Error())
// }
// }
// return group, err
//}
//
//func CacheGetUsername(id int) (username string, err error) {
// if !common.RedisEnabled {
// return GetUsernameById(id)
// }
// username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
// if err != nil {
// username, err = GetUsernameById(id)
// if err != nil {
// return "", err
// }
// err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user group error: " + err.Error())
// }
// }
// return username, err
//}
//
//func CacheGetUserQuota(id int) (quota int, err error) {
// if !common.RedisEnabled {
// return GetUserQuota(id)
// }
// quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
// if err != nil {
// quota, err = GetUserQuota(id)
// if err != nil {
// return 0, err
// }
// return quota, nil
// }
// quota, err = strconv.Atoi(quotaString)
// return quota, nil
//}
//
//func CacheUpdateUserQuota(id int) error {
// if !common.RedisEnabled {
// return nil
// }
// quota, err := GetUserQuota(id)
// if err != nil {
// return err
// }
// return cacheSetUserQuota(id, quota)
//}
//
//func cacheSetUserQuota(id int, quota int) error {
// err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
// return err
//}
//
//func CacheDecreaseUserQuota(id int, quota int) error {
// if !common.RedisEnabled {
// return nil
// }
// err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
// return err
//}
//
//func CacheIsUserEnabled(userId int) (bool, error) {
// if !common.RedisEnabled {
// return IsUserEnabled(userId)
// }
// enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
// if err == nil {
// return enabled == "1", nil
// }
//
// userEnabled, err := IsUserEnabled(userId)
// if err != nil {
// return false, err
// }
// enabled = "0"
// if userEnabled {
// enabled = "1"
// }
// err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user enabled error: " + err.Error())
// }
// return userEnabled, err
//}
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex var channelSyncLock sync.RWMutex

View File

@@ -28,7 +28,7 @@ type Channel struct {
Models string `json:"models"` Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(64);default:'default'"` Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:text"`
//MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"` //MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"` StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`

View File

@@ -1,8 +1,8 @@
package model package model
import ( import (
"context"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"os" "os"
"strings" "strings"
@@ -87,14 +87,14 @@ func RecordLog(userId int, logType int, content string) {
} }
} }
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) { isStream bool, group string, other map[string]interface{}) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled { if !common.LogConsumeEnabled {
return return
} }
username, _ := GetUsernameById(userId, false) username := c.GetString("username")
otherStr := common.MapToJsonStr(other) otherStr := common.MapToJsonStr(other)
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
@@ -116,7 +116,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
} }
err := LOG_DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error()) common.LogError(c, "failed to record log: "+err.Error())
} }
if common.DataExportEnabled { if common.DataExportEnabled {
gopool.Go(func() { gopool.Go(func() {
@@ -133,9 +133,6 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = LOG_DB.Where("logs.type = ?", logType) tx = LOG_DB.Where("logs.type = ?", logType)
} }
tx = tx.Joins("LEFT JOIN channels ON logs.channel_id = channels.id")
tx = tx.Select("logs.*, channels.name as channel_name")
if modelName != "" { if modelName != "" {
tx = tx.Where("logs.model_name like ?", modelName) tx = tx.Where("logs.model_name like ?", modelName)
} }
@@ -165,6 +162,30 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
channelIds := make([]int, 0)
channelMap := make(map[int]string)
for _, log := range logs {
if log.ChannelId != 0 {
channelIds = append(channelIds, log.ChannelId)
}
}
if len(channelIds) > 0 {
var channels []struct {
Id int `gorm:"column:id"`
Name string `gorm:"column:name"`
}
if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil {
return logs, total, err
}
for _, channel := range channels {
channelMap[channel.Id] = channel.Name
}
for i := range logs {
logs[i].ChannelName = channelMap[logs[i].ChannelId]
}
}
return logs, total, err return logs, total, err
} }
@@ -176,9 +197,6 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType) tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
} }
tx = tx.Joins("LEFT JOIN channels ON logs.channel_id = channels.id")
tx = tx.Select("logs.*, channels.name as channel_name")
if modelName != "" { if modelName != "" {
tx = tx.Where("logs.model_name like ?", modelName) tx = tx.Where("logs.model_name like ?", modelName)
} }
@@ -199,6 +217,10 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
return nil, 0, err return nil, 0, err
} }
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
formatUserLogs(logs) formatUserLogs(logs)
return logs, total, err return logs, total, err
} }

View File

@@ -1,9 +1,9 @@
package model package model
import ( import (
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"log" "log"
"one-api/common" "one-api/common"
@@ -119,12 +119,9 @@ func InitDB() (err error) {
if !common.IsMasterNode { if !common.IsMasterNode {
return nil return nil
} }
//if common.UsingMySQL { if common.UsingMySQL {
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded _, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded }
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//}
common.SysLog("database migration started") common.SysLog("database migration started")
err = migrateDB() err = migrateDB()
return err return err

View File

@@ -3,6 +3,7 @@ package model
import ( import (
"one-api/common" "one-api/common"
"one-api/setting" "one-api/setting"
"one-api/setting/config"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -23,6 +24,8 @@ func AllOption() ([]*Option, error) {
func InitOptionMap() { func InitOptionMap() {
common.OptionMapRWMutex.Lock() common.OptionMapRWMutex.Lock()
common.OptionMap = make(map[string]string) common.OptionMap = make(map[string]string)
// 添加原有的系统配置
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
@@ -84,7 +87,10 @@ func InitOptionMap() {
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -104,11 +110,19 @@ func InitOptionMap() {
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled) common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
for k, v := range modelConfigs {
common.OptionMap[k] = v
}
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
@@ -152,6 +166,13 @@ func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock() common.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock() defer common.OptionMapRWMutex.Unlock()
common.OptionMap[key] = value common.OptionMap[key] = value
// 检查是否是模型配置 - 使用更规范的方式处理
if handleConfigUpdate(key, value) {
return nil // 已由配置系统处理
}
// 处理传统配置项...
if strings.HasSuffix(key, "Permission") { if strings.HasSuffix(key, "Permission") {
intValue, _ := strconv.Atoi(value) intValue, _ := strconv.Atoi(value)
switch key { switch key {
@@ -220,10 +241,12 @@ func updateOptionMap(key string, value string) (err error) {
setting.MjActionCheckSuccessEnabled = boolValue setting.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled": case "CheckSensitiveEnabled":
setting.CheckSensitiveEnabled = boolValue setting.CheckSensitiveEnabled = boolValue
case "DemoSiteEnabled":
setting.DemoSiteEnabled = boolValue
case "CheckSensitiveOnPromptEnabled": case "CheckSensitiveOnPromptEnabled":
setting.CheckSensitiveOnPromptEnabled = boolValue setting.CheckSensitiveOnPromptEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled": case "ModelRequestRateLimitEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue setting.ModelRequestRateLimitEnabled = boolValue
case "StopOnSensitiveEnabled": case "StopOnSensitiveEnabled":
setting.StopOnSensitiveEnabled = boolValue setting.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled": case "SMTPSSLEnabled":
@@ -302,8 +325,14 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForInvitee, _ = strconv.Atoi(value) common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold": case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value) common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota": case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value) common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRequestRateLimitCount":
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitDurationMinutes":
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "RetryTimes": case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value) common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval": case "DataExportInterval":
@@ -332,8 +361,35 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords": case "SensitiveWords":
setting.SensitiveWordsFromString(value) setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength": case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value) setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
} }
return err return err
} }
// handleConfigUpdate 处理分层配置更新,返回是否已处理
func handleConfigUpdate(key, value string) bool {
parts := strings.SplitN(key, ".", 2)
if len(parts) != 2 {
return false // 不是分层配置
}
configName := parts[0]
configKey := parts[1]
// 获取配置对象
cfg := config.GlobalConfig.Get(configName)
if cfg == nil {
return false // 未注册的配置
}
// 更新配置
configMap := map[string]string{
configKey: value,
}
config.UpdateConfigFromMap(cfg, configMap)
return true // 已处理
}

View File

@@ -3,13 +3,11 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"one-api/common"
"strings"
"github.com/bytedance/gopkg/util/gopool" "github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common"
relaycommon "one-api/relay/common"
"one-api/setting"
"strconv"
"strings"
) )
type Token struct { type Token struct {
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error ).Error
return err return err
} }
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = IncreaseUserQuota(relayInfo.UserId, -quota)
}
if err != nil {
return err
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(relayInfo.UserId)
if err != nil {
common.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
prompt = "您的额度已用尽"
}
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
common.SysError("failed to send email" + err.Error())
}
common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
}
}()
}
}
}
return nil
}

View File

@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
func cacheGetTokenByKey(key string) (*Token, error) { func cacheGetTokenByKey(key string) (*Token, error) {
hmacKey := common.GenerateHMAC(key) hmacKey := common.GenerateHMAC(key)
if !common.RedisEnabled { if !common.RedisEnabled {
return nil, nil return nil, fmt.Errorf("redis is not enabled")
} }
var token Token var token Token
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token) err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)

View File

@@ -1,6 +1,7 @@
package model package model
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
@@ -38,6 +39,20 @@ type User struct {
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"` LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
}
func (user *User) ToBaseUser() *UserBase {
cache := &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return cache
} }
func (user *User) GetAccessToken() string { func (user *User) GetAccessToken() string {
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token user.AccessToken = &token
} }
func (user *User) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *User) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) { func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User var user User
@@ -289,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
} }
if inviterId != 0 { if inviterId != 0 {
if common.QuotaForInvitee > 0 { if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
} }
if common.QuotaForInviter > 0 { if common.QuotaForInviter > 0 {
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
return err return err
} }
// 更新缓存 // Update cache
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) return updateUserCache(*user)
} }
func (user *User) Edit(updatePassword bool) error { func (user *User) Edit(updatePassword bool) error {
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
return err return err
} }
// 更新缓存 // Update cache
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) return updateUserCache(*user)
} }
func (user *User) Delete() error { func (user *User) Delete() error {
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
// ValidateAndFill check password & user status // ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) { func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields, // When querying with struct, GORM will only query with non-zero fields,
// that means if your fields value is 0, '', false or other zero values, // that means if your field's value is 0, '', false or other zero values,
// it wont be used to build query conditions // it won't be used to build query conditions
password := user.Password password := user.Password
username := strings.TrimSpace(user.Username) username := strings.TrimSpace(user.Username)
if username == "" || password == "" { if username == "" || password == "" {
@@ -471,35 +502,35 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser return user.Role >= common.RoleAdminUser
} }
// IsUserEnabled checks user status from Redis first, falls back to DB if needed //// IsUserEnabled checks user status from Redis first, falls back to DB if needed
func IsUserEnabled(id int, fromDB bool) (status bool, err error) { //func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
defer func() { // defer func() {
// Update Redis cache asynchronously on successful DB read // // Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) { // if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() { // gopool.Go(func() {
if err := updateUserStatusCache(id, status); err != nil { // if err := updateUserStatusCache(id, status); err != nil {
common.SysError("failed to update user status cache: " + err.Error()) // common.SysError("failed to update user status cache: " + err.Error())
} // }
}) // })
} // }
}() // }()
if !fromDB && common.RedisEnabled { // if !fromDB && common.RedisEnabled {
// Try Redis first // // Try Redis first
status, err := getUserStatusCache(id) // status, err := getUserStatusCache(id)
if err == nil { // if err == nil {
return status == common.UserStatusEnabled, nil // return status == common.UserStatusEnabled, nil
} // }
// Don't return error - fall through to DB // // Don't return error - fall through to DB
} // }
fromDB = true // fromDB = true
var user User // var user User
err = DB.Where("id = ?", id).Select("status").Find(&user).Error // err = DB.Where("id = ?", id).Select("status").Find(&user).Error
if err != nil { // if err != nil {
return false, err // return false, err
} // }
//
return user.Status == common.UserStatusEnabled, nil // return user.Status == common.UserStatusEnabled, nil
} //}
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
if token == "" { if token == "" {
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
return quota, nil return quota, nil
} }
// Don't return error - fall through to DB // Don't return error - fall through to DB
//common.SysError("failed to get user quota from cache: " + err.Error())
} }
fromDB = true fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
@@ -580,7 +610,36 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
return group, nil return group, nil
} }
func IncreaseUserQuota(id int, quota int) (err error) { // GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil {
common.SysError("failed to update user setting cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
setting, err := getUserSettingCache(id)
if err == nil {
return setting, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
return map[string]interface{}{}, err
}
return common.StrToMap(setting), nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
@@ -590,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
common.SysError("failed to increase user quota: " + err.Error()) common.SysError("failed to increase user quota: " + err.Error())
} }
}) })
if common.BatchUpdateEnabled { if !db && common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota) addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil return nil
} }
@@ -635,15 +694,20 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
return nil return nil
} }
if delta > 0 { if delta > 0 {
return IncreaseUserQuota(id, delta) return IncreaseUserQuota(id, delta, false)
} else { } else {
return DecreaseUserQuota(id, -delta) return DecreaseUserQuota(id, -delta)
} }
} }
func GetRootUserEmail() (email string) { //func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) // DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email // return email
//}
func GetRootUser() (user *User) {
DB.Where("role = ?", common.RoleRootUser).First(&user)
return user
} }
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
return !errors.Is(err, gorm.ErrRecordNotFound) return !errors.Is(err, gorm.ErrRecordNotFound)
} }
func (u *User) FillUserByLinuxDOId() error { func (user *User) FillUserByLinuxDOId() error {
if u.LinuxDOId == "" { if user.LinuxDOId == "" {
return errors.New("linux do id is empty") return errors.New("linux do id is empty")
} }
err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err return err
} }

View File

@@ -1,206 +1,223 @@
package model package model
import ( import (
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"strconv"
"time" "time"
"github.com/bytedance/gopkg/util/gopool"
) )
// Change UserCache struct to userCache // UserBase struct remains the same as it represents the cached data structure
type userCache struct { type UserBase struct {
Id int `json:"id"` Id int `json:"id"`
Group string `json:"group"` Group string `json:"group"`
Email string `json:"email"`
Quota int `json:"quota"` Quota int `json:"quota"`
Status int `json:"status"` Status int `json:"status"`
Role int `json:"role"`
Username string `json:"username"` Username string `json:"username"`
Setting string `json:"setting"`
} }
// Rename all exported functions to private ones func (user *UserBase) WriteContext(c *gin.Context) {
// invalidateUserCache clears all user related cache c.Set(constant.ContextKeyUserGroup, user.Group)
c.Set(constant.ContextKeyUserQuota, user.Quota)
c.Set(constant.ContextKeyUserStatus, user.Status)
c.Set(constant.ContextKeyUserEmail, user.Email)
c.Set("username", user.Username)
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// getUserCacheKey returns the key for user cache
func getUserCacheKey(userId int) string {
return fmt.Sprintf("user:%d", userId)
}
// invalidateUserCache clears user cache
func invalidateUserCache(userId int) error { func invalidateUserCache(userId int) error {
if !common.RedisEnabled { if !common.RedisEnabled {
return nil return nil
} }
return common.RedisHDelObj(getUserCacheKey(userId))
}
keys := []string{ // updateUserCache updates all user cache fields using hash
fmt.Sprintf(constant.UserGroupKeyFmt, userId), func updateUserCache(user User) error {
fmt.Sprintf(constant.UserQuotaKeyFmt, userId), if !common.RedisEnabled {
fmt.Sprintf(constant.UserEnabledKeyFmt, userId), return nil
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
} }
for _, key := range keys { return common.RedisHSetObj(
if err := common.RedisDel(key); err != nil { getUserCacheKey(user.Id),
return fmt.Errorf("failed to delete cache key %s: %w", key, err) user.ToBaseUser(),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// GetUserCache gets complete user cache from hash
func GetUserCache(userId int) (userCache *UserBase, err error) {
var user *User
var fromDB bool
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) && user != nil {
gopool.Go(func() {
if err := updateUserCache(*user); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
} }
} }()
return nil
}
// updateUserGroupCache updates user group cache // Try getting from Redis first
func updateUserGroupCache(userId int, group string) error { userCache, err = cacheGetUserBase(userId)
if !common.RedisEnabled { if err == nil {
return nil return userCache, nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
group,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserQuotaCache updates user quota cache
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf("%d", quota),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserStatusCache updates user status cache
func updateUserStatusCache(userId int, userEnabled bool) error {
if !common.RedisEnabled {
return nil
}
enabled := "0"
if userEnabled {
enabled = "1"
}
return common.RedisSet(
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
enabled,
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
)
}
// updateUserNameCache updates username cache
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
username,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserCache updates all user cache fields
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
if !common.RedisEnabled {
return nil
} }
if err := updateUserGroupCache(userId, userGroup); err != nil { // If Redis fails, get from DB
return fmt.Errorf("update group cache: %w", err) fromDB = true
} user, err = GetUserById(userId, false)
if err := updateUserQuotaCache(userId, quota); err != nil {
return fmt.Errorf("update quota cache: %w", err)
}
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
return fmt.Errorf("update status cache: %w", err)
}
if err := updateUserNameCache(userId, username); err != nil {
return fmt.Errorf("update username cache: %w", err)
}
return nil
}
// getUserGroupCache gets user group from cache
func getUserGroupCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
}
// getUserQuotaCache gets user quota from cache
func getUserQuotaCache(userId int) (int, error) {
if !common.RedisEnabled {
return 0, nil
}
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
if err != nil { if err != nil {
return 0, err return nil, err // Return nil and error if DB lookup fails
} }
return strconv.Atoi(quotaStr)
// Create cache object from user data
userCache = &UserBase{
Id: user.Id,
Group: user.Group,
Quota: user.Quota,
Status: user.Status,
Username: user.Username,
Setting: user.Setting,
Email: user.Email,
}
return userCache, nil
} }
// getUserStatusCache gets user status from cache func cacheGetUserBase(userId int) (*UserBase, error) {
func getUserStatusCache(userId int) (int, error) {
if !common.RedisEnabled { if !common.RedisEnabled {
return 0, nil return nil, fmt.Errorf("redis is not enabled")
} }
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId)) var userCache UserBase
// Try getting from Redis first
err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
if err != nil { if err != nil {
return 0, err return nil, err
} }
return strconv.Atoi(statusStr) return &userCache, nil
} }
// getUserNameCache gets username from cache // Add atomic quota operations using hash fields
func getUserNameCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
}
// getUserCache gets complete user cache
func getUserCache(userId int) (*userCache, error) {
if !common.RedisEnabled {
return nil, nil
}
group, err := getUserGroupCache(userId)
if err != nil {
return nil, fmt.Errorf("get group cache: %w", err)
}
quota, err := getUserQuotaCache(userId)
if err != nil {
return nil, fmt.Errorf("get quota cache: %w", err)
}
status, err := getUserStatusCache(userId)
if err != nil {
return nil, fmt.Errorf("get status cache: %w", err)
}
username, err := getUserNameCache(userId)
if err != nil {
return nil, fmt.Errorf("get username cache: %w", err)
}
return &userCache{
Id: userId,
Group: group,
Quota: quota,
Status: status,
Username: username,
}, nil
}
// Add atomic quota operations
func cacheIncrUserQuota(userId int, delta int64) error { func cacheIncrUserQuota(userId int, delta int64) error {
if !common.RedisEnabled { if !common.RedisEnabled {
return nil return nil
} }
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId) return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
return common.RedisIncr(key, delta)
} }
func cacheDecrUserQuota(userId int, delta int64) error { func cacheDecrUserQuota(userId int, delta int64) error {
return cacheIncrUserQuota(userId, -delta) return cacheIncrUserQuota(userId, -delta)
} }
// Helper functions to get individual fields if needed
func getUserGroupCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Group, nil
}
func getUserQuotaCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Quota, nil
}
func getUserStatusCache(userId int) (int, error) {
cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
return cache.Status, nil
}
func getUserNameCache(userId int) (string, error) {
cache, err := GetUserCache(userId)
if err != nil {
return "", err
}
return cache.Username, nil
}
func getUserSettingCache(userId int) (map[string]interface{}, error) {
setting := make(map[string]interface{})
cache, err := GetUserCache(userId)
if err != nil {
return setting, err
}
return cache.GetSetting(), nil
}
// New functions for individual field updates
func updateUserStatusCache(userId int, status bool) error {
if !common.RedisEnabled {
return nil
}
statusInt := common.UserStatusEnabled
if !status {
statusInt = common.UserStatusDisabled
}
return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
}
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
}
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
}
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
}
func updateUserSettingCache(userId int, setting string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
}

View File

@@ -15,6 +15,7 @@ type Adaptor interface {
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)

View File

@@ -49,9 +49,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil
default: default:
aliReq := requestOpenAI2Ali(*request) aliReq := requestOpenAI2Ali(*request)
return aliReq, nil return aliReq, nil
@@ -67,6 +64,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return embeddingRequestOpenAI2Ali(request), nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented") return nil, errors.New("not implemented")

View File

@@ -25,9 +25,12 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
return &request return &request
} }
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest { func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
if request.Model == "" {
request.Model = "text-embedding-v1"
}
return &AliEmbeddingRequest{ return &AliEmbeddingRequest{
Model: "text-embedding-v1", Model: request.Model,
Input: struct { Input: struct {
Texts []string `json:"texts"` Texts []string `json:"texts"`
}{ }{

View File

@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
if err != nil { if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err) return nil, fmt.Errorf("setup request header failed: %w", err)
} }
resp, err := doRequest(c, req, info.ToRelayInfo()) resp, err := doRequest(c, req, info.RelayInfo)
if err != nil { if err != nil {
return nil, fmt.Errorf("do request failed: %w", err) return nil, fmt.Errorf("do request failed: %w", err)
} }

View File

@@ -8,6 +8,7 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
) )
const ( const (
@@ -38,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
model_setting.GetClaudeSettings().WriteHeaders(req)
return nil return nil
} }
@@ -49,8 +51,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
var claudeReq *claude.ClaudeRequest var claudeReq *claude.ClaudeRequest
var err error var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {
c.Set("request_model", request.Model) return nil, err
}
c.Set("request_model", claudeReq.Model)
c.Set("converted_request", claudeReq) c.Set("converted_request", claudeReq)
return claudeReq, err return claudeReq, err
} }
@@ -59,6 +63,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil return nil, nil
} }

View File

@@ -9,7 +9,8 @@ var awsModelIDMap = map[string]string{
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
} }
var ChannelName = "aws" var ChannelName = "aws"

View File

@@ -16,6 +16,7 @@ type AwsClaudeRequest struct {
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"` Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"`
} }
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest { func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
@@ -30,5 +31,6 @@ func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
StopSequences: req.StopSequences, StopSequences: req.StopSequences,
Tools: req.Tools, Tools: req.Tools,
ToolChoice: req.ToolChoice, ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
} }
} }

View File

@@ -109,9 +109,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil
default: default:
baiduRequest := requestOpenAI2Baidu(*request) baiduRequest := requestOpenAI2Baidu(*request)
return baiduRequest, nil return baiduRequest, nil
@@ -122,6 +119,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
return baiduEmbeddingRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -87,7 +87,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.Cha
return &response return &response
} }
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest { func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{ return &BaiduEmbeddingRequest{
Input: request.ParseInput(), Input: request.ParseInput(),
} }

View File

@@ -0,0 +1,76 @@
package baidu_v2
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,29 @@
package baidu_v2
var ModelList = []string{
"ernie-4.0-8k-latest",
"ernie-4.0-8k-preview",
"ernie-4.0-8k",
"ernie-4.0-turbo-8k-latest",
"ernie-4.0-turbo-8k-preview",
"ernie-4.0-turbo-8k",
"ernie-4.0-turbo-128k",
"ernie-3.5-8k-preview",
"ernie-3.5-8k",
"ernie-3.5-128k",
"ernie-speed-8k",
"ernie-speed-128k",
"ernie-speed-pro-128k",
"ernie-lite-8k",
"ernie-lite-pro-128k",
"ernie-tiny-8k",
"ernie-char-8k",
"ernie-char-fiction-8k",
"ernie-novel-8k",
"deepseek-v3",
"deepseek-r1",
"deepseek-r1-distill-qwen-32b",
"deepseek-r1-distill-qwen-14b",
}
var ChannelName = "volcengine"

View File

@@ -9,6 +9,7 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"strings" "strings"
) )
@@ -55,6 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
anthropicVersion = "2023-06-01" anthropicVersion = "2023-06-01"
} }
req.Set("anthropic-version", anthropicVersion) req.Set("anthropic-version", anthropicVersion)
model_setting.GetClaudeSettings().WriteHeaders(req)
return nil return nil
} }
@@ -73,6 +75,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -11,6 +11,8 @@ var ModelList = []string{
"claude-3-5-haiku-20241022", "claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking",
} }
var ChannelName = "claude" var ChannelName = "claude"

View File

@@ -11,6 +11,9 @@ type ClaudeMediaMessage struct {
Usage *ClaudeUsage `json:"usage,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"` StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"` PartialJson string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
Delta string `json:"delta,omitempty"`
// tool_calls // tool_calls
Id string `json:"id,omitempty"` Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
@@ -54,9 +57,15 @@ type ClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
} }
type ClaudeError struct { type ClaudeError struct {

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"one-api/setting/model_setting"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -92,6 +93,30 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
Stream: textRequest.Stream, Stream: textRequest.Stream,
Tools: claudeTools, Tools: claudeTools,
} }
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") {
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().ThinkingAdapterMaxTokens)
}
// 因为BudgetTokens 必须大于1024
if claudeRequest.MaxTokens < 1280 {
claudeRequest.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &Thinking{
Type: "enabled",
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
claudeRequest.TopP = 0
claudeRequest.Temperature = common.GetPointer[float64](1.0)
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
}
if claudeRequest.MaxTokens == 0 { if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096 claudeRequest.MaxTokens = 4096
} }
@@ -273,7 +298,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0) tools := make([]dto.ToolCallResponse, 0)
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion) choice.Delta.SetContentString(claudeResponse.Completion)
@@ -292,10 +317,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.ContentBlock != nil { if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" { if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{ tools = append(tools, dto.ToolCallResponse{
ID: claudeResponse.ContentBlock.Id, ID: claudeResponse.ContentBlock.Id,
Type: "function", Type: "function",
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name, Name: claudeResponse.ContentBlock.Name,
Arguments: "", Arguments: "",
}, },
@@ -308,12 +333,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.Delta != nil { if claudeResponse.Delta != nil {
choice.Index = claudeResponse.Index choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text) choice.Delta.SetContentString(claudeResponse.Delta.Text)
if claudeResponse.Delta.Type == "input_json_delta" { switch claudeResponse.Delta.Type {
tools = append(tools, dto.ToolCall{ case "input_json_delta":
Function: dto.FunctionCall{ tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson, Arguments: claudeResponse.Delta.PartialJson,
}, },
}) })
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
thinkingContent := claudeResponse.Delta.Thinking
choice.Delta.ReasoningContent = &thinkingContent
} }
} }
} else if claudeResponse.Type == "message_delta" { } else if claudeResponse.Type == "message_delta" {
@@ -351,7 +384,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(claudeResponse.Content) > 0 { if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text responseText = claudeResponse.Content[0].Text
} }
tools := make([]dto.ToolCall, 0) tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
@@ -367,16 +402,22 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
} else { } else {
fullTextResponse.Id = claudeResponse.Id fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content { for _, message := range claudeResponse.Content {
if message.Type == "tool_use" { switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input) args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCall{ tools = append(tools, dto.ToolCallResponse{
ID: message.Id, ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Name: message.Name, Name: message.Name,
Arguments: string(args), Arguments: string(args),
}, },
}) })
case "thinking":
// 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking
case "text":
responseText = message.Text
} }
} }
} }
@@ -391,6 +432,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(tools) > 0 { if len(tools) > 0 {
choice.Message.SetToolCalls(tools) choice.Message.SetToolCalls(tools)
} }
choice.Message.ReasoningContent = thinkingContent
fullTextResponse.Model = claudeResponse.Model fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice) choices = append(choices, choice)
fullTextResponse.Choices = choices fullTextResponse.Choices = choices

View File

@@ -4,13 +4,14 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"github.com/gin-gonic/gin"
) )
type Adaptor struct { type Adaptor struct {
@@ -56,6 +57,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil return request, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
// 添加文件字段 // 添加文件字段
file, _, err := c.Request.FormFile("file") file, _, err := c.Request.FormFile("file")

View File

@@ -54,6 +54,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return requestConvertRerank2Cohere(request), nil return requestConvertRerank2Cohere(request), nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info) err, usage = cohereRerankHandler(c, resp, info)

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant"
) )
type Adaptor struct { type Adaptor struct {
@@ -29,7 +30,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -49,6 +55,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -9,9 +9,18 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"strings"
)
const (
BotTypeChatFlow = 1 // chatflow default
BotTypeAgent = 2
BotTypeWorkFlow = 3
BotTypeCompletion = 4
) )
type Adaptor struct { type Adaptor struct {
BotType int
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -25,10 +34,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "agent") {
a.BotType = BotTypeAgent
} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
a.BotType = BotTypeWorkFlow
} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
a.BotType = BotTypeCompletion
} else {
a.BotType = BotTypeChatFlow
}
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil switch a.BotType {
case BotTypeWorkFlow:
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
case BotTypeCompletion:
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
case BotTypeAgent:
fallthrough
default:
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -48,6 +75,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -1,15 +1,21 @@
package gemini package gemini
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/constant" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting/model_setting"
"strings"
"github.com/gin-gonic/gin"
) )
type Adaptor struct { type Adaptor struct {
@@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
} }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
return nil, errors.New("not implemented") return nil, errors.New("not supported model for image generation")
}
// convert size to aspect ratio
aspectRatio := "1:1" // default aspect ratio
switch request.Size {
case "1024x1024":
aspectRatio = "1:1"
case "1024x1792":
aspectRatio = "9:16"
case "1792x1024":
aspectRatio = "16:9"
}
// build gemini imagen request
geminiRequest := GeminiImageRequest{
Instances: []GeminiImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: GeminiImageParameters{
SampleCount: request.N,
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},
}
return geminiRequest, nil
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -30,14 +64,10 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta" version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta { if strings.HasPrefix(info.UpstreamModelName, "imagen") {
if info.ApiVersion != "" { return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
version = info.ApiVersion
} else {
version = "v1beta"
}
} }
action := "generateContent" action := "generateContent"
@@ -68,11 +98,20 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, resp, info)
}
if info.IsStream { if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info) err, usage = GeminiChatStreamHandler(c, resp, info)
} else { } else {
@@ -81,6 +120,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return return
} }
func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage = &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetModelList() []string {
return ModelList return ModelList
} }

View File

@@ -3,17 +3,29 @@ package gemini
var ModelList = []string{ var ModelList = []string{
// stable version // stable version
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
"gemini-2.0-flash",
// latest version // latest version
"gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
// legacy version // preview version
"gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827", "gemini-2.0-flash-lite-preview",
// exp // gemini exp
"gemini-exp-1114", "gemini-exp-1121", "gemini-exp-1206", "gemini-exp-1206",
// flash exp // flash exp
"gemini-2.0-flash-exp", "gemini-2.0-flash-exp",
// pro exp
"gemini-2.0-pro-exp",
// thinking exp // thinking exp
"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp",
"gemini-2.0-flash-thinking-exp-1219", // imagen models
"imagen-3.0-generate-002",
}
var SafetySettingList = []string{
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
} }
var ChannelName = "google gemini" var ChannelName = "google gemini"

View File

@@ -109,3 +109,30 @@ type GeminiUsageMetadata struct {
CandidatesTokenCount int `json:"candidatesTokenCount"` CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"` TotalTokenCount int `json:"totalTokenCount"`
} }
// Imagen related structs
type GeminiImageRequest struct {
Instances []GeminiImageInstance `json:"instances"`
Parameters GeminiImageParameters `json:"parameters"`
}
type GeminiImageInstance struct {
Prompt string `json:"prompt"`
}
type GeminiImageParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
type GeminiImageResponse struct {
Predictions []GeminiImagePrediction `json:"predictions"`
}
type GeminiImagePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
SafetyAttributes any `json:"safetyAttributes,omitempty"`
}

View File

@@ -11,6 +11,7 @@ import (
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"one-api/setting/model_setting"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
@@ -22,28 +23,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
geminiRequest := GeminiChatRequest{ geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{ //SafetySettings: []GeminiChatSafetySettings{},
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
Threshold: common.GeminiSafetySetting,
},
},
GenerationConfig: GeminiChatGenerationConfig{ GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
TopP: textRequest.TopP, TopP: textRequest.TopP,
@@ -52,9 +32,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}, },
} }
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
Category: category,
Threshold: model_setting.GetGeminiSafetySetting(category),
})
}
geminiRequest.SafetySettings = safetySettings
// openaiContent.FuncToToolCalls() // openaiContent.FuncToToolCalls()
if textRequest.Tools != nil { if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools)) functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false googleSearch := false
codeExecution := false codeExecution := false
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
@@ -349,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data return data
} }
func getToolCall(item *GeminiPart) *dto.ToolCall { func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte var argsBytes []byte
var err error var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok { if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -361,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
if err != nil { if err != nil {
return nil return nil
} }
return &dto.ToolCall{ return &dto.ToolCallResponse{
ID: fmt.Sprintf("call_%s", common.GetUUID()), ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function", Type: "function",
Function: dto.FunctionCall{ Function: dto.FunctionResponse{
Arguments: string(argsBytes), Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName, Name: item.FunctionCall.FunctionName,
}, },
@@ -379,7 +368,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
} }
content, _ := json.Marshal("") content, _ := json.Marshal("")
is_tool_call := false isToolCall := false
for _, candidate := range response.Candidates { for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: int(candidate.Index), Index: int(candidate.Index),
@@ -391,12 +380,12 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
var texts []string var texts []string
var tool_calls []dto.ToolCall var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil { if call := getResponseToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call) toolCalls = append(toolCalls, *call)
} }
} else { } else {
if part.ExecutableCode != nil { if part.ExecutableCode != nil {
@@ -411,9 +400,9 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
} }
} }
} }
if len(tool_calls) > 0 { if len(toolCalls) > 0 {
choice.Message.SetToolCalls(tool_calls) choice.Message.SetToolCalls(toolCalls)
is_tool_call = true isToolCall = true
} }
choice.Message.SetStringContent(strings.Join(texts, "\n")) choice.Message.SetStringContent(strings.Join(texts, "\n"))
@@ -429,7 +418,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
choice.FinishReason = constant.FinishReasonContentFilter choice.FinishReason = constant.FinishReasonContentFilter
} }
} }
if is_tool_call { if isToolCall {
choice.FinishReason = constant.FinishReasonToolCalls choice.FinishReason = constant.FinishReasonToolCalls
} }
@@ -468,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
isTools = true isTools = true
if call := getToolCall(&part); call != nil { if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls)) call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call) choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
} }

View File

@@ -55,9 +55,13 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil return request, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp) err, usage = JinaRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings { } else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = jinaEmbeddingHandler(c, resp) err, usage = jinaEmbeddingHandler(c, resp)
} }

View File

@@ -9,7 +9,7 @@ import (
"one-api/service" "one-api/service"
) )
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -41,15 +41,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
mistralReq := requestOpenAI2Mistral(*request) return requestOpenAI2Mistral(request), nil
//common.LogJson(c, "body", mistralReq)
return mistralReq, nil
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -1,25 +1,21 @@
package mistral package mistral
import ( import (
"encoding/json"
"one-api/dto" "one-api/dto"
) )
func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages)) messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
if !message.IsStringContent() { mediaMessages := message.ParseContent()
mediaMessages := message.ParseContent() for j, mediaMessage := range mediaMessages {
for j, mediaMessage := range mediaMessages { if mediaMessage.Type == dto.ContentTypeImageURL {
if mediaMessage.Type == dto.ContentTypeImageURL { imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) mediaMessage.ImageUrl = imageUrl.Url
mediaMessage.ImageUrl = imageUrl.Url mediaMessages[j] = mediaMessage
mediaMessages[j] = mediaMessage
}
} }
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
} }
message.SetMediaContent(mediaMessages)
messages = append(messages, dto.Message{ messages = append(messages, dto.Message{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.Content,

View File

@@ -0,0 +1,93 @@
package mokaai
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return request, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
suffix = "embeddings"
}
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request)
return baiduEmbeddingRequest, nil
default:
return nil, errors.New("not implemented")
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = mokaEmbeddingHandler(c, resp)
default:
// err, usage = mokaHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,9 @@
package mokaai
var ModelList = []string{
"m3e-large",
"m3e-base",
"m3e-small",
}
var ChannelName = "mokaai"

View File

@@ -0,0 +1,83 @@
package mokaai
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
var input []string // Change input to []string
switch v := request.Input.(type) {
case string:
input = []string{v} // Convert string to []string
case []string:
input = v // Already a []string, no conversion needed
case []interface{}:
for _, part := range v {
if str, ok := part.(string); ok {
input = append(input, str) // Append each string to the slice
}
}
}
return &dto.EmbeddingRequest{
Input: input,
Model: request.Model,
}
}
func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var baiduResponse dto.EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// if baiduResponse.ErrorMsg != "" {
// return &dto.OpenAIErrorWithStatusCode{
// Error: dto.OpenAIError{
// Type: "baidu_error",
// Param: "",
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -39,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil return nil
} }
@@ -46,18 +47,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch info.RelayMode { return requestOpenAI2Ollama(*request)
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
return requestOpenAI2Ollama(*request), nil
}
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return requestOpenAI2Embeddings(request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -3,18 +3,22 @@ package ollama
import "one-api/dto" import "one-api/dto"
type OllamaRequest struct { type OllamaRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"` Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"` Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
ResponseFormat any `json:"response_format,omitempty"` Tools []dto.ToolCallRequest `json:"tools,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` ResponseFormat any `json:"response_format,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
} }
type Options struct { type Options struct {
@@ -35,7 +39,7 @@ type OllamaEmbeddingRequest struct {
} }
type OllamaEmbeddingResponse struct { type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
Model string `json:"model"` Model string `json:"model"`
Embedding [][]float64 `json:"embeddings,omitempty"` Embedding [][]float64 `json:"embeddings,omitempty"`
} }

View File

@@ -9,14 +9,36 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/service" "one-api/service"
"strings"
) )
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages)) messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
if err != nil {
return nil, err
}
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
}
mediaMessage.ImageUrl = imageUrl
mediaMessages[j] = mediaMessage
}
}
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{ messages = append(messages, dto.Message{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
}) })
} }
str, ok := request.Stop.(string) str, ok := request.Stop.(string)
@@ -36,13 +58,17 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
TopK: request.TopK, TopK: request.TopK,
Stop: Stop, Stop: Stop,
Tools: request.Tools, Tools: request.Tools,
MaxTokens: request.MaxTokens,
ResponseFormat: request.ResponseFormat, ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty, FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty, PresencePenalty: request.PresencePenalty,
} Prompt: request.Prompt,
StreamOptions: request.StreamOptions,
Suffix: request.Suffix,
}, nil
} }
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{ return &OllamaEmbeddingRequest{
Model: request.Model, Model: request.Model,
Input: request.ParseInput(), Input: request.ParseInput(),
@@ -123,9 +149,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
} }
func flattenEmbeddings(embeddings [][]float64) []float64 { func flattenEmbeddings(embeddings [][]float64) []float64 {
flattened := []float64{} flattened := []float64{}
for _, row := range embeddings { for _, row := range embeddings {
flattened = append(flattened, row...) flattened = append(flattened, row...)
}
return flattened
} }
return flattened
}

View File

@@ -10,9 +10,11 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"one-api/common" "one-api/common"
constant2 "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/ai360" "one-api/relay/channel/ai360"
"one-api/relay/channel/jina"
"one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax" "one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot" "one-api/relay/channel/moonshot"
@@ -44,16 +46,20 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
switch info.ChannelType { switch info.ChannelType {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
apiVersion := info.ApiVersion
if apiVersion == "" {
apiVersion = constant2.AzureDefaultAPIVersion
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0] requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion) requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/") task := strings.TrimPrefix(requestURL, "/v1/")
model_ := info.UpstreamModelName model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1) model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67 // https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime { if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion) requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
} }
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax: case common.ChannelTypeMiniMax:
@@ -114,7 +120,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
request.MaxCompletionTokens = request.MaxTokens request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0 request.MaxTokens = 0
} }
if strings.HasPrefix(request.Model, "o3") { if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
request.Temperature = nil request.Temperature = nil
} }
if strings.HasSuffix(request.Model, "-high") { if strings.HasSuffix(request.Model, "-high") {
@@ -141,7 +147,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented") return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -219,6 +229,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info) err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank:
err, usage = jina.JinaRerankHandler(c, resp)
default: default:
if info.IsStream { if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info) err, usage = OaiStreamHandler(c, resp, info)

View File

@@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/pkg/errors"
"io" "io"
"math" "math"
"mime/multipart" "mime/multipart"
@@ -24,21 +23,62 @@ import (
"github.com/bytedance/gopkg/util/gopool" "github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pkg/errors"
) )
func sendStreamData(c *gin.Context, data string, forceFormat bool) error { func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
if data == "" { if data == "" {
return nil return nil
} }
if forceFormat { if !forceFormat && !thinkToContent {
var lastStreamResponse dto.ChatCompletionsStreamResponse return service.StringData(c, data)
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil { }
return err
} var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
return err
}
if !thinkToContent {
return service.ObjectData(c, lastStreamResponse) return service.ObjectData(c, lastStreamResponse)
} }
return service.StringData(c, data)
// Handle think to content conversion
if info.IsFirstResponse {
response := lastStreamResponse.Copy()
for i := range response.Choices {
response.Choices[i].Delta.SetContentString("<think>\n")
response.Choices[i].Delta.SetReasoningContent("")
}
service.ObjectData(c, response)
}
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
return service.ObjectData(c, lastStreamResponse)
}
// Process each choice
for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
response := lastStreamResponse.Copy()
for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>")
response.Choices[j].Delta.SetReasoningContent("")
}
info.SendLastReasoningResponse = true
service.ObjectData(c, response)
}
// Convert reasoning content to regular content
if len(choice.Delta.GetReasoningContent()) > 0 {
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
lastStreamResponse.Choices[i].Delta.SetReasoningContent("")
}
}
return service.ObjectData(c, lastStreamResponse)
} }
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -57,11 +97,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var usage = &dto.Usage{} var usage = &dto.Usage{}
var streamItems []string // store stream items var streamItems []string // store stream items
var forceFormat bool var forceFormat bool
var thinkToContent bool
if info.ChannelType == common.ChannelTypeCustom { if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok { forceFormat = forceFmt
forceFormat = forceFmt }
}
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
thinkToContent = think2Content
} }
toolCount := 0 toolCount := 0
@@ -85,23 +128,28 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
) )
gopool.Go(func() { gopool.Go(func() {
for scanner.Scan() { for scanner.Scan() {
info.SetFirstResponseTime() //info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text() data := scanner.Text()
if common.DebugEnabled {
println(data)
}
if len(data) < 6 { // ignore blank line or wrong format if len(data) < 6 { // ignore blank line or wrong format
continue continue
} }
if data[:6] != "data: " && data[:6] != "[DONE]" { if data[:5] != "data:" && data[:6] != "[DONE]" {
continue continue
} }
mu.Lock() mu.Lock()
data = data[6:] data = data[5:]
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" { if lastStreamData != "" {
err := sendStreamData(c, lastStreamData, forceFormat) err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil { if err != nil {
common.LogError(c, "streaming error: "+err.Error()) common.LogError(c, "streaming error: "+err.Error())
} }
info.SetFirstResponseTime()
} }
lastStreamData = data lastStreamData = data
streamItems = append(streamItems, data) streamItems = append(streamItems, data)
@@ -141,7 +189,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
} }
if shouldSendLastResp { if shouldSendLastResp {
sendStreamData(c, lastStreamData, forceFormat) sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
} }
// 计算token // 计算token
@@ -162,6 +210,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//} //}
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil { if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount { if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls) toolCount = len(choice.Delta.ToolCalls)
@@ -182,6 +231,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//} //}
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil { if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount { if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls) toolCount = len(choice.Delta.ToolCalls)
@@ -273,7 +323,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range simpleResponse.Choices { for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model) ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
completionTokens += ctkm completionTokens += ctkm
} }
simpleResponse.Usage = dto.Usage{ simpleResponse.Usage = dto.Usage{

View File

@@ -0,0 +1,74 @@
package openrouter
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
req.Set("X-Title", "New API")
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,5 @@
package openrouter
var ModelList = []string{}
var ChannelName = "openrouter"

View File

@@ -49,6 +49,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -52,6 +52,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions { } else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
} }
return "", errors.New("invalid relay mode") return "", errors.New("invalid relay mode")
} }
@@ -58,6 +60,10 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil return request, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeRerank: case constant.RelayModeRerank:
@@ -68,6 +74,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
case constant.RelayModeCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@@ -40,7 +40,7 @@ var ModelList = []string{
"Pro/meta-llama/Meta-Llama-3-8B-Instruct", "Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2", "Pro/mistralai/Mistral-7B-Instruct-v0.2",
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell",
"iic/SenseVoiceSmall", "FunAudioLLM/SenseVoiceSmall",
"netease-youdao/bce-embedding-base_v1", "netease-youdao/bce-embedding-base_v1",
"BAAI/bge-m3", "BAAI/bge-m3",
"internlm/internlm2_5-20b-chat", "internlm/internlm2_5-20b-chat",

View File

@@ -73,6 +73,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -28,6 +28,7 @@ var claudeModelMap = map[string]string{
"claude-3-opus-20240229": "claude-3-opus@20240229", "claude-3-opus-20240229": "claude-3-opus@20240229",
"claude-3-haiku-20240307": "claude-3-haiku@20240307", "claude-3-haiku-20240307": "claude-3-haiku@20240307",
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
} }
const anthropicVersion = "vertex-2023-10-16" const anthropicVersion = "vertex-2023-10-16"
@@ -132,7 +133,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil { if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
return nil, errors.New("failed to copy claude request") return nil, errors.New("failed to copy claude request")
} }
c.Set("request_model", request.Model) c.Set("request_model", claudeReq.Model)
return vertexClaudeReq, nil return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini { } else if a.RequestMode == RequestModeGemini {
geminiRequest, err := gemini.CovertGemini2OpenAI(*request) geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
@@ -151,6 +152,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -0,0 +1,92 @@
package volcengine
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"strings"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
}
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,13 @@
package volcengine
var ModelList = []string{
"Doubao-pro-128k",
"Doubao-pro-32k",
"Doubao-pro-4k",
"Doubao-lite-128k",
"Doubao-lite-32k",
"Doubao-lite-4k",
"Doubao-embedding",
}
var ChannelName = "volcengine"

View File

@@ -50,6 +50,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// xunfei's request is not http request, so we don't need to do anything here // xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{} dummyResp := &http.Response{}

View File

@@ -56,6 +56,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -53,6 +53,12 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }

View File

@@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
mediaMessages[j] = mediaMessage mediaMessages[j] = mediaMessage
} }
} }
messageRaw, _ := json.Marshal(mediaMessages) message.SetMediaContent(mediaMessages)
message.Content = messageRaw
} }
messages = append(messages, dto.Message{ messages = append(messages, dto.Message{
Role: message.Role, Role: message.Role,

View File

@@ -13,24 +13,25 @@ import (
) )
type RelayInfo struct { type RelayInfo struct {
ChannelType int ChannelType int
ChannelId int ChannelId int
TokenId int TokenId int
TokenKey string TokenKey string
UserId int UserId int
Group string Group string
TokenUnlimited bool TokenUnlimited bool
StartTime time.Time StartTime time.Time
FirstResponseTime time.Time FirstResponseTime time.Time
setFirstResponse bool IsFirstResponse bool
ApiType int SendLastReasoningResponse bool
IsStream bool ApiType int
IsPlayground bool IsStream bool
UsePrice bool IsPlayground bool
RelayMode int UsePrice bool
UpstreamModelName string RelayMode int
OriginModelName string UpstreamModelName string
RecodeModelName string OriginModelName string
//RecodeModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string
PromptTokens int PromptTokens int
@@ -39,6 +40,7 @@ type RelayInfo struct {
BaseUrl string BaseUrl string
SupportStreamOptions bool SupportStreamOptions bool
ShouldIncludeUsage bool ShouldIncludeUsage bool
IsModelMapped bool
ClientWs *websocket.Conn ClientWs *websocket.Conn
TargetWs *websocket.Conn TargetWs *websocket.Conn
InputAudioFormat string InputAudioFormat string
@@ -48,6 +50,21 @@ type RelayInfo struct {
AudioUsage bool AudioUsage bool
ReasoningEffort string ReasoningEffort string
ChannelSetting map[string]interface{} ChannelSetting map[string]interface{}
UserSetting map[string]interface{}
UserEmail string
UserQuota int
}
// 定义支持流式选项的通道类型
var streamSupportedChannels = map[int]bool{
common.ChannelTypeOpenAI: true,
common.ChannelTypeAnthropic: true,
common.ChannelTypeAws: true,
common.ChannelTypeGemini: true,
common.ChannelCloudflare: true,
common.ChannelTypeAzure: true,
common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true,
} }
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -75,6 +92,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
apiType, _ := relayconstant.ChannelType2APIType(channelType) apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{ info := &RelayInfo{
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
UserEmail: c.GetString(constant.ContextKeyUserEmail),
IsFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"), BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),
@@ -89,12 +110,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
FirstResponseTime: startTime.Add(-time.Second), FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"), OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"), UpstreamModelName: c.GetString("original_model"),
RecodeModelName: c.GetString("recode_model"), //RecodeModelName: c.GetString("original_model"),
ApiType: apiType, IsModelMapped: false,
ApiVersion: c.GetString("api_version"), ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiVersion: c.GetString("api_version"),
Organization: c.GetString("channel_organization"), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
ChannelSetting: channelSetting, Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
} }
if strings.HasPrefix(c.Request.URL.Path, "/pg") { if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true info.IsPlayground = true
@@ -110,9 +132,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeVertexAi { if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region") info.ApiVersion = c.GetString("region")
} }
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || if streamSupportedChannels[info.ChannelType] {
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
info.SupportStreamOptions = true info.SupportStreamOptions = true
} }
return info return info
@@ -127,26 +147,14 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
} }
func (info *RelayInfo) SetFirstResponseTime() { func (info *RelayInfo) SetFirstResponseTime() {
if !info.setFirstResponse { if info.IsFirstResponse {
info.FirstResponseTime = time.Now() info.FirstResponseTime = time.Now()
info.setFirstResponse = true info.IsFirstResponse = false
} }
} }
type TaskRelayInfo struct { type TaskRelayInfo struct {
ChannelType int *RelayInfo
ChannelId int
TokenId int
UserId int
Group string
StartTime time.Time
ApiType int
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiKey string
BaseUrl string
Action string Action string
OriginTaskID string OriginTaskID string
@@ -154,48 +162,8 @@ type TaskRelayInfo struct {
} }
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &TaskRelayInfo{ info := &TaskRelayInfo{
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RelayInfo: GenRelayInfo(c),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
StartTime: startTime,
ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
} }
return info return info
} }
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}

View File

@@ -27,7 +27,10 @@ const (
APITypeVertexAi APITypeVertexAi
APITypeMistral APITypeMistral
APITypeDeepSeek APITypeDeepSeek
APITypeMokaAI
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeDummy // this one is only for count, do not add any channel after this APITypeDummy // this one is only for count, do not add any channel after this
) )
@@ -78,6 +81,14 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeMistral apiType = APITypeMistral
case common.ChannelTypeDeepSeek: case common.ChannelTypeDeepSeek:
apiType = APITypeDeepSeek apiType = APITypeDeepSeek
case common.ChannelTypeMokaAI:
apiType = APITypeMokaAI
case common.ChannelTypeVolcEngine:
apiType = APITypeVolcEngine
case common.ChannelTypeBaiduV2:
apiType = APITypeBaiduV2
case common.ChannelTypeOpenRouter:
apiType = APITypeOpenRouter
} }
if apiType == -1 { if apiType == -1 {
return APITypeOpenAI, false return APITypeOpenAI, false

Some files were not shown because too many files have changed in this diff Show More