Compare commits

...

36 Commits

Author SHA1 Message Date
1808837298@qq.com
3a18c0ce9f fix: Correct model request configuration in Vertex Claude adaptor 2025-02-27 20:51:10 +08:00
1808837298@qq.com
929668bead feat: Refactor model configuration management with new config system
- Introduce a new configuration management approach for model-specific settings
- Update Gemini settings to use the new config system with more flexible management
- Add support for dynamic configuration updates in option handling
- Modify Claude and Vertex adaptors to use new configuration methods
- Enhance web interface to support namespaced configuration keys
2025-02-27 20:49:34 +08:00
1808837298@qq.com
06a78f9042 feat: Add Claude model configuration management #791 2025-02-27 20:49:21 +08:00
1808837298@qq.com
0f1c4c4ebe fix: Add pagination support to user search functionality 2025-02-27 16:55:02 +08:00
1808837298@qq.com
1bcf7a3c39 chore: Update Azure OpenAI API version and embedding model detection
- Enhance channel test to detect more embedding models
- Update Azure OpenAI default API version to 2024-12-01-preview
- Remove redundant default API version setting in channel edit
- Add user cache writing in channel test
2025-02-27 16:49:32 +08:00
1808837298@qq.com
5f0b3f6d6f fix: Improve AWS Claude adaptor request conversion error handling #796 2025-02-27 14:57:00 +08:00
1808837298@qq.com
19a318c943 init openrouter adaptor 2025-02-27 00:01:21 +08:00
1808837298@qq.com
13ab0f8e4f fix: gemini&claude tool call format #795 #766 2025-02-26 23:56:10 +08:00
1808837298@qq.com
6d8d40e67b fix: claude tool call format #795 #766 2025-02-26 23:40:16 +08:00
1808837298@qq.com
287caf8e38 feat: Add Jina reranking support for OpenAI adaptor 2025-02-26 21:46:06 +08:00
1808837298@qq.com
c802b3b41a fix: Update Gemini safety settings to use 'OFF' as default 2025-02-26 19:20:17 +08:00
1808837298@qq.com
ed4e1c2332 fix: Update Gemini safety settings category 2025-02-26 19:18:00 +08:00
1808837298@qq.com
e581ea33c2 fix: Update Gemini safety settings default value 2025-02-26 19:01:45 +08:00
1808837298@qq.com
bf80d71ddf feat: Add Gemini version settings configuration support (close #568) 2025-02-26 18:19:09 +08:00
1808837298@qq.com
e19b244e73 feat: Add Gemini safety settings configuration support (close #703) 2025-02-26 16:54:43 +08:00
1808837298@qq.com
f451268830 feat: Update Claude relay temperature setting 2025-02-25 22:01:05 +08:00
1808837298@qq.com
069f2672c1 refactor: Enhance user context and quota management
- Add new context keys for user-related information
- Modify user cache and authentication middleware to populate context
- Refactor quota and notification services to use context-based user data
- Remove redundant database queries by leveraging context information
- Update various components to use new context-based user retrieval methods
2025-02-25 20:56:16 +08:00
1808837298@qq.com
ccf13d445f feat: redis poolsize 2025-02-25 19:39:29 +08:00
1808837298@qq.com
da4d1861fe fix: Adjust Claude thinking mode request parameters 2025-02-25 16:52:45 +08:00
1808837298@qq.com
3de5b96cb4 docs: Update README 2025-02-25 16:31:42 +08:00
Calcium-Ion
5b9e275690 Merge pull request #788 from MartialBE/main
feat: Add Claude 3.7 Sonnet thinking mode support
2025-02-25 15:21:39 +08:00
1808837298@qq.com
607e3206b3 Merge branch 'main' into thinking
# Conflicts:
#	relay/channel/claude/dto.go
2025-02-25 15:21:22 +08:00
1808837298@qq.com
83feb492fb feat: Add support for Claude thinking parameter in request 2025-02-25 14:37:03 +08:00
MartialBE
4f212be45c feat: Add Claude 3.7 Sonnet thinking mode support 2025-02-25 14:10:43 +08:00
1808837298@qq.com
92918e3751 feat: Add Claude 3.7 Sonnet model to AWS channel mapping 2025-02-25 02:55:23 +08:00
1808837298@qq.com
de15551570 feat: Add support for Claude 3.7 Sonnet model 2025-02-25 02:51:31 +08:00
1808837298@qq.com
a81a28b7a5 feat: Support max_tokens parameter for Ollama channel #782 2025-02-24 17:35:49 +08:00
Calcium-Ion
dc36fdedc2 Merge pull request #781 from zeyugao/main
feat: Pass extra_body in OpenAI request to the backend
2025-02-24 16:29:48 +08:00
Calcium-Ion
3017882fa3 Merge pull request #783 from Calcium-Ion/rate-limit
feat: Add model request rate limiting functionality
2025-02-24 16:29:23 +08:00
1808837298@qq.com
e9ba392af8 feat: Add model rate limit settings in system configuration 2025-02-24 16:27:20 +08:00
1808837298@qq.com
83a37e4653 feat: Add model request rate limiting functionality 2025-02-24 16:20:55 +08:00
1808837298@qq.com
b6f95dca41 feat: Add support for different Dify bot types and request URLs 2025-02-24 14:18:30 +08:00
1808837298@qq.com
7ff4cebdbe feat: Enhance token counting and content parsing for messages 2025-02-24 14:18:15 +08:00
Elsa
af00f7b311 Pass extra_body to the backend 2025-02-24 10:52:55 +08:00
1808837298@qq.com
cc1d6e1c05 fix: Improve 429 error logging with detailed message 2025-02-23 21:26:31 +08:00
1808837298@qq.com
6c7a8c811c fix typo 2025-02-23 17:27:33 +08:00
66 changed files with 1852 additions and 428 deletions

View File

@@ -50,10 +50,6 @@
# CHANNEL_TEST_FREQUENCY=10
# 生成默认token
# GENERATE_DEFAULT_TOKEN=false
# Gemini 安全设置
# GEMINI_SAFETY_SETTING=BLOCK_NONE
# Gemini版本设置
# GEMINI_MODEL_MAP=gemini-1.0-pro:v1
# Cohere 安全设置
# COHERE_SAFETY_SETTING=NONE
# 是否统计图片token

View File

@@ -63,7 +63,8 @@
- Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`)
- Add suffix `-medium` to set medium reasoning effort
- Add suffix `-low` to set low reasoning effort
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_conetnt` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings`
## Model Support
This version additionally supports:

View File

@@ -66,10 +66,14 @@
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API支持Azure渠道
16. 支持使用路由/chat2link 进入聊天界面
17. 🧠 支持通过模型名称后缀设置 reasoning effort
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回。
1. OpenAI o系列模型
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
2. Claude 思考模型
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制
## 模型支持
此版本额外支持以下模型:
@@ -90,7 +94,6 @@
- `GET_MEDIA_TOKEN`是否统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用"模型:版本"指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB默认为 `20`
@@ -99,6 +102,10 @@
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
## 已废弃的环境变量
- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置
- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置
## 部署
> [!TIP]

View File

@@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"", //37
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40

View File

@@ -83,92 +83,94 @@ var defaultModelRatio = map[string]float64{
"text-curie-001": 1,
//"text-davinci-002": 10,
//"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5,
"claude-3-7-sonnet-20250219": 1.5,
"claude-3-7-sonnet-20250219-thinking": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // 0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18,

View File

@@ -32,6 +32,7 @@ func InitRedisClient() (err error) {
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
}
opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
RDB = redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -41,6 +42,10 @@ func InitRedisClient() (err error) {
if err != nil {
FatalLog("Redis ping test failed: " + err.Error())
}
if DebugEnabled {
SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
}
return err
}
@@ -53,13 +58,20 @@ func ParseRedisOption() *redis.Options {
}
func RedisSet(key string, value string, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
}
ctx := context.Background()
return RDB.Set(ctx, key, value, expiration).Err()
}
func RedisGet(key string) (string, error) {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis GET: key=%s", key))
}
ctx := context.Background()
return RDB.Get(ctx, key).Result()
val, err := RDB.Get(ctx, key).Result()
return val, err
}
//func RedisExpire(key string, expiration time.Duration) error {
@@ -73,16 +85,25 @@ func RedisGet(key string) (string, error) {
//}
func RedisDel(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
}
ctx := context.Background()
return RDB.Del(ctx, key).Err()
}
func RedisHDelObj(key string) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
}
ctx := context.Background()
return RDB.HDel(ctx, key).Err()
}
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
}
ctx := context.Background()
data := make(map[string]interface{})
@@ -130,6 +151,9 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
}
func RedisHGetObj(key string, obj interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
}
ctx := context.Background()
result, err := RDB.HGetAll(ctx, key).Result()
@@ -208,6 +232,9 @@ func RedisHGetObj(key string, obj interface{}) error {
// RedisIncr Add this function to handle atomic increments
func RedisIncr(key string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
}
// 检查键的剩余生存时间
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
@@ -238,6 +265,9 @@ func RedisIncr(key string, delta int64) error {
}
func RedisHIncrBy(key, field string, delta int64) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
}
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
@@ -262,6 +292,9 @@ func RedisHIncrBy(key, field string, delta int64) error {
}
func RedisHSetField(key, field string, value interface{}) error {
if DebugEnabled {
SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
}
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {

View File

@@ -5,6 +5,7 @@ import (
"context"
crand "crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"html/template"
@@ -213,6 +214,24 @@ func RandomSleep() {
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
}
func GetPointer[T any](v T) *T {
return &v
}
func Any2Type[T any](data any) (T, error) {
var zero T
bytes, err := json.Marshal(data)
if err != nil {
return zero, err
}
var res T
err = json.Unmarshal(bytes, &res)
if err != nil {
return zero, err
}
return res, nil
}
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
func SaveTmpFile(filename string, data io.Reader) (string, error) {
f, err := os.CreateTemp(os.TempDir(), filename)

View File

@@ -2,4 +2,9 @@ package constant
const (
ContextKeyRequestStartTime = "request_start_time"
ContextKeyUserSetting = "user_setting"
ContextKeyUserQuota = "user_quota"
ContextKeyUserStatus = "user_status"
ContextKeyUserEmail = "user_email"
ContextKeyUserGroup = "user_group"
)

View File

@@ -1,10 +1,7 @@
package constant
import (
"fmt"
"one-api/common"
"os"
"strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
@@ -23,9 +20,9 @@ var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
var GeminiModelMap = map[string]string{
"gemini-1.0-pro": "v1",
}
//var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1",
//}
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
@@ -33,18 +30,18 @@ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
//if modelVersionMapStr == "" {
// return
//}
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
// parts := strings.Split(pair, ":")
// if len(parts) == 2 {
// GeminiModelMap[parts[0]] = parts[1]
// } else {
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
// }
//}
}
// GenerateDefaultToken 是否生成初始令牌,默认关闭。

View File

@@ -48,7 +48,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
testModel == "text-embedding-v1" ||
strings.Contains(testModel, "embed") ||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
@@ -84,6 +84,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
}
cache, err := model.GetUserCache(1)
if err != nil {
return err, nil
}
cache.WriteContext(c)
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)

View File

@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota)
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}

View File

@@ -85,6 +85,7 @@ func Relay(c *gin.Context) {
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)

View File

@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}

View File

@@ -210,7 +210,7 @@ func EpayNotify(c *gin.Context) {
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return

View File

@@ -1,6 +1,9 @@
package dto
import "encoding/json"
import (
"encoding/json"
"strings"
)
type ResponseFormat struct {
Type string `json:"type,omitempty"`
@@ -15,49 +18,52 @@ type FormatJsonSchema struct {
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"`
}
type OpenAITools struct {
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
type ToolCallRequest struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Function FunctionRequest `json:"function"`
}
type OpenAIFunction struct {
type FunctionRequest struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
type StreamOptions struct {
@@ -133,11 +139,11 @@ func (m *Message) SetPrefix(prefix bool) {
m.Prefix = &prefix
}
func (m *Message) ParseToolCalls() []ToolCall {
func (m *Message) ParseToolCalls() []ToolCallRequest {
if m.ToolCalls == nil {
return nil
}
var toolCalls []ToolCall
var toolCalls []ToolCallRequest
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
return toolCalls
}
@@ -153,11 +159,24 @@ func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return stringContent
}
return string(m.Content)
contentStr := new(strings.Builder)
arrayContent := m.ParseContent()
for _, content := range arrayContent {
if content.Type == ContentTypeText {
contentStr.WriteString(content.Text)
}
}
stringContent = contentStr.String()
m.parsedStringContent = &stringContent
return stringContent
}
func (m *Message) SetStringContent(content string) {

View File

@@ -62,10 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
}
type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -90,24 +90,24 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string)
c.ReasoningContent = &s
}
type ToolCall struct {
type ToolCallResponse struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionCall `json:"function"`
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type any `json:"type"`
Function FunctionResponse `json:"function"`
}
func (c *ToolCall) SetIndex(i int) {
func (c *ToolCallResponse) SetIndex(i int) {
c.Index = &i
}
type FunctionCall struct {
type FunctionResponse struct {
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
// call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"`
Arguments string `json:"arguments"` // response
}
type ChatCompletionsStreamResponse struct {

View File

@@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.IsUserEnabled(token.UserId, false)
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
userEnabled := userCache.Status == common.UserStatusEnabled
if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
userCache.WriteContext(c)
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)

View File

@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
return
}
}
userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.GetUserGroup(userId, false)
userGroup := c.GetString(constant.ContextKeyUserGroup)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]

