Compare commits

...

90 Commits

Author SHA1 Message Date
Calcium-Ion
ea433b2ed6 Merge pull request #586 from Calcium-Ion/channel-tag
fix: tag channel copy
2024-11-30 19:52:58 +08:00
CalciumIon
bb0c504709 fix: tag channel copy 2024-11-30 19:52:36 +08:00
Calcium-Ion
48abfd055c Merge pull request #574 from Calcium-Ion/channel-tag
feat: 初步集成渠道标签分组功能
2024-11-30 17:45:16 +08:00
CalciumIon
6693072c49 feat: 完善标签编辑(优先级,权重) 2024-11-30 17:43:03 +08:00
CalciumIon
3053d94170 feat: 完善标签编辑 2024-11-30 16:57:58 +08:00
CalciumIon
1774be8536 fix: xAI missing finish_reason #572 2024-11-30 16:57:57 +08:00
Calcium-Ion
821f3a7522 Merge pull request #582 from prnake/patch-1
feat: add claude-3-5-haiku-20241022
2024-11-30 15:07:31 +08:00
CalciumIon
9c4d30602c feat: 完善标签编辑 2024-11-29 23:58:31 +08:00
CalciumIon
7b3394d863 chore: update default STREAMING_TIMEOUT 2024-11-28 23:59:10 +08:00
papersnake
999ba11363 feat: add claude-3-5-haiku-20241022 2024-11-27 13:33:37 +08:00
CalciumIon
7ebc1cfb60 fix: stt模型计费 2024-11-25 19:31:59 +08:00
Calcium-Ion
b71e33b095 Update README.md 2024-11-24 16:51:46 +08:00
CalciumIon
15842163be feat: support audio response_format #580 2024-11-24 16:44:27 +08:00
CalciumIon
e57788375e Update README.md 2024-11-23 16:58:57 +08:00
CalciumIon
78cac7085c Update README.md 2024-11-23 16:38:59 +08:00
Calcium-Ion
76f7474640 Merge pull request #579 from HynoR/main
Chore: support gpt-4o-2024-11-20
2024-11-23 16:34:41 +08:00
CalciumIon
0dd1953cd6 Update BT.md 2024-11-23 16:33:24 +08:00
Calcium-Ion
019361a762 Update BT.md 2024-11-23 16:28:06 +08:00
CalciumIon
9b9d73e725 Update README.md 2024-11-23 16:27:39 +08:00
CalciumIon
9e08709756 chore: Update docker-compose.yml 2024-11-23 16:26:30 +08:00
CalciumIon
05b5d6f255 Update README.md 2024-11-23 16:26:15 +08:00
HynoR
79b6c0a73e Chore: support gpt-4o-2024-11-20 2024-11-21 09:14:23 +08:00
Calcium-Ion
462c2cc1a1 Merge pull request #493 from xixingya/feature/bug-fix
ratio must gte 0
2024-11-19 18:34:24 +08:00
CalciumIon
e8286e479b chore: go.mod 2024-11-19 15:14:40 +08:00
CalciumIon
ed2ec69545 feat: 暂时禁用透传功能 2024-11-19 15:06:54 +08:00
CalciumIon
a167dd9a23 feat: 暂时禁用透传功能 2024-11-19 15:04:19 +08:00
CalciumIon
6e6e390f6f feat: 一键编辑标签下渠道重定向 2024-11-19 01:43:05 +08:00
CalciumIon
807385d3d1 fix: search channel #442 2024-11-19 01:39:27 +08:00
CalciumIon
0ce600ed49 feat: 渠道标签分组 2024-11-19 01:13:18 +08:00
CalciumIon
334a2424e9 fix: oauth aff 2024-11-18 18:53:55 +08:00
CalciumIon
7db703374c fix: oauth aff 2024-11-18 18:52:14 +08:00
Calcium-Ion
6a42ccf00e Merge pull request #569 from utopeadia/main
Modify the default gemini API to v1beta
2024-11-17 16:27:38 +08:00
Calcium-Ion
7aa7114bb9 Merge branch 'main' into main 2024-11-17 16:27:29 +08:00
Calcium-Ion
c3e6b2408e Merge pull request #570 from leezhuuuuu/main
增加对于gemini-exp-1114模型的支持,映射到v1beta
2024-11-17 16:26:36 +08:00
leezhuuuuu
4601932902 feat: add support for gemini-exp-1114 model / 添加 gemini-exp-1114 模型支持
# feat: add support for gemini-exp-1114 model / 添加 gemini-exp-1114 模型支持

## Changes / 更改内容
- Add gemini-exp-1114 to ModelList in constant.go
- Add gemini-exp-1114 to GeminiModelMap with v1beta API version
- 在 constant.go 的 ModelList 中添加 gemini-exp-1114 模型
- 在 GeminiModelMap 中添加 gemini-exp-1114 的 v1beta API 版本映射

## Testing / 测试情况
- [x] Tested gemini-exp-1114 model API calls / 已测试 gemini-exp-1114 模型的 API 调用
- [x] Verified existing models still work / 验证现有模型仍然正常工作
- [x] Confirmed v1beta API version works correctly / 确认 v1beta API 版本正常工作

## Related Issues / 相关问题
- Fix 404 error when calling gemini-exp-1114 model / 修复调用 gemini-exp-1114 模型时的 404 错误

## Implementation Details / 实现细节
- Use configuration-based approach instead of code modification / 使用基于配置的方式而不是修改代码
- Maintain clean separation of concerns / 保持关注点分离
- Keep backward compatibility / 保持向后兼容性

## Notes / 注意事项
- This PR follows the principle of minimal invasion / 本 PR 遵循最小侵入原则
- Configuration changes only / 仅包含配置更改
2024-11-16 21:52:37 +08:00
leezhuuuuu
5d96f7b2cc 增加对于gemini-exp-1114模型的支持,映射到v1beta
feat(gemini): add support for gemini-exp-1114 model

- Add gemini-exp-1114 to ModelList in constant.go
- Update GetRequestURL in adaptor.go to use v1beta API version for gemini-exp-1114
- Keep backward compatibility for other models

This change enables the use of the experimental gemini-exp-1114 model by correctly routing its requests to the v1beta API endpoint while maintaining existing functionality for other models.
2024-11-16 21:29:35 +08:00
HowieWood
8eb32e9b3f Modify the default gemini API to v1beta 2024-11-16 12:21:50 +00:00
CalciumIon
320e6ec5a4 fix: aws claude 2024-11-14 15:12:34 +08:00
Calcium-Ion
8baeece386 Merge pull request #564 from Licoy/main
优化页面组件大小规格一致
2024-11-12 22:39:34 +08:00
licoy
08023f6d96 feat: 增加GLOBAL_API_RATE_LIMIT_ENABLEGLOBAL_WEB_RATE_LIMIT_ENABLE环境变量,支持是否开启访问速率控制 2024-11-12 20:02:33 +08:00
licoy
fad29a8cc2 feat: 增加GLOBAL_API_RATE_LIMIT_DURATIONGLOBAL_WEB_RATE_LIMIT_DURATION环境变量,支持控制访问速率时间设置 2024-11-12 20:01:43 +08:00
licoy
67d09d68c6 feat: 优化数据管理操作栏均为顶部 2024-11-12 17:00:06 +08:00
licoy
cdc02f660b feat: 优化switch组件的大小规格与整体表单一致 2024-11-12 16:32:40 +08:00
licoy
674abe5ae2 feat: 统一运营设置页面的保存按钮大小规格 2024-11-12 16:30:51 +08:00
Calcium-Ion
0b0bcbab80 Merge pull request #563 from Licoy/main
封装OAuth2授权回调页面、修复独立日志数据库查询令牌日志时错误问题
2024-11-12 16:27:46 +08:00
licoy
450bea8f2c 修复独立日志数据库查询令牌日志时错误问题 2024-11-12 16:22:13 +08:00
licoy
bf75df8f04 优化设置页面的模块间距与部分数据获取提示 2024-11-12 16:17:55 +08:00
licoy
c6dae4b879 封装OAuth2授权回调页面 2024-11-12 16:11:38 +08:00
Calcium-Ion
a5abd40ff6 Merge pull request #505 from OiAnthony/f_dotenv
feat: 添加.env配置文件和初始化环境变量
2024-11-11 22:06:09 +08:00
CalciumIon
b012505ff4 chore: update .env.example 2024-11-11 22:05:29 +08:00
CalciumIon
c7c870d4c6 chore: update .env.example 2024-11-11 22:04:51 +08:00
CalciumIon
66fa020be8 feat: update LinuxDo icon 2024-11-11 17:29:54 +08:00
Calcium-Ion
6d47b2c5a1 Merge pull request #562 from seefs001/main
feat: integrate Linux DO OAuth authentication
2024-11-11 17:25:40 +08:00
CalciumIon
85b90e89e6 fix: LinuxDo OAuth 2024-11-11 17:24:57 +08:00
CalciumIon
e291bb02d0 feat: playground用户分组设为默认选项
(cherry picked from commit dd7e9afed43bca3807c4680d28b5cef97f3bf880)
2024-11-11 16:43:20 +08:00
CalciumIon
34998f7939 fix: 非root日志展开bug
(cherry picked from commit 23121a3caf74be60f178bfd5f898a77de02b6d35)
2024-11-11 16:34:36 +08:00
seefs001
046f859d92 feat: integrate Linux DO OAuth authentication 2024-11-10 23:56:22 +08:00
CalciumIon
8fc49f98d2 fix: returnPreConsumedQuota 2024-11-10 02:09:18 +08:00
CalciumIon
4131183378 feat: realtime扣费时检测令牌额度
(cherry picked from commit 91511b8b64fc0d28dbf657cb97e12b7d1e50070d)
2024-11-07 17:28:53 +08:00
CalciumIon
3b53a2a5ce feat: 完善audio倍率 2024-11-07 16:42:08 +08:00
CalciumIon
97fdcd8e8f feat: 完善audio计费 2024-11-07 16:12:09 +08:00
Calcium-Ion
be652fa3c2 Merge pull request #555 from utopeadia/main
Continue fixing Ollama embedding return issue
2024-11-06 21:13:06 +08:00
CalciumIon
cbf0688b80 feat: update model ratio 2024-11-06 19:33:50 +08:00
HowieWood
2ffa4268fc Continue fixing Ollama embedding return issue 2024-11-06 01:21:02 +00:00
Calcium-Ion
3037dfab5b Merge pull request #552 from utopeadia/main
Modify ollama embed return fields
2024-11-05 22:05:45 +08:00
CalciumIon
b40c2e1071 feat: 美化日志页面
(cherry picked from commit 90daa38d5bea7b158ebed9990f042f6bf8567eb3)
2024-11-05 20:45:01 +08:00
Xyfacai
afc1e92ed0 fix: log table unknown ws prop error 2024-11-05 20:20:19 +08:00
1808837298@qq.com
ee04dbd9dd feat: 日志详情完善
(cherry picked from commit ec79110c99e9b4c076c5f7b8285e535b9c5052db)
2024-11-05 20:19:58 +08:00
HowieWood
5253a0e7b2 Modify ollama embed return fields 2024-11-05 20:12:51 +08:00
CalciumIon
e5588fc1ee Update README.md 2024-11-05 19:48:03 +08:00
Calcium-Ion
a859ff5985 Merge pull request #551 from Calcium-Ion/realtime
feat: support openai realtime api
2024-11-05 19:45:43 +08:00
CalciumIon
0a80231e18 chore: 删除无用日志 2024-11-05 19:41:38 +08:00
CalciumIon
7b1ff41e4c fix: mistral adaptor 2024-11-05 19:32:51 +08:00
1808837298@qq.com
4e0c522cd0 fix: realtime计费
(cherry picked from commit fdfea8726c6d86d3844af1ac18d7b3df908f26a7)
2024-11-05 19:29:06 +08:00
1808837298@qq.com
f08f7ae940 fix: channel test
(cherry picked from commit 052bdab1c45b3a4ba5f079afc763f54e751b1cd7)
2024-11-05 19:28:58 +08:00
Xyfacai
be64408a25 fix(realtime): 修复ws 握手失败、计费问题
(cherry picked from commit 618dffc43fd5a5f4065944db87761f9ee18e44d3)
2024-11-05 19:28:46 +08:00
Xyfacai
d596699250 refactor: realtime log
(cherry picked from commit fd24dc467bfc360008b313220e607f0176ee7aa3)
2024-11-05 19:28:09 +08:00
Xyfacai
f0907bf60a fix: 部分情况缺少返回预扣
(cherry picked from commit 96373455521a38095706bd81c57f9a18557d9c2e)
2024-11-05 19:28:08 +08:00
1808837298@qq.com
e5c05d77b7 feat: realtime pre consume
(cherry picked from commit 273d154e1640bae26b7caedddf1685e9ff21ab74)
2024-11-05 19:28:06 +08:00
1808837298@qq.com
24b3ed50d7 feat: realtime pre consume
(cherry picked from commit d87917f8f6eb9d2e144a9f840d6d91767ea2eb69)
2024-11-05 19:28:03 +08:00
1808837298@qq.com
8de79382f0 feat: azure realtime
(cherry picked from commit 75ff3d98f06103dc2df1f8817bd3fcbf433e0f20)
2024-11-05 19:27:55 +08:00
1808837298@qq.com
74f9006b40 feat: realtime
(cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425)
2024-11-05 19:27:47 +08:00
1808837298@qq.com
33af069fae feat: realtime
(cherry picked from commit a5529df3e1a4c08a120e8c05203a7d885b0fe8d8)
2024-11-05 19:24:14 +08:00
1808837298@qq.com
e3c85572d4 Update dto
(cherry picked from commit 030187ff75c64c40017cda2fa98ef2b3c01f0bd5)
2024-11-05 19:23:56 +08:00
lianghaoyuan
2e18d5f96c refactor(config): 调整配置文件,优化注释和变量命名 2024-09-25 17:03:06 +08:00
lianghaoyuan
84f40b63b2 feat: 添加.env配置文件和初始化环境变量 2024-09-24 11:39:02 +08:00
liuzhifei
e0f19e5ed7 ratio must gte 0 2024-09-20 18:33:17 +08:00
liuzhifei
3d33079de0 ratio must gte 0 2024-09-20 18:27:16 +08:00
liuzhifei
1d064a2e88 ratio must gte 0 2024-09-20 18:09:40 +08:00
liuzhifei
4eae3b2177 ratio must gte 0 2024-09-20 17:51:42 +08:00
110 changed files with 7480 additions and 4882 deletions

71
.env.example Normal file
View File

