mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 17:44:41 +00:00
Compare commits
142 Commits
v0.4.8.4
...
v0.6.0-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e46d4217d | ||
|
|
c25d4d8d23 | ||
|
|
b291fbff6b | ||
|
|
e68edf81f7 | ||
|
|
5ff16f9b2d | ||
|
|
f614cfa563 | ||
|
|
2048b451bf | ||
|
|
bd48f43410 | ||
|
|
c47d8a10f0 | ||
|
|
c0b9350785 | ||
|
|
229738cda9 | ||
|
|
39d95172e8 | ||
|
|
5059cbdb46 | ||
|
|
a981e10712 | ||
|
|
495bbcb621 | ||
|
|
20e34bec7e | ||
|
|
0033f5ba2e | ||
|
|
e52ac52e7b | ||
|
|
66682584a5 | ||
|
|
1a2bf8df1f | ||
|
|
1819c4d5f5 | ||
|
|
6f24dddcb2 | ||
|
|
8de29fbb83 | ||
|
|
f2163acf2b | ||
|
|
5259acfacd | ||
|
|
c433af284c | ||
|
|
3122b8a36a | ||
|
|
bbe7223a85 | ||
|
|
2af05c166c | ||
|
|
ecb5b5630c | ||
|
|
e1b9f164f9 | ||
|
|
69db1f1465 | ||
|
|
94549f9687 | ||
|
|
c7e1bab18a | ||
|
|
627f95b034 | ||
|
|
8b99eec440 | ||
|
|
49bfd2b719 | ||
|
|
434e9d7695 | ||
|
|
b2938ffe2c | ||
|
|
d9cf0885f1 | ||
|
|
3ed50787b3 | ||
|
|
97d948cdb1 | ||
|
|
5017fabbfa | ||
|
|
bd5c261b99 | ||
|
|
00c2d6c102 | ||
|
|
4a8bb625b8 | ||
|
|
db01994cd0 | ||
|
|
a0ca3effa7 | ||
|
|
5a10ebd384 | ||
|
|
68097c132d | ||
|
|
3352bacd35 | ||
|
|
7fcb14e25f | ||
|
|
867187ab4d | ||
|
|
3ad96d3b4e | ||
|
|
d9390ff4c3 | ||
|
|
8c209e2fb9 | ||
|
|
a9bfcb0daf | ||
|
|
bb848b2fe0 | ||
|
|
618908f6f8 | ||
|
|
1f4ebddcfa | ||
|
|
6d79d8993e | ||
|
|
7c03ad71de | ||
|
|
4f194f4e6a | ||
|
|
81137e0533 | ||
|
|
b9b66dda54 | ||
|
|
fd22948ead | ||
|
|
894dce7366 | ||
|
|
b95142bbac | ||
|
|
7f74a9664e | ||
|
|
a3739f67f7 | ||
|
|
b841ce006f | ||
|
|
e3f9ef1894 | ||
|
|
558e625a01 | ||
|
|
37a83ecc33 | ||
|
|
37bb34b4b0 | ||
|
|
8deab221f9 | ||
|
|
17e9f1a07d | ||
|
|
792754cee3 | ||
|
|
98b27a17a6 | ||
|
|
7855f83e2d | ||
|
|
cbdf26bf2c | ||
|
|
eb46b71a71 | ||
|
|
a42c3b6227 | ||
|
|
b00dd8b405 | ||
|
|
be228ccd2c | ||
|
|
b1be64bcf3 | ||
|
|
6ecfb81cbc | ||
|
|
14848ff789 | ||
|
|
47d3b515da | ||
|
|
760514c3e1 | ||
|
|
254c25c27a | ||
|
|
8731a32e56 | ||
|
|
7208a65e5d | ||
|
|
4084b18071 | ||
|
|
2ca0d7246d | ||
|
|
d042a1bd55 | ||
|
|
816e831a2e | ||
|
|
a3ceae4a86 | ||
|
|
eb163d9c94 | ||
|
|
a592a81bc2 | ||
|
|
bb300d199e | ||
|
|
7dbb6b017c | ||
|
|
ce1854847b | ||
|
|
2f9faba40d | ||
|
|
a5085014cc | ||
|
|
18d3706ff8 | ||
|
|
152950497e | ||
|
|
d6fd50e382 | ||
|
|
cfd3f6c073 | ||
|
|
45c56b5ded | ||
|
|
d306394f33 | ||
|
|
cdba87a7da | ||
|
|
ae5b874a6c | ||
|
|
d0bc8d17d1 | ||
|
|
4784ca7514 | ||
|
|
3a18c0ce9f | ||
|
|
929668bead | ||
|
|
06a78f9042 | ||
|
|
0f1c4c4ebe | ||
|
|
1bcf7a3c39 | ||
|
|
5f0b3f6d6f | ||
|
|
19a318c943 | ||
|
|
13ab0f8e4f | ||
|
|
6d8d40e67b | ||
|
|
287caf8e38 | ||
|
|
c802b3b41a | ||
|
|
ed4e1c2332 | ||
|
|
e581ea33c2 | ||
|
|
bf80d71ddf | ||
|
|
e19b244e73 | ||
|
|
f451268830 | ||
|
|
069f2672c1 | ||
|
|
ccf13d445f | ||
|
|
da4d1861fe | ||
|
|
3de5b96cb4 | ||
|
|
5b9e275690 | ||
|
|
607e3206b3 | ||
|
|
83feb492fb | ||
|
|
4f212be45c | ||
|
|
92918e3751 | ||
|
|
de15551570 | ||
|
|
a81a28b7a5 |
@@ -50,10 +50,6 @@
|
||||
# CHANNEL_TEST_FREQUENCY=10
|
||||
# 生成默认token
|
||||
# GENERATE_DEFAULT_TOKEN=false
|
||||
# Gemini 安全设置
|
||||
# GEMINI_SAFETY_SETTING=BLOCK_NONE
|
||||
# Gemini版本设置
|
||||
# GEMINI_MODEL_MAP=gemini-1.0-pro:v1
|
||||
# Cohere 安全设置
|
||||
# COHERE_SAFETY_SETTING=NONE
|
||||
# 是否统计图片token
|
||||
|
||||
12
README.en.md
12
README.en.md
@@ -65,10 +65,18 @@
|
||||
- Add suffix `-low` to set low reasoning effort
|
||||
17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
|
||||
18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings`
|
||||
19. 💰 Cache billing support, when enabled can charge a configurable ratio for cache hits:
|
||||
1. Set `Prompt Cache Ratio` in `System Settings -> Operation Settings`
|
||||
2. Set `Prompt Cache Ratio` in channel settings, range 0-1 (e.g., 0.5 means 50% charge on cache hits)
|
||||
3. Supported channels:
|
||||
- [x] OpenAI
|
||||
- [x] Azure
|
||||
- [x] DeepSeek
|
||||
- [ ] Claude
|
||||
|
||||
## Model Support
|
||||
This version additionally supports:
|
||||
1. Third-party model **gps** (gpt-4-gizmo-*)
|
||||
1. Third-party model **gpts** (gpt-4-gizmo-*)
|
||||
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md)
|
||||
3. Custom channels with full API URL support
|
||||
4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md)
|
||||
@@ -162,7 +170,7 @@ docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtow
|
||||
|
||||
## Channel Retry
|
||||
Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
|
||||
First retry uses same priority, second retry uses next priority, and so on.
|
||||
If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request.
|
||||
|
||||
### Cache Configuration
|
||||
1. `REDIS_CONN_STRING`: Use Redis as cache
|
||||
|
||||
231
README.md
231
README.md
@@ -7,7 +7,6 @@
|
||||
|
||||
# New API
|
||||
|
||||
|
||||
🍥新一代大模型网关与AI资产管理系统
|
||||
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
@@ -41,181 +40,157 @@
|
||||
> - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
|
||||
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||
|
||||
## 📚 文档
|
||||
|
||||
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||
|
||||
## ✨ 主要特性
|
||||
|
||||
1. 🎨 全新的UI界面(部分界面还待更新)
|
||||
2. 🌍 多语言支持(待完善)
|
||||
3. 🎨 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口支持,[对接文档](Midjourney.md)
|
||||
4. 💰 支持在线充值功能,可在系统设置中设置:
|
||||
- [x] 易支付
|
||||
5. 🔍 支持用key查询使用额度:
|
||||
- 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
|
||||
New API提供了丰富的功能,详细特性请参考[维基百科-特性说明](https://docs.newapi.pro/wiki/features-introduction):
|
||||
|
||||
1. 🎨 全新的UI界面
|
||||
2. 🌍 多语言支持
|
||||
3. 🎨 支持[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](https://docs.newapi.pro/api/relay/image/midjourney)
|
||||
4. 💰 支持在线充值功能(易支付)
|
||||
5. 🔍 支持用key查询使用额度(配合[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
|
||||
6. 📑 分页支持选择每页显示数量
|
||||
7. 🔄 兼容原版One API的数据库,可直接使用原版数据库(one-api.db)
|
||||
8. 💵 支持模型按次数收费,可在 系统设置-运营设置 中设置
|
||||
9. ⚖️ 支持渠道**加权随机**
|
||||
7. 🔄 兼容原版One API的数据库
|
||||
8. 💵 支持模型按次数收费
|
||||
9. ⚖️ 支持渠道加权随机
|
||||
10. 📈 数据看板(控制台)
|
||||
11. 🔒 可设置令牌能调用的模型
|
||||
12. 🤖 支持Telegram授权登录:
|
||||
1. 系统设置-配置登录注册-允许通过Telegram登录
|
||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||
13. 🎵 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口支持,[对接文档](Suno.md)
|
||||
14. 🔄 支持Rerank模型,目前兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md)
|
||||
15. ⚡ **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API,支持Azure渠道
|
||||
16. 支持使用路由/chat2link 进入聊天界面
|
||||
17. 🧠 支持通过模型名称后缀设置 reasoning effort:
|
||||
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
|
||||
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
|
||||
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
|
||||
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。
|
||||
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制
|
||||
12. 🤖 支持Telegram授权登录
|
||||
13. 🎵 支持[Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music)
|
||||
14. 🔄 支持Rerank模型(Cohere和Jina),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
|
||||
15. ⚡ 支持OpenAI Realtime API(包括Azure渠道),[接口文档](https://docs.newapi.pro/api/openai-realtime)
|
||||
16. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
|
||||
17. 支持使用路由/chat2link进入聊天界面
|
||||
18. 🧠 支持通过模型名称后缀设置 reasoning effort:
|
||||
1. OpenAI o系列模型
|
||||
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
|
||||
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
|
||||
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
|
||||
2. Claude 思考模型
|
||||
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
|
||||
19. 🔄 思考转内容功能
|
||||
20. 🔄 模型限流功能
|
||||
20. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
||||
1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
|
||||
2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
|
||||
3. 支持的渠道:
|
||||
- [x] OpenAI
|
||||
- [x] Azure
|
||||
- [x] DeepSeek
|
||||
- [x] Claude
|
||||
|
||||
## 模型支持
|
||||
此版本额外支持以下模型:
|
||||
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
||||
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||
|
||||
此版本支持多种模型,详情请参考[接口文档-中继接口](https://docs.newapi.pro/api):
|
||||
|
||||
1. 第三方模型 **gpts** (gpt-4-gizmo-*)
|
||||
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[接口文档](https://docs.newapi.pro/api/midjourney-proxy-image)
|
||||
3. 自定义渠道,支持填入完整调用地址
|
||||
4. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
||||
5. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md)
|
||||
6. Dify
|
||||
4. [Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music)
|
||||
5. Rerank模型([Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
|
||||
6. Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
|
||||
7. Dify
|
||||
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
## 环境变量配置
|
||||
|
||||
## 比原版One API多出的配置
|
||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`。
|
||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 60 秒。
|
||||
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`。
|
||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true`,建议开启,不影响客户端传入stream_options参数返回结果。
|
||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
|
||||
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
|
||||
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
|
||||
- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用"模型:版本"指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
|
||||
- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`。
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
|
||||
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
|
||||
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
|
||||
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview`
|
||||
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。
|
||||
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。
|
||||
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
||||
|
||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
|
||||
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
||||
- `GET_MEDIA_TOKEN_NOT_STREAM`:非流情况下是否统计图片token,默认 `true`
|
||||
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认 `true`
|
||||
- `COHERE_SAFETY_SETTING`:Cohere模型安全设置,可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认 `NONE`
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16`
|
||||
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20`
|
||||
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容
|
||||
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2024-12-01-preview`
|
||||
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
|
||||
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
|
||||
|
||||
## 部署
|
||||
|
||||
详细部署指南请参考[安装指南-部署方式](https://docs.newapi.pro/installation):
|
||||
|
||||
> [!TIP]
|
||||
> 最新版Docker镜像:`calciumion/new-api:latest`
|
||||
> 默认账号root 密码123456
|
||||
|
||||
### 多机部署
|
||||
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。
|
||||
- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取。
|
||||
### 多机部署注意事项
|
||||
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致
|
||||
- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取
|
||||
|
||||
### 部署要求
|
||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||
- 远程数据库:MySQL 版本 >= 5.7.8,PgSQL 版本 >= 9.6
|
||||
- 本地数据库(默认):SQLite(Docker部署必须挂载`/data`目录)
|
||||
- 远程数据库:MySQL版本 >= 5.7.8,PgSQL版本 >= 9.6
|
||||
|
||||
### 使用宝塔面板Docker功能部署
|
||||
安装宝塔面板 (**9.2.0版本**及以上),前往 [宝塔面板](https://www.bt.cn/new/download.html) 官网,选择正式版的脚本下载安装
|
||||
安装后登录宝塔面板,在菜单栏中点击 Docker ,首次进入会提示安装 Docker 服务,点击立即安装,按提示完成安装
|
||||
安装完成后在应用商店中找到 **New-API** ,点击安装,配置基本选项 即可完成安装
|
||||
### 部署方式
|
||||
|
||||
#### 使用宝塔面板Docker功能部署
|
||||
安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。
|
||||
[图文教程](BT.md)
|
||||
|
||||
### 基于 Docker 进行部署
|
||||
|
||||
> [!TIP]
|
||||
> 默认管理员账号root 密码123456
|
||||
|
||||
### 使用 Docker Compose 部署(推荐)
|
||||
#### 使用Docker Compose部署(推荐)
|
||||
```shell
|
||||
# 下载项目
|
||||
git clone https://github.com/Calcium-Ion/new-api.git
|
||||
cd new-api
|
||||
# 按需编辑 docker-compose.yml
|
||||
# nano docker-compose.yml
|
||||
# vim docker-compose.yml
|
||||
# 按需编辑docker-compose.yml
|
||||
# 启动
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### 更新版本
|
||||
#### 直接使用Docker镜像
|
||||
```shell
|
||||
docker-compose pull
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
### 直接使用 Docker 镜像
|
||||
```shell
|
||||
# 使用 SQLite 的部署命令:
|
||||
# 使用SQLite
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
|
||||
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
|
||||
# 例如:
|
||||
# 使用MySQL
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
```
|
||||
|
||||
#### 更新版本
|
||||
```shell
|
||||
# 拉取最新镜像
|
||||
docker pull calciumion/new-api:latest
|
||||
# 停止并删除旧容器
|
||||
docker stop new-api
|
||||
docker rm new-api
|
||||
# 使用相同参数运行新容器
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
```
|
||||
## 渠道重试与缓存
|
||||
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
|
||||
|
||||
或者使用 Watchtower 自动更新(不推荐,可能会导致数据库不兼容):
|
||||
```shell
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
|
||||
```
|
||||
|
||||
## 渠道重试
|
||||
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
|
||||
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
|
||||
### 缓存设置方法
|
||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
||||
### 为什么有的时候没有重试
|
||||
这些错误码不会重试:400,504,524
|
||||
### 我想让400也重试
|
||||
在`渠道->编辑`中,将`状态码复写`改为
|
||||
```json
|
||||
{
|
||||
"400": "500"
|
||||
}
|
||||
```
|
||||
可以实现400错误转为500错误,从而重试
|
||||
1. `REDIS_CONN_STRING`:设置Redis作为缓存
|
||||
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(设置了Redis则无需手动设置)
|
||||
|
||||
## Midjourney接口设置文档
|
||||
[对接文档](Midjourney.md)
|
||||
## 接口文档
|
||||
|
||||
## Suno接口设置文档
|
||||
[对接文档](Suno.md)
|
||||
详细接口文档请参考[接口文档](https://docs.newapi.pro/api):
|
||||
|
||||
## 界面截图
|
||||

|
||||
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## 交流群
|
||||
<img src="https://github.com/user-attachments/assets/9ca0bc82-e057-4230-a28d-9f198fa022e3" width="200">
|
||||
- [聊天接口(Chat)](https://docs.newapi.pro/api/openai-chat)
|
||||
- [图像接口(Image)](https://docs.newapi.pro/api/openai-image)
|
||||
- [Midjourney接口](https://docs.newapi.pro/api/midjourney-proxy-image)
|
||||
- [音乐接口(Music)](https://docs.newapi.pro/api/relay/music)
|
||||
- [Suno接口](https://docs.newapi.pro/api/suno-music)
|
||||
- [重排序接口(Rerank)](https://docs.newapi.pro/api/jinaai-rerank)
|
||||
- [实时对话接口(Realtime)](https://docs.newapi.pro/api/openai-realtime)
|
||||
- [Claude聊天接口(messages)](https://docs.newapi.pro/api/anthropic-chat)
|
||||
|
||||
## 相关项目
|
||||
- [One API](https://github.com/songquanpeng/one-api):原版项目
|
||||
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持
|
||||
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代 AI 一站式 B/C 端解决方案
|
||||
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代AI一站式B/C端解决方案
|
||||
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool):用key查询使用额度
|
||||
|
||||
其他基于New API的项目:
|
||||
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版,并支持Claude格式
|
||||
- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的闭源项目
|
||||
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
|
||||
- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
|
||||
|
||||
## 帮助支持
|
||||
|
||||
如有问题,请参考[帮助支持](https://docs.newapi.pro/support):
|
||||
- [社区交流](https://docs.newapi.pro/support/community-interaction)
|
||||
- [反馈问题](https://docs.newapi.pro/support/feedback-issues)
|
||||
- [常见问题](https://docs.newapi.pro/support/faq)
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
|
||||
@@ -15,8 +15,9 @@ var SystemName = "New API"
|
||||
var Footer = ""
|
||||
var Logo = ""
|
||||
var TopUpLink = ""
|
||||
var ChatLink = ""
|
||||
var ChatLink2 = ""
|
||||
|
||||
// var ChatLink = ""
|
||||
// var ChatLink2 = ""
|
||||
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
||||
var DisplayInCurrencyEnabled = true
|
||||
var DisplayTokenStatEnabled = true
|
||||
@@ -76,7 +77,6 @@ var SMTPToken = ""
|
||||
|
||||
var GitHubClientId = ""
|
||||
var GitHubClientSecret = ""
|
||||
|
||||
var LinuxDOClientId = ""
|
||||
var LinuxDOClientSecret = ""
|
||||
|
||||
@@ -234,6 +234,7 @@ const (
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -286,4 +287,5 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ var fieldReplacer = strings.NewReplacer(
|
||||
"\r", "\\r")
|
||||
|
||||
var dataReplacer = strings.NewReplacer(
|
||||
"\n", "\ndata:",
|
||||
"\n", "\n",
|
||||
"\r", "\\r")
|
||||
|
||||
type CustomEvent struct {
|
||||
|
||||
24
common/gopool.go
Normal file
24
common/gopool.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"math"
|
||||
)
|
||||
|
||||
var relayGoPool gopool.Pool
|
||||
|
||||
func init() {
|
||||
relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
|
||||
relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
|
||||
if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
|
||||
SafeSendBool(stopChan, true)
|
||||
}
|
||||
SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i))
|
||||
})
|
||||
}
|
||||
|
||||
func RelayCtxGo(ctx context.Context, f func()) {
|
||||
relayGoPool.CtxGo(ctx, f)
|
||||
}
|
||||
@@ -32,6 +32,7 @@ func InitRedisClient() (err error) {
|
||||
if err != nil {
|
||||
FatalLog("failed to parse Redis connection string: " + err.Error())
|
||||
}
|
||||
opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
|
||||
RDB = redis.NewClient(opt)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
@@ -41,6 +42,10 @@ func InitRedisClient() (err error) {
|
||||
if err != nil {
|
||||
FatalLog("Redis ping test failed: " + err.Error())
|
||||
}
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
|
||||
SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -53,13 +58,20 @@ func ParseRedisOption() *redis.Options {
|
||||
}
|
||||
|
||||
func RedisSet(key string, value string, expiration time.Duration) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
|
||||
}
|
||||
ctx := context.Background()
|
||||
return RDB.Set(ctx, key, value, expiration).Err()
|
||||
}
|
||||
|
||||
func RedisGet(key string) (string, error) {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis GET: key=%s", key))
|
||||
}
|
||||
ctx := context.Background()
|
||||
return RDB.Get(ctx, key).Result()
|
||||
val, err := RDB.Get(ctx, key).Result()
|
||||
return val, err
|
||||
}
|
||||
|
||||
//func RedisExpire(key string, expiration time.Duration) error {
|
||||
@@ -73,16 +85,25 @@ func RedisGet(key string) (string, error) {
|
||||
//}
|
||||
|
||||
func RedisDel(key string) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
|
||||
}
|
||||
ctx := context.Background()
|
||||
return RDB.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisHDelObj(key string) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
|
||||
}
|
||||
ctx := context.Background()
|
||||
return RDB.HDel(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
data := make(map[string]interface{})
|
||||
@@ -130,6 +151,9 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||
}
|
||||
|
||||
func RedisHGetObj(key string, obj interface{}) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := RDB.HGetAll(ctx, key).Result()
|
||||
@@ -208,6 +232,9 @@ func RedisHGetObj(key string, obj interface{}) error {
|
||||
|
||||
// RedisIncr Add this function to handle atomic increments
|
||||
func RedisIncr(key string, delta int64) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
|
||||
}
|
||||
// 检查键的剩余生存时间
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
@@ -238,6 +265,9 @@ func RedisIncr(key string, delta int64) error {
|
||||
}
|
||||
|
||||
func RedisHIncrBy(key, field string, delta int64) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
|
||||
}
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
@@ -262,6 +292,9 @@ func RedisHIncrBy(key, field string, delta int64) error {
|
||||
}
|
||||
|
||||
func RedisHSetField(key, field string, value interface{}) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
|
||||
}
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"html/template"
|
||||
@@ -213,6 +214,24 @@ func RandomSleep() {
|
||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
func GetPointer[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func Any2Type[T any](data any) (T, error) {
|
||||
var zero T
|
||||
bytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
var res T
|
||||
err = json.Unmarshal(bytes, &res)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
|
||||
func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||
f, err := os.CreateTemp(os.TempDir(), filename)
|
||||
|
||||
@@ -2,4 +2,9 @@ package constant
|
||||
|
||||
const (
|
||||
ContextKeyRequestStartTime = "request_start_time"
|
||||
ContextKeyUserSetting = "user_setting"
|
||||
ContextKeyUserQuota = "user_quota"
|
||||
ContextKeyUserStatus = "user_status"
|
||||
ContextKeyUserEmail = "user_email"
|
||||
ContextKeyUserGroup = "user_group"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
@@ -23,9 +20,9 @@ var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
|
||||
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
|
||||
|
||||
var GeminiModelMap = map[string]string{
|
||||
"gemini-1.0-pro": "v1",
|
||||
}
|
||||
//var GeminiModelMap = map[string]string{
|
||||
// "gemini-1.0-pro": "v1",
|
||||
//}
|
||||
|
||||
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
|
||||
@@ -33,18 +30,18 @@ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
|
||||
func InitEnv() {
|
||||
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
if modelVersionMapStr == "" {
|
||||
return
|
||||
}
|
||||
for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
||||
parts := strings.Split(pair, ":")
|
||||
if len(parts) == 2 {
|
||||
GeminiModelMap[parts[0]] = parts[1]
|
||||
} else {
|
||||
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
||||
}
|
||||
}
|
||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
//if modelVersionMapStr == "" {
|
||||
// return
|
||||
//}
|
||||
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
||||
// parts := strings.Split(pair, ":")
|
||||
// if len(parts) == 2 {
|
||||
// GeminiModelMap[parts[0]] = parts[1]
|
||||
// } else {
|
||||
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -48,7 +49,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||
testModel == "text-embedding-v1" ||
|
||||
strings.Contains(testModel, "embed") ||
|
||||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
requestPath = "/v1/embeddings" // 修改请求路径
|
||||
}
|
||||
@@ -72,26 +73,29 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
}
|
||||
}
|
||||
|
||||
modelMapping := *channel.ModelMapping
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[testModel] != "" {
|
||||
testModel = modelMap[testModel]
|
||||
}
|
||||
cache, err := model.GetUserCache(1)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
cache.WriteContext(c)
|
||||
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
group, _ := model.GetUserGroup(1, false)
|
||||
c.Set("group", group)
|
||||
|
||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
testModel = info.UpstreamModelName
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
@@ -99,12 +103,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel)
|
||||
meta.UpstreamModelName = testModel
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
|
||||
|
||||
adaptor.Init(meta)
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -114,7 +117,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -122,11 +125,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp)
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||
}
|
||||
}
|
||||
usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
|
||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||
}
|
||||
@@ -139,26 +142,28 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
modelPrice, usePrice := common.GetModelPrice(testModel, false)
|
||||
modelRatio := common.GetModelRatio(testModel)
|
||||
completionRatio := common.GetCompletionRatio(testModel)
|
||||
ratio := modelRatio
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
quota := 0
|
||||
if !usePrice {
|
||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
|
||||
quota = int(math.Round(float64(quota) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
if !priceData.UsePrice {
|
||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
||||
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
||||
if priceData.ModelRatio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(modelPrice * common.QuotaPerUnit)
|
||||
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
|
||||
}
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
@@ -170,10 +175,10 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
}
|
||||
|
||||
// 先判断是否为 Embedding 模型
|
||||
if strings.Contains(strings.ToLower(model), "embedding") ||
|
||||
if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
|
||||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
|
||||
strings.Contains(model, "bge-") || // bge 系列模型
|
||||
model == "text-embedding-v1" { // 其他 embedding 模型
|
||||
strings.Contains(model, "bge-") {
|
||||
testRequest.Model = model
|
||||
// Embedding 请求
|
||||
testRequest.Input = []string{"hello world"}
|
||||
return testRequest
|
||||
@@ -181,6 +186,8 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
// 并非Embedding 模型
|
||||
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
|
||||
testRequest.MaxCompletionTokens = 10
|
||||
} else if strings.Contains(model, "thinking") {
|
||||
testRequest.MaxTokens = 50
|
||||
} else {
|
||||
testRequest.MaxTokens = 10
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
} else {
|
||||
if shouldReturnQuota {
|
||||
err = model.IncreaseUserQuota(task.UserId, task.Quota)
|
||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -33,40 +35,43 @@ func GetStatus(c *gin.Context) {
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"chat_link": common.ChatLink,
|
||||
"chat_link2": common.ChatLink2,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": setting.DemoSiteEnabled,
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
@@ -216,6 +216,13 @@ func DashboardListModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func EnabledListModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": model.GetEnabledModels(),
|
||||
})
|
||||
}
|
||||
|
||||
func RetrieveModel(c *gin.Context) {
|
||||
modelId := c.Param("model")
|
||||
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
||||
|
||||
240
controller/oidc.go
Normal file
240
controller/oidc.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OidcResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oidcResponse OidcResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
common.SysError("OIDC 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("OIDC 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var oidcUser OidcUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
common.SysError("OIDC 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &oidcUser, nil
|
||||
}
|
||||
|
||||
func OidcAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
OidcBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
err := user.FillUserByOidcId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
if oidcUser.PreferredUsername != "" {
|
||||
user.Username = oidcUser.PreferredUsername
|
||||
} else {
|
||||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if oidcUser.Name != "" {
|
||||
user.DisplayName = oidcUser.Name
|
||||
} else {
|
||||
user.DisplayName = "OIDC User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func OidcBind(c *gin.Context) {
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 OIDC 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -51,6 +52,13 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "oidc.enabled":
|
||||
if option.Value == "true" && system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret!",
|
||||
})
|
||||
}
|
||||
case "LinuxDOOAuthEnabled":
|
||||
if option.Value == "true" && common.LinuxDOClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -2,9 +2,9 @@ package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
)
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
@@ -40,7 +40,7 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
func ResetModelRatio(c *gin.Context) {
|
||||
defaultStr := common.DefaultModelRatio2JSONString()
|
||||
defaultStr := operation_setting.DefaultModelRatio2JSONString()
|
||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
@@ -49,7 +49,7 @@ func ResetModelRatio(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = common.UpdateModelRatioByJSONString(defaultStr)
|
||||
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"one-api/relay"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
@@ -41,15 +42,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
return err
|
||||
}
|
||||
|
||||
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
switch relayMode {
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
@@ -110,7 +102,7 @@ func WssRelay(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
service.WssError(c, ws, openaiErr.Error)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -152,7 +144,51 @@ func WssRelay(c *gin.Context) {
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
service.WssError(c, ws, openaiErr.Error)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func RelayClaude(c *gin.Context) {
|
||||
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var claudeErr *dto.ClaudeErrorWithStatusCode
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
|
||||
claudeErr = claudeRequest(c, channel)
|
||||
|
||||
if claudeErr == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if claudeErr != nil {
|
||||
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
|
||||
c.JSON(claudeErr.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": claudeErr.Error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,6 +206,13 @@ func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *mode
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.ClaudeHelper(c)
|
||||
}
|
||||
|
||||
func addUsedChannel(c *gin.Context, channelId int) {
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
|
||||
@@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
} else {
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota)
|
||||
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
|
||||
@@ -2,9 +2,6 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Calcium-Ion/go-epay/epay"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
@@ -14,16 +11,21 @@ import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Calcium-Ion/go-epay/epay"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
type EpayRequest struct {
|
||||
Amount int `json:"amount"`
|
||||
Amount int64 `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
}
|
||||
|
||||
type AmountRequest struct {
|
||||
Amount int `json:"amount"`
|
||||
Amount int64 `json:"amount"`
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
}
|
||||
|
||||
@@ -41,25 +43,35 @@ func GetEpayClient() *epay.Client {
|
||||
return withUrl
|
||||
}
|
||||
|
||||
func getPayMoney(amount float64, group string) float64 {
|
||||
func getPayMoney(amount int64, group string) float64 {
|
||||
dAmount := decimal.NewFromInt(amount)
|
||||
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
dAmount = dAmount.Div(dQuotaPerUnit)
|
||||
}
|
||||
// 别问为什么用float64,问就是这么点钱没必要
|
||||
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * setting.Price * topupGroupRatio
|
||||
return payMoney
|
||||
|
||||
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
|
||||
dPrice := decimal.NewFromFloat(setting.Price)
|
||||
|
||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
|
||||
|
||||
return payMoney.InexactFloat64()
|
||||
}
|
||||
|
||||
func getMinTopup() int {
|
||||
func getMinTopup() int64 {
|
||||
minTopup := setting.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
dMinTopup := decimal.NewFromInt(int64(minTopup))
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
minTopup = int(dMinTopup.Mul(dQuotaPerUnit).IntPart())
|
||||
}
|
||||
return minTopup
|
||||
return int64(minTopup)
|
||||
}
|
||||
|
||||
func RequestEpay(c *gin.Context) {
|
||||
@@ -80,7 +92,7 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(float64(req.Amount), group)
|
||||
payMoney := getPayMoney(req.Amount, group)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
@@ -118,7 +130,9 @@ func RequestEpay(c *gin.Context) {
|
||||
}
|
||||
amount := req.Amount
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / int(common.QuotaPerUnit)
|
||||
dAmount := decimal.NewFromInt(int64(amount))
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
amount = dAmount.Div(dQuotaPerUnit).IntPart()
|
||||
}
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
@@ -210,13 +224,16 @@ func EpayNotify(c *gin.Context) {
|
||||
}
|
||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||
//user.Quota += topUp.Amount * 500000
|
||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
|
||||
dAmount := decimal.NewFromInt(int64(topUp.Amount))
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
|
||||
err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
@@ -241,7 +258,7 @@ func RequestAmount(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(float64(req.Amount), group)
|
||||
payMoney := getPayMoney(req.Amount, group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
|
||||
212
dto/claude.go
Normal file
212
dto/claude.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeMediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
PartialJson *string `json:"partial_json,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
// tool_calls
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) SetText(s string) {
|
||||
c.Text = &s
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) GetText() string {
|
||||
if c.Text == nil {
|
||||
return ""
|
||||
}
|
||||
return *c.Text
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) IsStringContent() bool {
|
||||
var content string
|
||||
return json.Unmarshal(c.Content, &content) == nil
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) GetStringContent() string {
|
||||
var content string
|
||||
if err := json.Unmarshal(c.Content, &content); err == nil {
|
||||
return content
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) SetContent(content any) {
|
||||
jsonContent, _ := json.Marshal(content)
|
||||
c.Content = jsonContent
|
||||
}
|
||||
|
||||
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
|
||||
var mediaContent []ClaudeMediaMessage
|
||||
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
|
||||
return mediaContent
|
||||
}
|
||||
return make([]ClaudeMediaMessage, 0)
|
||||
}
|
||||
|
||||
type ClaudeMessageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
||||
type ClaudeMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) IsStringContent() bool {
|
||||
_, ok := c.Content.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) GetStringContent() string {
|
||||
if c.IsStringContent() {
|
||||
return c.Content.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) SetStringContent(content string) {
|
||||
c.Content = content
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
|
||||
// map content to []ClaudeMediaMessage
|
||||
// parse to json
|
||||
jsonContent, _ := json.Marshal(c.Content)
|
||||
var contentList []ClaudeMediaMessage
|
||||
err := json.Unmarshal(jsonContent, &contentList)
|
||||
if err != nil {
|
||||
return make([]ClaudeMediaMessage, 0), err
|
||||
}
|
||||
return contentList, nil
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"input_schema"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties any `json:"properties,omitempty"`
|
||||
Required any `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System any `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type Thinking struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens int `json:"budget_tokens"`
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) IsStringSystem() bool {
|
||||
_, ok := c.System.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) GetStringSystem() string {
|
||||
if c.IsStringSystem() {
|
||||
return c.System.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) SetStringSystem(system string) {
|
||||
c.System = system
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
||||
// map content to []ClaudeMediaMessage
|
||||
// parse to json
|
||||
jsonContent, _ := json.Marshal(c.System)
|
||||
var contentList []ClaudeMediaMessage
|
||||
if err := json.Unmarshal(jsonContent, &contentList); err == nil {
|
||||
return contentList
|
||||
}
|
||||
return make([]ClaudeMediaMessage, 0)
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ClaudeErrorWithStatusCode struct {
|
||||
Error ClaudeError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
LocalError bool
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Id string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ClaudeMediaMessage `json:"content,omitempty"`
|
||||
Completion string `json:"completion,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *ClaudeError `json:"error,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||
Delta *ClaudeMediaMessage `json:"delta,omitempty"`
|
||||
Message *ClaudeMediaMessage `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// set index
|
||||
func (c *ClaudeResponse) SetIndex(i int) {
|
||||
c.Index = &i
|
||||
}
|
||||
|
||||
// get index
|
||||
func (c *ClaudeResponse) GetIndex() int {
|
||||
if c.Index == nil {
|
||||
return 0
|
||||
}
|
||||
return *c.Index
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
@@ -18,50 +18,52 @@ type FormatJsonSchema struct {
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Prefix any `json:"prefix,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCall `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
ExtraBody any `json:"extra_body,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Prefix any `json:"prefix,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCallRequest `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
ExtraBody any `json:"extra_body,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAITools struct {
|
||||
Type string `json:"type"`
|
||||
Function OpenAIFunction `json:"function"`
|
||||
type ToolCallRequest struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Function FunctionRequest `json:"function"`
|
||||
}
|
||||
|
||||
type OpenAIFunction struct {
|
||||
type FunctionRequest struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
@@ -97,6 +99,7 @@ type Message struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Prefix *bool `json:"prefix,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
parsedContent []MediaContent
|
||||
@@ -137,11 +140,11 @@ func (m *Message) SetPrefix(prefix bool) {
|
||||
m.Prefix = &prefix
|
||||
}
|
||||
|
||||
func (m *Message) ParseToolCalls() []ToolCall {
|
||||
func (m *Message) ParseToolCalls() []ToolCallRequest {
|
||||
if m.ToolCalls == nil {
|
||||
return nil
|
||||
}
|
||||
var toolCalls []ToolCall
|
||||
var toolCalls []ToolCallRequest
|
||||
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
@@ -62,10 +62,11 @@ type ChatCompletionsStreamResponseChoice struct {
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoiceDelta struct {
|
||||
Content *string `json:"content,omitempty"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Content *string `json:"content,omitempty"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
Reasoning *string `json:"reasoning,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
||||
@@ -80,34 +81,38 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
|
||||
if c.ReasoningContent == nil {
|
||||
if c.ReasoningContent == nil && c.Reasoning == nil {
|
||||
return ""
|
||||
}
|
||||
return *c.ReasoningContent
|
||||
if c.ReasoningContent != nil {
|
||||
return *c.ReasoningContent
|
||||
}
|
||||
return *c.Reasoning
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
|
||||
c.ReasoningContent = &s
|
||||
c.Reasoning = &s
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
type ToolCallResponse struct {
|
||||
// Index is not nil only in chat completion chunk object
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type any `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type any `json:"type"`
|
||||
Function FunctionResponse `json:"function"`
|
||||
}
|
||||
|
||||
func (c *ToolCall) SetIndex(i int) {
|
||||
func (c *ToolCallResponse) SetIndex(i int) {
|
||||
c.Index = &i
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
type FunctionResponse struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
// call function with arguments in JSON format
|
||||
Parameters any `json:"parameters,omitempty"` // request
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
Arguments string `json:"arguments"` // response
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
@@ -120,6 +125,20 @@ type ChatCompletionsStreamResponse struct {
|
||||
Usage *Usage `json:"usage"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
|
||||
if len(c.Choices) == 0 {
|
||||
return false
|
||||
}
|
||||
return len(c.Choices[0].Delta.ToolCalls) > 0
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
|
||||
if c.IsToolCall() {
|
||||
return &c.Choices[0].Delta.ToolCalls[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
||||
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
||||
copy(choices, c.Choices)
|
||||
@@ -161,6 +180,7 @@ type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
|
||||
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
|
||||
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
|
||||
}
|
||||
|
||||
@@ -44,10 +44,11 @@ type RealtimeUsage struct {
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
CachedCreationTokens int
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
}
|
||||
|
||||
type OutputTokenDetails struct {
|
||||
|
||||
14
go.mod
14
go.mod
@@ -11,6 +11,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/bytedance/sonic v1.11.6
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@@ -22,15 +23,15 @@ require (
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pkoukk/tiktoken-go v0.1.7
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
golang.org/x/crypto v0.27.0
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.28.0
|
||||
golang.org/x/net v0.35.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/gorm v1.25.2
|
||||
@@ -42,7 +43,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||
github.com/aws/smithy-go v1.20.2 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
@@ -84,9 +84,9 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.12.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/sys v0.27.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.22.5 // indirect
|
||||
|
||||
24
go.sum
24
go.sum
@@ -117,8 +117,6 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
|
||||
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
|
||||
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
@@ -183,6 +181,8 @@ github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -217,18 +217,18 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
||||
golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
|
||||
golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
|
||||
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
|
||||
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
|
||||
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
|
||||
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -239,14 +239,14 @@ golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
|
||||
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
|
||||
@@ -174,6 +174,14 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
// 检查path包含/v1/messages
|
||||
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
|
||||
// 从x-api-key中获取key
|
||||
key := c.Request.Header.Get("x-api-key")
|
||||
if key != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
}
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
parts := make([]string, 0)
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
@@ -199,15 +207,19 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
userEnabled, err := model.IsUserEnabled(token.UserId, false)
|
||||
userCache, err := model.GetUserCache(token.UserId)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
userEnabled := userCache.Status == common.UserStatusEnabled
|
||||
if !userEnabled {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
|
||||
userCache.WriteContext(c)
|
||||
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
|
||||
@@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||
@@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||
return
|
||||
}
|
||||
userGroup, _ := model.GetUserGroup(userId, false)
|
||||
userGroup := c.GetString(constant.ContextKeyUserGroup)
|
||||
tokenGroup := c.GetString("token_group")
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
|
||||
@@ -51,7 +51,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
|
||||
// 如果在时间窗口内已达到限制,拒绝请求
|
||||
subTime := nowTime.Sub(oldTime).Seconds()
|
||||
if int64(subTime) < duration {
|
||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
||||
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC
|
||||
now := time.Now().Format(timeFormat)
|
||||
rdb.LPush(ctx, key, now)
|
||||
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
|
||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
||||
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
||||
}
|
||||
|
||||
// Redis限流处理器
|
||||
@@ -118,7 +118,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
|
||||
|
||||
// 内存限流处理器
|
||||
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
|
||||
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
|
||||
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
userId := strconv.Itoa(c.GetInt("id"))
|
||||
@@ -153,20 +153,23 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
|
||||
|
||||
// ModelRequestRateLimit 模型请求限流中间件
|
||||
func ModelRequestRateLimit() func(c *gin.Context) {
|
||||
// 如果未启用限流,直接放行
|
||||
if !setting.ModelRequestRateLimitEnabled {
|
||||
return defNext
|
||||
}
|
||||
return func(c *gin.Context) {
|
||||
// 在每个请求时检查是否启用限流
|
||||
if !setting.ModelRequestRateLimitEnabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 计算限流参数
|
||||
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
|
||||
totalMaxCount := setting.ModelRequestRateLimitCount
|
||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||
// 计算限流参数
|
||||
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
|
||||
totalMaxCount := setting.ModelRequestRateLimitCount
|
||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||
|
||||
// 根据存储类型选择限流处理器
|
||||
if common.RedisEnabled {
|
||||
return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
|
||||
} else {
|
||||
return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
|
||||
// 根据存储类型选择并执行限流处理器
|
||||
if common.RedisEnabled {
|
||||
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
||||
} else {
|
||||
memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ type Channel struct {
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting string `json:"setting" gorm:"type:text"`
|
||||
Setting *string `json:"setting" gorm:"type:text"`
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
@@ -290,35 +290,42 @@ func (channel *Channel) Delete() error {
|
||||
|
||||
var channelStatusLock sync.Mutex
|
||||
|
||||
func UpdateChannelStatusById(id int, status int, reason string) {
|
||||
func UpdateChannelStatusById(id int, status int, reason string) bool {
|
||||
if common.MemoryCacheEnabled {
|
||||
channelStatusLock.Lock()
|
||||
defer channelStatusLock.Unlock()
|
||||
|
||||
channelCache, _ := CacheGetChannel(id)
|
||||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||||
if channelCache != nil && channelCache.Status == status {
|
||||
channelStatusLock.Unlock()
|
||||
return
|
||||
return false
|
||||
}
|
||||
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
||||
if channelCache == nil && status != common.ChannelStatusEnabled {
|
||||
channelStatusLock.Unlock()
|
||||
return
|
||||
return false
|
||||
}
|
||||
CacheUpdateChannelStatus(id, status)
|
||||
channelStatusLock.Unlock()
|
||||
}
|
||||
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
||||
if err != nil {
|
||||
common.SysError("failed to update ability status: " + err.Error())
|
||||
return false
|
||||
}
|
||||
channel, err := GetChannelById(id, true)
|
||||
if err != nil {
|
||||
// find channel by id error, directly update status
|
||||
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel status: " + err.Error())
|
||||
result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
|
||||
if result.Error != nil {
|
||||
common.SysError("failed to update channel status: " + result.Error.Error())
|
||||
return false
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if channel.Status == status {
|
||||
return false
|
||||
}
|
||||
// find channel by id success, update status and other info
|
||||
info := channel.GetOtherInfo()
|
||||
info["status_reason"] = reason
|
||||
@@ -328,9 +335,10 @@ func UpdateChannelStatusById(id int, status int, reason string) {
|
||||
err = channel.Save()
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel status: " + err.Error())
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func EnableChannelByTag(tag string) error {
|
||||
@@ -485,8 +493,8 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
|
||||
func (channel *Channel) GetSetting() map[string]interface{} {
|
||||
setting := make(map[string]interface{})
|
||||
if channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(channel.Setting), &setting)
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
@@ -500,7 +508,7 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
}
|
||||
channel.Setting = string(settingBytes)
|
||||
channel.Setting = common.GetPointer[string](string(settingBytes))
|
||||
}
|
||||
|
||||
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
||||
|
||||
13
model/log.go
13
model/log.go
@@ -1,13 +1,14 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -18,7 +19,7 @@ type Log struct {
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
Content string `json:"content"`
|
||||
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
|
||||
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
|
||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
@@ -87,14 +88,14 @@ func RecordLog(userId int, logType int, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
||||
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
||||
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username, _ := GetUsernameById(userId, false)
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
@@ -116,7 +117,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
common.LogError(ctx, "failed to record log: "+err.Error())
|
||||
common.LogError(c, "failed to record log: "+err.Error())
|
||||
}
|
||||
if common.DataExportEnabled {
|
||||
gopool.Go(func() {
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"one-api/common"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var groupCol string
|
||||
@@ -60,7 +61,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
||||
}()
|
||||
dsn := os.Getenv(envName)
|
||||
if dsn != "" {
|
||||
if strings.HasPrefix(dsn, "postgres://") {
|
||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
common.UsingPostgreSQL = true
|
||||
|
||||
@@ -3,6 +3,8 @@ package model
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -23,6 +25,8 @@ func AllOption() ([]*Option, error) {
|
||||
func InitOptionMap() {
|
||||
common.OptionMapRWMutex.Lock()
|
||||
common.OptionMap = make(map[string]string)
|
||||
|
||||
// 添加原有的系统配置
|
||||
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
|
||||
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
|
||||
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
|
||||
@@ -84,18 +88,19 @@ func InitOptionMap() {
|
||||
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
|
||||
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
|
||||
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
|
||||
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
||||
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
||||
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
|
||||
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
||||
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
//common.OptionMap["ChatLink"] = common.ChatLink
|
||||
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
|
||||
@@ -107,14 +112,20 @@ func InitOptionMap() {
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
|
||||
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
|
||||
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
|
||||
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled)
|
||||
common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled)
|
||||
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
|
||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||
common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
|
||||
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
||||
|
||||
// 自动添加所有注册的模型配置
|
||||
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
||||
for k, v := range modelConfigs {
|
||||
common.OptionMap[k] = v
|
||||
}
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
@@ -158,6 +169,13 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.OptionMapRWMutex.Lock()
|
||||
defer common.OptionMapRWMutex.Unlock()
|
||||
common.OptionMap[key] = value
|
||||
|
||||
// 检查是否是模型配置 - 使用更规范的方式处理
|
||||
if handleConfigUpdate(key, value) {
|
||||
return nil // 已由配置系统处理
|
||||
}
|
||||
|
||||
// 处理传统配置项...
|
||||
if strings.HasSuffix(key, "Permission") {
|
||||
intValue, _ := strconv.Atoi(value)
|
||||
switch key {
|
||||
@@ -227,14 +245,13 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "CheckSensitiveEnabled":
|
||||
setting.CheckSensitiveEnabled = boolValue
|
||||
case "DemoSiteEnabled":
|
||||
setting.DemoSiteEnabled = boolValue
|
||||
operation_setting.DemoSiteEnabled = boolValue
|
||||
case "SelfUseModeEnabled":
|
||||
operation_setting.SelfUseModeEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
setting.CheckSensitiveOnPromptEnabled = boolValue
|
||||
case "ModelRequestRateLimitEnabled":
|
||||
setting.ModelRequestRateLimitEnabled = boolValue
|
||||
|
||||
//case "CheckSensitiveOnCompletionEnabled":
|
||||
// constant.CheckSensitiveOnCompletionEnabled = boolValue
|
||||
case "StopOnSensitiveEnabled":
|
||||
setting.StopOnSensitiveEnabled = boolValue
|
||||
case "SMTPSSLEnabled":
|
||||
@@ -313,7 +330,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.QuotaForInvitee, _ = strconv.Atoi(value)
|
||||
case "QuotaRemindThreshold":
|
||||
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
||||
case "ShouldPreConsumedQuota":
|
||||
case "PreConsumedQuota":
|
||||
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
||||
case "ModelRequestRateLimitCount":
|
||||
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
|
||||
@@ -328,21 +345,23 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "DataExportDefaultTime":
|
||||
common.DataExportDefaultTime = value
|
||||
case "ModelRatio":
|
||||
err = common.UpdateModelRatioByJSONString(value)
|
||||
err = operation_setting.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = setting.UpdateGroupRatioByJSONString(value)
|
||||
case "UserUsableGroups":
|
||||
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = common.UpdateCompletionRatioByJSONString(value)
|
||||
err = operation_setting.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
err = common.UpdateModelPriceByJSONString(value)
|
||||
err = operation_setting.UpdateModelPriceByJSONString(value)
|
||||
case "CacheRatio":
|
||||
err = operation_setting.UpdateCacheRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
case "ChatLink":
|
||||
common.ChatLink = value
|
||||
case "ChatLink2":
|
||||
common.ChatLink2 = value
|
||||
//case "ChatLink":
|
||||
// common.ChatLink = value
|
||||
//case "ChatLink2":
|
||||
// common.ChatLink2 = value
|
||||
case "ChannelDisableThreshold":
|
||||
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
|
||||
case "QuotaPerUnit":
|
||||
@@ -350,9 +369,34 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SensitiveWords":
|
||||
setting.SensitiveWordsFromString(value)
|
||||
case "AutomaticDisableKeywords":
|
||||
setting.AutomaticDisableKeywordsFromString(value)
|
||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handleConfigUpdate 处理分层配置更新,返回是否已处理
|
||||
func handleConfigUpdate(key, value string) bool {
|
||||
parts := strings.SplitN(key, ".", 2)
|
||||
if len(parts) != 2 {
|
||||
return false // 不是分层配置
|
||||
}
|
||||
|
||||
configName := parts[0]
|
||||
configKey := parts[1]
|
||||
|
||||
// 获取配置对象
|
||||
cfg := config.GlobalConfig.Get(configName)
|
||||
if cfg == nil {
|
||||
return false // 未注册的配置
|
||||
}
|
||||
|
||||
// 更新配置
|
||||
configMap := map[string]string{
|
||||
configKey: value,
|
||||
}
|
||||
config.UpdateConfigFromMap(cfg, configMap)
|
||||
|
||||
return true // 已处理
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -64,13 +65,14 @@ func updatePricing() {
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
}
|
||||
modelPrice, findPrice := common.GetModelPrice(model, false)
|
||||
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
pricing.ModelPrice = modelPrice
|
||||
pricing.QuotaType = 1
|
||||
} else {
|
||||
pricing.ModelRatio = common.GetModelRatio(model)
|
||||
pricing.CompletionRatio = common.GetCompletionRatio(model)
|
||||
modelRatio, _ := operation_setting.GetModelRatio(model)
|
||||
pricing.ModelRatio = modelRatio
|
||||
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
|
||||
pricing.QuotaType = 0
|
||||
}
|
||||
pricingMap = append(pricingMap, pricing)
|
||||
|
||||
@@ -3,7 +3,7 @@ package model
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int `json:"amount"`
|
||||
Amount int64 `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -24,6 +23,7 @@ type User struct {
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
@@ -320,7 +320,7 @@ func (user *User) Insert(inviterId int) error {
|
||||
}
|
||||
if inviterId != 0 {
|
||||
if common.QuotaForInvitee > 0 {
|
||||
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
|
||||
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
|
||||
}
|
||||
if common.QuotaForInviter > 0 {
|
||||
@@ -442,6 +442,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByOidcId() error {
|
||||
if user.OidcId == "" {
|
||||
return errors.New("oidc id 为空!")
|
||||
}
|
||||
DB.Where(User{OidcId: user.OidcId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByWeChatId() error {
|
||||
if user.WeChatId == "" {
|
||||
return errors.New("WeChat id 为空!")
|
||||
@@ -473,6 +481,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsTelegramIdAlreadyTaken(telegramId string) bool {
|
||||
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
@@ -502,35 +514,35 @@ func IsAdmin(userId int) bool {
|
||||
return user.Role >= common.RoleAdminUser
|
||||
}
|
||||
|
||||
// IsUserEnabled checks user status from Redis first, falls back to DB if needed
|
||||
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserStatusCache(id, status); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
// Try Redis first
|
||||
status, err := getUserStatusCache(id)
|
||||
if err == nil {
|
||||
return status == common.UserStatusEnabled, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
var user User
|
||||
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return user.Status == common.UserStatusEnabled, nil
|
||||
}
|
||||
//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
|
||||
//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
|
||||
// defer func() {
|
||||
// // Update Redis cache asynchronously on successful DB read
|
||||
// if shouldUpdateRedis(fromDB, err) {
|
||||
// gopool.Go(func() {
|
||||
// if err := updateUserStatusCache(id, status); err != nil {
|
||||
// common.SysError("failed to update user status cache: " + err.Error())
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }()
|
||||
// if !fromDB && common.RedisEnabled {
|
||||
// // Try Redis first
|
||||
// status, err := getUserStatusCache(id)
|
||||
// if err == nil {
|
||||
// return status == common.UserStatusEnabled, nil
|
||||
// }
|
||||
// // Don't return error - fall through to DB
|
||||
// }
|
||||
// fromDB = true
|
||||
// var user User
|
||||
// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
|
||||
// if err != nil {
|
||||
// return false, err
|
||||
// }
|
||||
//
|
||||
// return user.Status == common.UserStatusEnabled, nil
|
||||
//}
|
||||
|
||||
func ValidateAccessToken(token string) (user *User) {
|
||||
if token == "" {
|
||||
@@ -639,7 +651,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
|
||||
return common.StrToMap(setting), nil
|
||||
}
|
||||
|
||||
func IncreaseUserQuota(id int, quota int) (err error) {
|
||||
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
@@ -649,7 +661,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
|
||||
common.SysError("failed to increase user quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
if common.BatchUpdateEnabled {
|
||||
if !db && common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
||||
return nil
|
||||
}
|
||||
@@ -694,7 +706,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
|
||||
return nil
|
||||
}
|
||||
if delta > 0 {
|
||||
return IncreaseUserQuota(id, delta)
|
||||
return IncreaseUserQuota(id, delta, false)
|
||||
} else {
|
||||
return DecreaseUserQuota(id, -delta)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"time"
|
||||
@@ -21,6 +22,15 @@ type UserBase struct {
|
||||
Setting string `json:"setting"`
|
||||
}
|
||||
|
||||
func (user *UserBase) WriteContext(c *gin.Context) {
|
||||
c.Set(constant.ContextKeyUserGroup, user.Group)
|
||||
c.Set(constant.ContextKeyUserQuota, user.Quota)
|
||||
c.Set(constant.ContextKeyUserStatus, user.Status)
|
||||
c.Set(constant.ContextKeyUserEmail, user.Email)
|
||||
c.Set("username", user.Username)
|
||||
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
|
||||
}
|
||||
|
||||
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||
if user.Setting == "" {
|
||||
return nil
|
||||
|
||||
@@ -13,7 +13,7 @@ type Adaptor interface {
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
|
||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||
@@ -22,6 +22,7 @@ type Adaptor interface {
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
@@ -44,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
||||
}
|
||||
|
||||
func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) {
|
||||
url := fmt.Sprintf("/api/v1/tasks/%s", taskID)
|
||||
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
|
||||
|
||||
var aliResponse AliResponse
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
@@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
lastResponseText := ""
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
common2 "one-api/common"
|
||||
"one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
@@ -31,6 +32,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
if common2.DebugEnabled {
|
||||
println("fullRequestURL:", fullRequestURL)
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
@@ -130,7 +134,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := doRequest(c, req, info.ToRelayInfo())
|
||||
resp, err := doRequest(c, req, info.RelayInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/model_setting"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,6 +20,10 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -38,19 +43,22 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
var claudeReq *claude.ClaudeRequest
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
|
||||
c.Set("request_model", request.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Set("request_model", claudeReq.Model)
|
||||
c.Set("converted_request", claudeReq)
|
||||
return claudeReq, err
|
||||
}
|
||||
@@ -64,7 +72,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,8 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
}
|
||||
|
||||
var ChannelName = "aws"
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
type AwsClaudeRequest struct {
|
||||
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []claude.ClaudeMessage `json:"messages"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []claude.Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
System any `json:"system,omitempty"`
|
||||
Messages []dto.ClaudeMessage `json:"messages"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *dto.Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
|
||||
func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
||||
return &AwsClaudeRequest{
|
||||
AnthropicVersion: "bedrock-2023-05-31",
|
||||
System: req.System,
|
||||
@@ -30,5 +31,6 @@ func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
|
||||
StopSequences: req.StopSequences,
|
||||
Tools: req.Tools,
|
||||
ToolChoice: req.ToolChoice,
|
||||
Thinking: req.Thinking,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,9 +9,10 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
relaymodel "one-api/dto"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -38,10 +39,10 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
|
||||
return &relaymodel.OpenAIErrorWithStatusCode{
|
||||
func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: relaymodel.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: fmt.Sprintf("%s", err.Error()),
|
||||
},
|
||||
}
|
||||
@@ -55,7 +56,7 @@ func awsModelID(requestModel string) (string, error) {
|
||||
return requestModel, nil
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
@@ -76,7 +77,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*claude.ClaudeRequest)
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
@@ -88,14 +89,14 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||
}
|
||||
|
||||
claudeResponse := new(claude.ClaudeResponse)
|
||||
claudeResponse := new(dto.ClaudeResponse)
|
||||
err = json.Unmarshal(awsResp.Body, claudeResponse)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
|
||||
}
|
||||
|
||||
openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
|
||||
usage := relaymodel.Usage{
|
||||
usage := dto.Usage{
|
||||
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||
CompletionTokens: claudeResponse.Usage.OutputTokens,
|
||||
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
|
||||
@@ -106,7 +107,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
@@ -127,7 +128,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*claude.ClaudeRequest)
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
@@ -143,11 +144,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
defer stream.Close()
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
var usage relaymodel.Usage
|
||||
var id string
|
||||
var model string
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
isFirst := true
|
||||
createdTime := common.GetTimestamp()
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
@@ -160,33 +164,19 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
claudeResp := new(claude.ClaudeResponse)
|
||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
||||
claudeResponse := new(dto.ClaudeResponse)
|
||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
|
||||
if claudeUsage != nil {
|
||||
usage.PromptTokens += claudeUsage.InputTokens
|
||||
usage.CompletionTokens += claudeUsage.OutputTokens
|
||||
}
|
||||
response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse)
|
||||
|
||||
if response == nil {
|
||||
if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) {
|
||||
return true
|
||||
}
|
||||
|
||||
if response.Id != "" {
|
||||
id = response.Id
|
||||
}
|
||||
if response.Model != "" {
|
||||
model = response.Model
|
||||
}
|
||||
response.Created = createdTime
|
||||
response.Id = id
|
||||
response.Model = model
|
||||
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
@@ -202,19 +192,27 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
||||
err := service.ObjectData(c, response)
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
helper.Done(c)
|
||||
if resp != nil {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
}
|
||||
return nil, &usage
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -104,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/model_setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -21,6 +22,10 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -55,10 +60,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Set("anthropic-version", anthropicVersion)
|
||||
model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ var ModelList = []string{
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-7-sonnet-20250219-thinking",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -1,85 +1,95 @@
|
||||
package claude
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeMediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
PartialJson string `json:"partial_json,omitempty"`
|
||||
// tool_calls
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeMessageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type ClaudeMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"input_schema"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties any `json:"properties,omitempty"`
|
||||
Required any `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Content []ClaudeMediaMessage `json:"content"`
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
Index int `json:"index"` // stream only
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block"`
|
||||
Delta *ClaudeMediaMessage `json:"delta"` // stream only
|
||||
Message *ClaudeResponse `json:"message"` // stream only: message_start
|
||||
}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
//
|
||||
//type ClaudeMetadata struct {
|
||||
// UserId string `json:"user_id"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeMediaMessage struct {
|
||||
// Type string `json:"type"`
|
||||
// Text string `json:"text,omitempty"`
|
||||
// Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
// Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
// StopReason *string `json:"stop_reason,omitempty"`
|
||||
// PartialJson string `json:"partial_json,omitempty"`
|
||||
// Thinking string `json:"thinking,omitempty"`
|
||||
// Signature string `json:"signature,omitempty"`
|
||||
// Delta string `json:"delta,omitempty"`
|
||||
// // tool_calls
|
||||
// Id string `json:"id,omitempty"`
|
||||
// Name string `json:"name,omitempty"`
|
||||
// Input any `json:"input,omitempty"`
|
||||
// Content string `json:"content,omitempty"`
|
||||
// ToolUseId string `json:"tool_use_id,omitempty"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeMessageSource struct {
|
||||
// Type string `json:"type"`
|
||||
// MediaType string `json:"media_type"`
|
||||
// Data string `json:"data"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeMessage struct {
|
||||
// Role string `json:"role"`
|
||||
// Content any `json:"content"`
|
||||
//}
|
||||
//
|
||||
//type Tool struct {
|
||||
// Name string `json:"name"`
|
||||
// Description string `json:"description,omitempty"`
|
||||
// InputSchema map[string]interface{} `json:"input_schema"`
|
||||
//}
|
||||
//
|
||||
//type InputSchema struct {
|
||||
// Type string `json:"type"`
|
||||
// Properties any `json:"properties,omitempty"`
|
||||
// Required any `json:"required,omitempty"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeRequest struct {
|
||||
// Model string `json:"model"`
|
||||
// Prompt string `json:"prompt,omitempty"`
|
||||
// System string `json:"system,omitempty"`
|
||||
// Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
// MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
// MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
// StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
// Temperature *float64 `json:"temperature,omitempty"`
|
||||
// TopP float64 `json:"top_p,omitempty"`
|
||||
// TopK int `json:"top_k,omitempty"`
|
||||
// //ClaudeMetadata `json:"metadata,omitempty"`
|
||||
// Stream bool `json:"stream,omitempty"`
|
||||
// Tools any `json:"tools,omitempty"`
|
||||
// ToolChoice any `json:"tool_choice,omitempty"`
|
||||
// Thinking *Thinking `json:"thinking,omitempty"`
|
||||
//}
|
||||
//
|
||||
//type Thinking struct {
|
||||
// Type string `json:"type"`
|
||||
// BudgetTokens int `json:"budget_tokens"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeError struct {
|
||||
// Type string `json:"type"`
|
||||
// Message string `json:"message"`
|
||||
//}
|
||||
//
|
||||
//type ClaudeResponse struct {
|
||||
// Id string `json:"id"`
|
||||
// Type string `json:"type"`
|
||||
// Content []ClaudeMediaMessage `json:"content"`
|
||||
// Completion string `json:"completion"`
|
||||
// StopReason string `json:"stop_reason"`
|
||||
// Model string `json:"model"`
|
||||
// Error ClaudeError `json:"error"`
|
||||
// Usage ClaudeUsage `json:"usage"`
|
||||
// Index int `json:"index"` // stream only
|
||||
// ContentBlock *ClaudeMediaMessage `json:"content_block"`
|
||||
// Delta *ClaudeMediaMessage `json:"delta"` // stream only
|
||||
// Message *ClaudeResponse `json:"message"` // stream only: message_start
|
||||
//}
|
||||
//
|
||||
//type ClaudeUsage struct {
|
||||
// InputTokens int `json:"input_tokens"`
|
||||
// OutputTokens int `json:"output_tokens"`
|
||||
//}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -28,9 +30,9 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
|
||||
|
||||
claudeRequest := ClaudeRequest{
|
||||
claudeRequest := dto.ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
StopSequences: nil,
|
||||
@@ -59,12 +61,12 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
||||
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
|
||||
|
||||
for _, tool := range textRequest.Tools {
|
||||
if params, ok := tool.Function.Parameters.(map[string]any); ok {
|
||||
claudeTool := Tool{
|
||||
claudeTool := dto.Tool{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
}
|
||||
@@ -82,7 +84,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
}
|
||||
}
|
||||
|
||||
claudeRequest := ClaudeRequest{
|
||||
claudeRequest := dto.ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
StopSequences: nil,
|
||||
@@ -92,9 +94,31 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
Stream: textRequest.Stream,
|
||||
Tools: claudeTools,
|
||||
}
|
||||
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
}
|
||||
|
||||
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
|
||||
// 因为BudgetTokens 必须大于1024
|
||||
if claudeRequest.MaxTokens < 1280 {
|
||||
claudeRequest.MaxTokens = 1280
|
||||
}
|
||||
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
claudeRequest.TopP = 0
|
||||
claudeRequest.Temperature = common.GetPointer[float64](1.0)
|
||||
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
}
|
||||
|
||||
if textRequest.Stop != nil {
|
||||
// stop maybe string/array string, convert to array string
|
||||
switch textRequest.Stop.(type) {
|
||||
@@ -142,7 +166,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
lastMessage = fmtMessage
|
||||
}
|
||||
|
||||
claudeMessages := make([]ClaudeMessage, 0)
|
||||
claudeMessages := make([]dto.ClaudeMessage, 0)
|
||||
isFirstMessage := true
|
||||
for _, message := range formatMessages {
|
||||
if message.Role == "system" {
|
||||
@@ -163,63 +187,63 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
isFirstMessage = false
|
||||
if message.Role != "user" {
|
||||
// fix: first message is assistant, add user message
|
||||
claudeMessage := ClaudeMessage{
|
||||
claudeMessage := dto.ClaudeMessage{
|
||||
Role: "user",
|
||||
Content: []ClaudeMediaMessage{
|
||||
Content: []dto.ClaudeMediaMessage{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "...",
|
||||
Text: common.GetPointer[string]("..."),
|
||||
},
|
||||
},
|
||||
}
|
||||
claudeMessages = append(claudeMessages, claudeMessage)
|
||||
}
|
||||
}
|
||||
claudeMessage := ClaudeMessage{
|
||||
claudeMessage := dto.ClaudeMessage{
|
||||
Role: message.Role,
|
||||
}
|
||||
if message.Role == "tool" {
|
||||
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
|
||||
lastMessage := claudeMessages[len(claudeMessages)-1]
|
||||
if content, ok := lastMessage.Content.(string); ok {
|
||||
lastMessage.Content = []ClaudeMediaMessage{
|
||||
lastMessage.Content = []dto.ClaudeMediaMessage{
|
||||
{
|
||||
Type: "text",
|
||||
Text: content,
|
||||
Text: common.GetPointer[string](content),
|
||||
},
|
||||
}
|
||||
}
|
||||
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
|
||||
lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
|
||||
Type: "tool_result",
|
||||
ToolUseId: message.ToolCallId,
|
||||
Content: message.StringContent(),
|
||||
Content: message.Content,
|
||||
})
|
||||
claudeMessages[len(claudeMessages)-1] = lastMessage
|
||||
continue
|
||||
} else {
|
||||
claudeMessage.Role = "user"
|
||||
claudeMessage.Content = []ClaudeMediaMessage{
|
||||
claudeMessage.Content = []dto.ClaudeMediaMessage{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseId: message.ToolCallId,
|
||||
Content: message.StringContent(),
|
||||
Content: message.Content,
|
||||
},
|
||||
}
|
||||
}
|
||||
} else if message.IsStringContent() && message.ToolCalls == nil {
|
||||
claudeMessage.Content = message.StringContent()
|
||||
} else {
|
||||
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
|
||||
claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
|
||||
for _, mediaMessage := range message.ParseContent() {
|
||||
claudeMediaMessage := ClaudeMediaMessage{
|
||||
claudeMediaMessage := dto.ClaudeMediaMessage{
|
||||
Type: mediaMessage.Type,
|
||||
}
|
||||
if mediaMessage.Type == "text" {
|
||||
claudeMediaMessage.Text = mediaMessage.Text
|
||||
claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
|
||||
} else {
|
||||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||||
claudeMediaMessage.Type = "image"
|
||||
claudeMediaMessage.Source = &ClaudeMessageSource{
|
||||
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
}
|
||||
// 判断是否是url
|
||||
@@ -249,7 +273,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
||||
continue
|
||||
}
|
||||
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
|
||||
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
|
||||
Type: "tool_use",
|
||||
Id: toolCall.ID,
|
||||
Name: toolCall.Function.Name,
|
||||
@@ -267,13 +291,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
return &claudeRequest, nil
|
||||
}
|
||||
|
||||
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
|
||||
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
var claudeUsage *ClaudeUsage
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
|
||||
tools := make([]dto.ToolCall, 0)
|
||||
tools := make([]dto.ToolCallResponse, 0)
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if reqMode == RequestModeCompletion {
|
||||
choice.Delta.SetContentString(claudeResponse.Completion)
|
||||
@@ -285,35 +308,43 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
if claudeResponse.Type == "message_start" {
|
||||
response.Id = claudeResponse.Message.Id
|
||||
response.Model = claudeResponse.Message.Model
|
||||
claudeUsage = &claudeResponse.Message.Usage
|
||||
//claudeUsage = &claudeResponse.Message.Usage
|
||||
choice.Delta.SetContentString("")
|
||||
choice.Delta.Role = "assistant"
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
if claudeResponse.ContentBlock != nil {
|
||||
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
|
||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||
tools = append(tools, dto.ToolCall{
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
ID: claudeResponse.ContentBlock.Id,
|
||||
Type: "function",
|
||||
Function: dto.FunctionCall{
|
||||
Function: dto.FunctionResponse{
|
||||
Name: claudeResponse.ContentBlock.Name,
|
||||
Arguments: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta != nil {
|
||||
choice.Index = claudeResponse.Index
|
||||
choice.Delta.SetContentString(claudeResponse.Delta.Text)
|
||||
if claudeResponse.Delta.Type == "input_json_delta" {
|
||||
tools = append(tools, dto.ToolCall{
|
||||
Function: dto.FunctionCall{
|
||||
Arguments: claudeResponse.Delta.PartialJson,
|
||||
choice.Index = *claudeResponse.Index
|
||||
choice.Delta.Content = claudeResponse.Delta.Text
|
||||
switch claudeResponse.Delta.Type {
|
||||
case "input_json_delta":
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
Function: dto.FunctionResponse{
|
||||
Arguments: *claudeResponse.Delta.PartialJson,
|
||||
},
|
||||
})
|
||||
case "signature_delta":
|
||||
// 加密的不处理
|
||||
signatureContent := "\n"
|
||||
choice.Delta.ReasoningContent = &signatureContent
|
||||
case "thinking_delta":
|
||||
thinkingContent := claudeResponse.Delta.Thinking
|
||||
choice.Delta.ReasoningContent = &thinkingContent
|
||||
}
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
@@ -321,26 +352,23 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
claudeUsage = &claudeResponse.Usage
|
||||
//claudeUsage = &claudeResponse.Usage
|
||||
} else if claudeResponse.Type == "message_stop" {
|
||||
return nil, nil
|
||||
return nil
|
||||
} else {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if claudeUsage == nil {
|
||||
claudeUsage = &ClaudeUsage{}
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
||||
choice.Delta.ToolCalls = tools
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
|
||||
return &response, claudeUsage
|
||||
return &response
|
||||
}
|
||||
|
||||
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||
func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
|
||||
choices := make([]dto.OpenAITextResponseChoice, 0)
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
@@ -349,9 +377,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
}
|
||||
var responseText string
|
||||
if len(claudeResponse.Content) > 0 {
|
||||
responseText = claudeResponse.Content[0].Text
|
||||
responseText = *claudeResponse.Content[0].Text
|
||||
}
|
||||
tools := make([]dto.ToolCall, 0)
|
||||
tools := make([]dto.ToolCallResponse, 0)
|
||||
thinkingContent := ""
|
||||
|
||||
if reqMode == RequestModeCompletion {
|
||||
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
@@ -367,16 +397,22 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
} else {
|
||||
fullTextResponse.Id = claudeResponse.Id
|
||||
for _, message := range claudeResponse.Content {
|
||||
if message.Type == "tool_use" {
|
||||
switch message.Type {
|
||||
case "tool_use":
|
||||
args, _ := json.Marshal(message.Input)
|
||||
tools = append(tools, dto.ToolCall{
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
ID: message.Id,
|
||||
Type: "function", // compatible with other OpenAI derivative applications
|
||||
Function: dto.FunctionCall{
|
||||
Function: dto.FunctionResponse{
|
||||
Name: message.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
})
|
||||
case "thinking":
|
||||
// 加密的不管, 只输出明文的推理过程
|
||||
thinkingContent = message.Thinking
|
||||
case "text":
|
||||
responseText = *message.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -391,92 +427,155 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
if len(tools) > 0 {
|
||||
choice.Message.SetToolCalls(tools)
|
||||
}
|
||||
choice.Message.ReasoningContent = thinkingContent
|
||||
fullTextResponse.Model = claudeResponse.Model
|
||||
choices = append(choices, choice)
|
||||
fullTextResponse.Choices = choices
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
var usage *dto.Usage
|
||||
usage = &dto.Usage{}
|
||||
responseText := ""
|
||||
createdTime := common.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
service.SetEventStreamHeaders(c)
|
||||
type ClaudeResponseInfo struct {
|
||||
ResponseId string
|
||||
Created int64
|
||||
Model string
|
||||
ResponseText strings.Builder
|
||||
Usage *dto.Usage
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
info.SetFirstResponseTime()
|
||||
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||
claudeInfo.Model = claudeResponse.Message.Model
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
data = strings.TrimSpace(data)
|
||||
var claudeResponse ClaudeResponse
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
}
|
||||
if oaiResponse != nil {
|
||||
oaiResponse.Id = claudeInfo.ResponseId
|
||||
oaiResponse.Created = claudeInfo.Created
|
||||
oaiResponse.Model = claudeInfo.Model
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
|
||||
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
return toOpenAIStreamHandler(c, resp, info, requestMode)
|
||||
}
|
||||
|
||||
usage := &dto.Usage{}
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
if response == nil {
|
||||
continue
|
||||
return true
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
responseText += claudeResponse.Completion
|
||||
responseId = response.Id
|
||||
responseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
responseId = claudeResponse.Message.Id
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
usage.PromptTokens = claudeUsage.InputTokens
|
||||
usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
responseText += claudeResponse.Delta.Text
|
||||
responseText.WriteString(claudeResponse.Delta.GetText())
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
usage.CompletionTokens = claudeUsage.OutputTokens
|
||||
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
|
||||
} else {
|
||||
continue
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
//response.Id = responseId
|
||||
response.Id = responseId
|
||||
response.Created = createdTime
|
||||
response.Model = info.UpstreamModelName
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
return true
|
||||
})
|
||||
|
||||
err = service.ObjectData(c, response)
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 说明流模式建立失败,可能为官方出错
|
||||
if usage.PromptTokens == 0 {
|
||||
//usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: responseId,
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
|
||||
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
|
||||
return true
|
||||
}
|
||||
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if usage.PromptTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
resp.Body.Close()
|
||||
return nil, usage
|
||||
helper.Done(c)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
@@ -488,7 +587,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var claudeResponse ClaudeResponse
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -504,13 +603,12 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens = completionTokens
|
||||
usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
@@ -518,14 +616,23 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
|
||||
usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
var responseData []byte
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
openaiResponse.Usage = usage
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
responseData = responseBody
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
_, err = c.Writer.Write(responseData)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
@@ -17,6 +17,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
@@ -37,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
id := service.GetResponseID(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
id := helper.GetResponseID(c)
|
||||
var responseText string
|
||||
isFirst := true
|
||||
|
||||
@@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
}
|
||||
response.Id = id
|
||||
response.Model = info.UpstreamModelName
|
||||
err = service.ObjectData(c, response)
|
||||
err = helper.ObjectData(c, response)
|
||||
if isFirst {
|
||||
isFirst = false
|
||||
info.FirstResponseTime = time.Now()
|
||||
@@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
helper.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
@@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = service.GetResponseID(c)
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -42,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return requestOpenAI2Cohere(*request), nil
|
||||
}
|
||||
|
||||
@@ -59,7 +65,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = cohereRerankHandler(c, resp, info)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
isFirst := true
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -44,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -23,6 +23,12 @@ type Adaptor struct {
|
||||
BotType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -64,7 +70,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
@@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
@@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
}
|
||||
err = service.ObjectData(c, openaiResponse)
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
@@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
service.Done(c)
|
||||
helper.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
|
||||
"strings"
|
||||
|
||||
@@ -21,6 +21,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -64,20 +70,18 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta"
|
||||
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
|
||||
if !beta {
|
||||
if info.ApiVersion != "" {
|
||||
version = info.ApiVersion
|
||||
} else {
|
||||
version = "v1beta"
|
||||
}
|
||||
}
|
||||
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent?alt=sse"
|
||||
@@ -91,7 +95,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -107,8 +111,37 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
if request.Input == nil {
|
||||
return nil, errors.New("input is required")
|
||||
}
|
||||
|
||||
inputs := request.ParseInput()
|
||||
if len(inputs) == 0 {
|
||||
return nil, errors.New("input is empty")
|
||||
}
|
||||
|
||||
// only process the first input
|
||||
geminiRequest := GeminiEmbeddingRequest{
|
||||
Content: GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: inputs[0],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// set specific parameters for different models
|
||||
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
|
||||
switch info.UpstreamModelName {
|
||||
case "text-embedding-004":
|
||||
// except embedding-001 supports setting `OutputDimensionality`
|
||||
if request.Dimensions > 0 {
|
||||
geminiRequest.OutputDimensionality = request.Dimensions
|
||||
}
|
||||
}
|
||||
|
||||
return geminiRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
@@ -120,6 +153,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return GeminiImageHandler(c, resp, info)
|
||||
}
|
||||
|
||||
// check if the model is an embedding model
|
||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||
return GeminiEmbeddingHandler(c, resp, info)
|
||||
}
|
||||
|
||||
if info.IsStream {
|
||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -18,6 +18,18 @@ var ModelList = []string{
|
||||
"gemini-2.0-flash-thinking-exp",
|
||||
// imagen models
|
||||
"imagen-3.0-generate-002",
|
||||
// embedding models
|
||||
"gemini-embedding-exp-03-07",
|
||||
"text-embedding-004",
|
||||
"embedding-001",
|
||||
}
|
||||
|
||||
var SafetySettingList = []string{
|
||||
"HARM_CATEGORY_HARASSMENT",
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
}
|
||||
|
||||
var ChannelName = "google gemini"
|
||||
|
||||
@@ -136,3 +136,19 @@ type GeminiImagePrediction struct {
|
||||
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
|
||||
SafetyAttributes any `json:"safetyAttributes,omitempty"`
|
||||
}
|
||||
|
||||
// Embedding related structs
|
||||
type GeminiEmbeddingRequest struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
TaskType string `json:"taskType,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiEmbeddingResponse struct {
|
||||
Embedding ContentEmbedding `json:"embedding"`
|
||||
}
|
||||
|
||||
type ContentEmbedding struct {
|
||||
Values []float64 `json:"values"`
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,7 +9,9 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -22,28 +23,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
{
|
||||
Category: "HARM_CATEGORY_HARASSMENT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
},
|
||||
//SafetySettings: []GeminiChatSafetySettings{},
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
@@ -52,9 +32,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
},
|
||||
}
|
||||
|
||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
for _, category := range SafetySettingList {
|
||||
safetySettings = append(safetySettings, GeminiChatSafetySettings{
|
||||
Category: category,
|
||||
Threshold: model_setting.GetGeminiSafetySetting(category),
|
||||
})
|
||||
}
|
||||
geminiRequest.SafetySettings = safetySettings
|
||||
|
||||
// openaiContent.FuncToToolCalls()
|
||||
if textRequest.Tools != nil {
|
||||
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
|
||||
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
|
||||
googleSearch := false
|
||||
codeExecution := false
|
||||
for _, tool := range textRequest.Tools {
|
||||
@@ -349,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
|
||||
return data
|
||||
}
|
||||
|
||||
func getToolCall(item *GeminiPart) *dto.ToolCall {
|
||||
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
var argsBytes []byte
|
||||
var err error
|
||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||
@@ -361,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &dto.ToolCall{
|
||||
return &dto.ToolCallResponse{
|
||||
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
Type: "function",
|
||||
Function: dto.FunctionCall{
|
||||
Function: dto.FunctionResponse{
|
||||
Arguments: string(argsBytes),
|
||||
Name: item.FunctionCall.FunctionName,
|
||||
},
|
||||
@@ -379,7 +368,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
content, _ := json.Marshal("")
|
||||
is_tool_call := false
|
||||
isToolCall := false
|
||||
for _, candidate := range response.Candidates {
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: int(candidate.Index),
|
||||
@@ -391,12 +380,12 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
var texts []string
|
||||
var tool_calls []dto.ToolCall
|
||||
var toolCalls []dto.ToolCallResponse
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
if call := getToolCall(&part); call != nil {
|
||||
tool_calls = append(tool_calls, *call)
|
||||
if call := getResponseToolCall(&part); call != nil {
|
||||
toolCalls = append(toolCalls, *call)
|
||||
}
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
@@ -411,9 +400,9 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(tool_calls) > 0 {
|
||||
choice.Message.SetToolCalls(tool_calls)
|
||||
is_tool_call = true
|
||||
if len(toolCalls) > 0 {
|
||||
choice.Message.SetToolCalls(toolCalls)
|
||||
isToolCall = true
|
||||
}
|
||||
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
@@ -429,7 +418,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
choice.FinishReason = constant.FinishReasonContentFilter
|
||||
}
|
||||
}
|
||||
if is_tool_call {
|
||||
if isToolCall {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
}
|
||||
|
||||
@@ -440,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||
is_stop := false
|
||||
isStop := false
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
||||
is_stop = true
|
||||
isStop = true
|
||||
candidate.FinishReason = nil
|
||||
}
|
||||
choice := dto.ChatCompletionsStreamResponseChoice{
|
||||
@@ -468,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
isTools = true
|
||||
if call := getToolCall(&part); call != nil {
|
||||
if call := getResponseToolCall(&part); call != nil {
|
||||
call.SetIndex(len(choice.Delta.ToolCalls))
|
||||
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
|
||||
}
|
||||
@@ -493,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = choices
|
||||
return &response, is_stop
|
||||
return &response, isStop
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
@@ -503,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createAt := common.GetTimestamp()
|
||||
var usage = &dto.Usage{}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
info.SetFirstResponseTime()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := json.Unmarshal([]byte(data), &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
continue
|
||||
return false
|
||||
}
|
||||
|
||||
response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
@@ -532,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
}
|
||||
err = service.ObjectData(c, response)
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
if is_stop {
|
||||
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
||||
service.ObjectData(c, response)
|
||||
if isStop {
|
||||
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
||||
helper.ObjectData(c, response)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
var response *dto.ChatCompletionsStreamResponse
|
||||
|
||||
@@ -549,14 +527,14 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
resp.Body.Close()
|
||||
helper.Done(c)
|
||||
//resp.Body.Close()
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -602,3 +580,52 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse GeminiEmbeddingResponse
|
||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: []dto.OpenAIEmbeddingResponseItem{
|
||||
{
|
||||
Object: "embedding",
|
||||
Embedding: geminiResponse.Embedding.Values,
|
||||
Index: 0,
|
||||
},
|
||||
},
|
||||
Model: info.UpstreamModelName,
|
||||
}
|
||||
|
||||
// calculate usage
|
||||
// https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
|
||||
// Google has not yet clarified how embedding models will be billed
|
||||
// refer to openai billing method to use input tokens billing
|
||||
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
openAIResponse.Usage = *usage.(*dto.Usage)
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -43,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -61,7 +67,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = jinaRerankHandler(c, resp)
|
||||
err, usage = JinaRerankHandler(c, resp)
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
err, usage = jinaEmbeddingHandler(c, resp)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -14,6 +14,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -37,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -51,7 +57,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -73,13 +79,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = mokaEmbeddingHandler(c, resp)
|
||||
default:
|
||||
// err, usage = mokaHandler(c, resp)
|
||||
|
||||
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -43,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
@@ -3,21 +3,22 @@ package ollama
|
||||
import "one-api/dto"
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Topp float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Tools []dto.ToolCall `json:"tools,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Topp float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
|
||||
@@ -58,6 +58,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
Tools: request.Tools,
|
||||
MaxTokens: request.MaxTokens,
|
||||
ResponseFormat: request.ResponseFormat,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
|
||||
@@ -14,11 +14,14 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/jina"
|
||||
"one-api/relay/channel/lingyiwanwu"
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -27,11 +30,30 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if !strings.HasPrefix(request.Model, "claude") {
|
||||
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
}
|
||||
aiRequest, err := service.ClaudeToOpenAIRequest(*request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if info.SupportStreamOptions {
|
||||
aiRequest.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, aiRequest)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||
@@ -107,7 +129,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -146,7 +168,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
@@ -228,6 +250,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = jina.JinaRerankHandler(c, resp)
|
||||
default:
|
||||
if info.IsStream {
|
||||
err, usage = OaiStreamHandler(c, resp, info)
|
||||
@@ -248,6 +272,8 @@ func (a *Adaptor) GetModelList() []string {
|
||||
return lingyiwanwu.ModelList
|
||||
case common.ChannelTypeMiniMax:
|
||||
return minimax.ModelList
|
||||
case common.ChannelTypeXinference:
|
||||
return xinference.ModelList
|
||||
default:
|
||||
return ModelList
|
||||
}
|
||||
@@ -263,6 +289,8 @@ func (a *Adaptor) GetChannelName() string {
|
||||
return lingyiwanwu.ChannelName
|
||||
case common.ChannelTypeMiniMax:
|
||||
return minimax.ChannelName
|
||||
case common.ChannelTypeXinference:
|
||||
return xinference.ChannelName
|
||||
default:
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ var ModelList = []string{
|
||||
"chatgpt-4o-latest",
|
||||
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
|
||||
"o1-preview", "o1-preview-2024-09-12",
|
||||
"o1-mini", "o1-mini-2024-09-12",
|
||||
"o3-mini", "o3-mini-2025-01-31",
|
||||
|
||||
188
relay/channel/openai/helper.go
Normal file
188
relay/channel/openai/helper.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 辅助函数
|
||||
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
return handleClaudeFormat(c, data, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||
for _, resp := range claudeResponses {
|
||||
helper.ClaudeData(c, *resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||
*toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
|
||||
common.SysError("error processing stream response: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||
*toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||||
systemFingerprint *string, model *string, usage **dto.Usage,
|
||||
containStreamUsage *bool, info *relaycommon.RelayInfo,
|
||||
shouldSendLastResp *bool) error {
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*responseId = lastStreamResponse.Id
|
||||
*createAt = lastStreamResponse.Created
|
||||
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
||||
*model = lastStreamResponse.Model
|
||||
|
||||
if service.ValidUsage(lastStreamResponse.Usage) {
|
||||
*containStreamUsage = true
|
||||
*usage = lastStreamResponse.Usage
|
||||
if !info.ShouldIncludeUsage {
|
||||
*shouldSendLastResp = false
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||
responseId string, createAt int64, model string, systemFingerprint string,
|
||||
usage *dto.Usage, containStreamUsage bool) {
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||
response.SetSystemFingerprint(systemFingerprint)
|
||||
helper.ObjectData(c, response)
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
case relaycommon.RelayFormatClaude:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
streamResponse.Usage = usage
|
||||
}
|
||||
|
||||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||
for _, resp := range claudeResponses {
|
||||
helper.ClaudeData(c, *resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -13,12 +12,10 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -32,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
}
|
||||
|
||||
if !forceFormat && !thinkToContent {
|
||||
return service.StringData(c, data)
|
||||
return helper.StringData(c, data)
|
||||
}
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
@@ -41,44 +38,68 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
}
|
||||
|
||||
if !thinkToContent {
|
||||
return service.ObjectData(c, lastStreamResponse)
|
||||
return helper.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
|
||||
hasThinkingContent := false
|
||||
hasContent := false
|
||||
var thinkingContent strings.Builder
|
||||
for _, choice := range lastStreamResponse.Choices {
|
||||
if len(choice.Delta.GetReasoningContent()) > 0 {
|
||||
hasThinkingContent = true
|
||||
thinkingContent.WriteString(choice.Delta.GetReasoningContent())
|
||||
}
|
||||
if len(choice.Delta.GetContentString()) > 0 {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
|
||||
// Handle think to content conversion
|
||||
if info.IsFirstResponse {
|
||||
response := lastStreamResponse.Copy()
|
||||
for i := range response.Choices {
|
||||
response.Choices[i].Delta.SetContentString("<think>\n")
|
||||
response.Choices[i].Delta.SetReasoningContent("")
|
||||
if info.ThinkingContentInfo.IsFirstThinkingContent {
|
||||
if hasThinkingContent {
|
||||
response := lastStreamResponse.Copy()
|
||||
for i := range response.Choices {
|
||||
// send `think` tag with thinking content
|
||||
response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
|
||||
response.Choices[i].Delta.ReasoningContent = nil
|
||||
response.Choices[i].Delta.Reasoning = nil
|
||||
}
|
||||
info.ThinkingContentInfo.IsFirstThinkingContent = false
|
||||
return helper.ObjectData(c, response)
|
||||
}
|
||||
service.ObjectData(c, response)
|
||||
}
|
||||
|
||||
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
|
||||
return service.ObjectData(c, lastStreamResponse)
|
||||
return helper.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
|
||||
// Process each choice
|
||||
for i, choice := range lastStreamResponse.Choices {
|
||||
// Handle transition from thinking to content
|
||||
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
|
||||
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent {
|
||||
response := lastStreamResponse.Copy()
|
||||
for j := range response.Choices {
|
||||
response.Choices[j].Delta.SetContentString("\n</think>")
|
||||
response.Choices[j].Delta.SetReasoningContent("")
|
||||
response.Choices[j].Delta.SetContentString("\n</think>\n")
|
||||
response.Choices[j].Delta.ReasoningContent = nil
|
||||
response.Choices[j].Delta.Reasoning = nil
|
||||
}
|
||||
info.SendLastReasoningResponse = true
|
||||
service.ObjectData(c, response)
|
||||
info.ThinkingContentInfo.SendLastThinkingContent = true
|
||||
helper.ObjectData(c, response)
|
||||
}
|
||||
|
||||
// Convert reasoning content to regular content
|
||||
if len(choice.Delta.GetReasoningContent()) > 0 {
|
||||
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
|
||||
lastStreamResponse.Choices[i].Delta.SetReasoningContent("")
|
||||
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
|
||||
lastStreamResponse.Choices[i].Delta.Reasoning = nil
|
||||
} else if !hasThinkingContent && !hasContent {
|
||||
// flush thinking content
|
||||
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
|
||||
lastStreamResponse.Choices[i].Delta.Reasoning = nil
|
||||
}
|
||||
}
|
||||
|
||||
return service.ObjectData(c, lastStreamResponse)
|
||||
return helper.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
|
||||
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
@@ -108,64 +129,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
toolCount := 0
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
|
||||
// twice timeout for o1 model
|
||||
streamingTimeout *= 2
|
||||
}
|
||||
ticker := time.NewTicker(streamingTimeout)
|
||||
defer ticker.Stop()
|
||||
|
||||
stopChan := make(chan bool)
|
||||
defer close(stopChan)
|
||||
var (
|
||||
lastStreamData string
|
||||
mu sync.Mutex
|
||||
)
|
||||
gopool.Go(func() {
|
||||
for scanner.Scan() {
|
||||
//info.SetFirstResponseTime()
|
||||
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
data := scanner.Text()
|
||||
if common.DebugEnabled {
|
||||
println(data)
|
||||
}
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
mu.Lock()
|
||||
data = data[5:]
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
if lastStreamData != "" {
|
||||
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if err != nil {
|
||||
common.LogError(c, "streaming error: "+err.Error())
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
}
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
common.LogError(c, "streaming timeout")
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
}
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
if lastStreamData != "" {
|
||||
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if err != nil {
|
||||
common.SysError("error handling stream format: " + err.Error())
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
}
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
return true
|
||||
})
|
||||
|
||||
shouldSendLastResp := true
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
@@ -192,96 +172,24 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
}
|
||||
|
||||
// 计算token
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
//if service.ValidUsage(streamResponse.Usage) {
|
||||
// usage = streamResponse.Usage
|
||||
//}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
//if service.ValidUsage(streamResponse.Usage) {
|
||||
// usage = streamResponse.Usage
|
||||
// containStreamUsage = true
|
||||
//}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case relayconstant.RelayModeCompletions:
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 处理token计算
|
||||
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
||||
common.SysError("error processing tokens: " + err.Error())
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
||||
if usage.PromptCacheHitTokens != 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||
response.SetSystemFingerprint(systemFingerprint)
|
||||
service.ObjectData(c, response)
|
||||
}
|
||||
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
|
||||
service.Done(c)
|
||||
|
||||
resp.Body.Close()
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -323,7 +231,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
|
||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, model)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
@@ -512,7 +420,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
|
||||
err = service.WssString(c, targetConn, string(message))
|
||||
err = helper.WssString(c, targetConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||
return
|
||||
@@ -618,7 +526,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||
}
|
||||
|
||||
err = service.WssString(c, clientConn, string(message))
|
||||
err = helper.WssString(c, clientConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||
return
|
||||
|
||||
80
relay/channel/openrouter/adaptor.go
Normal file
80
relay/channel/openrouter/adaptor.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package openrouter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
||||
req.Set("X-Title", "New API")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
5
relay/channel/openrouter/constant.go
Normal file
5
relay/channel/openrouter/constant.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package openrouter
|
||||
|
||||
var ModelList = []string{}
|
||||
|
||||
var ChannelName = "openrouter"
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -54,7 +60,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
@@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
||||
dataChan <- string(jsonResponse)
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -57,7 +63,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -48,7 +54,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,12 @@ type Adaptor struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -52,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -78,7 +84,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
@@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
responseText += response.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
|
||||
err = service.ObjectData(c, response)
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
@@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
common.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
service.Done(c)
|
||||
helper.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/copier"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
@@ -28,6 +27,8 @@ var claudeModelMap = map[string]string{
|
||||
"claude-3-opus-20240229": "claude-3-opus@20240229",
|
||||
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
|
||||
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
|
||||
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
|
||||
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
@@ -37,6 +38,9 @@ type Adaptor struct {
|
||||
AccountCredentials Credentials
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -85,15 +89,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
} else {
|
||||
suffix = "rawPredict"
|
||||
}
|
||||
model := info.UpstreamModelName
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
info.UpstreamModelName = v
|
||||
model = v
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
info.UpstreamModelName,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
@@ -117,7 +122,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -126,13 +131,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vertexClaudeReq := &VertexAIClaudeRequest{
|
||||
AnthropicVersion: anthropicVersion,
|
||||
}
|
||||
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
|
||||
return nil, errors.New("failed to copy claude request")
|
||||
}
|
||||
c.Set("request_model", request.Model)
|
||||
vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
|
||||
c.Set("request_model", claudeReq.Model)
|
||||
info.UpstreamModelName = claudeReq.Model
|
||||
return vertexClaudeReq, nil
|
||||
} else if a.RequestMode == RequestModeGemini {
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
|
||||
@@ -156,7 +157,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -1,17 +1,37 @@
|
||||
package vertex
|
||||
|
||||
import "one-api/relay/channel/claude"
|
||||
import (
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
type VertexAIClaudeRequest struct {
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
Messages []claude.ClaudeMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools []claude.Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
Messages []dto.ClaudeMessage `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *dto.Thinking `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
|
||||
return &VertexAIClaudeRequest{
|
||||
AnthropicVersion: version,
|
||||
System: req.System,
|
||||
Messages: req.Messages,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Stream: req.Stream,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
TopK: req.TopK,
|
||||
StopSequences: req.StopSequences,
|
||||
Tools: req.Tools,
|
||||
ToolChoice: req.ToolChoice,
|
||||
Thinking: req.Thinking,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -50,7 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
7
relay/channel/xinference/constant.go
Normal file
7
relay/channel/xinference/constant.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package xinference
|
||||
|
||||
var ModelList = []string{
|
||||
"bge-reranker-v2-m3",
|
||||
}
|
||||
|
||||
var ChannelName = "xinference"
|
||||
@@ -16,6 +16,12 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -55,7 +61,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
}
|
||||
service.SetEventStreamHeaders(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
var usage dto.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
|
||||
@@ -14,6 +14,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -42,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -61,7 +67,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -39,7 +45,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
@@ -58,7 +64,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
|
||||
163
relay/claude_handler.go
Normal file
163
relay/claude_handler.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
|
||||
textRequest = &dto.ClaudeRequest{}
|
||||
err = c.ShouldBindJSON(textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return nil, errors.New("field messages is required")
|
||||
}
|
||||
if textRequest.Model == "" {
|
||||
return nil, errors.New("field model is required")
|
||||
}
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoClaude(c)
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateClaudeRequest(c)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if textRequest.Stream {
|
||||
relayInfo.IsStream = true
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
textRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
|
||||
if openaiErr != nil {
|
||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
|
||||
if textRequest.MaxTokens == 0 {
|
||||
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
}
|
||||
|
||||
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
if textRequest.Thinking == nil {
|
||||
// 因为BudgetTokens 必须大于1024
|
||||
if textRequest.MaxTokens < 1280 {
|
||||
textRequest.MaxTokens = 1280
|
||||
}
|
||||
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
textRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
textRequest.TopP = 0
|
||||
textRequest.Temperature = common.GetPointer[float64](1.0)
|
||||
}
|
||||
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
}
|
||||
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
//log.Printf("usage: %v", usage)
|
||||
if openaiErr != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
||||
}
|
||||
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
var promptTokens int
|
||||
var err error
|
||||
switch info.RelayMode {
|
||||
default:
|
||||
promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
|
||||
}
|
||||
info.PromptTokens = promptTokens
|
||||
return promptTokens, err
|
||||
}
|
||||
@@ -12,25 +12,45 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type ThinkingContentInfo struct {
|
||||
IsFirstThinkingContent bool
|
||||
SendLastThinkingContent bool
|
||||
}
|
||||
|
||||
const (
|
||||
LastMessageTypeText = "text"
|
||||
LastMessageTypeTools = "tools"
|
||||
)
|
||||
|
||||
type ClaudeConvertInfo struct {
|
||||
LastMessagesType string
|
||||
Index int
|
||||
}
|
||||
|
||||
const (
|
||||
RelayFormatOpenAI = "openai"
|
||||
RelayFormatClaude = "claude"
|
||||
)
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
IsFirstResponse bool
|
||||
SendLastReasoningResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
isFirstResponse bool
|
||||
//SendLastReasoningResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
//RecodeModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
@@ -50,6 +70,13 @@ type RelayInfo struct {
|
||||
AudioUsage bool
|
||||
ReasoningEffort string
|
||||
ChannelSetting map[string]interface{}
|
||||
UserSetting map[string]interface{}
|
||||
UserEmail string
|
||||
UserQuota int
|
||||
RelayFormat string
|
||||
SendResponseCount int
|
||||
ThinkingContentInfo
|
||||
ClaudeConvertInfo
|
||||
}
|
||||
|
||||
// 定义支持流式选项的通道类型
|
||||
@@ -73,6 +100,16 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatClaude
|
||||
info.ShouldIncludeUsage = false
|
||||
info.ClaudeConvertInfo = ClaudeConvertInfo{
|
||||
LastMessagesType: LastMessageTypeText,
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := c.GetInt("channel_type")
|
||||
channelId := c.GetInt("channel_id")
|
||||
@@ -89,7 +126,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
IsFirstResponse: true,
|
||||
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
||||
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
||||
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
@@ -111,6 +151,11 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
ChannelSetting: channelSetting,
|
||||
RelayFormat: RelayFormatOpenAI,
|
||||
ThinkingContentInfo: ThinkingContentInfo{
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
},
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||
info.IsPlayground = true
|
||||
@@ -141,26 +186,14 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetFirstResponseTime() {
|
||||
if info.IsFirstResponse {
|
||||
if info.isFirstResponse {
|
||||
info.FirstResponseTime = time.Now()
|
||||
info.IsFirstResponse = false
|
||||
info.isFirstResponse = false
|
||||
}
|
||||
}
|
||||
|
||||
type TaskRelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
UserId int
|
||||
Group string
|
||||
StartTime time.Time
|
||||
ApiType int
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
RequestURLPath string
|
||||
ApiKey string
|
||||
BaseUrl string
|
||||
|
||||
*RelayInfo
|
||||
Action string
|
||||
OriginTaskID string
|
||||
|
||||
@@ -168,48 +201,8 @@ type TaskRelayInfo struct {
|
||||
}
|
||||
|
||||
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
channelType := c.GetInt("channel_type")
|
||||
channelId := c.GetInt("channel_id")
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &TaskRelayInfo{
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
StartTime: startTime,
|
||||
ApiType: apiType,
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||
RelayInfo: GenRelayInfo(c),
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
|
||||
return &RelayInfo{
|
||||
ChannelType: info.ChannelType,
|
||||
ChannelId: info.ChannelId,
|
||||
TokenId: info.TokenId,
|
||||
UserId: info.UserId,
|
||||
Group: info.Group,
|
||||
StartTime: info.StartTime,
|
||||
ApiType: info.ApiType,
|
||||
RelayMode: info.RelayMode,
|
||||
UpstreamModelName: info.UpstreamModelName,
|
||||
RequestURLPath: info.RequestURLPath,
|
||||
ApiKey: info.ApiKey,
|
||||
BaseUrl: info.BaseUrl,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ const (
|
||||
APITypeMokaAI
|
||||
APITypeVolcEngine
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
@@ -86,6 +88,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeVolcEngine
|
||||
case common.ChannelTypeBaiduV2:
|
||||
apiType = APITypeBaiduV2
|
||||
case common.ChannelTypeOpenRouter:
|
||||
apiType = APITypeOpenRouter
|
||||
case common.ChannelTypeXinference:
|
||||
apiType = APITypeXinference
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package helper
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -19,6 +19,30 @@ func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
|
||||
jsonData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
} else {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func StringData(c *gin.Context, str string) error {
|
||||
//str = strings.TrimPrefix(str, "data: ")
|
||||
//str = strings.TrimSuffix(str, "\r")
|
||||
@@ -1,31 +1,50 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
)
|
||||
|
||||
type PriceData struct {
|
||||
ModelPrice float64
|
||||
ModelRatio float64
|
||||
CompletionRatio float64
|
||||
CacheRatio float64
|
||||
GroupRatio float64
|
||||
UsePrice bool
|
||||
CacheCreationRatio float64
|
||||
ShouldPreConsumedQuota int
|
||||
}
|
||||
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
|
||||
modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
||||
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(info.Group)
|
||||
var preConsumedQuota int
|
||||
var modelRatio float64
|
||||
var completionRatio float64
|
||||
var cacheRatio float64
|
||||
var cacheCreationRatio float64
|
||||
if !usePrice {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if maxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + maxTokens
|
||||
}
|
||||
modelRatio = common.GetModelRatio(info.OriginModelName)
|
||||
var success bool
|
||||
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
|
||||
if !success {
|
||||
if info.UserId == 1 {
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
|
||||
} else {
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
|
||||
}
|
||||
}
|
||||
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
|
||||
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
@@ -34,8 +53,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
return PriceData{
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
CompletionRatio: completionRatio,
|
||||
GroupRatio: groupRatio,
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
CacheCreationRatio: cacheCreationRatio,
|
||||
ShouldPreConsumedQuota: preConsumedQuota,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
91
relay/helper/stream_scanner.go
Normal file
91
relay/helper/stream_scanner.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
|
||||
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
|
||||
// twice timeout for thinking model
|
||||
streamingTimeout *= 2
|
||||
}
|
||||
|
||||
var (
|
||||
stopChan = make(chan bool, 2)
|
||||
scanner = bufio.NewScanner(resp.Body)
|
||||
ticker = time.NewTicker(streamingTimeout)
|
||||
)
|
||||
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
close(stopChan)
|
||||
}()
|
||||
|
||||
scanner.Split(bufio.ScanLines)
|
||||
SetEventStreamHeaders(c)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ctx = context.WithValue(ctx, "stop_chan", stopChan)
|
||||
common.RelayCtxGo(ctx, func() {
|
||||
for scanner.Scan() {
|
||||
ticker.Reset(streamingTimeout)
|
||||
data := scanner.Text()
|
||||
if common.DebugEnabled {
|
||||
println(data)
|
||||
}
|
||||
|
||||
if len(data) < 6 {
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
data = strings.TrimLeft(data, " ")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
info.SetFirstResponseTime()
|
||||
success := dataHandler(data)
|
||||
if !success {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if err != io.EOF {
|
||||
common.LogError(c, "scanner error: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
common.SafeSendBool(stopChan, true)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
common.LogError(c, "streaming timeout")
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
@@ -75,12 +74,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo.PromptTokens = promptTokens
|
||||
}
|
||||
|
||||
priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
@@ -119,7 +117,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp)
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user