View File

@@ -0,0 +1,172 @@
package middleware
import (
"context"
"fmt"
"net/http"
"one-api/common"
"one-api/setting"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
const (
ModelRequestRateLimitCountMark = "MRRL"
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
// 检查Redis中的请求限制
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
// 如果maxCount为0表示不限制
if maxCount == 0 {
return true, nil
}
// 获取当前计数
length, err := rdb.LLen(ctx, key).Result()
if err != nil {
return false, err
}
// 如果未达到限制,允许请求
if length < int64(maxCount) {
return true, nil
}
// 检查时间窗口
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
return false, err
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
return false, err
}
// 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
return false, nil
}
return true, nil
}
// 记录Redis请求
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
// 如果maxCount为0不记录请求
if maxCount == 0 {
return
}
now := time.Now().Format(timeFormat)
rdb.LPush(ctx, key, now)
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
}
// Redis限流处理器
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background()
rdb := common.RDB
// 1. 检查总请求数限制当totalMaxCount为0时会自动跳过
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 2. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
return
}
// 3. 记录总请求当totalMaxCount为0时会自动跳过
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
// 4. 处理请求
c.Next()
// 5. 如果请求成功,记录成功请求
if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount)
}
}
}
// 内存限流处理器
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id"))
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
// 1. 检查总请求数限制当totalMaxCount为0时跳过
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 2. 检查成功请求数限制
// 使用一个临时key来检查限制这样可以避免实际记录
checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
// 3. 处理请求
c.Next()
// 4. 如果请求成功,记录到实际的成功请求计数中
if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
}
}
}
// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit() func(c *gin.Context) {
// 如果未启用限流,直接放行
if !setting.ModelRequestRateLimitEnabled {
return defNext
}
// 计算限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 根据存储类型选择限流处理器
if common.RedisEnabled {
return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
} else {
return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
}
}