@@ -0,0 +1,71 @@
# 端口号
# PORT=3000
# 前端基础URL
# FRONTEND_BASE_URL=https://your-frontend-url.com
# 调试相关配置
# 启用pprof
# ENABLE_PPROF=true
# 数据库相关配置
# 数据库连接字符串
# SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# 日志数据库连接字符串
# LOG_SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# SQLite数据库路径
# SQLITE_PATH=/path/to/sqlite.db
# 数据库最大空闲连接数
# SQL_MAX_IDLE_CONNS=100
# 数据库最大打开连接数
# SQL_MAX_OPEN_CONNS=1000
# 数据库连接最大生命周期(秒)
# SQL_MAX_LIFETIME=60
# 缓存相关配置
# Redis连接字符串
# REDIS_CONN_STRING=redis://user:password@localhost:6379/0
# 同步频率(单位:秒)
# SYNC_FREQUENCY=60
# 内存缓存启用
# MEMORY_CACHE_ENABLED=true
# 渠道更新频率(单位:秒)
# CHANNEL_UPDATE_FREQUENCY=30
# 批量更新启用
# BATCH_UPDATE_ENABLED=true
# 批量更新间隔(单位:秒)
# BATCH_UPDATE_INTERVAL=5
# 任务和功能配置
# 更新任务启用
# UPDATE_TASK=true
# 会话密钥
# SESSION_SECRET=random_string
# 其他配置
# 渠道测试频率(单位:秒)
# CHANNEL_TEST_FREQUENCY=10
# 生成默认token
# GENERATE_DEFAULT_TOKEN=false
# Gemini 安全设置
# GEMINI_SAFETY_SETTING=BLOCK_NONE
# Gemini版本设置
# GEMINI_MODEL_MAP=gemini-1.0-pro:v1
# Cohere 安全设置
# COHERE_SAFETY_SETTING=NONE
# 是否统计图片token
# GET_MEDIA_TOKEN=true
# 是否在非流stream=false情况下统计图片token
# GET_MEDIA_TOKEN_NOT_STREAM=true
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
# DIFY_DEBUG=true
# 设置流式一次回复的超时时间
# STREAMING_TIMEOUT=90
# 节点类型
# 如果是主节点则为master
# NODE_TYPE=master

3
.gitignore vendored
View File

@@ -6,4 +6,5 @@ upload
build
*.db-journal
logs
web/dist
web/dist
.env

3
BT.md Normal file
View File

@@ -0,0 +1,3 @@
密钥为环境变量SESSION_SECRET
![8285bba413e770fe9620f1bf9b40d44e](https://github.com/user-attachments/assets/7a6fc03e-c457-45e4-b8f9-184508fc26b0)

View File

@@ -48,6 +48,7 @@
4. Telegram Bot 名称是bot username 去掉@后的字符串
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
14. 支持Rerank模型目前仅兼容Cohere和Jina可接入Dify[对接文档](Rerank.md)
15. **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API支持Azure渠道。
## 模型支持
此版本额外支持以下模型:
@@ -67,18 +68,25 @@
## 比原版One API多出的配置
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 60 秒。
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL``STRICT`,默认为 `NONE`
## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
- 远程数据库MySQL 版本 >= 5.7.8PgSQL 版本 >= 9.6
### 使用宝塔面板Docker功能部署
安装宝塔面板 (**9.2.0版本**及以上),前往 [宝塔面板](https://www.bt.cn/new/download.html) 官网,选择正式版的脚本下载安装
安装后登录宝塔面板,在菜单栏中点击 Docker ,首次进入会提示安装 Docker 服务,点击立即安装,按提示完成安装
安装完成后在应用商店中找到 **New-API** ,点击安装,配置基本选项 即可完成安装
[图文教程](BT.md)
### 基于 Docker 进行部署
```shell
# 使用 SQLite 的部署命令:
@@ -87,16 +95,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
# 例如:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
### 使用宝塔面板Docker功能部署
```shell
# 使用 SQLite 的部署命令:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
# 例如:
# 注意数据库要开启远程访问并且只允许服务器IP访问
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 注意数据库要开启远程访问并且只允许服务器IP访问
```
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
@@ -134,7 +132,7 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="200">
<img src="https://github.com/user-attachments/assets/9ca0bc82-e057-4230-a28d-9f198fa022e3" width="200">
## 相关项目
- [One API](https://github.com/songquanpeng/one-api):原版项目

View File

@@ -41,6 +41,7 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var LinuxDOOAuthEnabled = false
var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false
@@ -75,6 +76,9 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LinuxDOClientId = ""
var LinuxDOClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@@ -140,11 +144,13 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60

View File

@@ -2,6 +2,7 @@ package common
import (
"encoding/json"
"errors"
)
var GroupRatio = map[string]float64{
@@ -31,3 +32,17 @@ func GetGroupRatio(name string) float64 {
}
return ratio
}
func CheckGroupRatio(jsonStr string) error {
checkGroupRatio := make(map[string]float64)
err := json.Unmarshal([]byte(jsonStr), &checkGroupRatio)
if err != nil {
return err
}
for name, ratio := range checkGroupRatio {
if ratio < 0 {
return errors.New("group ratio must be not less than 0: " + name)
}
}
return nil
}

View File

@@ -33,17 +33,18 @@ var defaultModelRatio = map[string]float64{
"gpt-4-32k": 30,
//"gpt-4-32k-0314": 30, //deprecated
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
"gpt-4o": 1.25, // $0.01 / 1K tokens
"gpt-4o-audio-preview": 1.25, // $0.0015 / 1K tokens
"gpt-4o-audio-preview-2024-10-01": 1.25, // $0.0015 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
"gpt-4o-2024-05-13": 2.5,
"gpt-4-1106-preview": 5, // $10 / 1M tokens
"gpt-4-0125-preview": 5, // $10 / 1M tokens
"gpt-4-turbo-preview": 5, // $10 / 1M tokens
"gpt-4-vision-preview": 5, // $10 / 1M tokens
"gpt-4-1106-vision-preview": 5, // $10 / 1M tokens
"chatgpt-4o-latest": 2.5, // $5 / 1M tokens
"gpt-4o": 1.25, // $2.5 / 1M tokens
"gpt-4o-audio-preview": 1.25, // $2.5 / 1M tokens
"gpt-4o-audio-preview-2024-10-01": 1.25, // $2.5 / 1M tokens
"gpt-4o-2024-05-13": 2.5, // $5 / 1M tokens
"gpt-4o-2024-08-06": 1.25, // $2.5 / 1M tokens
"gpt-4o-2024-11-20": 1.25, // $2.5 / 1M tokens
"gpt-4o-realtime-preview": 2.5,
"o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
@@ -54,6 +55,7 @@ var defaultModelRatio = map[string]float64{
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
//"gpt-3.5-turbo-0301": 0.75, //deprecated
"gpt-3.5-turbo": 0.25,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
@@ -421,6 +423,36 @@ func GetCompletionRatio(name string) float64 {
return 1
}
func GetAudioRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4o-realtime") {
return 20
} else if strings.HasPrefix(name, "gpt-4o-audio") {
return 40
}
return 20
}
func GetAudioCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4o-realtime") {
return 2
}
return 2
}
//func GetAudioPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.06
// }
// return 0.06
//}
//
//func GetAudioCompletionPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.24
// }
// return 0.24
//}
func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio

View File

@@ -7,7 +7,7 @@ import (
"strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
// ForceStreamOption 覆盖请求参数强制返回usage信息
@@ -20,16 +20,7 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-pro-exp-0827": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-exp-0827": "v1beta",
"gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
"gemini-1.0-pro": "v1",
}
func InitEnv() {

View File

@@ -102,17 +102,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if err != nil {
return err, nil
}
if resp != nil && resp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp)
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
}
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr
}
if usage == nil {
if usageA == nil {
return errors.New("usage is nil"), nil
}
usage := usageA.(*dto.Usage)
result := w.Result()
respBody, err := io.ReadAll(result.Body)
if err != nil {

View File

@@ -57,10 +57,37 @@ func GetAllChannels(c *gin.Context) {
})
return
}
tags := make(map[string]bool)
channelData := make([]*model.Channel, 0, len(channels))
tagChannels := make([]*model.Channel, 0)
for _, channel := range channels {
channelTag := channel.GetTag()
if channelTag != "" && !tags[channelTag] {
tags[channelTag] = true
tagChannel, err := model.GetChannelsByTag(channelTag)
if err == nil {
tagChannels = append(tagChannels, tagChannel...)
}
} else {
channelData = append(channelData, channel)
}
}
for i, channel := range tagChannels {
find := false
for _, can := range channelData {
if channel.Id == can.Id {
find = true
break
}
}
if !find {
channelData = append(channelData, tagChannels[i])
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channels,
"data": channelData,
})
return
}
@@ -144,8 +171,8 @@ func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
//idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.SearchChannels(keyword, group, modelKeyword)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -279,6 +306,98 @@ func DeleteDisabledChannel(c *gin.Context) {
return
}
type ChannelTag struct {
Tag string `json:"tag"`
NewTag *string `json:"new_tag"`
Priority *int64 `json:"priority"`
Weight *uint `json:"weight"`
ModelMapping *string `json:"model_mapping"`
Models *string `json:"models"`
Groups *string `json:"groups"`
}
func DisableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.DisableChannelByTag(channelTag.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EnableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.EnableChannelByTag(channelTag.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EditTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
if channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
type ChannelBatch struct {
Ids []int `json:"ids"`
}

View File

@@ -142,8 +142,13 @@ func GitHubOAuth(c *gin.Context) {
user.Email = githubUser.Email
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(0); err != nil {
if err := user.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -227,6 +232,10 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
affCode := c.Query("aff")
if affCode != "" {
session.Set("aff", affCode)
}
session.Set("oauth_state", state)
err := session.Save()
if err != nil {

271
controller/linuxdo.go Normal file
View File

@@ -0,0 +1,271 @@
package controller
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type LinuxdoUser struct {
Id int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDOOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Linux DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
}
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Linux DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
}
func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
if code == "" {
return nil, errors.New("invalid code")
}
// Get access token using Basic auth
tokenEndpoint := "https://connect.linux.do/oauth2/token"
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
// Get redirect URI from request
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", redirectURI)
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", basicAuth)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
return nil, errors.New("failed to connect to Linux DO server")
}
defer res.Body.Close()
var tokenRes struct {
AccessToken string `json:"access_token"`
Message string `json:"message"`
}
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
return nil, err
}
if tokenRes.AccessToken == "" {
return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
}
// Get user info
userEndpoint := "https://connect.linux.do/api/user"
req, err = http.NewRequest("GET", userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
req.Header.Set("Accept", "application/json")
res2, err := client.Do(req)
if err != nil {
return nil, errors.New("failed to get user info from Linux DO")
}
defer res2.Body.Close()
var linuxdoUser LinuxdoUser
if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
return nil, err
}
if linuxdoUser.Id == 0 {
return nil, errors.New("invalid user info returned")
}
return &linuxdoUser, nil
}
func LinuxdoOAuth(c *gin.Context) {
session := sessions.Default(c)
errorCode := c.Query("error")
if errorCode != "" {
errorDescription := c.Query("error_description")
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": errorDescription,
})
return
}
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDOOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Linux DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
}
// Check if user exists
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
err := user.FillUserByLinuxDOId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = linuxdoUser.Name
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}

View File

@@ -38,6 +38,8 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,

View File

@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "LinuxDOOAuthEnabled":
if option.Value == "true" && common.LinuxDOClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 LinuxDO OAuth请先填入 LinuxDO Client Id 以及 LinuxDO Client Secret",
})
return
}
case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{
@@ -74,6 +82,15 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "GroupRatio":
err = common.CheckGroupRatio(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
}
err = model.UpdateOption(option.Key, option.Value)
if err != nil {

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"log"
"net/http"
@@ -38,6 +39,15 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
return err
}
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
default:
err = relay.TextHelper(c)
}
return err
}
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
@@ -134,6 +144,67 @@ func Relay(c *gin.Context) {
}
}
var upgrader = websocket.Upgrader{
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol则必须在此声明对应的 Protocol TODO add other protocol
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
func WssRelay(c *gin.Context) {
// 将 HTTP 连接升级为 WebSocket 连接
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
defer ws.Close()
if err != nil {
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
service.WssError(c, ws, openaiErr.Error)
return
}
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
openaiErr = wssRequest(c, ws, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
service.WssError(c, ws, openaiErr.Error)
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
@@ -141,6 +212,13 @@ func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.Op
return relayHandler(c, relayMode)
}
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.WssHelper(c, ws)
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))

View File

@@ -3,7 +3,6 @@ version: '3.4'
services:
new-api:
image: calciumion/new-api:latest
# build: .
container_name: new-api
restart: always
command: --log-dir /app/logs
@@ -13,16 +12,17 @@ services:
- ./data:/data
- ./logs:/app/logs
environment:
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
- REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai
# - NODE_TYPE=slave # 多机部署时从节点取消注释该行
# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行
# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
# - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
depends_on:
- redis
- mysql
healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
@@ -33,3 +33,18 @@ services:
image: redis:latest
container_name: redis
restart: always
mysql:
image: mysql:8.2
container_name: mysql
restart: always
environment:
MYSQL_ROOT_PASSWORD: 123456 # Ensure this matches the password in SQL_DSN
MYSQL_DATABASE: new-api
volumes:
- mysql_data:/var/lib/mysql
ports:
- "3306:3306" # If you want to access MySQL from outside Docker, uncomment
volumes:
mysql_data:

View File

@@ -128,7 +128,9 @@ type CompletionsStreamResponse struct {
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
}

98
dto/realtime.go Normal file
View File