View File

@@ -1,8 +1,8 @@
package model
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"os"
"strings"
@@ -87,14 +87,14 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
username, _ := GetUsernameById(userId, false)
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
@@ -116,7 +116,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
err := LOG_DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
common.LogError(c, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {

View File

@@ -3,6 +3,7 @@ package model
import (
"one-api/common"
"one-api/setting"
"one-api/setting/config"
"strconv"
"strings"
"time"
@@ -23,6 +24,8 @@ func AllOption() ([]*Option, error) {
func InitOptionMap() {
common.OptionMapRWMutex.Lock()
common.OptionMap = make(map[string]string)
// 添加原有的系统配置
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
@@ -85,6 +88,9 @@ func InitOptionMap() {
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -105,13 +111,19 @@ func InitOptionMap() {
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
for k, v := range modelConfigs {
common.OptionMap[k] = v
}
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -154,6 +166,13 @@ func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock()
common.OptionMap[key] = value
// 检查是否是模型配置 - 使用更规范的方式处理
if handleConfigUpdate(key, value) {
return nil // 已由配置系统处理
}
// 处理传统配置项...
if strings.HasSuffix(key, "Permission") {
intValue, _ := strconv.Atoi(value)
switch key {
@@ -226,8 +245,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.DemoSiteEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
setting.CheckSensitiveOnPromptEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue
case "ModelRequestRateLimitEnabled":
setting.ModelRequestRateLimitEnabled = boolValue
case "StopOnSensitiveEnabled":
setting.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled":
@@ -308,6 +327,12 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRequestRateLimitCount":
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitDurationMinutes":
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
@@ -343,3 +368,28 @@ func updateOptionMap(key string, value string) (err error) {
}
return err
}
// handleConfigUpdate 处理分层配置更新,返回是否已处理
func handleConfigUpdate(key, value string) bool {
parts := strings.SplitN(key, ".", 2)
if len(parts) != 2 {
return false // 不是分层配置
}
configName := parts[0]
configKey := parts[1]
// 获取配置对象
cfg := config.GlobalConfig.Get(configName)
if cfg == nil {
return false // 未注册的配置
}
// 更新配置
configMap := map[string]string{
configKey: value,
}
config.UpdateConfigFromMap(cfg, configMap)
return true // 已处理
}

View File

@@ -320,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
@@ -502,35 +502,35 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser
}
// IsUserEnabled checks user status from Redis first, falls back to DB if needed
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserStatusCache(id, status); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
// Try Redis first
status, err := getUserStatusCache(id)
if err == nil {
return status == common.UserStatusEnabled, nil
}
// Don't return error - fall through to DB
}
fromDB = true
var user User
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
if err != nil {
return false, err
}
return user.Status == common.UserStatusEnabled, nil
}
//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
// defer func() {
// // Update Redis cache asynchronously on successful DB read
// if shouldUpdateRedis(fromDB, err) {
// gopool.Go(func() {
// if err := updateUserStatusCache(id, status); err != nil {
// common.SysError("failed to update user status cache: " + err.Error())
// }
// })
// }
// }()
// if !fromDB && common.RedisEnabled {
// // Try Redis first
// status, err := getUserStatusCache(id)
// if err == nil {
// return status == common.UserStatusEnabled, nil
// }
// // Don't return error - fall through to DB
// }
// fromDB = true
// var user User
// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
// if err != nil {
// return false, err
// }
//
// return user.Status == common.UserStatusEnabled, nil
//}
func ValidateAccessToken(token string) (user *User) {
if token == "" {
@@ -639,7 +639,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
return common.StrToMap(setting), nil
}
func IncreaseUserQuota(id int, quota int) (err error) {
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -649,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
common.SysError("failed to increase user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
if !db && common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
@@ -694,7 +694,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
return nil
}
if delta > 0 {
return IncreaseUserQuota(id, delta)
return IncreaseUserQuota(id, delta, false)
} else {
return DecreaseUserQuota(id, -delta)
}

View File

@@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
"time"
@@ -21,6 +22,15 @@ type UserBase struct {
Setting string `json:"setting"`
}
func (user *UserBase) WriteContext(c *gin.Context) {
c.Set(constant.ContextKeyUserGroup, user.Group)
c.Set(constant.ContextKeyUserQuota, user.Quota)
c.Set(constant.ContextKeyUserStatus, user.Status)
c.Set(constant.ContextKeyUserEmail, user.Email)
c.Set("username", user.Username)
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil

View File

@@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req, info.ToRelayInfo())
resp, err := doRequest(c, req, info.RelayInfo)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}

View File

@@ -8,6 +8,7 @@ import (
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
)
const (
@@ -38,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
model_setting.GetClaudeSettings().WriteHeaders(req)
return nil
}
@@ -49,8 +51,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
var claudeReq *claude.ClaudeRequest
var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
c.Set("request_model", request.Model)
if err != nil {
return nil, err
}
c.Set("request_model", claudeReq.Model)
c.Set("converted_request", claudeReq)
return claudeReq, err
}
@@ -64,7 +68,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
}

View File

@@ -9,7 +9,8 @@ var awsModelIDMap = map[string]string{
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
}
var ChannelName = "aws"

View File

@@ -16,6 +16,7 @@ type AwsClaudeRequest struct {
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
@@ -30,5 +31,6 @@ func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
StopSequences: req.StopSequences,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
}
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"strings"
)
@@ -55,6 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
anthropicVersion = "2023-06-01"
}
req.Set("anthropic-version", anthropicVersion)
model_setting.GetClaudeSettings().WriteHeaders(req)
return nil
}

View File

@@ -11,6 +11,8 @@ var ModelList = []string{
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking",
}
var ChannelName = "claude"

View File

@@ -11,6 +11,9 @@ type ClaudeMediaMessage struct {
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
Delta string `json:"delta,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
@@ -54,9 +57,15 @@ type ClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
}
type ClaudeError struct {

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting/model_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -92,6 +93,30 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
Stream: textRequest.Stream,
Tools: claudeTools,
}
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") {
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().ThinkingAdapterMaxTokens)
}
// 因为BudgetTokens 必须大于1024
if claudeRequest.MaxTokens < 1280 {
claudeRequest.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &Thinking{
Type: "enabled",
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
claudeRequest.TopP = 0
claudeRequest.Temperature = common.GetPointer[float64](1.0)
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
@@ -273,7 +298,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0)
tools := make([]dto.ToolCallResponse, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
@@ -292,10 +317,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{
tools = append(tools, dto.ToolCallResponse{
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
@@ -308,12 +333,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.Delta != nil {
choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, dto.ToolCall{
Function: dto.FunctionCall{
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
thinkingContent := claudeResponse.Delta.Thinking
choice.Delta.ReasoningContent = &thinkingContent
}
}
} else if claudeResponse.Type == "message_delta" {
@@ -351,7 +384,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
tools := make([]dto.ToolCall, 0)
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
@@ -367,16 +402,22 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
} else {
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
if message.Type == "tool_use" {
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCall{
tools = append(tools, dto.ToolCallResponse{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Name: message.Name,
Arguments: string(args),
},
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking
case "text":
responseText = message.Text
}
}
}
@@ -391,6 +432,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(tools) > 0 {
choice.Message.SetToolCalls(tools)
}
choice.Message.ReasoningContent = thinkingContent
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice)
fullTextResponse.Choices = choices

View File

@@ -9,9 +9,18 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"strings"
)
const (
BotTypeChatFlow = 1 // chatflow default
BotTypeAgent = 2
BotTypeWorkFlow = 3
BotTypeCompletion = 4
)
type Adaptor struct {
BotType int
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -25,10 +34,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "agent") {
a.BotType = BotTypeAgent
} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
a.BotType = BotTypeWorkFlow
} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
a.BotType = BotTypeCompletion
} else {
a.BotType = BotTypeChatFlow
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
switch a.BotType {
case BotTypeWorkFlow:
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
case BotTypeCompletion:
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
case BotTypeAgent:
fallthrough
default:
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -53,7 +80,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -7,11 +7,11 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting/model_setting"
"strings"
@@ -64,15 +64,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta"
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
} else {
version = "v1beta"
}
}
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil

View File

@@ -20,4 +20,12 @@ var ModelList = []string{
"imagen-3.0-generate-002",
}
var SafetySettingList = []string{
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
}
var ChannelName = "google gemini"

View File

@@ -11,6 +11,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting/model_setting"
"strings"
"unicode/utf8"
@@ -22,28 +23,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
Threshold: common.GeminiSafetySetting,
},
},
//SafetySettings: []GeminiChatSafetySettings{},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
@@ -52,9 +32,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
Category: category,
Threshold: model_setting.GetGeminiSafetySetting(category),
})
}
geminiRequest.SafetySettings = safetySettings
// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
for _, tool := range textRequest.Tools {
@@ -349,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data
}
func getToolCall(item *GeminiPart) *dto.ToolCall {
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte
var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -361,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
if err != nil {
return nil
}
return &dto.ToolCall{
return &dto.ToolCallResponse{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
Function: dto.FunctionResponse{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
@@ -379,7 +368,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
is_tool_call := false
isToolCall := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: int(candidate.Index),
@@ -391,12 +380,12 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
}
if len(candidate.Content.Parts) > 0 {
var texts []string
var tool_calls []dto.ToolCall
var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call)
if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call)
}
} else {
if part.ExecutableCode != nil {
@@ -411,9 +400,9 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
}
}
}
if len(tool_calls) > 0 {
choice.Message.SetToolCalls(tool_calls)
is_tool_call = true
if len(toolCalls) > 0 {
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
@@ -429,7 +418,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
choice.FinishReason = constant.FinishReasonContentFilter
}
}
if is_tool_call {
if isToolCall {
choice.FinishReason = constant.FinishReasonToolCalls
}
@@ -468,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
isTools = true
if call := getToolCall(&part); call != nil {
if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}

View File

@@ -61,7 +61,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)
err, usage = JinaRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = jinaEmbeddingHandler(c, resp)
}

View File

@@ -9,7 +9,7 @@ import (
"one-api/service"
)
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -3,21 +3,22 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
}
type Options struct {

View File

@@ -58,6 +58,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
MaxTokens: request.MaxTokens,
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,

View File

@@ -14,6 +14,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
"one-api/relay/channel/jina"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
@@ -146,7 +147,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
@@ -228,6 +229,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank:
err, usage = jina.JinaRerankHandler(c, resp)
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)

View File

@@ -0,0 +1,74 @@
package openrouter
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
req.Set("X-Title", "New API")
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,5 @@
package openrouter
var ModelList = []string{}
var ChannelName = "openrouter"

View File

@@ -28,6 +28,7 @@ var claudeModelMap = map[string]string{
"claude-3-opus-20240229": "claude-3-opus@20240229",
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
}
const anthropicVersion = "vertex-2023-10-16"
@@ -132,7 +133,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
return nil, errors.New("failed to copy claude request")
}
c.Set("request_model", request.Model)
c.Set("request_model", claudeReq.Model)
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
@@ -156,7 +157,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -50,6 +50,9 @@ type RelayInfo struct {
AudioUsage bool
ReasoningEffort string
ChannelSetting map[string]interface{}
UserSetting map[string]interface{}
UserEmail string
UserQuota int
}
// 定义支持流式选项的通道类型
@@ -89,6 +92,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
UserEmail: c.GetString(constant.ContextKeyUserEmail),
IsFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
@@ -148,19 +154,7 @@ func (info *RelayInfo) SetFirstResponseTime() {
}
type TaskRelayInfo struct {
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
StartTime time.Time
ApiType int
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiKey string
BaseUrl string
*RelayInfo
Action string
OriginTaskID string
@@ -168,48 +162,8 @@ type TaskRelayInfo struct {
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &TaskRelayInfo{
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
StartTime: startTime,
ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
RelayInfo: GenRelayInfo(c),
}
return info
}
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}

View File

@@ -30,6 +30,7 @@ const (
APITypeMokaAI
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -86,6 +87,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeVolcEngine
case common.ChannelTypeBaiduV2:
apiType = APITypeBaiduV2
case common.ChannelTypeOpenRouter:
apiType = APITypeOpenRouter
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -192,7 +191,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
if err != nil {
return &mjResp.Response
}
defer func(ctx context.Context) {
defer func() {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
@@ -208,14 +207,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
}()
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
@@ -498,7 +497,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
defer func() {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
@@ -510,14 +509,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
}()
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功

View File

@@ -248,6 +248,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
}
relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited {
@@ -267,7 +268,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
}
if preConsumedQuota > 0 {
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}

View File

@@ -18,6 +18,7 @@ import (
"one-api/relay/channel/mokaai"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
@@ -83,6 +84,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &volcengine.Adaptor{}
case constant.APITypeBaiduV2:
return &baidu_v2.Adaptor{}
case constant.APITypeOpenRouter:
return &openrouter.Adaptor{}
}
return nil
}

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -109,11 +108,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
return
}
defer func(ctx context.Context) {
defer func() {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -123,13 +122,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
}
}(c.Request.Context())
}()
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
if taskErr != nil {

View File

@@ -24,6 +24,7 @@ func SetRelayRouter(router *gin.Engine) {
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth())
relayV1Router.Use(middleware.ModelRequestRateLimit())
{
// WebSocket 路由
wsRouter := relayV1Router.Group("")

View File

@@ -276,7 +276,7 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
}
if err != nil {
return err
@@ -295,20 +295,16 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu
if sendEmail {
if (quota + preConsumedQuota) != 0 {
checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
}
}
return nil
}
func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
gopool.Go(func() {
userCache, err := model.GetUserCache(userId)
if err != nil {
common.SysError("failed to get user cache: " + err.Error())
}
userSetting := userCache.GetSetting()
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
@@ -317,16 +313,16 @@ func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false
consumeQuota := quota + preConsumedQuota
if userCache.Quota-consumeQuota < threshold {
if relayInfo.UserQuota-consumeQuota < threshold {
quotaTooLow = true
}
if quotaTooLow {
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
}
})

View File

@@ -1,7 +1,6 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"image"
@@ -78,6 +77,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if text == "" {
return 0
}
return len(tokenEncoder.Encode(text, nil, nil))
}
@@ -167,12 +169,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
}
tkm += msgTokens
if request.Tools != nil {
toolsData, _ := json.Marshal(request.Tools)
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
openaiTools := request.Tools
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
@@ -282,30 +279,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
} else {
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}
}
}

View File

@@ -11,47 +11,45 @@ import (
func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
_ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
}
func NotifyUser(user *model.UserBase, data dto.Notify) error {
userSetting := user.GetSetting()
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
notifyType, ok := userSetting[constant.UserSettingNotifyType]
if !ok {
notifyType = constant.NotifyTypeEmail
}
// Check notification limit
canSend, err := CheckNotificationLimit(user.Id, data.Type)
canSend, err := CheckNotificationLimit(userId, data.Type)
if err != nil {
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
}
switch notifyType {
case constant.NotifyTypeEmail:
userEmail := user.Email
// check setting email
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
userEmail = settingEmail.(string)
}
if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
return sendEmailNotify(userEmail, data)
case constant.NotifyTypeWebhook:
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
if !ok {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}
webhookURLStr, ok := webhookURL.(string)
if !ok {
common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
return nil
}

259
setting/config/config.go Normal file
View File

@@ -0,0 +1,259 @@
package config
import (
"encoding/json"
"one-api/common"
"reflect"
"strconv"
"strings"
"sync"
)
// ConfigManager 统一管理所有配置
type ConfigManager struct {
configs map[string]interface{}
mutex sync.RWMutex
}
var GlobalConfig = NewConfigManager()
func NewConfigManager() *ConfigManager {
return &ConfigManager{
configs: make(map[string]interface{}),
}
}
// Register 注册一个配置模块
func (cm *ConfigManager) Register(name string, config interface{}) {
cm.mutex.Lock()
defer cm.mutex.Unlock()
cm.configs[name] = config
}
// Get 获取指定配置模块
func (cm *ConfigManager) Get(name string) interface{} {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
return cm.configs[name]
}
// LoadFromDB 从数据库加载配置
func (cm *ConfigManager) LoadFromDB(options map[string]string) error {
cm.mutex.Lock()
defer cm.mutex.Unlock()
for name, config := range cm.configs {
prefix := name + "."
configMap := make(map[string]string)
// 收集属于此配置的所有选项
for key, value := range options {
if strings.HasPrefix(key, prefix) {
configKey := strings.TrimPrefix(key, prefix)
configMap[configKey] = value
}
}
// 如果找到配置项,则更新配置
if len(configMap) > 0 {
if err := updateConfigFromMap(config, configMap); err != nil {
common.SysError("failed to update config " + name + ": " + err.Error())
continue
}
}
}
return nil
}
// SaveToDB 将配置保存到数据库
func (cm *ConfigManager) SaveToDB(updateFunc func(key, value string) error) error {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
for name, config := range cm.configs {
configMap, err := configToMap(config)
if err != nil {
return err
}
for key, value := range configMap {
dbKey := name + "." + key
if err := updateFunc(dbKey, value); err != nil {
return err
}
}
}
return nil
}
// 辅助函数将配置对象转换为map
func configToMap(config interface{}) (map[string]string, error) {
result := make(map[string]string)
val := reflect.ValueOf(config)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil, nil
}
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// 跳过未导出字段
if !fieldType.IsExported() {
continue
}
// 获取json标签作为键名
key := fieldType.Tag.Get("json")
if key == "" || key == "-" {
key = fieldType.Name
}
// 处理不同类型的字段
var strValue string
switch field.Kind() {
case reflect.String:
strValue = field.String()
case reflect.Bool:
strValue = strconv.FormatBool(field.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
strValue = strconv.FormatInt(field.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
strValue = strconv.FormatUint(field.Uint(), 10)
case reflect.Float32, reflect.Float64:
strValue = strconv.FormatFloat(field.Float(), 'f', -1, 64)
case reflect.Map, reflect.Slice, reflect.Struct:
// 复杂类型使用JSON序列化
bytes, err := json.Marshal(field.Interface())
if err != nil {
return nil, err
}
strValue = string(bytes)
default:
// 跳过不支持的类型
continue
}
result[key] = strValue
}
return result, nil
}
// 辅助函数从map更新配置对象
func updateConfigFromMap(config interface{}, configMap map[string]string) error {
val := reflect.ValueOf(config)
if val.Kind() != reflect.Ptr {
return nil
}
val = val.Elem()
if val.Kind() != reflect.Struct {
return nil
}
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// 跳过未导出字段
if !fieldType.IsExported() {
continue
}
// 获取json标签作为键名
key := fieldType.Tag.Get("json")
if key == "" || key == "-" {
key = fieldType.Name
}
// 检查map中是否有对应的值
strValue, ok := configMap[key]
if !ok {
continue
}
// 根据字段类型设置值
if !field.CanSet() {
continue
}
switch field.Kind() {
case reflect.String:
field.SetString(strValue)
case reflect.Bool:
boolValue, err := strconv.ParseBool(strValue)
if err != nil {
continue
}
field.SetBool(boolValue)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(strValue, 10, 64)
if err != nil {
continue
}
field.SetInt(intValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(strValue, 10, 64)
if err != nil {
continue
}
field.SetUint(uintValue)
case reflect.Float32, reflect.Float64:
floatValue, err := strconv.ParseFloat(strValue, 64)
if err != nil {
continue
}
field.SetFloat(floatValue)
case reflect.Map, reflect.Slice, reflect.Struct:
// 复杂类型使用JSON反序列化
err := json.Unmarshal([]byte(strValue), field.Addr().Interface())
if err != nil {
continue
}
}
}
return nil
}
// ConfigToMap 将配置对象转换为map导出函数
func ConfigToMap(config interface{}) (map[string]string, error) {
return configToMap(config)
}
// UpdateConfigFromMap 从map更新配置对象导出函数
func UpdateConfigFromMap(config interface{}, configMap map[string]string) error {
return updateConfigFromMap(config, configMap)
}
// ExportAllConfigs 导出所有已注册的配置为扁平结构
func (cm *ConfigManager) ExportAllConfigs() map[string]string {
cm.mutex.RLock()
defer cm.mutex.RUnlock()
result := make(map[string]string)
for name, cfg := range cm.configs {
configMap, err := ConfigToMap(cfg)
if err != nil {
continue
}
// 使用 "模块名.配置项" 的格式添加到结果中
for key, value := range configMap {
result[name+"."+key] = value
}
}
return result
}

View File

@@ -0,0 +1,49 @@
package model_setting
import (
"net/http"
"one-api/setting/config"
)
//var claudeHeadersSettings = map[string][]string{}
//
//var ClaudeThinkingAdapterEnabled = true
//var ClaudeThinkingAdapterMaxTokens = 8192
//var ClaudeThinkingAdapterBudgetTokensPercentage = 0.8
// ClaudeSettings 定义Claude模型的配置
type ClaudeSettings struct {
HeadersSettings map[string][]string `json:"headers_settings"`
ThinkingAdapterEnabled bool `json:"thinking_adapter_enabled"`
ThinkingAdapterMaxTokens int `json:"thinking_adapter_max_tokens"`
ThinkingAdapterBudgetTokensPercentage float64 `json:"thinking_adapter_budget_tokens_percentage"`
}
// 默认配置
var defaultClaudeSettings = ClaudeSettings{
HeadersSettings: map[string][]string{},
ThinkingAdapterEnabled: true,
ThinkingAdapterMaxTokens: 8192,
ThinkingAdapterBudgetTokensPercentage: 0.8,
}
// 全局实例
var claudeSettings = defaultClaudeSettings
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("claude", &claudeSettings)
}
// GetClaudeSettings 获取Claude配置
func GetClaudeSettings() *ClaudeSettings {
return &claudeSettings
}
func (c *ClaudeSettings) WriteHeaders(headers *http.Header) {
for key, values := range c.HeadersSettings {
for _, value := range values {
headers.Add(key, value)
}
}
}

View File

@@ -0,0 +1,52 @@
package model_setting
import (
"one-api/setting/config"
)
// GeminiSettings 定义Gemini模型的配置
type GeminiSettings struct {
SafetySettings map[string]string `json:"safety_settings"`
VersionSettings map[string]string `json:"version_settings"`
}
// 默认配置
var defaultGeminiSettings = GeminiSettings{
SafetySettings: map[string]string{
"default": "OFF",
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
},
VersionSettings: map[string]string{
"default": "v1beta",
"gemini-1.0-pro": "v1",
},
}
// 全局实例
var geminiSettings = defaultGeminiSettings
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("gemini", &geminiSettings)
}
// GetGeminiSettings 获取Gemini配置
func GetGeminiSettings() *GeminiSettings {
return &geminiSettings
}
// GetGeminiSafetySetting 获取安全设置
func GetGeminiSafetySetting(key string) string {
if value, ok := geminiSettings.SafetySettings[key]; ok {
return value
}
return geminiSettings.SafetySettings["default"]
}
// GetGeminiVersionSetting 获取版本设置
func GetGeminiVersionSetting(key string) string {
if value, ok := geminiSettings.VersionSettings[key]; ok {
return value
}
return geminiSettings.VersionSettings["default"]
}

6
setting/rate_limit.go Normal file
View File

@@ -0,0 +1,6 @@
package setting
var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000

View File

@@ -0,0 +1,83 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import { API, showError, showSuccess } from '../helpers';
import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
import { useTranslation } from 'react-i18next';
import SettingGeminiModel from '../pages/Setting/Model/SettingGeminiModel.js';
import SettingClaudeModel from '../pages/Setting/Model/SettingClaudeModel.js';
const ModelSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
'gemini.safety_settings': '',
'gemini.version_settings': '',
'claude.headers_settings': '',
'claude.thinking_adapter_enabled': true,
'claude.thinking_adapter_max_tokens': 8192,
'claude.thinking_adapter_budget_tokens_percentage': 0.8,
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key === 'gemini.safety_settings' ||
item.key === 'gemini.version_settings' ||
item.key === 'claude.headers_settings'
) {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
if (
item.key.endsWith('Enabled')
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
// showSuccess('刷新成功');
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* Gemini */}
<Card style={{ marginTop: '10px' }}>
<SettingGeminiModel options={inputs} refresh={onRefresh} />
</Card>
{/* Claude */}
<Card style={{ marginTop: '10px' }}>
<SettingClaudeModel options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default ModelSetting;

View File

@@ -0,0 +1,80 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js';
import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js';
import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js';
import SettingsLog from '../pages/Setting/Operation/SettingsLog.js';
import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js';
import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js';
import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js';
import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js';
import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js';
import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js';
import { API, showError, showSuccess } from '../helpers';
import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
import { useTranslation } from 'react-i18next';
import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimit.js';
const RateLimitSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key.endsWith('Enabled')
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
// showSuccess('刷新成功');
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* AI请求速率限制 */}
<Card style={{ marginTop: '10px' }}>
<RequestRateLimit options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default RateLimitSetting;

View File

@@ -376,7 +376,7 @@ const UsersTable = () => {
if (searchKeyword === '') {
await loadUsers(activePage, pageSize);
} else {
await searchUsers(searchKeyword, searchGroup);
await searchUsers(activePage, pageSize, searchKeyword, searchGroup);
}
};

View File

@@ -856,7 +856,7 @@
"IP黑名单": "IP blacklist",
"不允许的IP一行一个": "IPs not allowed, one per line",
"请选择该渠道所支持的模型": "Please select the model supported by this channel",
"次": "Second-rate",
"次": "times",
"达到限速报错内容": "Error content when the speed limit is reached",
"不填则使用默认报错": "If not filled in, the default error will be reported.",
"Midjouney 设置 (可选)": "Midjouney settings (optional)",
@@ -1074,10 +1074,9 @@
"删除所选通道": "Delete selected channels",
"标签聚合模式": "Enable tag mode",
"没有账户?": "No account? ",
"注意,模型部署名称必须和模型名称保持一致,因为 One API 会把请求体中的 model 参数替换为你的部署名称(模型名称中的点会被剔除)": "Note: The model deployment name must match the model name because One API will replace the model parameter in the request body with your deployment name (dots in the model name will be removed)",
"请输入 AZURE_OPENAI_ENDPOINT例如https://docs-test-001.openai.azure.com": "Please enter AZURE_OPENAI_ENDPOINT, e.g.: https://docs-test-001.openai.azure.com",
"默认 API 版本": "Default API Version",
"请输入默认 API 版本例如2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖": "Please enter default API version, e.g.: 2023-06-01-preview. This configuration can be overridden by actual request query parameters",
"请输入默认 API 版本例如2024-12-01-preview": "Please enter default API version, e.g.: 2024-12-01-preview.",
"请为渠道命名": "Please name the channel",
"请选择可以使用该渠道的分组": "Please select groups that can use this channel",
"请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit Group ratios in system settings to add new groups:",
@@ -1271,5 +1270,15 @@
"留空则使用账号绑定的邮箱": "If left blank, the email address bound to the account will be used",
"代理站地址": "Base URL",
"对于官方渠道new-api已经内置地址除非是第三方代理站点或者Azure的特殊接入地址否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in",
"渠道额外设置": "Channel extra settings"
}
"渠道额外设置": "Channel extra settings",
"模型请求速率限制": "Model request rate limit",
"启用用户模型请求速率限制(可能会影响高并发性能)": "Enable user model request rate limit (may affect high concurrency performance)",
"限制周期": "Limit period",
"用户每周期最多请求次数": "User max request times per period",
"用户每周期最多请求完成次数": "User max successful request times per period",
"包括失败请求的次数0代表不限制": "Including failed request times, 0 means no limit",
"频率限制的周期(分钟)": "Rate limit period (minutes)",
"只包括请求成功的次数": "Only include successful request times",
"保存模型速率限制": "Save model rate limit settings",
"速率限制设置": "Rate limit settings"
}

View File

@@ -327,9 +327,6 @@ const EditChannel = (props) => {
localInputs.base_url.length - 1
);
}
if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2023-06-01-preview';
}
if (localInputs.type === 18 && localInputs.other === '') {
localInputs.other = 'v2.1';
}
@@ -494,7 +491,7 @@ const EditChannel = (props) => {
<Input
label={t('默认 API 版本')}
name="azure_other"
placeholder={t('请输入默认 API 版本例如2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖')}
placeholder={t('请输入默认 API 版本例如2024-12-01-preview')}
onChange={(value) => {
handleInputChange('other', value);
}}

View File

@@ -0,0 +1,152 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
showError,
showSuccess,
showWarning, verifyJSON
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
const CLAUDE_HEADER = {
'anthropic-beta': ['output-128k-2025-02-19', 'token-efficient-tools-2025-02-19'],
};
export default function SettingClaudeModel(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
'claude.headers_settings': '',
'claude.thinking_adapter_enabled': true,
'claude.thinking_adapter_max_tokens': 8192,
'claude.thinking_adapter_budget_tokens_percentage': 0.8,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
})
.catch(() => {
showError(t('保存失败,请重试'));
})
.finally(() => {
setLoading(false);
});
}
useEffect(() => {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('Claude设置')}>
<Row>
<Col span={16}>
<Form.TextArea
label={t('Claude请求头覆盖')}
field={'claude.headers_settings'}
placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(CLAUDE_HEADER, null, 2)}
extraText={t('示例') + JSON.stringify(CLAUDE_HEADER, null, 2)}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: t('不是合法的 JSON 字符串')
}
]}
onChange={(value) => setInputs({ ...inputs, 'claude.headers_settings': value })}
/>
</Col>
</Row>
<Row>
<Col span={16}>
<Form.Switch
label={t('启用Claude思考适配-thinking后缀')}
field={'claude.thinking_adapter_enabled'}
onChange={(value) => setInputs({ ...inputs, 'claude.thinking_adapter_enabled': value })}
/>
</Col>
</Row>
<Row>
<Col span={16}>
{/*//展示MaxTokens和BudgetTokens的计算公式, 并展示实际数字*/}
<Text>
{t('Claude思考适配 BudgetTokens = MaxTokens * BudgetTokens 百分比')}
</Text>
</Col>
</Row>
<Row>
<Col span={8}>
<Form.InputNumber
label={t('思考适配 MaxTokens')}
field={'claude.thinking_adapter_max_tokens'}
initValue={''}
onChange={(value) => setInputs({ ...inputs, 'claude.thinking_adapter_max_tokens': value })}
/>
</Col>
<Col span={8}>
<Form.InputNumber
label={t('思考适配 BudgetTokens 百分比')}
field={'claude.thinking_adapter_budget_tokens_percentage'}
initValue={''}
suffix={t('%')}
onChange={(value) => setInputs({ ...inputs, 'claude.thinking_adapter_budget_tokens_percentage': value })}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
);
}

View File

@@ -0,0 +1,139 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
showError,
showSuccess,
showWarning, verifyJSON
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
const GEMINI_SETTING_EXAMPLE = {
'default': 'OFF',
'HARM_CATEGORY_CIVIC_INTEGRITY': 'BLOCK_NONE',
};
const GEMINI_VERSION_EXAMPLE = {
'default': 'v1beta',
};
export default function SettingGeminiModel(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
'gemini.safety_settings': '',
'gemini.version_settings': '',
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
})
.catch(() => {
showError(t('保存失败,请重试'));
})
.finally(() => {
setLoading(false);
});
}
useEffect(() => {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('Gemini设置')}>
<Row>
<Col span={16}>
<Form.TextArea
label={t('Gemini安全设置')}
placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(GEMINI_SETTING_EXAMPLE, null, 2)}
field={'gemini.safety_settings'}
extraText={t('default为默认设置可单独设置每个分类的安全等级')}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: t('不是合法的 JSON 字符串')
}
]}
onChange={(value) => setInputs({ ...inputs, 'gemini.safety_settings': value })}
/>
</Col>
</Row>
<Row>
<Col span={16}>
<Form.TextArea
label={t('Gemini版本设置')}
placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(GEMINI_VERSION_EXAMPLE, null, 2)}
field={'gemini.version_settings'}
extraText={t('default为默认设置可单独设置每个模型的版本')}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: t('不是合法的 JSON 字符串')
}
]}
onChange={(value) => setInputs({ ...inputs, 'gemini.version_settings': value })}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
);
}

View File

@@ -0,0 +1,159 @@
import React, { useEffect, useState, useRef } from 'react';
import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
import {
compareObjects,
API,
showError,
showSuccess,
showWarning,
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
export default function RequestRateLimit(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
ModelRequestRateLimitEnabled: false,
ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
})
.catch(() => {
showError(t('保存失败,请重试'));
})
.finally(() => {
setLoading(false);
});
}
useEffect(() => {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
currentInputs[key] = props.options[key];
}
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
return (
<>
<Spin spinning={loading}>
<Form
values={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
style={{ marginBottom: 15 }}
>
<Form.Section text={t('模型请求速率限制')}>
<Row gutter={16}>
<Col span={8}>
<Form.Switch
field={'ModelRequestRateLimitEnabled'}
label={t('启用用户模型请求速率限制(可能会影响高并发性能)')}
size='default'
checkedText=''
uncheckedText=''
onChange={(value) => {
setInputs({
...inputs,
ModelRequestRateLimitEnabled: value,
});
}}
/>
</Col>
</Row>
<Row>
<Col span={8}>
<Form.InputNumber
label={t('限制周期')}
step={1}
min={0}
suffix={t('分钟')}
extraText={t('频率限制的周期(分钟)')}
field={'ModelRequestRateLimitDurationMinutes'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitDurationMinutes: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Col span={8}>
<Form.InputNumber
label={t('用户每周期最多请求次数')}
step={1}
min={0}
suffix={t('次')}
extraText={t('包括失败请求的次数0代表不限制')}
field={'ModelRequestRateLimitCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitCount: String(value),
})
}
/>
</Col>
<Col span={8}>
<Form.InputNumber
label={t('用户每周期最多请求完成次数')}
step={1}
min={1}
suffix={t('次')}
extraText={t('只包括请求成功的次数')}
field={'ModelRequestRateLimitSuccessCount'}
onChange={(value) =>
setInputs({
...inputs,
ModelRequestRateLimitSuccessCount: String(value),
})
}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存模型速率限制')}
</Button>
</Row>
</Form.Section>
</Form>
</Spin>
</>
);
}

View File

@@ -8,6 +8,8 @@ import { isRoot } from '../../helpers';
import OtherSetting from '../../components/OtherSetting';
import PersonalSetting from '../../components/PersonalSetting';
import OperationSetting from '../../components/OperationSetting';
import RateLimitSetting from '../../components/RateLimitSetting.js';
import ModelSetting from '../../components/ModelSetting.js';
const Setting = () => {
const { t } = useTranslation();
@@ -28,6 +30,16 @@ const Setting = () => {
content: <OperationSetting />,
itemKey: 'operation',
});
panes.push({
tab: t('速率限制设置'),
content: <RateLimitSetting />,
itemKey: 'ratelimit',
});
panes.push({
tab: t('模型相关设置'),
content: <ModelSetting />,
itemKey: 'models',
});
panes.push({
tab: t('系统设置'),
content: <SystemSetting />,