@@ -0,0 +1,98 @@
package dto
const (
RealtimeEventTypeError = "error"
RealtimeEventTypeSessionUpdate = "session.update"
RealtimeEventTypeConversationCreate = "conversation.item.create"
RealtimeEventTypeResponseCreate = "response.create"
RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
)
const (
RealtimeEventTypeResponseDone = "response.done"
RealtimeEventTypeSessionUpdated = "session.updated"
RealtimeEventTypeSessionCreated = "session.created"
RealtimeEventResponseAudioDelta = "response.audio.delta"
RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta"
RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done"
RealtimeEventConversationItemCreated = "conversation.item.created"
)
type RealtimeEvent struct {
EventId string `json:"event_id"`
Type string `json:"type"`
//PreviousItemId string `json:"previous_item_id"`
Session *RealtimeSession `json:"session,omitempty"`
Item *RealtimeItem `json:"item,omitempty"`
Error *OpenAIError `json:"error,omitempty"`
Response *RealtimeResponse `json:"response,omitempty"`
Delta string `json:"delta,omitempty"`
Audio string `json:"audio,omitempty"`
}
type RealtimeResponse struct {
Usage *RealtimeUsage `json:"usage"`
}
type RealtimeUsage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails InputTokenDetails `json:"input_token_details"`
OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
}
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
}
type OutputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
}
type RealtimeSession struct {
Modalities []string `json:"modalities"`
Instructions string `json:"instructions"`
Voice string `json:"voice"`
InputAudioFormat string `json:"input_audio_format"`
OutputAudioFormat string `json:"output_audio_format"`
InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
TurnDetection interface{} `json:"turn_detection"`
Tools []RealTimeTool `json:"tools"`
ToolChoice string `json:"tool_choice"`
Temperature float64 `json:"temperature"`
//MaxResponseOutputTokens int `json:"max_response_output_tokens"`
}
type InputAudioTranscription struct {
Model string `json:"model"`
}
type RealTimeTool struct {
Type string `json:"type"`
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters"`
}
type RealtimeItem struct {
Id string `json:"id"`
Type string `json:"type"`
Status string `json:"status"`
Role string `json:"role"`
Content []RealtimeContent `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
CallId string `json:"call_id,omitempty"`
}
type RealtimeContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes.
Transcript string `json:"transcript,omitempty"`
}

13
go.mod
View File

@@ -12,6 +12,7 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/bytedance/sonic v1.12.4
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@@ -23,6 +24,7 @@ require (
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0
@@ -41,9 +43,10 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/bytedance/sonic/loader v0.2.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
@@ -64,7 +67,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
@@ -78,11 +81,11 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.24.0 // indirect
golang.org/x/sys v0.27.0 // indirect
golang.org/x/text v0.17.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

34
go.sum
View File

@@ -20,14 +20,17 @@ github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/bytedance/sonic v1.12.4 h1:9Csb3c9ZJhfUWeMtpCDCq6BUoH5ogfDFLUgQ/jG+R0k=
github.com/bytedance/sonic v1.12.4/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/bytedance/sonic/loader v0.2.1 h1:1GgorWTqf12TA8mma4DDSbaQigE2wOgQo7iCjjJv3+E=
github.com/bytedance/sonic/loader v0.2.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -111,12 +114,15 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
@@ -194,9 +200,8 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
@@ -217,12 +222,11 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -260,4 +264,4 @@ gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=

File diff suppressed because it is too large Load Diff

17
main.go
View File

@@ -3,10 +3,6 @@ package main
import (
"embed"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"log"
"net/http"
"one-api/common"
@@ -19,6 +15,12 @@ import (
"os"
"strconv"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/joho/godotenv"
_ "net/http/pprof"
)
@@ -29,6 +31,11 @@ var buildFS embed.FS
var indexPage []byte
func main() {
err := godotenv.Load(".env")
if err != nil {
common.SysLog("Can't load .env file")
}
common.SetupLogger()
common.SysLog("New API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
@@ -38,7 +45,7 @@ func main() {
common.SysLog("running in debug mode")
}
// Initialize SQL Database
err := model.InitDB()
err = model.InitDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}

View File

@@ -155,8 +155,27 @@ func RootAuth() func(c *gin.Context) {
}
}
func WssAuth(c *gin.Context) {
}
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
// 先检测是否为ws
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
// read sk from Sec-WebSocket-Protocol
key := c.Request.Header.Get("Sec-WebSocket-Protocol")
parts := strings.Split(key, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "openai-insecure-api-key") {
key = strings.TrimPrefix(part, "openai-insecure-api-key.")
break
}
}
c.Request.Header.Set("Authorization", "Bearer "+key)
}
key := c.Request.Header.Get("Authorization")
parts := make([]string, 0)
key = strings.TrimPrefix(key, "Bearer ")

View File

@@ -170,6 +170,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
modelRequest.Model = c.Query("model")
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"

View File

@@ -13,6 +13,10 @@ var timeFormat = "2006-01-02T15:04:05.000Z"
var inMemoryRateLimiter common.InMemoryRateLimiter
var defNext = func(c *gin.Context) {
c.Next()
}
func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
ctx := context.Background()
rdb := common.RDB
@@ -83,11 +87,17 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
func GlobalWebRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
if common.GlobalWebRateLimitEnable {
return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
}
return defNext
}
func GlobalAPIRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
if common.GlobalApiRateLimitEnable {
return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
}
return defNext
}
func CriticalRateLimit() func(c *gin.Context) {

View File

@@ -10,12 +10,13 @@ import (
)
type Ability struct {
Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
Weight uint `json:"weight" gorm:"default:0;index"`
Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
Weight uint `json:"weight" gorm:"default:0;index"`
Tag *string `json:"tag" gorm:"index"`
}
func GetGroupModels(group string) []string {
@@ -149,6 +150,7 @@ func (channel *Channel) AddAbilities() error {
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
Tag: channel.Tag,
}
abilities = append(abilities, ability)
}
@@ -190,6 +192,24 @@ func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
func UpdateAbilityStatusByTag(tag string, status bool) error {
return DB.Model(&Ability{}).Where("tag = ?", tag).Select("enabled").Update("enabled", status).Error
}
func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uint) error {
ability := Ability{}
if newTag != nil {
ability.Tag = newTag
}
if priority != nil {
ability.Priority = priority
}
if weight != nil {
ability.Weight = *weight
}
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
}
func FixAbility() (int, error) {
var channelIds []int
count := 0

View File

@@ -32,6 +32,7 @@ type Channel struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
Tag *string `json:"tag" gorm:"index"`
}
func (channel *Channel) GetModels() []string {
@@ -61,6 +62,17 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
channel.OtherInfo = string(otherInfoBytes)
}
func (channel *Channel) GetTag() string {
if channel.Tag == nil {
return ""
}
return *channel.Tag
}
func (channel *Channel) SetTag(tag string) {
channel.Tag = &tag
}
func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
@@ -87,7 +99,13 @@ func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Chan
return channels, err
}
func SearchChannels(keyword string, group string, model string) ([]*Channel, error) {
func GetChannelsByTag(tag string) ([]*Channel, error) {
var channels []*Channel
err := DB.Where("tag = ?", tag).Find(&channels).Error
return channels, err
}
func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
var channels []*Channel
keyCol := "`key`"
groupCol := "`group`"
@@ -100,6 +118,11 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
modelsCol = `"models"`
}
order := "priority desc"
if idSort {
order = "id desc"
}
// 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
@@ -122,7 +145,7 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
}
// 执行查询
err := baseQuery.Where(whereClause, args...).Order("priority desc").Find(&channels).Error
err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error
if err != nil {
return nil, err
}
@@ -288,6 +311,74 @@ func UpdateChannelStatusById(id int, status int, reason string) {
}
func EnableChannelByTag(tag string) error {
err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error
if err != nil {
return err
}
err = UpdateAbilityStatusByTag(tag, true)
return err
}
func DisableChannelByTag(tag string) error {
err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error
if err != nil {
return err
}
err = UpdateAbilityStatusByTag(tag, false)
return err
}
func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error {
updateData := Channel{}
shouldReCreateAbilities := false
updatedTag := tag
// 如果 newTag 不为空且不等于 tag则更新 tag
if newTag != nil && *newTag != tag {
updateData.Tag = newTag
updatedTag = *newTag
}
if modelMapping != nil && *modelMapping != "" {
updateData.ModelMapping = modelMapping
}
if models != nil && *models != "" {
shouldReCreateAbilities = true
updateData.Models = *models
}
if group != nil && *group != "" {
shouldReCreateAbilities = true
updateData.Group = *group
}
if priority != nil {
updateData.Priority = priority
}
if weight != nil {
updateData.Weight = weight
}
err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
if err != nil {
return err
}
if shouldReCreateAbilities {
channels, err := GetChannelsByTag(updatedTag)
if err == nil {
for _, channel := range channels {
err = channel.UpdateAbilities()
if err != nil {
common.SysError("failed to update abilities: " + err.Error())
}
}
}
} else {
err := UpdateAbilityByTag(tag, newTag, priority, weight)
if err != nil {
return err
}
}
return nil
}
func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"one-api/common"
"os"
"strings"
"time"
@@ -39,7 +40,15 @@ const (
)
func GetLogByKey(key string) (logs []*Log, err error) {
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
if os.Getenv("LOG_SQL_DSN") != "" {
var tk Token
if err = DB.Model(&Token{}).Where("`key`=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err
}
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
} else {
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
}
return logs, err
}

View File

@@ -31,6 +31,7 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
@@ -175,6 +176,8 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
case "LinuxDOOAuthEnabled":
common.LinuxDOOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled":
@@ -267,6 +270,10 @@ func updateOptionMap(key string, value string) (err error) {
common.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
case "LinuxDOClientId":
common.LinuxDOClientId = value
case "LinuxDOClientSecret":
common.LinuxDOClientSecret = value
case "Footer":
common.Footer = value
case "SystemName":

View File

@@ -36,6 +36,7 @@ type User struct {
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
}
func (user *User) GetAccessToken() string {
@@ -537,3 +538,17 @@ func GetUsernameById(id int) (username string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
return username, err
}
func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
var user User
err := DB.Unscoped().Where("linux_do_id = ?", linuxDOId).First(&user).Error
return !errors.Is(err, gorm.ErrRecordNotFound)
}
func (u *User) FillUserByLinuxDOId() error {
if u.LinuxDOId == "" {
return errors.New("linux do id is empty")
}
err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
return err
}

View File

@@ -12,13 +12,13 @@ type Adaptor interface {
// Init IsStream bool
Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

View File

@@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
req.Set("Authorization", "Bearer "+info.ApiKey)
if info.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
req.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
}
return nil
}
@@ -72,11 +72,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"net/http"
"one-api/relay/common"
@@ -11,14 +12,16 @@ import (
"one-api/service"
)
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
// multipart/form-data
} else if info.RelayMode == constant.RelayModeRealtime {
// websocket
} else {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
req.Set("Accept", "text/event-stream")
}
}
}
@@ -32,7 +35,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, req, info)
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
@@ -55,7 +58,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
// set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
err = a.SetupRequestHeader(c, req, info)
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
@@ -66,6 +69,27 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
return resp, nil
}
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
targetHeader := http.Header{}
err = a.SetupRequestHeader(c, &targetHeader, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
if err != nil {
return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
}
// send request body
//all, err := io.ReadAll(requestBody)
//err = service.WssString(c, targetConn, string(all))
return targetConn, nil
}
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := service.GetHttpClient().Do(req)
if err != nil {

View File

@@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
return nil
}
@@ -59,11 +59,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else {

View File

@@ -9,6 +9,7 @@ var awsModelIDMap = map[string]string{
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
}
var ChannelName = "aws"

View File

@@ -7,9 +7,9 @@ import (
type AwsClaudeRequest struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
System string `json:"system"`
System string `json:"system,omitempty"`
Messages []claude.ClaudeMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
@@ -17,3 +17,18 @@ type AwsClaudeRequest struct {
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
return &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
System: req.System,
Messages: req.Messages,
MaxTokens: req.MaxTokens,
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
StopSequences: req.StopSequences,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
}
}

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"io"
"net/http"
@@ -78,13 +77,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*claude.ClaudeRequest)
awsClaudeReq := &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
@@ -136,12 +129,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
claudeReq := claudeReq_.(*claude.ClaudeRequest)
awsClaudeReq := &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil

View File

@@ -98,9 +98,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -122,11 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = baiduStreamHandler(c, resp)
} else {

View File

@@ -47,14 +47,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-api-key", info.ApiKey)
req.Set("x-api-key", info.ApiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
req.Set("anthropic-version", anthropicVersion)
return nil
}
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else {

View File

@@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
}, nil
}
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}

View File

@@ -30,9 +30,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
@@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fallthrough

View File

@@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage

View File

@@ -36,9 +36,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
@@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return requestOpenAI2Cohere(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return requestConvertRerank2Cohere(request), nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)
} else {

View File

@@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -48,11 +48,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = difyStreamHandler(c, resp, info)
} else {

View File

@@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
}
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, usage

View File

@@ -30,13 +30,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta"
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
} else {
version = "v1"
version = "v1beta"
}
}
@@ -47,9 +47,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey)
req.Set("x-goog-api-key", info.ApiKey)
return nil
}
@@ -64,11 +64,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info)
} else {

View File

@@ -7,6 +7,7 @@ const (
var ModelList = []string{
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
"gemini-exp-1114",
}
var ChannelName = "google gemini"

View File

@@ -37,9 +37,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
@@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -55,7 +55,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {

View File

@@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -50,11 +50,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {

View File

@@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
return nil
}
@@ -58,11 +58,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {

View File

@@ -37,5 +37,5 @@ type OllamaEmbeddingRequest struct {
type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding []float64 `json:"embedding,omitempty"`
Embedding [][]float64 `json:"embeddings,omitempty"`
}

View File

@@ -73,9 +73,10 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if ollamaEmbeddingResponse.Error != "" {
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
}
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding,
Embedding: flattenedEmbeddings,
Object: "embedding",
})
usage := &dto.Usage{
@@ -120,3 +121,11 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
}
return nil, usage
}
func flattenEmbeddings(embeddings [][]float64) []float64 {
flattened := []float64{}
for _, row := range embeddings {
flattened = append(flattened, row...)
}
return flattened
}

View File

@@ -31,6 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRealtime {
// trim https
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
baseUrl = strings.TrimPrefix(baseUrl, "http://")
baseUrl = "wss://" + baseUrl
info.BaseUrl = baseUrl
}
switch info.ChannelType {
case common.ChannelTypeAzure:
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
@@ -40,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
@@ -54,16 +63,34 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, header)
if info.ChannelType == common.ChannelTypeAzure {
req.Header.Set("api-key", info.ApiKey)
header.Set("api-key", info.ApiKey)
return nil
}
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
req.Header.Set("OpenAI-Organization", info.Organization)
header.Set("OpenAI-Organization", info.Organization)
}
if info.RelayMode == constant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
items := []string{
"realtime",
"openai-insecure-api-key." + info.ApiKey,
"openai-beta.realtime-v1",
}
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
} else {
header.Set("openai-beta", "realtime=v1")
header.Set("Authorization", "Bearer "+info.ApiKey)
}
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
}
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
//if info.ChannelType == common.ChannelTypeOpenRouter {
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
// req.Header.Set("X-Title", "One API")
@@ -105,6 +132,19 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
// 添加文件字段
file, header, err := c.Request.FormFile("file")
if err != nil {
@@ -131,16 +171,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
return channel.DoFormRequest(a, c, info, requestBody)
} else if info.RelayMode == constant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
} else {
return channel.DoApiRequest(a, c, info, requestBody)
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
case constant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeAudioTranslation:

View File

@@ -9,7 +9,7 @@ var ModelList = []string{
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview",
"chatgpt-4o-latest",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"net/http"
"one-api/common"
@@ -97,6 +98,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
shouldSendLastResp = false
}
}
for _, choice := range lastStreamResponse.Choices {
if choice.FinishReason != nil {
shouldSendLastResp = true
}
}
}
if shouldSendLastResp {
service.StringData(c, lastStreamData)
@@ -231,7 +237,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
@@ -278,7 +284,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var audioResp dto.AudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -287,11 +292,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &audioResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
@@ -324,7 +324,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
@@ -373,3 +373,210 @@ func getTextFromJSON(body []byte) (string, error) {
}
return whisperResponse.Text, nil
}
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
info.IsStream = true
clientConn := info.ClientWs
targetConn := info.TargetWs
clientClosed := make(chan struct{})
targetClosed := make(chan struct{})
sendChan := make(chan []byte, 100)
receiveChan := make(chan []byte, 100)
errChan := make(chan error, 2)
usage := &dto.RealtimeUsage{}
localUsage := &dto.RealtimeUsage{}
sumUsage := &dto.RealtimeUsage{}
gopool.Go(func() {
for {
select {
case <-c.Done():
return
default:
_, message, err := clientConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from client: %v", err)
}
close(clientClosed)
return
}
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
if realtimeEvent.Session != nil {
if realtimeEvent.Session.Tools != nil {
info.RealtimeTools = realtimeEvent.Session.Tools
}
}
}
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = service.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
return
}
select {
case sendChan <- message:
default:
}
}
}
})
gopool.Go(func() {
for {
select {
case <-c.Done():
return
default:
_, message, err := targetConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from target: %v", err)
}
close(targetClosed)
return
}
info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
realtimeUsage := realtimeEvent.Response.Usage
if realtimeUsage != nil {
usage.TotalTokens += realtimeUsage.TotalTokens
usage.InputTokens += realtimeUsage.InputTokens
usage.OutputTokens += realtimeUsage.OutputTokens
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
err := preConsumeUsage(c, info, usage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
usage = &dto.RealtimeUsage{}
localUsage = &dto.RealtimeUsage{}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = preConsumeUsage(c, info, localUsage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
localUsage = &dto.RealtimeUsage{}
// print now usage
}
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
if realtimeSession != nil {
// update audio format
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
err = service.WssString(c, clientConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to client: %v", err)
return
}
select {
case receiveChan <- message:
default:
}
}
}
})
select {
case <-clientClosed:
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
common.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
if usage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, usage, sumUsage)
}
if localUsage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, localUsage, sumUsage)
}
// check usage total tokens, if 0, use local usage
return nil, sumUsage
}
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
totalUsage.TotalTokens += usage.TotalTokens
totalUsage.InputTokens += usage.InputTokens
totalUsage.OutputTokens += usage.OutputTokens
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
// clear usage
err := service.PreWssConsumeQuota(ctx, info, usage)
return err
}

View File

@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey)
req.Set("x-goog-api-key", info.ApiKey)
return nil
}
@@ -49,11 +49,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)

View File

@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,

View File

@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -52,11 +52,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {

View File

@@ -40,9 +40,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
@@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRerank:
err, usage = siliconflowRerankHandler(c, resp)

View File

@@ -43,12 +43,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", a.Action)
req.Header.Set("X-TC-Version", a.Version)
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
req.Set("Authorization", a.Sign)
req.Set("X-TC-Action", a.Action)
req.Set("X-TC-Version", a.Version)
req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
return nil
}
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)

View File

@@ -107,13 +107,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", errors.New("unsupported request mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
accessToken, err := getAccessToken(a, info)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Set("Authorization", "Bearer "+accessToken)
return nil
}
@@ -148,11 +148,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:

View File

@@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
return nil
}
@@ -50,14 +50,14 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}
dummyResp.StatusCode = http.StatusOK
return dummyResp, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
splits := strings.Split(info.ApiKey, "|")
if len(splits) != 3 {
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)

View File

@@ -35,10 +35,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey)
req.Header.Set("Authorization", token)
req.Set("Authorization", token)
return nil
}
@@ -56,11 +56,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = zhipuStreamHandler(c, resp)
} else {

View File

@@ -32,10 +32,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey)
req.Header.Set("Authorization", token)
req.Set("Authorization", token)
return nil
}
@@ -53,11 +53,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {

View File

@@ -2,7 +2,9 @@ package common
import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"one-api/common"
"one-api/dto"
"one-api/relay/constant"
"strings"
"time"
@@ -21,6 +23,7 @@ type RelayInfo struct {
ApiType int
IsStream bool
IsPlayground bool
UsePrice bool
RelayMode int
UpstreamModelName string
OriginModelName string
@@ -32,6 +35,22 @@ type RelayInfo struct {
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
OutputAudioFormat string
RealtimeTools []dto.RealTimeTool
IsFirstRequest bool
AudioUsage bool
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c)
info.ClientWs = ws
info.InputAudioFormat = "pcm16"
info.OutputAudioFormat = "pcm16"
info.IsFirstRequest = true
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo {

View File

@@ -38,6 +38,8 @@ const (
RelayModeSunoSubmit
RelayModeRerank
RelayModeRealtime
)
func Path2RelayMode(path string) int {
@@ -64,6 +66,8 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeAudioTranslation
} else if strings.HasPrefix(path, "/v1/rerank") {
relayMode = RelayModeRerank
} else if strings.HasPrefix(path, "/v1/realtime") {
relayMode = RelayModeRealtime
}
return relayMode
}

View File

@@ -33,12 +33,19 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
}
}
default:
if audioRequest.Model == "" {
audioRequest.Model = c.PostForm("model")
err = c.Request.ParseForm()
if err != nil {
return nil, err
}
formData := c.Request.PostForm
if audioRequest.Model == "" {
audioRequest.Model = formData.Get("model")
}
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
audioRequest.ResponseFormat = formData.Get("response_format")
if audioRequest.ResponseFormat == "" {
audioRequest.ResponseFormat = "json"
}
@@ -46,7 +53,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return audioRequest, nil
}
func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
@@ -58,7 +65,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
}
@@ -92,6 +99,11 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
// map model name
modelMapping := c.GetString("model_mapping")
@@ -122,27 +134,27 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp)
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
}
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
return nil
}

View File

@@ -149,22 +149,24 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
var httpResp *http.Response
if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(resp)
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(httpResp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
}
_, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
_, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/sonic"
"io"
"math"
"net/http"
@@ -64,7 +65,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
return textRequest, nil
}
func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
@@ -76,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
}
// map model name
isModelMapped := false
//isModelMapped := false
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
@@ -86,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
isModelMapped = true
//isModelMapped = true
textRequest.Model = modelMap[textRequest.Model]
// set upstream model name
//isModelMapped = true
@@ -131,7 +132,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
includeUsage := false
// 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
@@ -161,49 +166,56 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
adaptor.Init(relayInfo)
var requestBody io.Reader
if relayInfo.ChannelType == common.ChannelTypeOpenAI && !isModelMapped {
body, err := common.GetRequestBody(c)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
//if relayInfo.ChannelType == common.ChannelTypeOpenAI && !isModelMapped {
// body, err := common.GetRequestBody(c)
// if err != nil {
// return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
// }
// requestBody = bytes.NewBuffer(body)
//} else {
//
//}
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := sonic.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp)
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
}
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
if strings.HasPrefix(relayInfo.UpstreamModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
} else {
postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
}
return nil
}

View File

@@ -23,7 +23,7 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
return token
}
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
var rerankRequest *dto.RerankRequest
@@ -79,6 +79,12 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
@@ -99,23 +105,24 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
var httpResp *http.Response
if resp != nil {
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp)
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
}
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
return nil
}

159
relay/websocket.go Normal file
View File

@@ -0,0 +1,159 @@
package relay
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
)
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
// _, p, err := ws.ReadMessage()
// if err != nil {
// return nil, err
// }
// realtimeEvent := &dto.RealtimeEvent{}
// err = json.Unmarshal(p, realtimeEvent)
// if err != nil {
// return nil, err
// }
// // save the original request
// if realtimeEvent.Session == nil {
// return nil, errors.New("session object is nil")
// }
// c.Set("first_wss_request", p)
// return realtimeEvent, nil
//}
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
// get & validate textRequest 获取并验证文本请求
//realtimeEvent, err := getAndValidateWssRequest(c, ws)
//if err != nil {
// common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
//}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[relayInfo.OriginModelName] != "" {
relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
// set upstream model name
//isModelMapped = true
}
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
//if constant.ShouldCheckPromptSensitive() {
// err = checkRequestSensitive(textRequest, relayInfo)
// if err != nil {
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
// }
//}
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
//// count messages token error 计算promptTokens错误
//if err != nil {
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
//}
//
if !getModelPriceSuccess {
preConsumedTokens := common.PreConsumedQuota
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
//}
modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
relayInfo.UsePrice = true
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
//var requestBody io.Reader
//firstWssRequest, _ := c.Get("first_wss_request")
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, nil)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {
relayInfo.TargetWs = resp.(*websocket.Conn)
defer relayInfo.TargetWs.Close()
}
usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
return nil
}
//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
// var promptTokens int
// var err error
// switch info.RelayMode {
// default:
// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
// }
// info.PromptTokens = promptTokens
// return promptTokens, err
//}
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
// var err error
// switch info.RelayMode {
// case relayconstant.RelayModeChatCompletions:
// err = service.CheckSensitiveMessages(textRequest.Messages)
// case relayconstant.RelayModeCompletions:
// err = service.CheckSensitiveInput(textRequest.Prompt)
// case relayconstant.RelayModeModerations:
// err = service.CheckSensitiveInput(textRequest.Input)
// case relayconstant.RelayModeEmbeddings:
// err = service.CheckSensitiveInput(textRequest.Input)
// }
// return err
//}

View File

@@ -25,6 +25,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
@@ -90,6 +91,9 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
channelRoute.POST("/tag/disabled", controller.DisableTagChannels)
channelRoute.POST("/tag/enabled", controller.EnableTagChannels)
channelRoute.PUT("/tag", controller.EditTagChannels)
channelRoute.DELETE("/:id", controller.DeleteChannel)
channelRoute.POST("/batch", controller.DeleteChannelBatch)
channelRoute.POST("/fix", controller.FixChannelsAbilities)

View File

@@ -22,32 +22,41 @@ func SetRelayRouter(router *gin.Engine) {
playgroundRouter.POST("/chat/completions", controller.Playground)
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
relayV1Router.Use(middleware.TokenAuth())
{
relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.Relay)
relayV1Router.POST("/images/generations", controller.Relay)
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/embeddings", controller.Relay)
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay)
relayV1Router.POST("/rerank", controller.Relay)
// WebSocket 路由
wsRouter := relayV1Router.Group("")
wsRouter.Use(middleware.Distribute())
wsRouter.GET("/realtime", controller.WssRelay)
}
{
//http router
httpRouter := relayV1Router.Group("")
httpRouter.Use(middleware.Distribute())
httpRouter.POST("/completions", controller.Relay)
httpRouter.POST("/chat/completions", controller.Relay)
httpRouter.POST("/edits", controller.Relay)
httpRouter.POST("/images/generations", controller.Relay)
httpRouter.POST("/images/edits", controller.RelayNotImplemented)
httpRouter.POST("/images/variations", controller.RelayNotImplemented)
httpRouter.POST("/embeddings", controller.Relay)
httpRouter.POST("/engines/:model/embeddings", controller.Relay)
httpRouter.POST("/audio/transcriptions", controller.Relay)
httpRouter.POST("/audio/translations", controller.Relay)
httpRouter.POST("/audio/speech", controller.Relay)
httpRouter.GET("/files", controller.RelayNotImplemented)
httpRouter.POST("/files", controller.RelayNotImplemented)
httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
httpRouter.GET("/files/:id", controller.RelayNotImplemented)
httpRouter.GET("/files/:id/content", controller.RelayNotImplemented)
httpRouter.POST("/fine-tunes", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented)
httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
httpRouter.POST("/moderations", controller.Relay)
httpRouter.POST("/rerank", controller.Relay)
}
relayMjRouter := router.Group("/mj")

31
service/audio.go Normal file
View File

@@ -0,0 +1,31 @@
package service
import (
"encoding/base64"
"fmt"
)
func parseAudio(audioBase64 string, format string) (duration float64, err error) {
audioData, err := base64.StdEncoding.DecodeString(audioBase64)
if err != nil {
return 0, fmt.Errorf("base64 decode error: %v", err)
}
var samplesCount int
var sampleRate int
switch format {
case "pcm16":
samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
sampleRate = 24000 // 24kHz
case "g711_ulaw", "g711_alaw":
samplesCount = len(audioData) // 8位 = 1字节每样本
sampleRate = 8000 // 8kHz
default:
samplesCount = len(audioData) // 8位 = 1字节每样本
sampleRate = 8000 // 8kHz
}
duration = float64(samplesCount) / float64(sampleRate)
return duration, nil
}

View File

@@ -2,6 +2,7 @@ package service
import (
"github.com/gin-gonic/gin"
"one-api/dto"
relaycommon "one-api/relay/common"
)
@@ -17,3 +18,27 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
other["admin_info"] = adminInfo
return other
}
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
info["ws"] = true
info["audio_input"] = usage.InputTokenDetails.AudioTokens
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
info["text_input"] = usage.InputTokenDetails.TextTokens
info["text_output"] = usage.OutputTokenDetails.TextTokens
info["audio_ratio"] = audioRatio
info["audio_completion_ratio"] = audioCompletionRatio
return info
}
func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
info["audio"] = true
info["audio_input"] = usage.PromptTokensDetails.AudioTokens
info["audio_output"] = usage.CompletionTokenDetails.AudioTokens
info["text_input"] = usage.PromptTokensDetails.TextTokens
info["text_output"] = usage.CompletionTokenDetails.TextTokens
info["audio_ratio"] = audioRatio
info["audio_completion_ratio"] = audioCompletionRatio
return info
}

212
service/quota.go Normal file
View File

@@ -0,0 +1,212 @@
package service
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"math"
"one-api/common"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"strings"
"time"
)
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
if relayInfo.UsePrice {
return nil
}
userQuota, err := model.GetUserQuota(relayInfo.UserId)
if err != nil {
return err
}
token, err := model.CacheGetTokenByKey(strings.TrimLeft(relayInfo.ApiKey, "sk-"))
if err != nil {
return err
}
modelName := relayInfo.UpstreamModelName
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
groupRatio := common.GetGroupRatio(relayInfo.Group)
modelRatio := common.GetModelRatio(modelName)
ratio := groupRatio * modelRatio
quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
if userQuota < quota {
return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
}
if token.RemainQuota < quota {
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
}
err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false)
if err != nil {
return err
}
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
err = model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
return err
}
return nil
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
quota := 0
if !usePrice {
quota = int(math.Round(float64(textInputTokens) + float64(textOutTokens)*completionRatio))
quota += int(math.Round(float64(audioInputTokens)*audioRatio + float64(audioOutTokens)*audioRatio*audioCompletionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
//if sensitiveResp != nil {
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
//}
//quotaDelta := quota - preConsumedQuota
//if quotaDelta != 0 {
// err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
// if err != nil {
// common.LogError(ctx, "error consuming token remain quota: "+err.Error())
// }
//}
//err := model.CacheUpdateUserQuota(relayInfo.UserId)
//if err != nil {
// common.LogError(ctx, "error update user quota cache: "+err.Error())
//}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
logModel := modelName
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
textOutTokens := usage.CompletionTokenDetails.TextTokens
audioInputTokens := usage.PromptTokensDetails.AudioTokens
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(relayInfo.UpstreamModelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.UpstreamModelName)
quota := 0
if !usePrice {
quota = int(math.Round(float64(textInputTokens) + float64(textOutTokens)*completionRatio))
quota += int(math.Round(float64(audioInputTokens)*audioRatio + float64(audioOutTokens)*audioRatio*audioCompletionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.UpstreamModelName, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
err := model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
logModel := relayInfo.UpstreamModelName
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"one-api/common"
"one-api/dto"
@@ -42,11 +43,47 @@ func Done(c *gin.Context) {
_ = StringData(c, "[DONE]")
}
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
if ws == nil {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
return ws.WriteMessage(1, []byte(str))
}
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
if ws == nil {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
return ws.WriteMessage(1, jsonData)
}
func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
errorObj := &dto.RealtimeEvent{
Type: "error",
EventId: GetLocalRealtimeID(c),
Error: &openaiError,
}
_ = WssObject(c, ws, errorObj)
}
func GetResponseID(c *gin.Context) string {
logID := c.GetString("X-Oneapi-Request-Id")
logID := c.GetString(common.RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID)
}
func GetLocalRealtimeID(c *gin.Context) string {
logID := c.GetString(common.RequestIdKey)
return fmt.Sprintf("evt_%s", logID)
}
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,

View File

@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"strings"
"unicode/utf8"
)
@@ -191,6 +192,72 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
return tkm, nil
}
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
audioToken := 0
textToken := 0
switch request.Type {
case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil {
msgTokens, err := CountTextToken(request.Session.Instructions, model)
if err != nil {
return 0, 0, err
}
textToken += msgTokens
}
case dto.RealtimeEventResponseAudioDelta:
// count audio token
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token
tkm, err := CountTextToken(request.Delta, model)
if err != nil {
return 0, 0, fmt.Errorf("error counting text token: %v", err)
}
textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend:
// count audio token
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventConversationItemCreated:
if request.Item != nil {
switch request.Item.Type {
case "message":
for _, content := range request.Item.Content {
if content.Type == "input_text" {
tokens, err := CountTextToken(content.Text, model)
if err != nil {
return 0, 0, err
}
textToken += tokens
}
}
}
}
case dto.RealtimeEventTypeResponseDone:
// count tools token
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
toolTokens, err := CountTokenInput(tool, model)
if err != nil {
return 0, 0, err
}
textToken += 8
textToken += toolTokens
}
}
}
}
return textToken, audioToken, nil
}
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
@@ -248,13 +315,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
func CountTokenInput(input any, model string) (int, error) {
switch v := input.(type) {
case string:
return CountTokenText(v, model)
return CountTextToken(v, model)
case []string:
text := ""
for _, s := range v {
text += s
}
return CountTokenText(text, model)
return CountTextToken(text, model)
}
return CountTokenInput(fmt.Sprintf("%v", input), model)
}
@@ -276,16 +343,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens
}
func CountAudioToken(text string, model string) (int, error) {
func CountTTSToken(text string, model string) (int, error) {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil
} else {
return CountTokenText(text, model)
return CountTextToken(text, model)
}
}
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTokenText(text string, model string) (int, error) {
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 100 / 0.06), nil
}
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 200 / 0.24), nil
}
//func CountAudioToken(sec float64, audioType string) {
// if audioType == "input" {
//
// }
//}
// CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTextToken(text string, model string) (int, error) {
var err error
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err

View File

@@ -19,7 +19,7 @@ import (
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
ctkm, err := CountTokenText(responseText, modeName)
ctkm, err := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err

4789
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,6 @@ import Setting from './pages/Setting';
import EditUser from './pages/User/EditUser';
import { getLogo, getSystemName } from './helpers';
import PasswordResetForm from './components/PasswordResetForm';
import GitHubOAuth from './components/GitHubOAuth';
import PasswordResetConfirm from './components/PasswordResetConfirm';
import { UserContext } from './context/User';
import Channel from './pages/Channel';
@@ -26,6 +25,7 @@ import Midjourney from './pages/Midjourney';
import Pricing from './pages/Pricing/index.js';
import Task from "./pages/Task/index.js";
import Playground from './components/Playground.js';
import OAuth2Callback from "./components/OAuth2Callback.js";
const Home = lazy(() => import('./pages/Home'));
const Detail = lazy(() => import('./pages/Detail'));
@@ -177,7 +177,15 @@ function App() {
path='/oauth/github'
element={
<Suspense fallback={<Loading></Loading>}>
<GitHubOAuth />
<OAuth2Callback type='github'></OAuth2Callback>
</Suspense>
}
/>
<Route
path='/oauth/linuxdo'
element={
<Suspense fallback={<Loading></Loading>}>
<OAuth2Callback type='linuxdo'></OAuth2Callback>
</Suspense>
}
/>

File diff suppressed because it is too large Load Diff

View File

@@ -1,61 +0,0 @@
import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess, updateAPI } from '../helpers';
import { UserContext } from '../context/User';
import { setUserData } from '../helpers/data.js';
const GitHubOAuth = () => {
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [prompt, setPrompt] = useState('处理中...');
const [processing, setProcessing] = useState(true);
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
} else {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
setUserData(data);
updateAPI()
showSuccess('登录成功!');
navigate('/token');
}
} else {
showError(message);
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
navigate('/setting'); // in case this is failed to bind GitHub
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, state, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
</Dimmer>
</Segment>
);
};
export default GitHubOAuth;

View File

@@ -0,0 +1,27 @@
import React from 'react';
import { Icon } from '@douyinfe/semi-ui';
const LinuxDoIcon = (props) => {
function CustomIcon() {
return (
<svg
className='icon'
viewBox='0 0 24 24'
version='1.1'
xmlns='http://www.w3.org/2000/svg'
width='1em'
height='1em'
{...props}
>
<path
d='M19.7,17.6c-0.1-0.2-0.2-0.4-0.2-0.6c0-0.4-0.2-0.7-0.5-1c-0.1-0.1-0.3-0.2-0.4-0.2c0.6-1.8-0.3-3.6-1.3-4.9c0,0,0,0,0,0c-0.8-1.2-2-2.1-1.9-3.7c0-1.9,0.2-5.4-3.3-5.1C8.5,2.3,9.5,6,9.4,7.3c0,1.1-0.5,2.2-1.3,3.1c-0.2,0.2-0.4,0.5-0.5,0.7c-1,1.2-1.5,2.8-1.5,4.3c-0.2,0.2-0.4,0.4-0.5,0.6c-0.1,0.1-0.2,0.2-0.2,0.3c-0.1,0.1-0.3,0.2-0.5,0.3c-0.4,0.1-0.7,0.3-0.9,0.7c-0.1,0.3-0.2,0.7-0.1,1.1c0.1,0.2,0.1,0.4,0,0.7c-0.2,0.4-0.2,0.9,0,1.4c0.3,0.4,0.8,0.5,1.5,0.6c0.5,0,1.1,0.2,1.6,0.4l0,0c0.5,0.3,1.1,0.5,1.7,0.5c0.3,0,0.7-0.1,1-0.2c0.3-0.2,0.5-0.4,0.6-0.7c0.4,0,1-0.2,1.7-0.2c0.6,0,1.2,0.2,2,0.1c0,0.1,0,0.2,0.1,0.3c0.2,0.5,0.7,0.9,1.3,1c0.1,0,0.1,0,0.2,0c0.8-0.1,1.6-0.5,2.1-1.1l0,0c0.4-0.4,0.9-0.7,1.4-0.9c0.6-0.3,1-0.5,1.1-1C20.3,18.6,20.1,18.2,19.7,17.6z M12.8,4.8c0.6,0.1,1.1,0.6,1,1.2c0,0.3-0.1,0.6-0.3,0.9c0,0,0,0-0.1,0c-0.2-0.1-0.3-0.1-0.4-0.2c0.1-0.1,0.1-0.3,0.2-0.5c0-0.4-0.2-0.7-0.4-0.7c-0.3,0-0.5,0.3-0.5,0.7c0,0,0,0.1,0,0.1c-0.1-0.1-0.3-0.1-0.4-0.2c0,0,0-0.1,0-0.1C11.8,5.5,12.2,4.9,12.8,4.8z M12.5,6.8c0.1,0.1,0.3,0.2,0.4,0.2c0.1,0,0.3,0.1,0.4,0.2c0.2,0.1,0.4,0.2,0.4,0.5c0,0.3-0.3,0.6-0.9,0.8c-0.2,0.1-0.3,0.1-0.4,0.2c-0.3,0.2-0.6,0.3-1,0.3c-0.3,0-0.6-0.2-0.8-0.4c-0.1-0.1-0.2-0.2-0.4-0.3C10.1,8.2,9.9,8,9.8,7.7c0-0.1,0.1-0.2,0.2-0.3c0.3-0.2,0.4-0.3,0.5-0.4l0.1-0.1c0.2-0.3,0.6-0.5,1-0.5C11.9,6.5,12.2,6.6,12.5,6.8z M10.4,5c0.4,0,0.7,0.4,0.8,1.1c0,0.1,0,0.1,0,0.2c-0.1,0-0.3,0.1-0.4,0.2c0,0,0-0.1,0-0.2c0-0.3-0.2-0.6-0.4-0.5c-0.2,0-0.3,0.3-0.3,0.6c0,0.2,0.1,0.3,0.2,0.4l0,0c0,0-0.1,0.1-0.2,0.1C9.9,6.7,9.7,6.4,9.7,6.1C9.7,5.5,10,5,10.4,5z M9.4,21.1c-0.7,0.3-1.6,0.2-2.2-0.2c-0.6-0.3-1.1-0.4-1.8-0.4c-0.5-0.1-1-0.1-1.1-0.3c-0.1-0.2-0.1-0.5,0.1-1c0.1-0.3,0.1-0.6,0-0.9c-0.1-0.3-0.1-0.5,0-0.8C4.5,17.2,4.7,17.1,5,17c0.3-0.1,0.5-0.2,0.7-0.4c0.1-0.1,0.2-0.2,0.3-0.4c0.3-0.4,0.5-0.6,0.8-0.6c0.6,0.1,1.1,1,1.5,1.9c0.2,0.3,0.4,0.7,0.7,1c0.4,0.5,0.9,1.2,0.9,1.6C9.9,20.6,9.7,20.9,9.4,21.1z M14.3,18.9c0,0.1,0,0.1-0.1,0.2c-1.2,0.9-2.8,1-4.1,0.3c-0.2-0.3-0.4-0.6-0.6-0.9c0.9-0.1,0.7-1.3-1.2-2.5c-2-1.3-0.6-3.7,0.1-4.8c0.1-0.1,0.1,0-0.3,0.8c-0.3,0.6-0.9,2.1-0.1,3.2c0-0.8,0.2-1.6,0.5-2.4c0.7-1.3,1.2-2.8,1.5-4.3c0.1,0.1,0.1,0.1,0.2,0.1c0.1,0.1,0.2,0.2,0.3,0.2c0.2,0.3,0.6,0.4,0.9,0.4c0,0,0.1,0,0.1,0c0.4,0,0.8-0.1,1.1-0.4c0.1-0.1,0.2-0.2,0.4-0.2c0.3-0.1,0.6-0.3,0.9-0.6c0.4,1.3,0.8,2.5,1.4,3.6c0.4,0.8,0.7,1.6,0.9,2.5c0.3,0,0.7,0.1,1,0.3c0.8,0.4,1.1,0.7,1,1.2c-0.1,0-0.1,0-0.2,0c0-0.3-0.2-0.6-0.9-0.9c-0.7-0.3-1.3-0.3-1.5,0.4c-0.1,0-0.2,0.1-0.3,0.1c-0.8,0.4-0.8,1.5-0.9,2.6C14.5,18.2,14.4,18.5,14.3,18.9z M18.9,19.5c-0.6,0.2-1.1,0.6-1.5,1.1c-0.4,0.6-1.1,1-1.9,0.9c-0.4,0-0.8-0.3-0.9-0.7c-0.1-0.6-0.1-1.2,0.2-1.8c0.1-0.4,0.2-0.7,0.3-1.1c0.1-1.2,0.1-1.9,0.6-2.2h0c0,0.5,0.3,0.8,0.7,1c0.5,0,1-0.1,1.4-0.5c0.1,0,0.1,0,0.2,0c0.3,0,0.5,0,0.7,0.2c0.2,0.2,0.3,0.5,0.3,0.7c0,0.3,0.2,0.6,0.3,0.9c0.5,0.5,0.5,0.8,0.5,0.9C19.7,19.1,19.3,19.3,18.9,19.5z M9.9,7.5c-0.1,0-0.1,0-0.1,0.1c0,0,0,0.1,0.1,0.1c0,0,0,0,0,0c0.1,0,0.1,0.1,0.1,0.1c0.3,0.4,0.8,0.6,1.4,0.7c0.5-0.1,1-0.2,1.5-0.6c0.2-0.1,0.4-0.2,0.6-0.3c0.1,0,0.1-0.1,0.1-0.1c0-0.1,0-0.1-0.1-0.1l0,0c-0.2,0.1-0.5,0.2-0.7,0.3c-0.4,0.3-0.9,0.5-1.4,0.5c-0.5,0-0.9-0.3-1.2-0.6C10.1,7.6,10,7.5,9.9,7.5z'
fill='currentColor'
/>
</svg>
);
}
return <Icon svg={<CustomIcon />} />;
};
export default LinuxDoIcon;

View File

@@ -1,8 +1,15 @@
import React, { useContext, useEffect, useState } from 'react';
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, showError, showInfo, showSuccess, updateAPI } from '../helpers';
import { onGitHubOAuthClicked } from './utils';
import {
API,
getLogo,
showError,
showInfo,
showSuccess,
updateAPI,
} from '../helpers';
import { onGitHubOAuthClicked, onLinuxDOOAuthClicked } from './utils';
import Turnstile from 'react-turnstile';
import {
Button,
@@ -17,9 +24,10 @@ import Title from '@douyinfe/semi-ui/lib/es/typography/title';
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
import TelegramLoginButton from 'react-telegram-login';
import { IconGithubLogo } from '@douyinfe/semi-icons';
import { IconGithubLogo, IconAlarm } from '@douyinfe/semi-icons';
import WeChatIcon from './WeChatIcon';
import { setUserData } from '../helpers/data.js';
import LinuxDoIcon from './LinuxDoIcon.js';
const LoginForm = () => {
const [inputs, setInputs] = useState({
@@ -36,8 +44,15 @@ const LoginForm = () => {
const [turnstileToken, setTurnstileToken] = useState('');
let navigate = useNavigate();
const [status, setStatus] = useState({});
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
const logo = getLogo();
let affCode = new URLSearchParams(window.location.search).get('aff');
if (affCode) {
localStorage.setItem('aff', affCode);
}
useEffect(() => {
if (searchParams.get('expired')) {
showError('未登录或登录已过期,请重新登录!');
@@ -53,7 +68,6 @@ const LoginForm = () => {
}
}, []);
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
const onWeChatLoginClicked = () => {
setShowWeChatLoginModal(true);
@@ -72,7 +86,7 @@ const LoginForm = () => {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
setUserData(data);
updateAPI()
updateAPI();
navigate('/');
showSuccess('登录成功!');
setShowWeChatLoginModal(false);
@@ -103,7 +117,7 @@ const LoginForm = () => {
if (success) {
userDispatch({ type: 'login', payload: data });
setUserData(data);
updateAPI()
updateAPI();
showSuccess('登录成功!');
if (username === 'root' && password === '123456') {
Modal.error({
@@ -146,7 +160,7 @@ const LoginForm = () => {
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
setUserData(data);
updateAPI()
updateAPI();
navigate('/');
} else {
showError(message);
@@ -214,7 +228,8 @@ const LoginForm = () => {
</div>
{status.github_oauth ||
status.wechat_login ||
status.telegram_oauth ? (
status.telegram_oauth ||
status.linuxdo_oauth ? (
<>
<Divider margin='12px' align='center'>
第三方登录
@@ -237,6 +252,16 @@ const LoginForm = () => {
) : (
<></>
)}
{status.linuxdo_oauth ? (
<Button
icon={<LinuxDoIcon />}
onClick={() =>
onLinuxDOOAuthClicked(status.linuxdo_client_id)
}
/>
) : (
<></>
)}
{status.wechat_login ? (
<Button
type='primary'

View File

@@ -11,7 +11,7 @@ import {
import {
Avatar,
Button,
Button, Descriptions,
Form,
Layout,
Modal,
@@ -20,14 +20,15 @@ import {
Spin,
Table,
Tag,
Tooltip,
Tooltip
} from '@douyinfe/semi-ui';
import { ITEMS_PER_PAGE } from '../constants';
import {
renderAudioModelPrice,
renderModelPrice,
renderNumber,
renderQuota,
stringToColor,
stringToColor
} from '../helpers/render';
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
import { getLogOther } from '../helpers/other.js';
@@ -384,31 +385,31 @@ const LogsTable = () => {
</Paragraph>
);
}
let content = renderModelPrice(
record.prompt_tokens,
record.completion_tokens,
other.model_ratio,
other.model_price,
other.completion_ratio,
other.group_ratio,
);
// let content = renderModelPrice(
// record.prompt_tokens,
// record.completion_tokens,
// other.model_ratio,
// other.model_price,
// other.completion_ratio,
// other.group_ratio,
// );
return (
<Tooltip content={content}>
<Paragraph
ellipsis={{
rows: 2,
}}
style={{ maxWidth: 240 }}
ellipsis={{
rows: 2,
}}
style={{ maxWidth: 240 }}
>
{text}
调用消费
</Paragraph>
</Tooltip>
);
},
},
];
const [logs, setLogs] = useState([]);
const [expandData, setExpandData] = useState({});
const [showStat, setShowStat] = useState(false);
const [loading, setLoading] = useState(false);
const [loadingStat, setLoadingStat] = useState(false);
@@ -512,10 +513,89 @@ const LogsTable = () => {
};
const setLogsFormat = (logs) => {
let expandDatesLocal = {};
for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = '' + logs[i].id;
logs[i].key = i;
let other = getLogOther(logs[i].other);
let expandDataLocal = [];
if (isAdmin()) {
// let content = '渠道:' + logs[i].channel;
// if (other.admin_info !== undefined) {
// if (
// other.admin_info.use_channel !== null &&
// other.admin_info.use_channel !== undefined &&
// other.admin_info.use_channel !== ''
// ) {
// // channel id array
// let useChannel = other.admin_info.use_channel;
// let useChannelStr = useChannel.join('->');
// content = `渠道:${useChannelStr}`;
// }
// }
// expandDataLocal.push({
// key: '渠道重试',
// value: content,
// })
}
if (other?.ws || other?.audio) {
expandDataLocal.push({
key: '语音输入',
value: other.audio_input,
});
expandDataLocal.push({
key: '语音输出',
value: other.audio_output,
});
expandDataLocal.push({
key: '文字输入',
value: other.text_input,
});
expandDataLocal.push({
key: '文字输出',
value: other.text_output,
});
}
expandDataLocal.push({
key: '日志详情',
value: logs[i].content,
})
if (logs[i].type === 2) {
let content = '';
if (other?.ws || other?.audio) {
content = renderAudioModelPrice(
other.text_input,
other.text_output,
other.model_ratio,
other.model_price,
other.completion_ratio,
other.audio_input,
other.audio_output,
other?.audio_ratio,
other?.audio_completion_ratio,
other.group_ratio,
);
} else {
content = renderModelPrice(
logs[i].prompt_tokens,
logs[i].completion_tokens,
other.model_ratio,
other.model_price,
other.completion_ratio,
other.group_ratio,
);
}
expandDataLocal.push({
key: '计费过程',
value: content,
});
}
expandDatesLocal[logs[i].key] = expandDataLocal;
}
setExpandData(expandDatesLocal);
setLogs(logs);
};
@@ -588,6 +668,10 @@ const LogsTable = () => {
handleEyeClick();
}, []);
const expandRowRender = (record, index) => {
return <Descriptions data={expandData[record.key]} />;
};
return (
<>
<Layout>
@@ -683,10 +767,29 @@ const LogsTable = () => {
<Form.Section></Form.Section>
</>
</Form>
<div style={{marginTop:10}}>
<Select
defaultValue='0'
style={{ width: 120 }}
onChange={(value) => {
setLogType(parseInt(value));
loadLogs(0, pageSize, parseInt(value));
}}
>
<Select.Option value='0'>全部</Select.Option>
<Select.Option value='1'>充值</Select.Option>
<Select.Option value='2'>消费</Select.Option>
<Select.Option value='3'>管理</Select.Option>
<Select.Option value='4'>系统</Select.Option>
</Select>
</div>
<Table
style={{ marginTop: 5 }}
columns={columns}
expandedRowRender={expandRowRender}
expandRowByClick={true}
dataSource={logs}
rowKey="key"
pagination={{
currentPage: activePage,
pageSize: pageSize,
@@ -699,20 +802,6 @@ const LogsTable = () => {
onPageChange: handlePageChange,
}}
/>
<Select
defaultValue='0'
style={{ width: 120 }}
onChange={(value) => {
setLogType(parseInt(value));
loadLogs(0, pageSize, parseInt(value));
}}
>
<Select.Option value='0'>全部</Select.Option>
<Select.Option value='1'>充值</Select.Option>
<Select.Option value='2'>消费</Select.Option>
<Select.Option value='3'>管理</Select.Option>
<Select.Option value='4'>系统</Select.Option>
</Select>
</Layout>
</>
);

View File

@@ -0,0 +1,61 @@
import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess, updateAPI } from '../helpers';
import { UserContext } from '../context/User';
import { setUserData } from '../helpers/data.js';
const OAuth2Callback = (props) => {
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [prompt, setPrompt] = useState('处理中...');
const [processing, setProcessing] = useState(true);
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
const res = await API.get(`/api/oauth/${props.type}?code=${code}&state=${state}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
} else {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
setUserData(data);
updateAPI()
showSuccess('登录成功!');
navigate('/token');
}
} else {
showError(message);
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
navigate('/setting'); // in case this is failed to bind GitHub
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, state, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
</Dimmer>
</Segment>
);
};
export default OAuth2Callback;

View File

@@ -90,7 +90,7 @@ const OperationSetting = () => {
try {
setLoading(true);
await getOptions();
showSuccess('刷新成功');
// showSuccess('刷新成功');
} catch (error) {
showError('刷新失败');
} finally {

File diff suppressed because it is too large Load Diff

View File

@@ -84,6 +84,11 @@ const Playground = () => {
// handleInputChange('group', localGroupOptions[0].value);
if (localGroupOptions.length > 0) {
// set default group at first
localGroupOptions.unshift({
label: '用户分组',
value: '',
});
} else {
localGroupOptions = [{
label: '用户分组',

View File

@@ -10,7 +10,7 @@ import {
import { ITEMS_PER_PAGE } from '../constants';
import { renderQuota } from '../helpers/render';
import {
Button,
Button, Divider,
Form,
Modal,
Popconfirm,
@@ -391,6 +391,39 @@ const RedemptionsTable = () => {
onChange={handleKeywordChange}
/>
</Form>
<Divider style={{margin:'5px 0 15px 0'}}/>
<div>
<Button
theme='light'
type='primary'
style={{ marginRight: 8 }}
onClick={() => {
setEditingRedemption({
id: undefined,
});
setShowEdit(true);
}}
>
添加兑换码
</Button>
<Button
label='复制所选兑换码'
type='warning'
onClick={async () => {
if (selectedKeys.length === 0) {
showError('请至少选择一个兑换码!');
return;
}
let keys = '';
for (let i = 0; i < selectedKeys.length; i++) {
keys += selectedKeys[i].name + ' ' + selectedKeys[i].key + '\n';
}
await copyText(keys);
}}
>
复制所选兑换码到剪贴板
</Button>
</div>
<Table
style={{ marginTop: 20 }}
@@ -414,36 +447,6 @@ const RedemptionsTable = () => {
rowSelection={rowSelection}
onRow={handleRow}
></Table>
<Button
theme='light'
type='primary'
style={{ marginRight: 8 }}
onClick={() => {
setEditingRedemption({
id: undefined,
});
setShowEdit(true);
}}
>
添加兑换码
</Button>
<Button
label='复制所选兑换码'
type='warning'
onClick={async () => {
if (selectedKeys.length === 0) {
showError('请至少选择一个兑换码!');
return;
}
let keys = '';
for (let i = 0; i < selectedKeys.length; i++) {
keys += selectedKeys[i].name + ' ' + selectedKeys[i].key + '\n';
}
await copyText(keys);
}}
>
复制所选兑换码到剪贴板
</Button>
</>
);
};

View File

@@ -1,10 +1,16 @@
import React, { useEffect, useState } from 'react';
import { Link, useNavigate } from 'react-router-dom';
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
import { API, getLogo, showError, showInfo, showSuccess, updateAPI } from '../helpers';
import Turnstile from 'react-turnstile';
import { Button, Card, Form, Layout } from '@douyinfe/semi-ui';
import { Button, Card, Divider, Form, Icon, Layout, Modal } from '@douyinfe/semi-ui';
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
import { IconGithubLogo } from '@douyinfe/semi-icons';
import { onGitHubOAuthClicked, onLinuxDOOAuthClicked } from './utils.js';
import LinuxDoIcon from './LinuxDoIcon.js';
import WeChatIcon from './WeChatIcon.js';
import TelegramLoginButton from 'react-telegram-login/src';
import { setUserData } from '../helpers/data.js';
const RegisterForm = () => {
const [inputs, setInputs] = useState({
@@ -20,7 +26,11 @@ const RegisterForm = () => {
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
const [turnstileToken, setTurnstileToken] = useState('');
const [loading, setLoading] = useState(false);
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
const [status, setStatus] = useState({});
let navigate = useNavigate();
const logo = getLogo();
let affCode = new URLSearchParams(window.location.search).get('aff');
if (affCode) {
localStorage.setItem('aff', affCode);
@@ -30,6 +40,7 @@ const RegisterForm = () => {
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
setStatus(status);
setShowEmailVerification(status.email_verification);
if (status.turnstile_check) {
setTurnstileEnabled(true);
@@ -38,7 +49,32 @@ const RegisterForm = () => {
}
});
let navigate = useNavigate();
const onWeChatLoginClicked = () => {
setShowWeChatLoginModal(true);
};
const onSubmitWeChatVerificationCode = async () => {
if (turnstileEnabled && turnstileToken === '') {
showInfo('请稍后几秒重试Turnstile 正在检查用户环境!');
return;
}
const res = await API.get(
`/api/oauth/wechat?code=${inputs.wechat_verification_code}`,
);
const { success, message, data } = res.data;
if (success) {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
setUserData(data);
updateAPI();
navigate('/');
showSuccess('登录成功!');
setShowWeChatLoginModal(false);
} else {
showError(message);
}
};
function handleChange(name, value) {
setInputs((inputs) => ({ ...inputs, [name]: value }));
@@ -189,14 +225,127 @@ const RegisterForm = () => {
</Link>
</Text>
</div>
{status.github_oauth ||
status.wechat_login ||
status.telegram_oauth ||
status.linuxdo_oauth ? (
<>
<Divider margin='12px' align='center'>
第三方登录
</Divider>
<div
style={{
display: 'flex',
justifyContent: 'center',
marginTop: 20,
}}
>
{status.github_oauth ? (
<Button
type='primary'
icon={<IconGithubLogo />}
onClick={() =>
onGitHubOAuthClicked(status.github_client_id)
}
/>
) : (
<></>
)}
{status.linuxdo_oauth ? (
<Button
icon={<LinuxDoIcon />}
onClick={() =>
onLinuxDOOAuthClicked(status.linuxdo_client_id)
}
/>
) : (
<></>
)}
{status.wechat_login ? (
<Button
type='primary'
style={{ color: 'rgba(var(--semi-green-5), 1)' }}
icon={<Icon svg={<WeChatIcon />} />}
onClick={onWeChatLoginClicked}
/>
) : (
<></>
)}
</div>
{status.telegram_oauth ? (
<>
<div
style={{
display: 'flex',
justifyContent: 'center',
marginTop: 5,
}}
>
<TelegramLoginButton
dataOnauth={onTelegramLoginClicked}
botName={status.telegram_bot_name}
/>
</div>
</>
) : (
<></>
)}
</>
) : (
<></>
)}
</Card>
{turnstileEnabled ? (
<Turnstile
sitekey={turnstileSiteKey}
onVerify={(token) => {
setTurnstileToken(token);
<Modal
title='微信扫码登录'
visible={showWeChatLoginModal}
maskClosable={true}
onOk={onSubmitWeChatVerificationCode}
onCancel={() => setShowWeChatLoginModal(false)}
okText={'登录'}
size={'small'}
centered={true}
>
<div
style={{
display: 'flex',
alignItem: 'center',
flexDirection: 'column',
}}
/>
>
<img src={status.wechat_qrcode} />
</div>
<div style={{ textAlign: 'center' }}>
<p>
微信扫码关注公众号输入验证码获取验证码三分钟内有效
</p>
</div>
<Form size='large'>
<Form.Input
field={'wechat_verification_code'}
placeholder='验证码'
label={'验证码'}
value={inputs.wechat_verification_code}
onChange={(value) =>
handleChange('wechat_verification_code', value)
}
/>
</Form>
</Modal>
{turnstileEnabled ? (
<div
style={{
display: 'flex',
justifyContent: 'center',
marginTop: 20,
}}
>
<Turnstile
sitekey={turnstileSiteKey}
onVerify={(token) => {
setTurnstileToken(token);
}}
/>
</div>
) : (
<></>
)}

View File

@@ -53,6 +53,9 @@ const SystemSetting = () => {
TelegramOAuthEnabled: '',
TelegramBotToken: '',
TelegramBotName: '',
LinuxDOOAuthEnabled: '',
LinuxDOClientId: '',
LinuxDOClientSecret: '',
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@@ -103,6 +106,7 @@ const SystemSetting = () => {
case 'PasswordRegisterEnabled':
case 'EmailVerificationEnabled':
case 'GitHubOAuthEnabled':
case 'LinuxDOOAuthEnabled':
case 'WeChatAuthEnabled':
case 'TelegramOAuthEnabled':
case 'TurnstileCheckEnabled':
@@ -163,7 +167,9 @@ const SystemSetting = () => {
name === 'EmailDomainWhitelist' ||
name === 'TopupGroupRatio' ||
name === 'TelegramBotToken' ||
name === 'TelegramBotName'
name === 'TelegramBotName' ||
name === 'LinuxDOClientId' ||
name === 'LinuxDOClientSecret'
) {
setInputs((inputs) => ({ ...inputs, [name]: value }));
} else {
@@ -182,7 +188,7 @@ const SystemSetting = () => {
if (inputs.WorkerValidKey !== '') {
await updateOption('WorkerValidKey', inputs.WorkerValidKey);
}
}
};
const submitPayAddress = async () => {
if (inputs.ServerAddress === '') {
@@ -320,6 +326,18 @@ const SystemSetting = () => {
}
};
const submitLinuxDOOAuth = async () => {
if (originInputs['LinuxDOClientId'] !== inputs.LinuxDOClientId) {
await updateOption('LinuxDOClientId', inputs.LinuxDOClientId);
}
if (
originInputs['LinuxDOClientSecret'] !== inputs.LinuxDOClientSecret &&
inputs.LinuxDOClientSecret !== ''
) {
await updateOption('LinuxDOClientSecret', inputs.LinuxDOClientSecret);
}
};
return (
<Grid columns={1}>
<Grid.Column>
@@ -340,7 +358,15 @@ const SystemSetting = () => {
更新服务器地址
</Form.Button>
<Header as='h3' inverted={isDark}>
代理设置支持 <a href='https://github.com/Calcium-Ion/new-api-worker' target='_blank' rel='noreferrer'>new-api-worker</a>
代理设置支持{' '}
<a
href='https://github.com/Calcium-Ion/new-api-worker'
target='_blank'
rel='noreferrer'
>
new-api-worker
</a>
</Header>
<Form.Group widths='equal'>
<Form.Input
@@ -358,9 +384,7 @@ const SystemSetting = () => {
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={submitWorker}>
更新Worker设置
</Form.Button>
<Form.Button onClick={submitWorker}>更新Worker设置</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
支付设置当前仅支持易支付接口默认使用上方服务器地址作为回调地址
@@ -483,6 +507,12 @@ const SystemSetting = () => {
name='GitHubOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.LinuxDOOAuthEnabled === 'true'}
label='允许通过 LinuxDO 账户登录 & 注册'
name='LinuxDOOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.WeChatAuthEnabled === 'true'}
label='允许通过微信登录 & 注册'
@@ -781,6 +811,48 @@ const SystemSetting = () => {
<Form.Button onClick={submitTurnstile}>
保存 Turnstile 设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
配置 LinuxDO OAuth App
<Header.Subheader>
用以支持通过 LinuxDO 进行登录注册
<a
href='https://connect.linux.do/'
target='_blank'
rel='noreferrer'
>
点击此处
</a>
管理你的 LinuxDO OAuth App
</Header.Subheader>
</Header>
<Message>
Homepage URL <code>{inputs.ServerAddress}</code>
Authorization callback URL {' '}
<code>{`${inputs.ServerAddress}/oauth/linuxdo`}</code>
</Message>
<Form.Group widths={3}>
<Form.Input
label='LinuxDO Client ID'
name='LinuxDOClientId'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.LinuxDOClientId}
placeholder='输入你注册的 LinuxDO OAuth APP 的 ID'
/>
<Form.Input
label='LinuxDO Client Secret'
name='LinuxDOClientSecret'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
value={inputs.LinuxDOClientSecret}
placeholder='敏感信息不会发送到前端显示'
/>
</Form.Group>
<Form.Button onClick={submitLinuxDOOAuth}>
保存 LinuxDO OAuth 设置
</Form.Button>
</Form>
</Grid.Column>
</Grid>

View File

@@ -10,7 +10,7 @@ import {
import { ITEMS_PER_PAGE } from '../constants';
import {renderGroup, renderQuota} from '../helpers/render';
import {
Button,
Button, Divider,
Dropdown,
Form,
Modal,
@@ -596,6 +596,40 @@ const TokensTable = () => {
查询
</Button>
</Form>
<Divider style={{margin:'15px 0'}}/>
<div>
<Button
theme='light'
type='primary'
style={{ marginRight: 8 }}
onClick={() => {
setEditingToken({
id: undefined,
});
setShowEdit(true);
}}
>
添加令牌
</Button>
<Button
label='复制所选令牌'
type='warning'
onClick={async () => {
if (selectedKeys.length === 0) {
showError('请至少选择一个令牌!');
return;
}
let keys = '';
for (let i = 0; i < selectedKeys.length; i++) {
keys +=
selectedKeys[i].name + ' sk-' + selectedKeys[i].key + '\n';
}
await copyText(keys);
}}
>
复制所选令牌到剪贴板
</Button>
</div>
<Table
style={{ marginTop: 20 }}
@@ -619,37 +653,6 @@ const TokensTable = () => {
rowSelection={rowSelection}
onRow={handleRow}
></Table>
<Button
theme='light'
type='primary'
style={{ marginRight: 8 }}
onClick={() => {
setEditingToken({
id: undefined,
});
setShowEdit(true);
}}
>
添加令牌
</Button>
<Button
label='复制所选令牌'
type='warning'
onClick={async () => {
if (selectedKeys.length === 0) {
showError('请至少选择一个令牌!');
return;
}
let keys = '';
for (let i = 0; i < selectedKeys.length; i++) {
keys +=
selectedKeys[i].name + ' sk-' + selectedKeys[i].key + '\n';
}
await copyText(keys);
}}
>
复制所选令牌到剪贴板
</Button>
</>
);
};

View File

@@ -476,10 +476,18 @@ const UsersTable = () => {
type='primary'
htmlType='submit'
className='btn-margin-right'
style={{ marginRight: 8 }}
>
查询
</Button>
<Button
theme='light'
type='primary'
onClick={() => {
setShowAddUser(true);
}}
>
添加用户
</Button>
</Space>
</div>
</Form>
@@ -496,16 +504,6 @@ const UsersTable = () => {
}}
loading={loading}
/>
<Button
theme='light'
type='primary'
style={{ marginRight: 8 }}
onClick={() => {
setShowAddUser(true);
}}
>
添加用户
</Button>
</>
);
};

View File

@@ -0,0 +1,21 @@
import { Input, Typography } from '@douyinfe/semi-ui';
import React from 'react';
const TextInput = ({ label, name, value, onChange, placeholder, type = 'text' }) => {
return (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>{label}</Typography.Text>
</div>
<Input
name={name}
placeholder={placeholder}
onChange={(value) => onChange(value)}
value={value}
autoComplete="new-password"
/>
</>
);
}
export default TextInput;

View File

@@ -0,0 +1,21 @@
import { Input, InputNumber, Typography } from '@douyinfe/semi-ui';
import React from 'react';
const TextNumberInput = ({ label, name, value, onChange, placeholder }) => {
return (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>{label}</Typography.Text>
</div>
<InputNumber
name={name}
placeholder={placeholder}
onChange={(value) => onChange(value)}
value={value}
autoComplete="new-password"
/>
</>
);
}
export default TextNumberInput;

View File

@@ -1,7 +1,12 @@
import { API, showError } from '../helpers';
export async function getOAuthState() {
const res = await API.get('/api/oauth/state');
let path = '/api/oauth/state';
let affCode = localStorage.getItem('aff');
if (affCode && affCode.length > 0) {
path += `?aff=${affCode}`;
}
const res = await API.get(path);
const { success, message, data } = res.data;
if (success) {
return data;
@@ -19,6 +24,14 @@ export async function onGitHubOAuthClicked(github_client_id) {
);
}
export async function onLinuxDOOAuthClicked(linuxdo_client_id) {
const state = await getOAuthState();
if (!state) return;
window.open(
`https://connect.linux.do/oauth2/authorize?response_type=code&client_id=${linuxdo_client_id}&state=${state}`,
);
}
let channelModels = undefined;
export async function loadChannelModels() {
const res = await API.get('/api/models');

View File

@@ -67,6 +67,8 @@ export function renderQuotaNumberWithDigit(num, digits = 2) {
}
export function renderNumberWithPoint(num) {
if (num === undefined)
return '';
num = num.toFixed(2);
if (num >= 100000) {
// Convert number to string to manipulate it
@@ -173,6 +175,59 @@ export function renderModelPrice(
}
}
export function renderAudioModelPrice(
inputTokens,
completionTokens,
modelRatio,
modelPrice = -1,
completionRatio,
audioInputTokens,
audioCompletionTokens,
audioRatio,
audioCompletionRatio,
groupRatio,
) {
// 1 ratio = $0.002 / 1K tokens
if (modelPrice !== -1) {
return '模型价格:$' + modelPrice + ' * 分组倍率:' + groupRatio + ' = $' + modelPrice * groupRatio;
} else {
if (completionRatio === undefined) {
completionRatio = 0;
}
// 这里的 *2 是因为 1倍率=0.002刀,请勿删除
let inputRatioPrice = modelRatio * 2.0;
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
let price =
(inputTokens / 1000000) * inputRatioPrice * groupRatio +
(completionTokens / 1000000) * completionRatioPrice * groupRatio +
(audioInputTokens / 1000000) * inputRatioPrice * audioRatio * groupRatio +
(audioCompletionTokens / 1000000) * inputRatioPrice * audioRatio * audioCompletionRatio * groupRatio;
return (
<>
<article>
<p>提示${inputRatioPrice} * {groupRatio} = ${inputRatioPrice * groupRatio} / 1M tokens</p>
<p>补全${completionRatioPrice} * {groupRatio} = ${completionRatioPrice * groupRatio} / 1M tokens</p>
<p>音频提示${inputRatioPrice} * {groupRatio} * {audioRatio} = ${inputRatioPrice * audioRatio * groupRatio} / 1M tokens</p>
<p>音频补全${inputRatioPrice} * {groupRatio} * {audioRatio} * {audioCompletionRatio} = ${inputRatioPrice * audioRatio * audioCompletionRatio * groupRatio} / 1M tokens</p>
<p></p>
<p>
提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '}
{completionTokens} tokens / 1M tokens * ${completionRatioPrice} +
</p>
<p>
音频提示 {audioInputTokens} tokens / 1M tokens * ${inputRatioPrice} * {audioRatio} + 音频补全 {audioCompletionTokens} tokens / 1M tokens * ${inputRatioPrice} * {audioRatio} * {audioCompletionRatio}
</p>
<p>
文字 + 音频 * 分组 {groupRatio} =
${price.toFixed(6)}
</p>
<p>仅供参考以实际扣费为准</p>
</article>
</>
);
}
}
export function renderQuotaWithPrompt(quota, digits) {
let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true';

View File

@@ -6,7 +6,7 @@ import {
showError,
showInfo,
showSuccess,
verifyJSON,
verifyJSON
} from '../../helpers';
import { CHANNEL_OPTIONS } from '../../constants';
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
@@ -21,28 +21,26 @@ import {
Select,
TextArea,
Checkbox,
Banner,
Banner
} from '@douyinfe/semi-ui';
import { Divider } from 'semantic-ui-react';
import { getChannelModels, loadChannelModels } from '../../components/utils.js';
import axios from 'axios';
const MODEL_MAPPING_EXAMPLE = {
'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
'gpt-4-0314': 'gpt-4',
'gpt-4-32k-0314': 'gpt-4-32k',
'gpt-3.5-turbo': 'gpt-3.5-turbo-0125'
};
const STATUS_CODE_MAPPING_EXAMPLE = {
400: '500',
400: '500'
};
const REGION_EXAMPLE = {
"default": "us-central1",
"claude-3-5-sonnet-20240620": "europe-west1"
}
'default': 'us-central1',
'claude-3-5-sonnet-20240620': 'europe-west1'
};
const fetchButtonTips = "1. 新建渠道时请求通过当前浏览器发出2. 编辑已有渠道,请求通过后端服务器发出"
const fetchButtonTips = '1. 新建渠道时请求通过当前浏览器发出2. 编辑已有渠道,请求通过后端服务器发出';
function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
@@ -84,6 +82,9 @@ const EditChannel = (props) => {
auto_ban: 1,
test_model: '',
groups: ['default'],
priority: 0,
weight: 0,
tag: ''
};
const [batch, setBatch] = useState(false);
const [autoBan, setAutoBan] = useState(true);
@@ -108,7 +109,7 @@ const EditChannel = (props) => {
'mj_blend',
'mj_upscale',
'mj_describe',
'mj_uploads',
'mj_uploads'
];
break;
case 5:
@@ -128,13 +129,13 @@ const EditChannel = (props) => {
'mj_high_variation',
'mj_low_variation',
'mj_pan',
'mj_uploads',
'mj_uploads'
];
break;
case 36:
localModels = [
'suno_music',
'suno_lyrics',
'suno_lyrics'
];
break;
default:
@@ -171,7 +172,7 @@ const EditChannel = (props) => {
data.model_mapping = JSON.stringify(
JSON.parse(data.model_mapping),
null,
2,
2
);
}
setInputs(data);
@@ -190,70 +191,69 @@ const EditChannel = (props) => {
const fetchUpstreamModelList = async (name) => {
if (inputs["type"] !== 1) {
showError("仅支持 OpenAI 接口格式")
if (inputs['type'] !== 1) {
showError('仅支持 OpenAI 接口格式');
return;
}
setLoading(true)
const models = inputs["models"] || []
setLoading(true);
const models = inputs['models'] || [];
let err = false;
if (isEdit) {
const res = await API.get("/api/channel/fetch_models/" + channelId)
const res = await API.get('/api/channel/fetch_models/' + channelId);
if (res.data && res.data?.success) {
models.push(...res.data.data)
models.push(...res.data.data);
} else {
err = true
err = true;
}
} else {
if (!inputs?.["key"]) {
showError("请填写密钥")
err = true
if (!inputs?.['key']) {
showError('请填写密钥');
err = true;
} else {
try {
const host = new URL((inputs["base_url"] || "https://api.openai.com"))
const host = new URL((inputs['base_url'] || 'https://api.openai.com'));
const url = `https://${host.hostname}/v1/models`;
const key = inputs["key"];
const key = inputs['key'];
const res = await axios.get(url, {
headers: {
'Authorization': `Bearer ${key}`
}
})
});
if (res.data && res.data?.success) {
models.push(...res.data.data.map((model) => model.id))
models.push(...res.data.data.map((model) => model.id));
} else {
err = true
err = true;
}
}
catch (error) {
err = true
} catch (error) {
err = true;
}
}
}
if (!err) {
handleInputChange(name, Array.from(new Set(models)));
showSuccess("获取模型列表成功");
showSuccess('获取模型列表成功');
} else {
showError('获取模型列表失败');
}
setLoading(false);
}
};
const fetchModels = async () => {
try {
let res = await API.get(`/api/channel/models`);
let localModelOptions = res.data.data.map((model) => ({
label: model.id,
value: model.id,
value: model.id
}));
setOriginModelOptions(localModelOptions);
setFullModels(res.data.data.map((model) => model.id));
setBasicModels(
res.data.data
.filter((model) => {
return model.id.startsWith('gpt-3') || model.id.startsWith('text-');
return model.id.startsWith('gpt-') || model.id.startsWith('text-');
})
.map((model) => model.id),
.map((model) => model.id)
);
} catch (error) {
showError(error.message);
@@ -269,8 +269,8 @@ const EditChannel = (props) => {
setGroupOptions(
res.data.data.map((group) => ({
label: group,
value: group,
})),
value: group
}))
);
} catch (error) {
showError(error.message);
@@ -280,10 +280,10 @@ const EditChannel = (props) => {
useEffect(() => {
let localModelOptions = [...originModelOptions];
inputs.models.forEach((model) => {
if (!localModelOptions.find((option) => option.key === model)) {
if (!localModelOptions.find((option) => option.label === model)) {
localModelOptions.push({
label: model,
value: model,
value: model
});
}
});
@@ -320,7 +320,7 @@ const EditChannel = (props) => {
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(
0,
localInputs.base_url.length - 1,
localInputs.base_url.length - 1
);
}
if (localInputs.type === 3 && localInputs.other === '') {
@@ -341,7 +341,7 @@ const EditChannel = (props) => {
if (isEdit) {
res = await API.put(`/api/channel/`, {
...localInputs,
id: parseInt(channelId),
id: parseInt(channelId)
});
} else {
res = await API.post(`/api/channel/`, localInputs);
@@ -378,7 +378,7 @@ const EditChannel = (props) => {
// 添加到下拉选项
key: model,
text: model,
value: model,
value: model
});
} else if (model) {
showError('某些模型已存在!');
@@ -409,11 +409,11 @@ const EditChannel = (props) => {
footer={
<div style={{ display: 'flex', justifyContent: 'flex-end' }}>
<Space>
<Button theme='solid' size={'large'} onClick={submit}>
<Button theme="solid" size={'large'} onClick={submit}>
提交
</Button>
<Button
theme='solid'
theme="solid"
size={'large'}
type={'tertiary'}
onClick={handleCancel}
@@ -432,7 +432,7 @@ const EditChannel = (props) => {
<Typography.Text strong>类型</Typography.Text>
</div>
<Select
name='type'
name="type"
required
optionList={CHANNEL_OPTIONS}
value={inputs.type}
@@ -450,8 +450,8 @@ const EditChannel = (props) => {
因为 One API 会把请求体中的 model
参数替换为你的部署名称模型名称中的点会被剔除
<a
target='_blank'
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'
target="_blank"
href="https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271"
>
图片演示
</a>
@@ -466,8 +466,8 @@ const EditChannel = (props) => {
</Typography.Text>
</div>
<Input
label='AZURE_OPENAI_ENDPOINT'
name='azure_base_url'
label="AZURE_OPENAI_ENDPOINT"
name="azure_base_url"
placeholder={
'请输入 AZURE_OPENAI_ENDPOINT例如https://docs-test-001.openai.azure.com'
}
@@ -475,14 +475,14 @@ const EditChannel = (props) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete='new-password'
autoComplete="new-password"
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>默认 API 版本</Typography.Text>
</div>
<Input
label='默认 API 版本'
name='azure_other'
label="默认 API 版本"
name="azure_other"
placeholder={
'请输入默认 API 版本例如2023-06-01-preview该配置可以被实际的请求查询参数所覆盖'
}
@@ -490,7 +490,7 @@ const EditChannel = (props) => {
handleInputChange('other', value);
}}
value={inputs.other}
autoComplete='new-password'
autoComplete="new-password"
/>
</>
)}
@@ -512,7 +512,7 @@ const EditChannel = (props) => {
</Typography.Text>
</div>
<Input
name='base_url'
name="base_url"
placeholder={
'请输入完整的URL例如https://api.openai.com/v1/chat/completions'
}
@@ -520,49 +520,84 @@ const EditChannel = (props) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete='new-password'
autoComplete="new-password"
/>
</>
)}
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>代理</Typography.Text>
</div>
<Input
label="代理"
name="base_url"
placeholder={'此项可选,用于通过代理站来进行 API 调用'}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete="new-password"
/>
</>
)}
{inputs.type === 22 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>私有部署地址</Typography.Text>
</div>
<Input
name="base_url"
placeholder={
'请输入私有部署地址格式为https://fastgpt.run/api/openapi'
}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete="new-password"
/>
</>
)}
{inputs.type === 36 && (
<>
<div style={{marginTop: 10}}>
<Typography.Text strong>
注意非Chat API请务必填写正确的API地址否则可能导致无法使用
</Typography.Text>
</div>
<Input
name='base_url'
placeholder={
'请输入到 /suno 前的路径通常就是域名例如https://api.example.com '
}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete='new-password'
/>
</>
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>
注意非Chat API请务必填写正确的API地址否则可能导致无法使用
</Typography.Text>
</div>
<Input
name="base_url"
placeholder={
'请输入到 /suno 前的路径通常就是域名例如https://api.example.com '
}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete="new-password"
/>
</>
)}
<div style={{marginTop: 10}}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>名称</Typography.Text>
</div>
<Input
required
name='name'
required
name="name"
placeholder={'请为渠道命名'}
onChange={(value) => {
handleInputChange('name', value);
}}
value={inputs.name}
autoComplete='new-password'
autoComplete="new-password"
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>分组</Typography.Text>
</div>
<Select
placeholder={'请选择可以使用该渠道的分组'}
name='groups'
name="groups"
required
multiple
selection
@@ -572,7 +607,7 @@ const EditChannel = (props) => {
handleInputChange('groups', value);
}}
value={inputs.groups}
autoComplete='new-password'
autoComplete="new-password"
optionList={groupOptions}
/>
{inputs.type === 18 && (
@@ -581,7 +616,7 @@ const EditChannel = (props) => {
<Typography.Text strong>模型版本</Typography.Text>
</div>
<Input
name='other'
name="other"
placeholder={
'请输入星火大模型版本注意是接口地址中的版本号例如v2.1'
}
@@ -589,7 +624,7 @@ const EditChannel = (props) => {
handleInputChange('other', value);
}}
value={inputs.other}
autoComplete='new-password'
autoComplete="new-password"
/>
</>
)}
@@ -599,7 +634,7 @@ const EditChannel = (props) => {
<Typography.Text strong>部署地区</Typography.Text>
</div>
<TextArea
name='other'
name="other"
placeholder={
'请输入部署地区例如us-central1\n支持使用模型映射格式\n' +
'{\n' +
@@ -612,18 +647,18 @@ const EditChannel = (props) => {
handleInputChange('other', value);
}}
value={inputs.other}
autoComplete='new-password'
autoComplete="new-password"
/>
<Typography.Text
style={{
color: 'rgba(var(--semi-blue-5), 1)',
userSelect: 'none',
cursor: 'pointer',
cursor: 'pointer'
}}
onClick={() => {
handleInputChange(
'other',
JSON.stringify(REGION_EXAMPLE, null, 2),
JSON.stringify(REGION_EXAMPLE, null, 2)
);
}}
>
@@ -637,14 +672,14 @@ const EditChannel = (props) => {
<Typography.Text strong>知识库 ID</Typography.Text>
</div>
<Input
label='知识库 ID'
name='other'
label="知识库 ID"
name="other"
placeholder={'请输入知识库 ID例如123456'}
onChange={(value) => {
handleInputChange('other', value);
}}
value={inputs.other}
autoComplete='new-password'
autoComplete="new-password"
/>
</>
)}
@@ -654,7 +689,7 @@ const EditChannel = (props) => {
<Typography.Text strong>Account ID</Typography.Text>
</div>
<Input
name='other'
name="other"
placeholder={
'请输入Account ID例如d6b5da8hk1awo8nap34ube6gh'
}
@@ -662,7 +697,7 @@ const EditChannel = (props) => {
handleInputChange('other', value);
}}
value={inputs.other}
autoComplete='new-password'
autoComplete="new-password"
/>
</>
)}
@@ -671,7 +706,7 @@ const EditChannel = (props) => {
</div>
<Select
placeholder={'请选择该渠道所支持的模型'}
name='models'
name="models"
required
multiple
selection
@@ -679,13 +714,13 @@ const EditChannel = (props) => {
handleInputChange('models', value);
}}
value={inputs.models}
autoComplete='new-password'
autoComplete="new-password"
optionList={modelOptions}
/>
<div style={{ lineHeight: '40px', marginBottom: '12px' }}>
<Space>
<Button
type='primary'
type="primary"
onClick={() => {
handleInputChange('models', basicModels);
}}
@@ -693,7 +728,7 @@ const EditChannel = (props) => {
填入相关模型
</Button>
<Button
type='secondary'
type="secondary"
onClick={() => {
handleInputChange('models', fullModels);
}}
@@ -702,7 +737,7 @@ const EditChannel = (props) => {
</Button>
<Tooltip content={fetchButtonTips}>
<Button
type='tertiary'
type="tertiary"
onClick={() => {
fetchUpstreamModelList('models');
}}
@@ -711,7 +746,7 @@ const EditChannel = (props) => {
</Button>
</Tooltip>
<Button
type='warning'
type="warning"
onClick={() => {
handleInputChange('models', []);
}}
@@ -721,11 +756,11 @@ const EditChannel = (props) => {
</Space>
<Input
addonAfter={
<Button type='primary' onClick={addCustomModels}>
<Button type="primary" onClick={addCustomModels}>
填入
</Button>
}
placeholder='输入自定义模型名称'
placeholder="输入自定义模型名称"
value={customModel}
onChange={(value) => {
setCustomModel(value.trim());
@@ -737,24 +772,24 @@ const EditChannel = (props) => {
</div>
<TextArea
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
name='model_mapping'
name="model_mapping"
onChange={(value) => {
handleInputChange('model_mapping', value);
}}
autosize
value={inputs.model_mapping}
autoComplete='new-password'
autoComplete="new-password"
/>
<Typography.Text
style={{
color: 'rgba(var(--semi-blue-5), 1)',
userSelect: 'none',
cursor: 'pointer',
cursor: 'pointer'
}}
onClick={() => {
handleInputChange(
'model_mapping',
JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2),
JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)
);
}}
>
@@ -765,8 +800,8 @@ const EditChannel = (props) => {
</div>
{batch ? (
<TextArea
label='密钥'
name='key'
label="密钥"
name="key"
required
placeholder={'请输入密钥,一行一个'}
onChange={(value) => {
@@ -774,14 +809,14 @@ const EditChannel = (props) => {
}}
value={inputs.key}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
autoComplete="new-password"
/>
) : (
<>
{inputs.type === 41 ? (
<TextArea
label='鉴权json'
name='key'
label="鉴权json"
name="key"
required
placeholder={'{\n' +
' "type": "service_account",\n' +
@@ -801,23 +836,36 @@ const EditChannel = (props) => {
}}
autosize={{ minRows: 10 }}
value={inputs.key}
autoComplete='new-password'
autoComplete="new-password"
/>
) : (
<Input
label='密钥'
name='key'
label="密钥"
name="key"
required
placeholder={type2secretPrompt(inputs.type)}
onChange={(value) => {
handleInputChange('key', value);
}}
value={inputs.key}
autoComplete='new-password'
autoComplete="new-password"
/>
)
}
</>
</>
)}
{!isEdit && (
<div style={{ marginTop: 10, display: 'flex' }}>
<Space>
<Checkbox
checked={batch}
label="批量创建"
name="batch"
onChange={() => setBatch(!batch)}
/>
<Typography.Text strong>批量创建</Typography.Text>
</Space>
</div>
)}
{inputs.type === 1 && (
<>
@@ -825,9 +873,9 @@ const EditChannel = (props) => {
<Typography.Text strong>组织</Typography.Text>
</div>
<Input
label='组织,可选,不填则为默认组织'
name='openai_organization'
placeholder='请输入组织org-xxx'
label="组织,可选,不填则为默认组织"
name="openai_organization"
placeholder="请输入组织org-xxx"
onChange={(value) => {
handleInputChange('openai_organization', value);
}}
@@ -839,8 +887,8 @@ const EditChannel = (props) => {
<Typography.Text strong>默认测试模型</Typography.Text>
</div>
<Input
name='test_model'
placeholder='不填则为模型列表第一个'
name="test_model"
placeholder="不填则为模型列表第一个"
onChange={(value) => {
handleInputChange('test_model', value);
}}
@@ -849,7 +897,7 @@ const EditChannel = (props) => {
<div style={{ marginTop: 10, display: 'flex' }}>
<Space>
<Checkbox
name='auto_ban'
name="auto_ban"
checked={autoBan}
onChange={() => {
setAutoBan(!autoBan);
@@ -861,55 +909,6 @@ const EditChannel = (props) => {
</Typography.Text>
</Space>
</div>
{!isEdit && (
<div style={{ marginTop: 10, display: 'flex' }}>
<Space>
<Checkbox
checked={batch}
label='批量创建'
name='batch'
onChange={() => setBatch(!batch)}
/>
<Typography.Text strong>批量创建</Typography.Text>
</Space>
</div>
)}
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>代理</Typography.Text>
</div>
<Input
label='代理'
name='base_url'
placeholder={'此项可选,用于通过代理站来进行 API 调用'}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete='new-password'
/>
</>
)}
{inputs.type === 22 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>私有部署地址</Typography.Text>
</div>
<Input
name='base_url'
placeholder={
'请输入私有部署地址格式为https://fastgpt.run/api/openapi'
}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete='new-password'
/>
</>
)}
<div style={{ marginTop: 10 }}>
<Typography.Text strong>
状态码复写仅影响本地判断不修改返回到上游的状态码
@@ -917,43 +916,74 @@ const EditChannel = (props) => {
</div>
<TextArea
placeholder={`此项可选用于复写返回的状态码比如将claude渠道的400错误复写为500用于重试请勿滥用该功能例如\n${JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2)}`}
name='status_code_mapping'
name="status_code_mapping"
onChange={(value) => {
handleInputChange('status_code_mapping', value);
}}
autosize
value={inputs.status_code_mapping}
autoComplete='new-password'
autoComplete="new-password"
/>
<Typography.Text
style={{
color: 'rgba(var(--semi-blue-5), 1)',
userSelect: 'none',
cursor: 'pointer',
cursor: 'pointer'
}}
onClick={() => {
handleInputChange(
'status_code_mapping',
JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2),
JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2)
);
}}
>
填入模板
</Typography.Text>
{/*<div style={{ marginTop: 10 }}>*/}
{/* <Typography.Text strong>*/}
{/* 最大请求token0表示不限制*/}
{/* </Typography.Text>*/}
{/*</div>*/}
{/*<Input*/}
{/* label='最大请求token'*/}
{/* name='max_input_tokens'*/}
{/* placeholder='默认为0表示不限制'*/}
{/* onChange={(value) => {*/}
{/* handleInputChange('max_input_tokens', value);*/}
{/* }}*/}
{/* value={inputs.max_input_tokens}*/}
{/*/>*/}
<div style={{ marginTop: 10 }}>
<Typography.Text strong>
渠道标签
</Typography.Text>
</div>
<Input
label="渠道标签"
name="tag"
placeholder={'渠道标签'}
onChange={(value) => {
handleInputChange('tag', value);
}}
value={inputs.tag}
autoComplete="new-password"
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>
渠道优先级
</Typography.Text>
</div>
<Input
label="渠道优先级"
name="priority"
placeholder={'渠道优先级'}
onChange={(value) => {
handleInputChange('priority', parseInt(value));
}}
value={inputs.priority}
autoComplete="new-password"
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>
渠道权重
</Typography.Text>
</div>
<Input
label="渠道权重"
name="weight"
placeholder={'渠道权重'}
onChange={(value) => {
handleInputChange('weight', parseInt(value));
}}
value={inputs.weight}
autoComplete="new-password"
/>
</Spin>
</SideSheet>
</>

Some files were not shown because too many files have changed in this diff Show More